Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 200 additions & 46 deletions dpsynth/text/bulk_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@

"""Bulk LLM inference for synthetic text generation."""

from collections.abc import Sequence
from collections.abc import Callable, Sequence
import concurrent.futures
import dataclasses
import enum
import functools
import random
import re
from typing import Protocol
import time
from typing import Protocol, TypeVar

from absl import logging
from google import genai
Expand Down Expand Up @@ -63,10 +67,10 @@ def annotate(

Args:
texts: Input texts to annotate.
schema: Pydantic model class defining the features to extract. The model's
field names, ``Literal`` type annotations, and field descriptions guide
the LLM. This same class is used as the ``response_schema`` for
constrained decoding in supported backends.
schema: Pydantic model class defining the features to extract. Fields may
use ``Literal`` type annotations for constrained decoding or plain types
such as ``str`` for open-ended annotation. Field names and descriptions
guide the LLM.
system_prompt: System-level instructions for the LLM describing how to
annotate the texts.

Expand All @@ -90,7 +94,7 @@ def generate(self, prompts: Sequence[str]) -> list[str]:
...


@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class GenAIBackend:
"""TextGenerationBackend using the google.genai API.

Expand All @@ -100,65 +104,179 @@ class GenAIBackend:
Attributes:
model: Model name string (e.g., ``'gemini-2.5-flash-lite'``). Accepts any
``ModelName`` enum value or arbitrary string for unlisted models.
api_key: API key for authentication. If None, uses Application Default
Credentials (ADC).
api_key: API key for authentication.
poll_interval_seconds: How often to poll for batch job completion.
chunk_size: Number of texts per batch job.
max_concurrent_jobs: Maximum number of active parallel batch jobs.
"""

model: str = ModelName.GEMINI_2_5_FLASH_LITE
api_key: str | None = None
poll_interval_seconds: int = 60
chunk_size: int = 100
max_concurrent_jobs: int = 8

def _make_client(self) -> genai.Client:
"""Creates a genai client."""
kwargs: dict[str, object] = {
'http_options': types.HttpOptions(api_version='v1alpha'),
}
if self.api_key is not None:
@functools.cached_property
def client(self) -> genai.Client:
"""Creates and caches a genai.Client."""
kwargs = {'http_options': types.HttpOptions(api_version='v1alpha')}

if self.api_key:
kwargs['api_key'] = self.api_key
return genai.Client(**kwargs)

def _parse_job_responses(
self,
batch_job: types.BatchJob,
schema: type[pydantic.BaseModel],
) -> list[dict[str, str | None]]:
"""Parses responses from a completed BatchJob."""
if batch_job.state != types.JobState.JOB_STATE_SUCCEEDED:
error_msg = (
f'Batch job {batch_job.name} ended with state={batch_job.state}.'
)
if batch_job.error:
error_msg += f' Error: {batch_job.error}'
raise RuntimeError(error_msg)

null_row = {f: None for f in schema.model_fields.keys()}

inlined_responses = (
batch_job.dest.inlined_responses if batch_job.dest else []
) or []

chunk_rows = []
for i, inlined_resp in enumerate(inlined_responses):
row = dict(null_row) # Default to null row
try:
if inlined_resp.error:
logging.warning(
'Batch result %d in job %s had error: %s',
i,
batch_job.name,
inlined_resp.error,
)
else:
response_text = (
inlined_resp.response.text if inlined_resp.response else None
)
if response_text:
row = schema.model_validate_json(
_strip_markdown_fences(response_text)
).model_dump()
else:
logging.warning(
'Empty batch response in job %s for text %d.',
batch_job.name,
i,
)
except Exception as e: # pylint: disable=broad-except
logging.warning(
'Failed to parse batch result %d in job %s: %s',
i,
batch_job.name,
e,
)
chunk_rows.append(row)
return chunk_rows

def _submit_and_poll_chunk(
self,
chunk_texts: Sequence[str],
config: types.GenerateContentConfig | None = None,
) -> types.BatchJob:
"""Submit a batch job for one chunk and poll until done."""

inlined_requests = [
types.InlinedRequest(contents=text, config=config)
for text in chunk_texts
]

job = _call_with_retry(
lambda: self.client.batches.create(
model=self.model, src=inlined_requests
),
'create',
)
logging.info('Batch annotate: job %s created.', job.name)

while not job.done:
time.sleep(self.poll_interval_seconds)
job = _call_with_retry(
lambda: self.client.batches.get(name=job.name), 'get'
)

logging.info(
'Batch annotate: job %s completed with state=%s',
job.name,
job.state,
)
return job

def annotate(
self,
texts: Sequence[str],
schema: type[pydantic.BaseModel],
system_prompt: str,
) -> pd.DataFrame:
"""Extract structured features via constrained decoding.
"""Extract structured features via the GenAI Batch API.

Submits texts as inlined requests to the batch prediction endpoint,
polls for completion, and parses the inlined responses.

Args:
texts: Input texts to annotate.
schema: Pydantic model used as the ``response_schema`` for constrained
decoding.
schema: Pydantic model used as the ``response_schema``.
system_prompt: System-level instructions for the LLM.

Returns:
DataFrame with exactly ``len(texts)`` rows. Failed rows have ``None``.

Raises:
RuntimeError: If the batch job fails or is cancelled.
"""
client = self._make_client()
field_names = list(schema.model_fields.keys())
null_row = {f: None for f in field_names}
rows: list[dict[str, str | None]] = []
for i, text in enumerate(texts):
try:
response = client.models.generate_content(
model=self.model,
contents=text,
config=types.GenerateContentConfig(
system_instruction=system_prompt,
response_mime_type='application/json',
response_schema=schema,
),
config = types.GenerateContentConfig(
system_instruction=system_prompt,
response_mime_type='application/json',
response_schema=schema,
)

chunks = [
texts[i : i + self.chunk_size]
for i in range(0, len(texts), self.chunk_size)
]

logging.info(
'Batch annotate: processing %d chunks with concurrency limit %d...',
len(chunks),
self.max_concurrent_jobs,
)

with concurrent.futures.ThreadPoolExecutor(
max_workers=self.max_concurrent_jobs
) as pool:
completed_jobs = list(
pool.map(
functools.partial(self._submit_and_poll_chunk, config=config),
chunks,
)
)

logging.info('Batch annotate: all jobs completed. Parsing responses...')

all_rows = []
for batch_job, chunk_texts in zip(completed_jobs, chunks, strict=True):
chunk_rows = self._parse_job_responses(batch_job, schema)

if len(chunk_rows) != len(chunk_texts):
raise ValueError(
f'Batch annotate: job {batch_job.name} got {len(chunk_rows)}'
f' results for {len(chunk_texts)} inputs.'
)
if response.text:
cleaned = _strip_markdown_fences(response.text)
parsed = schema.model_validate_json(cleaned)
rows.append(parsed.model_dump())
else:
logging.warning('Empty annotation response for text %d.', i)
rows.append(null_row)
except Exception: # pylint: disable=broad-except
logging.warning('Annotation failed for text %d.', i)
rows.append(null_row)
return pd.DataFrame(rows)

all_rows.extend(chunk_rows)

return pd.DataFrame(all_rows)

def generate(self, prompts: Sequence[str]) -> list[str]:
"""Generate free-form text via google.genai.
Expand All @@ -169,7 +287,7 @@ def generate(self, prompts: Sequence[str]) -> list[str]:
Returns:
List of exactly ``len(prompts)`` strings. Empty string on failure.
"""
client = self._make_client()
client = self.client
results: list[str] = []
for i, prompt in enumerate(prompts):
try:
Expand All @@ -178,8 +296,10 @@ def generate(self, prompts: Sequence[str]) -> list[str]:
contents=prompt,
)
results.append(response.text or '')
except Exception: # pylint: disable=broad-except
logging.warning('Generation failed for prompt %d.', i)
except Exception as e: # pylint: disable=broad-except
logging.warning(
'Generation failed for prompt %d. Error: %s', i, e, exc_info=True
)
results.append('')
return results

Expand All @@ -189,3 +309,37 @@ def _strip_markdown_fences(text):
regex = r'^\s*```(?:json)?\s*\n(.*?)\n\s*```\s*$'
m = re.compile(regex, re.DOTALL).match(text)
return m.group(1).strip() if m else text.strip()


T = TypeVar('T')


def _call_with_retry(
func: Callable[[], T],
op_name: str,
max_retries: int = 10,
initial_delay: float = 5.0,
) -> T:
"""Calls `func` with exponential backoff on exceptions."""
delay = initial_delay
for attempt in range(1, max_retries + 1):
try:
return func()
except Exception as e: # pylint: disable=broad-except
if attempt == max_retries:
logging.error(
'Batch %s failed after %d attempts.', op_name, max_retries
)
raise

sleep_time = delay + random.uniform(0, 5)
logging.warning(
'Batch %s failed (attempt %d/%d): %s. Retrying in %.1f sec...',
op_name,
attempt,
max_retries,
e,
sleep_time,
)
time.sleep(sleep_time)
delay *= 2
Loading
Loading