Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 51 additions & 22 deletions src/auth0_api_python/api_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import time
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Optional, Union

Expand All @@ -22,6 +23,7 @@
VerifyAccessTokenError,
)
from .utils import (
aclose_default_httpx_client,
calculate_jwk_thumbprint,
fetch_jwks,
fetch_oidc_metadata,
Expand Down Expand Up @@ -111,11 +113,24 @@ def __init__(self, options: ApiClientOptions):

self._cache_ttl = options.cache_ttl_seconds

# Per-cache-key single-flight locks for OIDC discovery and JWKS
# refetches. Without these, every concurrent request that misses the
# cache at the moment of TTL expiry fires its own outbound HTTP call
# — a thundering herd that Auth0 rate-limits and we time out on.
# The lock guarantees only ONE coroutine per cache key refetches;
# the rest await the result and read from the now-warm cache.
self._discovery_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self._jwks_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)

self._jwt = JsonWebToken(["RS256"])

self._dpop_algorithms = ["ES256"]
self._dpop_jwt = JsonWebToken(self._dpop_algorithms)

async def aclose(self) -> None:
"""Release the shared default httpx client. Idempotent; no-op when a `custom_fetch` is in use."""
await aclose_default_httpx_client()

def is_dpop_required(self) -> bool:
"""Check if DPoP authentication is required."""
return getattr(self.options, "dpop_required", False)
Expand Down Expand Up @@ -1029,19 +1044,26 @@ async def _discover(self, issuer: Optional[str] = None) -> dict[str, Any]:
if cached:
return cached

metadata, max_age = await fetch_oidc_metadata(
domain=domain,
custom_fetch=self.options.custom_fetch
)
# Single-flight: only one coroutine per cache_key refetches; the
# rest await it and re-check the cache after acquiring the lock.
async with self._discovery_locks[cache_key]:
cached = self._discovery_cache.get(cache_key)
if cached:
return cached

metadata, max_age = await fetch_oidc_metadata(
domain=domain,
custom_fetch=self.options.custom_fetch
)

effective_ttl = self._cache_ttl
if max_age is not None and self._cache_ttl is not None:
effective_ttl = min(max_age, self._cache_ttl)
elif max_age is not None:
effective_ttl = max_age
effective_ttl = self._cache_ttl
if max_age is not None and self._cache_ttl is not None:
effective_ttl = min(max_age, self._cache_ttl)
elif max_age is not None:
effective_ttl = max_age

self._discovery_cache.set(cache_key, metadata, ttl_seconds=effective_ttl)
return metadata
self._discovery_cache.set(cache_key, metadata, ttl_seconds=effective_ttl)
return metadata

async def _fetch_jwks(self, jwks_uri: str) -> dict[str, Any]:
"""
Expand All @@ -1060,19 +1082,26 @@ async def _fetch_jwks(self, jwks_uri: str) -> dict[str, Any]:
if cached:
return cached

jwks_data, max_age = await fetch_jwks(
jwks_uri=jwks_uri,
custom_fetch=self.options.custom_fetch
)
# Single-flight: only one coroutine per cache_key refetches; the
# rest await it and re-check the cache after acquiring the lock.
async with self._jwks_locks[cache_key]:
cached = self._jwks_cache.get(cache_key)
if cached:
return cached

jwks_data, max_age = await fetch_jwks(
jwks_uri=jwks_uri,
custom_fetch=self.options.custom_fetch
)

effective_ttl = self._cache_ttl
if max_age is not None and self._cache_ttl is not None:
effective_ttl = min(max_age, self._cache_ttl)
elif max_age is not None:
effective_ttl = max_age
effective_ttl = self._cache_ttl
if max_age is not None and self._cache_ttl is not None:
effective_ttl = min(max_age, self._cache_ttl)
elif max_age is not None:
effective_ttl = max_age

self._jwks_cache.set(cache_key, jwks_data, ttl_seconds=effective_ttl)
return jwks_data
self._jwks_cache.set(cache_key, jwks_data, ttl_seconds=effective_ttl)
return jwks_data

def _validate_claims_presence(
self,
Expand Down
55 changes: 45 additions & 10 deletions src/auth0_api_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using httpx or a custom fetch approach.
"""

import asyncio
import base64
import hashlib
import json
Expand All @@ -13,6 +14,40 @@
import httpx
from ada_url import URL

_DEFAULT_HTTPX_CLIENT: Optional[httpx.AsyncClient] = None
_DEFAULT_HTTPX_CLIENT_LOCK = asyncio.Lock()


def _build_default_httpx_client() -> httpx.AsyncClient:
"""Construct the default shared client used when no `custom_fetch` is set."""
return httpx.AsyncClient(
timeout=httpx.Timeout(connect=5.0, read=10.0, write=5.0, pool=5.0),
limits=httpx.Limits(
max_connections=200,
max_keepalive_connections=50,
),
transport=httpx.AsyncHTTPTransport(retries=2),
)


async def _get_default_httpx_client() -> httpx.AsyncClient:
"""Return the shared default client, creating it on first use."""
global _DEFAULT_HTTPX_CLIENT
if _DEFAULT_HTTPX_CLIENT is not None and not _DEFAULT_HTTPX_CLIENT.is_closed:
return _DEFAULT_HTTPX_CLIENT
async with _DEFAULT_HTTPX_CLIENT_LOCK:
if _DEFAULT_HTTPX_CLIENT is None or _DEFAULT_HTTPX_CLIENT.is_closed:
_DEFAULT_HTTPX_CLIENT = _build_default_httpx_client()
return _DEFAULT_HTTPX_CLIENT


async def aclose_default_httpx_client() -> None:
"""Close the module-level shared httpx client. Idempotent."""
global _DEFAULT_HTTPX_CLIENT
if _DEFAULT_HTTPX_CLIENT is not None and not _DEFAULT_HTTPX_CLIENT.is_closed:
await _DEFAULT_HTTPX_CLIENT.aclose()
_DEFAULT_HTTPX_CLIENT = None


def parse_cache_control_max_age(headers: Mapping[str, str]) -> Optional[int]:
"""
Expand Down Expand Up @@ -102,11 +137,11 @@ async def fetch_oidc_metadata(
return data, max_age
return response, None
else:
async with httpx.AsyncClient() as client:
resp = await client.get(url)
resp.raise_for_status()
max_age = parse_cache_control_max_age(resp.headers)
return resp.json(), max_age
client = await _get_default_httpx_client()
resp = await client.get(url)
resp.raise_for_status()
max_age = parse_cache_control_max_age(resp.headers)
return resp.json(), max_age


async def fetch_jwks(
Expand All @@ -128,11 +163,11 @@ async def fetch_jwks(
return data, max_age
return response, None
else:
async with httpx.AsyncClient() as client:
resp = await client.get(jwks_uri)
resp.raise_for_status()
max_age = parse_cache_control_max_age(resp.headers)
return resp.json(), max_age
client = await _get_default_httpx_client()
resp = await client.get(jwks_uri)
resp.raise_for_status()
max_age = parse_cache_control_max_age(resp.headers)
return resp.json(), max_age


def _decode_jwt_segment(token: Union[str, bytes], segment_index: int) -> dict:
Expand Down
137 changes: 137 additions & 0 deletions tests/test_concurrent_fetch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import asyncio

import pytest
import pytest_asyncio
from conftest import DISCOVERY_URL, JWKS_URL
from pytest_httpx import HTTPXMock

from auth0_api_python import ApiClient, ApiClientOptions
from auth0_api_python import utils as auth0_utils

# ===== Fixtures =====

@pytest_asyncio.fixture(autouse=True)
async def _reset_default_httpx_client():
await auth0_utils.aclose_default_httpx_client()
yield
await auth0_utils.aclose_default_httpx_client()


# ===== Single-flight: JWKS =====

@pytest.mark.asyncio
async def test_concurrent_jwks_misses_trigger_single_fetch(httpx_mock: HTTPXMock):
"""N concurrent JWKS cache misses for the same URI cause exactly one upstream fetch."""
httpx_mock.add_response(
method="GET",
url=JWKS_URL,
json={"keys": []},
is_reusable=True,
)

api_client = ApiClient(
ApiClientOptions(domain="auth0.local", audience="my-audience")
)

results = await asyncio.gather(
*(api_client._fetch_jwks(JWKS_URL) for _ in range(50))
)

assert all(r == {"keys": []} for r in results)
requests = [r for r in httpx_mock.get_requests() if str(r.url) == JWKS_URL]
assert len(requests) == 1


# ===== Single-flight: OIDC discovery =====

@pytest.mark.asyncio
async def test_concurrent_oidc_misses_trigger_single_fetch(httpx_mock: HTTPXMock):
"""N concurrent OIDC discovery cache misses cause exactly one upstream fetch."""
httpx_mock.add_response(
method="GET",
url=DISCOVERY_URL,
json={"issuer": "https://auth0.local/", "jwks_uri": JWKS_URL},
is_reusable=True,
)

api_client = ApiClient(
ApiClientOptions(domain="auth0.local", audience="my-audience")
)

results = await asyncio.gather(
*(api_client._discover() for _ in range(50))
)

expected = {"issuer": "https://auth0.local/", "jwks_uri": JWKS_URL}
assert all(r == expected for r in results)
requests = [r for r in httpx_mock.get_requests() if str(r.url) == DISCOVERY_URL]
assert len(requests) == 1


# ===== Per-key locking =====

@pytest.mark.asyncio
async def test_jwks_locks_are_per_cache_key(httpx_mock: HTTPXMock):
"""Concurrent misses for different JWKS URIs are not serialized behind one global lock."""
uri_a = "https://a.auth0.local/.well-known/jwks.json"
uri_b = "https://b.auth0.local/.well-known/jwks.json"
httpx_mock.add_response(method="GET", url=uri_a, json={"keys": ["a"]})
httpx_mock.add_response(method="GET", url=uri_b, json={"keys": ["b"]})

api_client = ApiClient(
ApiClientOptions(domain="auth0.local", audience="my-audience")
)

a, b = await asyncio.gather(
api_client._fetch_jwks(uri_a),
api_client._fetch_jwks(uri_b),
)

assert a == {"keys": ["a"]}
assert b == {"keys": ["b"]}

requests_a = [r for r in httpx_mock.get_requests() if str(r.url) == uri_a]
requests_b = [r for r in httpx_mock.get_requests() if str(r.url) == uri_b]
assert len(requests_a) == 1
assert len(requests_b) == 1


# ===== Default httpx client =====

@pytest.mark.asyncio
async def test_default_httpx_client_is_shared():
"""The default httpx client is a singleton across calls."""
first = await auth0_utils._get_default_httpx_client()
second = await auth0_utils._get_default_httpx_client()

assert first is second


@pytest.mark.asyncio
async def test_default_httpx_client_has_explicit_timeouts():
"""The default httpx client sets explicit, non-default timeouts."""
client = await auth0_utils._get_default_httpx_client()

assert client.timeout.connect is not None
assert client.timeout.read is not None
assert client.timeout.write is not None
assert client.timeout.pool is not None
assert client.timeout.read >= 5.0


# ===== Shutdown =====

@pytest.mark.asyncio
async def test_aclose_is_idempotent():
"""`aclose()` is safe to call repeatedly and the client can be re-created afterward."""
await auth0_utils.aclose_default_httpx_client()
await auth0_utils.aclose_default_httpx_client()

api_client = ApiClient(
ApiClientOptions(domain="auth0.local", audience="my-audience")
)
await api_client.aclose()
await api_client.aclose() # idempotent

new_client = await auth0_utils._get_default_httpx_client()
assert not new_client.is_closed
Loading