diff --git a/datafusion/datasource-parquet/src/row_filter.rs b/datafusion/datasource-parquet/src/row_filter.rs index 6dfaa731ae7f9..f4b38580bd5a1 100644 --- a/datafusion/datasource-parquet/src/row_filter.rs +++ b/datafusion/datasource-parquet/src/row_filter.rs @@ -442,7 +442,7 @@ impl TreeNodeVisitor<'_> for PushdownChecker<'_> { && (!DataType::is_nested(return_type) || self.is_nested_type_supported(return_type)) { - // try to resolve all field name arguments to strinrg literals + // try to resolve all field name arguments to string literals // if any argument is not a string literal, we can not determine the exact // leaf path so we fall back to reading the entire struct root column let field_path = args[1..] @@ -766,11 +766,7 @@ fn resolve_struct_field_leaves( // A leaf matches if its path starts with our prefix. // e.g., prefix=["s", "value"] matches leaf path ["s", "value"] - // prefix=["s", "outer"] matches ["s", "outer", "inner"] - - // a leaf matches if its path starts with our prefix - // for example: prefix=["s", "value"] matches leaf path ["s", "value"] - // prefix=["s", "outer"] matches ["s", "outer", "inner"] + // prefix=["s", "outer"] matches ["s", "outer", "inner"] let leaf_matches_path = col_path.len() >= prefix.len() && col_path.iter().zip(prefix.iter()).all(|(a, b)| a == b); @@ -1523,9 +1519,8 @@ mod test { } /// Regression test: when a schema has Struct columns, Arrow field indices diverge - /// from Parquet leaf indices (Struct children become separate leaves). The - /// `PrimitiveOnly` fast-path in `leaf_indices_for_roots` assumes they are equal, - /// so a filter on a primitive column *after* a Struct gets the wrong leaf index. + /// from Parquet leaf indices (Struct children become separate leaves). + /// A filter on a primitive column *after* a Struct must use the correct leaf index. /// /// Schema: /// Arrow indices: col_a=0 struct_col=1 col_b=2 @@ -2045,7 +2040,7 @@ mod test { ), ); - // all3 Parquet leaves should be in the projection mask + // all 3 Parquet leaves should be in the projection mask let expected_mask = ProjectionMask::leaves(schema_descr, [0, 1, 2]); assert_eq!(read_plan.projection_mask, expected_mask,); } diff --git a/datafusion/spark/src/function/datetime/make_interval.rs b/datafusion/spark/src/function/datetime/make_interval.rs index abbf398d53d89..b6bfff4a40425 100644 --- a/datafusion/spark/src/function/datetime/make_interval.rs +++ b/datafusion/spark/src/function/datetime/make_interval.rs @@ -20,18 +20,25 @@ use std::sync::Arc; use arrow::array::{Array, ArrayRef, IntervalMonthDayNanoBuilder, PrimitiveArray}; use arrow::datatypes::DataType::Interval; use arrow::datatypes::IntervalUnit::MonthDayNano; -use arrow::datatypes::{DataType, IntervalMonthDayNano}; +use arrow::datatypes::{DataType, Field, FieldRef, IntervalMonthDayNano}; +use datafusion_common::config::ConfigOptions; use datafusion_common::types::{NativeType, logical_float64, logical_int32}; -use datafusion_common::{DataFusionError, Result, ScalarValue, plan_datafusion_err}; +use datafusion_common::{ + DataFusionError, Result, ScalarValue, exec_err, plan_datafusion_err, +}; use datafusion_expr::{ - Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, - TypeSignatureClass, Volatility, + Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, + ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, }; use datafusion_functions::utils::make_scalar_function; #[derive(Debug, PartialEq, Eq, Hash)] pub struct SparkMakeInterval { signature: Signature, + /// Mirrors `spark.sql.ansi.enabled` / `enable_ansi_mode`. + /// When true (failOnError=true in Spark) arithmetic overflow returns an error; + /// when false (default) it returns NULL instead. + ansi_mode: bool, } impl Default for SparkMakeInterval { @@ -42,6 +49,10 @@ impl Default for SparkMakeInterval { impl SparkMakeInterval { pub fn new() -> Self { + Self::new_with_config(&ConfigOptions::default()) + } + + pub fn new_with_config(config: &ConfigOptions) -> Self { let int32 = Coercion::new_implicit( TypeSignatureClass::Native(logical_int32()), vec![TypeSignatureClass::Integer], @@ -100,6 +111,7 @@ impl SparkMakeInterval { Self { signature: Signature::one_of(variants, Volatility::Immutable), + ansi_mode: config.execution.enable_ansi_mode, } } } @@ -114,20 +126,50 @@ impl ScalarUDFImpl for SparkMakeInterval { } fn return_type(&self, _arg_types: &[DataType]) -> Result { + // return_field_from_args is the authoritative implementation Ok(Interval(MonthDayNano)) } + fn with_updated_config(&self, config: &ConfigOptions) -> Option { + Some(ScalarUDF::from(Self::new_with_config(config))) + } + + /// Spark nullability rule (mirrors `failOnError` in Spark source): + /// - nullary call → never null (always returns zero interval) + /// - ANSI mode on → nullable only when any input field is nullable + /// - ANSI mode off → always nullable (overflow silently produces NULL) + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = if args.arg_fields.is_empty() { + false + } else if self.ansi_mode { + args.arg_fields.iter().any(|f| f.is_nullable()) + } else { + true + }; + Ok(Arc::new(Field::new( + self.name(), + Interval(MonthDayNano), + nullable, + ))) + } + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { if args.args.is_empty() { return Ok(ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano( Some(IntervalMonthDayNano::new(0, 0, 0)), ))); } - make_scalar_function(make_interval_kernel, vec![])(&args.args) + let ansi_mode = self.ansi_mode; + make_scalar_function(move |cols| make_interval_kernel(cols, ansi_mode), vec![])( + &args.args, + ) } } -fn make_interval_kernel(args: &[ArrayRef]) -> Result { +fn make_interval_kernel( + args: &[ArrayRef], + ansi_mode: bool, +) -> Result { use arrow::array::AsArray; use arrow::datatypes::{Float64Type, Int32Type}; @@ -216,6 +258,11 @@ fn make_interval_kernel(args: &[ArrayRef]) -> Result match make_interval_month_day_nano(y, mo, w, d, h, mi, s) { Some(v) => builder.append_value(v), None => { + if ansi_mode { + return exec_err!( + "Arithmetic overflow in make_interval: result does not fit in IntervalMonthDayNano" + ); + } builder.append_null(); continue; } @@ -274,7 +321,7 @@ mod tests { use super::*; fn run_make_interval_month_day_nano(arrs: Vec) -> Result { - make_interval_kernel(&arrs) + make_interval_kernel(&arrs, false) } #[test] @@ -537,6 +584,14 @@ mod tests { fn invoke_make_interval_with_args( args: Vec, number_rows: usize, + ) -> Result { + invoke_make_interval_with_config(args, number_rows, &ConfigOptions::default()) + } + + fn invoke_make_interval_with_config( + args: Vec, + number_rows: usize, + config: &ConfigOptions, ) -> Result { let arg_fields = args .iter() @@ -547,9 +602,9 @@ mod tests { arg_fields, number_rows, return_field: Field::new("f", Interval(MonthDayNano), true).into(), - config_options: Arc::new(ConfigOptions::default()), + config_options: Arc::new(config.clone()), }; - SparkMakeInterval::new().invoke_with_args(args) + SparkMakeInterval::new_with_config(config).invoke_with_args(args) } #[test] @@ -601,4 +656,114 @@ mod tests { Ok(()) } + + // --- nullability / return_field_from_args tests --- + + fn make_ansi_config() -> ConfigOptions { + let mut cfg = ConfigOptions::default(); + cfg.execution.enable_ansi_mode = true; + cfg + } + + #[test] + fn return_field_nullary_is_not_nullable() { + let udf = SparkMakeInterval::new(); + let field = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[], + scalar_arguments: &[], + }) + .unwrap(); + assert!(!field.is_nullable(), "nullary call must not be nullable"); + } + + #[test] + fn return_field_non_ansi_always_nullable() { + // Even with all non-null inputs, non-ANSI mode is always nullable + // because overflow silently returns NULL. + let udf = SparkMakeInterval::new(); // ansi_mode = false + let non_null_field: FieldRef = Arc::new(Field::new("x", DataType::Int32, false)); + let field = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[non_null_field], + scalar_arguments: &[None], + }) + .unwrap(); + assert!(field.is_nullable(), "non-ANSI must always be nullable"); + } + + #[test] + fn return_field_ansi_mode_not_nullable_when_inputs_not_null() { + // ANSI mode: no overflow → null; nullable only if inputs are nullable. + let udf = SparkMakeInterval::new_with_config(&make_ansi_config()); + let non_null_field: FieldRef = Arc::new(Field::new("x", DataType::Int32, false)); + let field = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[non_null_field], + scalar_arguments: &[None], + }) + .unwrap(); + assert!( + !field.is_nullable(), + "ANSI mode with non-null inputs must not be nullable" + ); + } + + #[test] + fn return_field_ansi_mode_nullable_when_any_input_nullable() { + let udf = SparkMakeInterval::new_with_config(&make_ansi_config()); + let nullable_field: FieldRef = Arc::new(Field::new("x", DataType::Int32, true)); + let field = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[nullable_field], + scalar_arguments: &[None], + }) + .unwrap(); + assert!( + field.is_nullable(), + "ANSI mode with nullable inputs must be nullable" + ); + } + + // --- ANSI mode overflow error tests --- + + #[test] + fn ansi_mode_overflow_returns_error() { + let ansi_cfg = make_ansi_config(); + let year = ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(i32::MAX)]))); + let month = ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(1)]))); + let week = ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(0)]))); + let day = ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(0)]))); + let hour = ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(0)]))); + let min = ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(0)]))); + let sec = ColumnarValue::Array(Arc::new(Float64Array::from(vec![Some(0.0)]))); + + let result = invoke_make_interval_with_config( + vec![year, month, week, day, hour, min, sec], + 1, + &ansi_cfg, + ); + assert!( + result.is_err(), + "ANSI mode overflow must return an error, not NULL" + ); + } + + #[test] + fn non_ansi_overflow_returns_null() { + // Existing behavior must be preserved: overflow → NULL in non-ANSI mode. + let year = Arc::new(Int32Array::from(vec![Some(i32::MAX)])) as ArrayRef; + let month = Arc::new(Int32Array::from(vec![Some(1)])) as ArrayRef; + let week = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef; + let day = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef; + let hour = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef; + let min = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef; + let sec = Arc::new(Float64Array::from(vec![Some(0.0)])) as ArrayRef; + + let out = run_make_interval_month_day_nano(vec![ + year, month, week, day, hour, min, sec, + ]) + .unwrap(); + assert_eq!(out.null_count(), 1, "non-ANSI overflow must produce NULL"); + } }