Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 93 additions & 80 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1001,51 +1001,30 @@ impl OptimizerRule for PushDownFilter {
})
.collect::<HashSet<_>>();

let predicates = split_conjunction_owned(filter.predicate);

let mut keep_predicates = vec![];
let mut push_predicates = vec![];
for expr in predicates {
let cols = expr.column_refs();
if cols.iter().all(|c| group_expr_columns.contains(c)) {
push_predicates.push(expr);
} else {
keep_predicates.push(expr);
}
}

// As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)]
// After push, we need to replace `a+b` with Column(a)+Column(b)
// So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))}
let mut replace_map = HashMap::new();
for expr in &agg.group_expr {
replace_map.insert(expr.schema_name().to_string(), expr.clone());
}
let replaced_push_predicates = push_predicates
.into_iter()
.map(|expr| replace_cols_by_name(expr, &replace_map))
.collect::<Result<Vec<_>>>()?;

let agg_input = Arc::clone(&agg.input);
Transformed::yes(LogicalPlan::Aggregate(agg))
.transform_data(|new_plan| {
// If we have a filter to push, we push it down to the input of the aggregate
if let Some(predicate) = conjunction(replaced_push_predicates) {
let new_filter = make_filter(predicate, agg_input)?;
insert_below(new_plan, new_filter)
} else {
Ok(Transformed::no(new_plan))
}
})?
.map_data(|child_plan| {
// if there are any remaining predicates we can't push, add them
// back as a filter
if let Some(predicate) = conjunction(keep_predicates) {
make_filter(predicate, Arc::new(child_plan))
} else {
Ok(child_plan)
}
})
let replace_map = agg
.group_expr
.iter()
.map(|expr| (expr.schema_name().to_string(), expr.clone()))
.collect::<HashMap<_, _>>();

push_down_filter_through_unary(
filter.predicate,
|expr| {
expr.column_refs()
.iter()
.all(|c| group_expr_columns.contains(c))
},
LogicalPlan::Aggregate(agg),
|push_predicates| {
push_predicates
.into_iter()
.map(|expr| replace_cols_by_name(expr, &replace_map))
.collect::<Result<Vec<_>>>()
},
)
}
// Tries to push filters based on the partition key(s) of the window function(s) used.
// Example:
Expand Down Expand Up @@ -1102,18 +1081,6 @@ impl OptimizerRule for PushDownFilter {
.reduce(|a, b| &a & &b)
.unwrap_or_default();

let predicates = split_conjunction_owned(filter.predicate);
let mut keep_predicates = vec![];
let mut push_predicates = vec![];
for expr in predicates {
let cols = expr.column_refs();
if cols.iter().all(|c| potential_partition_keys.contains(c)) {
push_predicates.push(expr);
} else {
keep_predicates.push(expr);
}
}

// Unlike with aggregations, there are no cases where we have to replace, e.g.,
// `a+b` with Column(a)+Column(b). This is because partition expressions are not
// available as standalone columns to the user. For example, while an aggregation on
Expand All @@ -1122,26 +1089,16 @@ impl OptimizerRule for PushDownFilter {
// place, so we can use `push_predicates` directly. This is consistent with other
// optimizers, such as the one used by Postgres.

let window_input = Arc::clone(&window.input);
Transformed::yes(LogicalPlan::Window(window))
.transform_data(|new_plan| {
// If we have a filter to push, we push it down to the input of the window
if let Some(predicate) = conjunction(push_predicates) {
let new_filter = make_filter(predicate, window_input)?;
insert_below(new_plan, new_filter)
} else {
Ok(Transformed::no(new_plan))
}
})?
.map_data(|child_plan| {
// if there are any remaining predicates we can't push, add them
// back as a filter
if let Some(predicate) = conjunction(keep_predicates) {
make_filter(predicate, Arc::new(child_plan))
} else {
Ok(child_plan)
}
})
push_down_filter_through_unary(
filter.predicate,
|expr| {
expr.column_refs()
.iter()
.all(|c| potential_partition_keys.contains(c))
},
LogicalPlan::Window(window),
Ok,
)
}
LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
LogicalPlan::TableScan(scan) => {
Expand Down Expand Up @@ -1378,6 +1335,52 @@ fn rewrite_projection(
}
}

/// Pushes eligible conjunctive predicates below a single-input plan node and
/// re-applies ineligible predicates above it.
fn push_down_filter_through_unary<C, R>(
predicate: Expr,
mut can_push: C,
unary_plan: LogicalPlan,
rewrite_push_predicates: R,
) -> Result<Transformed<LogicalPlan>>
where
C: FnMut(&Expr) -> bool,
R: FnOnce(Vec<Expr>) -> Result<Vec<Expr>>,
{
let (push_predicates, keep_predicates): (Vec<_>, Vec<_>) =
split_conjunction_owned(predicate)
.into_iter()
.partition(|expr| can_push(expr));
let push_predicates = rewrite_push_predicates(push_predicates)?;

Transformed::yes(unary_plan)
.transform_data(|new_plan| {
// If we have a filter to push, push it down to the unary node's input.
if let Some(predicate) = conjunction(push_predicates) {
insert_filter_below_unary(new_plan, predicate)
} else {
Ok(Transformed::no(new_plan))
}
})?
.map_data(|child_plan| {
// If any predicates remain, add them back as a filter above the unary node.
if let Some(predicate) = conjunction(keep_predicates) {
make_filter(predicate, Arc::new(child_plan))
} else {
Ok(child_plan)
}
})
}

/// Inserts a filter below a single-input plan node, using that node's existing
/// child as the filter input.
fn insert_filter_below_unary(
plan: LogicalPlan,
predicate: Expr,
) -> Result<Transformed<LogicalPlan>> {
map_single_child(plan, |child| make_filter(predicate, Arc::new(child)))
}

/// Creates a new LogicalPlan::Filter node.
pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
Filter::try_new(predicate, input).map(LogicalPlan::Filter)
Expand All @@ -1400,18 +1403,28 @@ fn insert_below(
plan: LogicalPlan,
new_child: LogicalPlan,
) -> Result<Transformed<LogicalPlan>> {
let mut new_child = Some(new_child);
let transformed_plan = plan.map_children(|_child| {
if let Some(new_child) = new_child.take() {
Ok(Transformed::yes(new_child))
map_single_child(plan, |_child| Ok(new_child))
}

fn map_single_child<F>(
plan: LogicalPlan,
replace_child: F,
) -> Result<Transformed<LogicalPlan>>
where
F: FnOnce(LogicalPlan) -> Result<LogicalPlan>,
{
let mut replace_child = Some(replace_child);
let transformed_plan = plan.map_children(|child| {
if let Some(replace_child) = replace_child.take() {
replace_child(child).map(Transformed::yes)
} else {
// already took the new child
// already replaced the child
internal_err!("node had more than one input")
}
})?;

// make sure we did the actual replacement
assert_or_internal_err!(new_child.is_none(), "node had no inputs");
assert_or_internal_err!(replace_child.is_none(), "node had no inputs");

Ok(transformed_plan)
}
Expand Down
Loading