diff --git a/dpsynth/text/bulk_inference.py b/dpsynth/text/bulk_inference.py index 13a436e..1206be9 100644 --- a/dpsynth/text/bulk_inference.py +++ b/dpsynth/text/bulk_inference.py @@ -14,11 +14,15 @@ """Bulk LLM inference for synthetic text generation.""" -from collections.abc import Sequence +from collections.abc import Callable, Sequence +import concurrent.futures import dataclasses import enum +import functools +import random import re -from typing import Protocol +import time +from typing import Protocol, TypeVar from absl import logging from google import genai @@ -63,10 +67,10 @@ def annotate( Args: texts: Input texts to annotate. - schema: Pydantic model class defining the features to extract. The model's - field names, ``Literal`` type annotations, and field descriptions guide - the LLM. This same class is used as the ``response_schema`` for - constrained decoding in supported backends. + schema: Pydantic model class defining the features to extract. Fields may + use ``Literal`` type annotations for constrained decoding or plain types + such as ``str`` for open-ended annotation. Field names and descriptions + guide the LLM. system_prompt: System-level instructions for the LLM describing how to annotate the texts. @@ -90,7 +94,7 @@ def generate(self, prompts: Sequence[str]) -> list[str]: ... -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class GenAIBackend: """TextGenerationBackend using the google.genai API. @@ -100,65 +104,179 @@ class GenAIBackend: Attributes: model: Model name string (e.g., ``'gemini-2.5-flash-lite'``). Accepts any ``ModelName`` enum value or arbitrary string for unlisted models. - api_key: API key for authentication. If None, uses Application Default - Credentials (ADC). + api_key: API key for authentication. + poll_interval_seconds: How often to poll for batch job completion. + chunk_size: Number of texts per batch job. + max_concurrent_jobs: Maximum number of active parallel batch jobs. """ model: str = ModelName.GEMINI_2_5_FLASH_LITE api_key: str | None = None + poll_interval_seconds: int = 60 + chunk_size: int = 100 + max_concurrent_jobs: int = 8 - def _make_client(self) -> genai.Client: - """Creates a genai client.""" - kwargs: dict[str, object] = { - 'http_options': types.HttpOptions(api_version='v1alpha'), - } - if self.api_key is not None: + @functools.cached_property + def client(self) -> genai.Client: + """Creates and caches a genai.Client.""" + kwargs = {'http_options': types.HttpOptions(api_version='v1alpha')} + + if self.api_key: kwargs['api_key'] = self.api_key return genai.Client(**kwargs) + def _parse_job_responses( + self, + batch_job: types.BatchJob, + schema: type[pydantic.BaseModel], + ) -> list[dict[str, str | None]]: + """Parses responses from a completed BatchJob.""" + if batch_job.state != types.JobState.JOB_STATE_SUCCEEDED: + error_msg = ( + f'Batch job {batch_job.name} ended with state={batch_job.state}.' + ) + if batch_job.error: + error_msg += f' Error: {batch_job.error}' + raise RuntimeError(error_msg) + + null_row = {f: None for f in schema.model_fields.keys()} + + inlined_responses = ( + batch_job.dest.inlined_responses if batch_job.dest else [] + ) or [] + + chunk_rows = [] + for i, inlined_resp in enumerate(inlined_responses): + row = dict(null_row) # Default to null row + try: + if inlined_resp.error: + logging.warning( + 'Batch result %d in job %s had error: %s', + i, + batch_job.name, + inlined_resp.error, + ) + else: + response_text = ( + inlined_resp.response.text if inlined_resp.response else None + ) + if response_text: + row = schema.model_validate_json( + _strip_markdown_fences(response_text) + ).model_dump() + else: + logging.warning( + 'Empty batch response in job %s for text %d.', + batch_job.name, + i, + ) + except Exception as e: # pylint: disable=broad-except + logging.warning( + 'Failed to parse batch result %d in job %s: %s', + i, + batch_job.name, + e, + ) + chunk_rows.append(row) + return chunk_rows + + def _submit_and_poll_chunk( + self, + chunk_texts: Sequence[str], + config: types.GenerateContentConfig | None = None, + ) -> types.BatchJob: + """Submit a batch job for one chunk and poll until done.""" + + inlined_requests = [ + types.InlinedRequest(contents=text, config=config) + for text in chunk_texts + ] + + job = _call_with_retry( + lambda: self.client.batches.create( + model=self.model, src=inlined_requests + ), + 'create', + ) + logging.info('Batch annotate: job %s created.', job.name) + + while not job.done: + time.sleep(self.poll_interval_seconds) + job = _call_with_retry( + lambda: self.client.batches.get(name=job.name), 'get' + ) + + logging.info( + 'Batch annotate: job %s completed with state=%s', + job.name, + job.state, + ) + return job + def annotate( self, texts: Sequence[str], schema: type[pydantic.BaseModel], system_prompt: str, ) -> pd.DataFrame: - """Extract structured features via constrained decoding. + """Extract structured features via the GenAI Batch API. + + Submits texts as inlined requests to the batch prediction endpoint, + polls for completion, and parses the inlined responses. Args: texts: Input texts to annotate. - schema: Pydantic model used as the ``response_schema`` for constrained - decoding. + schema: Pydantic model used as the ``response_schema``. system_prompt: System-level instructions for the LLM. Returns: DataFrame with exactly ``len(texts)`` rows. Failed rows have ``None``. + + Raises: + RuntimeError: If the batch job fails or is cancelled. """ - client = self._make_client() - field_names = list(schema.model_fields.keys()) - null_row = {f: None for f in field_names} - rows: list[dict[str, str | None]] = [] - for i, text in enumerate(texts): - try: - response = client.models.generate_content( - model=self.model, - contents=text, - config=types.GenerateContentConfig( - system_instruction=system_prompt, - response_mime_type='application/json', - response_schema=schema, - ), + config = types.GenerateContentConfig( + system_instruction=system_prompt, + response_mime_type='application/json', + response_schema=schema, + ) + + chunks = [ + texts[i : i + self.chunk_size] + for i in range(0, len(texts), self.chunk_size) + ] + + logging.info( + 'Batch annotate: processing %d chunks with concurrency limit %d...', + len(chunks), + self.max_concurrent_jobs, + ) + + with concurrent.futures.ThreadPoolExecutor( + max_workers=self.max_concurrent_jobs + ) as pool: + completed_jobs = list( + pool.map( + functools.partial(self._submit_and_poll_chunk, config=config), + chunks, + ) + ) + + logging.info('Batch annotate: all jobs completed. Parsing responses...') + + all_rows = [] + for batch_job, chunk_texts in zip(completed_jobs, chunks, strict=True): + chunk_rows = self._parse_job_responses(batch_job, schema) + + if len(chunk_rows) != len(chunk_texts): + raise ValueError( + f'Batch annotate: job {batch_job.name} got {len(chunk_rows)}' + f' results for {len(chunk_texts)} inputs.' ) - if response.text: - cleaned = _strip_markdown_fences(response.text) - parsed = schema.model_validate_json(cleaned) - rows.append(parsed.model_dump()) - else: - logging.warning('Empty annotation response for text %d.', i) - rows.append(null_row) - except Exception: # pylint: disable=broad-except - logging.warning('Annotation failed for text %d.', i) - rows.append(null_row) - return pd.DataFrame(rows) + + all_rows.extend(chunk_rows) + + return pd.DataFrame(all_rows) def generate(self, prompts: Sequence[str]) -> list[str]: """Generate free-form text via google.genai. @@ -169,7 +287,7 @@ def generate(self, prompts: Sequence[str]) -> list[str]: Returns: List of exactly ``len(prompts)`` strings. Empty string on failure. """ - client = self._make_client() + client = self.client results: list[str] = [] for i, prompt in enumerate(prompts): try: @@ -178,8 +296,10 @@ def generate(self, prompts: Sequence[str]) -> list[str]: contents=prompt, ) results.append(response.text or '') - except Exception: # pylint: disable=broad-except - logging.warning('Generation failed for prompt %d.', i) + except Exception as e: # pylint: disable=broad-except + logging.warning( + 'Generation failed for prompt %d. Error: %s', i, e, exc_info=True + ) results.append('') return results @@ -189,3 +309,37 @@ def _strip_markdown_fences(text): regex = r'^\s*```(?:json)?\s*\n(.*?)\n\s*```\s*$' m = re.compile(regex, re.DOTALL).match(text) return m.group(1).strip() if m else text.strip() + + +T = TypeVar('T') + + +def _call_with_retry( + func: Callable[[], T], + op_name: str, + max_retries: int = 10, + initial_delay: float = 5.0, +) -> T: + """Calls `func` with exponential backoff on exceptions.""" + delay = initial_delay + for attempt in range(1, max_retries + 1): + try: + return func() + except Exception as e: # pylint: disable=broad-except + if attempt == max_retries: + logging.error( + 'Batch %s failed after %d attempts.', op_name, max_retries + ) + raise + + sleep_time = delay + random.uniform(0, 5) + logging.warning( + 'Batch %s failed (attempt %d/%d): %s. Retrying in %.1f sec...', + op_name, + attempt, + max_retries, + e, + sleep_time, + ) + time.sleep(sleep_time) + delay *= 2 diff --git a/tests/text/bulk_inference_test.py b/tests/text/bulk_inference_test.py index 995b647..d06622b 100644 --- a/tests/text/bulk_inference_test.py +++ b/tests/text/bulk_inference_test.py @@ -72,63 +72,185 @@ def test_strips_whitespace(self): ) +@mock.patch('google.genai.Client', autospec=True) class GenAIBackendAnnotateTest(absltest.TestCase): + """Tests for GenAIBackend.annotate.""" - @mock.patch('google.genai.Client') - def test_annotate_index_aligned_on_success(self, mock_client_cls): - mock_client = mock_client_cls.return_value + @staticmethod + def _create_inlined_response( + text: str | None = None, error: str | None = None + ) -> mock.MagicMock: + """Creates a mock inlined batch response.""" mock_response = mock.MagicMock() - mock_response.text = '{"topic": "Science", "complexity": "High"}' - mock_client.models.generate_content.return_value = mock_response + mock_response.text = text + return mock.MagicMock(error=error, response=mock_response if text else None) + + @staticmethod + def _create_mock_job( + state, + inlined_responses: list[mock.MagicMock] | None = None, + done_side_effect: list[bool] | None = None, + inlined_responses_side_effect: list[list[mock.MagicMock]] | None = None, + error: str | None = None, + name: str = 'job', + ) -> mock.MagicMock: + """Creates a mock batch job.""" + job = mock.MagicMock() + job.name = name + if done_side_effect is not None: + type(job).done = mock.PropertyMock(side_effect=done_side_effect) + else: + type(job).done = mock.PropertyMock(return_value=True) + job.state = state + job.error = error + + mock_dest = mock.MagicMock() + if inlined_responses_side_effect is not None: + type(mock_dest).inlined_responses = mock.PropertyMock( + side_effect=inlined_responses_side_effect + ) + elif inlined_responses is not None: + mock_dest.inlined_responses = inlined_responses + job.dest = mock_dest + return job + + def test_annotate_success(self, mock_client_cls): + mock_client = mock_client_cls.return_value + inlined_resp = self._create_inlined_response( + '{"topic": "Science", "complexity": "High"}' + ) + mock_batch_job = self._create_mock_job( + state=bulk_inference.types.JobState.JOB_STATE_SUCCEEDED, + inlined_responses=[inlined_resp, inlined_resp], + ) - backend = bulk_inference.GenAIBackend() + mock_client.batches.create.return_value = mock_batch_job + mock_client.batches.get.return_value = mock_batch_job - df = backend.annotate(['text1', 'text2'], SimpleFeatures, 'Annotate.') + backend = bulk_inference.GenAIBackend( + api_key='fake', poll_interval_seconds=0 + ) + df = backend.annotate(['text1', 'text2'], SimpleFeatures, 'Sys.') self.assertLen(df, 2) - self.assertListEqual(list(df.columns), ['topic', 'complexity']) + self.assertEqual(df.iloc[0]['topic'], 'Science') + self.assertEqual(df.iloc[1]['topic'], 'Science') - @mock.patch('google.genai.Client') - def test_annotate_fills_none_on_failure(self, mock_client_cls): + def test_annotate_raises_on_length_mismatch(self, mock_client_cls): mock_client = mock_client_cls.return_value - mock_client.models.generate_content.side_effect = RuntimeError('API down') + inlined_resp = self._create_inlined_response( + '{"topic": "Science", "complexity": "High"}' + ) + mock_batch_job = self._create_mock_job( + state=bulk_inference.types.JobState.JOB_STATE_SUCCEEDED, + inlined_responses=[inlined_resp], + ) - backend = bulk_inference.GenAIBackend() + mock_client.batches.create.return_value = mock_batch_job + mock_client.batches.get.return_value = mock_batch_job - df = backend.annotate(['text1', 'text2', 'text3'], SimpleFeatures, 'Sys.') - self.assertLen(df, 3) - self.assertTrue(df.iloc[0].isna().all()) - self.assertTrue(df.iloc[2].isna().all()) + backend = bulk_inference.GenAIBackend( + api_key='fake', poll_interval_seconds=0 + ) + with self.assertRaisesRegex(ValueError, 'got 1 results for 2 inputs'): + backend.annotate(['text1', 'text2'], SimpleFeatures, 'Sys.') - @mock.patch('google.genai.Client') - def test_annotate_mixed_success_and_failure(self, mock_client_cls): + def test_annotate_fills_none_on_item_error(self, mock_client_cls): mock_client = mock_client_cls.return_value - good = mock.MagicMock() - good.text = '{"topic": "Technology", "complexity": "Low"}' - bad = RuntimeError('fail') + inlined_error = self._create_inlined_response(error='Failed item') + inlined_success = self._create_inlined_response( + '{"topic": "Science", "complexity": "High"}' + ) + + mock_batch_job = self._create_mock_job( + state=bulk_inference.types.JobState.JOB_STATE_SUCCEEDED, + inlined_responses=[inlined_error, inlined_success], + ) - mock_client.models.generate_content.side_effect = [good, bad, good] + mock_client.batches.create.return_value = mock_batch_job + mock_client.batches.get.return_value = mock_batch_job - backend = bulk_inference.GenAIBackend() + backend = bulk_inference.GenAIBackend( + api_key='fake', poll_interval_seconds=0 + ) + df = backend.annotate(['text1', 'text2'], SimpleFeatures, 'Sys.') + self.assertLen(df, 2) + self.assertTrue(pd.isna(df.iloc[0]['topic'])) + self.assertEqual(df.iloc[1]['topic'], 'Science') - df = backend.annotate(['a', 'b', 'c'], SimpleFeatures, 'Sys.') - self.assertLen(df, 3) - self.assertEqual(df.iloc[0]['topic'], 'Technology') - self.assertTrue(pd.isna(df.iloc[1]['topic'])) - self.assertEqual(df.iloc[2]['topic'], 'Technology') + def test_annotate_raises_on_failed_job(self, mock_client_cls): + mock_client = mock_client_cls.return_value + mock_batch_job = self._create_mock_job( + state=bulk_inference.types.JobState.JOB_STATE_FAILED, + error='Something went wrong', + ) - @mock.patch('google.genai.Client') - def test_annotate_handles_markdown_fenced_json(self, mock_client_cls): + mock_client.batches.create.return_value = mock_batch_job + mock_client.batches.get.return_value = mock_batch_job + + backend = bulk_inference.GenAIBackend( + api_key='fake', poll_interval_seconds=0 + ) + with self.assertRaisesRegex( + RuntimeError, + 'Batch job .* ended with state.* Error: Something went wrong', + ): + backend.annotate(['text1'], SimpleFeatures, 'Sys.') + + def test_annotate_respects_class_chunk_size(self, mock_client_cls): mock_client = mock_client_cls.return_value - fenced = mock.MagicMock() - fenced.text = '```json\n{"topic": "Science", "complexity": "High"}\n```' - mock_client.models.generate_content.return_value = fenced + inlined_resp = self._create_inlined_response( + '{"topic": "Science", "complexity": "High"}' + ) + mock_batch_job = self._create_mock_job( + state=bulk_inference.types.JobState.JOB_STATE_SUCCEEDED, + inlined_responses=[inlined_resp, inlined_resp], + ) - backend = bulk_inference.GenAIBackend() + mock_client.batches.create.return_value = mock_batch_job + mock_client.batches.get.return_value = mock_batch_job - df = backend.annotate(['text1'], SimpleFeatures, 'Sys.') - self.assertLen(df, 1) - self.assertEqual(df.iloc[0]['topic'], 'Science') - self.assertEqual(df.iloc[0]['complexity'], 'High') + backend = bulk_inference.GenAIBackend( + api_key='fake', poll_interval_seconds=0, chunk_size=2 + ) + # 4 texts, chunk_size=2 -> 2 chunks/jobs + backend.annotate(['t1', 't2', 't3', 't4'], SimpleFeatures, 'Sys.') + self.assertEqual(mock_client.batches.create.call_count, 2) + + def test_annotate_respects_max_concurrent_jobs(self, mock_client_cls): + mock_client = mock_client_cls.return_value + create_resp = GenAIBackendAnnotateTest._create_inlined_response + create_job = GenAIBackendAnnotateTest._create_mock_job + + jobs = [] + + def create_side_effect(model, src): + del model # Unused. + inlined_resp = create_resp('{"topic": "Science", "complexity": "High"}') + job = create_job( + state=bulk_inference.types.JobState.JOB_STATE_SUCCEEDED, + inlined_responses=[inlined_resp] * len(src), + done_side_effect=[False, True], + name=f'job-{len(jobs)}', + ) + jobs.append(job) + return job + + mock_client.batches.create.side_effect = create_side_effect + + def get_side_effect(name): + idx = int(name.split('-')[1]) + return jobs[idx] + + mock_client.batches.get.side_effect = get_side_effect + + backend = bulk_inference.GenAIBackend( + api_key='fake', + poll_interval_seconds=0, + chunk_size=2, + max_concurrent_jobs=1, + ) + backend.annotate(['t1', 't2', 't3', 't4', 't5'], SimpleFeatures, 'Sys.') + self.assertEqual(mock_client.batches.create.call_count, 3) class GenAIBackendGenerateTest(absltest.TestCase):