diff --git a/tests/test_auth.py b/tests/test_auth.py index f2c7da6..fcd266e 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,6 +1,9 @@ +from base64 import b64encode +from contextlib import asynccontextmanager from os import environ from subprocess import run from tempfile import NamedTemporaryFile +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -133,3 +136,28 @@ def test_precedence(scenario, expected, monkeypatch): match="api_key is not available when using an Ethereum private key", ): client.api_key # noqa: B018 + + +@pytest.mark.asyncio +async def test_basic_auth_header(): + api_key = "testkey" + captured_headers = {} + + response_mock = MagicMock() + response_mock.status = 200 + response_mock.headers = {} + response_mock.__aenter__ = AsyncMock(return_value=response_mock) + response_mock.__aexit__ = AsyncMock(return_value=False) + response_mock.json = AsyncMock(return_value={"url": "https://a.example"}) + + @asynccontextmanager + async def fake_post(**kwargs): + captured_headers.update(kwargs.get("headers", {})) + yield response_mock + + with patch("zyte_api._async._post_func", return_value=fake_post): + client = AsyncZyteAPI(api_key=api_key) + await client.get({"url": "https://a.example", "httpResponseBody": True}) + + expected = "Basic " + b64encode(f"{api_key}:".encode()).decode() + assert captured_headers.get("Authorization") == expected diff --git a/zyte_api/_async.py b/zyte_api/_async.py index 39588d4..69804b1 100644 --- a/zyte_api/_async.py +++ b/zyte_api/_async.py @@ -203,9 +203,11 @@ async def get( query = _process_query(query) headers = {"User-Agent": self.user_agent, "Accept-Encoding": "br"} - auth_kwargs = {} if isinstance(self._auth, str): - auth_kwargs["auth"] = aiohttp.BasicAuth(self._auth) + if hasattr(aiohttp, "encode_basic_auth"): # aiohttp 3.14+ + headers["Authorization"] = aiohttp.encode_basic_auth(self._auth, "") + else: + headers["Authorization"] = aiohttp.BasicAuth(self._auth).encode() else: x402_headers = await self._auth.get_headers(url, query, headers, post) headers.update(x402_headers) @@ -214,7 +216,6 @@ async def get( "url": url, "json": query, "headers": headers, - **auth_kwargs, } response_stats = []