Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray};
use arrow::buffer::NullBuffer;
use arrow::datatypes::ArrowPrimitiveType;

use crate::aggregate::groups_accumulator::nulls::filter_to_validity;
use datafusion_expr_common::groups_accumulator::EmitTo;

/// If the input has nulls, then the accumulator must potentially
Expand Down Expand Up @@ -471,7 +472,7 @@ pub fn accumulate<T, F>(
///
/// This method assumes that for any input record index, if any of the value column
/// is null, or it's filtered out by `opt_filter`, then the record would be ignored.
/// (won't be accumulated by `value_fn`)
/// (Won't be accumulated by `value_fn`)
///
/// # Arguments
///
Expand All @@ -491,35 +492,28 @@ pub fn accumulate_multiple<T, F>(
T: ArrowPrimitiveType + Send,
F: FnMut(usize, usize, &[&PrimitiveArray<T>]) + Send,
{
// Calculate `valid_indices` to accumulate, non-valid indices are ignored.
// `valid_indices` is a bit mask corresponding to the `group_indices`. An index
// is considered valid if:
// 1. All columns are non-null at this index.
// 2. Not filtered out by `opt_filter`

// Take AND from all null buffers of `value_columns`.
let combined_nulls = value_columns
.iter()
.map(|arr| arr.logical_nulls())
.fold(None, |acc, nulls| {
NullBuffer::union(acc.as_ref(), nulls.as_ref())
});

// Take AND from previous combined nulls and `opt_filter`.
let valid_indices = match (combined_nulls, opt_filter) {
(None, None) => None,
(None, Some(filter)) => Some(filter.clone()),
(Some(nulls), None) => Some(BooleanArray::new(nulls.inner().clone(), None)),
(Some(nulls), Some(filter)) => {
let combined = nulls.inner() & filter.values();
Some(BooleanArray::new(combined, None))
}
};

for col in value_columns.iter() {
debug_assert_eq!(col.len(), group_indices.len());
}

// Start with rows where all value columns are non-null.
let mut valid_indices =
NullBuffer::union_many(value_columns.iter().map(|arr| arr.nulls()))
.map(NullBuffer::into_inner);

// Restrict to rows where the optional filter is Some(true). Keep the filter
// as a raw BooleanBuffer to avoid computing a NullBuffer null_count just to
// test row validity below.
if let Some(filter) = opt_filter {
debug_assert_eq!(filter.len(), group_indices.len());
let filter_validity = filter_to_validity(filter);
if let Some(valid_indices) = valid_indices.as_mut() {
*valid_indices &= &filter_validity;
} else {
valid_indices = Some(filter_validity);
}
}

match valid_indices {
None => {
for (batch_idx, &group_idx) in group_indices.iter().enumerate() {
Expand Down Expand Up @@ -562,7 +556,8 @@ pub fn accumulate_indices<F>(
(None, Some(filter)) => {
debug_assert_eq!(filter.len(), group_indices.len());
let group_indices_chunks = group_indices.chunks_exact(64);
let bit_chunks = filter.values().bit_chunks();
let filter_validity = filter_to_validity(filter);
let bit_chunks = filter_validity.bit_chunks();

let group_indices_remainder = group_indices_chunks.remainder();

Expand Down Expand Up @@ -636,7 +631,8 @@ pub fn accumulate_indices<F>(

let group_indices_chunks = group_indices.chunks_exact(64);
let valid_bit_chunks = valids.inner().bit_chunks();
let filter_bit_chunks = filter.values().bit_chunks();
let filter_validity = filter_to_validity(filter);
let filter_bit_chunks = filter_validity.bit_chunks();

let group_indices_remainder = group_indices_chunks.remainder();

Expand Down Expand Up @@ -1188,6 +1184,68 @@ mod test {
assert_eq!(accumulated, expected);
}

#[test]
fn test_accumulate_indices_with_null_filter() {
let group_indices = vec![0, 1, 0, 1];
let filter = BooleanArray::new(
BooleanBuffer::from(vec![true, true, true, false]),
Some(NullBuffer::from(vec![true, false, true, true])),
);

let mut accumulated = vec![];
accumulate_indices(&group_indices, None, Some(&filter), |group_idx| {
accumulated.push(group_idx);
});

// A NULL filter value should be treated the same as false, even if the
// underlying BooleanBuffer value is true.
let expected = vec![0, 0];
assert_eq!(accumulated, expected);

let value_validity = NullBuffer::from(vec![true, true, false, true]);
let mut accumulated = vec![];
accumulate_indices(
&group_indices,
Some(&value_validity),
Some(&filter),
|group_idx| {
accumulated.push(group_idx);
},
);

let expected = vec![0];
assert_eq!(accumulated, expected);
}

#[test]
fn test_accumulate_multiple_with_null_filter() {
let group_indices = vec![0, 1, 0, 1];
let values1 = Int32Array::from(vec![1, 2, 3, 4]);
let values2 = Int32Array::from(vec![10, 20, 30, 40]);
let value_columns = [values1, values2];

let filter = BooleanArray::new(
BooleanBuffer::from(vec![true, true, true, false]),
Some(NullBuffer::from(vec![true, false, true, true])),
);

let mut accumulated = vec![];
accumulate_multiple(
&group_indices,
&value_columns.iter().collect::<Vec<_>>(),
Some(&filter),
|group_idx, batch_idx, columns| {
let values = columns.iter().map(|col| col.value(batch_idx)).collect();
accumulated.push((group_idx, values));
},
);

// A NULL filter value should be treated the same as false, even if the
// underlying BooleanBuffer value is true.
let expected = vec![(0, vec![1, 10]), (0, vec![3, 30])];
assert_eq!(accumulated, expected);
}

#[test]
fn test_accumulate_multiple_with_nulls_and_filter() {
let group_indices = vec![0, 1, 0, 1];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use arrow::array::{
BooleanArray, LargeBinaryArray, LargeStringArray, PrimitiveArray, StringArray,
StringViewArray, StructArray,
};
use arrow::buffer::NullBuffer;
use arrow::buffer::{BooleanBuffer, NullBuffer};
use arrow::datatypes::DataType;
use datafusion_common::{Result, not_impl_err};
use std::sync::Arc;
Expand All @@ -39,15 +39,24 @@ pub fn set_nulls<T: ArrowNumericType + Send>(
PrimitiveArray::<T>::new(values, nulls).with_data_type(dt)
}

/// Converts a `BooleanBuffer` representing a filter to a `NullBuffer.
/// Converts an aggregate filter expression to a validity bitmap.
///
/// The output is `true` for rows where the filter is `Some(true)`, and `false`
/// for rows where the filter is `Some(false)` or `None`.
pub(crate) fn filter_to_validity(filter: &BooleanArray) -> BooleanBuffer {
let Some(filter_nulls) = filter.nulls() else {
return filter.values().clone();
};
filter.values() & filter_nulls.inner()
}

/// Converts an aggregate filter expression to a `NullBuffer`.
///
/// The `NullBuffer` is
/// * `true` (representing valid) for values that were `true` in filter
/// * `false` (representing null) for values that were `false` or `null` in filter
pub fn filter_to_nulls(filter: &BooleanArray) -> Option<NullBuffer> {
let (filter_bools, filter_nulls) = filter.clone().into_parts();
let filter_bools = NullBuffer::from(filter_bools);
NullBuffer::union(Some(&filter_bools), filter_nulls.as_ref())
/// * `true` (representing valid) for filter values that were `Some(true)`
/// * `false` (representing null) for filter values that were `Some(false)` or `None`
pub fn filter_to_nulls(filter: &BooleanArray) -> NullBuffer {
NullBuffer::new(filter_to_validity(filter))
}

/// Compute an output validity mask for an array that has been filtered
Expand Down Expand Up @@ -97,7 +106,7 @@ pub fn filtered_null_mask(
opt_filter: Option<&BooleanArray>,
input: &dyn Array,
) -> Option<NullBuffer> {
let opt_filter = opt_filter.and_then(filter_to_nulls);
let opt_filter = opt_filter.map(filter_to_nulls);
NullBuffer::union(opt_filter.as_ref(), input.nulls())
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/src/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ impl GroupsAccumulator for ArrayAggGroupsAccumulator {
let offsets = OffsetBuffer::from_repeated_length(1, input.len());

// Filtered rows become null list entries, which merge_batch will skip.
let filter_nulls = opt_filter.and_then(filter_to_nulls);
let filter_nulls = opt_filter.map(filter_to_nulls);

// With ignore_nulls, null values also become null list entries. Without
// ignore_nulls, null values stay as [NULL] so merge_batch retains them.
Expand Down
20 changes: 20 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,18 @@ from data
----
1

# correlation_with_group_by_and_nullable_filter
query IR rowsort
SELECT g, corr(x, y) FILTER (WHERE b < 1) AS r
FROM (VALUES
(0, 1.0, 1.0, CAST(NULL AS INT)),
(0, 2.0, 2.0, CAST(NULL AS INT)),
(0, 3.0, 4.0, 2)
) AS t(g, x, y, b)
GROUP BY g
----
0 NULL

# group correlation_query_with_nans_f32
query IR
select id, corr(f, b)
Expand Down Expand Up @@ -6177,6 +6189,14 @@ FROM test_table
----
2

# count_with_group_by_and_nullable_filter
query II rowsort
SELECT g, COUNT(a) FILTER (WHERE b < 1) AS count_a
FROM (VALUES (0, 1, CAST(NULL AS INT)), (0, 2, 2)) AS t(g, a, b)
GROUP BY g
----
0 0

# query_with_and_without_filter
query III rowsort
SELECT
Expand Down
Loading