diff --git a/src/auth0_api_python/api_client.py b/src/auth0_api_python/api_client.py index 36f23cf..5c3726b 100644 --- a/src/auth0_api_python/api_client.py +++ b/src/auth0_api_python/api_client.py @@ -1,5 +1,6 @@ import asyncio import time +from collections import defaultdict from collections.abc import Mapping, Sequence from typing import Any, Optional, Union @@ -22,6 +23,7 @@ VerifyAccessTokenError, ) from .utils import ( + aclose_default_httpx_client, calculate_jwk_thumbprint, fetch_jwks, fetch_oidc_metadata, @@ -111,11 +113,24 @@ def __init__(self, options: ApiClientOptions): self._cache_ttl = options.cache_ttl_seconds + # Per-cache-key single-flight locks for OIDC discovery and JWKS + # refetches. Without these, every concurrent request that misses the + # cache at the moment of TTL expiry fires its own outbound HTTP call + # — a thundering herd that Auth0 rate-limits and we time out on. + # The lock guarantees only ONE coroutine per cache key refetches; + # the rest await the result and read from the now-warm cache. + self._discovery_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) + self._jwks_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) + self._jwt = JsonWebToken(["RS256"]) self._dpop_algorithms = ["ES256"] self._dpop_jwt = JsonWebToken(self._dpop_algorithms) + async def aclose(self) -> None: + """Release the shared default httpx client. Idempotent; no-op when a `custom_fetch` is in use.""" + await aclose_default_httpx_client() + def is_dpop_required(self) -> bool: """Check if DPoP authentication is required.""" return getattr(self.options, "dpop_required", False) @@ -1029,19 +1044,26 @@ async def _discover(self, issuer: Optional[str] = None) -> dict[str, Any]: if cached: return cached - metadata, max_age = await fetch_oidc_metadata( - domain=domain, - custom_fetch=self.options.custom_fetch - ) + # Single-flight: only one coroutine per cache_key refetches; the + # rest await it and re-check the cache after acquiring the lock. + async with self._discovery_locks[cache_key]: + cached = self._discovery_cache.get(cache_key) + if cached: + return cached + + metadata, max_age = await fetch_oidc_metadata( + domain=domain, + custom_fetch=self.options.custom_fetch + ) - effective_ttl = self._cache_ttl - if max_age is not None and self._cache_ttl is not None: - effective_ttl = min(max_age, self._cache_ttl) - elif max_age is not None: - effective_ttl = max_age + effective_ttl = self._cache_ttl + if max_age is not None and self._cache_ttl is not None: + effective_ttl = min(max_age, self._cache_ttl) + elif max_age is not None: + effective_ttl = max_age - self._discovery_cache.set(cache_key, metadata, ttl_seconds=effective_ttl) - return metadata + self._discovery_cache.set(cache_key, metadata, ttl_seconds=effective_ttl) + return metadata async def _fetch_jwks(self, jwks_uri: str) -> dict[str, Any]: """ @@ -1060,19 +1082,26 @@ async def _fetch_jwks(self, jwks_uri: str) -> dict[str, Any]: if cached: return cached - jwks_data, max_age = await fetch_jwks( - jwks_uri=jwks_uri, - custom_fetch=self.options.custom_fetch - ) + # Single-flight: only one coroutine per cache_key refetches; the + # rest await it and re-check the cache after acquiring the lock. + async with self._jwks_locks[cache_key]: + cached = self._jwks_cache.get(cache_key) + if cached: + return cached + + jwks_data, max_age = await fetch_jwks( + jwks_uri=jwks_uri, + custom_fetch=self.options.custom_fetch + ) - effective_ttl = self._cache_ttl - if max_age is not None and self._cache_ttl is not None: - effective_ttl = min(max_age, self._cache_ttl) - elif max_age is not None: - effective_ttl = max_age + effective_ttl = self._cache_ttl + if max_age is not None and self._cache_ttl is not None: + effective_ttl = min(max_age, self._cache_ttl) + elif max_age is not None: + effective_ttl = max_age - self._jwks_cache.set(cache_key, jwks_data, ttl_seconds=effective_ttl) - return jwks_data + self._jwks_cache.set(cache_key, jwks_data, ttl_seconds=effective_ttl) + return jwks_data def _validate_claims_presence( self, diff --git a/src/auth0_api_python/utils.py b/src/auth0_api_python/utils.py index 72fbef8..7349782 100644 --- a/src/auth0_api_python/utils.py +++ b/src/auth0_api_python/utils.py @@ -3,6 +3,7 @@ using httpx or a custom fetch approach. """ +import asyncio import base64 import hashlib import json @@ -13,6 +14,40 @@ import httpx from ada_url import URL +_DEFAULT_HTTPX_CLIENT: Optional[httpx.AsyncClient] = None +_DEFAULT_HTTPX_CLIENT_LOCK = asyncio.Lock() + + +def _build_default_httpx_client() -> httpx.AsyncClient: + """Construct the default shared client used when no `custom_fetch` is set.""" + return httpx.AsyncClient( + timeout=httpx.Timeout(connect=5.0, read=10.0, write=5.0, pool=5.0), + limits=httpx.Limits( + max_connections=200, + max_keepalive_connections=50, + ), + transport=httpx.AsyncHTTPTransport(retries=2), + ) + + +async def _get_default_httpx_client() -> httpx.AsyncClient: + """Return the shared default client, creating it on first use.""" + global _DEFAULT_HTTPX_CLIENT + if _DEFAULT_HTTPX_CLIENT is not None and not _DEFAULT_HTTPX_CLIENT.is_closed: + return _DEFAULT_HTTPX_CLIENT + async with _DEFAULT_HTTPX_CLIENT_LOCK: + if _DEFAULT_HTTPX_CLIENT is None or _DEFAULT_HTTPX_CLIENT.is_closed: + _DEFAULT_HTTPX_CLIENT = _build_default_httpx_client() + return _DEFAULT_HTTPX_CLIENT + + +async def aclose_default_httpx_client() -> None: + """Close the module-level shared httpx client. Idempotent.""" + global _DEFAULT_HTTPX_CLIENT + if _DEFAULT_HTTPX_CLIENT is not None and not _DEFAULT_HTTPX_CLIENT.is_closed: + await _DEFAULT_HTTPX_CLIENT.aclose() + _DEFAULT_HTTPX_CLIENT = None + def parse_cache_control_max_age(headers: Mapping[str, str]) -> Optional[int]: """ @@ -102,11 +137,11 @@ async def fetch_oidc_metadata( return data, max_age return response, None else: - async with httpx.AsyncClient() as client: - resp = await client.get(url) - resp.raise_for_status() - max_age = parse_cache_control_max_age(resp.headers) - return resp.json(), max_age + client = await _get_default_httpx_client() + resp = await client.get(url) + resp.raise_for_status() + max_age = parse_cache_control_max_age(resp.headers) + return resp.json(), max_age async def fetch_jwks( @@ -128,11 +163,11 @@ async def fetch_jwks( return data, max_age return response, None else: - async with httpx.AsyncClient() as client: - resp = await client.get(jwks_uri) - resp.raise_for_status() - max_age = parse_cache_control_max_age(resp.headers) - return resp.json(), max_age + client = await _get_default_httpx_client() + resp = await client.get(jwks_uri) + resp.raise_for_status() + max_age = parse_cache_control_max_age(resp.headers) + return resp.json(), max_age def _decode_jwt_segment(token: Union[str, bytes], segment_index: int) -> dict: diff --git a/tests/test_concurrent_fetch.py b/tests/test_concurrent_fetch.py new file mode 100644 index 0000000..551ef9a --- /dev/null +++ b/tests/test_concurrent_fetch.py @@ -0,0 +1,137 @@ +import asyncio + +import pytest +import pytest_asyncio +from conftest import DISCOVERY_URL, JWKS_URL +from pytest_httpx import HTTPXMock + +from auth0_api_python import ApiClient, ApiClientOptions +from auth0_api_python import utils as auth0_utils + +# ===== Fixtures ===== + +@pytest_asyncio.fixture(autouse=True) +async def _reset_default_httpx_client(): + await auth0_utils.aclose_default_httpx_client() + yield + await auth0_utils.aclose_default_httpx_client() + + +# ===== Single-flight: JWKS ===== + +@pytest.mark.asyncio +async def test_concurrent_jwks_misses_trigger_single_fetch(httpx_mock: HTTPXMock): + """N concurrent JWKS cache misses for the same URI cause exactly one upstream fetch.""" + httpx_mock.add_response( + method="GET", + url=JWKS_URL, + json={"keys": []}, + is_reusable=True, + ) + + api_client = ApiClient( + ApiClientOptions(domain="auth0.local", audience="my-audience") + ) + + results = await asyncio.gather( + *(api_client._fetch_jwks(JWKS_URL) for _ in range(50)) + ) + + assert all(r == {"keys": []} for r in results) + requests = [r for r in httpx_mock.get_requests() if str(r.url) == JWKS_URL] + assert len(requests) == 1 + + +# ===== Single-flight: OIDC discovery ===== + +@pytest.mark.asyncio +async def test_concurrent_oidc_misses_trigger_single_fetch(httpx_mock: HTTPXMock): + """N concurrent OIDC discovery cache misses cause exactly one upstream fetch.""" + httpx_mock.add_response( + method="GET", + url=DISCOVERY_URL, + json={"issuer": "https://auth0.local/", "jwks_uri": JWKS_URL}, + is_reusable=True, + ) + + api_client = ApiClient( + ApiClientOptions(domain="auth0.local", audience="my-audience") + ) + + results = await asyncio.gather( + *(api_client._discover() for _ in range(50)) + ) + + expected = {"issuer": "https://auth0.local/", "jwks_uri": JWKS_URL} + assert all(r == expected for r in results) + requests = [r for r in httpx_mock.get_requests() if str(r.url) == DISCOVERY_URL] + assert len(requests) == 1 + + +# ===== Per-key locking ===== + +@pytest.mark.asyncio +async def test_jwks_locks_are_per_cache_key(httpx_mock: HTTPXMock): + """Concurrent misses for different JWKS URIs are not serialized behind one global lock.""" + uri_a = "https://a.auth0.local/.well-known/jwks.json" + uri_b = "https://b.auth0.local/.well-known/jwks.json" + httpx_mock.add_response(method="GET", url=uri_a, json={"keys": ["a"]}) + httpx_mock.add_response(method="GET", url=uri_b, json={"keys": ["b"]}) + + api_client = ApiClient( + ApiClientOptions(domain="auth0.local", audience="my-audience") + ) + + a, b = await asyncio.gather( + api_client._fetch_jwks(uri_a), + api_client._fetch_jwks(uri_b), + ) + + assert a == {"keys": ["a"]} + assert b == {"keys": ["b"]} + + requests_a = [r for r in httpx_mock.get_requests() if str(r.url) == uri_a] + requests_b = [r for r in httpx_mock.get_requests() if str(r.url) == uri_b] + assert len(requests_a) == 1 + assert len(requests_b) == 1 + + +# ===== Default httpx client ===== + +@pytest.mark.asyncio +async def test_default_httpx_client_is_shared(): + """The default httpx client is a singleton across calls.""" + first = await auth0_utils._get_default_httpx_client() + second = await auth0_utils._get_default_httpx_client() + + assert first is second + + +@pytest.mark.asyncio +async def test_default_httpx_client_has_explicit_timeouts(): + """The default httpx client sets explicit, non-default timeouts.""" + client = await auth0_utils._get_default_httpx_client() + + assert client.timeout.connect is not None + assert client.timeout.read is not None + assert client.timeout.write is not None + assert client.timeout.pool is not None + assert client.timeout.read >= 5.0 + + +# ===== Shutdown ===== + +@pytest.mark.asyncio +async def test_aclose_is_idempotent(): + """`aclose()` is safe to call repeatedly and the client can be re-created afterward.""" + await auth0_utils.aclose_default_httpx_client() + await auth0_utils.aclose_default_httpx_client() + + api_client = ApiClient( + ApiClientOptions(domain="auth0.local", audience="my-audience") + ) + await api_client.aclose() + await api_client.aclose() # idempotent + + new_client = await auth0_utils._get_default_httpx_client() + assert not new_client.is_closed