diff --git a/packages/data-designer-config/src/data_designer/config/utils/io_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/io_helpers.py index 8247a36b7..e71ec3a10 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/io_helpers.py +++ b/packages/data-designer-config/src/data_designer/config/utils/io_helpers.py @@ -9,7 +9,7 @@ from datetime import date, datetime, timedelta from decimal import Decimal from numbers import Number -from pathlib import Path, PurePosixPath +from pathlib import Path from typing import TYPE_CHECKING, Any from urllib.parse import urlparse @@ -29,15 +29,6 @@ VALID_CONFIG_FILE_EXTENSIONS = {".yaml", ".yml", ".json"} -def ensure_config_dir_exists(config_dir: Path) -> None: - """Create configuration directory if it doesn't exist. - - Args: - config_dir: Directory path to create - """ - config_dir.mkdir(parents=True, exist_ok=True) - - def load_config_file(file_path: Path) -> dict: """Load a YAML configuration file. @@ -177,40 +168,6 @@ def validate_path_contains_files_of_type(path: str | Path, file_extension: str) raise InvalidFilePathError(f"🛑 Path {path!r} does not contain files of type {file_extension!r}.") -def smart_load_dataframe(dataframe: str | Path | pd.DataFrame) -> pd.DataFrame: - """Load a dataframe from file if a path is given, otherwise return the dataframe. - - Args: - dataframe: A path to a file or a pandas DataFrame object. - - Returns: - A pandas DataFrame object. - """ - if isinstance(dataframe, lazy.pd.DataFrame): - return dataframe - - # Get the file extension. - if isinstance(dataframe, str) and dataframe.startswith("http"): - dataframe = _maybe_rewrite_url(dataframe) - # Parse extension from the URL path to avoid query-string contamination (e.g. "csv?token=…"). - ext = PurePosixPath(urlparse(dataframe).path).suffix.lstrip(".").lower() - else: - dataframe = Path(dataframe) - ext = dataframe.suffix.lower() - if not dataframe.exists(): - raise FileNotFoundError(f"File not found: {dataframe}") - - # Load the dataframe based on the file extension. - if ext == "csv": - return lazy.pd.read_csv(dataframe) - elif ext == "json": - return lazy.pd.read_json(dataframe, lines=True) - elif ext == "parquet": - return lazy.pd.read_parquet(dataframe) - else: - raise ValueError(f"Unsupported file format: {dataframe}") - - def smart_load_yaml(yaml_in: str | Path | dict) -> dict: """Return the yaml config as a dict given flexible input types. diff --git a/packages/data-designer-config/tests/config/utils/test_io_helpers.py b/packages/data-designer-config/tests/config/utils/test_io_helpers.py index 5b2b8b98a..1d65d945d 100644 --- a/packages/data-designer-config/tests/config/utils/test_io_helpers.py +++ b/packages/data-designer-config/tests/config/utils/test_io_helpers.py @@ -6,8 +6,6 @@ import tempfile from datetime import date, datetime, timedelta from decimal import Decimal -from pathlib import Path -from typing import TYPE_CHECKING from unittest.mock import MagicMock, patch import pytest @@ -19,66 +17,9 @@ _maybe_rewrite_url, is_http_url, serialize_data, - smart_load_dataframe, smart_load_yaml, ) -if TYPE_CHECKING: - import pandas as pd - - -@patch("data_designer.config.utils.io_helpers.Path", autospec=True) -@patch("data_designer.config.utils.io_helpers.lazy.pd.read_csv", autospec=True) -@patch("data_designer.config.utils.io_helpers.lazy.pd.read_json", autospec=True) -@patch("data_designer.config.utils.io_helpers.lazy.pd.read_parquet", autospec=True) -def test_smart_load_dataframe(mock_read_parquet, mock_read_json, mock_read_csv, mock_path_cls, stub_dataframe): - mock_read_parquet.return_value = stub_dataframe - mock_read_json.return_value = stub_dataframe - mock_read_csv.return_value = stub_dataframe - - # dataframe objects are passed through - assert smart_load_dataframe(stub_dataframe).size == stub_dataframe.size - - # url based - stub_base_url = "https://example.com/data.{extention}" - url_csv = stub_base_url.format(extention="csv") - smart_load_dataframe(url_csv) - mock_read_csv.assert_called_once_with(url_csv) - - url_json = stub_base_url.format(extention="json") - smart_load_dataframe(url_json) - mock_read_json.assert_called_once_with(url_json, lines=True) - - url_parquet = stub_base_url.format(extention="parquet") - smart_load_dataframe(url_parquet) - mock_read_parquet.assert_called_once_with(url_parquet) - - url_unknown = stub_base_url.format(extention="unknown") - with pytest.raises(ValueError): - smart_load_dataframe(url_unknown) - - # local file based - mock_read_csv.reset_mock() - mock_read_json.reset_mock() - mock_read_parquet.reset_mock() - - mock_path = MagicMock(autospec=Path) - mock_path.exists.return_value = True - mock_path.suffix.lower.return_value = "csv" - mock_path_cls.return_value = mock_path - - stub_base_path_str = "/some/path/to/data.{extension}" - path_csv = stub_base_path_str.format(extension="csv") - _ = smart_load_dataframe(path_csv) - mock_read_csv.assert_called_once_with(mock_path) - - mock_path.reset_mock() - mock_path.suffix.lower.return_value = "json" - mock_path.exists.return_value = False - path_json = stub_base_path_str.format(extension="json") - with pytest.raises(FileNotFoundError): - _ = smart_load_dataframe(Path(path_json)) - def test_smart_load_yaml(): stub_dict = { @@ -348,37 +289,6 @@ def test_smart_load_yaml_rewrites_huggingface_blob_url(mock_requests: MagicMock) ) -@patch("data_designer.config.utils.io_helpers.lazy.pd.read_csv", autospec=True) -def test_smart_load_dataframe_rewrites_github_blob_url(mock_read_csv: MagicMock, stub_dataframe: pd.DataFrame) -> None: - mock_read_csv.return_value = stub_dataframe - - smart_load_dataframe("https://github.com/org/repo/blob/main/data.csv") - - mock_read_csv.assert_called_once_with("https://raw.githubusercontent.com/org/repo/main/data.csv") - - -@patch("data_designer.config.utils.io_helpers.lazy.pd.read_csv", autospec=True) -def test_smart_load_dataframe_rewrites_github_blob_url_with_token( - mock_read_csv: MagicMock, stub_dataframe: pd.DataFrame -) -> None: - mock_read_csv.return_value = stub_dataframe - - smart_load_dataframe("https://github.com/org/repo/blob/main/data.csv?token=secret123") - - mock_read_csv.assert_called_once_with("https://raw.githubusercontent.com/org/repo/main/data.csv?token=secret123") - - -@patch("data_designer.config.utils.io_helpers.lazy.pd.read_csv", autospec=True) -def test_smart_load_dataframe_rewrites_huggingface_blob_url( - mock_read_csv: MagicMock, stub_dataframe: pd.DataFrame -) -> None: - mock_read_csv.return_value = stub_dataframe - - smart_load_dataframe("https://huggingface.co/datasets/org/repo/blob/main/data.csv") - - mock_read_csv.assert_called_once_with("https://huggingface.co/datasets/org/repo/raw/main/data.csv") - - def test_maybe_rewrite_github_url_log_does_not_leak_query(caplog: pytest.LogCaptureFixture) -> None: import logging diff --git a/packages/data-designer-engine/src/data_designer/engine/analysis/utils/column_statistics_calculations.py b/packages/data-designer-engine/src/data_designer/engine/analysis/utils/column_statistics_calculations.py index 14f65a781..2c85e5d25 100644 --- a/packages/data-designer-engine/src/data_designer/engine/analysis/utils/column_statistics_calculations.py +++ b/packages/data-designer-engine/src/data_designer/engine/analysis/utils/column_statistics_calculations.py @@ -32,8 +32,6 @@ RANDOM_SEED = 42 MAX_PROMPT_SAMPLE_SIZE = 1000 WARNING_PREFIX = "⚠️ Error during column profile calculation: " -TEXT_FIELD_AVG_SPACE_COUNT_THRESHOLD = 0.1 - logger = logging.getLogger(__name__) diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py index 49e7d5b03..8cbd9833b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py @@ -277,38 +277,6 @@ def _postprocess_result( return self._validate_output(result, keys_before, is_dataframe) - def _validate_cell_output(self, row: dict, keys_before: set[str]) -> dict: - """Validate a single row output (dict) for cell_by_cell; strip undeclared columns.""" - expected_new = {self.config.name} | set(self.config.side_effect_columns) - result_keys = set(row.keys()) - - if self.config.name not in result_keys: - raise CustomColumnGenerationError( - f"Custom generator for column '{self.config.name}' did not create the expected column. " - f"The generator_function must add a column named '{self.config.name}' to the row." - ) - missing = set(self.config.side_effect_columns) - result_keys - if missing: - raise CustomColumnGenerationError( - f"Custom generator for column '{self.config.name}' did not create declared side_effect_columns: " - f"{sorted(missing)}. Declared side_effect_columns must be added to the row." - ) - removed = keys_before - result_keys - if removed: - raise CustomColumnGenerationError( - f"Custom generator for column '{self.config.name}' removed pre-existing columns: " - f"{sorted(removed)}. The generator_function must not remove any existing columns." - ) - undeclared = (result_keys - keys_before) - expected_new - if undeclared: - logger.warning( - f"⚠️ Custom generator for column '{self.config.name}' created undeclared columns: " - f"{sorted(undeclared)}. These columns will be removed. " - f"To keep additional columns, declare them in @custom_column_generator(side_effect_columns=[...])." - ) - row = {k: v for k, v in row.items() if k not in undeclared} - return row - def _validate_output( self, result: dict | pd.DataFrame, keys_before: set[str], is_dataframe: bool ) -> dict | pd.DataFrame: diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index c7e848b5a..685b8dbb8 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -305,7 +305,6 @@ def __init__( self._max_concurrent_row_groups = max_concurrent_row_groups self._max_in_flight_tasks = max_in_flight_tasks - self._max_model_task_admission = max_model_task_admission self._num_records = num_records self._buffer_size = buffer_size self._scheduled_records = self._row_groups.scheduled_total_rows @@ -660,9 +659,6 @@ def _apply_frontier_delta(self, delta: FrontierDelta) -> None: self._discard_ready_task(task) self._enqueue_ready_tasks(delta.added) - def _enqueue_ready_task(self, task: Task) -> None: - self._enqueue_ready_tasks((task,)) - def _enqueue_ready_tasks(self, tasks: tuple[Task, ...]) -> None: schedulables: list[SchedulableTask] = [] accepted_tasks_by_id: dict[str, Task] = {} diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py index 118e242a1..81dd19416 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py @@ -218,13 +218,6 @@ def first_non_retryable_error(self) -> Exception | None: """First non-retryable error captured by the scheduler in the most recent run.""" return self._first_non_retryable_error - def set_processor_runner(self, processors: list[Processor]) -> None: - """Replace the processor runner with a new one using the given processors.""" - self._processor_runner = ProcessorRunner( - processors=processors, - artifact_storage=self.artifact_storage, - ) - @functools.cached_property def single_column_configs(self) -> list[ColumnConfigT]: configs = [] @@ -235,10 +228,6 @@ def single_column_configs(self) -> list[ColumnConfigT]: configs.append(config) return configs - @functools.cached_property - def single_column_config_by_name(self) -> dict[str, ColumnConfigT]: - return {config.name: config for config in self.single_column_configs} - def build( self, *, diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_model.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_model.py index 574c594c1..e94b3fa17 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_model.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_model.py @@ -4,7 +4,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Literal +from typing import Literal @dataclass(frozen=True, order=True) @@ -26,17 +26,6 @@ class Task: task_type: Literal["from_scratch", "cell", "batch", "pre_batch_processor", "post_batch_processor"] -@dataclass -class TaskResult: - """Outcome of a completed task.""" - - task: Task - status: Literal["success", "error"] - output: Any = None - error: Exception | None = None - retryable: bool = False - - @dataclass class TaskTrace: """Timing trace for a single task. Only created when tracing is enabled.""" diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py index a7d1883c6..eb9e2104a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py @@ -1,33 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -"""Async batch execution with bounded concurrency and early-shutdown semantics. +"""Process-wide event-loop management for async engine work. -Async counterpart to ``concurrency.py``. Same operational contract (callbacks -with optional context, error aggregation, shutdown thresholds), different -runtime model. The sync module runs callables in a ``ThreadPoolExecutor``; -this module runs coroutines via ``asyncio.gather`` on a dedicated loop -thread. Callers stay synchronous. - -Architecture: - ``AsyncConcurrentExecutor.run()`` is a blocking call that submits - coroutines to a shared background event loop via - ``run_coroutine_threadsafe``. Bounded concurrency is enforced with an - ``asyncio.Semaphore``. Success/error counts use the same - ``ExecutorResults`` model as the sync executor. - - Caller Thread ──► run() ──► run_coroutine_threadsafe ──► Background Loop - (gather) - -Singleton Event Loop: +Singleton event loop: The background loop is a process-wide singleton. Async-stateful resources (connection pools, semaphores) bind internal state to a specific event loop, so creating per-call or per-instance loops breaks connection reuse and triggers cross-loop errors. - ``ensure_async_engine_loop()`` creates one daemon loop thread and - reuses it for all executor instances. + ``ensure_async_engine_loop()`` creates one daemon loop thread and reuses + it for all async engine work. -Startup Handshake: +Startup handshake: Loop creation uses a ``threading.Event`` readiness handshake. The background thread signals readiness via ``loop.call_soon(ready.set)``, and the creating thread holds the lock until that event fires (or a @@ -40,40 +24,11 @@ from __future__ import annotations import asyncio -import json import logging import threading -from collections.abc import Coroutine -from dataclasses import dataclass -from typing import Any, Generic, TypeVar - -from data_designer.engine.dataset_builders.utils.concurrency import ( - CallbackWithContext, - ErrorCallbackWithContext, - ExecutorResults, -) -from data_designer.engine.errors import DataDesignerRuntimeError -from data_designer.logging import LOG_INDENT logger = logging.getLogger(__name__) -T = TypeVar("T") - - -@dataclass(frozen=True, slots=True) -class Success(Generic[T]): - index: int - value: T - - -@dataclass(frozen=True, slots=True) -class Failure: - index: int - error: Exception - - -TaskResult = Success[T] | Failure - _loop: asyncio.AbstractEventLoop | None = None _thread: threading.Thread | None = None _lock = threading.Lock() @@ -90,9 +45,8 @@ def _run_loop(loop: asyncio.AbstractEventLoop, ready: threading.Event) -> None: def ensure_async_engine_loop() -> asyncio.AbstractEventLoop: """Get or create a persistent event loop for async engine work. - A single event loop is shared across all AsyncConcurrentExecutor instances - to avoid breaking async-stateful resources (connection pools, semaphores) - that bind internal state to a specific event loop. + A single event loop is shared across async engine work to avoid breaking + async-stateful resources that bind internal state to a specific event loop. """ global _loop, _thread with _lock: @@ -118,123 +72,3 @@ def ensure_async_engine_loop() -> asyncio.AbstractEventLoop: raise RuntimeError("AsyncEngine event loop failed to start within timeout") return _loop - - -class AsyncConcurrentExecutor: - """Async equivalent of ConcurrentThreadExecutor. - - Executes a batch of coroutines with bounded concurrency, error rate - monitoring, and early shutdown semantics. Callers remain synchronous — - the ``run()`` method submits work to a persistent background event loop. - - No locks are needed because asyncio tasks run cooperatively on a - single thread — mutations to ``_results`` are always sequential. - """ - - def __init__( - self, - *, - max_workers: int, - column_name: str, - result_callback: CallbackWithContext | None = None, - error_callback: ErrorCallbackWithContext | None = None, - shutdown_error_rate: float = 0.50, - shutdown_error_window: int = 10, - disable_early_shutdown: bool = False, - ) -> None: - self._column_name = column_name - self._max_workers = max_workers - self._result_callback = result_callback - self._error_callback = error_callback - self._shutdown_error_rate = shutdown_error_rate - self._shutdown_window_size = shutdown_error_window - self._disable_early_shutdown = disable_early_shutdown - self._results = ExecutorResults(failure_threshold=shutdown_error_rate) - - @property - def results(self) -> ExecutorResults: - return self._results - - @property - def max_workers(self) -> int: - return self._max_workers - - @property - def shutdown_error_rate(self) -> float: - return self._shutdown_error_rate - - @property - def shutdown_window_size(self) -> int: - return self._shutdown_window_size - - def run(self, work_items: list[tuple[Coroutine[Any, Any, Any], dict | None]]) -> None: - """Execute all work items concurrently. Callers remain synchronous.""" - logger.debug( - f"AsyncConcurrentExecutor: launching {len(work_items)} tasks " - f"with max_workers={self._max_workers} for column '{self._column_name}'" - ) - loop = ensure_async_engine_loop() - future = asyncio.run_coroutine_threadsafe(self._run_all(work_items), loop) - future.result() - - async def _run_all(self, work_items: list[tuple[Coroutine[Any, Any, Any], dict | None]]) -> None: - self._semaphore = asyncio.Semaphore(self._max_workers) - self._shutdown_event = asyncio.Event() - - # gather-with-explicit-cancel: equivalent to asyncio.TaskGroup but available on 3.10. - # _run_task swallows its own exceptions into error_trap, so children don't raise into - # gather under normal operation. The except-block preserves TaskGroup's "cancel siblings - # on parent cancellation or unexpected child raise" semantics for safety. - tasks = [asyncio.create_task(self._run_task(i, coro, ctx)) for i, (coro, ctx) in enumerate(work_items)] - try: - await asyncio.gather(*tasks) - except BaseException: - for t in tasks: - if not t.done(): - t.cancel() - await asyncio.gather(*tasks, return_exceptions=True) - raise - - if not self._disable_early_shutdown and self._results.early_shutdown: - self._raise_task_error() - - async def _run_task(self, index: int, coro: Coroutine[Any, Any, Any], context: dict | None) -> None: - if self._shutdown_event.is_set(): - coro.close() - return - - async with self._semaphore: - if self._shutdown_event.is_set(): - coro.close() - return - - try: - result = await coro - if self._result_callback is not None: - self._result_callback(result, context=context) - self._results.completed_count += 1 - self._results.success_count += 1 - except Exception as err: - self._results.completed_count += 1 - self._results.error_trap.handle_error(err) - if not self._disable_early_shutdown and self._results.is_error_rate_exceeded( - self._shutdown_window_size - ): - if not self._results.early_shutdown: - self._results.early_shutdown = True - self._shutdown_event.set() - if self._error_callback is not None: - try: - self._error_callback(err, context=context) - except Exception: - logger.warning("error_callback raised an exception", exc_info=True) - - def _raise_task_error(self) -> None: - raise DataDesignerRuntimeError( - "\n".join( - [ - f"{LOG_INDENT}Data generation was terminated early due to error rate exceeding threshold.", - f"{LOG_INDENT}The summary of encountered errors is: \n{json.dumps(self._results.summary, indent=4)}", - ] - ) - ) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/config_compiler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/config_compiler.py index 208fa6d80..d1c48fd00 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/config_compiler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/config_compiler.py @@ -3,7 +3,6 @@ from __future__ import annotations -from data_designer.config.base import ProcessorConfig from data_designer.config.column_types import DataDesignerColumnType from data_designer.config.data_designer_config import DataDesignerConfig from data_designer.engine.dataset_builders.multi_column_configs import ( @@ -54,9 +53,3 @@ def compile_dataset_builder_column_configs(config: DataDesignerConfig) -> list[D compiled_column_configs.extend(generated_column_configs) return compiled_column_configs - - -def compile_dataset_builder_processor_configs( - config: DataDesignerConfig, -) -> list[ProcessorConfig]: - return config.processors or [] diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/errors.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/errors.py index 4cf59697a..04774379c 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/errors.py @@ -6,9 +6,6 @@ from data_designer.engine.errors import DataDesignerError -class DatasetBatchManagementError(DataDesignerError): ... - - class ConfigCompilationError(DataDesignerError): ... diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py index 29b7d99bc..50c4dbb96 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py @@ -15,7 +15,6 @@ DatasetBuilderColumnConfigT, MultiColumnConfig, ) -from data_designer.engine.dataset_builders.scheduling.task_model import SliceRef from data_designer.engine.dataset_builders.utils.errors import ConfigCompilationError, DAGCircularDependencyError from data_designer.logging import LOG_INDENT @@ -283,42 +282,6 @@ def compute_task_count(self, num_records: int, buffer_size: int) -> dict[str, in counts[col] = num_row_groups return counts - def compute_cell_dependencies( - self, - column: str, - row_group: int, - row_index: int | None, - row_group_size: int, - ) -> list[SliceRef]: - """Derive cell-level deps on demand from column-level DAG + strategy. - - Returns a list of ``SliceRef`` that must be complete before this task can run. - """ - deps: list[SliceRef] = [] - for up_col in self.get_upstream_columns(column): - up_strategy = self._strategies[up_col] - if up_strategy == GenerationStrategy.CELL_BY_CELL: - if row_index is not None: - deps.append(SliceRef(up_col, row_group, row_index)) - else: - for ri in range(row_group_size): - deps.append(SliceRef(up_col, row_group, ri)) - else: - deps.append(SliceRef(up_col, row_group, None)) - return deps - - def to_mermaid(self) -> str: - """Mermaid diagram string with strategy annotations.""" - lines = ["graph TD"] - for col in self._columns: - strat = self._strategies[col] - label = f"{col} [{strat.value}]" - lines.append(f' {col}["{label}"]') - for col in self._columns: - for dep in sorted(self._upstream.get(col, set())): - lines.append(f" {dep} --> {col}") - return "\n".join(lines) - def topologically_sort_column_configs(column_configs: list[ColumnConfigT]) -> list[ColumnConfigT]: """Return column configs in dependency order using Kahn's algorithm. diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py index 98220e35f..e3cc28255 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py @@ -50,10 +50,6 @@ def update_cell(self, row_group: int, row_index: int, column: str, value: Any) - """Write a single cell value. Thread-safe within the asyncio event loop.""" self._buffers[row_group][row_index][column] = value - def update_cells(self, row_group: int, row_index: int, values: dict[str, Any]) -> None: - """Write multiple cell values for a single row.""" - self._buffers[row_group][row_index].update(values) - def update_batch(self, row_group: int, column: str, values: list[Any]) -> None: """Write a full column for all rows in a row group.""" buf = self._buffers[row_group] diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/skip_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/skip_tracker.py index d827d0568..ed91d7047 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/skip_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/skip_tracker.py @@ -10,20 +10,9 @@ from __future__ import annotations from collections.abc import Sequence -from dataclasses import dataclass from typing import Final SKIPPED_COLUMNS_RECORD_KEY: Final[str] = "__internal_skipped_columns" -SKIP_METADATA_RESTORE_ID_COLUMN_PREFIX: Final[str] = "__internal_skip_restore_id" - - -@dataclass(frozen=True, slots=True) -class SkipMetadataRestoreContext: - """Metadata needed to restore skip provenance after a DataFrame round-trip.""" - - restore_id_column: str - source_ids: set[str] - skipped_columns_by_source_id: dict[str, set[str]] def apply_skip_to_record( @@ -56,81 +45,3 @@ def strip_skip_metadata_for_dataframe_row(record: dict) -> dict: def strip_skip_metadata_from_records(records: Sequence[dict]) -> list[dict]: """Map :func:`strip_skip_metadata_for_dataframe_row` over *records*.""" return [strip_skip_metadata_for_dataframe_row(r) for r in records] - - -def prepare_records_for_skip_metadata_round_trip( - records: Sequence[dict], -) -> tuple[list[dict], SkipMetadataRestoreContext | None]: - """Prepare records for a DataFrame round-trip while preserving skip metadata. - - Returns stripped records ready for ``pd.DataFrame(...)``. If any record has - skip metadata, injects a hidden restore-ID column and returns a context that - can later be passed to :func:`restore_skip_metadata`. - """ - if not any(SKIPPED_COLUMNS_RECORD_KEY in record for record in records): - return strip_skip_metadata_from_records(records), None - - restore_id_column = _choose_restore_id_column(records) - prepared_records: list[dict] = [] - source_ids: set[str] = set() - skipped_columns_by_source_id: dict[str, set[str]] = {} - - for index, record in enumerate(records): - source_id = str(index) - source_ids.add(source_id) - prepared_record = strip_skip_metadata_for_dataframe_row(record) - prepared_record[restore_id_column] = source_id - prepared_records.append(prepared_record) - - meta = record.get(SKIPPED_COLUMNS_RECORD_KEY) - if meta is not None: - skipped_columns_by_source_id[source_id] = set(meta) - - return prepared_records, SkipMetadataRestoreContext( - restore_id_column=restore_id_column, - source_ids=source_ids, - skipped_columns_by_source_id=skipped_columns_by_source_id, - ) - - -def restore_skip_metadata( - records: Sequence[dict], - *, - context: SkipMetadataRestoreContext, -) -> None: - """Restore skip provenance using hidden restore IDs instead of row position.""" - restored_source_ids: list[str] = [] - for record in records: - if context.restore_id_column not in record: - raise ValueError( - f"Records returned from the DataFrame round-trip must preserve " - f"the internal column {context.restore_id_column!r} so skip " - "provenance can be restored." - ) - - source_id = str(record.pop(context.restore_id_column)) - if source_id not in context.source_ids: - raise ValueError( - f"Record returned unknown restore ID {source_id!r}. Skip provenance " - "can only be restored for rows derived from the original input." - ) - - restored_source_ids.append(source_id) - meta = context.skipped_columns_by_source_id.get(source_id) - if meta is not None: - record[SKIPPED_COLUMNS_RECORD_KEY] = set(meta) - - if len(restored_source_ids) != len(context.source_ids) or set(restored_source_ids) != context.source_ids: - raise ValueError( - "Full-column generation changed the row identity mapping. Returned rows must preserve " - "a 1:1 mapping to the original input so skip provenance can be restored." - ) - - -def _choose_restore_id_column(records: Sequence[dict]) -> str: - candidate = SKIP_METADATA_RESTORE_ID_COLUMN_PREFIX - suffix = 0 - while any(candidate in record for record in records): - suffix += 1 - candidate = f"{SKIP_METADATA_RESTORE_ID_COLUMN_PREFIX}_{suffix}" - return candidate diff --git a/packages/data-designer-engine/src/data_designer/engine/errors.py b/packages/data-designer-engine/src/data_designer/engine/errors.py index 3aee0544f..5c455e42a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/errors.py @@ -11,9 +11,6 @@ class DataDesignerRuntimeError(DataDesignerError): ... -class UnknownModelAliasError(DataDesignerError): ... - - class UnknownProviderError(DataDesignerError): ... diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py index f0b79c28b..8c459f848 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py @@ -5,7 +5,6 @@ import calendar import email.utils -import json import time from enum import Enum @@ -134,31 +133,6 @@ def map_http_error_to_provider_error( ) -def extract_message_from_exception_string(raw: str) -> str: - """Extract a human-readable message from a stringified provider exception. - - Some providers format errors as ``"Error code: 400 - {json}"``. This - mirrors the structured-key lookup in ``_extract_structured_message`` but - operates on a raw string instead of an ``HttpResponse``. - """ - json_start = raw.find("{") - if json_start != -1: - try: - payload = json.loads(raw[json_start:]) - except (json.JSONDecodeError, ValueError): - return raw - if isinstance(payload, dict): - for key in ("message", "error", "detail"): - value = payload.get(key) - if isinstance(value, str) and value.strip(): - return value.strip() - if isinstance(value, dict): - nested = value.get("message") - if isinstance(nested, str) and nested.strip(): - return nested.strip() - return raw - - def _extract_response_text(response: HttpResponse) -> str: # Try structured JSON extraction first — most providers return structured error # bodies and we want the human-readable message, not raw JSON. diff --git a/packages/data-designer-engine/src/data_designer/engine/models/parsers/postprocessors.py b/packages/data-designer-engine/src/data_designer/engine/models/parsers/postprocessors.py index 4635562cd..85b41f4a7 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/parsers/postprocessors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/parsers/postprocessors.py @@ -4,12 +4,10 @@ from __future__ import annotations import json_repair -from pydantic import BaseModel, ValidationError from data_designer.engine.models.parsers.types import ( CodeBlock, LLMStructuredResponse, - PydanticTypeBlock, StructuredDataBlock, TextBlock, ) @@ -57,37 +55,3 @@ def deserialize_json_code( processed_response.parsed.append(block) return processed_response - - -class RealizePydanticTypes: - types: list[type[BaseModel]] - - def __init__(self, types: list[type[BaseModel]]): - self.types = types - - def _fit_types(self, obj: dict) -> BaseModel | None: - final_obj = None - - for t in self.types: - try: - final_obj = t.model_validate(obj) - except ValidationError: - pass - - return final_obj - - def __call__(self, structured_response: LLMStructuredResponse) -> LLMStructuredResponse: - processed_response = structured_response.model_copy() - processed_response.parsed = [] - - for block in structured_response.parsed: - if isinstance(block, StructuredDataBlock): - new_block = block - pydantic_obj = self._fit_types(block.obj) - if pydantic_obj: - new_block = PydanticTypeBlock(serialized=block.serialized, obj=pydantic_obj) - processed_response.parsed.append(new_block) - else: - processed_response.parsed.append(block) - - return processed_response diff --git a/packages/data-designer-engine/src/data_designer/engine/models/request_admission/controller.py b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/controller.py index 38bbd0598..d0273ebca 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/request_admission/controller.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/controller.py @@ -635,17 +635,6 @@ def _apply_outcome( if state.in_flight == 0 and outcome.kind not in {"local_cancelled", "local_timeout"}: state.consecutive_rate_limits = 0 - def _increment_waiter(self, item: RequestAdmissionItem) -> None: - with self._lock: - self._get_or_create_state(item.resource).waiters += 1 - self._sequence += 1 - - def _decrement_waiter(self, item: RequestAdmissionItem) -> None: - with self._lock: - state = self._get_or_create_state(item.resource) - state.waiters = max(0, state.waiters - 1) - self._sequence += 1 - def _get_or_create_state(self, resource: RequestResourceKey) -> AdaptiveRequestLimitState: state = self._domains.get(resource) if state is None: diff --git a/packages/data-designer-engine/src/data_designer/engine/observability.py b/packages/data-designer-engine/src/data_designer/engine/observability.py index a7a28c41b..36453013b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/observability.py +++ b/packages/data-designer-engine/src/data_designer/engine/observability.py @@ -213,32 +213,3 @@ def emit_scheduler_event(self, event: SchedulerAdmissionEvent) -> None: ... class RequestAdmissionEventSink(Protocol): def emit_request_event(self, event: RequestAdmissionEvent) -> None: ... - - -class InMemoryAdmissionEventSink: - """Small sink used by tests, diagnostics, and benchmark smoke runs.""" - - def __init__(self) -> None: - self.scheduler_events: list[SchedulerAdmissionEvent] = [] - self.request_events: list[RequestAdmissionEvent] = [] - - def emit_scheduler_event(self, event: SchedulerAdmissionEvent) -> None: - self.scheduler_events.append(event) - - def emit_request_event(self, event: RequestAdmissionEvent) -> None: - self.request_events.append(event) - - -@dataclass(frozen=True) -class CorrelatedRuntimeView: - scheduler_events: tuple[SchedulerAdmissionEvent, ...] - request_events: tuple[RequestAdmissionEvent, ...] - - @property - def timeline(self) -> tuple[SchedulerAdmissionEvent | RequestAdmissionEvent, ...]: - return tuple( - sorted( - (*self.scheduler_events, *self.request_events), - key=lambda event: (event.captured_at_monotonic, event.sequence), - ) - ) diff --git a/packages/data-designer-engine/src/data_designer/engine/resources/resource_provider.py b/packages/data-designer-engine/src/data_designer/engine/resources/resource_provider.py index 1a969e2b1..b6fa89411 100644 --- a/packages/data-designer-engine/src/data_designer/engine/resources/resource_provider.py +++ b/packages/data-designer-engine/src/data_designer/engine/resources/resource_provider.py @@ -56,31 +56,6 @@ def get_dataset_metadata(self) -> DatasetMetadata: return DatasetMetadata(seed_column_names=seed_column_names) -def _validate_tool_configs_against_providers( - tool_configs: list[ToolConfig], - mcp_providers: list[MCPProviderT], -) -> None: - """Validate that all providers referenced in tool configs exist. - - Args: - tool_configs: List of tool configurations to validate. - mcp_providers: List of available MCP provider configurations. - - Raises: - ValueError: If a tool config references a provider that doesn't exist. - """ - available_providers = {p.name for p in mcp_providers} - - for tc in tool_configs: - missing_providers = [p for p in tc.providers if p not in available_providers] - if missing_providers: - available_list = sorted(available_providers) if available_providers else ["(none configured)"] - raise ValueError( - f"ToolConfig '{tc.tool_alias}' references provider(s) {missing_providers!r} " - f"which are not registered. Available providers: {available_list}" - ) - - def create_resource_provider( *, artifact_storage: ArtifactStorage, diff --git a/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py b/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py index 6bd6e3dd9..993e9f536 100644 --- a/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py +++ b/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py @@ -143,10 +143,6 @@ def _validate_image(self, image_path: Path) -> None: image_path.unlink(missing_ok=True) raise - def _ensure_images_directory(self) -> None: - """Create images directory if it doesn't exist (lazy initialization).""" - self.images_dir.mkdir(parents=True, exist_ok=True) - def _sanitize_subfolder_name(self, name: str) -> str: """Sanitize subfolder name to prevent path traversal and filesystem issues.""" # Replace path separators and parent directory references with underscores diff --git a/packages/data-designer-engine/src/data_designer/engine/testing/__init__.py b/packages/data-designer-engine/src/data_designer/engine/testing/__init__.py index 2b5b62d42..9050b41ea 100644 --- a/packages/data-designer-engine/src/data_designer/engine/testing/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/testing/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations +from data_designer.engine.testing.observability import CorrelatedRuntimeView, InMemoryAdmissionEventSink from data_designer.engine.testing.seed_readers import LineFanoutDirectorySeedReader from data_designer.engine.testing.stubs import ( StubChoice, @@ -16,6 +17,8 @@ from data_designer.engine.testing.utils import assert_valid_plugin __all__ = [ + "CorrelatedRuntimeView", + "InMemoryAdmissionEventSink", LineFanoutDirectorySeedReader.__name__, "StubChoice", "StubHuggingFaceSeedReader", diff --git a/packages/data-designer-engine/src/data_designer/engine/testing/observability.py b/packages/data-designer-engine/src/data_designer/engine/testing/observability.py new file mode 100644 index 000000000..cc3237ae8 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/testing/observability.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + +from data_designer.engine.observability import RequestAdmissionEvent, SchedulerAdmissionEvent + + +class InMemoryAdmissionEventSink: + """In-memory admission-event sink for tests and benchmark smoke runs.""" + + def __init__(self) -> None: + self.scheduler_events: list[SchedulerAdmissionEvent] = [] + self.request_events: list[RequestAdmissionEvent] = [] + + def emit_scheduler_event(self, event: SchedulerAdmissionEvent) -> None: + self.scheduler_events.append(event) + + def emit_request_event(self, event: RequestAdmissionEvent) -> None: + self.request_events.append(event) + + +@dataclass(frozen=True) +class CorrelatedRuntimeView: + """Combined chronological view of scheduler and request events for tests.""" + + scheduler_events: tuple[SchedulerAdmissionEvent, ...] + request_events: tuple[RequestAdmissionEvent, ...] + + @property + def timeline(self) -> tuple[SchedulerAdmissionEvent | RequestAdmissionEvent, ...]: + return tuple( + sorted( + (*self.scheduler_events, *self.request_events), + key=lambda event: (event.captured_at_monotonic, event.sequence), + ) + ) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_model.py b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_model.py index cdc5e6c6a..79f00cbcc 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_model.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_model.py @@ -5,7 +5,7 @@ import pytest -from data_designer.engine.dataset_builders.scheduling.task_model import Task, TaskResult, TaskTrace +from data_designer.engine.dataset_builders.scheduling.task_model import Task, TaskTrace def test_task_is_frozen() -> None: @@ -41,23 +41,6 @@ def test_task_types(task_type: str) -> None: assert task.task_type == task_type -def test_task_result_success() -> None: - task = Task(column="col_a", row_group=0, row_index=0, task_type="cell") - result = TaskResult(task=task, status="success", output={"col_a": "value"}) - assert result.status == "success" - assert result.error is None - assert result.retryable is False - - -def test_task_result_error() -> None: - task = Task(column="col_a", row_group=0, row_index=0, task_type="cell") - exc = ValueError("bad input") - result = TaskResult(task=task, status="error", error=exc, retryable=True) - assert result.status == "error" - assert result.error is exc - assert result.retryable is True - - def test_task_trace_from_task() -> None: task = Task(column="col_a", row_group=1, row_index=2, task_type="cell") trace = TaskTrace.from_task(task) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py index bb5cf5685..cddcf31a0 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py @@ -68,8 +68,8 @@ RequestResourceKey, ) from data_designer.engine.models.resources import ProviderModelKey -from data_designer.engine.observability import InMemoryAdmissionEventSink from data_designer.engine.resources.resource_provider import ResourceProvider +from data_designer.engine.testing import InMemoryAdmissionEventSink MODEL_ALIAS = "stub" @@ -2303,40 +2303,6 @@ async def record_sleep(delay: float) -> None: assert yielded_delays == [0] -@pytest.mark.asyncio(loop_scope="session") -async def test_scheduler_dispatch_does_not_scan_ready_frontier(monkeypatch: pytest.MonkeyPatch) -> None: - provider = _mock_provider() - configs = [ - SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), - LLMTextColumnConfig(name="cell_out", prompt="{{ seed }}", model_alias=MODEL_ALIAS), - ] - strategies = { - "seed": GenerationStrategy.FULL_COLUMN, - "cell_out": GenerationStrategy.CELL_BY_CELL, - } - generators = { - "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), - "cell_out": MockCellGenerator(config=_expr_config("cell_out"), resource_provider=provider), - } - graph = ExecutionGraph.create(configs, strategies) - tracker = CompletionTracker.with_graph(graph, [(0, 3)]) - - def fail_get_ready_tasks(*args: Any, **kwargs: Any) -> list[Task]: - raise AssertionError("scheduler should apply returned frontier deltas instead of scanning ready tasks") - - monkeypatch.setattr(tracker, "get_ready_tasks", fail_get_ready_tasks) - scheduler = AsyncTaskScheduler( - generators=generators, - graph=graph, - tracker=tracker, - row_groups=[(0, 3)], - ) - - await asyncio.wait_for(scheduler.run(), timeout=10.0) - - assert tracker.is_row_group_complete(0, 3, ["seed", "cell_out"]) - - @pytest.mark.asyncio(loop_scope="session") async def test_scheduler_pre_batch_drop_removes_pending_ready_task() -> None: provider = _mock_provider() @@ -3978,7 +3944,6 @@ def test_scheduler_adaptive_row_group_target_stays_blocked_after_llm_lease_boots def test_scheduler_adaptive_row_group_queue_guard_uses_in_flight_task_cap() -> None: scheduler, _tracker = _build_simple_pipeline(num_records=2, buffer_size=1) scheduler._max_in_flight_tasks = 2 - scheduler._max_model_task_admission = 100 scheduler._fair_queue = SimpleNamespace( view=lambda: SimpleNamespace(queued_total=8, queued_peer_demand_by_resource={}) ) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py index 6b257a062..90f4dba3b 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py @@ -33,6 +33,7 @@ from data_designer.engine.dataset_builders.dataset_builder import DatasetBuilder, build_row_group_resume_plan from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError from data_designer.engine.dataset_builders.row_group_plan import CompactRowGroupPlan +from data_designer.engine.dataset_builders.utils.processor_runner import ProcessorRunner from data_designer.engine.models.telemetry import InferenceEvent, NemoSourceEnum, TaskStatusEnum from data_designer.engine.models.usage import ModelUsageStats, TokenUsageStats from data_designer.engine.processing.processors.base import Processor @@ -44,6 +45,10 @@ import pandas as pd +def _replace_processors(builder: DatasetBuilder, processors: list[Processor]) -> None: + builder._processor_runner = ProcessorRunner(processors=processors, artifact_storage=builder.artifact_storage) + + @pytest.fixture def stub_test_column_configs(): return [ @@ -501,7 +506,7 @@ def test_run_after_generation( mock_processor = create_mock_processor("proc", ["process_after_generation"]) mock_processor.process_after_generation.side_effect = processor_fn - simple_builder.set_processor_runner([mock_processor]) + _replace_processors(simple_builder, [mock_processor]) simple_builder._processor_runner.run_after_generation(batch_size) mock_processor.process_after_generation.assert_called_once() @@ -524,7 +529,7 @@ def test_all_processor_stages_run_in_order(builder_with_seed, mode): df, )[1] - builder_with_seed.set_processor_runner([mock_processor]) + _replace_processors(builder_with_seed, [mock_processor]) if mode == "preview": raw_dataset = builder_with_seed.build_preview(num_records=3) @@ -544,7 +549,7 @@ def test_processor_exception_in_process_after_batch_raises_error(simple_builder) mock_processor = create_mock_processor("failing_processor", ["process_after_batch"]) mock_processor.process_after_batch.side_effect = ValueError("Post-batch processing failed") - simple_builder.set_processor_runner([mock_processor]) + _replace_processors(simple_builder, [mock_processor]) with pytest.raises(DatasetProcessingError, match="Failed in process_after_batch"): simple_builder._processor_runner.run_post_batch(lazy.pd.DataFrame({"id": [1, 2, 3]}), current_batch_number=0) @@ -553,7 +558,7 @@ def test_processor_exception_in_process_after_batch_raises_error(simple_builder) def test_processor_with_no_implemented_stages_is_skipped(builder_with_seed): """Test that a processor implementing no stages doesn't cause errors.""" mock_processor = create_mock_processor("noop_processor", []) - builder_with_seed.set_processor_runner([mock_processor]) + _replace_processors(builder_with_seed, [mock_processor]) result = builder_with_seed.build_preview(num_records=3) @@ -573,7 +578,7 @@ def test_multiple_processors_run_in_definition_order(builder_with_seed): p.process_before_batch.side_effect = lambda df, lbl=label: (call_order.append(lbl), df)[1] processors.append(p) - builder_with_seed.set_processor_runner(processors) + _replace_processors(builder_with_seed, processors) builder_with_seed.build(num_records=3) assert call_order == ["a", "b", "c"] @@ -582,7 +587,7 @@ def test_multiple_processors_run_in_definition_order(builder_with_seed): def test_pre_batch_processor_row_count_change_rejected(builder_with_seed, caplog): mock_processor = create_mock_processor("filtering_processor", ["process_before_batch"]) mock_processor.process_before_batch.side_effect = lambda df: df.iloc[:2].reset_index(drop=True) - builder_with_seed.set_processor_runner([mock_processor]) + _replace_processors(builder_with_seed, [mock_processor]) with caplog.at_level(logging.INFO): with pytest.raises(DatasetGenerationError, match="Pre-batch processor changed row count"): @@ -594,7 +599,7 @@ def test_pre_batch_processor_row_count_change_rejected(builder_with_seed, caplog def test_process_preview_with_empty_dataframe(simple_builder): """Test that process_preview handles empty DataFrames gracefully.""" mock_processor = create_mock_processor("test_processor", ["process_after_batch", "process_after_generation"]) - simple_builder.set_processor_runner([mock_processor]) + _replace_processors(simple_builder, [mock_processor]) result = simple_builder.process_preview(lazy.pd.DataFrame()) @@ -1604,7 +1609,7 @@ def test_build_resume_complete_dataset_runs_after_generation_when_no_marker( builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=2) after_gen_processor = create_mock_processor("after_gen", ["process_after_generation"]) - builder.set_processor_runner([after_gen_processor]) + _replace_processors(builder, [after_gen_processor]) builder.build(num_records=4, resume=ResumeMode.ALWAYS) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_async_concurrency.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_async_concurrency.py deleted file mode 100644 index 4baa829af..000000000 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_async_concurrency.py +++ /dev/null @@ -1,478 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import asyncio - -import pytest - -from data_designer.engine.dataset_builders.utils.async_concurrency import ( - AsyncConcurrentExecutor, -) -from data_designer.engine.dataset_builders.utils.concurrency import ExecutorResults -from data_designer.engine.errors import DataDesignerRuntimeError - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -async def _succeed(value: int) -> int: - """Simple coroutine that returns its input doubled.""" - return value * 2 - - -async def _fail(msg: str = "Test error") -> None: - """Simple coroutine that always raises.""" - raise ValueError(msg) - - -async def _succeed_slow(value: int, delay: float = 0.05) -> int: - """Coroutine with a small delay to simulate work.""" - await asyncio.sleep(delay) - return value * 2 - - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - - -def test_basic_creation(): - executor = AsyncConcurrentExecutor( - max_workers=4, - column_name="test_column", - shutdown_error_rate=0.3, - shutdown_error_window=5, - ) - assert executor.max_workers == 4 - assert executor.shutdown_error_rate == 0.3 - assert executor.shutdown_window_size == 5 - assert isinstance(executor.results, ExecutorResults) - assert executor.results.completed_count == 0 - assert executor.results.success_count == 0 - assert executor.results.early_shutdown is False - assert executor.results.failure_threshold == 0.3 - - -def test_successful_execution(): - executor = AsyncConcurrentExecutor(max_workers=4, column_name="test_column") - work_items = [(_succeed(i), None) for i in range(10)] - executor.run(work_items) - - assert executor.results.completed_count == 10 - assert executor.results.success_count == 10 - assert executor.results.error_trap.error_count == 0 - assert executor.results.early_shutdown is False - - -def test_successful_execution_with_context(): - executor = AsyncConcurrentExecutor(max_workers=2, column_name="test_column") - work_items = [(_succeed(i), {"index": i}) for i in range(5)] - executor.run(work_items) - - assert executor.results.completed_count == 5 - assert executor.results.success_count == 5 - - -def test_result_callback(): - results = [] - - def result_callback(result, *, context=None): - results.append((result, context)) - - executor = AsyncConcurrentExecutor( - max_workers=2, - column_name="test_column", - result_callback=result_callback, - ) - work_items = [ - (_succeed(5), {"key": "a"}), - (_succeed(10), {"key": "b"}), - ] - executor.run(work_items) - - assert len(results) == 2 - values = sorted(results, key=lambda r: r[0]) - assert values[0] == (10, {"key": "a"}) - assert values[1] == (20, {"key": "b"}) - - -def test_result_callback_with_none_context(): - results = [] - - def result_callback(result, *, context=None): - results.append((result, context)) - - executor = AsyncConcurrentExecutor( - max_workers=2, - column_name="test_column", - result_callback=result_callback, - ) - work_items = [(_succeed(7), None)] - executor.run(work_items) - - assert len(results) == 1 - assert results[0] == (14, None) - - -def test_error_callback(): - errors = [] - - def error_callback(exc, *, context=None): - errors.append((exc, context)) - - executor = AsyncConcurrentExecutor( - max_workers=2, - column_name="test_column", - error_callback=error_callback, - disable_early_shutdown=True, - ) - work_items = [(_fail("boom"), {"task": "first"})] - executor.run(work_items) - - assert len(errors) == 1 - assert isinstance(errors[0][0], ValueError) - assert str(errors[0][0]) == "boom" - assert errors[0][1] == {"task": "first"} - - -def test_early_shutdown_when_threshold_exceeded(): - """Error rate exceeds threshold -- should raise DataDesignerRuntimeError.""" - executor = AsyncConcurrentExecutor( - max_workers=4, - column_name="test_column", - shutdown_error_rate=0.5, - shutdown_error_window=2, - ) - # All tasks fail -> 100% error rate, well above 50% threshold - work_items = [(_fail(f"err-{i}"), None) for i in range(10)] - - with pytest.raises(DataDesignerRuntimeError, match="Data generation was terminated early"): - executor.run(work_items) - - assert executor.results.early_shutdown is True - - -def test_no_early_shutdown_below_threshold(): - """Error rate stays below threshold -- should NOT raise.""" - executor = AsyncConcurrentExecutor( - max_workers=4, - column_name="test_column", - shutdown_error_rate=0.5, - shutdown_error_window=20, - ) - # 2 failures + 18 successes = 10% error rate, well below 50% - work_items = [(_fail(f"err-{i}"), None) for i in range(2)] + [(_succeed(i), None) for i in range(18)] - executor.run(work_items) - - assert executor.results.early_shutdown is False - assert executor.results.completed_count == 20 - assert executor.results.success_count == 18 - assert executor.results.error_trap.error_count == 2 - - -def test_disable_early_shutdown(): - """All tasks fail but disable_early_shutdown=True -- no DataDesignerRuntimeError.""" - executor = AsyncConcurrentExecutor( - max_workers=4, - column_name="test_column", - shutdown_error_rate=0.0, - shutdown_error_window=0, - disable_early_shutdown=True, - ) - work_items = [(_fail(f"err-{i}"), None) for i in range(10)] - # Should not raise - executor.run(work_items) - - assert executor.results.error_trap.error_count == 10 - assert executor.results.success_count == 0 - assert executor.results.completed_count == 10 - assert executor.results.early_shutdown is False - - -def test_result_callback_raises_counts_as_failure(): - """When result_callback raises, the task should count as a failure, not a success. - - This validates the fix where a callback exception was previously - double-counted or misattributed. The corrected behavior: if the - coroutine succeeds but the callback raises, completed_count is - incremented once and the error is trapped (success_count is NOT - incremented). - """ - - def bad_callback(result, *, context=None): - raise RuntimeError("callback exploded") - - executor = AsyncConcurrentExecutor( - max_workers=2, - column_name="test_column", - result_callback=bad_callback, - disable_early_shutdown=True, - ) - work_items = [(_succeed(i), None) for i in range(5)] - executor.run(work_items) - - # Each task's coroutine succeeds, but callback raises -> counted as failure - assert executor.results.completed_count == 5 - assert executor.results.success_count == 0 - assert executor.results.error_trap.error_count == 5 - - -def test_error_callback_raises_safely(): - """error_callback raising should not crash the executor -- just log a warning.""" - - def bad_error_callback(exc, *, context=None): - raise RuntimeError("error callback also broke") - - executor = AsyncConcurrentExecutor( - max_workers=2, - column_name="test_column", - error_callback=bad_error_callback, - disable_early_shutdown=True, - ) - work_items = [(_fail(f"err-{i}"), None) for i in range(5)] - # Should not raise despite error_callback raising - executor.run(work_items) - - assert executor.results.completed_count == 5 - assert executor.results.error_trap.error_count == 5 - - -def test_semaphore_bounding(): - """Verify concurrency is bounded by max_workers.""" - max_concurrent = 0 - current_concurrent = 0 - lock = asyncio.Lock() - - async def tracked_work(index: int) -> int: - nonlocal max_concurrent, current_concurrent - async with lock: - current_concurrent += 1 - if current_concurrent > max_concurrent: - max_concurrent = current_concurrent - # Yield control so other tasks can run concurrently - await asyncio.sleep(0.02) - async with lock: - current_concurrent -= 1 - return index - - max_workers = 3 - executor = AsyncConcurrentExecutor(max_workers=max_workers, column_name="test_column") - work_items = [(tracked_work(i), None) for i in range(20)] - executor.run(work_items) - - assert executor.results.completed_count == 20 - assert executor.results.success_count == 20 - assert max_concurrent <= max_workers, f"Max concurrent was {max_concurrent}, expected <= {max_workers}" - # Also confirm the semaphore was actually exercised (more tasks than workers) - assert max_concurrent >= 1 - - -@pytest.mark.parametrize( - "shutdown_error_rate,num_errors,num_successes,shutdown_window,should_raise", - [ - (0.5, 60, 40, 20, True), # 60% errors > 50% threshold - (0.3, 40, 60, 20, True), # 40% errors > 30% threshold - (0.0, 5, 5, 10, True), # Any error > 0% threshold - (1.0, 20, 0, 10, True), # 100% errors >= 100% threshold - (0.5, 10, 90, 20, False), # 10% errors < 50% threshold - (0.3, 10, 90, 20, False), # 10% errors < 30% threshold - (1.0, 50, 50, 20, False), # 50% errors < 100% threshold - ], -) -def test_early_shutdown_parametric(shutdown_error_rate, num_errors, num_successes, shutdown_window, should_raise): - executor = AsyncConcurrentExecutor( - max_workers=10, - column_name="test_column", - shutdown_error_rate=shutdown_error_rate, - shutdown_error_window=shutdown_window, - ) - - # Interleave errors and successes to keep error rate relatively stable - total = num_errors + num_successes - work_items = [] - err_idx = 0 - suc_idx = 0 - if num_errors > 0: - tasks_per_error = total / num_errors - else: - tasks_per_error = float("inf") - - for i in range(total): - if num_errors > 0 and err_idx < num_errors and i >= int(err_idx * tasks_per_error): - work_items.append((_fail(f"err-{err_idx}"), None)) - err_idx += 1 - elif suc_idx < num_successes: - work_items.append((_succeed(suc_idx), None)) - suc_idx += 1 - - if should_raise: - with pytest.raises(DataDesignerRuntimeError, match="Data generation was terminated early"): - executor.run(work_items) - assert executor.results.early_shutdown is True - else: - executor.run(work_items) - assert executor.results.early_shutdown is False - assert executor.results.completed_count == total - assert executor.results.success_count == num_successes - assert executor.results.error_trap.error_count == num_errors - - -def test_mixed_success_and_failure_with_callbacks(): - """Stress test: mix of successes and failures with both callbacks.""" - results_list = [] - errors_list = [] - - def result_callback(result, *, context=None): - results_list.append(result) - - def error_callback(exc, *, context=None): - errors_list.append(exc) - - executor = AsyncConcurrentExecutor( - max_workers=8, - column_name="test_column", - result_callback=result_callback, - error_callback=error_callback, - shutdown_error_rate=0.9, - shutdown_error_window=50, - ) - - async def variable_task(x: int) -> int: - if x % 7 == 0: - raise ValueError(f"Error {x}") - if x % 3 == 0: - await asyncio.sleep(0.001) - return x * 2 - - num_tasks = 100 - work_items = [(variable_task(i), None) for i in range(num_tasks)] - executor.run(work_items) - - expected_errors = sum(1 for i in range(num_tasks) if i % 7 == 0) - expected_successes = num_tasks - expected_errors - - assert executor.results.completed_count == num_tasks - assert executor.results.success_count == expected_successes - assert executor.results.error_trap.error_count == expected_errors - assert len(results_list) == expected_successes - assert len(errors_list) == expected_errors - assert executor.results.early_shutdown is False - - -# --------------------------------------------------------------------------- -# Edge cases (mirroring sync test_concurrency.py) -# --------------------------------------------------------------------------- - - -def test_edge_cases_invalid_max_workers_negative(): - """asyncio.Semaphore(-1) raises ValueError, propagated through future.result().""" - - async def ok() -> int: - return 1 - - coro = ok() - executor = AsyncConcurrentExecutor(max_workers=-1, column_name="test_column") - with pytest.raises(ValueError, match="must be >= 0"): - executor.run([(coro, None)]) - coro.close() # prevent "coroutine was never awaited" warning - - -def test_edge_cases_zero_error_window(): - """With shutdown_error_window=0, the first failure triggers immediate shutdown. - - get_error_rate returns 0.0 only when completed_count < window. With window=0, - that guard never fires, so the first error's rate (1/1 = 100%) exceeds any - non-zero threshold. - """ - executor = AsyncConcurrentExecutor( - max_workers=1, # deterministic ordering - column_name="test_column", - shutdown_error_rate=0.5, - shutdown_error_window=0, - ) - - async def fail() -> None: - raise ValueError("boom") - - async def succeed() -> str: - return "ok" - - with pytest.raises(DataDesignerRuntimeError, match="Data generation was terminated early"): - executor.run([(fail(), None), (succeed(), None), (succeed(), None)]) - - assert executor.results.early_shutdown is True - assert executor.results.completed_count == 1 - assert executor.results.error_trap.error_count == 1 - assert executor.results.success_count == 0 - - -def test_edge_cases_multiple_early_shutdown_skips_pending(): - """After shutdown fires, remaining tasks are skipped via _shutdown_event check.""" - executor = AsyncConcurrentExecutor( - max_workers=1, # sequential execution for deterministic counts - column_name="test_column", - shutdown_error_rate=0.5, - shutdown_error_window=2, - ) - - async def fail() -> None: - raise ValueError("boom") - - async def succeed() -> int: - return 1 - - # 2 failures then 28 successes — shutdown should fire after the 2 failures - work = [(fail(), None), (fail(), None)] + [(succeed(), None) for _ in range(28)] - - with pytest.raises(DataDesignerRuntimeError, match="Data generation was terminated early"): - executor.run(work) - - assert executor.results.early_shutdown is True - # Only the tasks that actually executed get counted - assert executor.results.completed_count <= 3 # at most 2 failures + maybe 1 success - assert executor.results.error_trap.error_count == 2 - # Skipped tasks should NOT inflate completed_count - assert executor.results.completed_count < 30 - - -def test_edge_cases_semaphore_release_on_exception(): - """Verify semaphore is released after a failing task, allowing the next task to run. - - With max_workers=1, if the semaphore weren't released on exception, the second - task would deadlock. - """ - results = [] - errors = [] - - def result_cb(result, *, context=None): - results.append((result, context)) - - def error_cb(exc, *, context=None): - errors.append((type(exc).__name__, str(exc), context)) - - executor = AsyncConcurrentExecutor( - max_workers=1, - column_name="test_column", - result_callback=result_cb, - error_callback=error_cb, - shutdown_error_rate=1.0, # high threshold to avoid early shutdown - shutdown_error_window=10, - ) - - async def fail() -> None: - raise ValueError("boom") - - async def succeed() -> str: - return "ok" - - executor.run([(fail(), {"id": "fail"}), (succeed(), {"id": "ok"})]) - - assert executor.results.early_shutdown is False - assert executor.results.completed_count == 2 - assert executor.results.error_trap.error_count == 1 - assert executor.results.success_count == 1 - assert len(errors) == 1 - assert errors[0] == ("ValueError", "boom", {"id": "fail"}) - assert len(results) == 1 - assert results[0] == ("ok", {"id": "ok"}) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py index 6a5b31a51..c46502aad 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py @@ -20,7 +20,6 @@ from data_designer.config.utils.code_lang import CodeLang from data_designer.config.validator_params import CodeValidatorParams from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig -from data_designer.engine.dataset_builders.scheduling.task_model import SliceRef from data_designer.engine.dataset_builders.utils.errors import ConfigCompilationError, DAGCircularDependencyError from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph @@ -297,48 +296,6 @@ def test_add_column_duplicate_raises() -> None: graph.add_column("col_a", GenerationStrategy.FULL_COLUMN) -# -- Cell dependencies ------------------------------------------------------ - - -def test_cell_deps_cell_by_cell_upstream(simple_graph: ExecutionGraph) -> None: - """question depends on topic (full-column); answer depends on question (cell-by-cell).""" - # answer[rg=0, row=2] should depend on question[rg=0, row=2] - deps = simple_graph.compute_cell_dependencies("answer", row_group=0, row_index=2, row_group_size=5) - assert deps == [SliceRef("question", 0, 2)] - - -def test_cell_deps_full_column_upstream(simple_graph: ExecutionGraph) -> None: - """question depends on topic (full-column).""" - deps = simple_graph.compute_cell_dependencies("question", row_group=0, row_index=1, row_group_size=5) - assert deps == [SliceRef("topic", 0, None)] - - -def test_cell_deps_no_upstream(simple_graph: ExecutionGraph) -> None: - """topic has no upstream.""" - deps = simple_graph.compute_cell_dependencies("topic", row_group=0, row_index=None, row_group_size=5) - assert deps == [] - - -def test_cell_deps_full_column_downstream_of_cell_by_cell(simple_graph: ExecutionGraph) -> None: - """score (full-column) depends on answer (cell-by-cell) → needs ALL rows.""" - deps = simple_graph.compute_cell_dependencies("score", row_group=0, row_index=None, row_group_size=3) - assert sorted(deps) == [SliceRef("answer", 0, 0), SliceRef("answer", 0, 1), SliceRef("answer", 0, 2)] - - -# -- Mermaid output ---------------------------------------------------------- - - -def test_to_mermaid(simple_graph: ExecutionGraph) -> None: - mermaid = simple_graph.to_mermaid() - - assert "graph TD" in mermaid - assert 'topic["topic [full_column]"]' in mermaid - assert 'question["question [cell_by_cell]"]' in mermaid - assert "topic --> question" in mermaid - assert "question --> answer" in mermaid - assert "answer --> score" in mermaid - - # -- MultiColumnConfig ------------------------------------------------------- diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py index 08dad3730..a80b160f7 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py @@ -49,15 +49,6 @@ def test_update_cell() -> None: assert mgr.get_row(0, 1) == {"col_a": "val_1"} -def test_update_cells() -> None: - mgr = RowGroupBufferManager(_mock_artifact_storage()) - mgr.init_row_group(0, 1) - - mgr.update_cells(0, 0, {"col_a": "a", "col_b": "b"}) - - assert mgr.get_row(0, 0) == {"col_a": "a", "col_b": "b"} - - def test_update_batch() -> None: mgr = RowGroupBufferManager(_mock_artifact_storage()) mgr.init_row_group(0, 3) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_skip_tracker.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_skip_tracker.py index cf833c8e7..1f2944e81 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_skip_tracker.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_skip_tracker.py @@ -9,8 +9,6 @@ from data_designer.engine.dataset_builders.utils.skip_tracker import ( SKIPPED_COLUMNS_RECORD_KEY, apply_skip_to_record, - prepare_records_for_skip_metadata_round_trip, - restore_skip_metadata, strip_skip_metadata_for_dataframe_row, strip_skip_metadata_from_records, ) @@ -131,77 +129,3 @@ def test_strip_skip_metadata_for_dataframe_row_no_metadata() -> None: ) def test_strip_skip_metadata_from_records(rows: list[dict], expected: list[dict]) -> None: assert strip_skip_metadata_from_records(rows) == expected - - -def test_prepare_records_for_skip_metadata_round_trip_without_metadata() -> None: - rows = [{"a": 1}, {"a": 2}] - prepared_rows, restore_context = prepare_records_for_skip_metadata_round_trip(rows) - assert restore_context is None - assert prepared_rows == rows - assert prepared_rows is not rows - - -def test_prepare_records_for_skip_metadata_round_trip_injects_restore_ids() -> None: - rows = [ - {"a": 1, SKIPPED_COLUMNS_RECORD_KEY: {"col_x"}}, - {"a": 2}, - {"a": 3, SKIPPED_COLUMNS_RECORD_KEY: {"col_y", "col_z"}}, - ] - prepared_rows, restore_context = prepare_records_for_skip_metadata_round_trip(rows) - assert restore_context is not None - assert SKIPPED_COLUMNS_RECORD_KEY not in prepared_rows[0] - assert restore_context.restore_id_column in prepared_rows[0] - assert restore_context.skipped_columns_by_source_id == { - "0": {"col_x"}, - "2": {"col_y", "col_z"}, - } - - -def test_restore_skip_metadata_uses_restore_ids_after_reorder() -> None: - old = [ - {"a": 1, SKIPPED_COLUMNS_RECORD_KEY: {"col_x"}}, - {"a": 2}, - {"a": 3, SKIPPED_COLUMNS_RECORD_KEY: {"col_z"}}, - ] - prepared_rows, restore_context = prepare_records_for_skip_metadata_round_trip(old) - assert restore_context is not None - restore_id_column = restore_context.restore_id_column - - new = [ - {"a": 30, restore_id_column: prepared_rows[2][restore_id_column]}, - {"a": 10, restore_id_column: prepared_rows[0][restore_id_column]}, - {"a": 20, restore_id_column: prepared_rows[1][restore_id_column]}, - ] - restore_skip_metadata(new, context=restore_context) - - assert new[0][SKIPPED_COLUMNS_RECORD_KEY] == {"col_z"} - assert new[1][SKIPPED_COLUMNS_RECORD_KEY] == {"col_x"} - assert SKIPPED_COLUMNS_RECORD_KEY not in new[2] - - -def test_restore_skip_metadata_rejects_filtered_rows() -> None: - old = [{"a": 1}, {"a": 2}] - prepared_rows, restore_context = prepare_records_for_skip_metadata_round_trip(old) - assert restore_context is None - - old = [ - {"a": 1, SKIPPED_COLUMNS_RECORD_KEY: {"col_x"}}, - {"a": 2}, - ] - prepared_rows, restore_context = prepare_records_for_skip_metadata_round_trip(old) - assert restore_context is not None - restore_id_column = restore_context.restore_id_column - - new = [{"a": 20, restore_id_column: prepared_rows[1][restore_id_column]}] - - with pytest.raises(ValueError, match="1:1 mapping"): - restore_skip_metadata(new, context=restore_context) - - -def test_restore_skip_metadata_rejects_missing_restore_id_column() -> None: - old = [{"a": 1, SKIPPED_COLUMNS_RECORD_KEY: {"col_x"}}] - _prepared_rows, restore_context = prepare_records_for_skip_metadata_round_trip(old) - assert restore_context is not None - - with pytest.raises(ValueError, match="must preserve the internal column"): - restore_skip_metadata([{"a": 10}], context=restore_context) diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py b/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py index dc08a55ff..6e0820078 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py @@ -10,7 +10,6 @@ from data_designer.engine.models.clients.errors import ( ProviderError, ProviderErrorKind, - extract_message_from_exception_string, map_http_error_to_provider_error, map_http_status_to_provider_error_kind, ) @@ -244,46 +243,3 @@ def test_map_http_error_retry_after_returns_none_for_garbage() -> None: ) error = map_http_error_to_provider_error(response=response, provider_name="stub-provider") assert error.retry_after is None - - -@pytest.mark.parametrize( - "raw,expected", - [ - ( - "Error code: 400 - {'error': {'message': 'Context length exceeded', 'type': 'invalid_request_error'}}".replace( - "'", '"' - ), - "Context length exceeded", - ), - ( - 'Error code: 403 - {"error": "Insufficient permissions"}', - "Insufficient permissions", - ), - ( - 'Error code: 500 - {"message": "Internal failure"}', - "Internal failure", - ), - ( - 'Error code: 422 - {"detail": "Unprocessable entity"}', - "Unprocessable entity", - ), - ( - "Connection timed out", - "Connection timed out", - ), - ( - "Error code: 400 - {not valid json", - "Error code: 400 - {not valid json", - ), - ], - ids=[ - "nested-error-message", - "top-level-error-string", - "top-level-message-string", - "top-level-detail-string", - "no-json-passthrough", - "malformed-json-passthrough", - ], -) -def test_extract_message_from_exception_string(raw: str, expected: str) -> None: - assert extract_message_from_exception_string(raw) == expected diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_model_request_executor.py b/packages/data-designer-engine/tests/engine/models/clients/test_model_request_executor.py index 4dbae62e2..e193617ed 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_model_request_executor.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_model_request_executor.py @@ -28,7 +28,7 @@ RequestAdmissionError, ) from data_designer.engine.models.request_admission.resources import RequestAdmissionItem, RequestDomain -from data_designer.engine.observability import InMemoryAdmissionEventSink +from data_designer.engine.testing import InMemoryAdmissionEventSink class _Client: diff --git a/packages/data-designer-engine/tests/engine/models/parsers/test_parser.py b/packages/data-designer-engine/tests/engine/models/parsers/test_parser.py index 613194dfe..9e57b08eb 100644 --- a/packages/data-designer-engine/tests/engine/models/parsers/test_parser.py +++ b/packages/data-designer-engine/tests/engine/models/parsers/test_parser.py @@ -3,11 +3,8 @@ from __future__ import annotations -from pydantic import BaseModel - from data_designer.engine.models.parsers.parser import LLMResponseParser from data_designer.engine.models.parsers.postprocessors import ( - RealizePydanticTypes, deserialize_json_code, merge_text_blocks, ) @@ -15,7 +12,6 @@ from data_designer.engine.models.parsers.types import ( CodeBlock, LLMStructuredResponse, - PydanticTypeBlock, StructuredDataBlock, TextBlock, ) @@ -82,7 +78,7 @@ def test_llm_response_parser_markup_passthrough(): assert block.text == text -def test_llm_response_parser_full_pipeline(): +def test_llm_response_parser_deserializes_json_blocks(): text = """\ Test prompt return. The return has some `code` included. ```json @@ -97,17 +93,10 @@ def test_llm_response_parser_full_pipeline(): That is all there is at the moment.\ """ - class Foo(BaseModel): - baz: int - - class Bar(BaseModel): - foos: list[Foo] - parser = LLMResponseParser( postprocessors=[ merge_text_blocks, deserialize_json_code, - RealizePydanticTypes([Foo, Bar]), ] ) result = parser.parse(text) @@ -119,12 +108,12 @@ class Bar(BaseModel): assert result.parsed[0] == TextBlock(text="Test prompt return. The return has some `code` included.") assert isinstance(result.parsed[1], StructuredDataBlock) - assert isinstance(result.parsed[2], PydanticTypeBlock) - assert isinstance(result.parsed[3], PydanticTypeBlock) + assert isinstance(result.parsed[2], StructuredDataBlock) + assert isinstance(result.parsed[3], StructuredDataBlock) assert result.parsed[1].obj == {"asdf": 42} - assert result.parsed[2].obj == Foo(baz=3) - assert result.parsed[3].obj == Bar(foos=[Foo(baz=1), Foo(baz=2)]) + assert result.parsed[2].obj == {"baz": 3} + assert result.parsed[3].obj == {"foos": [{"baz": 1}, {"baz": 2}]} assert result.parsed[4] == TextBlock(text="That is all there is at the moment.") diff --git a/packages/data-designer-engine/tests/engine/models/parsers/test_postprocessors.py b/packages/data-designer-engine/tests/engine/models/parsers/test_postprocessors.py index 88cdd0296..67096594b 100644 --- a/packages/data-designer-engine/tests/engine/models/parsers/test_postprocessors.py +++ b/packages/data-designer-engine/tests/engine/models/parsers/test_postprocessors.py @@ -1,14 +1,11 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from pydantic import BaseModel - import data_designer.engine.models.parsers.postprocessors as post from data_designer.engine.models.parsers.types import ( CodeBlock, LLMStructuredResponse, PostProcessor, - PydanticTypeBlock, StructuredDataBlock, TextBlock, ) @@ -16,7 +13,6 @@ KNOWN_POSTPROCESSORS = [ post.merge_text_blocks, post.deserialize_json_code, - post.RealizePydanticTypes, ] @@ -62,61 +58,3 @@ def test_deserialize_json_code(): assert isinstance(response.parsed[2], StructuredDataBlock) assert response.parsed[2].obj == {"bar": 43, "baz": [1, 2, 3]} - - -def test_realize_pydantic_types(): - class Foo(BaseModel): - baz: int - - class Bar(BaseModel): - foos: list[Foo] - - blocks = [ - TextBlock(text="abc"), - StructuredDataBlock(serialized="", obj={"asdf": "q"}), - StructuredDataBlock(serialized="", obj={"baz": 42}), - StructuredDataBlock(serialized="", obj={"foos": [{"baz": 1}, {"baz": 2}]}), - TextBlock(text="cba"), - ] - - response = LLMStructuredResponse(response="", markup="", parsed=blocks) - - parser = post.RealizePydanticTypes([Foo, Bar]) - response = parser(response) - - assert len(response.parsed) == 5 - assert isinstance(response.parsed[1], StructuredDataBlock) - assert isinstance(response.parsed[2], PydanticTypeBlock) - assert isinstance(response.parsed[3], PydanticTypeBlock) - - assert response.parsed[2].obj == Foo(baz=42) - assert response.parsed[3].obj == Bar(foos=[Foo(baz=1), Foo(baz=2)]) - - -def test_deserialize_realize_pipeline(): - class Foo(BaseModel): - baz: int - - class Bar(BaseModel): - foos: list[Foo] - - blocks = [ - TextBlock(text="abc"), - CodeBlock(code='{"asdf": "q"}', code_lang="json"), - CodeBlock(code='{"baz": 42}', code_lang="json"), - CodeBlock(code='{"foos": [{"baz": 1}, {"baz": 2}]}', code_lang="json"), - TextBlock(text="cba"), - ] - - response = LLMStructuredResponse(response="", markup="", parsed=blocks) - - parser = post.RealizePydanticTypes([Foo, Bar]) - response = parser(post.deserialize_json_code(response)) - - assert len(response.parsed) == 5 - assert isinstance(response.parsed[1], StructuredDataBlock) - assert isinstance(response.parsed[2], PydanticTypeBlock) - assert isinstance(response.parsed[3], PydanticTypeBlock) - - assert response.parsed[2].obj == Foo(baz=42) - assert response.parsed[3].obj == Bar(foos=[Foo(baz=1), Foo(baz=2)]) diff --git a/packages/data-designer-engine/tests/engine/models/request_admission/test_controller.py b/packages/data-designer-engine/tests/engine/models/request_admission/test_controller.py index af77f8c40..5b165666c 100644 --- a/packages/data-designer-engine/tests/engine/models/request_admission/test_controller.py +++ b/packages/data-designer-engine/tests/engine/models/request_admission/test_controller.py @@ -25,7 +25,7 @@ RequestGroupSpec, RequestResourceKey, ) -from data_designer.engine.observability import InMemoryAdmissionEventSink +from data_designer.engine.testing import InMemoryAdmissionEventSink def _item(domain: RequestDomain = RequestDomain.CHAT, timeout: float | None = None) -> RequestAdmissionItem: diff --git a/packages/data-designer-engine/tests/engine/resources/test_resource_provider.py b/packages/data-designer-engine/tests/engine/resources/test_resource_provider.py index cb5d569f9..1d9c86530 100644 --- a/packages/data-designer-engine/tests/engine/resources/test_resource_provider.py +++ b/packages/data-designer-engine/tests/engine/resources/test_resource_provider.py @@ -5,11 +5,10 @@ import pytest -from data_designer.config.mcp import LocalStdioMCPProvider, MCPProvider, ToolConfig +from data_designer.config.mcp import LocalStdioMCPProvider, ToolConfig from data_designer.engine.models.registry import ModelRegistry from data_designer.engine.resources.resource_provider import ( ResourceProvider, - _validate_tool_configs_against_providers, create_resource_provider, ) from data_designer.engine.storage.artifact_storage import ArtifactStorage @@ -58,48 +57,6 @@ def test_create_resource_provider_error_cases(test_case, expected_error, tmp_pat class TestToolConfigValidation: """Tests for ToolConfig validation against MCP providers.""" - def test_valid_tool_config_with_existing_providers(self) -> None: - """Valid tool config passes when all providers exist.""" - providers = [ - MCPProvider(name="mcp-1", endpoint="http://localhost:8080/sse"), - LocalStdioMCPProvider(name="mcp-2", command="python", args=["-m", "server"]), - ] - tool_configs = [ - ToolConfig(tool_alias="tools-1", providers=["mcp-1"]), - ToolConfig(tool_alias="tools-2", providers=["mcp-1", "mcp-2"]), - ] - - # Should not raise - _validate_tool_configs_against_providers(tool_configs, providers) - - def test_tool_config_with_missing_provider_raises_error(self) -> None: - """Tool config referencing non-existent provider raises ValueError.""" - providers = [ - MCPProvider(name="mcp-1", endpoint="http://localhost:8080/sse"), - ] - tool_configs = [ - ToolConfig(tool_alias="search-tools", providers=["mcp-1", "nonexistent-mcp"]), - ] - - with pytest.raises(ValueError, match="ToolConfig 'search-tools' references provider"): - _validate_tool_configs_against_providers(tool_configs, providers) - - def test_tool_config_with_no_providers_available(self) -> None: - """Tool config fails when no MCP providers are configured.""" - tool_configs = [ - ToolConfig(tool_alias="search-tools", providers=["some-mcp"]), - ] - - with pytest.raises(ValueError, match="not registered.*none configured"): - _validate_tool_configs_against_providers(tool_configs, []) - - def test_empty_tool_configs_passes(self) -> None: - """Empty tool configs list passes validation.""" - providers = [MCPProvider(name="mcp-1", endpoint="http://localhost:8080/sse")] - - # Should not raise - _validate_tool_configs_against_providers([], providers) - def test_tool_config_validation_happens_during_health_check(self, tmp_path: str) -> None: """Tool config validation is deferred to health checks.""" artifact_storage = ArtifactStorage(artifact_path=str(tmp_path), dataset_name="test") diff --git a/packages/data-designer-engine/tests/engine/test_engine_errors.py b/packages/data-designer-engine/tests/engine/test_engine_errors.py index 9a9e794f2..a84e99540 100644 --- a/packages/data-designer-engine/tests/engine/test_engine_errors.py +++ b/packages/data-designer-engine/tests/engine/test_engine_errors.py @@ -6,7 +6,6 @@ ErrorTrap, RemoteValidationSchemaError, SecretResolutionError, - UnknownModelAliasError, UnknownProviderError, ) @@ -14,7 +13,6 @@ def test_error_message(): test_cases = [ (DataDesignerRuntimeError, "Runtime error occurred"), - (UnknownModelAliasError, "Unknown model alias"), (UnknownProviderError, "Unknown provider"), (SecretResolutionError, "Secret resolution failed"), (RemoteValidationSchemaError, "Remote validation schema error"), @@ -41,7 +39,7 @@ def test_error_trap_track_error(): error1 = DataDesignerRuntimeError("Error 1") error2 = DataDesignerRuntimeError("Error 2") - error3 = UnknownModelAliasError("Error 3") + error3 = SecretResolutionError("Error 3") error_trap.handle_error(error1) error_trap.handle_error(error2) @@ -49,7 +47,7 @@ def test_error_trap_track_error(): assert error_trap.error_count == 3 assert error_trap.task_errors["DataDesignerRuntimeError"] == 2 - assert error_trap.task_errors["UnknownModelAliasError"] == 1 + assert error_trap.task_errors["SecretResolutionError"] == 1 def test_error_trap_model_dump(): diff --git a/packages/data-designer-engine/tests/engine/test_observability.py b/packages/data-designer-engine/tests/engine/test_observability.py index e7d9ce21b..78ec66424 100644 --- a/packages/data-designer-engine/tests/engine/test_observability.py +++ b/packages/data-designer-engine/tests/engine/test_observability.py @@ -8,13 +8,12 @@ from enum import Enum from data_designer.engine.observability import ( - CorrelatedRuntimeView, - InMemoryAdmissionEventSink, RequestAdmissionEvent, RuntimeCorrelation, RuntimeCorrelationProvider, SchedulerAdmissionEvent, ) +from data_designer.engine.testing import CorrelatedRuntimeView, InMemoryAdmissionEventSink class _DiagnosticMode(Enum): diff --git a/packages/data-designer/src/data_designer/cli/repositories/persona_repository.py b/packages/data-designer/src/data_designer/cli/repositories/persona_repository.py index d1ec32f65..65541467f 100644 --- a/packages/data-designer/src/data_designer/cli/repositories/persona_repository.py +++ b/packages/data-designer/src/data_designer/cli/repositories/persona_repository.py @@ -68,23 +68,3 @@ def get_by_code(self, code: str) -> PersonaLocale | None: PersonaLocale if found, None otherwise """ return next((locale for locale in self._registry.locales if locale.code == code), None) - - def get_dataset_name(self, code: str) -> str | None: - """Get the NGC dataset name for a locale. - - Args: - code: Locale code (e.g., 'en_US', 'ja_JP') - - Returns: - Dataset name if locale exists, None otherwise - """ - locale = self.get_by_code(code) - return locale.dataset_name if locale else None - - def get_dataset_prefix(self) -> str: - """Get the dataset prefix for all persona datasets. - - Returns: - Dataset prefix string - """ - return self._registry.dataset_prefix diff --git a/packages/data-designer/src/data_designer/cli/services/plugin_install_service.py b/packages/data-designer/src/data_designer/cli/services/plugin_install_service.py index ea2fa5236..159152f1b 100644 --- a/packages/data-designer/src/data_designer/cli/services/plugin_install_service.py +++ b/packages/data-designer/src/data_designer/cli/services/plugin_install_service.py @@ -166,10 +166,6 @@ def uninstall(self, plan: UninstallPlan) -> None: if return_code != 0: raise RuntimeError(f"Plugin package uninstaller exited with status {return_code}") - def verify_entry_point(self, entry: PluginCatalogEntry) -> bool: - """Verify the runtime plugin's declared entry point is installed and loadable.""" - return self.verify_entry_points([entry]) - def verify_entry_points(self, entries: list[PluginCatalogEntry]) -> bool: """Verify every declared runtime entry point for an installed catalog package can load.""" if not entries: diff --git a/packages/data-designer/src/data_designer/cli/ui.py b/packages/data-designer/src/data_designer/cli/ui.py index e65cbafd2..717f18fe5 100644 --- a/packages/data-designer/src/data_designer/cli/ui.py +++ b/packages/data-designer/src/data_designer/cli/ui.py @@ -634,13 +634,6 @@ def print_header(text: str) -> None: _console.print() -def print_navigation_tip() -> None: - """Display a concise navigation tip for interactive prompts.""" - tip = "[dim]Tip: Use arrow keys to navigate menus, type [bold]'back'[/bold] to edit previous entries, press [bold]Tab[/bold] for completions[/dim]" - _print_with_padding(tip) - _console.print() - - def _print_with_padding(content: str | Panel) -> None: """Internal helper to print with left padding. diff --git a/packages/data-designer/tests/cli/repositories/test_persona_repository.py b/packages/data-designer/tests/cli/repositories/test_persona_repository.py index 905a26c91..06ffe9038 100644 --- a/packages/data-designer/tests/cli/repositories/test_persona_repository.py +++ b/packages/data-designer/tests/cli/repositories/test_persona_repository.py @@ -92,18 +92,6 @@ def test_get_by_code_case_sensitive(repository: PersonaRepository) -> None: assert locale is None -def test_get_dataset_name_valid_locale(repository: PersonaRepository) -> None: - """Test getting dataset name for valid locale.""" - dataset_name = repository.get_dataset_name("en_US") - assert dataset_name == "nemotron-personas-dataset-en_us" - - -def test_get_dataset_name_invalid_locale(repository: PersonaRepository) -> None: - """Test getting dataset name for invalid locale returns None.""" - dataset_name = repository.get_dataset_name("invalid_locale") - assert dataset_name is None - - def test_get_dataset_name_lowercase_conversion(repository: PersonaRepository) -> None: """Test that dataset names use lowercase locale codes.""" # Verify that mixed-case locale codes result in lowercase dataset names @@ -113,12 +101,6 @@ def test_get_dataset_name_lowercase_conversion(repository: PersonaRepository) -> assert locale.dataset_name.islower() or "_" in locale.dataset_name -def test_get_dataset_prefix(repository: PersonaRepository) -> None: - """Test getting dataset prefix.""" - prefix = repository.get_dataset_prefix() - assert prefix == "nemotron-personas-dataset-" - - def test_persona_locale_model() -> None: """Test PersonaLocale Pydantic model.""" locale = PersonaLocale( @@ -172,7 +154,7 @@ def test_locale_size_formats(repository: PersonaRepository) -> None: def test_dataset_name_consistency(repository: PersonaRepository) -> None: """Test that all dataset names follow consistent pattern.""" locales = repository.list_all() - prefix = repository.get_dataset_prefix() + prefix = "nemotron-personas-dataset-" for locale in locales: # All dataset names should start with the prefix diff --git a/packages/data-designer/tests/cli/services/test_plugin_install_service.py b/packages/data-designer/tests/cli/services/test_plugin_install_service.py index ea0744a91..c63639297 100644 --- a/packages/data-designer/tests/cli/services/test_plugin_install_service.py +++ b/packages/data-designer/tests/cli/services/test_plugin_install_service.py @@ -1007,7 +1007,7 @@ def runner(command: list[str], stdin_text: str | None) -> int: @patch("data_designer.cli.services.plugin_install_service.importlib.metadata.entry_points") @patch("data_designer.cli.services.plugin_install_service.importlib.invalidate_caches") -def test_verify_entry_point_invalidates_caches_and_checks_declared_entry_point( +def test_verify_entry_points_invalidates_caches_and_checks_declared_entry_point( mock_invalidate_caches: Mock, mock_entry_points: Mock, ) -> None: @@ -1023,7 +1023,7 @@ def test_verify_entry_point_invalidates_caches_and_checks_declared_entry_point( ] service = PluginInstallService() - assert service.verify_entry_point(entry) is True + assert service.verify_entry_points([entry]) is True mock_invalidate_caches.assert_called_once_with() mock_entry_points.assert_called_once_with(group="data_designer.plugins") diff --git a/packages/data-designer/tests/cli/test_cli_utils.py b/packages/data-designer/tests/cli/test_cli_utils.py index 7bb8eabdf..f7a72030f 100644 --- a/packages/data-designer/tests/cli/test_cli_utils.py +++ b/packages/data-designer/tests/cli/test_cli_utils.py @@ -14,23 +14,12 @@ ) from data_designer.config.errors import InvalidConfigError, InvalidFileFormatError, InvalidFilePathError from data_designer.config.utils.io_helpers import ( - ensure_config_dir_exists, is_http_url, load_config_file, save_config_file, ) -def test_ensure_config_dir_exists(tmp_path: Path) -> None: - """Test creating config directory.""" - test_dir = tmp_path / "test_config" - assert not test_dir.exists() - - ensure_config_dir_exists(test_dir) - assert test_dir.exists() - assert test_dir.is_dir() - - def test_save_and_load_config_file(tmp_path: Path) -> None: """Test saving and loading config files.""" config_file = tmp_path / "test_config.yaml"