diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 9c69276fa1db..b52a06b4e98d 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1001,51 +1001,30 @@ impl OptimizerRule for PushDownFilter { }) .collect::>(); - let predicates = split_conjunction_owned(filter.predicate); - - let mut keep_predicates = vec![]; - let mut push_predicates = vec![]; - for expr in predicates { - let cols = expr.column_refs(); - if cols.iter().all(|c| group_expr_columns.contains(c)) { - push_predicates.push(expr); - } else { - keep_predicates.push(expr); - } - } - // As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)] // After push, we need to replace `a+b` with Column(a)+Column(b) // So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))} - let mut replace_map = HashMap::new(); - for expr in &agg.group_expr { - replace_map.insert(expr.schema_name().to_string(), expr.clone()); - } - let replaced_push_predicates = push_predicates - .into_iter() - .map(|expr| replace_cols_by_name(expr, &replace_map)) - .collect::>>()?; - - let agg_input = Arc::clone(&agg.input); - Transformed::yes(LogicalPlan::Aggregate(agg)) - .transform_data(|new_plan| { - // If we have a filter to push, we push it down to the input of the aggregate - if let Some(predicate) = conjunction(replaced_push_predicates) { - let new_filter = make_filter(predicate, agg_input)?; - insert_below(new_plan, new_filter) - } else { - Ok(Transformed::no(new_plan)) - } - })? - .map_data(|child_plan| { - // if there are any remaining predicates we can't push, add them - // back as a filter - if let Some(predicate) = conjunction(keep_predicates) { - make_filter(predicate, Arc::new(child_plan)) - } else { - Ok(child_plan) - } - }) + let replace_map = agg + .group_expr + .iter() + .map(|expr| (expr.schema_name().to_string(), expr.clone())) + .collect::>(); + + push_down_filter_through_unary( + filter.predicate, + |expr| { + expr.column_refs() + .iter() + .all(|c| group_expr_columns.contains(c)) + }, + LogicalPlan::Aggregate(agg), + |push_predicates| { + push_predicates + .into_iter() + .map(|expr| replace_cols_by_name(expr, &replace_map)) + .collect::>>() + }, + ) } // Tries to push filters based on the partition key(s) of the window function(s) used. // Example: @@ -1102,18 +1081,6 @@ impl OptimizerRule for PushDownFilter { .reduce(|a, b| &a & &b) .unwrap_or_default(); - let predicates = split_conjunction_owned(filter.predicate); - let mut keep_predicates = vec![]; - let mut push_predicates = vec![]; - for expr in predicates { - let cols = expr.column_refs(); - if cols.iter().all(|c| potential_partition_keys.contains(c)) { - push_predicates.push(expr); - } else { - keep_predicates.push(expr); - } - } - // Unlike with aggregations, there are no cases where we have to replace, e.g., // `a+b` with Column(a)+Column(b). This is because partition expressions are not // available as standalone columns to the user. For example, while an aggregation on @@ -1122,26 +1089,16 @@ impl OptimizerRule for PushDownFilter { // place, so we can use `push_predicates` directly. This is consistent with other // optimizers, such as the one used by Postgres. - let window_input = Arc::clone(&window.input); - Transformed::yes(LogicalPlan::Window(window)) - .transform_data(|new_plan| { - // If we have a filter to push, we push it down to the input of the window - if let Some(predicate) = conjunction(push_predicates) { - let new_filter = make_filter(predicate, window_input)?; - insert_below(new_plan, new_filter) - } else { - Ok(Transformed::no(new_plan)) - } - })? - .map_data(|child_plan| { - // if there are any remaining predicates we can't push, add them - // back as a filter - if let Some(predicate) = conjunction(keep_predicates) { - make_filter(predicate, Arc::new(child_plan)) - } else { - Ok(child_plan) - } - }) + push_down_filter_through_unary( + filter.predicate, + |expr| { + expr.column_refs() + .iter() + .all(|c| potential_partition_keys.contains(c)) + }, + LogicalPlan::Window(window), + Ok, + ) } LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)), LogicalPlan::TableScan(scan) => { @@ -1378,6 +1335,52 @@ fn rewrite_projection( } } +/// Pushes eligible conjunctive predicates below a single-input plan node and +/// re-applies ineligible predicates above it. +fn push_down_filter_through_unary( + predicate: Expr, + mut can_push: C, + unary_plan: LogicalPlan, + rewrite_push_predicates: R, +) -> Result> +where + C: FnMut(&Expr) -> bool, + R: FnOnce(Vec) -> Result>, +{ + let (push_predicates, keep_predicates): (Vec<_>, Vec<_>) = + split_conjunction_owned(predicate) + .into_iter() + .partition(|expr| can_push(expr)); + let push_predicates = rewrite_push_predicates(push_predicates)?; + + Transformed::yes(unary_plan) + .transform_data(|new_plan| { + // If we have a filter to push, push it down to the unary node's input. + if let Some(predicate) = conjunction(push_predicates) { + insert_filter_below_unary(new_plan, predicate) + } else { + Ok(Transformed::no(new_plan)) + } + })? + .map_data(|child_plan| { + // If any predicates remain, add them back as a filter above the unary node. + if let Some(predicate) = conjunction(keep_predicates) { + make_filter(predicate, Arc::new(child_plan)) + } else { + Ok(child_plan) + } + }) +} + +/// Inserts a filter below a single-input plan node, using that node's existing +/// child as the filter input. +fn insert_filter_below_unary( + plan: LogicalPlan, + predicate: Expr, +) -> Result> { + map_single_child(plan, |child| make_filter(predicate, Arc::new(child))) +} + /// Creates a new LogicalPlan::Filter node. pub fn make_filter(predicate: Expr, input: Arc) -> Result { Filter::try_new(predicate, input).map(LogicalPlan::Filter) @@ -1400,18 +1403,28 @@ fn insert_below( plan: LogicalPlan, new_child: LogicalPlan, ) -> Result> { - let mut new_child = Some(new_child); - let transformed_plan = plan.map_children(|_child| { - if let Some(new_child) = new_child.take() { - Ok(Transformed::yes(new_child)) + map_single_child(plan, |_child| Ok(new_child)) +} + +fn map_single_child( + plan: LogicalPlan, + replace_child: F, +) -> Result> +where + F: FnOnce(LogicalPlan) -> Result, +{ + let mut replace_child = Some(replace_child); + let transformed_plan = plan.map_children(|child| { + if let Some(replace_child) = replace_child.take() { + replace_child(child).map(Transformed::yes) } else { - // already took the new child + // already replaced the child internal_err!("node had more than one input") } })?; // make sure we did the actual replacement - assert_or_internal_err!(new_child.is_none(), "node had no inputs"); + assert_or_internal_err!(replace_child.is_none(), "node had no inputs"); Ok(transformed_plan) }