Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from datetime import date, datetime, timedelta
from decimal import Decimal
from numbers import Number
from pathlib import Path, PurePosixPath
from pathlib import Path
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse

Expand All @@ -29,15 +29,6 @@
VALID_CONFIG_FILE_EXTENSIONS = {".yaml", ".yml", ".json"}


def ensure_config_dir_exists(config_dir: Path) -> None:
"""Create configuration directory if it doesn't exist.

Args:
config_dir: Directory path to create
"""
config_dir.mkdir(parents=True, exist_ok=True)


def load_config_file(file_path: Path) -> dict:
"""Load a YAML configuration file.

Expand Down Expand Up @@ -177,40 +168,6 @@ def validate_path_contains_files_of_type(path: str | Path, file_extension: str)
raise InvalidFilePathError(f"πŸ›‘ Path {path!r} does not contain files of type {file_extension!r}.")


def smart_load_dataframe(dataframe: str | Path | pd.DataFrame) -> pd.DataFrame:
"""Load a dataframe from file if a path is given, otherwise return the dataframe.

Args:
dataframe: A path to a file or a pandas DataFrame object.

Returns:
A pandas DataFrame object.
"""
if isinstance(dataframe, lazy.pd.DataFrame):
return dataframe

# Get the file extension.
if isinstance(dataframe, str) and dataframe.startswith("http"):
dataframe = _maybe_rewrite_url(dataframe)
# Parse extension from the URL path to avoid query-string contamination (e.g. "csv?token=…").
ext = PurePosixPath(urlparse(dataframe).path).suffix.lstrip(".").lower()
else:
dataframe = Path(dataframe)
ext = dataframe.suffix.lower()
if not dataframe.exists():
raise FileNotFoundError(f"File not found: {dataframe}")

# Load the dataframe based on the file extension.
if ext == "csv":
return lazy.pd.read_csv(dataframe)
elif ext == "json":
return lazy.pd.read_json(dataframe, lines=True)
elif ext == "parquet":
return lazy.pd.read_parquet(dataframe)
else:
raise ValueError(f"Unsupported file format: {dataframe}")


def smart_load_yaml(yaml_in: str | Path | dict) -> dict:
"""Return the yaml config as a dict given flexible input types.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import tempfile
from datetime import date, datetime, timedelta
from decimal import Decimal
from pathlib import Path
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch

import pytest
Expand All @@ -19,66 +17,9 @@
_maybe_rewrite_url,
is_http_url,
serialize_data,
smart_load_dataframe,
smart_load_yaml,
)

if TYPE_CHECKING:
import pandas as pd


@patch("data_designer.config.utils.io_helpers.Path", autospec=True)
@patch("data_designer.config.utils.io_helpers.lazy.pd.read_csv", autospec=True)
@patch("data_designer.config.utils.io_helpers.lazy.pd.read_json", autospec=True)
@patch("data_designer.config.utils.io_helpers.lazy.pd.read_parquet", autospec=True)
def test_smart_load_dataframe(mock_read_parquet, mock_read_json, mock_read_csv, mock_path_cls, stub_dataframe):
mock_read_parquet.return_value = stub_dataframe
mock_read_json.return_value = stub_dataframe
mock_read_csv.return_value = stub_dataframe

# dataframe objects are passed through
assert smart_load_dataframe(stub_dataframe).size == stub_dataframe.size

# url based
stub_base_url = "https://example.com/data.{extention}"
url_csv = stub_base_url.format(extention="csv")
smart_load_dataframe(url_csv)
mock_read_csv.assert_called_once_with(url_csv)

url_json = stub_base_url.format(extention="json")
smart_load_dataframe(url_json)
mock_read_json.assert_called_once_with(url_json, lines=True)

url_parquet = stub_base_url.format(extention="parquet")
smart_load_dataframe(url_parquet)
mock_read_parquet.assert_called_once_with(url_parquet)

url_unknown = stub_base_url.format(extention="unknown")
with pytest.raises(ValueError):
smart_load_dataframe(url_unknown)

# local file based
mock_read_csv.reset_mock()
mock_read_json.reset_mock()
mock_read_parquet.reset_mock()

mock_path = MagicMock(autospec=Path)
mock_path.exists.return_value = True
mock_path.suffix.lower.return_value = "csv"
mock_path_cls.return_value = mock_path

stub_base_path_str = "/some/path/to/data.{extension}"
path_csv = stub_base_path_str.format(extension="csv")
_ = smart_load_dataframe(path_csv)
mock_read_csv.assert_called_once_with(mock_path)

mock_path.reset_mock()
mock_path.suffix.lower.return_value = "json"
mock_path.exists.return_value = False
path_json = stub_base_path_str.format(extension="json")
with pytest.raises(FileNotFoundError):
_ = smart_load_dataframe(Path(path_json))


def test_smart_load_yaml():
stub_dict = {
Expand Down Expand Up @@ -348,37 +289,6 @@ def test_smart_load_yaml_rewrites_huggingface_blob_url(mock_requests: MagicMock)
)


@patch("data_designer.config.utils.io_helpers.lazy.pd.read_csv", autospec=True)
def test_smart_load_dataframe_rewrites_github_blob_url(mock_read_csv: MagicMock, stub_dataframe: pd.DataFrame) -> None:
mock_read_csv.return_value = stub_dataframe

smart_load_dataframe("https://github.com/org/repo/blob/main/data.csv")

mock_read_csv.assert_called_once_with("https://raw.githubusercontent.com/org/repo/main/data.csv")


@patch("data_designer.config.utils.io_helpers.lazy.pd.read_csv", autospec=True)
def test_smart_load_dataframe_rewrites_github_blob_url_with_token(
mock_read_csv: MagicMock, stub_dataframe: pd.DataFrame
) -> None:
mock_read_csv.return_value = stub_dataframe

smart_load_dataframe("https://github.com/org/repo/blob/main/data.csv?token=secret123")

mock_read_csv.assert_called_once_with("https://raw.githubusercontent.com/org/repo/main/data.csv?token=secret123")


@patch("data_designer.config.utils.io_helpers.lazy.pd.read_csv", autospec=True)
def test_smart_load_dataframe_rewrites_huggingface_blob_url(
mock_read_csv: MagicMock, stub_dataframe: pd.DataFrame
) -> None:
mock_read_csv.return_value = stub_dataframe

smart_load_dataframe("https://huggingface.co/datasets/org/repo/blob/main/data.csv")

mock_read_csv.assert_called_once_with("https://huggingface.co/datasets/org/repo/raw/main/data.csv")


def test_maybe_rewrite_github_url_log_does_not_leak_query(caplog: pytest.LogCaptureFixture) -> None:
import logging

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@
RANDOM_SEED = 42
MAX_PROMPT_SAMPLE_SIZE = 1000
WARNING_PREFIX = "⚠️ Error during column profile calculation: "
TEXT_FIELD_AVG_SPACE_COUNT_THRESHOLD = 0.1

logger = logging.getLogger(__name__)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,38 +277,6 @@ def _postprocess_result(

return self._validate_output(result, keys_before, is_dataframe)

def _validate_cell_output(self, row: dict, keys_before: set[str]) -> dict:
"""Validate a single row output (dict) for cell_by_cell; strip undeclared columns."""
expected_new = {self.config.name} | set(self.config.side_effect_columns)
result_keys = set(row.keys())

if self.config.name not in result_keys:
raise CustomColumnGenerationError(
f"Custom generator for column '{self.config.name}' did not create the expected column. "
f"The generator_function must add a column named '{self.config.name}' to the row."
)
missing = set(self.config.side_effect_columns) - result_keys
if missing:
raise CustomColumnGenerationError(
f"Custom generator for column '{self.config.name}' did not create declared side_effect_columns: "
f"{sorted(missing)}. Declared side_effect_columns must be added to the row."
)
removed = keys_before - result_keys
if removed:
raise CustomColumnGenerationError(
f"Custom generator for column '{self.config.name}' removed pre-existing columns: "
f"{sorted(removed)}. The generator_function must not remove any existing columns."
)
undeclared = (result_keys - keys_before) - expected_new
if undeclared:
logger.warning(
f"⚠️ Custom generator for column '{self.config.name}' created undeclared columns: "
f"{sorted(undeclared)}. These columns will be removed. "
f"To keep additional columns, declare them in @custom_column_generator(side_effect_columns=[...])."
)
row = {k: v for k, v in row.items() if k not in undeclared}
return row

def _validate_output(
self, result: dict | pd.DataFrame, keys_before: set[str], is_dataframe: bool
) -> dict | pd.DataFrame:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,6 @@ def __init__(

self._max_concurrent_row_groups = max_concurrent_row_groups
self._max_in_flight_tasks = max_in_flight_tasks
self._max_model_task_admission = max_model_task_admission
self._num_records = num_records
self._buffer_size = buffer_size
self._scheduled_records = self._row_groups.scheduled_total_rows
Expand Down Expand Up @@ -660,9 +659,6 @@ def _apply_frontier_delta(self, delta: FrontierDelta) -> None:
self._discard_ready_task(task)
self._enqueue_ready_tasks(delta.added)

def _enqueue_ready_task(self, task: Task) -> None:
self._enqueue_ready_tasks((task,))

def _enqueue_ready_tasks(self, tasks: tuple[Task, ...]) -> None:
schedulables: list[SchedulableTask] = []
accepted_tasks_by_id: dict[str, Task] = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,6 @@ def first_non_retryable_error(self) -> Exception | None:
"""First non-retryable error captured by the scheduler in the most recent run."""
return self._first_non_retryable_error

def set_processor_runner(self, processors: list[Processor]) -> None:
"""Replace the processor runner with a new one using the given processors."""
self._processor_runner = ProcessorRunner(
processors=processors,
artifact_storage=self.artifact_storage,
)

@functools.cached_property
def single_column_configs(self) -> list[ColumnConfigT]:
configs = []
Expand All @@ -235,10 +228,6 @@ def single_column_configs(self) -> list[ColumnConfigT]:
configs.append(config)
return configs

@functools.cached_property
def single_column_config_by_name(self) -> dict[str, ColumnConfigT]:
return {config.name: config for config in self.single_column_configs}

def build(
self,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Literal
from typing import Literal


@dataclass(frozen=True, order=True)
Expand All @@ -26,17 +26,6 @@ class Task:
task_type: Literal["from_scratch", "cell", "batch", "pre_batch_processor", "post_batch_processor"]


@dataclass
class TaskResult:
"""Outcome of a completed task."""

task: Task
status: Literal["success", "error"]
output: Any = None
error: Exception | None = None
retryable: bool = False


@dataclass
class TaskTrace:
"""Timing trace for a single task. Only created when tracing is enabled."""
Expand Down
Loading
Loading