From f9f3b2d3337e0b194783c9c25bff0cea2f1f6e0a Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Wed, 1 Jul 2026 14:51:24 -0600 Subject: [PATCH 1/2] refactor: remove verified dead code Remove unreachable helpers and production code exercised only by tests after a triple review of static references, runtime registration paths, and extension surfaces. Keep reusable observability fixtures in the explicit engine.testing namespace. Closes #789 Signed-off-by: Nabin Mulepati --- .../data_designer/config/utils/io_helpers.py | 45 +- .../tests/config/utils/test_io_helpers.py | 90 ---- .../utils/column_statistics_calculations.py | 2 - .../src/data_designer/engine/capacity.py | 123 ----- .../column_generators/generators/custom.py | 32 -- .../dataset_builders/async_scheduler.py | 193 ------- .../dataset_builders/dataset_builder.py | 11 - .../dataset_builders/scheduling/completion.py | 42 +- .../dataset_builders/scheduling/task_model.py | 13 +- .../utils/async_concurrency.py | 180 +------ .../dataset_builders/utils/config_compiler.py | 7 - .../engine/dataset_builders/utils/errors.py | 3 - .../dataset_builders/utils/execution_graph.py | 37 -- .../utils/row_group_buffer.py | 4 - .../dataset_builders/utils/skip_tracker.py | 89 ---- .../src/data_designer/engine/errors.py | 3 - .../engine/models/clients/errors.py | 26 - .../engine/models/parsers/postprocessors.py | 36 -- .../models/request_admission/controller.py | 11 - .../src/data_designer/engine/observability.py | 29 -- .../engine/resources/resource_provider.py | 25 - .../engine/storage/media_storage.py | 4 - .../data_designer/engine/testing/__init__.py | 3 + .../engine/testing/observability.py | 39 ++ .../scheduling/test_completion.py | 109 +--- .../scheduling/test_task_model.py | 19 +- .../dataset_builders/test_async_scheduler.py | 139 +---- .../dataset_builders/test_dataset_builder.py | 21 +- .../utils/test_async_concurrency.py | 478 ------------------ .../utils/test_execution_graph.py | 43 -- .../utils/test_row_group_buffer.py | 9 - .../utils/test_skip_tracker.py | 76 --- .../models/clients/test_client_errors.py | 44 -- .../clients/test_model_request_executor.py | 2 +- .../engine/models/parsers/test_parser.py | 21 +- .../models/parsers/test_postprocessors.py | 62 --- .../request_admission/test_controller.py | 2 +- .../resources/test_resource_provider.py | 45 +- .../tests/engine/test_capacity.py | 74 --- .../tests/engine/test_engine_errors.py | 6 +- .../tests/engine/test_observability.py | 3 +- .../cli/repositories/persona_repository.py | 20 - .../cli/services/plugin_install_service.py | 4 - .../data-designer/src/data_designer/cli/ui.py | 7 - .../repositories/test_persona_repository.py | 20 +- .../services/test_plugin_install_service.py | 4 +- .../data-designer/tests/cli/test_cli_utils.py | 11 - 47 files changed, 108 insertions(+), 2158 deletions(-) delete mode 100644 packages/data-designer-engine/src/data_designer/engine/capacity.py create mode 100644 packages/data-designer-engine/src/data_designer/engine/testing/observability.py delete mode 100644 packages/data-designer-engine/tests/engine/dataset_builders/utils/test_async_concurrency.py delete mode 100644 packages/data-designer-engine/tests/engine/test_capacity.py diff --git a/packages/data-designer-config/src/data_designer/config/utils/io_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/io_helpers.py index 8247a36b7..e71ec3a10 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/io_helpers.py +++ b/packages/data-designer-config/src/data_designer/config/utils/io_helpers.py @@ -9,7 +9,7 @@ from datetime import date, datetime, timedelta from decimal import Decimal from numbers import Number -from pathlib import Path, PurePosixPath +from pathlib import Path from typing import TYPE_CHECKING, Any from urllib.parse import urlparse @@ -29,15 +29,6 @@ VALID_CONFIG_FILE_EXTENSIONS = {".yaml", ".yml", ".json"} -def ensure_config_dir_exists(config_dir: Path) -> None: - """Create configuration directory if it doesn't exist. - - Args: - config_dir: Directory path to create - """ - config_dir.mkdir(parents=True, exist_ok=True) - - def load_config_file(file_path: Path) -> dict: """Load a YAML configuration file. @@ -177,40 +168,6 @@ def validate_path_contains_files_of_type(path: str | Path, file_extension: str) raise InvalidFilePathError(f"🛑 Path {path!r} does not contain files of type {file_extension!r}.") -def smart_load_dataframe(dataframe: str | Path | pd.DataFrame) -> pd.DataFrame: - """Load a dataframe from file if a path is given, otherwise return the dataframe. - - Args: - dataframe: A path to a file or a pandas DataFrame object. - - Returns: - A pandas DataFrame object. - """ - if isinstance(dataframe, lazy.pd.DataFrame): - return dataframe - - # Get the file extension. - if isinstance(dataframe, str) and dataframe.startswith("http"): - dataframe = _maybe_rewrite_url(dataframe) - # Parse extension from the URL path to avoid query-string contamination (e.g. "csv?token=…"). - ext = PurePosixPath(urlparse(dataframe).path).suffix.lstrip(".").lower() - else: - dataframe = Path(dataframe) - ext = dataframe.suffix.lower() - if not dataframe.exists(): - raise FileNotFoundError(f"File not found: {dataframe}") - - # Load the dataframe based on the file extension. - if ext == "csv": - return lazy.pd.read_csv(dataframe) - elif ext == "json": - return lazy.pd.read_json(dataframe, lines=True) - elif ext == "parquet": - return lazy.pd.read_parquet(dataframe) - else: - raise ValueError(f"Unsupported file format: {dataframe}") - - def smart_load_yaml(yaml_in: str | Path | dict) -> dict: """Return the yaml config as a dict given flexible input types. diff --git a/packages/data-designer-config/tests/config/utils/test_io_helpers.py b/packages/data-designer-config/tests/config/utils/test_io_helpers.py index 5b2b8b98a..1d65d945d 100644 --- a/packages/data-designer-config/tests/config/utils/test_io_helpers.py +++ b/packages/data-designer-config/tests/config/utils/test_io_helpers.py @@ -6,8 +6,6 @@ import tempfile from datetime import date, datetime, timedelta from decimal import Decimal -from pathlib import Path -from typing import TYPE_CHECKING from unittest.mock import MagicMock, patch import pytest @@ -19,66 +17,9 @@ _maybe_rewrite_url, is_http_url, serialize_data, - smart_load_dataframe, smart_load_yaml, ) -if TYPE_CHECKING: - import pandas as pd - - -@patch("data_designer.config.utils.io_helpers.Path", autospec=True) -@patch("data_designer.config.utils.io_helpers.lazy.pd.read_csv", autospec=True) -@patch("data_designer.config.utils.io_helpers.lazy.pd.read_json", autospec=True) -@patch("data_designer.config.utils.io_helpers.lazy.pd.read_parquet", autospec=True) -def test_smart_load_dataframe(mock_read_parquet, mock_read_json, mock_read_csv, mock_path_cls, stub_dataframe): - mock_read_parquet.return_value = stub_dataframe - mock_read_json.return_value = stub_dataframe - mock_read_csv.return_value = stub_dataframe - - # dataframe objects are passed through - assert smart_load_dataframe(stub_dataframe).size == stub_dataframe.size - - # url based - stub_base_url = "https://example.com/data.{extention}" - url_csv = stub_base_url.format(extention="csv") - smart_load_dataframe(url_csv) - mock_read_csv.assert_called_once_with(url_csv) - - url_json = stub_base_url.format(extention="json") - smart_load_dataframe(url_json) - mock_read_json.assert_called_once_with(url_json, lines=True) - - url_parquet = stub_base_url.format(extention="parquet") - smart_load_dataframe(url_parquet) - mock_read_parquet.assert_called_once_with(url_parquet) - - url_unknown = stub_base_url.format(extention="unknown") - with pytest.raises(ValueError): - smart_load_dataframe(url_unknown) - - # local file based - mock_read_csv.reset_mock() - mock_read_json.reset_mock() - mock_read_parquet.reset_mock() - - mock_path = MagicMock(autospec=Path) - mock_path.exists.return_value = True - mock_path.suffix.lower.return_value = "csv" - mock_path_cls.return_value = mock_path - - stub_base_path_str = "/some/path/to/data.{extension}" - path_csv = stub_base_path_str.format(extension="csv") - _ = smart_load_dataframe(path_csv) - mock_read_csv.assert_called_once_with(mock_path) - - mock_path.reset_mock() - mock_path.suffix.lower.return_value = "json" - mock_path.exists.return_value = False - path_json = stub_base_path_str.format(extension="json") - with pytest.raises(FileNotFoundError): - _ = smart_load_dataframe(Path(path_json)) - def test_smart_load_yaml(): stub_dict = { @@ -348,37 +289,6 @@ def test_smart_load_yaml_rewrites_huggingface_blob_url(mock_requests: MagicMock) ) -@patch("data_designer.config.utils.io_helpers.lazy.pd.read_csv", autospec=True) -def test_smart_load_dataframe_rewrites_github_blob_url(mock_read_csv: MagicMock, stub_dataframe: pd.DataFrame) -> None: - mock_read_csv.return_value = stub_dataframe - - smart_load_dataframe("https://github.com/org/repo/blob/main/data.csv") - - mock_read_csv.assert_called_once_with("https://raw.githubusercontent.com/org/repo/main/data.csv") - - -@patch("data_designer.config.utils.io_helpers.lazy.pd.read_csv", autospec=True) -def test_smart_load_dataframe_rewrites_github_blob_url_with_token( - mock_read_csv: MagicMock, stub_dataframe: pd.DataFrame -) -> None: - mock_read_csv.return_value = stub_dataframe - - smart_load_dataframe("https://github.com/org/repo/blob/main/data.csv?token=secret123") - - mock_read_csv.assert_called_once_with("https://raw.githubusercontent.com/org/repo/main/data.csv?token=secret123") - - -@patch("data_designer.config.utils.io_helpers.lazy.pd.read_csv", autospec=True) -def test_smart_load_dataframe_rewrites_huggingface_blob_url( - mock_read_csv: MagicMock, stub_dataframe: pd.DataFrame -) -> None: - mock_read_csv.return_value = stub_dataframe - - smart_load_dataframe("https://huggingface.co/datasets/org/repo/blob/main/data.csv") - - mock_read_csv.assert_called_once_with("https://huggingface.co/datasets/org/repo/raw/main/data.csv") - - def test_maybe_rewrite_github_url_log_does_not_leak_query(caplog: pytest.LogCaptureFixture) -> None: import logging diff --git a/packages/data-designer-engine/src/data_designer/engine/analysis/utils/column_statistics_calculations.py b/packages/data-designer-engine/src/data_designer/engine/analysis/utils/column_statistics_calculations.py index 14f65a781..2c85e5d25 100644 --- a/packages/data-designer-engine/src/data_designer/engine/analysis/utils/column_statistics_calculations.py +++ b/packages/data-designer-engine/src/data_designer/engine/analysis/utils/column_statistics_calculations.py @@ -32,8 +32,6 @@ RANDOM_SEED = 42 MAX_PROMPT_SAMPLE_SIZE = 1000 WARNING_PREFIX = "⚠️ Error during column profile calculation: " -TEXT_FIELD_AVG_SPACE_COUNT_THRESHOLD = 0.1 - logger = logging.getLogger(__name__) diff --git a/packages/data-designer-engine/src/data_designer/engine/capacity.py b/packages/data-designer-engine/src/data_designer/engine/capacity.py deleted file mode 100644 index e10a729e7..000000000 --- a/packages/data-designer-engine/src/data_designer/engine/capacity.py +++ /dev/null @@ -1,123 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from dataclasses import dataclass, field -from typing import Generic, Literal, TypeVar - -from data_designer.engine.dataset_builders.scheduling.resources import SchedulerResourceKey, TaskGroupKey -from data_designer.engine.models.request_admission.config import RequestAdmissionConfig -from data_designer.engine.models.request_admission.resources import RequestResourceKey -from data_designer.engine.models.resources import ProviderModelKey, ProviderModelStaticCap - -_T = TypeVar("_T") - -CapacityValueSource = Literal[ - "default", - "run_config", - "dataset_builder", - "model_metadata", - "engine_internal_config", - "adapter_config", - "environment", - "runtime_snapshot", - "benchmark_override", -] - - -@dataclass(frozen=True) -class CapacityValue(Generic[_T]): - value: _T | None - source: CapacityValueSource - fallback_from: str | None = None - missing_reason: str | None = None - - -@dataclass(frozen=True) -class RowGroupAdmission: - row_group_concurrency: CapacityValue[int] - observed_in_flight: int | None = None - mode: Literal["fixed", "adaptive"] = "fixed" - target_in_flight: int | None = None - observed_max_target: int | None = None - max_admitted_rows: int | None = None - blocked_reasons: Mapping[str, int] = field(default_factory=dict) - - -@dataclass(frozen=True) -class RequestAdmissionConfigSnapshot: - resources: Sequence[RequestResourceKey] - initial_limits: Mapping[RequestResourceKey, int] - max_limit_clamps: Mapping[RequestResourceKey, int | None] - cooldown_seconds: float - multiplicative_decrease_factor: float - additive_increase_step: int - successes_until_increase: int - startup_ramp_seconds: float - default_queue_wait_timeout_seconds: float | None - - @classmethod - def from_config(cls, config: RequestAdmissionConfig) -> RequestAdmissionConfigSnapshot: - resources = tuple(sorted({*config.initial_limits, *config.max_limit_clamps})) - return cls( - resources=resources, - initial_limits=dict(config.initial_limits), - max_limit_clamps=dict(config.max_limit_clamps), - cooldown_seconds=config.cooldown_seconds, - multiplicative_decrease_factor=config.multiplicative_decrease_factor, - additive_increase_step=config.additive_increase_step, - successes_until_increase=config.successes_until_increase, - startup_ramp_seconds=config.startup_ramp_seconds, - default_queue_wait_timeout_seconds=config.default_queue_wait_timeout_seconds, - ) - - -@dataclass(frozen=True) -class AsyncCapacityConfigured: - buffer_size: CapacityValue[int] - row_group_admission: RowGroupAdmission - submission_capacity: CapacityValue[int] - task_resource_limits: CapacityValue[Mapping[SchedulerResourceKey, int]] - request_resources: CapacityValue[Sequence[RequestResourceKey]] - provider_model_static_caps: CapacityValue[Mapping[ProviderModelKey, ProviderModelStaticCap]] - request_domain_initial_limits: CapacityValue[Mapping[RequestResourceKey, int]] - request_admission_config: CapacityValue[RequestAdmissionConfigSnapshot] - transport_pool_limits: CapacityValue[Mapping[ProviderModelKey, int]] - - -@dataclass(frozen=True) -class AsyncCapacityRuntimeSnapshot: - request_domain_current_limits: Mapping[RequestResourceKey, int] | None = None - request_domain_effective_max: Mapping[RequestResourceKey, int] | None = None - request_domain_blocked_until: Mapping[RequestResourceKey, float | None] | None = None - provider_model_aggregate_in_flight: Mapping[ProviderModelKey, int] | None = None - - -@dataclass(frozen=True) -class AsyncCapacityObservedMaxima: - row_groups_in_flight: int = 0 - queued_tasks_by_group: Mapping[TaskGroupKey | str, int] = field(default_factory=dict) - task_leases_by_resource: Mapping[SchedulerResourceKey, int] = field(default_factory=dict) - request_waiters_by_resource: Mapping[RequestResourceKey, int] = field(default_factory=dict) - request_in_flight_by_resource: Mapping[RequestResourceKey, int] = field(default_factory=dict) - provider_model_aggregate_in_flight: Mapping[ProviderModelKey, int] = field(default_factory=dict) - request_domain_current_limits: Mapping[RequestResourceKey, int] = field(default_factory=dict) - transport_pool_utilization: Mapping[ProviderModelKey, int] | None = None - - -@dataclass(frozen=True) -class AsyncCapacityPlan: - configured: AsyncCapacityConfigured - runtime_snapshot: AsyncCapacityRuntimeSnapshot - observed_maxima: AsyncCapacityObservedMaxima - - -def missing_capacity_value( - *, - source: CapacityValueSource, - missing_reason: str, - fallback_from: str | None = None, -) -> CapacityValue[object]: - return CapacityValue(value=None, source=source, fallback_from=fallback_from, missing_reason=missing_reason) diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py index 49e7d5b03..8cbd9833b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py @@ -277,38 +277,6 @@ def _postprocess_result( return self._validate_output(result, keys_before, is_dataframe) - def _validate_cell_output(self, row: dict, keys_before: set[str]) -> dict: - """Validate a single row output (dict) for cell_by_cell; strip undeclared columns.""" - expected_new = {self.config.name} | set(self.config.side_effect_columns) - result_keys = set(row.keys()) - - if self.config.name not in result_keys: - raise CustomColumnGenerationError( - f"Custom generator for column '{self.config.name}' did not create the expected column. " - f"The generator_function must add a column named '{self.config.name}' to the row." - ) - missing = set(self.config.side_effect_columns) - result_keys - if missing: - raise CustomColumnGenerationError( - f"Custom generator for column '{self.config.name}' did not create declared side_effect_columns: " - f"{sorted(missing)}. Declared side_effect_columns must be added to the row." - ) - removed = keys_before - result_keys - if removed: - raise CustomColumnGenerationError( - f"Custom generator for column '{self.config.name}' removed pre-existing columns: " - f"{sorted(removed)}. The generator_function must not remove any existing columns." - ) - undeclared = (result_keys - keys_before) - expected_new - if undeclared: - logger.warning( - f"⚠️ Custom generator for column '{self.config.name}' created undeclared columns: " - f"{sorted(undeclared)}. These columns will be removed. " - f"To keep additional columns, declare them in @custom_column_generator(side_effect_columns=[...])." - ) - row = {k: v for k, v in row.items() if k not in undeclared} - return row - def _validate_output( self, result: dict | pd.DataFrame, keys_before: set[str], is_dataframe: bool ) -> dict | pd.DataFrame: diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index c7e848b5a..8adde98a8 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -16,15 +16,6 @@ import data_designer.lazy_heavy_imports as lazy from data_designer.config.column_configs import ExpressionColumnConfig, GenerationStrategy -from data_designer.engine.capacity import ( - AsyncCapacityConfigured, - AsyncCapacityObservedMaxima, - AsyncCapacityPlan, - AsyncCapacityRuntimeSnapshot, - CapacityValue, - RequestAdmissionConfigSnapshot, - RowGroupAdmission, -) from data_designer.engine.context import current_row_group, current_row_group_start_offset from data_designer.engine.dataset_builders.errors import DatasetGenerationError from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig @@ -71,9 +62,6 @@ ModelRateLimitError, ModelRequestAdmissionTimeoutError, ) -from data_designer.engine.models.request_admission.config import RequestAdmissionConfig -from data_designer.engine.models.request_admission.resources import RequestResourceKey -from data_designer.engine.models.resources import ProviderModelKey, ProviderModelStaticCap from data_designer.engine.observability import ( RuntimeCorrelation, SchedulerAdmissionEvent, @@ -303,20 +291,11 @@ def __init__( self._first_non_retryable_error: Exception | None = None self._fatal_worker_error: BaseException | None = None - self._max_concurrent_row_groups = max_concurrent_row_groups self._max_in_flight_tasks = max_in_flight_tasks - self._max_model_task_admission = max_model_task_admission self._num_records = num_records self._buffer_size = buffer_size self._scheduled_records = self._row_groups.scheduled_total_rows self._initial_completed_records = initial_completed_records - self._observed_max_row_groups_in_flight = 0 - self._observed_max_task_leases_by_resource: dict[str, int] = {} - self._observed_max_queued_by_group: dict[str, int] = {} - self._observed_max_request_waiters_by_resource: dict[RequestResourceKey, int] = {} - self._observed_max_request_in_flight_by_resource: dict[RequestResourceKey, int] = {} - self._observed_max_provider_model_aggregate_in_flight: dict[ProviderModelKey, int] = {} - self._observed_max_request_domain_current_limits: dict[RequestResourceKey, int] = {} self._adaptive_row_group_admission = adaptive_row_group_admission self._row_group_admission_hard_cap = max(1, max_concurrent_row_groups) self._row_group_admission_target = ( @@ -324,7 +303,6 @@ def __init__( if adaptive_row_group_admission else self._row_group_admission_hard_cap ) - self._observed_max_row_group_admission_target = self._row_group_admission_target self._row_group_admission_event = asyncio.Event() self._row_group_admission_event.set() self._row_group_admission_pressure_ticks = 0 @@ -466,39 +444,6 @@ def _emit_scheduler_event( logger.warning("Scheduler admission event sink raised; dropping event.", exc_info=True) return - def _record_observed_task_state(self) -> None: - self._observed_max_row_groups_in_flight = max(self._observed_max_row_groups_in_flight, len(self._rg_states)) - view = self._task_admission.view() - for resource, count in view.leased_resources.items(): - self._observed_max_task_leases_by_resource[resource] = max( - self._observed_max_task_leases_by_resource.get(resource, 0), - count, - ) - queue_view = self._fair_queue.view() - for group, count in queue_view.queued_by_group.items(): - label = f"{group.kind}:{'/'.join(group.identity)}" - self._observed_max_queued_by_group[label] = max(self._observed_max_queued_by_group.get(label, 0), count) - if self._request_pressure_provider is None: - return - for resource, snapshot in self._request_pressure_provider.snapshots().items(): - self._observed_max_request_waiters_by_resource[resource] = max( - self._observed_max_request_waiters_by_resource.get(resource, 0), - snapshot.waiters, - ) - self._observed_max_request_in_flight_by_resource[resource] = max( - self._observed_max_request_in_flight_by_resource.get(resource, 0), - snapshot.in_flight_count, - ) - self._observed_max_request_domain_current_limits[resource] = max( - self._observed_max_request_domain_current_limits.get(resource, 0), - snapshot.current_limit, - ) - for provider_model, snapshot in self._request_pressure_provider.global_snapshots().items(): - self._observed_max_provider_model_aggregate_in_flight[provider_model] = max( - self._observed_max_provider_model_aggregate_in_flight.get(provider_model, 0), - snapshot.aggregate_in_flight, - ) - def _emit_scheduler_health_snapshot(self, reason: str) -> None: if self._scheduler_event_sink is None: return @@ -660,9 +605,6 @@ def _apply_frontier_delta(self, delta: FrontierDelta) -> None: self._discard_ready_task(task) self._enqueue_ready_tasks(delta.added) - def _enqueue_ready_task(self, task: Task) -> None: - self._enqueue_ready_tasks((task,)) - def _enqueue_ready_tasks(self, tasks: tuple[Task, ...]) -> None: schedulables: list[SchedulableTask] = [] accepted_tasks_by_id: dict[str, Task] = {} @@ -689,7 +631,6 @@ def _enqueue_ready_tasks(self, tasks: tuple[Task, ...]) -> None: self._tracker.mark_enqueued(accepted) for task_id in accepted: self._emit_scheduler_event("ready_enqueued", task=accepted_tasks_by_id[task_id]) - self._record_observed_task_state() self._wake_event.set() def _discard_ready_task(self, task: Task) -> None: @@ -760,7 +701,6 @@ def _dispatch_queued_tasks(self) -> _DispatchOutcome: self._dispatch_selected_task(committed, decision) dispatched = True - self._record_observed_task_state() if dispatched: self._emit_scheduler_event("queue_drained") @@ -932,10 +872,6 @@ def _maybe_update_adaptive_row_group_target(self) -> None: return old_target = self._row_group_admission_target self._row_group_admission_target = min(self._row_group_admission_hard_cap, old_target + 1) - self._observed_max_row_group_admission_target = max( - self._observed_max_row_group_admission_target, - self._row_group_admission_target, - ) self._row_group_admission_pressure_ticks = 0 if self._row_group_admission_target != old_target: self._emit_scheduler_event( @@ -1791,7 +1727,6 @@ async def _execute_task_inner_impl(self, task: Task, lease: TaskAdmissionLease, task_execution_id=task_execution_id, reason_or_result=release_result.reason, ) - self._record_observed_task_state() self._wake_event.set() async def _run_generator_call(self, task: Task, operation: str, call: Coroutine[Any, Any, Any]) -> Any: @@ -2160,134 +2095,6 @@ def task_admission_config(self) -> TaskAdmissionConfig: """Return the effective scheduler task-admission config.""" return self._task_admission_config - def capacity_plan(self) -> AsyncCapacityPlan: - """Return the scheduler-side async capacity explanation for this run.""" - task_view = self._task_admission.view() - request_snapshots = ( - dict(self._request_pressure_provider.snapshots()) if self._request_pressure_provider is not None else {} - ) - provider_snapshots = ( - dict(self._request_pressure_provider.global_snapshots()) - if self._request_pressure_provider is not None - else {} - ) - request_resources = tuple(sorted(request_snapshots)) - provider_model_static_caps = { - provider_model: ProviderModelStaticCap( - cap=snapshot.static_cap, - aliases=snapshot.aliases, - raw_caps=snapshot.raw_caps, - ) - for provider_model, snapshot in provider_snapshots.items() - } - request_config = self._request_pressure_provider.config if self._request_pressure_provider is not None else None - request_config_snapshot = ( - RequestAdmissionConfigSnapshot.from_config(request_config) - if isinstance(request_config, RequestAdmissionConfig) - else None - ) - request_domain_initial_limits: dict[RequestResourceKey, int] = {} - if request_config_snapshot is not None: - request_domain_initial_limits.update(request_config_snapshot.initial_limits) - for resource, snapshot in request_snapshots.items(): - configured_initial = ( - request_config_snapshot.initial_limits.get(resource) if request_config_snapshot is not None else None - ) - request_domain_initial_limits[resource] = ( - max(1, min(configured_initial, snapshot.effective_max)) - if configured_initial is not None - else snapshot.effective_max - ) - request_domain_current_limits = { - resource: snapshot.current_limit for resource, snapshot in request_snapshots.items() - } - request_domain_effective_max = { - resource: snapshot.effective_max for resource, snapshot in request_snapshots.items() - } - request_domain_blocked_until = { - resource: snapshot.blocked_until_monotonic for resource, snapshot in request_snapshots.items() - } - provider_model_aggregate_in_flight = { - provider_model: snapshot.aggregate_in_flight for provider_model, snapshot in provider_snapshots.items() - } - return AsyncCapacityPlan( - configured=AsyncCapacityConfigured( - buffer_size=CapacityValue(value=self._buffer_size, source="run_config"), - row_group_admission=RowGroupAdmission( - row_group_concurrency=CapacityValue( - value=self._max_concurrent_row_groups, - source="dataset_builder", - ), - observed_in_flight=len(self._rg_states), - mode="adaptive" if self._adaptive_row_group_admission else "fixed", - target_in_flight=self._row_group_admission_target, - observed_max_target=self._observed_max_row_group_admission_target, - max_admitted_rows=self._adaptive_max_admitted_rows, - blocked_reasons=dict(self._row_group_admission_blocked_reasons), - ), - submission_capacity=CapacityValue(value=self._max_in_flight_tasks, source="run_config"), - task_resource_limits=CapacityValue( - value=dict(self._task_admission_config.resource_limits), - source="engine_internal_config", - ), - request_resources=CapacityValue( - value=request_resources, - source="runtime_snapshot", - missing_reason=None if request_resources else "request admission has not observed any resources", - ), - provider_model_static_caps=CapacityValue( - value=provider_model_static_caps, - source="model_metadata", - missing_reason=None if provider_model_static_caps else "request admission has no registered models", - ), - request_domain_initial_limits=CapacityValue( - value=request_domain_initial_limits, - source="engine_internal_config" if request_config_snapshot is not None else "runtime_snapshot", - missing_reason=None - if request_domain_initial_limits - else "request admission has not observed any domain limits", - ), - request_admission_config=CapacityValue( - value=request_config_snapshot, - source="engine_internal_config", - missing_reason=None - if request_config_snapshot is not None - else "request admission config is not exposed by the pressure provider", - ), - transport_pool_limits=CapacityValue( - value={}, - source="adapter_config", - missing_reason="transport pool utilization is adapter-specific", - ), - ), - runtime_snapshot=AsyncCapacityRuntimeSnapshot( - request_domain_current_limits=request_domain_current_limits, - request_domain_effective_max=request_domain_effective_max, - request_domain_blocked_until=request_domain_blocked_until, - provider_model_aggregate_in_flight=provider_model_aggregate_in_flight, - ), - observed_maxima=AsyncCapacityObservedMaxima( - row_groups_in_flight=self._observed_max_row_groups_in_flight, - queued_tasks_by_group=dict(self._observed_max_queued_by_group), - task_leases_by_resource=dict(self._observed_max_task_leases_by_resource or task_view.leased_resources), - request_waiters_by_resource=dict( - self._observed_max_request_waiters_by_resource - or {resource: snapshot.waiters for resource, snapshot in request_snapshots.items()} - ), - request_in_flight_by_resource=dict( - self._observed_max_request_in_flight_by_resource - or {resource: snapshot.in_flight_count for resource, snapshot in request_snapshots.items()} - ), - provider_model_aggregate_in_flight=dict( - self._observed_max_provider_model_aggregate_in_flight or provider_model_aggregate_in_flight - ), - request_domain_current_limits=dict( - self._observed_max_request_domain_current_limits or request_domain_current_limits - ), - transport_pool_utilization=None, - ), - ) - @staticmethod def _is_retryable(exc: BaseException) -> bool: """Classify whether an exception is retryable.""" diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py index 118e242a1..81dd19416 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py @@ -218,13 +218,6 @@ def first_non_retryable_error(self) -> Exception | None: """First non-retryable error captured by the scheduler in the most recent run.""" return self._first_non_retryable_error - def set_processor_runner(self, processors: list[Processor]) -> None: - """Replace the processor runner with a new one using the given processors.""" - self._processor_runner = ProcessorRunner( - processors=processors, - artifact_storage=self.artifact_storage, - ) - @functools.cached_property def single_column_configs(self) -> list[ColumnConfigT]: configs = [] @@ -235,10 +228,6 @@ def single_column_configs(self) -> list[ColumnConfigT]: configs.append(config) return configs - @functools.cached_property - def single_column_config_by_name(self) -> dict[str, ColumnConfigT]: - return {config.name: config for config in self.single_column_configs} - def build( self, *, diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/completion.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/completion.py index 855c91642..8c7839358 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/completion.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/completion.py @@ -37,9 +37,8 @@ class CompletionTracker: Row indices are local to their row group (0-based). - Use ``with_graph`` to create a frontier-enabled tracker where - ``get_ready_tasks`` returns in O(frontier) instead of scanning all - columns x rows x row groups. + Use ``with_graph`` to create a frontier-enabled tracker that incrementally + maintains dependency-ready tasks. """ def __init__(self) -> None: @@ -93,20 +92,6 @@ def mark_row_range_complete(self, column: str, row_group: int, row_group_size: i def is_complete(self, ref: SliceRef) -> bool: return ref.row_index in self._completed.get(ref.row_group, {}).get(ref.column, set()) - def is_all_complete(self, cells: list[SliceRef]) -> bool: - """Check whether all the given cells are done. - - A ``row_index`` of ``None`` means the entire batch for that column must - have been completed via ``mark_row_range_complete``. - """ - for ref in cells: - if ref.row_index is None: - if ref.column not in self._batch_complete.get(ref.row_group, set()): - return False - elif not self.is_complete(ref): - return False - return True - def is_column_complete_for_rg(self, column: str, row_group_index: int) -> bool: """Check if *column* has been fully completed for *row_group_index*.""" if column in self._batch_complete.get(row_group_index, set()): @@ -173,33 +158,10 @@ def add_ready_tasks(self, tasks: list[Task] | tuple[Task, ...]) -> FrontierDelta added.append(task) return self._record_delta(added=added, removed=[]) - def get_ready_tasks(self, dispatched: set[Task], admitted_rgs: set[int] | None = None) -> list[Task]: - """Return all currently dispatchable tasks from the frontier.""" - return [ - t - for t in self.ready_frontier() - if t not in dispatched and (admitted_rgs is None or t.row_group in admitted_rgs) - ] - def is_frontier_task(self, task: Task) -> bool: """Return whether *task* is still in the ready frontier.""" return task in self._frontier - def seed_frontier(self) -> None: - """Populate the frontier with root tasks (columns with no upstream deps). - - Not called automatically - the scheduler manages root dispatch directly - to handle stateful locks and multi-column dedup. Call this explicitly - for static introspection (e.g., capacity planning, task enumeration). - """ - if self._graph is None: - raise RuntimeError("This method requires a graph to be set.") - for col in self._graph.get_root_columns(): - if self._row_group_plan is None: - raise RuntimeError("This method requires row groups to be set.") - for rg_id, rg_size in self._row_group_plan: - self.add_root_tasks(rg_id, rg_size, columns=(col,)) - def add_root_tasks( self, row_group: int, diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_model.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_model.py index 574c594c1..e94b3fa17 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_model.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_model.py @@ -4,7 +4,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Literal +from typing import Literal @dataclass(frozen=True, order=True) @@ -26,17 +26,6 @@ class Task: task_type: Literal["from_scratch", "cell", "batch", "pre_batch_processor", "post_batch_processor"] -@dataclass -class TaskResult: - """Outcome of a completed task.""" - - task: Task - status: Literal["success", "error"] - output: Any = None - error: Exception | None = None - retryable: bool = False - - @dataclass class TaskTrace: """Timing trace for a single task. Only created when tracing is enabled.""" diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py index a7d1883c6..eb9e2104a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py @@ -1,33 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -"""Async batch execution with bounded concurrency and early-shutdown semantics. +"""Process-wide event-loop management for async engine work. -Async counterpart to ``concurrency.py``. Same operational contract (callbacks -with optional context, error aggregation, shutdown thresholds), different -runtime model. The sync module runs callables in a ``ThreadPoolExecutor``; -this module runs coroutines via ``asyncio.gather`` on a dedicated loop -thread. Callers stay synchronous. - -Architecture: - ``AsyncConcurrentExecutor.run()`` is a blocking call that submits - coroutines to a shared background event loop via - ``run_coroutine_threadsafe``. Bounded concurrency is enforced with an - ``asyncio.Semaphore``. Success/error counts use the same - ``ExecutorResults`` model as the sync executor. - - Caller Thread ──► run() ──► run_coroutine_threadsafe ──► Background Loop - (gather) - -Singleton Event Loop: +Singleton event loop: The background loop is a process-wide singleton. Async-stateful resources (connection pools, semaphores) bind internal state to a specific event loop, so creating per-call or per-instance loops breaks connection reuse and triggers cross-loop errors. - ``ensure_async_engine_loop()`` creates one daemon loop thread and - reuses it for all executor instances. + ``ensure_async_engine_loop()`` creates one daemon loop thread and reuses + it for all async engine work. -Startup Handshake: +Startup handshake: Loop creation uses a ``threading.Event`` readiness handshake. The background thread signals readiness via ``loop.call_soon(ready.set)``, and the creating thread holds the lock until that event fires (or a @@ -40,40 +24,11 @@ from __future__ import annotations import asyncio -import json import logging import threading -from collections.abc import Coroutine -from dataclasses import dataclass -from typing import Any, Generic, TypeVar - -from data_designer.engine.dataset_builders.utils.concurrency import ( - CallbackWithContext, - ErrorCallbackWithContext, - ExecutorResults, -) -from data_designer.engine.errors import DataDesignerRuntimeError -from data_designer.logging import LOG_INDENT logger = logging.getLogger(__name__) -T = TypeVar("T") - - -@dataclass(frozen=True, slots=True) -class Success(Generic[T]): - index: int - value: T - - -@dataclass(frozen=True, slots=True) -class Failure: - index: int - error: Exception - - -TaskResult = Success[T] | Failure - _loop: asyncio.AbstractEventLoop | None = None _thread: threading.Thread | None = None _lock = threading.Lock() @@ -90,9 +45,8 @@ def _run_loop(loop: asyncio.AbstractEventLoop, ready: threading.Event) -> None: def ensure_async_engine_loop() -> asyncio.AbstractEventLoop: """Get or create a persistent event loop for async engine work. - A single event loop is shared across all AsyncConcurrentExecutor instances - to avoid breaking async-stateful resources (connection pools, semaphores) - that bind internal state to a specific event loop. + A single event loop is shared across async engine work to avoid breaking + async-stateful resources that bind internal state to a specific event loop. """ global _loop, _thread with _lock: @@ -118,123 +72,3 @@ def ensure_async_engine_loop() -> asyncio.AbstractEventLoop: raise RuntimeError("AsyncEngine event loop failed to start within timeout") return _loop - - -class AsyncConcurrentExecutor: - """Async equivalent of ConcurrentThreadExecutor. - - Executes a batch of coroutines with bounded concurrency, error rate - monitoring, and early shutdown semantics. Callers remain synchronous — - the ``run()`` method submits work to a persistent background event loop. - - No locks are needed because asyncio tasks run cooperatively on a - single thread — mutations to ``_results`` are always sequential. - """ - - def __init__( - self, - *, - max_workers: int, - column_name: str, - result_callback: CallbackWithContext | None = None, - error_callback: ErrorCallbackWithContext | None = None, - shutdown_error_rate: float = 0.50, - shutdown_error_window: int = 10, - disable_early_shutdown: bool = False, - ) -> None: - self._column_name = column_name - self._max_workers = max_workers - self._result_callback = result_callback - self._error_callback = error_callback - self._shutdown_error_rate = shutdown_error_rate - self._shutdown_window_size = shutdown_error_window - self._disable_early_shutdown = disable_early_shutdown - self._results = ExecutorResults(failure_threshold=shutdown_error_rate) - - @property - def results(self) -> ExecutorResults: - return self._results - - @property - def max_workers(self) -> int: - return self._max_workers - - @property - def shutdown_error_rate(self) -> float: - return self._shutdown_error_rate - - @property - def shutdown_window_size(self) -> int: - return self._shutdown_window_size - - def run(self, work_items: list[tuple[Coroutine[Any, Any, Any], dict | None]]) -> None: - """Execute all work items concurrently. Callers remain synchronous.""" - logger.debug( - f"AsyncConcurrentExecutor: launching {len(work_items)} tasks " - f"with max_workers={self._max_workers} for column '{self._column_name}'" - ) - loop = ensure_async_engine_loop() - future = asyncio.run_coroutine_threadsafe(self._run_all(work_items), loop) - future.result() - - async def _run_all(self, work_items: list[tuple[Coroutine[Any, Any, Any], dict | None]]) -> None: - self._semaphore = asyncio.Semaphore(self._max_workers) - self._shutdown_event = asyncio.Event() - - # gather-with-explicit-cancel: equivalent to asyncio.TaskGroup but available on 3.10. - # _run_task swallows its own exceptions into error_trap, so children don't raise into - # gather under normal operation. The except-block preserves TaskGroup's "cancel siblings - # on parent cancellation or unexpected child raise" semantics for safety. - tasks = [asyncio.create_task(self._run_task(i, coro, ctx)) for i, (coro, ctx) in enumerate(work_items)] - try: - await asyncio.gather(*tasks) - except BaseException: - for t in tasks: - if not t.done(): - t.cancel() - await asyncio.gather(*tasks, return_exceptions=True) - raise - - if not self._disable_early_shutdown and self._results.early_shutdown: - self._raise_task_error() - - async def _run_task(self, index: int, coro: Coroutine[Any, Any, Any], context: dict | None) -> None: - if self._shutdown_event.is_set(): - coro.close() - return - - async with self._semaphore: - if self._shutdown_event.is_set(): - coro.close() - return - - try: - result = await coro - if self._result_callback is not None: - self._result_callback(result, context=context) - self._results.completed_count += 1 - self._results.success_count += 1 - except Exception as err: - self._results.completed_count += 1 - self._results.error_trap.handle_error(err) - if not self._disable_early_shutdown and self._results.is_error_rate_exceeded( - self._shutdown_window_size - ): - if not self._results.early_shutdown: - self._results.early_shutdown = True - self._shutdown_event.set() - if self._error_callback is not None: - try: - self._error_callback(err, context=context) - except Exception: - logger.warning("error_callback raised an exception", exc_info=True) - - def _raise_task_error(self) -> None: - raise DataDesignerRuntimeError( - "\n".join( - [ - f"{LOG_INDENT}Data generation was terminated early due to error rate exceeding threshold.", - f"{LOG_INDENT}The summary of encountered errors is: \n{json.dumps(self._results.summary, indent=4)}", - ] - ) - ) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/config_compiler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/config_compiler.py index 208fa6d80..d1c48fd00 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/config_compiler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/config_compiler.py @@ -3,7 +3,6 @@ from __future__ import annotations -from data_designer.config.base import ProcessorConfig from data_designer.config.column_types import DataDesignerColumnType from data_designer.config.data_designer_config import DataDesignerConfig from data_designer.engine.dataset_builders.multi_column_configs import ( @@ -54,9 +53,3 @@ def compile_dataset_builder_column_configs(config: DataDesignerConfig) -> list[D compiled_column_configs.extend(generated_column_configs) return compiled_column_configs - - -def compile_dataset_builder_processor_configs( - config: DataDesignerConfig, -) -> list[ProcessorConfig]: - return config.processors or [] diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/errors.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/errors.py index 4cf59697a..04774379c 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/errors.py @@ -6,9 +6,6 @@ from data_designer.engine.errors import DataDesignerError -class DatasetBatchManagementError(DataDesignerError): ... - - class ConfigCompilationError(DataDesignerError): ... diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py index 29b7d99bc..50c4dbb96 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py @@ -15,7 +15,6 @@ DatasetBuilderColumnConfigT, MultiColumnConfig, ) -from data_designer.engine.dataset_builders.scheduling.task_model import SliceRef from data_designer.engine.dataset_builders.utils.errors import ConfigCompilationError, DAGCircularDependencyError from data_designer.logging import LOG_INDENT @@ -283,42 +282,6 @@ def compute_task_count(self, num_records: int, buffer_size: int) -> dict[str, in counts[col] = num_row_groups return counts - def compute_cell_dependencies( - self, - column: str, - row_group: int, - row_index: int | None, - row_group_size: int, - ) -> list[SliceRef]: - """Derive cell-level deps on demand from column-level DAG + strategy. - - Returns a list of ``SliceRef`` that must be complete before this task can run. - """ - deps: list[SliceRef] = [] - for up_col in self.get_upstream_columns(column): - up_strategy = self._strategies[up_col] - if up_strategy == GenerationStrategy.CELL_BY_CELL: - if row_index is not None: - deps.append(SliceRef(up_col, row_group, row_index)) - else: - for ri in range(row_group_size): - deps.append(SliceRef(up_col, row_group, ri)) - else: - deps.append(SliceRef(up_col, row_group, None)) - return deps - - def to_mermaid(self) -> str: - """Mermaid diagram string with strategy annotations.""" - lines = ["graph TD"] - for col in self._columns: - strat = self._strategies[col] - label = f"{col} [{strat.value}]" - lines.append(f' {col}["{label}"]') - for col in self._columns: - for dep in sorted(self._upstream.get(col, set())): - lines.append(f" {dep} --> {col}") - return "\n".join(lines) - def topologically_sort_column_configs(column_configs: list[ColumnConfigT]) -> list[ColumnConfigT]: """Return column configs in dependency order using Kahn's algorithm. diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py index 98220e35f..e3cc28255 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py @@ -50,10 +50,6 @@ def update_cell(self, row_group: int, row_index: int, column: str, value: Any) - """Write a single cell value. Thread-safe within the asyncio event loop.""" self._buffers[row_group][row_index][column] = value - def update_cells(self, row_group: int, row_index: int, values: dict[str, Any]) -> None: - """Write multiple cell values for a single row.""" - self._buffers[row_group][row_index].update(values) - def update_batch(self, row_group: int, column: str, values: list[Any]) -> None: """Write a full column for all rows in a row group.""" buf = self._buffers[row_group] diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/skip_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/skip_tracker.py index d827d0568..ed91d7047 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/skip_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/skip_tracker.py @@ -10,20 +10,9 @@ from __future__ import annotations from collections.abc import Sequence -from dataclasses import dataclass from typing import Final SKIPPED_COLUMNS_RECORD_KEY: Final[str] = "__internal_skipped_columns" -SKIP_METADATA_RESTORE_ID_COLUMN_PREFIX: Final[str] = "__internal_skip_restore_id" - - -@dataclass(frozen=True, slots=True) -class SkipMetadataRestoreContext: - """Metadata needed to restore skip provenance after a DataFrame round-trip.""" - - restore_id_column: str - source_ids: set[str] - skipped_columns_by_source_id: dict[str, set[str]] def apply_skip_to_record( @@ -56,81 +45,3 @@ def strip_skip_metadata_for_dataframe_row(record: dict) -> dict: def strip_skip_metadata_from_records(records: Sequence[dict]) -> list[dict]: """Map :func:`strip_skip_metadata_for_dataframe_row` over *records*.""" return [strip_skip_metadata_for_dataframe_row(r) for r in records] - - -def prepare_records_for_skip_metadata_round_trip( - records: Sequence[dict], -) -> tuple[list[dict], SkipMetadataRestoreContext | None]: - """Prepare records for a DataFrame round-trip while preserving skip metadata. - - Returns stripped records ready for ``pd.DataFrame(...)``. If any record has - skip metadata, injects a hidden restore-ID column and returns a context that - can later be passed to :func:`restore_skip_metadata`. - """ - if not any(SKIPPED_COLUMNS_RECORD_KEY in record for record in records): - return strip_skip_metadata_from_records(records), None - - restore_id_column = _choose_restore_id_column(records) - prepared_records: list[dict] = [] - source_ids: set[str] = set() - skipped_columns_by_source_id: dict[str, set[str]] = {} - - for index, record in enumerate(records): - source_id = str(index) - source_ids.add(source_id) - prepared_record = strip_skip_metadata_for_dataframe_row(record) - prepared_record[restore_id_column] = source_id - prepared_records.append(prepared_record) - - meta = record.get(SKIPPED_COLUMNS_RECORD_KEY) - if meta is not None: - skipped_columns_by_source_id[source_id] = set(meta) - - return prepared_records, SkipMetadataRestoreContext( - restore_id_column=restore_id_column, - source_ids=source_ids, - skipped_columns_by_source_id=skipped_columns_by_source_id, - ) - - -def restore_skip_metadata( - records: Sequence[dict], - *, - context: SkipMetadataRestoreContext, -) -> None: - """Restore skip provenance using hidden restore IDs instead of row position.""" - restored_source_ids: list[str] = [] - for record in records: - if context.restore_id_column not in record: - raise ValueError( - f"Records returned from the DataFrame round-trip must preserve " - f"the internal column {context.restore_id_column!r} so skip " - "provenance can be restored." - ) - - source_id = str(record.pop(context.restore_id_column)) - if source_id not in context.source_ids: - raise ValueError( - f"Record returned unknown restore ID {source_id!r}. Skip provenance " - "can only be restored for rows derived from the original input." - ) - - restored_source_ids.append(source_id) - meta = context.skipped_columns_by_source_id.get(source_id) - if meta is not None: - record[SKIPPED_COLUMNS_RECORD_KEY] = set(meta) - - if len(restored_source_ids) != len(context.source_ids) or set(restored_source_ids) != context.source_ids: - raise ValueError( - "Full-column generation changed the row identity mapping. Returned rows must preserve " - "a 1:1 mapping to the original input so skip provenance can be restored." - ) - - -def _choose_restore_id_column(records: Sequence[dict]) -> str: - candidate = SKIP_METADATA_RESTORE_ID_COLUMN_PREFIX - suffix = 0 - while any(candidate in record for record in records): - suffix += 1 - candidate = f"{SKIP_METADATA_RESTORE_ID_COLUMN_PREFIX}_{suffix}" - return candidate diff --git a/packages/data-designer-engine/src/data_designer/engine/errors.py b/packages/data-designer-engine/src/data_designer/engine/errors.py index 3aee0544f..5c455e42a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/errors.py @@ -11,9 +11,6 @@ class DataDesignerRuntimeError(DataDesignerError): ... -class UnknownModelAliasError(DataDesignerError): ... - - class UnknownProviderError(DataDesignerError): ... diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py index f0b79c28b..8c459f848 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py @@ -5,7 +5,6 @@ import calendar import email.utils -import json import time from enum import Enum @@ -134,31 +133,6 @@ def map_http_error_to_provider_error( ) -def extract_message_from_exception_string(raw: str) -> str: - """Extract a human-readable message from a stringified provider exception. - - Some providers format errors as ``"Error code: 400 - {json}"``. This - mirrors the structured-key lookup in ``_extract_structured_message`` but - operates on a raw string instead of an ``HttpResponse``. - """ - json_start = raw.find("{") - if json_start != -1: - try: - payload = json.loads(raw[json_start:]) - except (json.JSONDecodeError, ValueError): - return raw - if isinstance(payload, dict): - for key in ("message", "error", "detail"): - value = payload.get(key) - if isinstance(value, str) and value.strip(): - return value.strip() - if isinstance(value, dict): - nested = value.get("message") - if isinstance(nested, str) and nested.strip(): - return nested.strip() - return raw - - def _extract_response_text(response: HttpResponse) -> str: # Try structured JSON extraction first — most providers return structured error # bodies and we want the human-readable message, not raw JSON. diff --git a/packages/data-designer-engine/src/data_designer/engine/models/parsers/postprocessors.py b/packages/data-designer-engine/src/data_designer/engine/models/parsers/postprocessors.py index 4635562cd..85b41f4a7 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/parsers/postprocessors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/parsers/postprocessors.py @@ -4,12 +4,10 @@ from __future__ import annotations import json_repair -from pydantic import BaseModel, ValidationError from data_designer.engine.models.parsers.types import ( CodeBlock, LLMStructuredResponse, - PydanticTypeBlock, StructuredDataBlock, TextBlock, ) @@ -57,37 +55,3 @@ def deserialize_json_code( processed_response.parsed.append(block) return processed_response - - -class RealizePydanticTypes: - types: list[type[BaseModel]] - - def __init__(self, types: list[type[BaseModel]]): - self.types = types - - def _fit_types(self, obj: dict) -> BaseModel | None: - final_obj = None - - for t in self.types: - try: - final_obj = t.model_validate(obj) - except ValidationError: - pass - - return final_obj - - def __call__(self, structured_response: LLMStructuredResponse) -> LLMStructuredResponse: - processed_response = structured_response.model_copy() - processed_response.parsed = [] - - for block in structured_response.parsed: - if isinstance(block, StructuredDataBlock): - new_block = block - pydantic_obj = self._fit_types(block.obj) - if pydantic_obj: - new_block = PydanticTypeBlock(serialized=block.serialized, obj=pydantic_obj) - processed_response.parsed.append(new_block) - else: - processed_response.parsed.append(block) - - return processed_response diff --git a/packages/data-designer-engine/src/data_designer/engine/models/request_admission/controller.py b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/controller.py index 38bbd0598..d0273ebca 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/request_admission/controller.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/controller.py @@ -635,17 +635,6 @@ def _apply_outcome( if state.in_flight == 0 and outcome.kind not in {"local_cancelled", "local_timeout"}: state.consecutive_rate_limits = 0 - def _increment_waiter(self, item: RequestAdmissionItem) -> None: - with self._lock: - self._get_or_create_state(item.resource).waiters += 1 - self._sequence += 1 - - def _decrement_waiter(self, item: RequestAdmissionItem) -> None: - with self._lock: - state = self._get_or_create_state(item.resource) - state.waiters = max(0, state.waiters - 1) - self._sequence += 1 - def _get_or_create_state(self, resource: RequestResourceKey) -> AdaptiveRequestLimitState: state = self._domains.get(resource) if state is None: diff --git a/packages/data-designer-engine/src/data_designer/engine/observability.py b/packages/data-designer-engine/src/data_designer/engine/observability.py index a7a28c41b..36453013b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/observability.py +++ b/packages/data-designer-engine/src/data_designer/engine/observability.py @@ -213,32 +213,3 @@ def emit_scheduler_event(self, event: SchedulerAdmissionEvent) -> None: ... class RequestAdmissionEventSink(Protocol): def emit_request_event(self, event: RequestAdmissionEvent) -> None: ... - - -class InMemoryAdmissionEventSink: - """Small sink used by tests, diagnostics, and benchmark smoke runs.""" - - def __init__(self) -> None: - self.scheduler_events: list[SchedulerAdmissionEvent] = [] - self.request_events: list[RequestAdmissionEvent] = [] - - def emit_scheduler_event(self, event: SchedulerAdmissionEvent) -> None: - self.scheduler_events.append(event) - - def emit_request_event(self, event: RequestAdmissionEvent) -> None: - self.request_events.append(event) - - -@dataclass(frozen=True) -class CorrelatedRuntimeView: - scheduler_events: tuple[SchedulerAdmissionEvent, ...] - request_events: tuple[RequestAdmissionEvent, ...] - - @property - def timeline(self) -> tuple[SchedulerAdmissionEvent | RequestAdmissionEvent, ...]: - return tuple( - sorted( - (*self.scheduler_events, *self.request_events), - key=lambda event: (event.captured_at_monotonic, event.sequence), - ) - ) diff --git a/packages/data-designer-engine/src/data_designer/engine/resources/resource_provider.py b/packages/data-designer-engine/src/data_designer/engine/resources/resource_provider.py index 1a969e2b1..b6fa89411 100644 --- a/packages/data-designer-engine/src/data_designer/engine/resources/resource_provider.py +++ b/packages/data-designer-engine/src/data_designer/engine/resources/resource_provider.py @@ -56,31 +56,6 @@ def get_dataset_metadata(self) -> DatasetMetadata: return DatasetMetadata(seed_column_names=seed_column_names) -def _validate_tool_configs_against_providers( - tool_configs: list[ToolConfig], - mcp_providers: list[MCPProviderT], -) -> None: - """Validate that all providers referenced in tool configs exist. - - Args: - tool_configs: List of tool configurations to validate. - mcp_providers: List of available MCP provider configurations. - - Raises: - ValueError: If a tool config references a provider that doesn't exist. - """ - available_providers = {p.name for p in mcp_providers} - - for tc in tool_configs: - missing_providers = [p for p in tc.providers if p not in available_providers] - if missing_providers: - available_list = sorted(available_providers) if available_providers else ["(none configured)"] - raise ValueError( - f"ToolConfig '{tc.tool_alias}' references provider(s) {missing_providers!r} " - f"which are not registered. Available providers: {available_list}" - ) - - def create_resource_provider( *, artifact_storage: ArtifactStorage, diff --git a/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py b/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py index 6bd6e3dd9..993e9f536 100644 --- a/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py +++ b/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py @@ -143,10 +143,6 @@ def _validate_image(self, image_path: Path) -> None: image_path.unlink(missing_ok=True) raise - def _ensure_images_directory(self) -> None: - """Create images directory if it doesn't exist (lazy initialization).""" - self.images_dir.mkdir(parents=True, exist_ok=True) - def _sanitize_subfolder_name(self, name: str) -> str: """Sanitize subfolder name to prevent path traversal and filesystem issues.""" # Replace path separators and parent directory references with underscores diff --git a/packages/data-designer-engine/src/data_designer/engine/testing/__init__.py b/packages/data-designer-engine/src/data_designer/engine/testing/__init__.py index 2b5b62d42..9050b41ea 100644 --- a/packages/data-designer-engine/src/data_designer/engine/testing/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/testing/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations +from data_designer.engine.testing.observability import CorrelatedRuntimeView, InMemoryAdmissionEventSink from data_designer.engine.testing.seed_readers import LineFanoutDirectorySeedReader from data_designer.engine.testing.stubs import ( StubChoice, @@ -16,6 +17,8 @@ from data_designer.engine.testing.utils import assert_valid_plugin __all__ = [ + "CorrelatedRuntimeView", + "InMemoryAdmissionEventSink", LineFanoutDirectorySeedReader.__name__, "StubChoice", "StubHuggingFaceSeedReader", diff --git a/packages/data-designer-engine/src/data_designer/engine/testing/observability.py b/packages/data-designer-engine/src/data_designer/engine/testing/observability.py new file mode 100644 index 000000000..cc3237ae8 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/testing/observability.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + +from data_designer.engine.observability import RequestAdmissionEvent, SchedulerAdmissionEvent + + +class InMemoryAdmissionEventSink: + """In-memory admission-event sink for tests and benchmark smoke runs.""" + + def __init__(self) -> None: + self.scheduler_events: list[SchedulerAdmissionEvent] = [] + self.request_events: list[RequestAdmissionEvent] = [] + + def emit_scheduler_event(self, event: SchedulerAdmissionEvent) -> None: + self.scheduler_events.append(event) + + def emit_request_event(self, event: RequestAdmissionEvent) -> None: + self.request_events.append(event) + + +@dataclass(frozen=True) +class CorrelatedRuntimeView: + """Combined chronological view of scheduler and request events for tests.""" + + scheduler_events: tuple[SchedulerAdmissionEvent, ...] + request_events: tuple[RequestAdmissionEvent, ...] + + @property + def timeline(self) -> tuple[SchedulerAdmissionEvent | RequestAdmissionEvent, ...]: + return tuple( + sorted( + (*self.scheduler_events, *self.request_events), + key=lambda event: (event.captured_at_monotonic, event.sequence), + ) + ) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_completion.py b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_completion.py index e647d4ac6..bbae04d57 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_completion.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_completion.py @@ -40,7 +40,6 @@ def _build_simple_graph() -> ExecutionGraph: @dataclass class ReadyTasksFixture: tracker: CompletionTracker - dispatched: set[Task] @pytest.fixture() @@ -49,7 +48,6 @@ def ready_ctx() -> ReadyTasksFixture: graph = _build_simple_graph() return ReadyTasksFixture( tracker=CompletionTracker.with_graph(graph, [(0, 3)]), - dispatched=set(), ) @@ -86,43 +84,6 @@ def test_mark_cell_complete_raises_on_unknown_row_group(ready_ctx: ReadyTasksFix ready_ctx.tracker.mark_cell_complete("question", row_group=999, row_index=0) -# -- is_all_complete ----------------------------------------------------------- - - -def test_all_complete_cell_level() -> None: - tracker = CompletionTracker() - tracker.mark_cell_complete("col_a", 0, 0) - tracker.mark_cell_complete("col_a", 0, 1) - - assert tracker.is_all_complete([SliceRef("col_a", 0, 0), SliceRef("col_a", 0, 1)]) - assert not tracker.is_all_complete([SliceRef("col_a", 0, 0), SliceRef("col_a", 0, 2)]) - - -def test_all_complete_batch_level() -> None: - tracker = CompletionTracker() - tracker.mark_row_range_complete("col_a", 0, 3) - - assert tracker.is_all_complete([SliceRef("col_a", 0, None)]) - - -def test_all_complete_batch_single_cell_not_sufficient() -> None: - """mark_cell_complete on one row must NOT make is_all_complete return True for batch check.""" - tracker = CompletionTracker() - tracker.mark_cell_complete("col_a", 0, 0) - - assert not tracker.is_all_complete([SliceRef("col_a", 0, None)]) - - -def test_all_complete_batch_not_present() -> None: - tracker = CompletionTracker() - assert not tracker.is_all_complete([SliceRef("col_a", 0, None)]) - - -def test_all_complete_empty_list() -> None: - tracker = CompletionTracker() - assert tracker.is_all_complete([]) - - # -- drop_row / is_dropped ------------------------------------------------- @@ -174,19 +135,17 @@ def test_row_group_not_complete_missing_non_dropped() -> None: assert not tracker.is_row_group_complete(0, 3, ["col_a", "col_b"]) -# -- get_ready_tasks -------------------------------------------------------- +# -- ready frontier --------------------------------------------------------- -def test_get_ready_tasks_frontier_empty_without_seed(ready_ctx: ReadyTasksFixture) -> None: - """Frontier starts empty - seed_frontier() must be called explicitly.""" - ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) +def test_ready_frontier_starts_empty(ready_ctx: ReadyTasksFixture) -> None: + ready = ready_ctx.tracker.ready_frontier() assert len(ready) == 0 -def test_get_ready_tasks_seed_frontier(ready_ctx: ReadyTasksFixture) -> None: - """seed_frontier() populates the frontier with root tasks.""" - ready_ctx.tracker.seed_frontier() - ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) +def test_add_root_tasks_populates_frontier(ready_ctx: ReadyTasksFixture) -> None: + ready_ctx.tracker.add_root_tasks(0, 3, columns=("topic",)) + ready = ready_ctx.tracker.ready_frontier() assert len(ready) == 1 assert ready[0].column == "topic" @@ -194,7 +153,7 @@ def test_get_ready_tasks_seed_frontier(ready_ctx: ReadyTasksFixture) -> None: def test_mark_enqueued_uses_scheduler_stable_task_id(ready_ctx: ReadyTasksFixture) -> None: - ready_ctx.tracker.seed_frontier() + ready_ctx.tracker.add_root_tasks(0, 3, columns=("topic",)) task = ready_ctx.tracker.ready_frontier()[0] ready_ctx.tracker.mark_enqueued({stable_task_id(task)}) @@ -202,10 +161,10 @@ def test_mark_enqueued_uses_scheduler_stable_task_id(ready_ctx: ReadyTasksFixtur assert ready_ctx.tracker.ready_frontier() == () -def test_get_ready_tasks_after_seed_complete(ready_ctx: ReadyTasksFixture) -> None: +def test_ready_frontier_after_seed_complete(ready_ctx: ReadyTasksFixture) -> None: delta = ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) - ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) + ready = ready_ctx.tracker.ready_frontier() question_tasks = [t for t in ready if t.column == "question"] assert len(question_tasks) == 3 @@ -233,7 +192,7 @@ def test_fan_out_cell_completion_readies_all_children_for_same_row() -> None: assert {task.column for task in delta.added} == {"child_a", "child_b", "child_c"} assert {task.row_index for task in delta.added} == {0} - ready = tracker.get_ready_tasks(set()) + ready = tracker.ready_frontier() assert not any(task.column.startswith("child_") and task.row_index == 1 for task in ready) @@ -258,26 +217,16 @@ def test_fan_in_cell_downstream_waits_for_all_same_row_upstreams() -> None: assert not any(task.column == "judge" for task in first_delta.added) assert not any(task.column == "judge" for task in second_delta.added) assert final_delta.added == (Task(column="judge", row_group=0, row_index=0, task_type="cell"),) - ready = tracker.get_ready_tasks(set()) + ready = tracker.ready_frontier() assert not any(task.column == "judge" and task.row_index == 1 for task in ready) -def test_get_ready_tasks_skips_dispatched(ready_ctx: ReadyTasksFixture) -> None: - ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) - - ready1 = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) - ready_ctx.dispatched.update(ready1) - - ready2 = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) - assert len(ready2) == 0 - - -def test_get_ready_tasks_skips_dropped_rows(ready_ctx: ReadyTasksFixture) -> None: +def test_ready_frontier_skips_dropped_rows(ready_ctx: ReadyTasksFixture) -> None: ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) removed = Task(column="question", row_group=0, row_index=1, task_type="cell") delta = ready_ctx.tracker.drop_row(0, 1) - ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) + ready = ready_ctx.tracker.ready_frontier() question_tasks = [t for t in ready if t.column == "question"] assert len(question_tasks) == 2 @@ -294,32 +243,32 @@ def test_drop_row_unblocks_full_column_downstream(ready_ctx: ReadyTasksFixture) # question[2] never completes -- drop it instead delta = ready_ctx.tracker.drop_row(0, 2) - ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) + ready = ready_ctx.tracker.ready_frontier() score_tasks = [t for t in ready if t.column == "score"] assert len(score_tasks) == 1 assert score_tasks[0].task_type == "batch" assert score_tasks[0] in delta.added -def test_get_ready_tasks_full_column_waits_for_all_cells(ready_ctx: ReadyTasksFixture) -> None: +def test_ready_frontier_full_column_waits_for_all_cells(ready_ctx: ReadyTasksFixture) -> None: ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) ready_ctx.tracker.mark_cell_complete("question", 0, 0) ready_ctx.tracker.mark_cell_complete("question", 0, 1) # question[0,2] not done yet - ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) + ready = ready_ctx.tracker.ready_frontier() score_tasks = [t for t in ready if t.column == "score"] assert len(score_tasks) == 0 -def test_get_ready_tasks_full_column_ready_when_all_cells_done(ready_ctx: ReadyTasksFixture) -> None: +def test_ready_frontier_full_column_ready_when_all_cells_done(ready_ctx: ReadyTasksFixture) -> None: ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) delta = None for ri in range(3): delta = ready_ctx.tracker.mark_cell_complete("question", 0, ri) - ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) + ready = ready_ctx.tracker.ready_frontier() score_tasks = [t for t in ready if t.column == "score"] assert len(score_tasks) == 1 @@ -328,15 +277,13 @@ def test_get_ready_tasks_full_column_ready_when_all_cells_done(ready_ctx: ReadyT assert delta.added == (score_tasks[0],) -def test_get_ready_tasks_multiple_row_groups() -> None: +def test_ready_frontier_multiple_row_groups() -> None: graph = _build_simple_graph() tracker = CompletionTracker.with_graph(graph, [(0, 3), (1, 2)]) - dispatched: set[Task] = set() - tracker.mark_row_range_complete("topic", 0, 3) tracker.mark_row_range_complete("topic", 1, 2) - ready = tracker.get_ready_tasks(dispatched) + ready = tracker.ready_frontier() question_tasks = [t for t in ready if t.column == "question"] assert len(question_tasks) == 5 # 3 from rg0 + 2 from rg1 @@ -350,10 +297,10 @@ def test_frontier_delta_return_is_empty_when_frontier_does_not_change(ready_ctx: assert delta.empty -def test_get_ready_tasks_skips_already_complete_batch(ready_ctx: ReadyTasksFixture) -> None: +def test_ready_frontier_skips_already_complete_batch(ready_ctx: ReadyTasksFixture) -> None: ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) - ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) + ready = ready_ctx.tracker.ready_frontier() topic_tasks = [t for t in ready if t.column == "topic"] assert len(topic_tasks) == 0 @@ -380,8 +327,6 @@ def test_completed_cell_not_reenqueued_after_later_upstream() -> None: """A → B → C chain: completing C then firing a late upstream event must not re-enqueue C.""" graph = _build_simple_graph() tracker = CompletionTracker.with_graph(graph, [(0, 2)]) - dispatched: set[Task] = set() - # Complete the full pipeline tracker.mark_row_range_complete("topic", 0, 2) tracker.mark_cell_complete("question", 0, 0) @@ -391,7 +336,7 @@ def test_completed_cell_not_reenqueued_after_later_upstream() -> None: # Fire a late upstream cell event after score is already done tracker.mark_cell_complete("question", 0, 0) - ready = tracker.get_ready_tasks(dispatched) + ready = tracker.ready_frontier() score_tasks = [t for t in ready if t.column == "score"] assert len(score_tasks) == 0 @@ -410,21 +355,19 @@ def test_completed_batch_not_reenqueued_by_upstream_cell() -> None: } graph = ExecutionGraph.create(configs, strategies) tracker = CompletionTracker.with_graph(graph, [(0, 2)]) - dispatched: set[Task] = set() - # Complete seed and gen[0] — agg not ready yet tracker.mark_row_range_complete("seed", 0, 2) tracker.mark_cell_complete("gen", 0, 0) - ready = tracker.get_ready_tasks(dispatched) + ready = tracker.ready_frontier() assert not any(t.column == "agg" for t in ready) # Complete gen[1] — agg becomes ready tracker.mark_cell_complete("gen", 0, 1) - ready = tracker.get_ready_tasks(dispatched) + ready = tracker.ready_frontier() assert any(t.column == "agg" for t in ready) # Complete agg, then verify it doesn't reappear tracker.mark_row_range_complete("agg", 0, 2) - ready = tracker.get_ready_tasks(dispatched) + ready = tracker.ready_frontier() assert not any(t.column == "agg" for t in ready) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_model.py b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_model.py index cdc5e6c6a..79f00cbcc 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_model.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_model.py @@ -5,7 +5,7 @@ import pytest -from data_designer.engine.dataset_builders.scheduling.task_model import Task, TaskResult, TaskTrace +from data_designer.engine.dataset_builders.scheduling.task_model import Task, TaskTrace def test_task_is_frozen() -> None: @@ -41,23 +41,6 @@ def test_task_types(task_type: str) -> None: assert task.task_type == task_type -def test_task_result_success() -> None: - task = Task(column="col_a", row_group=0, row_index=0, task_type="cell") - result = TaskResult(task=task, status="success", output={"col_a": "value"}) - assert result.status == "success" - assert result.error is None - assert result.retryable is False - - -def test_task_result_error() -> None: - task = Task(column="col_a", row_group=0, row_index=0, task_type="cell") - exc = ValueError("bad input") - result = TaskResult(task=task, status="error", error=exc, retryable=True) - assert result.status == "error" - assert result.error is exc - assert result.retryable is True - - def test_task_trace_from_task() -> None: task = Task(column="col_a", row_group=1, row_index=2, task_type="cell") trace = TaskTrace.from_task(task) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py index bb5cf5685..c2ba039e5 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py @@ -57,7 +57,6 @@ from data_designer.engine.models.request_admission.config import RequestAdmissionConfig from data_designer.engine.models.request_admission.controller import ( AdaptiveRequestAdmissionController, - RequestAdmissionLease, ) from data_designer.engine.models.request_admission.outcomes import RequestReleaseOutcome from data_designer.engine.models.request_admission.pressure import RequestPressureSnapshot @@ -68,8 +67,8 @@ RequestResourceKey, ) from data_designer.engine.models.resources import ProviderModelKey -from data_designer.engine.observability import InMemoryAdmissionEventSink from data_designer.engine.resources.resource_provider import ResourceProvider +from data_designer.engine.testing import InMemoryAdmissionEventSink MODEL_ALIAS = "stub" @@ -2303,40 +2302,6 @@ async def record_sleep(delay: float) -> None: assert yielded_delays == [0] -@pytest.mark.asyncio(loop_scope="session") -async def test_scheduler_dispatch_does_not_scan_ready_frontier(monkeypatch: pytest.MonkeyPatch) -> None: - provider = _mock_provider() - configs = [ - SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), - LLMTextColumnConfig(name="cell_out", prompt="{{ seed }}", model_alias=MODEL_ALIAS), - ] - strategies = { - "seed": GenerationStrategy.FULL_COLUMN, - "cell_out": GenerationStrategy.CELL_BY_CELL, - } - generators = { - "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), - "cell_out": MockCellGenerator(config=_expr_config("cell_out"), resource_provider=provider), - } - graph = ExecutionGraph.create(configs, strategies) - tracker = CompletionTracker.with_graph(graph, [(0, 3)]) - - def fail_get_ready_tasks(*args: Any, **kwargs: Any) -> list[Task]: - raise AssertionError("scheduler should apply returned frontier deltas instead of scanning ready tasks") - - monkeypatch.setattr(tracker, "get_ready_tasks", fail_get_ready_tasks) - scheduler = AsyncTaskScheduler( - generators=generators, - graph=graph, - tracker=tracker, - row_groups=[(0, 3)], - ) - - await asyncio.wait_for(scheduler.run(), timeout=10.0) - - assert tracker.is_row_group_complete(0, 3, ["seed", "cell_out"]) - - @pytest.mark.asyncio(loop_scope="session") async def test_scheduler_pre_batch_drop_removes_pending_ready_task() -> None: provider = _mock_provider() @@ -3542,102 +3507,6 @@ async def test_scheduler_downstream_interleaves_with_upstream() -> None: ) -@pytest.mark.asyncio(loop_scope="session") -async def test_scheduler_capacity_plan_observes_buffer_backpressure() -> None: - provider = _mock_provider() - gen_names = ["gen_a", "gen_b", "gen_c"] - configs = [ - SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), - *[LLMTextColumnConfig(name=g, prompt="{{ topic }}", model_alias=MODEL_ALIAS) for g in gen_names], - ] - strategies: dict[str, GenerationStrategy] = {"topic": GenerationStrategy.FULL_COLUMN} - strategies.update({column: GenerationStrategy.CELL_BY_CELL for column in gen_names}) - generators: dict[str, ColumnGenerator] = { - "topic": MockSeedGenerator(config=_expr_config("topic"), resource_provider=provider), - **{ - name: SlowCellGenerator(config=_expr_config(name), resource_provider=provider, delay=0.02) - for name in gen_names - }, - } - graph = ExecutionGraph.create(configs, strategies) - row_groups = [(0, 3), (1, 3), (2, 3), (3, 3)] - tracker = CompletionTracker.with_graph(graph, row_groups) - scheduler = AsyncTaskScheduler( - generators=generators, - graph=graph, - tracker=tracker, - row_groups=row_groups, - max_concurrent_row_groups=2, - max_in_flight_tasks=2, - trace=True, - num_records=12, - buffer_size=3, - ) - - await asyncio.wait_for(scheduler.run(), timeout=10.0) - - plan = scheduler.capacity_plan() - for row_group_index, row_count in row_groups: - assert tracker.is_row_group_complete(row_group_index, row_count, ["topic", *gen_names]) - assert plan.configured.row_group_admission.observed_in_flight == 0 - assert plan.observed_maxima.row_groups_in_flight == 2 - assert plan.observed_maxima.queued_tasks_by_group - assert max(plan.observed_maxima.task_leases_by_resource.values()) <= 2 - - -def test_scheduler_capacity_plan_reports_request_admission_state() -> None: - resource = RequestResourceKey("provider", "model", RequestDomain.CHAT) - request_admission = AdaptiveRequestAdmissionController( - RequestAdmissionConfig(initial_limits={resource: 2}, max_limit_clamps={resource: 3}) - ) - request_admission.register( - provider_name="provider", - model_id="model", - alias="primary", - max_parallel_requests=4, - ) - lease = request_admission.try_acquire(RequestAdmissionItem(resource, RequestGroupSpec(resource))) - assert isinstance(lease, RequestAdmissionLease) - - scheduler, _tracker = _build_simple_pipeline() - scheduler._request_pressure_provider = request_admission - scheduler._record_observed_task_state() - plan = scheduler.capacity_plan() - - assert plan.configured.request_resources.value == (resource,) - assert plan.configured.request_domain_initial_limits.value[resource] == 2 - assert plan.configured.request_admission_config.value is not None - assert plan.configured.provider_model_static_caps.value[ProviderModelKey("provider", "model")].cap == 4 - assert plan.runtime_snapshot.request_domain_current_limits[resource] == 2 - assert plan.runtime_snapshot.request_domain_effective_max[resource] == 3 - assert plan.runtime_snapshot.provider_model_aggregate_in_flight[ProviderModelKey("provider", "model")] == 1 - assert plan.observed_maxima.request_in_flight_by_resource[resource] == 1 - assert plan.observed_maxima.provider_model_aggregate_in_flight[ProviderModelKey("provider", "model")] == 1 - request_admission.release(lease, RequestReleaseOutcome(kind="success")) - - -def test_scheduler_capacity_plan_reports_default_request_initial_limit_after_aimd_drop() -> None: - resource = RequestResourceKey("provider", "model", RequestDomain.CHAT) - request_admission = AdaptiveRequestAdmissionController() - request_admission.register( - provider_name="provider", - model_id="model", - alias="primary", - max_parallel_requests=4, - ) - lease = request_admission.try_acquire(RequestAdmissionItem(resource, RequestGroupSpec(resource))) - assert isinstance(lease, RequestAdmissionLease) - request_admission.release(lease, RequestReleaseOutcome(kind="rate_limited")) - - scheduler, _tracker = _build_simple_pipeline() - scheduler._request_pressure_provider = request_admission - plan = scheduler.capacity_plan() - - assert plan.configured.request_domain_initial_limits.value[resource] == 4 - assert plan.runtime_snapshot.request_domain_effective_max[resource] == 4 - assert plan.runtime_snapshot.request_domain_current_limits[resource] == 3 - - @pytest.mark.asyncio(loop_scope="session") async def test_scheduler_emits_job_health_and_row_group_telemetry() -> None: provider = _mock_provider() @@ -3744,12 +3613,7 @@ async def test_scheduler_adaptive_row_group_admission_expands_target_for_horizon await asyncio.wait_for(scheduler.run(), timeout=10.0) - plan = scheduler.capacity_plan() assert tracker.is_row_group_complete(0, 1, ["topic", "model_col"]) - assert plan.configured.row_group_admission.mode == "adaptive" - assert plan.configured.row_group_admission.observed_max_target is not None - assert plan.configured.row_group_admission.observed_max_target > 1 - assert plan.observed_maxima.row_groups_in_flight > 1 assert any(event.event_kind == "row_group_admission_target_changed" for event in sink.scheduler_events) @@ -3978,7 +3842,6 @@ def test_scheduler_adaptive_row_group_target_stays_blocked_after_llm_lease_boots def test_scheduler_adaptive_row_group_queue_guard_uses_in_flight_task_cap() -> None: scheduler, _tracker = _build_simple_pipeline(num_records=2, buffer_size=1) scheduler._max_in_flight_tasks = 2 - scheduler._max_model_task_admission = 100 scheduler._fair_queue = SimpleNamespace( view=lambda: SimpleNamespace(queued_total=8, queued_peer_demand_by_resource={}) ) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py index 6b257a062..90f4dba3b 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py @@ -33,6 +33,7 @@ from data_designer.engine.dataset_builders.dataset_builder import DatasetBuilder, build_row_group_resume_plan from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError from data_designer.engine.dataset_builders.row_group_plan import CompactRowGroupPlan +from data_designer.engine.dataset_builders.utils.processor_runner import ProcessorRunner from data_designer.engine.models.telemetry import InferenceEvent, NemoSourceEnum, TaskStatusEnum from data_designer.engine.models.usage import ModelUsageStats, TokenUsageStats from data_designer.engine.processing.processors.base import Processor @@ -44,6 +45,10 @@ import pandas as pd +def _replace_processors(builder: DatasetBuilder, processors: list[Processor]) -> None: + builder._processor_runner = ProcessorRunner(processors=processors, artifact_storage=builder.artifact_storage) + + @pytest.fixture def stub_test_column_configs(): return [ @@ -501,7 +506,7 @@ def test_run_after_generation( mock_processor = create_mock_processor("proc", ["process_after_generation"]) mock_processor.process_after_generation.side_effect = processor_fn - simple_builder.set_processor_runner([mock_processor]) + _replace_processors(simple_builder, [mock_processor]) simple_builder._processor_runner.run_after_generation(batch_size) mock_processor.process_after_generation.assert_called_once() @@ -524,7 +529,7 @@ def test_all_processor_stages_run_in_order(builder_with_seed, mode): df, )[1] - builder_with_seed.set_processor_runner([mock_processor]) + _replace_processors(builder_with_seed, [mock_processor]) if mode == "preview": raw_dataset = builder_with_seed.build_preview(num_records=3) @@ -544,7 +549,7 @@ def test_processor_exception_in_process_after_batch_raises_error(simple_builder) mock_processor = create_mock_processor("failing_processor", ["process_after_batch"]) mock_processor.process_after_batch.side_effect = ValueError("Post-batch processing failed") - simple_builder.set_processor_runner([mock_processor]) + _replace_processors(simple_builder, [mock_processor]) with pytest.raises(DatasetProcessingError, match="Failed in process_after_batch"): simple_builder._processor_runner.run_post_batch(lazy.pd.DataFrame({"id": [1, 2, 3]}), current_batch_number=0) @@ -553,7 +558,7 @@ def test_processor_exception_in_process_after_batch_raises_error(simple_builder) def test_processor_with_no_implemented_stages_is_skipped(builder_with_seed): """Test that a processor implementing no stages doesn't cause errors.""" mock_processor = create_mock_processor("noop_processor", []) - builder_with_seed.set_processor_runner([mock_processor]) + _replace_processors(builder_with_seed, [mock_processor]) result = builder_with_seed.build_preview(num_records=3) @@ -573,7 +578,7 @@ def test_multiple_processors_run_in_definition_order(builder_with_seed): p.process_before_batch.side_effect = lambda df, lbl=label: (call_order.append(lbl), df)[1] processors.append(p) - builder_with_seed.set_processor_runner(processors) + _replace_processors(builder_with_seed, processors) builder_with_seed.build(num_records=3) assert call_order == ["a", "b", "c"] @@ -582,7 +587,7 @@ def test_multiple_processors_run_in_definition_order(builder_with_seed): def test_pre_batch_processor_row_count_change_rejected(builder_with_seed, caplog): mock_processor = create_mock_processor("filtering_processor", ["process_before_batch"]) mock_processor.process_before_batch.side_effect = lambda df: df.iloc[:2].reset_index(drop=True) - builder_with_seed.set_processor_runner([mock_processor]) + _replace_processors(builder_with_seed, [mock_processor]) with caplog.at_level(logging.INFO): with pytest.raises(DatasetGenerationError, match="Pre-batch processor changed row count"): @@ -594,7 +599,7 @@ def test_pre_batch_processor_row_count_change_rejected(builder_with_seed, caplog def test_process_preview_with_empty_dataframe(simple_builder): """Test that process_preview handles empty DataFrames gracefully.""" mock_processor = create_mock_processor("test_processor", ["process_after_batch", "process_after_generation"]) - simple_builder.set_processor_runner([mock_processor]) + _replace_processors(simple_builder, [mock_processor]) result = simple_builder.process_preview(lazy.pd.DataFrame()) @@ -1604,7 +1609,7 @@ def test_build_resume_complete_dataset_runs_after_generation_when_no_marker( builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=2) after_gen_processor = create_mock_processor("after_gen", ["process_after_generation"]) - builder.set_processor_runner([after_gen_processor]) + _replace_processors(builder, [after_gen_processor]) builder.build(num_records=4, resume=ResumeMode.ALWAYS) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_async_concurrency.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_async_concurrency.py deleted file mode 100644 index 4baa829af..000000000 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_async_concurrency.py +++ /dev/null @@ -1,478 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import asyncio - -import pytest - -from data_designer.engine.dataset_builders.utils.async_concurrency import ( - AsyncConcurrentExecutor, -) -from data_designer.engine.dataset_builders.utils.concurrency import ExecutorResults -from data_designer.engine.errors import DataDesignerRuntimeError - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -async def _succeed(value: int) -> int: - """Simple coroutine that returns its input doubled.""" - return value * 2 - - -async def _fail(msg: str = "Test error") -> None: - """Simple coroutine that always raises.""" - raise ValueError(msg) - - -async def _succeed_slow(value: int, delay: float = 0.05) -> int: - """Coroutine with a small delay to simulate work.""" - await asyncio.sleep(delay) - return value * 2 - - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - - -def test_basic_creation(): - executor = AsyncConcurrentExecutor( - max_workers=4, - column_name="test_column", - shutdown_error_rate=0.3, - shutdown_error_window=5, - ) - assert executor.max_workers == 4 - assert executor.shutdown_error_rate == 0.3 - assert executor.shutdown_window_size == 5 - assert isinstance(executor.results, ExecutorResults) - assert executor.results.completed_count == 0 - assert executor.results.success_count == 0 - assert executor.results.early_shutdown is False - assert executor.results.failure_threshold == 0.3 - - -def test_successful_execution(): - executor = AsyncConcurrentExecutor(max_workers=4, column_name="test_column") - work_items = [(_succeed(i), None) for i in range(10)] - executor.run(work_items) - - assert executor.results.completed_count == 10 - assert executor.results.success_count == 10 - assert executor.results.error_trap.error_count == 0 - assert executor.results.early_shutdown is False - - -def test_successful_execution_with_context(): - executor = AsyncConcurrentExecutor(max_workers=2, column_name="test_column") - work_items = [(_succeed(i), {"index": i}) for i in range(5)] - executor.run(work_items) - - assert executor.results.completed_count == 5 - assert executor.results.success_count == 5 - - -def test_result_callback(): - results = [] - - def result_callback(result, *, context=None): - results.append((result, context)) - - executor = AsyncConcurrentExecutor( - max_workers=2, - column_name="test_column", - result_callback=result_callback, - ) - work_items = [ - (_succeed(5), {"key": "a"}), - (_succeed(10), {"key": "b"}), - ] - executor.run(work_items) - - assert len(results) == 2 - values = sorted(results, key=lambda r: r[0]) - assert values[0] == (10, {"key": "a"}) - assert values[1] == (20, {"key": "b"}) - - -def test_result_callback_with_none_context(): - results = [] - - def result_callback(result, *, context=None): - results.append((result, context)) - - executor = AsyncConcurrentExecutor( - max_workers=2, - column_name="test_column", - result_callback=result_callback, - ) - work_items = [(_succeed(7), None)] - executor.run(work_items) - - assert len(results) == 1 - assert results[0] == (14, None) - - -def test_error_callback(): - errors = [] - - def error_callback(exc, *, context=None): - errors.append((exc, context)) - - executor = AsyncConcurrentExecutor( - max_workers=2, - column_name="test_column", - error_callback=error_callback, - disable_early_shutdown=True, - ) - work_items = [(_fail("boom"), {"task": "first"})] - executor.run(work_items) - - assert len(errors) == 1 - assert isinstance(errors[0][0], ValueError) - assert str(errors[0][0]) == "boom" - assert errors[0][1] == {"task": "first"} - - -def test_early_shutdown_when_threshold_exceeded(): - """Error rate exceeds threshold -- should raise DataDesignerRuntimeError.""" - executor = AsyncConcurrentExecutor( - max_workers=4, - column_name="test_column", - shutdown_error_rate=0.5, - shutdown_error_window=2, - ) - # All tasks fail -> 100% error rate, well above 50% threshold - work_items = [(_fail(f"err-{i}"), None) for i in range(10)] - - with pytest.raises(DataDesignerRuntimeError, match="Data generation was terminated early"): - executor.run(work_items) - - assert executor.results.early_shutdown is True - - -def test_no_early_shutdown_below_threshold(): - """Error rate stays below threshold -- should NOT raise.""" - executor = AsyncConcurrentExecutor( - max_workers=4, - column_name="test_column", - shutdown_error_rate=0.5, - shutdown_error_window=20, - ) - # 2 failures + 18 successes = 10% error rate, well below 50% - work_items = [(_fail(f"err-{i}"), None) for i in range(2)] + [(_succeed(i), None) for i in range(18)] - executor.run(work_items) - - assert executor.results.early_shutdown is False - assert executor.results.completed_count == 20 - assert executor.results.success_count == 18 - assert executor.results.error_trap.error_count == 2 - - -def test_disable_early_shutdown(): - """All tasks fail but disable_early_shutdown=True -- no DataDesignerRuntimeError.""" - executor = AsyncConcurrentExecutor( - max_workers=4, - column_name="test_column", - shutdown_error_rate=0.0, - shutdown_error_window=0, - disable_early_shutdown=True, - ) - work_items = [(_fail(f"err-{i}"), None) for i in range(10)] - # Should not raise - executor.run(work_items) - - assert executor.results.error_trap.error_count == 10 - assert executor.results.success_count == 0 - assert executor.results.completed_count == 10 - assert executor.results.early_shutdown is False - - -def test_result_callback_raises_counts_as_failure(): - """When result_callback raises, the task should count as a failure, not a success. - - This validates the fix where a callback exception was previously - double-counted or misattributed. The corrected behavior: if the - coroutine succeeds but the callback raises, completed_count is - incremented once and the error is trapped (success_count is NOT - incremented). - """ - - def bad_callback(result, *, context=None): - raise RuntimeError("callback exploded") - - executor = AsyncConcurrentExecutor( - max_workers=2, - column_name="test_column", - result_callback=bad_callback, - disable_early_shutdown=True, - ) - work_items = [(_succeed(i), None) for i in range(5)] - executor.run(work_items) - - # Each task's coroutine succeeds, but callback raises -> counted as failure - assert executor.results.completed_count == 5 - assert executor.results.success_count == 0 - assert executor.results.error_trap.error_count == 5 - - -def test_error_callback_raises_safely(): - """error_callback raising should not crash the executor -- just log a warning.""" - - def bad_error_callback(exc, *, context=None): - raise RuntimeError("error callback also broke") - - executor = AsyncConcurrentExecutor( - max_workers=2, - column_name="test_column", - error_callback=bad_error_callback, - disable_early_shutdown=True, - ) - work_items = [(_fail(f"err-{i}"), None) for i in range(5)] - # Should not raise despite error_callback raising - executor.run(work_items) - - assert executor.results.completed_count == 5 - assert executor.results.error_trap.error_count == 5 - - -def test_semaphore_bounding(): - """Verify concurrency is bounded by max_workers.""" - max_concurrent = 0 - current_concurrent = 0 - lock = asyncio.Lock() - - async def tracked_work(index: int) -> int: - nonlocal max_concurrent, current_concurrent - async with lock: - current_concurrent += 1 - if current_concurrent > max_concurrent: - max_concurrent = current_concurrent - # Yield control so other tasks can run concurrently - await asyncio.sleep(0.02) - async with lock: - current_concurrent -= 1 - return index - - max_workers = 3 - executor = AsyncConcurrentExecutor(max_workers=max_workers, column_name="test_column") - work_items = [(tracked_work(i), None) for i in range(20)] - executor.run(work_items) - - assert executor.results.completed_count == 20 - assert executor.results.success_count == 20 - assert max_concurrent <= max_workers, f"Max concurrent was {max_concurrent}, expected <= {max_workers}" - # Also confirm the semaphore was actually exercised (more tasks than workers) - assert max_concurrent >= 1 - - -@pytest.mark.parametrize( - "shutdown_error_rate,num_errors,num_successes,shutdown_window,should_raise", - [ - (0.5, 60, 40, 20, True), # 60% errors > 50% threshold - (0.3, 40, 60, 20, True), # 40% errors > 30% threshold - (0.0, 5, 5, 10, True), # Any error > 0% threshold - (1.0, 20, 0, 10, True), # 100% errors >= 100% threshold - (0.5, 10, 90, 20, False), # 10% errors < 50% threshold - (0.3, 10, 90, 20, False), # 10% errors < 30% threshold - (1.0, 50, 50, 20, False), # 50% errors < 100% threshold - ], -) -def test_early_shutdown_parametric(shutdown_error_rate, num_errors, num_successes, shutdown_window, should_raise): - executor = AsyncConcurrentExecutor( - max_workers=10, - column_name="test_column", - shutdown_error_rate=shutdown_error_rate, - shutdown_error_window=shutdown_window, - ) - - # Interleave errors and successes to keep error rate relatively stable - total = num_errors + num_successes - work_items = [] - err_idx = 0 - suc_idx = 0 - if num_errors > 0: - tasks_per_error = total / num_errors - else: - tasks_per_error = float("inf") - - for i in range(total): - if num_errors > 0 and err_idx < num_errors and i >= int(err_idx * tasks_per_error): - work_items.append((_fail(f"err-{err_idx}"), None)) - err_idx += 1 - elif suc_idx < num_successes: - work_items.append((_succeed(suc_idx), None)) - suc_idx += 1 - - if should_raise: - with pytest.raises(DataDesignerRuntimeError, match="Data generation was terminated early"): - executor.run(work_items) - assert executor.results.early_shutdown is True - else: - executor.run(work_items) - assert executor.results.early_shutdown is False - assert executor.results.completed_count == total - assert executor.results.success_count == num_successes - assert executor.results.error_trap.error_count == num_errors - - -def test_mixed_success_and_failure_with_callbacks(): - """Stress test: mix of successes and failures with both callbacks.""" - results_list = [] - errors_list = [] - - def result_callback(result, *, context=None): - results_list.append(result) - - def error_callback(exc, *, context=None): - errors_list.append(exc) - - executor = AsyncConcurrentExecutor( - max_workers=8, - column_name="test_column", - result_callback=result_callback, - error_callback=error_callback, - shutdown_error_rate=0.9, - shutdown_error_window=50, - ) - - async def variable_task(x: int) -> int: - if x % 7 == 0: - raise ValueError(f"Error {x}") - if x % 3 == 0: - await asyncio.sleep(0.001) - return x * 2 - - num_tasks = 100 - work_items = [(variable_task(i), None) for i in range(num_tasks)] - executor.run(work_items) - - expected_errors = sum(1 for i in range(num_tasks) if i % 7 == 0) - expected_successes = num_tasks - expected_errors - - assert executor.results.completed_count == num_tasks - assert executor.results.success_count == expected_successes - assert executor.results.error_trap.error_count == expected_errors - assert len(results_list) == expected_successes - assert len(errors_list) == expected_errors - assert executor.results.early_shutdown is False - - -# --------------------------------------------------------------------------- -# Edge cases (mirroring sync test_concurrency.py) -# --------------------------------------------------------------------------- - - -def test_edge_cases_invalid_max_workers_negative(): - """asyncio.Semaphore(-1) raises ValueError, propagated through future.result().""" - - async def ok() -> int: - return 1 - - coro = ok() - executor = AsyncConcurrentExecutor(max_workers=-1, column_name="test_column") - with pytest.raises(ValueError, match="must be >= 0"): - executor.run([(coro, None)]) - coro.close() # prevent "coroutine was never awaited" warning - - -def test_edge_cases_zero_error_window(): - """With shutdown_error_window=0, the first failure triggers immediate shutdown. - - get_error_rate returns 0.0 only when completed_count < window. With window=0, - that guard never fires, so the first error's rate (1/1 = 100%) exceeds any - non-zero threshold. - """ - executor = AsyncConcurrentExecutor( - max_workers=1, # deterministic ordering - column_name="test_column", - shutdown_error_rate=0.5, - shutdown_error_window=0, - ) - - async def fail() -> None: - raise ValueError("boom") - - async def succeed() -> str: - return "ok" - - with pytest.raises(DataDesignerRuntimeError, match="Data generation was terminated early"): - executor.run([(fail(), None), (succeed(), None), (succeed(), None)]) - - assert executor.results.early_shutdown is True - assert executor.results.completed_count == 1 - assert executor.results.error_trap.error_count == 1 - assert executor.results.success_count == 0 - - -def test_edge_cases_multiple_early_shutdown_skips_pending(): - """After shutdown fires, remaining tasks are skipped via _shutdown_event check.""" - executor = AsyncConcurrentExecutor( - max_workers=1, # sequential execution for deterministic counts - column_name="test_column", - shutdown_error_rate=0.5, - shutdown_error_window=2, - ) - - async def fail() -> None: - raise ValueError("boom") - - async def succeed() -> int: - return 1 - - # 2 failures then 28 successes — shutdown should fire after the 2 failures - work = [(fail(), None), (fail(), None)] + [(succeed(), None) for _ in range(28)] - - with pytest.raises(DataDesignerRuntimeError, match="Data generation was terminated early"): - executor.run(work) - - assert executor.results.early_shutdown is True - # Only the tasks that actually executed get counted - assert executor.results.completed_count <= 3 # at most 2 failures + maybe 1 success - assert executor.results.error_trap.error_count == 2 - # Skipped tasks should NOT inflate completed_count - assert executor.results.completed_count < 30 - - -def test_edge_cases_semaphore_release_on_exception(): - """Verify semaphore is released after a failing task, allowing the next task to run. - - With max_workers=1, if the semaphore weren't released on exception, the second - task would deadlock. - """ - results = [] - errors = [] - - def result_cb(result, *, context=None): - results.append((result, context)) - - def error_cb(exc, *, context=None): - errors.append((type(exc).__name__, str(exc), context)) - - executor = AsyncConcurrentExecutor( - max_workers=1, - column_name="test_column", - result_callback=result_cb, - error_callback=error_cb, - shutdown_error_rate=1.0, # high threshold to avoid early shutdown - shutdown_error_window=10, - ) - - async def fail() -> None: - raise ValueError("boom") - - async def succeed() -> str: - return "ok" - - executor.run([(fail(), {"id": "fail"}), (succeed(), {"id": "ok"})]) - - assert executor.results.early_shutdown is False - assert executor.results.completed_count == 2 - assert executor.results.error_trap.error_count == 1 - assert executor.results.success_count == 1 - assert len(errors) == 1 - assert errors[0] == ("ValueError", "boom", {"id": "fail"}) - assert len(results) == 1 - assert results[0] == ("ok", {"id": "ok"}) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py index 6a5b31a51..c46502aad 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py @@ -20,7 +20,6 @@ from data_designer.config.utils.code_lang import CodeLang from data_designer.config.validator_params import CodeValidatorParams from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig -from data_designer.engine.dataset_builders.scheduling.task_model import SliceRef from data_designer.engine.dataset_builders.utils.errors import ConfigCompilationError, DAGCircularDependencyError from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph @@ -297,48 +296,6 @@ def test_add_column_duplicate_raises() -> None: graph.add_column("col_a", GenerationStrategy.FULL_COLUMN) -# -- Cell dependencies ------------------------------------------------------ - - -def test_cell_deps_cell_by_cell_upstream(simple_graph: ExecutionGraph) -> None: - """question depends on topic (full-column); answer depends on question (cell-by-cell).""" - # answer[rg=0, row=2] should depend on question[rg=0, row=2] - deps = simple_graph.compute_cell_dependencies("answer", row_group=0, row_index=2, row_group_size=5) - assert deps == [SliceRef("question", 0, 2)] - - -def test_cell_deps_full_column_upstream(simple_graph: ExecutionGraph) -> None: - """question depends on topic (full-column).""" - deps = simple_graph.compute_cell_dependencies("question", row_group=0, row_index=1, row_group_size=5) - assert deps == [SliceRef("topic", 0, None)] - - -def test_cell_deps_no_upstream(simple_graph: ExecutionGraph) -> None: - """topic has no upstream.""" - deps = simple_graph.compute_cell_dependencies("topic", row_group=0, row_index=None, row_group_size=5) - assert deps == [] - - -def test_cell_deps_full_column_downstream_of_cell_by_cell(simple_graph: ExecutionGraph) -> None: - """score (full-column) depends on answer (cell-by-cell) → needs ALL rows.""" - deps = simple_graph.compute_cell_dependencies("score", row_group=0, row_index=None, row_group_size=3) - assert sorted(deps) == [SliceRef("answer", 0, 0), SliceRef("answer", 0, 1), SliceRef("answer", 0, 2)] - - -# -- Mermaid output ---------------------------------------------------------- - - -def test_to_mermaid(simple_graph: ExecutionGraph) -> None: - mermaid = simple_graph.to_mermaid() - - assert "graph TD" in mermaid - assert 'topic["topic [full_column]"]' in mermaid - assert 'question["question [cell_by_cell]"]' in mermaid - assert "topic --> question" in mermaid - assert "question --> answer" in mermaid - assert "answer --> score" in mermaid - - # -- MultiColumnConfig ------------------------------------------------------- diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py index 08dad3730..a80b160f7 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py @@ -49,15 +49,6 @@ def test_update_cell() -> None: assert mgr.get_row(0, 1) == {"col_a": "val_1"} -def test_update_cells() -> None: - mgr = RowGroupBufferManager(_mock_artifact_storage()) - mgr.init_row_group(0, 1) - - mgr.update_cells(0, 0, {"col_a": "a", "col_b": "b"}) - - assert mgr.get_row(0, 0) == {"col_a": "a", "col_b": "b"} - - def test_update_batch() -> None: mgr = RowGroupBufferManager(_mock_artifact_storage()) mgr.init_row_group(0, 3) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_skip_tracker.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_skip_tracker.py index cf833c8e7..1f2944e81 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_skip_tracker.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_skip_tracker.py @@ -9,8 +9,6 @@ from data_designer.engine.dataset_builders.utils.skip_tracker import ( SKIPPED_COLUMNS_RECORD_KEY, apply_skip_to_record, - prepare_records_for_skip_metadata_round_trip, - restore_skip_metadata, strip_skip_metadata_for_dataframe_row, strip_skip_metadata_from_records, ) @@ -131,77 +129,3 @@ def test_strip_skip_metadata_for_dataframe_row_no_metadata() -> None: ) def test_strip_skip_metadata_from_records(rows: list[dict], expected: list[dict]) -> None: assert strip_skip_metadata_from_records(rows) == expected - - -def test_prepare_records_for_skip_metadata_round_trip_without_metadata() -> None: - rows = [{"a": 1}, {"a": 2}] - prepared_rows, restore_context = prepare_records_for_skip_metadata_round_trip(rows) - assert restore_context is None - assert prepared_rows == rows - assert prepared_rows is not rows - - -def test_prepare_records_for_skip_metadata_round_trip_injects_restore_ids() -> None: - rows = [ - {"a": 1, SKIPPED_COLUMNS_RECORD_KEY: {"col_x"}}, - {"a": 2}, - {"a": 3, SKIPPED_COLUMNS_RECORD_KEY: {"col_y", "col_z"}}, - ] - prepared_rows, restore_context = prepare_records_for_skip_metadata_round_trip(rows) - assert restore_context is not None - assert SKIPPED_COLUMNS_RECORD_KEY not in prepared_rows[0] - assert restore_context.restore_id_column in prepared_rows[0] - assert restore_context.skipped_columns_by_source_id == { - "0": {"col_x"}, - "2": {"col_y", "col_z"}, - } - - -def test_restore_skip_metadata_uses_restore_ids_after_reorder() -> None: - old = [ - {"a": 1, SKIPPED_COLUMNS_RECORD_KEY: {"col_x"}}, - {"a": 2}, - {"a": 3, SKIPPED_COLUMNS_RECORD_KEY: {"col_z"}}, - ] - prepared_rows, restore_context = prepare_records_for_skip_metadata_round_trip(old) - assert restore_context is not None - restore_id_column = restore_context.restore_id_column - - new = [ - {"a": 30, restore_id_column: prepared_rows[2][restore_id_column]}, - {"a": 10, restore_id_column: prepared_rows[0][restore_id_column]}, - {"a": 20, restore_id_column: prepared_rows[1][restore_id_column]}, - ] - restore_skip_metadata(new, context=restore_context) - - assert new[0][SKIPPED_COLUMNS_RECORD_KEY] == {"col_z"} - assert new[1][SKIPPED_COLUMNS_RECORD_KEY] == {"col_x"} - assert SKIPPED_COLUMNS_RECORD_KEY not in new[2] - - -def test_restore_skip_metadata_rejects_filtered_rows() -> None: - old = [{"a": 1}, {"a": 2}] - prepared_rows, restore_context = prepare_records_for_skip_metadata_round_trip(old) - assert restore_context is None - - old = [ - {"a": 1, SKIPPED_COLUMNS_RECORD_KEY: {"col_x"}}, - {"a": 2}, - ] - prepared_rows, restore_context = prepare_records_for_skip_metadata_round_trip(old) - assert restore_context is not None - restore_id_column = restore_context.restore_id_column - - new = [{"a": 20, restore_id_column: prepared_rows[1][restore_id_column]}] - - with pytest.raises(ValueError, match="1:1 mapping"): - restore_skip_metadata(new, context=restore_context) - - -def test_restore_skip_metadata_rejects_missing_restore_id_column() -> None: - old = [{"a": 1, SKIPPED_COLUMNS_RECORD_KEY: {"col_x"}}] - _prepared_rows, restore_context = prepare_records_for_skip_metadata_round_trip(old) - assert restore_context is not None - - with pytest.raises(ValueError, match="must preserve the internal column"): - restore_skip_metadata([{"a": 10}], context=restore_context) diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py b/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py index dc08a55ff..6e0820078 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py @@ -10,7 +10,6 @@ from data_designer.engine.models.clients.errors import ( ProviderError, ProviderErrorKind, - extract_message_from_exception_string, map_http_error_to_provider_error, map_http_status_to_provider_error_kind, ) @@ -244,46 +243,3 @@ def test_map_http_error_retry_after_returns_none_for_garbage() -> None: ) error = map_http_error_to_provider_error(response=response, provider_name="stub-provider") assert error.retry_after is None - - -@pytest.mark.parametrize( - "raw,expected", - [ - ( - "Error code: 400 - {'error': {'message': 'Context length exceeded', 'type': 'invalid_request_error'}}".replace( - "'", '"' - ), - "Context length exceeded", - ), - ( - 'Error code: 403 - {"error": "Insufficient permissions"}', - "Insufficient permissions", - ), - ( - 'Error code: 500 - {"message": "Internal failure"}', - "Internal failure", - ), - ( - 'Error code: 422 - {"detail": "Unprocessable entity"}', - "Unprocessable entity", - ), - ( - "Connection timed out", - "Connection timed out", - ), - ( - "Error code: 400 - {not valid json", - "Error code: 400 - {not valid json", - ), - ], - ids=[ - "nested-error-message", - "top-level-error-string", - "top-level-message-string", - "top-level-detail-string", - "no-json-passthrough", - "malformed-json-passthrough", - ], -) -def test_extract_message_from_exception_string(raw: str, expected: str) -> None: - assert extract_message_from_exception_string(raw) == expected diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_model_request_executor.py b/packages/data-designer-engine/tests/engine/models/clients/test_model_request_executor.py index 4dbae62e2..e193617ed 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_model_request_executor.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_model_request_executor.py @@ -28,7 +28,7 @@ RequestAdmissionError, ) from data_designer.engine.models.request_admission.resources import RequestAdmissionItem, RequestDomain -from data_designer.engine.observability import InMemoryAdmissionEventSink +from data_designer.engine.testing import InMemoryAdmissionEventSink class _Client: diff --git a/packages/data-designer-engine/tests/engine/models/parsers/test_parser.py b/packages/data-designer-engine/tests/engine/models/parsers/test_parser.py index 613194dfe..9e57b08eb 100644 --- a/packages/data-designer-engine/tests/engine/models/parsers/test_parser.py +++ b/packages/data-designer-engine/tests/engine/models/parsers/test_parser.py @@ -3,11 +3,8 @@ from __future__ import annotations -from pydantic import BaseModel - from data_designer.engine.models.parsers.parser import LLMResponseParser from data_designer.engine.models.parsers.postprocessors import ( - RealizePydanticTypes, deserialize_json_code, merge_text_blocks, ) @@ -15,7 +12,6 @@ from data_designer.engine.models.parsers.types import ( CodeBlock, LLMStructuredResponse, - PydanticTypeBlock, StructuredDataBlock, TextBlock, ) @@ -82,7 +78,7 @@ def test_llm_response_parser_markup_passthrough(): assert block.text == text -def test_llm_response_parser_full_pipeline(): +def test_llm_response_parser_deserializes_json_blocks(): text = """\ Test prompt return. The return has some `code` included. ```json @@ -97,17 +93,10 @@ def test_llm_response_parser_full_pipeline(): That is all there is at the moment.\ """ - class Foo(BaseModel): - baz: int - - class Bar(BaseModel): - foos: list[Foo] - parser = LLMResponseParser( postprocessors=[ merge_text_blocks, deserialize_json_code, - RealizePydanticTypes([Foo, Bar]), ] ) result = parser.parse(text) @@ -119,12 +108,12 @@ class Bar(BaseModel): assert result.parsed[0] == TextBlock(text="Test prompt return. The return has some `code` included.") assert isinstance(result.parsed[1], StructuredDataBlock) - assert isinstance(result.parsed[2], PydanticTypeBlock) - assert isinstance(result.parsed[3], PydanticTypeBlock) + assert isinstance(result.parsed[2], StructuredDataBlock) + assert isinstance(result.parsed[3], StructuredDataBlock) assert result.parsed[1].obj == {"asdf": 42} - assert result.parsed[2].obj == Foo(baz=3) - assert result.parsed[3].obj == Bar(foos=[Foo(baz=1), Foo(baz=2)]) + assert result.parsed[2].obj == {"baz": 3} + assert result.parsed[3].obj == {"foos": [{"baz": 1}, {"baz": 2}]} assert result.parsed[4] == TextBlock(text="That is all there is at the moment.") diff --git a/packages/data-designer-engine/tests/engine/models/parsers/test_postprocessors.py b/packages/data-designer-engine/tests/engine/models/parsers/test_postprocessors.py index 88cdd0296..67096594b 100644 --- a/packages/data-designer-engine/tests/engine/models/parsers/test_postprocessors.py +++ b/packages/data-designer-engine/tests/engine/models/parsers/test_postprocessors.py @@ -1,14 +1,11 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from pydantic import BaseModel - import data_designer.engine.models.parsers.postprocessors as post from data_designer.engine.models.parsers.types import ( CodeBlock, LLMStructuredResponse, PostProcessor, - PydanticTypeBlock, StructuredDataBlock, TextBlock, ) @@ -16,7 +13,6 @@ KNOWN_POSTPROCESSORS = [ post.merge_text_blocks, post.deserialize_json_code, - post.RealizePydanticTypes, ] @@ -62,61 +58,3 @@ def test_deserialize_json_code(): assert isinstance(response.parsed[2], StructuredDataBlock) assert response.parsed[2].obj == {"bar": 43, "baz": [1, 2, 3]} - - -def test_realize_pydantic_types(): - class Foo(BaseModel): - baz: int - - class Bar(BaseModel): - foos: list[Foo] - - blocks = [ - TextBlock(text="abc"), - StructuredDataBlock(serialized="", obj={"asdf": "q"}), - StructuredDataBlock(serialized="", obj={"baz": 42}), - StructuredDataBlock(serialized="", obj={"foos": [{"baz": 1}, {"baz": 2}]}), - TextBlock(text="cba"), - ] - - response = LLMStructuredResponse(response="", markup="", parsed=blocks) - - parser = post.RealizePydanticTypes([Foo, Bar]) - response = parser(response) - - assert len(response.parsed) == 5 - assert isinstance(response.parsed[1], StructuredDataBlock) - assert isinstance(response.parsed[2], PydanticTypeBlock) - assert isinstance(response.parsed[3], PydanticTypeBlock) - - assert response.parsed[2].obj == Foo(baz=42) - assert response.parsed[3].obj == Bar(foos=[Foo(baz=1), Foo(baz=2)]) - - -def test_deserialize_realize_pipeline(): - class Foo(BaseModel): - baz: int - - class Bar(BaseModel): - foos: list[Foo] - - blocks = [ - TextBlock(text="abc"), - CodeBlock(code='{"asdf": "q"}', code_lang="json"), - CodeBlock(code='{"baz": 42}', code_lang="json"), - CodeBlock(code='{"foos": [{"baz": 1}, {"baz": 2}]}', code_lang="json"), - TextBlock(text="cba"), - ] - - response = LLMStructuredResponse(response="", markup="", parsed=blocks) - - parser = post.RealizePydanticTypes([Foo, Bar]) - response = parser(post.deserialize_json_code(response)) - - assert len(response.parsed) == 5 - assert isinstance(response.parsed[1], StructuredDataBlock) - assert isinstance(response.parsed[2], PydanticTypeBlock) - assert isinstance(response.parsed[3], PydanticTypeBlock) - - assert response.parsed[2].obj == Foo(baz=42) - assert response.parsed[3].obj == Bar(foos=[Foo(baz=1), Foo(baz=2)]) diff --git a/packages/data-designer-engine/tests/engine/models/request_admission/test_controller.py b/packages/data-designer-engine/tests/engine/models/request_admission/test_controller.py index af77f8c40..5b165666c 100644 --- a/packages/data-designer-engine/tests/engine/models/request_admission/test_controller.py +++ b/packages/data-designer-engine/tests/engine/models/request_admission/test_controller.py @@ -25,7 +25,7 @@ RequestGroupSpec, RequestResourceKey, ) -from data_designer.engine.observability import InMemoryAdmissionEventSink +from data_designer.engine.testing import InMemoryAdmissionEventSink def _item(domain: RequestDomain = RequestDomain.CHAT, timeout: float | None = None) -> RequestAdmissionItem: diff --git a/packages/data-designer-engine/tests/engine/resources/test_resource_provider.py b/packages/data-designer-engine/tests/engine/resources/test_resource_provider.py index cb5d569f9..1d9c86530 100644 --- a/packages/data-designer-engine/tests/engine/resources/test_resource_provider.py +++ b/packages/data-designer-engine/tests/engine/resources/test_resource_provider.py @@ -5,11 +5,10 @@ import pytest -from data_designer.config.mcp import LocalStdioMCPProvider, MCPProvider, ToolConfig +from data_designer.config.mcp import LocalStdioMCPProvider, ToolConfig from data_designer.engine.models.registry import ModelRegistry from data_designer.engine.resources.resource_provider import ( ResourceProvider, - _validate_tool_configs_against_providers, create_resource_provider, ) from data_designer.engine.storage.artifact_storage import ArtifactStorage @@ -58,48 +57,6 @@ def test_create_resource_provider_error_cases(test_case, expected_error, tmp_pat class TestToolConfigValidation: """Tests for ToolConfig validation against MCP providers.""" - def test_valid_tool_config_with_existing_providers(self) -> None: - """Valid tool config passes when all providers exist.""" - providers = [ - MCPProvider(name="mcp-1", endpoint="http://localhost:8080/sse"), - LocalStdioMCPProvider(name="mcp-2", command="python", args=["-m", "server"]), - ] - tool_configs = [ - ToolConfig(tool_alias="tools-1", providers=["mcp-1"]), - ToolConfig(tool_alias="tools-2", providers=["mcp-1", "mcp-2"]), - ] - - # Should not raise - _validate_tool_configs_against_providers(tool_configs, providers) - - def test_tool_config_with_missing_provider_raises_error(self) -> None: - """Tool config referencing non-existent provider raises ValueError.""" - providers = [ - MCPProvider(name="mcp-1", endpoint="http://localhost:8080/sse"), - ] - tool_configs = [ - ToolConfig(tool_alias="search-tools", providers=["mcp-1", "nonexistent-mcp"]), - ] - - with pytest.raises(ValueError, match="ToolConfig 'search-tools' references provider"): - _validate_tool_configs_against_providers(tool_configs, providers) - - def test_tool_config_with_no_providers_available(self) -> None: - """Tool config fails when no MCP providers are configured.""" - tool_configs = [ - ToolConfig(tool_alias="search-tools", providers=["some-mcp"]), - ] - - with pytest.raises(ValueError, match="not registered.*none configured"): - _validate_tool_configs_against_providers(tool_configs, []) - - def test_empty_tool_configs_passes(self) -> None: - """Empty tool configs list passes validation.""" - providers = [MCPProvider(name="mcp-1", endpoint="http://localhost:8080/sse")] - - # Should not raise - _validate_tool_configs_against_providers([], providers) - def test_tool_config_validation_happens_during_health_check(self, tmp_path: str) -> None: """Tool config validation is deferred to health checks.""" artifact_storage = ArtifactStorage(artifact_path=str(tmp_path), dataset_name="test") diff --git a/packages/data-designer-engine/tests/engine/test_capacity.py b/packages/data-designer-engine/tests/engine/test_capacity.py deleted file mode 100644 index 856aeba09..000000000 --- a/packages/data-designer-engine/tests/engine/test_capacity.py +++ /dev/null @@ -1,74 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from data_designer.engine.capacity import ( - AsyncCapacityConfigured, - AsyncCapacityObservedMaxima, - AsyncCapacityPlan, - AsyncCapacityRuntimeSnapshot, - CapacityValue, - RequestAdmissionConfigSnapshot, - RowGroupAdmission, -) -from data_designer.engine.models.request_admission.config import RequestAdmissionConfig -from data_designer.engine.models.request_admission.resources import RequestDomain, RequestResourceKey -from data_designer.engine.models.resources import ProviderModelKey, ProviderModelStaticCap - - -def test_request_admission_config_snapshot_records_resources() -> None: - resource = RequestResourceKey("nvidia", "nemotron", RequestDomain.CHAT) - config = RequestAdmissionConfig( - initial_limits={resource: 2}, - max_limit_clamps={resource: 4}, - startup_ramp_seconds=30.0, - ) - - snapshot = RequestAdmissionConfigSnapshot.from_config(config) - - assert snapshot.resources == (resource,) - assert snapshot.initial_limits[resource] == 2 - assert snapshot.max_limit_clamps[resource] == 4 - assert snapshot.startup_ramp_seconds == 30.0 - - -def test_async_capacity_plan_records_configured_runtime_and_maxima() -> None: - resource = RequestResourceKey("nvidia", "nemotron", RequestDomain.CHAT) - provider_model = ProviderModelKey("nvidia", "nemotron") - static_cap = ProviderModelStaticCap(cap=4, aliases=("default",), raw_caps={"default": 4}) - - plan = AsyncCapacityPlan( - configured=AsyncCapacityConfigured( - buffer_size=CapacityValue(value=16, source="run_config"), - row_group_admission=RowGroupAdmission( - row_group_concurrency=CapacityValue(value=2, source="dataset_builder"), - observed_in_flight=1, - ), - submission_capacity=CapacityValue(value=8, source="engine_internal_config"), - task_resource_limits=CapacityValue(value={"submission": 8, "llm_wait": 4}, source="engine_internal_config"), - request_resources=CapacityValue(value=(resource,), source="runtime_snapshot"), - provider_model_static_caps=CapacityValue(value={provider_model: static_cap}, source="model_metadata"), - request_domain_initial_limits=CapacityValue(value={resource: 2}, source="engine_internal_config"), - request_admission_config=CapacityValue( - value=RequestAdmissionConfigSnapshot.from_config(RequestAdmissionConfig(initial_limits={resource: 2})), - source="engine_internal_config", - ), - transport_pool_limits=CapacityValue(value={provider_model: 8}, source="adapter_config"), - ), - runtime_snapshot=AsyncCapacityRuntimeSnapshot( - request_domain_current_limits={resource: 2}, - request_domain_effective_max={resource: 4}, - request_domain_blocked_until={resource: None}, - provider_model_aggregate_in_flight={provider_model: 0}, - ), - observed_maxima=AsyncCapacityObservedMaxima( - row_groups_in_flight=1, - request_in_flight_by_resource={resource: 2}, - provider_model_aggregate_in_flight={provider_model: 2}, - ), - ) - - assert plan.configured.provider_model_static_caps.value[provider_model].merge_rule == "min_same_endpoint" - assert plan.runtime_snapshot.request_domain_current_limits[resource] == 2 - assert plan.observed_maxima.provider_model_aggregate_in_flight[provider_model] == 2 diff --git a/packages/data-designer-engine/tests/engine/test_engine_errors.py b/packages/data-designer-engine/tests/engine/test_engine_errors.py index 9a9e794f2..a84e99540 100644 --- a/packages/data-designer-engine/tests/engine/test_engine_errors.py +++ b/packages/data-designer-engine/tests/engine/test_engine_errors.py @@ -6,7 +6,6 @@ ErrorTrap, RemoteValidationSchemaError, SecretResolutionError, - UnknownModelAliasError, UnknownProviderError, ) @@ -14,7 +13,6 @@ def test_error_message(): test_cases = [ (DataDesignerRuntimeError, "Runtime error occurred"), - (UnknownModelAliasError, "Unknown model alias"), (UnknownProviderError, "Unknown provider"), (SecretResolutionError, "Secret resolution failed"), (RemoteValidationSchemaError, "Remote validation schema error"), @@ -41,7 +39,7 @@ def test_error_trap_track_error(): error1 = DataDesignerRuntimeError("Error 1") error2 = DataDesignerRuntimeError("Error 2") - error3 = UnknownModelAliasError("Error 3") + error3 = SecretResolutionError("Error 3") error_trap.handle_error(error1) error_trap.handle_error(error2) @@ -49,7 +47,7 @@ def test_error_trap_track_error(): assert error_trap.error_count == 3 assert error_trap.task_errors["DataDesignerRuntimeError"] == 2 - assert error_trap.task_errors["UnknownModelAliasError"] == 1 + assert error_trap.task_errors["SecretResolutionError"] == 1 def test_error_trap_model_dump(): diff --git a/packages/data-designer-engine/tests/engine/test_observability.py b/packages/data-designer-engine/tests/engine/test_observability.py index e7d9ce21b..78ec66424 100644 --- a/packages/data-designer-engine/tests/engine/test_observability.py +++ b/packages/data-designer-engine/tests/engine/test_observability.py @@ -8,13 +8,12 @@ from enum import Enum from data_designer.engine.observability import ( - CorrelatedRuntimeView, - InMemoryAdmissionEventSink, RequestAdmissionEvent, RuntimeCorrelation, RuntimeCorrelationProvider, SchedulerAdmissionEvent, ) +from data_designer.engine.testing import CorrelatedRuntimeView, InMemoryAdmissionEventSink class _DiagnosticMode(Enum): diff --git a/packages/data-designer/src/data_designer/cli/repositories/persona_repository.py b/packages/data-designer/src/data_designer/cli/repositories/persona_repository.py index d1ec32f65..65541467f 100644 --- a/packages/data-designer/src/data_designer/cli/repositories/persona_repository.py +++ b/packages/data-designer/src/data_designer/cli/repositories/persona_repository.py @@ -68,23 +68,3 @@ def get_by_code(self, code: str) -> PersonaLocale | None: PersonaLocale if found, None otherwise """ return next((locale for locale in self._registry.locales if locale.code == code), None) - - def get_dataset_name(self, code: str) -> str | None: - """Get the NGC dataset name for a locale. - - Args: - code: Locale code (e.g., 'en_US', 'ja_JP') - - Returns: - Dataset name if locale exists, None otherwise - """ - locale = self.get_by_code(code) - return locale.dataset_name if locale else None - - def get_dataset_prefix(self) -> str: - """Get the dataset prefix for all persona datasets. - - Returns: - Dataset prefix string - """ - return self._registry.dataset_prefix diff --git a/packages/data-designer/src/data_designer/cli/services/plugin_install_service.py b/packages/data-designer/src/data_designer/cli/services/plugin_install_service.py index ea2fa5236..159152f1b 100644 --- a/packages/data-designer/src/data_designer/cli/services/plugin_install_service.py +++ b/packages/data-designer/src/data_designer/cli/services/plugin_install_service.py @@ -166,10 +166,6 @@ def uninstall(self, plan: UninstallPlan) -> None: if return_code != 0: raise RuntimeError(f"Plugin package uninstaller exited with status {return_code}") - def verify_entry_point(self, entry: PluginCatalogEntry) -> bool: - """Verify the runtime plugin's declared entry point is installed and loadable.""" - return self.verify_entry_points([entry]) - def verify_entry_points(self, entries: list[PluginCatalogEntry]) -> bool: """Verify every declared runtime entry point for an installed catalog package can load.""" if not entries: diff --git a/packages/data-designer/src/data_designer/cli/ui.py b/packages/data-designer/src/data_designer/cli/ui.py index e65cbafd2..717f18fe5 100644 --- a/packages/data-designer/src/data_designer/cli/ui.py +++ b/packages/data-designer/src/data_designer/cli/ui.py @@ -634,13 +634,6 @@ def print_header(text: str) -> None: _console.print() -def print_navigation_tip() -> None: - """Display a concise navigation tip for interactive prompts.""" - tip = "[dim]Tip: Use arrow keys to navigate menus, type [bold]'back'[/bold] to edit previous entries, press [bold]Tab[/bold] for completions[/dim]" - _print_with_padding(tip) - _console.print() - - def _print_with_padding(content: str | Panel) -> None: """Internal helper to print with left padding. diff --git a/packages/data-designer/tests/cli/repositories/test_persona_repository.py b/packages/data-designer/tests/cli/repositories/test_persona_repository.py index 905a26c91..06ffe9038 100644 --- a/packages/data-designer/tests/cli/repositories/test_persona_repository.py +++ b/packages/data-designer/tests/cli/repositories/test_persona_repository.py @@ -92,18 +92,6 @@ def test_get_by_code_case_sensitive(repository: PersonaRepository) -> None: assert locale is None -def test_get_dataset_name_valid_locale(repository: PersonaRepository) -> None: - """Test getting dataset name for valid locale.""" - dataset_name = repository.get_dataset_name("en_US") - assert dataset_name == "nemotron-personas-dataset-en_us" - - -def test_get_dataset_name_invalid_locale(repository: PersonaRepository) -> None: - """Test getting dataset name for invalid locale returns None.""" - dataset_name = repository.get_dataset_name("invalid_locale") - assert dataset_name is None - - def test_get_dataset_name_lowercase_conversion(repository: PersonaRepository) -> None: """Test that dataset names use lowercase locale codes.""" # Verify that mixed-case locale codes result in lowercase dataset names @@ -113,12 +101,6 @@ def test_get_dataset_name_lowercase_conversion(repository: PersonaRepository) -> assert locale.dataset_name.islower() or "_" in locale.dataset_name -def test_get_dataset_prefix(repository: PersonaRepository) -> None: - """Test getting dataset prefix.""" - prefix = repository.get_dataset_prefix() - assert prefix == "nemotron-personas-dataset-" - - def test_persona_locale_model() -> None: """Test PersonaLocale Pydantic model.""" locale = PersonaLocale( @@ -172,7 +154,7 @@ def test_locale_size_formats(repository: PersonaRepository) -> None: def test_dataset_name_consistency(repository: PersonaRepository) -> None: """Test that all dataset names follow consistent pattern.""" locales = repository.list_all() - prefix = repository.get_dataset_prefix() + prefix = "nemotron-personas-dataset-" for locale in locales: # All dataset names should start with the prefix diff --git a/packages/data-designer/tests/cli/services/test_plugin_install_service.py b/packages/data-designer/tests/cli/services/test_plugin_install_service.py index ea0744a91..c63639297 100644 --- a/packages/data-designer/tests/cli/services/test_plugin_install_service.py +++ b/packages/data-designer/tests/cli/services/test_plugin_install_service.py @@ -1007,7 +1007,7 @@ def runner(command: list[str], stdin_text: str | None) -> int: @patch("data_designer.cli.services.plugin_install_service.importlib.metadata.entry_points") @patch("data_designer.cli.services.plugin_install_service.importlib.invalidate_caches") -def test_verify_entry_point_invalidates_caches_and_checks_declared_entry_point( +def test_verify_entry_points_invalidates_caches_and_checks_declared_entry_point( mock_invalidate_caches: Mock, mock_entry_points: Mock, ) -> None: @@ -1023,7 +1023,7 @@ def test_verify_entry_point_invalidates_caches_and_checks_declared_entry_point( ] service = PluginInstallService() - assert service.verify_entry_point(entry) is True + assert service.verify_entry_points([entry]) is True mock_invalidate_caches.assert_called_once_with() mock_entry_points.assert_called_once_with(group="data_designer.plugins") diff --git a/packages/data-designer/tests/cli/test_cli_utils.py b/packages/data-designer/tests/cli/test_cli_utils.py index 7bb8eabdf..f7a72030f 100644 --- a/packages/data-designer/tests/cli/test_cli_utils.py +++ b/packages/data-designer/tests/cli/test_cli_utils.py @@ -14,23 +14,12 @@ ) from data_designer.config.errors import InvalidConfigError, InvalidFileFormatError, InvalidFilePathError from data_designer.config.utils.io_helpers import ( - ensure_config_dir_exists, is_http_url, load_config_file, save_config_file, ) -def test_ensure_config_dir_exists(tmp_path: Path) -> None: - """Test creating config directory.""" - test_dir = tmp_path / "test_config" - assert not test_dir.exists() - - ensure_config_dir_exists(test_dir) - assert test_dir.exists() - assert test_dir.is_dir() - - def test_save_and_load_config_file(tmp_path: Path) -> None: """Test saving and loading config files.""" config_file = tmp_path / "test_config.yaml" From 7ad5acc311c1ea0266e3a3f3e3af4c28f00af459 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Wed, 1 Jul 2026 15:07:42 -0600 Subject: [PATCH 2/2] fix: preserve active capacity diagnostics Restore the capacity-plan module, scheduler observation state, completion APIs, and their tests because open PRs #745 and #743 actively extend or assert these artifacts. Signed-off-by: Nabin Mulepati --- .../src/data_designer/engine/capacity.py | 123 ++++++++++++ .../dataset_builders/async_scheduler.py | 189 ++++++++++++++++++ .../dataset_builders/scheduling/completion.py | 42 +++- .../scheduling/test_completion.py | 109 +++++++--- .../dataset_builders/test_async_scheduler.py | 102 ++++++++++ .../tests/engine/test_capacity.py | 74 +++++++ 6 files changed, 611 insertions(+), 28 deletions(-) create mode 100644 packages/data-designer-engine/src/data_designer/engine/capacity.py create mode 100644 packages/data-designer-engine/tests/engine/test_capacity.py diff --git a/packages/data-designer-engine/src/data_designer/engine/capacity.py b/packages/data-designer-engine/src/data_designer/engine/capacity.py new file mode 100644 index 000000000..e10a729e7 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/capacity.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from typing import Generic, Literal, TypeVar + +from data_designer.engine.dataset_builders.scheduling.resources import SchedulerResourceKey, TaskGroupKey +from data_designer.engine.models.request_admission.config import RequestAdmissionConfig +from data_designer.engine.models.request_admission.resources import RequestResourceKey +from data_designer.engine.models.resources import ProviderModelKey, ProviderModelStaticCap + +_T = TypeVar("_T") + +CapacityValueSource = Literal[ + "default", + "run_config", + "dataset_builder", + "model_metadata", + "engine_internal_config", + "adapter_config", + "environment", + "runtime_snapshot", + "benchmark_override", +] + + +@dataclass(frozen=True) +class CapacityValue(Generic[_T]): + value: _T | None + source: CapacityValueSource + fallback_from: str | None = None + missing_reason: str | None = None + + +@dataclass(frozen=True) +class RowGroupAdmission: + row_group_concurrency: CapacityValue[int] + observed_in_flight: int | None = None + mode: Literal["fixed", "adaptive"] = "fixed" + target_in_flight: int | None = None + observed_max_target: int | None = None + max_admitted_rows: int | None = None + blocked_reasons: Mapping[str, int] = field(default_factory=dict) + + +@dataclass(frozen=True) +class RequestAdmissionConfigSnapshot: + resources: Sequence[RequestResourceKey] + initial_limits: Mapping[RequestResourceKey, int] + max_limit_clamps: Mapping[RequestResourceKey, int | None] + cooldown_seconds: float + multiplicative_decrease_factor: float + additive_increase_step: int + successes_until_increase: int + startup_ramp_seconds: float + default_queue_wait_timeout_seconds: float | None + + @classmethod + def from_config(cls, config: RequestAdmissionConfig) -> RequestAdmissionConfigSnapshot: + resources = tuple(sorted({*config.initial_limits, *config.max_limit_clamps})) + return cls( + resources=resources, + initial_limits=dict(config.initial_limits), + max_limit_clamps=dict(config.max_limit_clamps), + cooldown_seconds=config.cooldown_seconds, + multiplicative_decrease_factor=config.multiplicative_decrease_factor, + additive_increase_step=config.additive_increase_step, + successes_until_increase=config.successes_until_increase, + startup_ramp_seconds=config.startup_ramp_seconds, + default_queue_wait_timeout_seconds=config.default_queue_wait_timeout_seconds, + ) + + +@dataclass(frozen=True) +class AsyncCapacityConfigured: + buffer_size: CapacityValue[int] + row_group_admission: RowGroupAdmission + submission_capacity: CapacityValue[int] + task_resource_limits: CapacityValue[Mapping[SchedulerResourceKey, int]] + request_resources: CapacityValue[Sequence[RequestResourceKey]] + provider_model_static_caps: CapacityValue[Mapping[ProviderModelKey, ProviderModelStaticCap]] + request_domain_initial_limits: CapacityValue[Mapping[RequestResourceKey, int]] + request_admission_config: CapacityValue[RequestAdmissionConfigSnapshot] + transport_pool_limits: CapacityValue[Mapping[ProviderModelKey, int]] + + +@dataclass(frozen=True) +class AsyncCapacityRuntimeSnapshot: + request_domain_current_limits: Mapping[RequestResourceKey, int] | None = None + request_domain_effective_max: Mapping[RequestResourceKey, int] | None = None + request_domain_blocked_until: Mapping[RequestResourceKey, float | None] | None = None + provider_model_aggregate_in_flight: Mapping[ProviderModelKey, int] | None = None + + +@dataclass(frozen=True) +class AsyncCapacityObservedMaxima: + row_groups_in_flight: int = 0 + queued_tasks_by_group: Mapping[TaskGroupKey | str, int] = field(default_factory=dict) + task_leases_by_resource: Mapping[SchedulerResourceKey, int] = field(default_factory=dict) + request_waiters_by_resource: Mapping[RequestResourceKey, int] = field(default_factory=dict) + request_in_flight_by_resource: Mapping[RequestResourceKey, int] = field(default_factory=dict) + provider_model_aggregate_in_flight: Mapping[ProviderModelKey, int] = field(default_factory=dict) + request_domain_current_limits: Mapping[RequestResourceKey, int] = field(default_factory=dict) + transport_pool_utilization: Mapping[ProviderModelKey, int] | None = None + + +@dataclass(frozen=True) +class AsyncCapacityPlan: + configured: AsyncCapacityConfigured + runtime_snapshot: AsyncCapacityRuntimeSnapshot + observed_maxima: AsyncCapacityObservedMaxima + + +def missing_capacity_value( + *, + source: CapacityValueSource, + missing_reason: str, + fallback_from: str | None = None, +) -> CapacityValue[object]: + return CapacityValue(value=None, source=source, fallback_from=fallback_from, missing_reason=missing_reason) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index 8adde98a8..685b8dbb8 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -16,6 +16,15 @@ import data_designer.lazy_heavy_imports as lazy from data_designer.config.column_configs import ExpressionColumnConfig, GenerationStrategy +from data_designer.engine.capacity import ( + AsyncCapacityConfigured, + AsyncCapacityObservedMaxima, + AsyncCapacityPlan, + AsyncCapacityRuntimeSnapshot, + CapacityValue, + RequestAdmissionConfigSnapshot, + RowGroupAdmission, +) from data_designer.engine.context import current_row_group, current_row_group_start_offset from data_designer.engine.dataset_builders.errors import DatasetGenerationError from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig @@ -62,6 +71,9 @@ ModelRateLimitError, ModelRequestAdmissionTimeoutError, ) +from data_designer.engine.models.request_admission.config import RequestAdmissionConfig +from data_designer.engine.models.request_admission.resources import RequestResourceKey +from data_designer.engine.models.resources import ProviderModelKey, ProviderModelStaticCap from data_designer.engine.observability import ( RuntimeCorrelation, SchedulerAdmissionEvent, @@ -291,11 +303,19 @@ def __init__( self._first_non_retryable_error: Exception | None = None self._fatal_worker_error: BaseException | None = None + self._max_concurrent_row_groups = max_concurrent_row_groups self._max_in_flight_tasks = max_in_flight_tasks self._num_records = num_records self._buffer_size = buffer_size self._scheduled_records = self._row_groups.scheduled_total_rows self._initial_completed_records = initial_completed_records + self._observed_max_row_groups_in_flight = 0 + self._observed_max_task_leases_by_resource: dict[str, int] = {} + self._observed_max_queued_by_group: dict[str, int] = {} + self._observed_max_request_waiters_by_resource: dict[RequestResourceKey, int] = {} + self._observed_max_request_in_flight_by_resource: dict[RequestResourceKey, int] = {} + self._observed_max_provider_model_aggregate_in_flight: dict[ProviderModelKey, int] = {} + self._observed_max_request_domain_current_limits: dict[RequestResourceKey, int] = {} self._adaptive_row_group_admission = adaptive_row_group_admission self._row_group_admission_hard_cap = max(1, max_concurrent_row_groups) self._row_group_admission_target = ( @@ -303,6 +323,7 @@ def __init__( if adaptive_row_group_admission else self._row_group_admission_hard_cap ) + self._observed_max_row_group_admission_target = self._row_group_admission_target self._row_group_admission_event = asyncio.Event() self._row_group_admission_event.set() self._row_group_admission_pressure_ticks = 0 @@ -444,6 +465,39 @@ def _emit_scheduler_event( logger.warning("Scheduler admission event sink raised; dropping event.", exc_info=True) return + def _record_observed_task_state(self) -> None: + self._observed_max_row_groups_in_flight = max(self._observed_max_row_groups_in_flight, len(self._rg_states)) + view = self._task_admission.view() + for resource, count in view.leased_resources.items(): + self._observed_max_task_leases_by_resource[resource] = max( + self._observed_max_task_leases_by_resource.get(resource, 0), + count, + ) + queue_view = self._fair_queue.view() + for group, count in queue_view.queued_by_group.items(): + label = f"{group.kind}:{'/'.join(group.identity)}" + self._observed_max_queued_by_group[label] = max(self._observed_max_queued_by_group.get(label, 0), count) + if self._request_pressure_provider is None: + return + for resource, snapshot in self._request_pressure_provider.snapshots().items(): + self._observed_max_request_waiters_by_resource[resource] = max( + self._observed_max_request_waiters_by_resource.get(resource, 0), + snapshot.waiters, + ) + self._observed_max_request_in_flight_by_resource[resource] = max( + self._observed_max_request_in_flight_by_resource.get(resource, 0), + snapshot.in_flight_count, + ) + self._observed_max_request_domain_current_limits[resource] = max( + self._observed_max_request_domain_current_limits.get(resource, 0), + snapshot.current_limit, + ) + for provider_model, snapshot in self._request_pressure_provider.global_snapshots().items(): + self._observed_max_provider_model_aggregate_in_flight[provider_model] = max( + self._observed_max_provider_model_aggregate_in_flight.get(provider_model, 0), + snapshot.aggregate_in_flight, + ) + def _emit_scheduler_health_snapshot(self, reason: str) -> None: if self._scheduler_event_sink is None: return @@ -631,6 +685,7 @@ def _enqueue_ready_tasks(self, tasks: tuple[Task, ...]) -> None: self._tracker.mark_enqueued(accepted) for task_id in accepted: self._emit_scheduler_event("ready_enqueued", task=accepted_tasks_by_id[task_id]) + self._record_observed_task_state() self._wake_event.set() def _discard_ready_task(self, task: Task) -> None: @@ -701,6 +756,7 @@ def _dispatch_queued_tasks(self) -> _DispatchOutcome: self._dispatch_selected_task(committed, decision) dispatched = True + self._record_observed_task_state() if dispatched: self._emit_scheduler_event("queue_drained") @@ -872,6 +928,10 @@ def _maybe_update_adaptive_row_group_target(self) -> None: return old_target = self._row_group_admission_target self._row_group_admission_target = min(self._row_group_admission_hard_cap, old_target + 1) + self._observed_max_row_group_admission_target = max( + self._observed_max_row_group_admission_target, + self._row_group_admission_target, + ) self._row_group_admission_pressure_ticks = 0 if self._row_group_admission_target != old_target: self._emit_scheduler_event( @@ -1727,6 +1787,7 @@ async def _execute_task_inner_impl(self, task: Task, lease: TaskAdmissionLease, task_execution_id=task_execution_id, reason_or_result=release_result.reason, ) + self._record_observed_task_state() self._wake_event.set() async def _run_generator_call(self, task: Task, operation: str, call: Coroutine[Any, Any, Any]) -> Any: @@ -2095,6 +2156,134 @@ def task_admission_config(self) -> TaskAdmissionConfig: """Return the effective scheduler task-admission config.""" return self._task_admission_config + def capacity_plan(self) -> AsyncCapacityPlan: + """Return the scheduler-side async capacity explanation for this run.""" + task_view = self._task_admission.view() + request_snapshots = ( + dict(self._request_pressure_provider.snapshots()) if self._request_pressure_provider is not None else {} + ) + provider_snapshots = ( + dict(self._request_pressure_provider.global_snapshots()) + if self._request_pressure_provider is not None + else {} + ) + request_resources = tuple(sorted(request_snapshots)) + provider_model_static_caps = { + provider_model: ProviderModelStaticCap( + cap=snapshot.static_cap, + aliases=snapshot.aliases, + raw_caps=snapshot.raw_caps, + ) + for provider_model, snapshot in provider_snapshots.items() + } + request_config = self._request_pressure_provider.config if self._request_pressure_provider is not None else None + request_config_snapshot = ( + RequestAdmissionConfigSnapshot.from_config(request_config) + if isinstance(request_config, RequestAdmissionConfig) + else None + ) + request_domain_initial_limits: dict[RequestResourceKey, int] = {} + if request_config_snapshot is not None: + request_domain_initial_limits.update(request_config_snapshot.initial_limits) + for resource, snapshot in request_snapshots.items(): + configured_initial = ( + request_config_snapshot.initial_limits.get(resource) if request_config_snapshot is not None else None + ) + request_domain_initial_limits[resource] = ( + max(1, min(configured_initial, snapshot.effective_max)) + if configured_initial is not None + else snapshot.effective_max + ) + request_domain_current_limits = { + resource: snapshot.current_limit for resource, snapshot in request_snapshots.items() + } + request_domain_effective_max = { + resource: snapshot.effective_max for resource, snapshot in request_snapshots.items() + } + request_domain_blocked_until = { + resource: snapshot.blocked_until_monotonic for resource, snapshot in request_snapshots.items() + } + provider_model_aggregate_in_flight = { + provider_model: snapshot.aggregate_in_flight for provider_model, snapshot in provider_snapshots.items() + } + return AsyncCapacityPlan( + configured=AsyncCapacityConfigured( + buffer_size=CapacityValue(value=self._buffer_size, source="run_config"), + row_group_admission=RowGroupAdmission( + row_group_concurrency=CapacityValue( + value=self._max_concurrent_row_groups, + source="dataset_builder", + ), + observed_in_flight=len(self._rg_states), + mode="adaptive" if self._adaptive_row_group_admission else "fixed", + target_in_flight=self._row_group_admission_target, + observed_max_target=self._observed_max_row_group_admission_target, + max_admitted_rows=self._adaptive_max_admitted_rows, + blocked_reasons=dict(self._row_group_admission_blocked_reasons), + ), + submission_capacity=CapacityValue(value=self._max_in_flight_tasks, source="run_config"), + task_resource_limits=CapacityValue( + value=dict(self._task_admission_config.resource_limits), + source="engine_internal_config", + ), + request_resources=CapacityValue( + value=request_resources, + source="runtime_snapshot", + missing_reason=None if request_resources else "request admission has not observed any resources", + ), + provider_model_static_caps=CapacityValue( + value=provider_model_static_caps, + source="model_metadata", + missing_reason=None if provider_model_static_caps else "request admission has no registered models", + ), + request_domain_initial_limits=CapacityValue( + value=request_domain_initial_limits, + source="engine_internal_config" if request_config_snapshot is not None else "runtime_snapshot", + missing_reason=None + if request_domain_initial_limits + else "request admission has not observed any domain limits", + ), + request_admission_config=CapacityValue( + value=request_config_snapshot, + source="engine_internal_config", + missing_reason=None + if request_config_snapshot is not None + else "request admission config is not exposed by the pressure provider", + ), + transport_pool_limits=CapacityValue( + value={}, + source="adapter_config", + missing_reason="transport pool utilization is adapter-specific", + ), + ), + runtime_snapshot=AsyncCapacityRuntimeSnapshot( + request_domain_current_limits=request_domain_current_limits, + request_domain_effective_max=request_domain_effective_max, + request_domain_blocked_until=request_domain_blocked_until, + provider_model_aggregate_in_flight=provider_model_aggregate_in_flight, + ), + observed_maxima=AsyncCapacityObservedMaxima( + row_groups_in_flight=self._observed_max_row_groups_in_flight, + queued_tasks_by_group=dict(self._observed_max_queued_by_group), + task_leases_by_resource=dict(self._observed_max_task_leases_by_resource or task_view.leased_resources), + request_waiters_by_resource=dict( + self._observed_max_request_waiters_by_resource + or {resource: snapshot.waiters for resource, snapshot in request_snapshots.items()} + ), + request_in_flight_by_resource=dict( + self._observed_max_request_in_flight_by_resource + or {resource: snapshot.in_flight_count for resource, snapshot in request_snapshots.items()} + ), + provider_model_aggregate_in_flight=dict( + self._observed_max_provider_model_aggregate_in_flight or provider_model_aggregate_in_flight + ), + request_domain_current_limits=dict( + self._observed_max_request_domain_current_limits or request_domain_current_limits + ), + transport_pool_utilization=None, + ), + ) + @staticmethod def _is_retryable(exc: BaseException) -> bool: """Classify whether an exception is retryable.""" diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/completion.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/completion.py index 8c7839358..855c91642 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/completion.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/completion.py @@ -37,8 +37,9 @@ class CompletionTracker: Row indices are local to their row group (0-based). - Use ``with_graph`` to create a frontier-enabled tracker that incrementally - maintains dependency-ready tasks. + Use ``with_graph`` to create a frontier-enabled tracker where + ``get_ready_tasks`` returns in O(frontier) instead of scanning all + columns x rows x row groups. """ def __init__(self) -> None: @@ -92,6 +93,20 @@ def mark_row_range_complete(self, column: str, row_group: int, row_group_size: i def is_complete(self, ref: SliceRef) -> bool: return ref.row_index in self._completed.get(ref.row_group, {}).get(ref.column, set()) + def is_all_complete(self, cells: list[SliceRef]) -> bool: + """Check whether all the given cells are done. + + A ``row_index`` of ``None`` means the entire batch for that column must + have been completed via ``mark_row_range_complete``. + """ + for ref in cells: + if ref.row_index is None: + if ref.column not in self._batch_complete.get(ref.row_group, set()): + return False + elif not self.is_complete(ref): + return False + return True + def is_column_complete_for_rg(self, column: str, row_group_index: int) -> bool: """Check if *column* has been fully completed for *row_group_index*.""" if column in self._batch_complete.get(row_group_index, set()): @@ -158,10 +173,33 @@ def add_ready_tasks(self, tasks: list[Task] | tuple[Task, ...]) -> FrontierDelta added.append(task) return self._record_delta(added=added, removed=[]) + def get_ready_tasks(self, dispatched: set[Task], admitted_rgs: set[int] | None = None) -> list[Task]: + """Return all currently dispatchable tasks from the frontier.""" + return [ + t + for t in self.ready_frontier() + if t not in dispatched and (admitted_rgs is None or t.row_group in admitted_rgs) + ] + def is_frontier_task(self, task: Task) -> bool: """Return whether *task* is still in the ready frontier.""" return task in self._frontier + def seed_frontier(self) -> None: + """Populate the frontier with root tasks (columns with no upstream deps). + + Not called automatically - the scheduler manages root dispatch directly + to handle stateful locks and multi-column dedup. Call this explicitly + for static introspection (e.g., capacity planning, task enumeration). + """ + if self._graph is None: + raise RuntimeError("This method requires a graph to be set.") + for col in self._graph.get_root_columns(): + if self._row_group_plan is None: + raise RuntimeError("This method requires row groups to be set.") + for rg_id, rg_size in self._row_group_plan: + self.add_root_tasks(rg_id, rg_size, columns=(col,)) + def add_root_tasks( self, row_group: int, diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_completion.py b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_completion.py index bbae04d57..e647d4ac6 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_completion.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_completion.py @@ -40,6 +40,7 @@ def _build_simple_graph() -> ExecutionGraph: @dataclass class ReadyTasksFixture: tracker: CompletionTracker + dispatched: set[Task] @pytest.fixture() @@ -48,6 +49,7 @@ def ready_ctx() -> ReadyTasksFixture: graph = _build_simple_graph() return ReadyTasksFixture( tracker=CompletionTracker.with_graph(graph, [(0, 3)]), + dispatched=set(), ) @@ -84,6 +86,43 @@ def test_mark_cell_complete_raises_on_unknown_row_group(ready_ctx: ReadyTasksFix ready_ctx.tracker.mark_cell_complete("question", row_group=999, row_index=0) +# -- is_all_complete ----------------------------------------------------------- + + +def test_all_complete_cell_level() -> None: + tracker = CompletionTracker() + tracker.mark_cell_complete("col_a", 0, 0) + tracker.mark_cell_complete("col_a", 0, 1) + + assert tracker.is_all_complete([SliceRef("col_a", 0, 0), SliceRef("col_a", 0, 1)]) + assert not tracker.is_all_complete([SliceRef("col_a", 0, 0), SliceRef("col_a", 0, 2)]) + + +def test_all_complete_batch_level() -> None: + tracker = CompletionTracker() + tracker.mark_row_range_complete("col_a", 0, 3) + + assert tracker.is_all_complete([SliceRef("col_a", 0, None)]) + + +def test_all_complete_batch_single_cell_not_sufficient() -> None: + """mark_cell_complete on one row must NOT make is_all_complete return True for batch check.""" + tracker = CompletionTracker() + tracker.mark_cell_complete("col_a", 0, 0) + + assert not tracker.is_all_complete([SliceRef("col_a", 0, None)]) + + +def test_all_complete_batch_not_present() -> None: + tracker = CompletionTracker() + assert not tracker.is_all_complete([SliceRef("col_a", 0, None)]) + + +def test_all_complete_empty_list() -> None: + tracker = CompletionTracker() + assert tracker.is_all_complete([]) + + # -- drop_row / is_dropped ------------------------------------------------- @@ -135,17 +174,19 @@ def test_row_group_not_complete_missing_non_dropped() -> None: assert not tracker.is_row_group_complete(0, 3, ["col_a", "col_b"]) -# -- ready frontier --------------------------------------------------------- +# -- get_ready_tasks -------------------------------------------------------- -def test_ready_frontier_starts_empty(ready_ctx: ReadyTasksFixture) -> None: - ready = ready_ctx.tracker.ready_frontier() +def test_get_ready_tasks_frontier_empty_without_seed(ready_ctx: ReadyTasksFixture) -> None: + """Frontier starts empty - seed_frontier() must be called explicitly.""" + ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) assert len(ready) == 0 -def test_add_root_tasks_populates_frontier(ready_ctx: ReadyTasksFixture) -> None: - ready_ctx.tracker.add_root_tasks(0, 3, columns=("topic",)) - ready = ready_ctx.tracker.ready_frontier() +def test_get_ready_tasks_seed_frontier(ready_ctx: ReadyTasksFixture) -> None: + """seed_frontier() populates the frontier with root tasks.""" + ready_ctx.tracker.seed_frontier() + ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) assert len(ready) == 1 assert ready[0].column == "topic" @@ -153,7 +194,7 @@ def test_add_root_tasks_populates_frontier(ready_ctx: ReadyTasksFixture) -> None def test_mark_enqueued_uses_scheduler_stable_task_id(ready_ctx: ReadyTasksFixture) -> None: - ready_ctx.tracker.add_root_tasks(0, 3, columns=("topic",)) + ready_ctx.tracker.seed_frontier() task = ready_ctx.tracker.ready_frontier()[0] ready_ctx.tracker.mark_enqueued({stable_task_id(task)}) @@ -161,10 +202,10 @@ def test_mark_enqueued_uses_scheduler_stable_task_id(ready_ctx: ReadyTasksFixtur assert ready_ctx.tracker.ready_frontier() == () -def test_ready_frontier_after_seed_complete(ready_ctx: ReadyTasksFixture) -> None: +def test_get_ready_tasks_after_seed_complete(ready_ctx: ReadyTasksFixture) -> None: delta = ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) - ready = ready_ctx.tracker.ready_frontier() + ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) question_tasks = [t for t in ready if t.column == "question"] assert len(question_tasks) == 3 @@ -192,7 +233,7 @@ def test_fan_out_cell_completion_readies_all_children_for_same_row() -> None: assert {task.column for task in delta.added} == {"child_a", "child_b", "child_c"} assert {task.row_index for task in delta.added} == {0} - ready = tracker.ready_frontier() + ready = tracker.get_ready_tasks(set()) assert not any(task.column.startswith("child_") and task.row_index == 1 for task in ready) @@ -217,16 +258,26 @@ def test_fan_in_cell_downstream_waits_for_all_same_row_upstreams() -> None: assert not any(task.column == "judge" for task in first_delta.added) assert not any(task.column == "judge" for task in second_delta.added) assert final_delta.added == (Task(column="judge", row_group=0, row_index=0, task_type="cell"),) - ready = tracker.ready_frontier() + ready = tracker.get_ready_tasks(set()) assert not any(task.column == "judge" and task.row_index == 1 for task in ready) -def test_ready_frontier_skips_dropped_rows(ready_ctx: ReadyTasksFixture) -> None: +def test_get_ready_tasks_skips_dispatched(ready_ctx: ReadyTasksFixture) -> None: + ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) + + ready1 = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) + ready_ctx.dispatched.update(ready1) + + ready2 = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) + assert len(ready2) == 0 + + +def test_get_ready_tasks_skips_dropped_rows(ready_ctx: ReadyTasksFixture) -> None: ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) removed = Task(column="question", row_group=0, row_index=1, task_type="cell") delta = ready_ctx.tracker.drop_row(0, 1) - ready = ready_ctx.tracker.ready_frontier() + ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) question_tasks = [t for t in ready if t.column == "question"] assert len(question_tasks) == 2 @@ -243,32 +294,32 @@ def test_drop_row_unblocks_full_column_downstream(ready_ctx: ReadyTasksFixture) # question[2] never completes -- drop it instead delta = ready_ctx.tracker.drop_row(0, 2) - ready = ready_ctx.tracker.ready_frontier() + ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) score_tasks = [t for t in ready if t.column == "score"] assert len(score_tasks) == 1 assert score_tasks[0].task_type == "batch" assert score_tasks[0] in delta.added -def test_ready_frontier_full_column_waits_for_all_cells(ready_ctx: ReadyTasksFixture) -> None: +def test_get_ready_tasks_full_column_waits_for_all_cells(ready_ctx: ReadyTasksFixture) -> None: ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) ready_ctx.tracker.mark_cell_complete("question", 0, 0) ready_ctx.tracker.mark_cell_complete("question", 0, 1) # question[0,2] not done yet - ready = ready_ctx.tracker.ready_frontier() + ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) score_tasks = [t for t in ready if t.column == "score"] assert len(score_tasks) == 0 -def test_ready_frontier_full_column_ready_when_all_cells_done(ready_ctx: ReadyTasksFixture) -> None: +def test_get_ready_tasks_full_column_ready_when_all_cells_done(ready_ctx: ReadyTasksFixture) -> None: ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) delta = None for ri in range(3): delta = ready_ctx.tracker.mark_cell_complete("question", 0, ri) - ready = ready_ctx.tracker.ready_frontier() + ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) score_tasks = [t for t in ready if t.column == "score"] assert len(score_tasks) == 1 @@ -277,13 +328,15 @@ def test_ready_frontier_full_column_ready_when_all_cells_done(ready_ctx: ReadyTa assert delta.added == (score_tasks[0],) -def test_ready_frontier_multiple_row_groups() -> None: +def test_get_ready_tasks_multiple_row_groups() -> None: graph = _build_simple_graph() tracker = CompletionTracker.with_graph(graph, [(0, 3), (1, 2)]) + dispatched: set[Task] = set() + tracker.mark_row_range_complete("topic", 0, 3) tracker.mark_row_range_complete("topic", 1, 2) - ready = tracker.ready_frontier() + ready = tracker.get_ready_tasks(dispatched) question_tasks = [t for t in ready if t.column == "question"] assert len(question_tasks) == 5 # 3 from rg0 + 2 from rg1 @@ -297,10 +350,10 @@ def test_frontier_delta_return_is_empty_when_frontier_does_not_change(ready_ctx: assert delta.empty -def test_ready_frontier_skips_already_complete_batch(ready_ctx: ReadyTasksFixture) -> None: +def test_get_ready_tasks_skips_already_complete_batch(ready_ctx: ReadyTasksFixture) -> None: ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) - ready = ready_ctx.tracker.ready_frontier() + ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) topic_tasks = [t for t in ready if t.column == "topic"] assert len(topic_tasks) == 0 @@ -327,6 +380,8 @@ def test_completed_cell_not_reenqueued_after_later_upstream() -> None: """A → B → C chain: completing C then firing a late upstream event must not re-enqueue C.""" graph = _build_simple_graph() tracker = CompletionTracker.with_graph(graph, [(0, 2)]) + dispatched: set[Task] = set() + # Complete the full pipeline tracker.mark_row_range_complete("topic", 0, 2) tracker.mark_cell_complete("question", 0, 0) @@ -336,7 +391,7 @@ def test_completed_cell_not_reenqueued_after_later_upstream() -> None: # Fire a late upstream cell event after score is already done tracker.mark_cell_complete("question", 0, 0) - ready = tracker.ready_frontier() + ready = tracker.get_ready_tasks(dispatched) score_tasks = [t for t in ready if t.column == "score"] assert len(score_tasks) == 0 @@ -355,19 +410,21 @@ def test_completed_batch_not_reenqueued_by_upstream_cell() -> None: } graph = ExecutionGraph.create(configs, strategies) tracker = CompletionTracker.with_graph(graph, [(0, 2)]) + dispatched: set[Task] = set() + # Complete seed and gen[0] — agg not ready yet tracker.mark_row_range_complete("seed", 0, 2) tracker.mark_cell_complete("gen", 0, 0) - ready = tracker.ready_frontier() + ready = tracker.get_ready_tasks(dispatched) assert not any(t.column == "agg" for t in ready) # Complete gen[1] — agg becomes ready tracker.mark_cell_complete("gen", 0, 1) - ready = tracker.ready_frontier() + ready = tracker.get_ready_tasks(dispatched) assert any(t.column == "agg" for t in ready) # Complete agg, then verify it doesn't reappear tracker.mark_row_range_complete("agg", 0, 2) - ready = tracker.ready_frontier() + ready = tracker.get_ready_tasks(dispatched) assert not any(t.column == "agg" for t in ready) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py index c2ba039e5..cddcf31a0 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py @@ -57,6 +57,7 @@ from data_designer.engine.models.request_admission.config import RequestAdmissionConfig from data_designer.engine.models.request_admission.controller import ( AdaptiveRequestAdmissionController, + RequestAdmissionLease, ) from data_designer.engine.models.request_admission.outcomes import RequestReleaseOutcome from data_designer.engine.models.request_admission.pressure import RequestPressureSnapshot @@ -3507,6 +3508,102 @@ async def test_scheduler_downstream_interleaves_with_upstream() -> None: ) +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_capacity_plan_observes_buffer_backpressure() -> None: + provider = _mock_provider() + gen_names = ["gen_a", "gen_b", "gen_c"] + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + *[LLMTextColumnConfig(name=g, prompt="{{ topic }}", model_alias=MODEL_ALIAS) for g in gen_names], + ] + strategies: dict[str, GenerationStrategy] = {"topic": GenerationStrategy.FULL_COLUMN} + strategies.update({column: GenerationStrategy.CELL_BY_CELL for column in gen_names}) + generators: dict[str, ColumnGenerator] = { + "topic": MockSeedGenerator(config=_expr_config("topic"), resource_provider=provider), + **{ + name: SlowCellGenerator(config=_expr_config(name), resource_provider=provider, delay=0.02) + for name in gen_names + }, + } + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 3), (1, 3), (2, 3), (3, 3)] + tracker = CompletionTracker.with_graph(graph, row_groups) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + max_concurrent_row_groups=2, + max_in_flight_tasks=2, + trace=True, + num_records=12, + buffer_size=3, + ) + + await asyncio.wait_for(scheduler.run(), timeout=10.0) + + plan = scheduler.capacity_plan() + for row_group_index, row_count in row_groups: + assert tracker.is_row_group_complete(row_group_index, row_count, ["topic", *gen_names]) + assert plan.configured.row_group_admission.observed_in_flight == 0 + assert plan.observed_maxima.row_groups_in_flight == 2 + assert plan.observed_maxima.queued_tasks_by_group + assert max(plan.observed_maxima.task_leases_by_resource.values()) <= 2 + + +def test_scheduler_capacity_plan_reports_request_admission_state() -> None: + resource = RequestResourceKey("provider", "model", RequestDomain.CHAT) + request_admission = AdaptiveRequestAdmissionController( + RequestAdmissionConfig(initial_limits={resource: 2}, max_limit_clamps={resource: 3}) + ) + request_admission.register( + provider_name="provider", + model_id="model", + alias="primary", + max_parallel_requests=4, + ) + lease = request_admission.try_acquire(RequestAdmissionItem(resource, RequestGroupSpec(resource))) + assert isinstance(lease, RequestAdmissionLease) + + scheduler, _tracker = _build_simple_pipeline() + scheduler._request_pressure_provider = request_admission + scheduler._record_observed_task_state() + plan = scheduler.capacity_plan() + + assert plan.configured.request_resources.value == (resource,) + assert plan.configured.request_domain_initial_limits.value[resource] == 2 + assert plan.configured.request_admission_config.value is not None + assert plan.configured.provider_model_static_caps.value[ProviderModelKey("provider", "model")].cap == 4 + assert plan.runtime_snapshot.request_domain_current_limits[resource] == 2 + assert plan.runtime_snapshot.request_domain_effective_max[resource] == 3 + assert plan.runtime_snapshot.provider_model_aggregate_in_flight[ProviderModelKey("provider", "model")] == 1 + assert plan.observed_maxima.request_in_flight_by_resource[resource] == 1 + assert plan.observed_maxima.provider_model_aggregate_in_flight[ProviderModelKey("provider", "model")] == 1 + request_admission.release(lease, RequestReleaseOutcome(kind="success")) + + +def test_scheduler_capacity_plan_reports_default_request_initial_limit_after_aimd_drop() -> None: + resource = RequestResourceKey("provider", "model", RequestDomain.CHAT) + request_admission = AdaptiveRequestAdmissionController() + request_admission.register( + provider_name="provider", + model_id="model", + alias="primary", + max_parallel_requests=4, + ) + lease = request_admission.try_acquire(RequestAdmissionItem(resource, RequestGroupSpec(resource))) + assert isinstance(lease, RequestAdmissionLease) + request_admission.release(lease, RequestReleaseOutcome(kind="rate_limited")) + + scheduler, _tracker = _build_simple_pipeline() + scheduler._request_pressure_provider = request_admission + plan = scheduler.capacity_plan() + + assert plan.configured.request_domain_initial_limits.value[resource] == 4 + assert plan.runtime_snapshot.request_domain_effective_max[resource] == 4 + assert plan.runtime_snapshot.request_domain_current_limits[resource] == 3 + + @pytest.mark.asyncio(loop_scope="session") async def test_scheduler_emits_job_health_and_row_group_telemetry() -> None: provider = _mock_provider() @@ -3613,7 +3710,12 @@ async def test_scheduler_adaptive_row_group_admission_expands_target_for_horizon await asyncio.wait_for(scheduler.run(), timeout=10.0) + plan = scheduler.capacity_plan() assert tracker.is_row_group_complete(0, 1, ["topic", "model_col"]) + assert plan.configured.row_group_admission.mode == "adaptive" + assert plan.configured.row_group_admission.observed_max_target is not None + assert plan.configured.row_group_admission.observed_max_target > 1 + assert plan.observed_maxima.row_groups_in_flight > 1 assert any(event.event_kind == "row_group_admission_target_changed" for event in sink.scheduler_events) diff --git a/packages/data-designer-engine/tests/engine/test_capacity.py b/packages/data-designer-engine/tests/engine/test_capacity.py new file mode 100644 index 000000000..856aeba09 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/test_capacity.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from data_designer.engine.capacity import ( + AsyncCapacityConfigured, + AsyncCapacityObservedMaxima, + AsyncCapacityPlan, + AsyncCapacityRuntimeSnapshot, + CapacityValue, + RequestAdmissionConfigSnapshot, + RowGroupAdmission, +) +from data_designer.engine.models.request_admission.config import RequestAdmissionConfig +from data_designer.engine.models.request_admission.resources import RequestDomain, RequestResourceKey +from data_designer.engine.models.resources import ProviderModelKey, ProviderModelStaticCap + + +def test_request_admission_config_snapshot_records_resources() -> None: + resource = RequestResourceKey("nvidia", "nemotron", RequestDomain.CHAT) + config = RequestAdmissionConfig( + initial_limits={resource: 2}, + max_limit_clamps={resource: 4}, + startup_ramp_seconds=30.0, + ) + + snapshot = RequestAdmissionConfigSnapshot.from_config(config) + + assert snapshot.resources == (resource,) + assert snapshot.initial_limits[resource] == 2 + assert snapshot.max_limit_clamps[resource] == 4 + assert snapshot.startup_ramp_seconds == 30.0 + + +def test_async_capacity_plan_records_configured_runtime_and_maxima() -> None: + resource = RequestResourceKey("nvidia", "nemotron", RequestDomain.CHAT) + provider_model = ProviderModelKey("nvidia", "nemotron") + static_cap = ProviderModelStaticCap(cap=4, aliases=("default",), raw_caps={"default": 4}) + + plan = AsyncCapacityPlan( + configured=AsyncCapacityConfigured( + buffer_size=CapacityValue(value=16, source="run_config"), + row_group_admission=RowGroupAdmission( + row_group_concurrency=CapacityValue(value=2, source="dataset_builder"), + observed_in_flight=1, + ), + submission_capacity=CapacityValue(value=8, source="engine_internal_config"), + task_resource_limits=CapacityValue(value={"submission": 8, "llm_wait": 4}, source="engine_internal_config"), + request_resources=CapacityValue(value=(resource,), source="runtime_snapshot"), + provider_model_static_caps=CapacityValue(value={provider_model: static_cap}, source="model_metadata"), + request_domain_initial_limits=CapacityValue(value={resource: 2}, source="engine_internal_config"), + request_admission_config=CapacityValue( + value=RequestAdmissionConfigSnapshot.from_config(RequestAdmissionConfig(initial_limits={resource: 2})), + source="engine_internal_config", + ), + transport_pool_limits=CapacityValue(value={provider_model: 8}, source="adapter_config"), + ), + runtime_snapshot=AsyncCapacityRuntimeSnapshot( + request_domain_current_limits={resource: 2}, + request_domain_effective_max={resource: 4}, + request_domain_blocked_until={resource: None}, + provider_model_aggregate_in_flight={provider_model: 0}, + ), + observed_maxima=AsyncCapacityObservedMaxima( + row_groups_in_flight=1, + request_in_flight_by_resource={resource: 2}, + provider_model_aggregate_in_flight={provider_model: 2}, + ), + ) + + assert plan.configured.provider_model_static_caps.value[provider_model].merge_rule == "min_same_endpoint" + assert plan.runtime_snapshot.request_domain_current_limits[resource] == 2 + assert plan.observed_maxima.provider_model_aggregate_in_flight[provider_model] == 2