From e3fe94bbd2ac328e3e4fbb629355557cfcd6c67a Mon Sep 17 00:00:00 2001 From: Zac Farrell Date: Fri, 8 May 2026 09:07:40 -0700 Subject: [PATCH] feat(arrow): add ResultsApi subclass for Arrow IPC results --- .github/workflows/regenerate.yml | 8 + README.md | 29 +++ hotdata/arrow.py | 190 ++++++++++++++++ pyproject.toml | 3 + tests/integration/test_results_arrow.py | 101 +++++++++ tests/test_arrow.py | 283 ++++++++++++++++++++++++ 6 files changed, 614 insertions(+) create mode 100644 hotdata/arrow.py create mode 100644 tests/integration/test_results_arrow.py create mode 100644 tests/test_arrow.py diff --git a/.github/workflows/regenerate.yml b/.github/workflows/regenerate.yml index c5643f8..f45e2ab 100644 --- a/.github/workflows/regenerate.yml +++ b/.github/workflows/regenerate.yml @@ -84,6 +84,14 @@ jobs: 'keywords = ["hotdata", "api-client", "data-platform"]', re.MULTILINE, ), + # Insert [project.optional-dependencies] (for hotdata.arrow) just + # before [project.urls]. Run before the urls patch so the urls + # anchor is unchanged when this fires. + ( + r'(\ndependencies = \[\n(?:[^\]]|\][^\n])*\]\n)\n(\[project\.urls\])', + r'\1\n[project.optional-dependencies]\narrow = ["pyarrow >= 14"]\n\n\2', + 0, + ), ( r'\[project\.urls\]\nRepository = "[^"]*"\n', '[project.urls]\nHomepage = "https://www.hotdata.dev"\nRepository = "https://github.com/hotdata-dev/sdk-python"\n', diff --git a/README.md b/README.md index 4dc12be..ada8de8 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,35 @@ with hotdata.ApiClient(configuration) as api_client: Each `Api` class groups endpoints by resource. Construct the client, then call the typed methods you need. +## Arrow results + +Query results can be fetched as an [Apache Arrow](https://arrow.apache.org/) IPC stream instead of JSON, which is faster and far more memory-efficient for large result sets. Install the optional extra: + +```sh +pip install 'hotdata[arrow]' +``` + +Use `hotdata.arrow.ResultsApi` (a drop-in subclass of `ResultsApi` that adds Arrow methods): + +```python +from hotdata import ApiClient, Configuration +from hotdata.arrow import ResultsApi + +with ApiClient(Configuration(api_key="...", workspace_id="...")) as client: + results = ResultsApi(client) + + # Buffered: returns a pyarrow.Table. + table = results.get_result_arrow(result_id) + + # Streaming: yields a pyarrow.RecordBatchStreamReader without + # materializing the full table in memory. + with results.stream_result_arrow(result_id) as reader: + for batch in reader: + ... +``` + +Both methods accept `offset` and `limit` for pagination. They raise `hotdata.arrow.ResultNotReadyError` if the result is still pending or processing — poll `results.get_result(result_id)` until `status == "ready"` first. + ## API reference Generated Markdown for every operation and model is in [`docs/`](https://github.com/hotdata-dev/sdk-python/tree/main/docs): diff --git a/hotdata/arrow.py b/hotdata/arrow.py new file mode 100644 index 0000000..3526f4b --- /dev/null +++ b/hotdata/arrow.py @@ -0,0 +1,190 @@ +"""Arrow IPC helpers for ``GET /v1/results/{id}``. + +The auto-generated :class:`hotdata.api.results_api.ResultsApi` understands the +``format=arrow`` query parameter but cannot decode the +``application/vnd.apache.arrow.stream`` response body — openapi-generator picks +the JSON content variant for status 200 and routes Arrow bytes through the +JSON deserializer, which raises ``Unsupported content type``. + +This module wraps the generated client with a thin subclass that: + +* sets ``Accept: application/vnd.apache.arrow.stream`` and ``?format=arrow``, +* uses the generator's ``*_without_preload_content`` plumbing to hold the + underlying ``urllib3.HTTPResponse`` open as a byte stream, +* hands that stream to ``pyarrow.ipc.open_stream`` so callers get a + :class:`pyarrow.Table` (or a :class:`pyarrow.RecordBatchStreamReader` for + the streaming variant). + +Install with ``pip install 'hotdata[arrow]'`` to pull in pyarrow. +""" + +from __future__ import annotations + +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional + +from hotdata.api.results_api import ResultsApi as _GeneratedResultsApi +from hotdata.models.results_format_query import ResultsFormatQuery + +if TYPE_CHECKING: # pragma: no cover - import-time only for type checkers + import pyarrow as pa # type: ignore[import-untyped] + + +ARROW_STREAM_MEDIA_TYPE = "application/vnd.apache.arrow.stream" + + +class ResultNotReadyError(Exception): + """Raised when the result exists but is not yet ``ready``. + + The server replies with HTTP 202 while a result is ``pending`` or + ``processing``. Poll :meth:`ResultsApi.get_result` until ``status='ready'`` + before fetching as Arrow. + """ + + def __init__(self, status: str, result_id: str) -> None: + self.status = status + self.result_id = result_id + super().__init__( + f"Result {result_id} is not ready (status={status!r}); " + "poll get_result until status='ready' before fetching as Arrow." + ) + + +def _import_pyarrow() -> Any: + try: + import pyarrow.ipc as ipc # type: ignore[import-untyped] + except ImportError as exc: # pragma: no cover - exercised via tests + raise ImportError( + "pyarrow is required to fetch results as Arrow. " + "Install with: pip install 'hotdata[arrow]'" + ) from exc + return ipc + + +class ResultsApi(_GeneratedResultsApi): + """Drop-in replacement for :class:`hotdata.api.results_api.ResultsApi` + that adds Arrow IPC fetch helpers. + + All methods on the base class continue to work unchanged. + """ + + def get_result_arrow( + self, + id: str, + *, + offset: Optional[int] = None, + limit: Optional[int] = None, + _request_timeout: Any = None, + ) -> "pa.Table": + """Fetch a ready result as a :class:`pyarrow.Table`. + + Buffers the full Arrow IPC stream into memory before returning. Use + :meth:`stream_result_arrow` for large results where you want to + iterate batches without materializing the whole table. + + :param id: Result ID. + :param offset: Rows to skip (default: 0). + :param limit: Maximum rows to return (default: unbounded). + :raises ResultNotReadyError: result is still pending or processing. + :raises hotdata.exceptions.ApiException: for other HTTP errors + (400 invalid params, 404 not found, 409 failed result). + """ + ipc = _import_pyarrow() + response = self._call_arrow(id=id, offset=offset, limit=limit, + _request_timeout=_request_timeout) + try: + return ipc.open_stream(response).read_all() + finally: + response.release_conn() + + @contextmanager + def stream_result_arrow( + self, + id: str, + *, + offset: Optional[int] = None, + limit: Optional[int] = None, + _request_timeout: Any = None, + ) -> Iterator["pa.RecordBatchStreamReader"]: + """Yield a :class:`pyarrow.RecordBatchStreamReader` for a ready result. + + The HTTP connection is released when the context exits. Iterate the + reader to consume :class:`pyarrow.RecordBatch` messages, or call + ``reader.read_all()`` for a full :class:`pyarrow.Table`. + + Example:: + + with results.stream_result_arrow(result_id) as reader: + for batch in reader: + process(batch) + + :raises ResultNotReadyError: result is still pending or processing. + :raises hotdata.exceptions.ApiException: for other HTTP errors. + """ + ipc = _import_pyarrow() + response = self._call_arrow(id=id, offset=offset, limit=limit, + _request_timeout=_request_timeout) + try: + yield ipc.open_stream(response) + finally: + response.release_conn() + + def _call_arrow( + self, + *, + id: str, + offset: Optional[int], + limit: Optional[int], + _request_timeout: Any, + ) -> Any: + # Build the request via the generator's private serialize helper so + # path/query/auth handling stays in lockstep with the generated client. + # Override only what we need: the Accept header and the format query. + headers: Dict[str, Any] = {"Accept": ARROW_STREAM_MEDIA_TYPE} + params = self._get_result_serialize( + id=id, + offset=offset, + limit=limit, + format=ResultsFormatQuery.ARROW, + _request_auth=None, + _content_type=None, + _headers=headers, + _host_index=0, + ) + response_data = self.api_client.call_api( + *params, + _request_timeout=_request_timeout, + ) + + if response_data.status == 200: + # Hand the raw urllib3.HTTPResponse to the caller. preload_content + # was False on the way in, so the body has not been consumed. + return response_data.response + + # Non-200: drain, deserialize as JSON, then raise. response_deserialize + # raises ApiException for status >= 400 itself; only 202 falls through. + try: + response_data.read() + body = self.api_client.response_deserialize( + response_data=response_data, + response_types_map={ + "202": "GetResultResponse", + "400": "ApiErrorResponse", + "404": "ApiErrorResponse", + "409": "GetResultResponse", + }, + ).data + finally: + response_data.response.release_conn() + + raise ResultNotReadyError( + status=getattr(body, "status", "pending"), + result_id=getattr(body, "result_id", id), + ) + + +__all__ = [ + "ARROW_STREAM_MEDIA_TYPE", + "ResultNotReadyError", + "ResultsApi", +] diff --git a/pyproject.toml b/pyproject.toml index d52bf27..467ba48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,9 @@ dependencies = [ "typing-extensions (>=4.7.1)", ] +[project.optional-dependencies] +arrow = ["pyarrow >= 14"] + [project.urls] Homepage = "https://www.hotdata.dev" Repository = "https://github.com/hotdata-dev/sdk-python" diff --git a/tests/integration/test_results_arrow.py b/tests/integration/test_results_arrow.py new file mode 100644 index 0000000..82d3163 --- /dev/null +++ b/tests/integration/test_results_arrow.py @@ -0,0 +1,101 @@ +"""Scenario: results_arrow. + +Submit a small query, poll until the result is ready, then fetch the result +as a pyarrow.Table via hotdata.arrow.ResultsApi.get_result_arrow. Verifies +that Arrow IPC content negotiation works end-to-end and that the streaming +variant yields the same data. + +Skipped if pyarrow is not installed (the helper requires the ``arrow`` extra). +""" + +from __future__ import annotations + +import time + +import pytest + +pa = pytest.importorskip("pyarrow") + +from hotdata.api.query_api import QueryApi +from hotdata.api.query_runs_api import QueryRunsApi +from hotdata.arrow import ResultsApi +from hotdata.models.query_request import QueryRequest + + +TERMINAL_STATUSES = {"succeeded", "failed", "cancelled"} +POLL_TIMEOUT_S = 60.0 +POLL_INTERVAL_S = 1.0 + + +@pytest.fixture +def query_api(api_client) -> QueryApi: + return QueryApi(api_client) + + +@pytest.fixture +def query_runs_api(api_client) -> QueryRunsApi: + return QueryRunsApi(api_client) + + +@pytest.fixture +def results_api(api_client) -> ResultsApi: + return ResultsApi(api_client) + + +def test_results_arrow( + query_api: QueryApi, + query_runs_api: QueryRunsApi, + results_api: ResultsApi, +) -> None: + submitted = query_api.query( + QueryRequest( + var_async=True, + async_after_ms=1000, + sql="SELECT 1 AS x, 'hello' AS msg UNION ALL SELECT 2, 'world'", + ) + ) + query_run_id = submitted.query_run_id + assert query_run_id + + deadline = time.monotonic() + POLL_TIMEOUT_S + run = None + while time.monotonic() < deadline: + run = query_runs_api.get_query_run(query_run_id) + if run.status in TERMINAL_STATUSES: + break + time.sleep(POLL_INTERVAL_S) + assert run is not None + assert run.status == "succeeded", ( + f"expected succeeded, got {run.status}: {run.error_message}" + ) + assert run.result_id, "succeeded run must expose a result_id" + result_id = run.result_id + + # Wait for ready before fetching as Arrow — get_result_arrow raises + # ResultNotReadyError on 202. + deadline = time.monotonic() + POLL_TIMEOUT_S + while time.monotonic() < deadline: + result = results_api.get_result(result_id) + if result.status == "ready": + break + time.sleep(POLL_INTERVAL_S) + else: + pytest.fail(f"result {result_id} never became ready") + + # Buffered: returns a full pyarrow.Table. + table = results_api.get_result_arrow(result_id) + assert isinstance(table, pa.Table) + assert table.num_rows == 2 + assert set(table.column_names) == {"x", "msg"} + assert table.column("x").to_pylist() == [1, 2] + assert table.column("msg").to_pylist() == ["hello", "world"] + + # Streaming: same data via RecordBatchStreamReader. + with results_api.stream_result_arrow(result_id) as reader: + streamed = pa.Table.from_batches(list(reader), schema=reader.schema) + assert streamed.equals(table) + + # Pagination forwards correctly. + sliced = results_api.get_result_arrow(result_id, offset=1, limit=1) + assert sliced.num_rows == 1 + assert sliced.column("x").to_pylist() == [2] diff --git a/tests/test_arrow.py b/tests/test_arrow.py new file mode 100644 index 0000000..27b7dda --- /dev/null +++ b/tests/test_arrow.py @@ -0,0 +1,283 @@ +"""Unit tests for hotdata.arrow. + +These tests stub out the underlying urllib3 transport so they don't require a +running server. They build real Arrow IPC byte streams with pyarrow and feed +them through the SDK plumbing to verify: + +* the Arrow Table round-trips schema and values, +* the request carries ``Accept: application/vnd.apache.arrow.stream`` and + ``?format=arrow``, +* offset / limit are forwarded as query params, +* non-200 responses raise the right exceptions, +* the streaming variant yields RecordBatches and releases the connection. +""" + +from __future__ import annotations + +import io +import json +from typing import Any, Dict, List, Optional, Tuple +from unittest.mock import patch + +import pytest + +pa = pytest.importorskip("pyarrow") +pa_ipc = pytest.importorskip("pyarrow.ipc") + +from hotdata import ApiClient, Configuration +from hotdata.arrow import ( + ARROW_STREAM_MEDIA_TYPE, + ResultNotReadyError, + ResultsApi, +) +from hotdata.exceptions import ApiException + + +def _arrow_bytes(table: Any) -> bytes: + sink = io.BytesIO() + with pa_ipc.new_stream(sink, table.schema) as writer: + writer.write_table(table) + return sink.getvalue() + + +class _FakeUrllib3Response(io.RawIOBase): + """Minimal stand-in for urllib3.HTTPResponse. + + pyarrow.ipc.open_stream wants a real file-like object (it checks + ``closed`` and ``readable()``); the SDK's RESTResponse needs ``status``, + ``reason``, ``data``, and ``headers``. release_conn is recorded. + """ + + def __init__(self, status: int, body: bytes, headers: Dict[str, str]): + super().__init__() + self.status = status + self.reason = "OK" if 200 <= status < 300 else "Error" + self._body = body + self._buf = io.BytesIO(body) + self.headers = headers + self.release_conn_called = False + + @property + def data(self) -> bytes: + # urllib3.HTTPResponse.data preloads the body. SDK RESTResponse.read() + # reads from this attribute. + return self._body + + def read(self, amt: Optional[int] = -1) -> bytes: + if amt is None or amt < 0: + return self._buf.read() + return self._buf.read(amt) + + def readable(self) -> bool: + return True + + def release_conn(self) -> None: + self.release_conn_called = True + + +def _install_fake_response( + monkeypatch: pytest.MonkeyPatch, + response: _FakeUrllib3Response, + captured: List[Dict[str, Any]], +) -> None: + """Replace RESTClientObject.request with a stub that records the call.""" + + from hotdata import rest + + def fake_request( + self: Any, + method: str, + url: str, + headers: Optional[Dict[str, str]] = None, + body: Any = None, + post_params: Any = None, + _request_timeout: Any = None, + ) -> rest.RESTResponse: + captured.append( + { + "method": method, + "url": url, + "headers": dict(headers or {}), + } + ) + return rest.RESTResponse(response) + + monkeypatch.setattr(rest.RESTClientObject, "request", fake_request) + + +def _make_results_api() -> Tuple[ResultsApi, ApiClient]: + config = Configuration( + host="https://api.hotdata.test", + api_key="test-key", + workspace_id="ws_test", + ) + client = ApiClient(config) + return ResultsApi(client), client + + +def _sample_table() -> Any: + return pa.table( + { + "id": pa.array([1, 2, 3], type=pa.int64()), + "label": pa.array(["a", "b", "c"], type=pa.string()), + } + ) + + +# --- Happy path ----------------------------------------------------------- + + +def test_get_result_arrow_returns_table(monkeypatch: pytest.MonkeyPatch) -> None: + table = _sample_table() + fake = _FakeUrllib3Response( + status=200, + body=_arrow_bytes(table), + headers={"content-type": ARROW_STREAM_MEDIA_TYPE}, + ) + captured: List[Dict[str, Any]] = [] + _install_fake_response(monkeypatch, fake, captured) + + results, _ = _make_results_api() + got = results.get_result_arrow("res_123") + + assert got.equals(table) + assert fake.release_conn_called + + # Single request was made. + assert len(captured) == 1 + call = captured[0] + assert call["method"] == "GET" + # Path templating preserved. + assert "/v1/results/res_123" in call["url"] + # format=arrow query param sent. + assert "format=arrow" in call["url"] + # Accept header overrides the JSON default. + assert call["headers"]["Accept"] == ARROW_STREAM_MEDIA_TYPE + + +def test_get_result_arrow_forwards_offset_and_limit( + monkeypatch: pytest.MonkeyPatch, +) -> None: + fake = _FakeUrllib3Response( + status=200, + body=_arrow_bytes(_sample_table()), + headers={"content-type": ARROW_STREAM_MEDIA_TYPE}, + ) + captured: List[Dict[str, Any]] = [] + _install_fake_response(monkeypatch, fake, captured) + + results, _ = _make_results_api() + results.get_result_arrow("res_123", offset=10, limit=100) + + url = captured[0]["url"] + assert "offset=10" in url + assert "limit=100" in url + + +def test_stream_result_arrow_yields_reader( + monkeypatch: pytest.MonkeyPatch, +) -> None: + table = _sample_table() + fake = _FakeUrllib3Response( + status=200, + body=_arrow_bytes(table), + headers={"content-type": ARROW_STREAM_MEDIA_TYPE}, + ) + _install_fake_response(monkeypatch, fake, []) + + results, _ = _make_results_api() + with results.stream_result_arrow("res_123") as reader: + batches = list(reader) + assert batches, "expected at least one RecordBatch" + roundtrip = pa.Table.from_batches(batches, schema=reader.schema) + assert roundtrip.equals(table) + + # Connection is released after the context exits. + assert fake.release_conn_called + + +# --- Non-200 paths -------------------------------------------------------- + + +def test_get_result_arrow_raises_when_not_ready( + monkeypatch: pytest.MonkeyPatch, +) -> None: + body = json.dumps( + { + "result_id": "res_pending", + "status": "processing", + "row_count": 0, + } + ).encode() + fake = _FakeUrllib3Response( + status=202, + body=body, + headers={"content-type": "application/json"}, + ) + _install_fake_response(monkeypatch, fake, []) + + results, _ = _make_results_api() + with pytest.raises(ResultNotReadyError) as ei: + results.get_result_arrow("res_pending") + + assert ei.value.status == "processing" + assert ei.value.result_id == "res_pending" + assert fake.release_conn_called + + +def test_get_result_arrow_raises_api_exception_on_404( + monkeypatch: pytest.MonkeyPatch, +) -> None: + body = json.dumps({"error": {"message": "not found", "code": "not_found"}}).encode() + fake = _FakeUrllib3Response( + status=404, + body=body, + headers={"content-type": "application/json"}, + ) + _install_fake_response(monkeypatch, fake, []) + + results, _ = _make_results_api() + with pytest.raises(ApiException) as ei: + results.get_result_arrow("res_missing") + + assert ei.value.status == 404 + assert fake.release_conn_called + + +def test_get_result_arrow_raises_api_exception_on_409_failed( + monkeypatch: pytest.MonkeyPatch, +) -> None: + body = json.dumps( + { + "result_id": "res_failed", + "status": "failed", + "error_message": "boom", + "row_count": 0, + } + ).encode() + fake = _FakeUrllib3Response( + status=409, + body=body, + headers={"content-type": "application/json"}, + ) + _install_fake_response(monkeypatch, fake, []) + + results, _ = _make_results_api() + with pytest.raises(ApiException) as ei: + results.get_result_arrow("res_failed") + + assert ei.value.status == 409 + + +# --- Missing pyarrow ------------------------------------------------------ + + +def test_helpful_error_when_pyarrow_missing() -> None: + """``_import_pyarrow`` re-raises ImportError with an install hint.""" + from hotdata import arrow as arrow_module + + with patch.dict("sys.modules", {"pyarrow.ipc": None}): + # patch.dict with None value forces ImportError on `import pyarrow.ipc`. + with pytest.raises(ImportError) as ei: + arrow_module._import_pyarrow() + assert "hotdata[arrow]" in str(ei.value)