From c0567f841ef1da25b53fcf09c8387e9796910979 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Thu, 25 Jun 2026 18:54:10 -0300 Subject: [PATCH 1/3] feat: add repeat until workflow stages --- .../pages/concepts/workflow-chaining.mdx | 27 ++ .../src/data_designer/interface/__init__.py | 6 + .../interface/composite_workflow.py | 434 ++++++++++++++++-- .../interface/test_composite_workflow.py | 173 ++++++- 4 files changed, 603 insertions(+), 37 deletions(-) diff --git a/fern/versions/latest/pages/concepts/workflow-chaining.mdx b/fern/versions/latest/pages/concepts/workflow-chaining.mdx index ed41363b2..f215ead74 100644 --- a/fern/versions/latest/pages/concepts/workflow-chaining.mdx +++ b/fern/versions/latest/pages/concepts/workflow-chaining.mdx @@ -138,6 +138,33 @@ workflow.add_stage("enriched", enriched) `on_success_version` is part of the stage resume identity. Change it when the callback's output semantics change. If a callback returns zero rows, the workflow raises by default; set `allow_empty=True` to mark that stage as completed empty and skip downstream stages. +## Repeating until a filtered count + +Use `repeat_until` when a stage should keep generating candidates until its selected output reaches a target row count. This is useful for bounded rejection sampling, such as generating many candidates and keeping only rows that pass a judge or quality gate. + +```python +from data_designer.interface import RepeatUntil + +workflow = data_designer.compose_workflow(name="judge-disagreements") +workflow.add_stage( + "judged", + judges, + num_records=1_000, + on_success=keep_disagreements, + on_success_version="disagreements-v1", + repeat_until=RepeatUntil( + output_records=5_000, + max_iterations=10, + max_generated_records=20_000, + ), +) +workflow.add_stage("enriched", enriched) +``` + +`num_records` is the per-attempt size. In the default `mode="append"`, Data Designer extends the stage by another attempt each iteration, reruns `on_success`, and feeds exactly `output_records` selected rows downstream. Set `on_exhausted="return_partial"` to keep the best partial output when the bounds are reached; otherwise the workflow raises. + +Use `mode="discard"` when each attempt should replace the previous selected output instead of accumulating it. Keep bounded limits in place: a low acceptance rate is often a signal to inspect the recipe, not just to run indefinitely. + ## Changing row counts between stages Each stage has a fixed requested row count while it runs. To resize a workflow, change the selected output at a stage boundary and let the next stage seed from that output. diff --git a/packages/data-designer/src/data_designer/interface/__init__.py b/packages/data-designer/src/data_designer/interface/__init__.py index febf02f55..a64119abb 100644 --- a/packages/data-designer/src/data_designer/interface/__init__.py +++ b/packages/data-designer/src/data_designer/interface/__init__.py @@ -11,6 +11,9 @@ from data_designer.interface.composite_workflow import ( # noqa: F401 CompositeWorkflow, CompositeWorkflowResults, + RepeatUntil, + RepeatUntilExhaustion, + RepeatUntilMode, SkippedStageResult, SkippedStageStatus, ) @@ -33,6 +36,9 @@ "DataDesignerWorkflowError": ("data_designer.interface.errors", "DataDesignerWorkflowError"), "DatasetCreationResults": ("data_designer.interface.results", "DatasetCreationResults"), "ResumeMode": ("data_designer.engine.storage.artifact_storage", "ResumeMode"), + "RepeatUntil": ("data_designer.interface.composite_workflow", "RepeatUntil"), + "RepeatUntilExhaustion": ("data_designer.interface.composite_workflow", "RepeatUntilExhaustion"), + "RepeatUntilMode": ("data_designer.interface.composite_workflow", "RepeatUntilMode"), "SkippedStageResult": ("data_designer.interface.composite_workflow", "SkippedStageResult"), "SkippedStageStatus": ("data_designer.interface.composite_workflow", "SkippedStageStatus"), } diff --git a/packages/data-designer/src/data_designer/interface/composite_workflow.py b/packages/data-designer/src/data_designer/interface/composite_workflow.py index 408083be5..db31b7655 100644 --- a/packages/data-designer/src/data_designer/interface/composite_workflow.py +++ b/packages/data-designer/src/data_designer/interface/composite_workflow.py @@ -60,9 +60,69 @@ "callback_output_path", "output_processor_output_path", "stage_output_override_path", + "repeat_until_output_path", ) +class RepeatUntilMode(StrEnum): + APPEND = "append" + DISCARD = "discard" + + +class RepeatUntilExhaustion(StrEnum): + RAISE = "raise" + RETURN_PARTIAL = "return_partial" + + +@dataclass(frozen=True) +class RepeatUntil: + """Bounded stage-level retry policy for exact selected-output counts.""" + + output_records: int + max_iterations: int + mode: RepeatUntilMode | str = RepeatUntilMode.APPEND + max_generated_records: int | None = None + on_exhausted: RepeatUntilExhaustion | str = RepeatUntilExhaustion.RAISE + trim: bool = True + + def __post_init__(self) -> None: + if self.output_records < 1: + raise DataDesignerWorkflowError("repeat_until.output_records must be at least 1.") + if self.max_iterations < 1: + raise DataDesignerWorkflowError("repeat_until.max_iterations must be at least 1.") + if self.max_generated_records is not None and self.max_generated_records < 1: + raise DataDesignerWorkflowError("repeat_until.max_generated_records must be at least 1.") + try: + mode = RepeatUntilMode(self.mode) + except ValueError as exc: + raise DataDesignerWorkflowError( + f"repeat_until.mode must be one of: {_enum_values(RepeatUntilMode)}." + ) from exc + try: + on_exhausted = RepeatUntilExhaustion(self.on_exhausted) + except ValueError as exc: + raise DataDesignerWorkflowError( + f"repeat_until.on_exhausted must be one of: {_enum_values(RepeatUntilExhaustion)}." + ) from exc + object.__setattr__(self, "mode", mode) + object.__setattr__(self, "on_exhausted", on_exhausted) + + +@dataclass(frozen=True) +class _StageRunResult: + output_result: DatasetCreationResults + actual_records: int + output_seed_path: Path + output_records: int + callback_output_path: Path | None + output_processor_output_path: Path | None + num_records_requested: int + repeat_iterations: int | None = None + repeat_generated_records: int | None = None + repeat_satisfied: bool | None = None + repeat_until_output_path: Path | None = None + + @dataclass(frozen=True) class _WorkflowStage: name: str @@ -76,6 +136,7 @@ class _WorkflowStage: allow_empty: bool sampling_strategy: SamplingStrategy selection_strategy: IndexRange | PartitionBlock | None + repeat_until: RepeatUntil | None class SkippedStageStatus(StrEnum): @@ -207,6 +268,7 @@ def add_stage( allow_empty: bool = False, sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED, selection_strategy: IndexRange | PartitionBlock | None = None, + repeat_until: RepeatUntil | None = None, ) -> CompositeWorkflow: """Add a stage to the workflow. @@ -214,7 +276,8 @@ def add_stage( ``output_processors`` for stage-boundary transforms whose output should feed downstream stages by default. ``output="processor:"`` selects a named processor artifact, and ``on_success`` can override the selected - output by returning a parquet file or directory. + output by returning a parquet file or directory. Use ``repeat_until`` to + rerun the stage until the selected output reaches an exact row count. """ _validate_dir_name(name, "stage name") if any(stage.name == name for stage in self._stages): @@ -238,6 +301,7 @@ def add_stage( allow_empty=allow_empty, sampling_strategy=sampling_strategy, selection_strategy=selection_strategy, + repeat_until=repeat_until, ) ) return self @@ -409,39 +473,17 @@ def run( start_time = time.monotonic() try: - result = self._data_designer.create( - stage_builder, + run_result = self._run_stage( + stage=stage, + stage_builder=stage_builder, + workflow_path=workflow_path, + stage_dir_name=stage_dir_name, + stage_path=stage_path, num_records=num_records, - dataset_name=stage_dir_name, - artifact_path=workflow_path, resume=stage_resume, + prior_stage_metadata=prior_stage_metadata, ) - actual_records = result.count_records() - output_result = result - output_source_result = result - if stage.output_processors: - output_processor_path = stage_path / "output-processors" - if output_processor_path.exists(): - shutil.rmtree(output_processor_path) - output_processor_builder = _output_processor_config_builder( - stage_builder=stage_builder, - seed_path=result.artifact_storage.final_dataset_path, - output_processors=stage.output_processors, - ) - output_result = self._data_designer.create( - output_processor_builder, - num_records=actual_records, - dataset_name="output-processors", - artifact_path=workflow_path / stage_dir_name, - ) - output_source_result = _select_output_result(stage, result, output_result) - - callback_output_path = None - if stage.on_success is not None: - callback_output_path = Path(stage.on_success(result.artifact_storage.base_dataset_path)) - output_seed_path = callback_output_path - else: - output_seed_path = _resolve_stage_output_path(output_source_result, stage.output) + output_seed_path = run_result.output_seed_path override_path = _stage_output_override(stage.name, stage_output_overrides) if override_path is not None: output_seed_path = override_path @@ -458,29 +500,45 @@ def run( stage_metadata.update( { "status": status, - "num_records_actual": actual_records, + "num_records_requested": run_result.num_records_requested, + "num_records_actual": run_result.actual_records, "output_records": output_records, "output_seed_path": _metadata_path_value(workflow_path, output_seed_path), "callback_output_path": ( - _metadata_path_value(workflow_path, callback_output_path) if callback_output_path else None + _metadata_path_value(workflow_path, run_result.callback_output_path) + if run_result.callback_output_path + else None ), "stage_output_override_path": ( _metadata_path_value(workflow_path, override_path) if override_path else None ), "output_processor_output_path": ( - _metadata_path_value(workflow_path, output_result.artifact_storage.base_dataset_path) - if stage.output_processors + _metadata_path_value(workflow_path, run_result.output_processor_output_path) + if run_result.output_processor_output_path else None ), "duration_sec": time.monotonic() - start_time, } ) + if run_result.repeat_iterations is not None: + stage_metadata.update( + { + "repeat_iterations": run_result.repeat_iterations, + "repeat_generated_records": run_result.repeat_generated_records, + "repeat_satisfied": run_result.repeat_satisfied, + "repeat_until_output_path": ( + _metadata_path_value(workflow_path, run_result.repeat_until_output_path) + if run_result.repeat_until_output_path + else None + ), + } + ) except Exception: stage_metadata.update({"status": "failed", "duration_sec": time.monotonic() - start_time}) _write_workflow_metadata(workflow_path, metadata) raise - stage_results[stage.name] = output_result + stage_results[stage.name] = run_result.output_result stage_output_paths[stage.name] = output_seed_path previous_seed_path = output_seed_path previous_output_records = None if status == "completed_empty" else output_records @@ -496,11 +554,305 @@ def run( stage_output_paths=stage_output_paths, ) + def _run_stage( + self, + *, + stage: _WorkflowStage, + stage_builder: DataDesignerConfigBuilder, + workflow_path: Path, + stage_dir_name: str, + stage_path: Path, + num_records: int, + resume: ResumeMode, + prior_stage_metadata: dict[str, Any] | None, + ) -> _StageRunResult: + if stage.repeat_until is None: + return self._run_stage_attempt( + stage=stage, + stage_builder=stage_builder, + artifact_path=workflow_path, + dataset_name=stage_dir_name, + num_records=num_records, + resume=resume, + ) + if stage.repeat_until.mode == RepeatUntilMode.DISCARD: + return self._run_stage_until_discard( + stage=stage, + stage_builder=stage_builder, + workflow_path=workflow_path, + stage_dir_name=stage_dir_name, + stage_path=stage_path, + num_records=num_records, + ) + return self._run_stage_until_append( + stage=stage, + stage_builder=stage_builder, + workflow_path=workflow_path, + stage_dir_name=stage_dir_name, + stage_path=stage_path, + num_records=num_records, + resume=resume, + prior_stage_metadata=prior_stage_metadata, + ) + + def _run_stage_attempt( + self, + *, + stage: _WorkflowStage, + stage_builder: DataDesignerConfigBuilder, + artifact_path: Path, + dataset_name: str, + num_records: int, + resume: ResumeMode, + ) -> _StageRunResult: + result = self._data_designer.create( + stage_builder, + num_records=num_records, + dataset_name=dataset_name, + artifact_path=artifact_path, + resume=resume, + ) + actual_records = result.count_records() + output_result = result + output_source_result = result + stage_path = artifact_path / dataset_name + output_processor_output_path = None + if stage.output_processors: + output_processor_path = stage_path / "output-processors" + if output_processor_path.exists(): + shutil.rmtree(output_processor_path) + output_processor_builder = _output_processor_config_builder( + stage_builder=stage_builder, + seed_path=result.artifact_storage.final_dataset_path, + output_processors=stage.output_processors, + ) + output_result = self._data_designer.create( + output_processor_builder, + num_records=actual_records, + dataset_name="output-processors", + artifact_path=stage_path, + ) + output_source_result = _select_output_result(stage, result, output_result) + output_processor_output_path = output_result.artifact_storage.base_dataset_path + + callback_output_path = None + if stage.on_success is not None: + callback_output_path = Path(stage.on_success(result.artifact_storage.base_dataset_path)) + output_seed_path = callback_output_path + else: + output_seed_path = _resolve_stage_output_path(output_source_result, stage.output) + output_records = _count_parquet_records(output_seed_path) + return _StageRunResult( + output_result=output_result, + actual_records=actual_records, + output_seed_path=output_seed_path, + output_records=output_records, + callback_output_path=callback_output_path, + output_processor_output_path=output_processor_output_path, + num_records_requested=num_records, + ) + + def _run_stage_until_append( + self, + *, + stage: _WorkflowStage, + stage_builder: DataDesignerConfigBuilder, + workflow_path: Path, + stage_dir_name: str, + stage_path: Path, + num_records: int, + resume: ResumeMode, + prior_stage_metadata: dict[str, Any] | None, + ) -> _StageRunResult: + repeat_until = _require_repeat_until(stage) + start_iteration = _append_start_iteration(num_records, resume, prior_stage_metadata) + last_result = None + for iteration in range(start_iteration, repeat_until.max_iterations + 1): + requested_records = num_records * iteration + if _exceeds_max_generated_records(repeat_until, requested_records): + break + attempt_resume = resume if iteration == start_iteration else ResumeMode.ALWAYS + last_result = self._run_stage_attempt( + stage=stage, + stage_builder=stage_builder, + artifact_path=workflow_path, + dataset_name=stage_dir_name, + num_records=requested_records, + resume=attempt_resume, + ) + if last_result.output_records >= repeat_until.output_records: + return _with_repeat_result( + last_result, + stage_path=stage_path, + repeat_until=repeat_until, + iterations=iteration, + generated_records=last_result.actual_records, + satisfied=True, + ) + + return _handle_repeat_until_exhausted( + stage=stage, + repeat_until=repeat_until, + last_result=last_result, + stage_path=stage_path, + iterations=(last_result.num_records_requested // num_records if last_result else 0), + generated_records=(last_result.actual_records if last_result else 0), + ) + + def _run_stage_until_discard( + self, + *, + stage: _WorkflowStage, + stage_builder: DataDesignerConfigBuilder, + workflow_path: Path, + stage_dir_name: str, + stage_path: Path, + num_records: int, + ) -> _StageRunResult: + repeat_until = _require_repeat_until(stage) + last_result = None + generated_records = 0 + iterations_run = 0 + for iteration in range(1, repeat_until.max_iterations + 1): + if _exceeds_max_generated_records(repeat_until, generated_records + num_records): + break + if stage_path.exists(): + shutil.rmtree(stage_path) + last_result = self._run_stage_attempt( + stage=stage, + stage_builder=stage_builder, + artifact_path=workflow_path, + dataset_name=stage_dir_name, + num_records=num_records, + resume=ResumeMode.NEVER, + ) + iterations_run = iteration + generated_records += last_result.actual_records + if last_result.output_records >= repeat_until.output_records: + return _with_repeat_result( + last_result, + stage_path=stage_path, + repeat_until=repeat_until, + iterations=iteration, + generated_records=generated_records, + satisfied=True, + ) + + return _handle_repeat_until_exhausted( + stage=stage, + repeat_until=repeat_until, + last_result=last_result, + stage_path=stage_path, + iterations=iterations_run, + generated_records=generated_records, + ) + def _stage_indices_by_name(stages: list[_WorkflowStage]) -> dict[str, int]: return {stage.name: index for index, stage in enumerate(stages)} +def _require_repeat_until(stage: _WorkflowStage) -> RepeatUntil: + if stage.repeat_until is None: + raise DataDesignerWorkflowError(f"Stage {stage.name!r} has no repeat_until policy.") + return stage.repeat_until + + +def _append_start_iteration( + num_records: int, + resume: ResumeMode, + prior_stage_metadata: dict[str, Any] | None, +) -> int: + if resume != ResumeMode.ALWAYS or prior_stage_metadata is None: + return 1 + prior_requested = prior_stage_metadata.get("num_records_requested") + if not isinstance(prior_requested, int) or prior_requested <= num_records: + return 1 + return -(-prior_requested // num_records) + + +def _exceeds_max_generated_records(repeat_until: RepeatUntil, generated_records: int) -> bool: + return repeat_until.max_generated_records is not None and generated_records > repeat_until.max_generated_records + + +def _with_repeat_result( + result: _StageRunResult, + *, + stage_path: Path, + repeat_until: RepeatUntil, + iterations: int, + generated_records: int, + satisfied: bool, +) -> _StageRunResult: + output_seed_path = result.output_seed_path + output_records = result.output_records + repeat_until_output_path = None + if output_records > repeat_until.output_records and repeat_until.trim: + repeat_until_output_path = stage_path / "repeat-until" / "selected-output" + _write_parquet_head(output_seed_path, repeat_until_output_path, repeat_until.output_records) + output_seed_path = repeat_until_output_path + output_records = repeat_until.output_records + return _StageRunResult( + output_result=result.output_result, + actual_records=result.actual_records, + output_seed_path=output_seed_path, + output_records=output_records, + callback_output_path=result.callback_output_path, + output_processor_output_path=result.output_processor_output_path, + num_records_requested=result.num_records_requested, + repeat_iterations=iterations, + repeat_generated_records=generated_records, + repeat_satisfied=satisfied, + repeat_until_output_path=repeat_until_output_path, + ) + + +def _handle_repeat_until_exhausted( + *, + stage: _WorkflowStage, + repeat_until: RepeatUntil, + last_result: _StageRunResult | None, + stage_path: Path, + iterations: int, + generated_records: int, +) -> _StageRunResult: + selected_records = last_result.output_records if last_result is not None else 0 + if repeat_until.on_exhausted == RepeatUntilExhaustion.RAISE: + raise DataDesignerWorkflowError( + f"Stage {stage.name!r} repeat_until exhausted after {iterations} iteration(s): " + f"selected {selected_records} of {repeat_until.output_records} requested records." + ) + if last_result is None: + raise DataDesignerWorkflowError( + f"Stage {stage.name!r} repeat_until did not run because max_generated_records was too low." + ) + return _with_repeat_result( + last_result, + stage_path=stage_path, + repeat_until=repeat_until, + iterations=iterations, + generated_records=generated_records, + satisfied=False, + ) + + +def _repeat_until_payload(repeat_until: RepeatUntil | None) -> dict[str, Any] | None: + if repeat_until is None: + return None + return { + "output_records": repeat_until.output_records, + "max_iterations": repeat_until.max_iterations, + "mode": repeat_until.mode.value, + "max_generated_records": repeat_until.max_generated_records, + "on_exhausted": repeat_until.on_exhausted.value, + "trim": repeat_until.trim, + } + + +def _enum_values(enum_type: type[StrEnum]) -> str: + return ", ".join(repr(item.value) for item in enum_type) + + def _normalize_stage_names( stage_names: StageTargets | None, stage_indices: dict[str, int], @@ -758,6 +1110,7 @@ def _base_stage_metadata(index: int, stage: _WorkflowStage, stage_dir_name: str) "output": stage.output, "sampling_strategy": stage.sampling_strategy.value, "selection_strategy": _selection_strategy_payload(stage.selection_strategy), + "repeat_until": _repeat_until_payload(stage.repeat_until), } @@ -777,6 +1130,7 @@ def _stage_fingerprint( "on_success_version": stage.on_success_version, "output_processors": [processor.model_dump(mode="json") for processor in stage.output_processors], "output": stage.output, + "repeat_until": _repeat_until_payload(stage.repeat_until), "library_version": get_library_version(), "upstream_fingerprint": upstream_fingerprint, } @@ -841,6 +1195,14 @@ def _load_parquet_dataset(path: Path) -> pd.DataFrame: raise DataDesignerWorkflowError(f"Failed to read parquet files at {str(path)!r}: {e}") from e +def _write_parquet_head(source_path: Path, output_path: Path, num_records: int) -> None: + df = _load_parquet_dataset(source_path).head(num_records) + if output_path.exists(): + shutil.rmtree(output_path) + output_path.mkdir(parents=True) + df.to_parquet(output_path / "data.parquet", index=False) + + def _export_parquet_dataset(source_path: Path, output_path: Path, *, format: ExportFormat | None = None) -> Path: resolved_format: str = format if format is not None else output_path.suffix.lstrip(".").lower() if resolved_format not in SUPPORTED_EXPORT_FORMATS: diff --git a/packages/data-designer/tests/interface/test_composite_workflow.py b/packages/data-designer/tests/interface/test_composite_workflow.py index eb9a7abad..47f9a5eba 100644 --- a/packages/data-designer/tests/interface/test_composite_workflow.py +++ b/packages/data-designer/tests/interface/test_composite_workflow.py @@ -22,7 +22,7 @@ from data_designer.config.seed_source_dataframe import DataFrameSeedSource from data_designer.engine.secret_resolver import PlaintextResolver from data_designer.engine.storage.artifact_storage import ArtifactStorage, BatchStage, ResumeMode -from data_designer.interface.composite_workflow import SkippedStageResult, SkippedStageStatus +from data_designer.interface.composite_workflow import RepeatUntil, SkippedStageResult, SkippedStageStatus from data_designer.interface.data_designer import DataDesigner from data_designer.interface.errors import DataDesignerWorkflowError from data_designer.interface.results import DatasetCreationResults @@ -1030,6 +1030,177 @@ def expand(stage_path: Path) -> Path: assert results.count_stage_output_records("personas") == 4 +def test_composite_workflow_repeat_until_append_accumulates_and_trims( + tmp_path: Path, + stub_model_providers: list[ModelProvider], + stub_model_configs: list[ModelConfig], +) -> None: + stage_1 = _seeded_builder( + stub_model_configs, + [ + {"name": "Ada", "keep": False}, + {"name": "Grace", "keep": True}, + {"name": "Linus", "keep": True}, + {"name": "Margaret", "keep": True}, + ], + ) + stage_1.add_column(ExpressionColumnConfig(name="candidate", expr="{{ name }} candidate")) + + def keep_rows(stage_path: Path) -> Path: + df = lazy.pd.read_parquet(stage_path / "parquet-files") + output_path = stage_path / "callback-outputs" / "kept" + if output_path.exists(): + shutil.rmtree(output_path) + output_path.mkdir(parents=True) + df[df["keep"]].to_parquet(output_path / "data.parquet", index=False) + return output_path + + stage_2 = _expression_builder(stub_model_configs, "final", "{{ candidate }} final") + workflow = _real_data_designer(tmp_path / "artifacts", stub_model_providers).compose_workflow(name="repeat-append") + workflow.add_stage( + "candidates", + stage_1, + num_records=2, + on_success=keep_rows, + on_success_version="kept-v1", + repeat_until=RepeatUntil(output_records=2, max_iterations=3), + ) + workflow.add_stage("final", stage_2) + + results = workflow.run() + final = results.load_dataset().sort_values("name").reset_index(drop=True) + metadata = _load_workflow_metadata(tmp_path / "artifacts", "repeat-append") + + assert final[["name", "candidate", "final"]].to_dict(orient="records") == [ + {"name": "Grace", "candidate": "Grace candidate", "final": "Grace candidate final"}, + {"name": "Linus", "candidate": "Linus candidate", "final": "Linus candidate final"}, + ] + assert results["candidates"].count_records() == 4 + assert results.count_stage_output_records("candidates") == 2 + assert metadata["stages"][0]["num_records_requested"] == 4 + assert metadata["stages"][0]["repeat_iterations"] == 2 + assert metadata["stages"][0]["repeat_generated_records"] == 4 + assert metadata["stages"][0]["repeat_satisfied"] is True + assert metadata["stages"][0]["repeat_until_output_path"].endswith("stage-0-candidates/repeat-until/selected-output") + + +def test_composite_workflow_repeat_until_returns_partial_when_exhausted( + tmp_path: Path, + stub_model_providers: list[ModelProvider], + stub_model_configs: list[ModelConfig], +) -> None: + stage_1 = _seeded_builder( + stub_model_configs, + [ + {"name": "Ada", "keep": False}, + {"name": "Grace", "keep": True}, + {"name": "Linus", "keep": False}, + {"name": "Margaret", "keep": False}, + ], + ) + stage_1.add_column(ExpressionColumnConfig(name="candidate", expr="{{ name }} candidate")) + + def keep_rows(stage_path: Path) -> Path: + df = lazy.pd.read_parquet(stage_path / "parquet-files") + output_path = stage_path / "callback-outputs" / "kept" + if output_path.exists(): + shutil.rmtree(output_path) + output_path.mkdir(parents=True) + df[df["keep"]].to_parquet(output_path / "data.parquet", index=False) + return output_path + + workflow = _real_data_designer(tmp_path / "artifacts", stub_model_providers).compose_workflow(name="repeat-partial") + workflow.add_stage( + "candidates", + stage_1, + num_records=2, + on_success=keep_rows, + on_success_version="kept-v1", + repeat_until=RepeatUntil(output_records=3, max_iterations=2, on_exhausted="return_partial"), + ) + + results = workflow.run() + metadata = _load_workflow_metadata(tmp_path / "artifacts", "repeat-partial") + + assert results.load_dataset()["name"].tolist() == ["Grace"] + assert results.count_records() == 1 + assert metadata["stages"][0]["repeat_iterations"] == 2 + assert metadata["stages"][0]["repeat_generated_records"] == 4 + assert metadata["stages"][0]["repeat_satisfied"] is False + + +def test_composite_workflow_repeat_until_raises_when_exhausted( + tmp_path: Path, + stub_model_providers: list[ModelProvider], + stub_model_configs: list[ModelConfig], +) -> None: + stage = _seeded_builder(stub_model_configs, [{"name": "Ada", "keep": False}]) + stage.add_column(ExpressionColumnConfig(name="candidate", expr="{{ name }} candidate")) + + def keep_rows(stage_path: Path) -> Path: + df = lazy.pd.read_parquet(stage_path / "parquet-files") + output_path = stage_path / "callback-outputs" / "kept" + if output_path.exists(): + shutil.rmtree(output_path) + output_path.mkdir(parents=True) + df[df["keep"]].to_parquet(output_path / "data.parquet", index=False) + return output_path + + workflow = _real_data_designer(tmp_path / "artifacts", stub_model_providers).compose_workflow(name="repeat-raises") + workflow.add_stage( + "candidates", + stage, + num_records=1, + on_success=keep_rows, + repeat_until=RepeatUntil(output_records=1, max_iterations=2), + ) + + with pytest.raises(DataDesignerWorkflowError, match="repeat_until exhausted"): + workflow.run() + + +def test_composite_workflow_repeat_until_discard_keeps_latest_attempt( + tmp_path: Path, + stub_model_providers: list[ModelProvider], + stub_model_configs: list[ModelConfig], +) -> None: + stage = _seeded_builder( + stub_model_configs, + [{"name": "Ada"}, {"name": "Grace"}, {"name": "Linus"}], + ) + stage.add_column(ExpressionColumnConfig(name="candidate", expr="{{ name }} candidate")) + callback_calls = 0 + + def keep_more_each_time(stage_path: Path) -> Path: + nonlocal callback_calls + callback_calls += 1 + df = lazy.pd.read_parquet(stage_path / "parquet-files").head(callback_calls) + output_path = stage_path / "callback-outputs" / "latest" + if output_path.exists(): + shutil.rmtree(output_path) + output_path.mkdir(parents=True) + df.to_parquet(output_path / "data.parquet", index=False) + return output_path + + workflow = _real_data_designer(tmp_path / "artifacts", stub_model_providers).compose_workflow(name="repeat-discard") + workflow.add_stage( + "candidates", + stage, + num_records=3, + on_success=keep_more_each_time, + repeat_until=RepeatUntil(output_records=2, max_iterations=3, mode="discard"), + ) + + results = workflow.run() + metadata = _load_workflow_metadata(tmp_path / "artifacts", "repeat-discard") + + assert callback_calls == 2 + assert results.load_dataset()["name"].tolist() == ["Ada", "Grace"] + assert results["candidates"].count_records() == 3 + assert metadata["stages"][0]["repeat_iterations"] == 2 + assert metadata["stages"][0]["repeat_generated_records"] == 6 + + def test_composite_workflow_does_not_forward_dropped_processor_columns( tmp_path: Path, stub_model_providers: list[ModelProvider], From 773b721b69feb0219f99614c0653235ec8bc7f33 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Fri, 26 Jun 2026 12:08:04 -0300 Subject: [PATCH 2/3] fix: address repeat until review feedback --- .../pages/concepts/workflow-chaining.mdx | 6 +- .../interface/composite_workflow.py | 21 ++- .../interface/test_composite_workflow.py | 120 ++++++++++++++++++ 3 files changed, 143 insertions(+), 4 deletions(-) diff --git a/fern/versions/latest/pages/concepts/workflow-chaining.mdx b/fern/versions/latest/pages/concepts/workflow-chaining.mdx index f215ead74..2f7272452 100644 --- a/fern/versions/latest/pages/concepts/workflow-chaining.mdx +++ b/fern/versions/latest/pages/concepts/workflow-chaining.mdx @@ -161,9 +161,11 @@ workflow.add_stage( workflow.add_stage("enriched", enriched) ``` -`num_records` is the per-attempt size. In the default `mode="append"`, Data Designer extends the stage by another attempt each iteration, reruns `on_success`, and feeds exactly `output_records` selected rows downstream. Set `on_exhausted="return_partial"` to keep the best partial output when the bounds are reached; otherwise the workflow raises. +`num_records` is the per-attempt size. In the default `mode="append"`, each iteration requests the cumulative stage size (`num_records`, then `2 * num_records`, and so on), reruns `on_success` over the accumulated stage output, and feeds exactly `output_records` selected rows downstream. -Use `mode="discard"` when each attempt should replace the previous selected output instead of accumulating it. Keep bounded limits in place: a low acceptance rate is often a signal to inspect the recipe, not just to run indefinitely. +Set `on_exhausted="return_partial"` to keep the best partial output when the bounds are reached; otherwise the workflow raises. If no rows pass, the stage completes empty and downstream stages are skipped, matching `allow_empty=True` behavior. + +Use `mode="discard"` when each attempt should replace the previous selected output instead of accumulating it. Discard mode restarts the stage on resume because previous attempts are intentionally replaced. Keep bounded limits in place: a low acceptance rate is often a signal to inspect the recipe, not just to run indefinitely. In append mode, `max_generated_records` caps the cumulative requested stage size; in discard mode, it caps records produced across attempts. ## Changing row counts between stages diff --git a/packages/data-designer/src/data_designer/interface/composite_workflow.py b/packages/data-designer/src/data_designer/interface/composite_workflow.py index db31b7655..2b63cd000 100644 --- a/packages/data-designer/src/data_designer/interface/composite_workflow.py +++ b/packages/data-designer/src/data_designer/interface/composite_workflow.py @@ -490,7 +490,7 @@ def run( output_records = _count_parquet_records(output_seed_path) if output_records == 0: - if not stage.allow_empty: + if not _allows_empty_stage_output(stage, run_result): raise DataDesignerWorkflowError(f"Stage {stage.name!r} produced an empty output.") status = "completed_empty" skipped_upstream_stage = stage.name @@ -583,6 +583,7 @@ def _run_stage( stage_dir_name=stage_dir_name, stage_path=stage_path, num_records=num_records, + resume=resume, ) return self._run_stage_until_append( stage=stage, @@ -708,8 +709,14 @@ def _run_stage_until_discard( stage_dir_name: str, stage_path: Path, num_records: int, + resume: ResumeMode, ) -> _StageRunResult: repeat_until = _require_repeat_until(stage) + if resume == ResumeMode.ALWAYS: + logger.warning( + "Stage %r uses repeat_until mode='discard'; previous attempts cannot be resumed and will be replaced.", + stage.name, + ) last_result = None generated_records = 0 iterations_run = 0 @@ -824,7 +831,7 @@ def _handle_repeat_until_exhausted( ) if last_result is None: raise DataDesignerWorkflowError( - f"Stage {stage.name!r} repeat_until did not run because max_generated_records was too low." + f"Stage {stage.name!r} repeat_until did not run because no iteration fit within the configured limits." ) return _with_repeat_result( last_result, @@ -836,6 +843,16 @@ def _handle_repeat_until_exhausted( ) +def _allows_empty_stage_output(stage: _WorkflowStage, run_result: _StageRunResult) -> bool: + if stage.allow_empty: + return True + return ( + stage.repeat_until is not None + and stage.repeat_until.on_exhausted == RepeatUntilExhaustion.RETURN_PARTIAL + and run_result.repeat_satisfied is False + ) + + def _repeat_until_payload(repeat_until: RepeatUntil | None) -> dict[str, Any] | None: if repeat_until is None: return None diff --git a/packages/data-designer/tests/interface/test_composite_workflow.py b/packages/data-designer/tests/interface/test_composite_workflow.py index 47f9a5eba..2a76b3026 100644 --- a/packages/data-designer/tests/interface/test_composite_workflow.py +++ b/packages/data-designer/tests/interface/test_composite_workflow.py @@ -1129,6 +1129,47 @@ def keep_rows(stage_path: Path) -> Path: assert metadata["stages"][0]["repeat_satisfied"] is False +def test_composite_workflow_repeat_until_returns_empty_partial_when_exhausted( + tmp_path: Path, + stub_model_providers: list[ModelProvider], + stub_model_configs: list[ModelConfig], +) -> None: + stage = _seeded_builder( + stub_model_configs, + [{"name": "Ada", "keep": False}, {"name": "Grace", "keep": False}], + ) + stage.add_column(ExpressionColumnConfig(name="candidate", expr="{{ name }} candidate")) + + def keep_rows(stage_path: Path) -> Path: + df = lazy.pd.read_parquet(stage_path / "parquet-files") + output_path = stage_path / "callback-outputs" / "kept" + if output_path.exists(): + shutil.rmtree(output_path) + output_path.mkdir(parents=True) + df[df["keep"]].to_parquet(output_path / "data.parquet", index=False) + return output_path + + workflow = _real_data_designer(tmp_path / "artifacts", stub_model_providers).compose_workflow( + name="repeat-empty-partial" + ) + workflow.add_stage( + "candidates", + stage, + num_records=1, + on_success=keep_rows, + on_success_version="kept-v1", + repeat_until=RepeatUntil(output_records=1, max_iterations=2, on_exhausted="return_partial"), + ) + + results = workflow.run() + metadata = _load_workflow_metadata(tmp_path / "artifacts", "repeat-empty-partial") + + assert results.count_records() == 0 + assert results.load_dataset().empty + assert metadata["stages"][0]["status"] == "completed_empty" + assert metadata["stages"][0]["repeat_satisfied"] is False + + def test_composite_workflow_repeat_until_raises_when_exhausted( tmp_path: Path, stub_model_providers: list[ModelProvider], @@ -1201,6 +1242,85 @@ def keep_more_each_time(stage_path: Path) -> Path: assert metadata["stages"][0]["repeat_generated_records"] == 6 +def test_composite_workflow_repeat_until_discard_warns_when_resuming( + stub_artifact_path: Path, + stub_model_providers: list[ModelProvider], + stub_model_configs: list[ModelConfig], + stub_dataset_profiler_results, + caplog: pytest.LogCaptureFixture, +) -> None: + data_designer = _data_designer(stub_artifact_path, stub_model_providers) + _patch_create(data_designer, stub_dataset_profiler_results) + workflow = data_designer.compose_workflow(name="repeat-discard-resume") + workflow.add_stage( + "candidates", + _category_builder(stub_model_configs), + num_records=2, + repeat_until=RepeatUntil(output_records=1, max_iterations=2, mode="discard"), + ) + workflow.run() + + metadata_path = stub_artifact_path / "repeat-discard-resume" / "workflow-metadata.json" + metadata = json.loads(metadata_path.read_text()) + _mark_stage_resumable(metadata, 0, "failed") + metadata_path.write_text(json.dumps(metadata)) + + resumed = data_designer.compose_workflow(name="repeat-discard-resume") + resumed.add_stage( + "candidates", + _category_builder(stub_model_configs), + num_records=2, + repeat_until=RepeatUntil(output_records=1, max_iterations=2, mode="discard"), + ) + + with caplog.at_level("WARNING", logger="data_designer.interface.composite_workflow"): + resumed.run(resume=ResumeMode.IF_POSSIBLE) + + assert "previous attempts cannot be resumed" in caplog.text + + +def test_composite_workflow_repeat_until_uses_processor_output( + tmp_path: Path, + stub_model_providers: list[ModelProvider], + stub_model_configs: list[ModelConfig], +) -> None: + stage_1 = _seeded_builder( + stub_model_configs, + [{"name": "Ada"}, {"name": "Linus"}, {"name": "Grace"}], + ) + stage_1.add_column(ExpressionColumnConfig(name="persona", expr="{{ name }}")) + stage_2 = _expression_builder(stub_model_configs, "final", "{{ compact_name }} final") + + workflow = _real_data_designer(tmp_path / "artifacts", stub_model_providers).compose_workflow( + name="repeat-processor-output" + ) + workflow.add_stage( + "compact", + stage_1, + num_records=1, + output_processors=[SchemaTransformProcessorConfig(name="compact", template={"compact_name": "{{ persona }}"})], + output="processor:compact", + repeat_until=RepeatUntil(output_records=2, max_iterations=3), + ) + workflow.add_stage("final", stage_2) + + results = workflow.run() + final = results.load_dataset().sort_values("compact_name").reset_index(drop=True) + stage_output = results.load_stage_output("compact").sort_values("compact_name").reset_index(drop=True) + metadata = _load_workflow_metadata(tmp_path / "artifacts", "repeat-processor-output") + + assert stage_output.to_dict(orient="records") == [{"compact_name": "Ada"}, {"compact_name": "Linus"}] + assert final.to_dict(orient="records") == [ + {"compact_name": "Ada", "final": "Ada final"}, + {"compact_name": "Linus", "final": "Linus final"}, + ] + assert metadata["stages"][0]["num_records_requested"] == 2 + assert metadata["stages"][0]["repeat_iterations"] == 2 + assert metadata["stages"][0]["output_seed_path"].endswith( + "stage-0-compact/output-processors/processors-files/compact" + ) + + def test_composite_workflow_does_not_forward_dropped_processor_columns( tmp_path: Path, stub_model_providers: list[ModelProvider], From 022cac6a616fa70ffe0ea94ffd6c82908ca0d5a1 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Fri, 26 Jun 2026 12:10:24 -0300 Subject: [PATCH 3/3] refactor: simplify repeat until append resume --- .../interface/composite_workflow.py | 22 ++----------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/packages/data-designer/src/data_designer/interface/composite_workflow.py b/packages/data-designer/src/data_designer/interface/composite_workflow.py index 2b63cd000..0dcd26488 100644 --- a/packages/data-designer/src/data_designer/interface/composite_workflow.py +++ b/packages/data-designer/src/data_designer/interface/composite_workflow.py @@ -481,7 +481,6 @@ def run( stage_path=stage_path, num_records=num_records, resume=stage_resume, - prior_stage_metadata=prior_stage_metadata, ) output_seed_path = run_result.output_seed_path override_path = _stage_output_override(stage.name, stage_output_overrides) @@ -564,7 +563,6 @@ def _run_stage( stage_path: Path, num_records: int, resume: ResumeMode, - prior_stage_metadata: dict[str, Any] | None, ) -> _StageRunResult: if stage.repeat_until is None: return self._run_stage_attempt( @@ -593,7 +591,6 @@ def _run_stage( stage_path=stage_path, num_records=num_records, resume=resume, - prior_stage_metadata=prior_stage_metadata, ) def _run_stage_attempt( @@ -663,16 +660,14 @@ def _run_stage_until_append( stage_path: Path, num_records: int, resume: ResumeMode, - prior_stage_metadata: dict[str, Any] | None, ) -> _StageRunResult: repeat_until = _require_repeat_until(stage) - start_iteration = _append_start_iteration(num_records, resume, prior_stage_metadata) last_result = None - for iteration in range(start_iteration, repeat_until.max_iterations + 1): + for iteration in range(1, repeat_until.max_iterations + 1): requested_records = num_records * iteration if _exceeds_max_generated_records(repeat_until, requested_records): break - attempt_resume = resume if iteration == start_iteration else ResumeMode.ALWAYS + attempt_resume = resume if iteration == 1 else ResumeMode.ALWAYS last_result = self._run_stage_attempt( stage=stage, stage_builder=stage_builder, @@ -765,19 +760,6 @@ def _require_repeat_until(stage: _WorkflowStage) -> RepeatUntil: return stage.repeat_until -def _append_start_iteration( - num_records: int, - resume: ResumeMode, - prior_stage_metadata: dict[str, Any] | None, -) -> int: - if resume != ResumeMode.ALWAYS or prior_stage_metadata is None: - return 1 - prior_requested = prior_stage_metadata.get("num_records_requested") - if not isinstance(prior_requested, int) or prior_requested <= num_records: - return 1 - return -(-prior_requested // num_records) - - def _exceeds_max_generated_records(repeat_until: RepeatUntil, generated_records: int) -> bool: return repeat_until.max_generated_records is not None and generated_records > repeat_until.max_generated_records