diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 7d32f2a88fd9c..d143ee228f122 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -456,6 +456,53 @@ impl Display for SpillCompression { } } +/// Policy for handling duplicate keys in Spark-compatible map-construction +/// functions (`map_from_arrays`, `map_from_entries`, `str_to_map`). Mirrors +/// Spark's [`spark.sql.mapKeyDedupPolicy`](https://github.com/apache/spark/blob/cf3a34e19dfcf70e2d679217ff1ba21302212472/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala#L4961). +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] +pub enum MapKeyDedupPolicy { + /// Raise `[DUPLICATED_MAP_KEY]` at runtime on any duplicate key. + #[default] + Exception, + /// Keep the last occurrence of each duplicate key. + LastWin, +} + +impl FromStr for MapKeyDedupPolicy { + type Err = DataFusionError; + + fn from_str(s: &str) -> Result { + match s.to_ascii_uppercase().as_str() { + "EXCEPTION" => Ok(Self::Exception), + "LAST_WIN" => Ok(Self::LastWin), + other => Err(DataFusionError::Configuration(format!( + "Invalid MapKeyDedupPolicy: {other}. Expected one of: EXCEPTION, LAST_WIN" + ))), + } + } +} + +impl ConfigField for MapKeyDedupPolicy { + fn visit(&self, v: &mut V, key: &str, description: &'static str) { + v.some(key, self, description) + } + + fn set(&mut self, _: &str, value: &str) -> Result<()> { + *self = MapKeyDedupPolicy::from_str(value)?; + Ok(()) + } +} + +impl Display for MapKeyDedupPolicy { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let str = match self { + Self::Exception => "EXCEPTION", + Self::LastWin => "LAST_WIN", + }; + write!(f, "{str}") + } +} + impl From for Option { fn from(c: SpillCompression) -> Self { match c { @@ -1461,6 +1508,24 @@ impl<'a> TryFrom<&'a FormatOptions> for arrow::util::display::FormatOptions<'a> } } +config_namespace! { + /// Options controlling DataFusion's Spark-compatibility layer (functions + /// under `datafusion/spark`). Keys here mirror their `spark.sql.*` + /// equivalents in Apache Spark. + pub struct SparkOptions { + /// Policy for handling duplicate keys in Spark-compatible map-construction + /// functions (`map_from_arrays`, `map_from_entries`, `str_to_map`). + /// + /// Mirrors Spark's + /// [`spark.sql.mapKeyDedupPolicy`](https://github.com/apache/spark/blob/cf3a34e19dfcf70e2d679217ff1ba21302212472/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala#L4961): + /// - `EXCEPTION` (default): raise `[DUPLICATED_MAP_KEY]` at runtime on any duplicate key. + /// - `LAST_WIN`: keep the last occurrence of each duplicate key. + /// + /// Values are case-insensitive. + pub map_key_dedup_policy: MapKeyDedupPolicy, default = MapKeyDedupPolicy::Exception + } +} + /// A key value pair, with a corresponding description #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct ConfigEntry { @@ -1492,6 +1557,8 @@ pub struct ConfigOptions { pub extensions: Extensions, /// Formatting options when printing batches pub format: FormatOptions, + /// Spark-compatibility options (functions under `datafusion/spark`) + pub spark: SparkOptions, } impl ConfigField for ConfigOptions { @@ -1502,6 +1569,7 @@ impl ConfigField for ConfigOptions { self.explain.visit(v, "datafusion.explain", ""); self.sql_parser.visit(v, "datafusion.sql_parser", ""); self.format.visit(v, "datafusion.format", ""); + self.spark.visit(v, "datafusion.spark", ""); } fn set(&mut self, key: &str, value: &str) -> Result<()> { @@ -1514,6 +1582,7 @@ impl ConfigField for ConfigOptions { "explain" => self.explain.set(rem, value), "sql_parser" => self.sql_parser.set(rem, value), "format" => self.format.set(rem, value), + "spark" => self.spark.set(rem, value), _ => _config_err!("Config value \"{key}\" not found on ConfigOptions"), } } @@ -1553,6 +1622,7 @@ impl ConfigField for ConfigOptions { "explain" => self.explain.reset(rem), "sql_parser" => self.sql_parser.reset(rem), "format" => self.format.reset(rem), + "spark" => self.spark.reset(rem), other => _config_err!("Config value \"{other}\" not found on ConfigOptions"), } } diff --git a/datafusion/spark/src/function/map/map_from_arrays.rs b/datafusion/spark/src/function/map/map_from_arrays.rs index 692e837d00f5e..92dea2720fbfc 100644 --- a/datafusion/spark/src/function/map/map_from_arrays.rs +++ b/datafusion/spark/src/function/map/map_from_arrays.rs @@ -22,6 +22,7 @@ use crate::function::map::utils::{ use arrow::array::{Array, ArrayRef, NullArray}; use arrow::compute::kernels::cast; use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::config::MapKeyDedupPolicy; use datafusion_common::utils::take_function_args; use datafusion_common::{Result, internal_err}; use datafusion_expr::{ @@ -81,11 +82,16 @@ impl ScalarUDFImpl for MapFromArrays { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(map_from_arrays_inner, vec![])(&args.args) + let last_value_wins = + args.config_options.spark.map_key_dedup_policy == MapKeyDedupPolicy::LastWin; + make_scalar_function( + move |args: &[ArrayRef]| map_from_arrays_inner(args, last_value_wins), + vec![], + )(&args.args) } } -fn map_from_arrays_inner(args: &[ArrayRef]) -> Result { +fn map_from_arrays_inner(args: &[ArrayRef], last_value_wins: bool) -> Result { let [keys, values] = take_function_args("map_from_arrays", args)?; if *keys.data_type() == DataType::Null || *values.data_type() == DataType::Null { @@ -105,6 +111,7 @@ fn map_from_arrays_inner(args: &[ArrayRef]) -> Result { &get_list_offsets(values)?, keys.nulls(), values.nulls(), + last_value_wins, ) } diff --git a/datafusion/spark/src/function/map/map_from_entries.rs b/datafusion/spark/src/function/map/map_from_entries.rs index facf9f8c53473..69ce352694bd1 100644 --- a/datafusion/spark/src/function/map/map_from_entries.rs +++ b/datafusion/spark/src/function/map/map_from_entries.rs @@ -24,6 +24,7 @@ use crate::function::map::utils::{ use arrow::array::{Array, ArrayRef, NullBufferBuilder, StructArray}; use arrow::buffer::NullBuffer; use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::config::MapKeyDedupPolicy; use datafusion_common::utils::take_function_args; use datafusion_common::{Result, exec_err, internal_err}; use datafusion_expr::{ @@ -101,11 +102,16 @@ impl ScalarUDFImpl for MapFromEntries { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(map_from_entries_inner, vec![])(&args.args) + let last_value_wins = + args.config_options.spark.map_key_dedup_policy == MapKeyDedupPolicy::LastWin; + make_scalar_function( + move |args: &[ArrayRef]| map_from_entries_inner(args, last_value_wins), + vec![], + )(&args.args) } } -fn map_from_entries_inner(args: &[ArrayRef]) -> Result { +fn map_from_entries_inner(args: &[ArrayRef], last_value_wins: bool) -> Result { let [entries] = take_function_args("map_from_entries", args)?; let entries_offsets = get_list_offsets(entries)?; let entries_values = get_list_values(entries)?; @@ -148,6 +154,7 @@ fn map_from_entries_inner(args: &[ArrayRef]) -> Result { &entries_offsets, None, res_nulls.as_ref(), + last_value_wins, ) } diff --git a/datafusion/spark/src/function/map/str_to_map.rs b/datafusion/spark/src/function/map/str_to_map.rs index c603e775a6031..8f4feb130bcb0 100644 --- a/datafusion/spark/src/function/map/str_to_map.rs +++ b/datafusion/spark/src/function/map/str_to_map.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use arrow::array::{ @@ -33,6 +33,7 @@ use datafusion_expr::{ }; use crate::function::map::utils::map_type_from_key_value_types; +use datafusion_common::config::MapKeyDedupPolicy; const DEFAULT_PAIR_DELIM: &str = ","; const DEFAULT_KV_DELIM: &str = ":"; @@ -48,11 +49,10 @@ const DEFAULT_KV_DELIM: &str = ":"; /// - keyValueDelim: Delimiter between key and value (default: ':') /// /// # Duplicate Key Handling -/// Uses EXCEPTION behavior (Spark 3.0+ default): errors on duplicate keys. -/// See `spark.sql.mapKeyDedupPolicy`: -/// -/// -/// TODO: Support configurable `spark.sql.mapKeyDedupPolicy` (LAST_WIN) in a follow-up PR. +/// Mirrors Spark's [`spark.sql.mapKeyDedupPolicy`](https://github.com/apache/spark/blob/v4.0.0/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala#L4502-L4511), +/// wired through DataFusion's `datafusion.spark.map_key_dedup_policy`: +/// - `EXCEPTION` (default): error on duplicate keys. +/// - `LAST_WIN`: keep the last occurrence of each duplicate key. #[derive(Debug, PartialEq, Eq, Hash)] pub struct SparkStrToMap { signature: Signature, @@ -102,22 +102,32 @@ impl ScalarUDFImpl for SparkStrToMap { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let last_value_wins = + args.config_options.spark.map_key_dedup_policy == MapKeyDedupPolicy::LastWin; let arrays: Vec = ColumnarValue::values_to_arrays(&args.args)?; - let result = str_to_map_inner(&arrays)?; + let result = str_to_map_inner(&arrays, last_value_wins)?; Ok(ColumnarValue::Array(result)) } } -fn str_to_map_inner(args: &[ArrayRef]) -> Result { +fn str_to_map_inner(args: &[ArrayRef], last_value_wins: bool) -> Result { match args.len() { 1 => match args[0].data_type() { - DataType::Utf8 => str_to_map_impl(as_string_array(&args[0])?, None, None), - DataType::LargeUtf8 => { - str_to_map_impl(as_large_string_array(&args[0])?, None, None) - } - DataType::Utf8View => { - str_to_map_impl(as_string_view_array(&args[0])?, None, None) + DataType::Utf8 => { + str_to_map_impl(as_string_array(&args[0])?, None, None, last_value_wins) } + DataType::LargeUtf8 => str_to_map_impl( + as_large_string_array(&args[0])?, + None, + None, + last_value_wins, + ), + DataType::Utf8View => str_to_map_impl( + as_string_view_array(&args[0])?, + None, + None, + last_value_wins, + ), other => exec_err!( "Unsupported data type {other:?} for str_to_map, \ expected Utf8, LargeUtf8, or Utf8View" @@ -128,16 +138,19 @@ fn str_to_map_inner(args: &[ArrayRef]) -> Result { as_string_array(&args[0])?, Some(as_string_array(&args[1])?), None, + last_value_wins, ), (DataType::LargeUtf8, DataType::LargeUtf8) => str_to_map_impl( as_large_string_array(&args[0])?, Some(as_large_string_array(&args[1])?), None, + last_value_wins, ), (DataType::Utf8View, DataType::Utf8View) => str_to_map_impl( as_string_view_array(&args[0])?, Some(as_string_view_array(&args[1])?), None, + last_value_wins, ), (t1, t2) => exec_err!( "Unsupported data types ({t1:?}, {t2:?}) for str_to_map, \ @@ -153,12 +166,14 @@ fn str_to_map_inner(args: &[ArrayRef]) -> Result { as_string_array(&args[0])?, Some(as_string_array(&args[1])?), Some(as_string_array(&args[2])?), + last_value_wins, ), (DataType::LargeUtf8, DataType::LargeUtf8, DataType::LargeUtf8) => { str_to_map_impl( as_large_string_array(&args[0])?, Some(as_large_string_array(&args[1])?), Some(as_large_string_array(&args[2])?), + last_value_wins, ) } (DataType::Utf8View, DataType::Utf8View, DataType::Utf8View) => { @@ -166,6 +181,7 @@ fn str_to_map_inner(args: &[ArrayRef]) -> Result { as_string_view_array(&args[0])?, Some(as_string_view_array(&args[1])?), Some(as_string_view_array(&args[2])?), + last_value_wins, ) } (t1, t2, t3) => exec_err!( @@ -181,6 +197,7 @@ fn str_to_map_impl<'a, V: StringArrayType<'a> + Copy>( text_array: V, pair_delim_array: Option, kv_delim_array: Option, + last_value_wins: bool, ) -> Result { let num_rows = text_array.len(); @@ -207,6 +224,10 @@ fn str_to_map_impl<'a, V: StringArrayType<'a> + Copy>( ); let mut seen_keys = HashSet::new(); + // LAST_WIN buffers pairs to support in-place value overwrite at the key's + // first-seen position — matches Spark's `ArrayBasedMapBuilder`. + let mut pairs: Vec<(&str, Option<&str>)> = Vec::new(); + let mut key_positions: HashMap<&str, usize> = HashMap::new(); for row_idx in 0..num_rows { if combined_nulls.as_ref().is_some_and(|n| n.is_null(row_idx)) { map_builder.append(false)?; @@ -227,31 +248,56 @@ fn str_to_map_impl<'a, V: StringArrayType<'a> + Copy>( continue; } - seen_keys.clear(); - for pair in text.split(pair_delim) { - if pair.is_empty() { - continue; + if last_value_wins { + pairs.clear(); + key_positions.clear(); + for pair in text.split(pair_delim) { + if pair.is_empty() { + continue; + } + let mut kv_iter = pair.splitn(2, kv_delim); + let key = kv_iter.next().unwrap_or(""); + let value = kv_iter.next(); + match key_positions.get(key) { + Some(&idx) => pairs[idx].1 = value, + None => { + key_positions.insert(key, pairs.len()); + pairs.push((key, value)); + } + } + } + for (key, value) in &pairs { + map_builder.keys().append_value(key); + match value { + Some(v) => map_builder.values().append_value(v), + None => map_builder.values().append_null(), + } } + } else { + seen_keys.clear(); + for pair in text.split(pair_delim) { + if pair.is_empty() { + continue; + } - let mut kv_iter = pair.splitn(2, kv_delim); - let key = kv_iter.next().unwrap_or(""); - let value = kv_iter.next(); + let mut kv_iter = pair.splitn(2, kv_delim); + let key = kv_iter.next().unwrap_or(""); + let value = kv_iter.next(); - // TODO: Support LAST_WIN policy via spark.sql.mapKeyDedupPolicy config - // EXCEPTION policy: error on duplicate keys (Spark 3.0+ default) - if !seen_keys.insert(key) { - return exec_err!( - "Duplicate map key '{key}' was found, please check the input data. \ - If you want to remove the duplicated keys, you can set \ - spark.sql.mapKeyDedupPolicy to \"LAST_WIN\" so that the key \ - inserted at last takes precedence." - ); - } + if !seen_keys.insert(key) { + return exec_err!( + "[DUPLICATED_MAP_KEY] Duplicate map key '{key}' was found, \ + please check the input data. To allow duplicate keys with \ + last-value-wins semantics, set \ + `datafusion.spark.map_key_dedup_policy` to `LAST_WIN`." + ); + } - map_builder.keys().append_value(key); - match value { - Some(v) => map_builder.values().append_value(v), - None => map_builder.values().append_null(), + map_builder.keys().append_value(key); + match value { + Some(v) => map_builder.values().append_value(v), + None => map_builder.values().append_null(), + } } } map_builder.append(true)?; diff --git a/datafusion/spark/src/function/map/utils.rs b/datafusion/spark/src/function/map/utils.rs index f5fff0c4b4c46..fa6b2a960dabb 100644 --- a/datafusion/spark/src/function/map/utils.rs +++ b/datafusion/spark/src/function/map/utils.rs @@ -16,12 +16,14 @@ // under the License. use std::borrow::Cow; -use std::collections::HashSet; +use std::collections::HashMap; use std::sync::Arc; -use arrow::array::{Array, ArrayRef, AsArray, BooleanBuilder, MapArray, StructArray}; +use arrow::array::{ + Array, ArrayRef, AsArray, BooleanBuilder, Int32Array, MapArray, StructArray, +}; use arrow::buffer::{NullBuffer, OffsetBuffer}; -use arrow::compute::filter; +use arrow::compute::{filter, take}; use arrow::datatypes::{DataType, Field, Fields}; use datafusion_common::{Result, ScalarValue, exec_err}; @@ -111,13 +113,13 @@ pub fn map_type_from_key_value_types( /// So the inputs can be [`ListArray`](`arrow::array::ListArray`)/[`LargeListArray`](`arrow::array::LargeListArray`)/[`FixedSizeListArray`](`arrow::array::FixedSizeListArray`)
/// To preserve the row info, [`offsets`](arrow::array::ListArray::offsets) and [`nulls`](arrow::array::ListArray::nulls) for both keys and values need to be provided
/// [`FixedSizeListArray`](`arrow::array::FixedSizeListArray`) has no `offsets`, so they can be generated as a cumulative sum of it's `Size` -/// 2. Spark provides [spark.sql.mapKeyDedupPolicy](https://github.com/apache/spark/blob/cf3a34e19dfcf70e2d679217ff1ba21302212472/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala#L4961) -/// to handle duplicate keys
-/// For now, configurable functions are not supported by Datafusion
-/// So more permissive `LAST_WIN` option is used in this implementation (instead of `EXCEPTION`)
-/// `EXCEPTION` behaviour can still be achieved externally in cost of performance:
-/// `when(array_length(array_distinct(keys)) == array_length(keys), constructed_map)`
-/// `.otherwise(raise_error("duplicate keys occurred during map construction"))` +/// 2. Duplicate-key handling mirrors Spark's +/// [spark.sql.mapKeyDedupPolicy](https://github.com/apache/spark/blob/cf3a34e19dfcf70e2d679217ff1ba21302212472/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala#L4961) +/// and is driven by `last_value_wins`: +/// - `false` (Spark's default `EXCEPTION`): raise `[DUPLICATED_MAP_KEY]` on any duplicate. +/// - `true` (`LAST_WIN`): keep the last occurrence of each duplicate key. +/// +/// Callers wire this from `datafusion.spark.map_key_dedup_policy`. pub fn map_from_keys_values_offsets_nulls( flat_keys: &ArrayRef, flat_values: &ArrayRef, @@ -125,6 +127,7 @@ pub fn map_from_keys_values_offsets_nulls( values_offsets: &[i32], keys_nulls: Option<&NullBuffer>, values_nulls: Option<&NullBuffer>, + last_value_wins: bool, ) -> Result { let (keys, values, offsets) = map_deduplicate_keys( flat_keys, @@ -133,6 +136,7 @@ pub fn map_from_keys_values_offsets_nulls( values_offsets, keys_nulls, values_nulls, + last_value_wins, )?; let nulls = NullBuffer::union(keys_nulls, values_nulls); @@ -155,6 +159,7 @@ fn map_deduplicate_keys( values_offsets: &[i32], keys_nulls: Option<&NullBuffer>, values_nulls: Option<&NullBuffer>, + last_value_wins: bool, ) -> Result<(ArrayRef, ArrayRef, OffsetBuffer)> { let offsets_len = keys_offsets.len(); let mut new_offsets = Vec::with_capacity(offsets_len); @@ -171,8 +176,14 @@ fn map_deduplicate_keys( let mut new_last_offset = 0; new_offsets.push(new_last_offset); + // Mirror Spark's `ArrayBasedMapBuilder`: the first occurrence of a key + // fixes its position in the output; under LAST_WIN a later duplicate + // overwrites that slot's value. `keys_mask` selects the first-seen keys, + // `value_indices` records the source index in `flat_values` to materialize + // for each output slot (updated in place on overwrite). let mut keys_mask_builder = BooleanBuilder::new(); - let mut values_mask_builder = BooleanBuilder::new(); + let mut value_indices: Vec = Vec::new(); + let mut key_to_output_idx: HashMap = HashMap::new(); for (row_idx, (next_keys_offset, next_values_offset)) in keys_offsets .iter() .zip(values_offsets.iter()) @@ -182,9 +193,6 @@ fn map_deduplicate_keys( let num_keys_entries = *next_keys_offset as usize - cur_keys_offset; let num_values_entries = *next_values_offset as usize - cur_values_offset; - let mut keys_mask_one = vec![false; num_keys_entries]; - let mut values_mask_one = vec![false; num_values_entries]; - let key_is_valid = keys_nulls.is_none_or(|buf| buf.is_valid(row_idx)); let value_is_valid = values_nulls.is_none_or(|buf| buf.is_valid(row_idx)); @@ -193,43 +201,175 @@ fn map_deduplicate_keys( return exec_err!( "map_deduplicate_keys: keys and values lists in the same row must have equal lengths" ); - } else if num_keys_entries != 0 { - let mut seen_keys = HashSet::new(); - - for cur_entry_idx in (0..num_keys_entries).rev() { - let key = ScalarValue::try_from_array( - &flat_keys, - cur_keys_offset + cur_entry_idx, - )? - .compacted(); - if seen_keys.contains(&key) { - // TODO: implement configuration and logic for spark.sql.mapKeyDedupPolicy=EXCEPTION (this is default spark-config) - // exec_err!("invalid argument: duplicate keys in map") - // https://github.com/apache/spark/blob/cf3a34e19dfcf70e2d679217ff1ba21302212472/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala#L4961 - } else { - // This code implements deduplication logic for spark.sql.mapKeyDedupPolicy=LAST_WIN (this is NOT default spark-config) - keys_mask_one[cur_entry_idx] = true; - values_mask_one[cur_entry_idx] = true; - seen_keys.insert(key); - new_last_offset += 1; + } + key_to_output_idx.clear(); + for cur_entry_idx in 0..num_keys_entries { + let key = ScalarValue::try_from_array( + &flat_keys, + cur_keys_offset + cur_entry_idx, + )? + .compacted(); + let abs_value_idx = (cur_values_offset + cur_entry_idx) as i32; + + if let Some(&output_idx) = key_to_output_idx.get(&key) { + if last_value_wins { + value_indices[output_idx] = abs_value_idx; + keys_mask_builder.append_value(false); + continue; } + return exec_err!( + "[DUPLICATED_MAP_KEY] Duplicate map key {key} was found, \ + please check the input data. To allow duplicate keys with \ + last-value-wins semantics, set \ + `datafusion.spark.map_key_dedup_policy` to `LAST_WIN`." + ); } + keys_mask_builder.append_value(true); + key_to_output_idx.insert(key, value_indices.len()); + value_indices.push(abs_value_idx); + new_last_offset += 1; } } else { - // the result entry is NULL - // both current row offsets are skipped - // keys or values in the current row are marked false in the masks + // The result entry is NULL — no keys/values emitted. Still pad the + // mask so it stays aligned with `flat_keys`. + keys_mask_builder.append_n(num_keys_entries, false); } - keys_mask_builder.append_array(&keys_mask_one.into()); - values_mask_builder.append_array(&values_mask_one.into()); new_offsets.push(new_last_offset); cur_keys_offset += num_keys_entries; cur_values_offset += num_values_entries; } let keys_mask = keys_mask_builder.finish(); - let values_mask = values_mask_builder.finish(); let needed_keys = filter(&flat_keys, &keys_mask)?; - let needed_values = filter(&flat_values, &values_mask)?; + let value_indices_array = Int32Array::from(value_indices); + let needed_values = take(&flat_values, &value_indices_array, None)?; let offsets = OffsetBuffer::new(new_offsets.into()); Ok((needed_keys, needed_values, offsets)) } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, StringArray}; + + fn int32_utf8_inputs( + keys: Vec, + values: Vec>, + ) -> (ArrayRef, ArrayRef) { + let keys: ArrayRef = Arc::new(Int32Array::from(keys)); + let values: ArrayRef = Arc::new(StringArray::from(values)); + (keys, values) + } + + #[test] + fn happy_path_two_rows_no_duplicates() { + let (keys, values) = + int32_utf8_inputs(vec![1, 2, 3], vec![Some("a"), Some("b"), Some("c")]); + let offsets = [0i32, 2, 3]; + + let result = map_from_keys_values_offsets_nulls( + &keys, &values, &offsets, &offsets, None, None, false, + ) + .unwrap(); + + let map = result.as_map(); + assert_eq!(map.len(), 2); + assert_eq!(map.value_offsets(), &[0, 2, 3]); + } + + #[test] + fn single_row_duplicate_errors_under_exception() { + let (keys, values) = + int32_utf8_inputs(vec![1, 2, 1], vec![Some("a"), Some("b"), Some("c")]); + let offsets = [0i32, 3]; + + let err = map_from_keys_values_offsets_nulls( + &keys, &values, &offsets, &offsets, None, None, false, + ) + .unwrap_err() + .to_string(); + + assert!(err.contains("[DUPLICATED_MAP_KEY]"), "{err}"); + assert!(err.contains("map_key_dedup_policy"), "{err}"); + } + + #[test] + fn last_win_keeps_final_occurrence() { + let (keys, values) = int32_utf8_inputs( + vec![1, 2, 1, 3, 2], + vec![Some("a"), Some("b"), Some("c"), Some("d"), Some("e")], + ); + let offsets = [0i32, 5]; + + let result = map_from_keys_values_offsets_nulls( + &keys, &values, &offsets, &offsets, None, None, true, + ) + .unwrap(); + + let map = result.as_map(); + assert_eq!(map.len(), 1); + // 5 entries in, 3 unique keys -> offsets [0, 3] + assert_eq!(map.value_offsets(), &[0, 3]); + } + + #[test] + fn duplicate_in_later_row_still_errors() { + let (keys, values) = int32_utf8_inputs( + vec![1, 2, 1, 1], + vec![Some("a"), Some("b"), Some("x"), Some("y")], + ); + let offsets = [0i32, 2, 4]; + + let err = map_from_keys_values_offsets_nulls( + &keys, &values, &offsets, &offsets, None, None, false, + ) + .unwrap_err() + .to_string(); + + assert!(err.contains("[DUPLICATED_MAP_KEY]"), "{err}"); + } + + #[test] + fn empty_row_does_not_trigger_dedup() { + let (keys, values) = int32_utf8_inputs(vec![], vec![]); + let offsets = [0i32, 0]; + + let result = map_from_keys_values_offsets_nulls( + &keys, &values, &offsets, &offsets, None, None, false, + ) + .unwrap(); + + let map = result.as_map(); + assert_eq!(map.len(), 1); + assert_eq!(map.value_offsets(), &[0, 0]); + } + + #[test] + fn null_row_is_skipped_and_not_checked() { + // Row 0 is NULL (keys null). Its duplicate keys should be ignored; + // row 1 is a clean row. + let (keys, values) = int32_utf8_inputs( + vec![1, 1, 2, 3], + vec![Some("dup-a"), Some("dup-b"), Some("x"), Some("y")], + ); + let offsets = [0i32, 2, 4]; + let keys_nulls = NullBuffer::from(vec![false, true]); + + let result = map_from_keys_values_offsets_nulls( + &keys, + &values, + &offsets, + &offsets, + Some(&keys_nulls), + None, + false, + ) + .unwrap(); + + let map = result.as_map(); + assert_eq!(map.len(), 2); + // First row is NULL (no entries emitted), second row keeps both entries. + assert_eq!(map.value_offsets(), &[0, 0, 2]); + assert!(map.is_null(0)); + assert!(!map.is_null(1)); + } +} diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index b04c78bd2774c..5a758062fe0c2 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -338,6 +338,7 @@ datafusion.runtime.max_temp_directory_size 100G datafusion.runtime.memory_limit unlimited datafusion.runtime.metadata_cache_limit 50M datafusion.runtime.temp_directory NULL +datafusion.spark.map_key_dedup_policy EXCEPTION datafusion.sql_parser.collect_spans false datafusion.sql_parser.default_null_ordering nulls_max datafusion.sql_parser.dialect generic @@ -485,6 +486,7 @@ datafusion.runtime.max_temp_directory_size 100G Maximum temporary file directory datafusion.runtime.memory_limit unlimited Maximum memory limit for query execution. Supports suffixes K (kilobytes), M (megabytes), and G (gigabytes). Example: '2G' for 2 gigabytes. datafusion.runtime.metadata_cache_limit 50M Maximum memory to use for file metadata cache such as Parquet metadata. Supports suffixes K (kilobytes), M (megabytes), and G (gigabytes). Example: '2G' for 2 gigabytes. datafusion.runtime.temp_directory NULL The path to the temporary file directory. +datafusion.spark.map_key_dedup_policy EXCEPTION Policy for handling duplicate keys in Spark-compatible map-construction functions (`map_from_arrays`, `map_from_entries`, `str_to_map`). Mirrors Spark's [`spark.sql.mapKeyDedupPolicy`](https://github.com/apache/spark/blob/cf3a34e19dfcf70e2d679217ff1ba21302212472/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala#L4961): - `EXCEPTION` (default): raise `[DUPLICATED_MAP_KEY]` at runtime on any duplicate key. - `LAST_WIN`: keep the last occurrence of each duplicate key. Values are case-insensitive. datafusion.sql_parser.collect_spans false When set to true, the source locations relative to the original SQL query (i.e. [`Span`](https://docs.rs/sqlparser/latest/sqlparser/tokenizer/struct.Span.html)) will be collected and recorded in the logical plan nodes. datafusion.sql_parser.default_null_ordering nulls_max Specifies the default null ordering for query results. There are 4 options: - `nulls_max`: Nulls appear last in ascending order. - `nulls_min`: Nulls appear first in ascending order. - `nulls_first`: Nulls always be first in any order. - `nulls_last`: Nulls always be last in any order. By default, `nulls_max` is used to follow Postgres's behavior. postgres rule: datafusion.sql_parser.dialect generic Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, Ansi, DuckDB and Databricks. diff --git a/datafusion/sqllogictest/test_files/spark/map/map_from_arrays.slt b/datafusion/sqllogictest/test_files/spark/map/map_from_arrays.slt index a26b0435c9291..7e501a31628e1 100644 --- a/datafusion/sqllogictest/test_files/spark/map/map_from_arrays.slt +++ b/datafusion/sqllogictest/test_files/spark/map/map_from_arrays.slt @@ -118,11 +118,25 @@ SELECT ---- {outer_key1: {inner_a: 1, inner_b: 2}, outer_key2: {inner_x: 10, inner_y: 20, inner_z: 30}} -# Test with duplicate keys -query ? +# Test with duplicate keys: raises DUPLICATED_MAP_KEY under Spark's default policy +query error DataFusion error: Execution error: \[DUPLICATED_MAP_KEY\] Duplicate map key true was found SELECT map_from_arrays(array(true, false, true), array('a', NULL, 'b')); ----- -{false: NULL, true: b} + +# Integer keys with a duplicate also raise DUPLICATED_MAP_KEY. +query error DataFusion error: Execution error: \[DUPLICATED_MAP_KEY\] Duplicate map key 1 was found +SELECT map_from_arrays(array(1, 2, 1), array('a', 'b', 'c')); + +# String keys with a duplicate also raise DUPLICATED_MAP_KEY. +query error DataFusion error: Execution error: \[DUPLICATED_MAP_KEY\] Duplicate map key k was found +SELECT map_from_arrays(array('k', 'k', 'k'), array(1, 2, 3)); + +# Multi-row: a clean row and a duplicate row still errors. +query error DataFusion error: Execution error: \[DUPLICATED_MAP_KEY\] Duplicate map key 1 was found +SELECT map_from_arrays(a, b) +FROM values + (array[1, 2], array['a', 'b']), + (array[1, 1], array['x', 'y']) +AS tab(a, b); # Tests with different list types query ? @@ -134,3 +148,40 @@ query ? SELECT map_from_arrays(arrow_cast(array('a', 'b', 'c'), 'FixedSizeList(3, Utf8)'), arrow_cast(array(1, 2, 3), 'LargeList(Int32)')); ---- {a: 1, b: 2, c: 3} + +# LAST_WIN policy: duplicates are allowed; later occurrences overwrite earlier ones. +statement ok +set datafusion.spark.map_key_dedup_policy = 'LAST_WIN'; + +query ? +SELECT map_from_arrays(array(1, 2, 1), array('a', 'b', 'c')); +---- +{1: c, 2: b} + +query ? +SELECT map_from_arrays(array('k', 'k', 'k'), array(1, 2, 3)); +---- +{k: 3} + +query ? +SELECT map_from_arrays(array(true, false, true), array('a', NULL, 'b')); +---- +{true: b, false: NULL} + +# Multi-row mix under LAST_WIN: clean, duplicate, empty and NULL rows all work. +query ? +SELECT map_from_arrays(a, b) +FROM values + (array[1, 2], array['a', 'b']), + (array[1, 1], array['x', 'y']), + (array[], array[]), + (NULL, NULL) +AS tab(a, b); +---- +{1: a, 2: b} +{1: y} +{} +NULL + +statement ok +set datafusion.spark.map_key_dedup_policy = 'EXCEPTION'; diff --git a/datafusion/sqllogictest/test_files/spark/map/map_from_entries.slt b/datafusion/sqllogictest/test_files/spark/map/map_from_entries.slt index 19b46886a027e..21f41f5ad976b 100644 --- a/datafusion/sqllogictest/test_files/spark/map/map_from_entries.slt +++ b/datafusion/sqllogictest/test_files/spark/map/map_from_entries.slt @@ -151,8 +151,8 @@ SELECT ---- {outer_key1: {inner_a: 1, inner_b: 2}, outer_key2: {inner_x: 10, inner_y: 20, inner_z: 30}} -# Test with duplicate keys -query ? +# Test with duplicate keys: raises DUPLICATED_MAP_KEY under Spark's default policy +query error DataFusion error: Execution error: \[DUPLICATED_MAP_KEY\] Duplicate map key true was found SELECT map_from_entries(array( struct(true, 'a'), struct(false, 'b'), @@ -160,5 +160,58 @@ SELECT map_from_entries(array( struct(false, cast(NULL as string)), struct(true, 'd') )); + +# Integer keys with a duplicate also raise DUPLICATED_MAP_KEY. +query error DataFusion error: Execution error: \[DUPLICATED_MAP_KEY\] Duplicate map key 1 was found +SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'b'), struct(1, 'c'))); + +# String keys with triple occurrence also raise DUPLICATED_MAP_KEY. +query error DataFusion error: Execution error: \[DUPLICATED_MAP_KEY\] Duplicate map key k was found +SELECT map_from_entries(array(struct('k', 1), struct('k', 2), struct('k', 3))); + +# Multi-row: a clean row followed by a duplicate row still errors. +query error DataFusion error: Execution error: \[DUPLICATED_MAP_KEY\] Duplicate map key 1 was found +SELECT map_from_entries(data) +FROM values + (array[struct(1, 'a'), struct(2, 'b')]), + (array[struct(1, 'x'), struct(1, 'y')]) +AS tab(data); + +# LAST_WIN policy: duplicates are allowed; later occurrences overwrite earlier ones. +statement ok +set datafusion.spark.map_key_dedup_policy = 'LAST_WIN'; + +query ? +SELECT map_from_entries(array( + struct(true, 'a'), + struct(false, 'b'), + struct(true, 'c'), + struct(false, cast(NULL as string)), + struct(true, 'd') +)); ---- -{false: NULL, true: d} +{true: d, false: NULL} + +query ? +SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'b'), struct(1, 'c'))); +---- +{1: c, 2: b} + +query ? +SELECT map_from_entries(array(struct('k', 1), struct('k', 2), struct('k', 3))); +---- +{k: 3} + +# Multi-row mix under LAST_WIN: clean row + duplicate row both succeed. +query ? +SELECT map_from_entries(data) +FROM values + (array[struct(1, 'a'), struct(2, 'b')]), + (array[struct(1, 'x'), struct(1, 'y')]) +AS tab(data); +---- +{1: a, 2: b} +{1: y} + +statement ok +set datafusion.spark.map_key_dedup_policy = 'EXCEPTION'; diff --git a/datafusion/sqllogictest/test_files/spark/map/str_to_map.slt b/datafusion/sqllogictest/test_files/spark/map/str_to_map.slt index 30d1672aef0ae..68d856d8545ae 100644 --- a/datafusion/sqllogictest/test_files/spark/map/str_to_map.slt +++ b/datafusion/sqllogictest/test_files/spark/map/str_to_map.slt @@ -64,11 +64,25 @@ SELECT str_to_map('a=1&b=2&c=3', '&', '='); {a: 1, b: 2, c: 3} # Duplicate keys: EXCEPTION policy (Spark 3.0+ default) -# TODO: Add LAST_WIN policy tests when spark.sql.mapKeyDedupPolicy config is supported statement error Duplicate map key SELECT str_to_map('a:1,b:2,a:3'); +# Triple+ occurrences of the same key still raise DUPLICATED_MAP_KEY. +statement error +Duplicate map key 'a' +SELECT str_to_map('a:1,a:2,a:3'); + +# Duplicate where one occurrence is missing the kv_delim (value = NULL) still errors. +statement error +Duplicate map key 'a' +SELECT str_to_map('a,b:2,a:3'); + +# Multi-row input: a clean row followed by a duplicate row fails on the duplicate row. +statement error +Duplicate map key 'a' +SELECT str_to_map(col) FROM (VALUES ('a:1,b:2'), ('a:3,a:4')) AS t(col); + # Additional tests (DataFusion-specific) # NULL input returns NULL @@ -111,4 +125,39 @@ SELECT str_to_map(col1, col2, col3) FROM (VALUES ('a=1,b=2', ',', '='), ('x#9', ---- {a: 1, b: 2} {x: 9} -NULL \ No newline at end of file +NULL + +# LAST_WIN policy: duplicates are allowed; later occurrences overwrite earlier ones. +statement ok +set datafusion.spark.map_key_dedup_policy = 'LAST_WIN'; + +query ? +SELECT str_to_map('a:1,b:2,a:3'); +---- +{a: 3, b: 2} + +query ? +SELECT str_to_map('a:1,a:2,a:3'); +---- +{a: 3} + +# Missing kv_delim: the later occurrence overwrites the value at the key's +# first-seen position. +query ? +SELECT str_to_map('a:1,b:2,a'); +---- +{a: NULL, b: 2} + +# Multi-row: both clean and duplicate rows succeed under LAST_WIN. +query ? +SELECT str_to_map(col) FROM (VALUES ('a:1,b:2'), ('a:3,a:4')) AS t(col); +---- +{a: 1, b: 2} +{a: 4} + +statement ok +set datafusion.spark.map_key_dedup_policy = 'EXCEPTION'; + +# Invalid policy values are rejected at SET time with a clear message. +statement error DataFusion error: Invalid or Unsupported Configuration: Invalid MapKeyDedupPolicy: BOGUS\. Expected one of: EXCEPTION, LAST_WIN +set datafusion.spark.map_key_dedup_policy = 'BOGUS'; \ No newline at end of file diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 46039f3c99c27..58eaffb9bf9f3 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -201,6 +201,7 @@ The following configuration settings are available: | datafusion.format.time_format | %H:%M:%S%.f | Time format for time arrays | | datafusion.format.duration_format | pretty | Duration format. Can be either `"pretty"` or `"ISO8601"` | | datafusion.format.types_info | false | Show types in visual representation batches | +| datafusion.spark.map_key_dedup_policy | EXCEPTION | Policy for handling duplicate keys in Spark-compatible map-construction functions (`map_from_arrays`, `map_from_entries`, `str_to_map`). Mirrors Spark's [`spark.sql.mapKeyDedupPolicy`](https://github.com/apache/spark/blob/cf3a34e19dfcf70e2d679217ff1ba21302212472/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala#L4961): - `EXCEPTION` (default): raise `[DUPLICATED_MAP_KEY]` at runtime on any duplicate key. - `LAST_WIN`: keep the last occurrence of each duplicate key. Values are case-insensitive. | You can also reset configuration options to default settings via SQL using the `RESET` command. For example, to set and reset `datafusion.execution.batch_size`: