Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,7 @@ pub fn start(
instance_id: instance_id.clone(),
label: label.map(|s| s.to_string()),
vars,
loop_iteration: 0,
};
let input_json = serde_json::to_string(&input).unwrap_or(instance_id.clone());

Expand Down
78 changes: 78 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2764,6 +2764,84 @@ mod tests {
"Should produce a VALUES clause"
);
}

// --- M7: Loop iteration counter persisted across continue_as_new ---

#[pg_test]
fn test_function_input_loop_iteration_serialization() {
use crate::types::FunctionInput;

// Verify loop_iteration is preserved through serialization
let input = FunctionInput {
instance_id: "test123".to_string(),
label: Some("test".to_string()),
vars: std::collections::HashMap::new(),
loop_iteration: 42,
};
let json = serde_json::to_string(&input).unwrap();
let deserialized: FunctionInput = serde_json::from_str(&json).unwrap();
assert_eq!(
deserialized.loop_iteration, 42,
"loop_iteration must survive serialization round-trip"
);
}

#[pg_test]
fn test_function_input_loop_iteration_defaults_to_zero() {
use crate::types::FunctionInput;

// Verify backward compat: old FunctionInput JSON without loop_iteration
// deserializes with loop_iteration = 0
let json = r#"{"instance_id":"abc12345","label":"test","vars":{}}"#;
let input: FunctionInput = serde_json::from_str(json).unwrap();
assert_eq!(
input.loop_iteration, 0,
"Missing loop_iteration should default to 0 for backward compatibility"
);
}

// --- M8: Malformed loop condition config detection ---

#[pg_test]
fn test_malformed_loop_condition_detected_at_validate() {
// A LOOP node whose condition_node is a plain string (not a Durofut object)
// should be rejected by validate_recursive because for_each_config_child
// requires condition_node to deserialize as a valid Durofut.
let node = Durofut {
node_type: "LOOP".to_string(),
left_node: Some(Box::new(Durofut {
node_type: "SQL".to_string(),
query: Some("SELECT 1".to_string()),
..Default::default()
})),
// Malformed config: valid JSON but condition_node is a string, not a Durofut object.
query: Some(r#"{"condition_node": "nonexist"}"#.to_string()),
..Default::default()
};
// Validate should fail because condition_node is not a valid Durofut object
let err = node.validate_recursive().unwrap_err();
assert!(
err.contains("condition_node"),
"Error should mention condition_node, got: {err}"
);

// But if the config is totally not JSON, for_each_config_child skips it
// (it's treated as a plain query string, not a config object).
let non_json_node = Durofut {
node_type: "LOOP".to_string(),
left_node: Some(Box::new(Durofut {
node_type: "SQL".to_string(),
query: Some("SELECT 1".to_string()),
..Default::default()
})),
query: Some("this is not json at all!!!".to_string()),
..Default::default()
};
assert!(
non_json_node.validate_recursive().is_ok(),
"LOOP with non-JSON config passes DSL validation (caught at execution time)"
);
}
}

/// Required by `cargo pgrx test`
Expand Down
76 changes: 53 additions & 23 deletions src/orchestrations/execute_function_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ pub const SUBTREE_NAME: &str = "pg_durable::orchestration::execute-subtree";
struct ExecutionContext {
vars: HashMap<String, String>,
label: Option<String>,
/// Loop iteration counter (persisted across continue_as_new generations).
loop_iteration: u64,
}

/// Control-flow-aware error type returned by every node handler.
Expand Down Expand Up @@ -181,6 +183,7 @@ pub async fn execute(ctx: OrchestrationContext, input_json: String) -> Result<St
let exec_ctx = ExecutionContext {
vars: input.vars.clone(),
label: input.label.clone(),
loop_iteration: input.loop_iteration,
};

let function_outcome =
Expand Down Expand Up @@ -265,7 +268,11 @@ pub async fn execute_subtree(

ctx.trace_info(format!("ExecuteSubtree: executing node {node_id}"));

let exec_ctx = ExecutionContext { vars, label };
let exec_ctx = ExecutionContext {
vars,
label,
loop_iteration: 0,
};

// Build the envelope carrying the result, the updated named-results map, and a typed
// control signal. A `Break` inside the subtree is re-encoded as `control: Break` (not a
Expand Down Expand Up @@ -526,6 +533,10 @@ async fn execute_wait_schedule_node(
/// deficit so an empty-bodied loop can't busy-spin via continue_as_new.
const LOOP_MIN_ITER_DURATION: Duration = Duration::from_secs(1);

/// Maximum loop iterations before the orchestration is forcibly terminated.
/// This prevents runaway infinite loops from consuming resources indefinitely.
/// At the minimum 1-second rate limit, this allows ~27 hours of looping.
const MAX_LOOP_ITERATIONS: u64 = 100_000;
async fn execute_loop_node(
ctx: &OrchestrationContext,
graph: &FunctionGraph,
Expand Down Expand Up @@ -567,35 +578,53 @@ async fn execute_loop_node(

// Check while-condition if present
if let Some(ref config_str) = node.query {
if let Ok(config) = serde_json::from_str::<serde_json::Value>(config_str) {
if let Some(condition_node_id) = config["condition_node"].as_str() {
ctx.trace_info("Evaluating loop condition");
let condition_result = Box::pin(execute_function_node_with_vars(
ctx,
graph,
condition_node_id,
results,
exec_ctx,
))
.await?;

// Parse condition result to check truthiness (uses evaluate_condition to extract boolean from SQL result)
let should_continue = evaluate_condition(&condition_result).unwrap_or(false);
ctx.trace_info(format!(
"Loop condition evaluated to: {condition_result} (continue={should_continue})"
));

if !should_continue {
ctx.trace_info("Loop condition false, exiting loop");
store_named_result(ctx, node, &body_result, results, "LOOP");
return Ok(body_result);
match serde_json::from_str::<serde_json::Value>(config_str) {
Ok(config) => {
if let Some(condition_node_id) = config["condition_node"].as_str() {
ctx.trace_info("Evaluating loop condition");
let condition_result = Box::pin(execute_function_node_with_vars(
ctx,
graph,
condition_node_id,
results,
exec_ctx,
))
.await?;

// Parse condition result to check truthiness (uses evaluate_condition to extract boolean from SQL result)
let should_continue = evaluate_condition(&condition_result).unwrap_or(false);
ctx.trace_info(format!(
"Loop condition evaluated to: {condition_result} (continue={should_continue})"
));

if !should_continue {
ctx.trace_info("Loop condition false, exiting loop");
store_named_result(ctx, node, &body_result, results, "LOOP");
return Ok(body_result);
}
}
}
Err(e) => {
// M8: Malformed condition config should fail the loop rather than
// silently creating an infinite loop without exit condition.
return Err(NodeError::Failure(format!(
"LOOP node {node_id}: failed to parse condition config: {e}"
)));
}
}
}

ctx.trace_info("Continuing as new for next loop iteration");

// M7: Enforce maximum iteration count to prevent runaway infinite loops
let next_iteration = exec_ctx.loop_iteration + 1;
if next_iteration >= MAX_LOOP_ITERATIONS {
return Err(NodeError::Failure(format!(
"Loop exceeded maximum iteration count of {MAX_LOOP_ITERATIONS}. \
Use df.break() to exit the loop or restructure the workflow."
)));
}

// Enforce a minimum per-iteration wall-clock duration to prevent
// busy-looping (e.g. `df.loop(df.sleep(0))`). Compute the elapsed time
// from the deterministic clock; if the iteration finished faster than
Expand All @@ -620,6 +649,7 @@ async fn execute_loop_node(
instance_id: graph.instance_id.clone(),
label: exec_ctx.label.clone(),
vars: exec_ctx.vars.clone(),
loop_iteration: next_iteration,
};

// duroxide 0.1.1: continue_as_new returns an awaitable future - return it directly
Expand Down
4 changes: 4 additions & 0 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,10 @@ pub struct FunctionInput {
pub label: Option<String>,
#[serde(default)]
pub vars: std::collections::HashMap<String, String>,
/// Loop iteration counter, incremented on each `continue_as_new`.
/// Used to enforce a maximum iteration safeguard.
#[serde(default)]
pub loop_iteration: u64,
}

/// Configuration for HTTP requests
Expand Down
Loading