diff --git a/datafusion/common/src/metadata.rs b/datafusion/common/src/metadata.rs index d6d8fb7b0ed0c..a0078a56e7888 100644 --- a/datafusion/common/src/metadata.rs +++ b/datafusion/common/src/metadata.rs @@ -17,10 +17,16 @@ use std::{collections::BTreeMap, sync::Arc}; -use arrow::datatypes::{DataType, Field, FieldRef}; +use arrow::{ + compute::CastOptions, + datatypes::{DataType, Field, FieldRef}, +}; use hashbrown::HashMap; -use crate::{DataFusionError, ScalarValue, error::_plan_err}; +use crate::{ + DataFusionError, ScalarValue, datatype::DataTypeExt, error::_plan_err, + nested_struct::CastExtension, +}; /// A [`ScalarValue`] with optional [`FieldMetadata`] #[derive(Debug, Clone)] @@ -62,6 +68,38 @@ impl ScalarAndMetadata { let new_value = self.value().cast_to(target_type)?; Ok(Self::new(new_value, self.metadata.clone())) } + + /// Try to cast this value to a ScalarValue of type `target_field` with [`CastOptions`] + pub fn cast_to_with_options( + &self, + target_field: &Field, + cast_extension: Option<&dyn CastExtension>, + cast_options: &CastOptions, + ) -> Result { + let mut source_field = self.value.data_type().into_nullable_field(); + if let Some(metadata) = &self.metadata { + source_field = metadata.add_to_field(source_field); + } + + if let Some(cast_extension) = cast_extension + && cast_extension.can_cast_fields(&source_field, target_field)? + { + let cast_arr = cast_extension.cast_array_fields( + &self.value.to_array()?, + &source_field, + target_field, + cast_options, + )?; + let storage = ScalarValue::try_from_array(&cast_arr, 0)?; + let metadata = FieldMetadata::new_from_field(target_field); + return Ok(Self { + value: storage, + metadata: Some(metadata), + }); + } + + self.cast_storage_to(target_field.data_type()) + } } /// create a new ScalarAndMetadata from a ScalarValue without diff --git a/datafusion/common/src/nested_struct.rs b/datafusion/common/src/nested_struct.rs index cdd6215d08e2f..7d0371f18a62c 100644 --- a/datafusion/common/src/nested_struct.rs +++ b/datafusion/common/src/nested_struct.rs @@ -15,7 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::error::{_plan_err, Result}; +use crate::{ + datatype::DataTypeExt, + error::{_internal_err, _plan_err, Result}, + metadata::format_type_and_metadata, +}; use arrow::{ array::{ Array, ArrayRef, DictionaryArray, GenericListArray, GenericListViewArray, @@ -56,6 +60,7 @@ use std::{collections::HashSet, sync::Arc}; fn cast_struct_column( source_col: &ArrayRef, target_fields: &[Arc], + cast_extension: Option<&dyn CastExtension>, cast_options: &CastOptions, ) -> Result { if source_col.data_type() == &DataType::Null @@ -78,14 +83,20 @@ fn cast_struct_column( for target_child_field in target_fields.iter() { fields.push(Arc::clone(target_child_field)); - let source_child_opt = - source_struct.column_by_name(target_child_field.name()); - - match source_child_opt { - Some(source_child_col) => { - let adapted_child = cast_column( - source_child_col, - target_child_field.data_type(), + let source_child_index_opt = source_struct + .column_names() + .iter() + .position(|name| *name == target_child_field.name()); + // let source_child_opt = + // source_struct.column_by_name(target_child_field.name()); + + match source_child_index_opt { + Some(source_child_index) => { + let adapted_child = cast_column_fields( + source_struct.column(source_child_index), + source_struct.fields()[source_child_index].as_ref(), + target_child_field.as_ref(), + cast_extension, cast_options, ) .map_err(|e| { @@ -173,22 +184,74 @@ pub fn cast_column( target_type: &DataType, cast_options: &CastOptions, ) -> Result { - match (source_col.data_type(), target_type) { + cast_column_fields( + source_col, + &source_col.data_type().clone().into_nullable_field(), + &target_type.clone().into_nullable_field(), + None, + cast_options, + ) +} + +pub fn cast_column_fields( + source_col: &ArrayRef, + source_field: &Field, + target_field: &Field, + cast_extension: Option<&dyn CastExtension>, + cast_options: &CastOptions, +) -> Result { + if let Some(cast_extension) = cast_extension + && cast_extension.can_cast_fields(source_field, target_field)? + { + return cast_extension.cast_array_fields( + source_col, + source_field, + target_field, + cast_options, + ); + } + + match (source_field.data_type(), target_field.data_type()) { (_, Struct(target_fields)) => { - cast_struct_column(source_col, target_fields, cast_options) + cast_struct_column(source_col, target_fields, cast_extension, cast_options) } - (DataType::List(_), DataType::List(target_inner)) => { - cast_list_column::(source_col, target_inner, cast_options) + (DataType::List(source_inner), DataType::List(target_inner)) => { + cast_list_column::( + source_col, + source_inner, + target_inner, + cast_extension, + cast_options, + ) } - (DataType::LargeList(_), DataType::LargeList(target_inner)) => { - cast_list_column::(source_col, target_inner, cast_options) + (DataType::LargeList(source_inner), DataType::LargeList(target_inner)) => { + cast_list_column::( + source_col, + source_inner, + target_inner, + cast_extension, + cast_options, + ) } - (DataType::ListView(_), DataType::ListView(target_inner)) => { - cast_list_view_column::(source_col, target_inner, cast_options) - } - (DataType::LargeListView(_), DataType::LargeListView(target_inner)) => { - cast_list_view_column::(source_col, target_inner, cast_options) + (DataType::ListView(source_inner), DataType::ListView(target_inner)) => { + cast_list_view_column::( + source_col, + source_inner, + target_inner, + cast_extension, + cast_options, + ) } + ( + DataType::LargeListView(source_inner), + DataType::LargeListView(target_inner), + ) => cast_list_view_column::( + source_col, + source_inner, + target_inner, + cast_extension, + cast_options, + ), ( DataType::Dictionary(source_key_type, _), DataType::Dictionary(target_key_type, target_value_type), @@ -197,15 +260,22 @@ pub fn cast_column( source_key_type, target_key_type, target_value_type, + cast_extension, cast_options, ), - _ => Ok(cast_with_options(source_col, target_type, cast_options)?), + _ => Ok(cast_with_options( + source_col, + target_field.data_type(), + cast_options, + )?), } } fn cast_list_column( source_col: &ArrayRef, + source_inner_field: &FieldRef, target_inner_field: &FieldRef, + cast_extension: Option<&dyn CastExtension>, cast_options: &CastOptions, ) -> Result { let source_list = source_col @@ -218,9 +288,11 @@ fn cast_list_column( )) })?; - let cast_values = cast_column( + let cast_values = cast_column_fields( source_list.values(), - target_inner_field.data_type(), + source_inner_field.as_ref(), + target_inner_field.as_ref(), + cast_extension, cast_options, )?; @@ -235,7 +307,9 @@ fn cast_list_column( fn cast_list_view_column( source_col: &ArrayRef, + source_inner_field: &FieldRef, target_inner_field: &FieldRef, + cast_extension: Option<&dyn CastExtension>, cast_options: &CastOptions, ) -> Result { let source_list = source_col @@ -248,9 +322,11 @@ fn cast_list_view_column( )) })?; - let cast_values = cast_column( + let cast_values = cast_column_fields( source_list.values(), - target_inner_field.data_type(), + source_inner_field.as_ref(), + target_inner_field.as_ref(), + cast_extension, cast_options, )?; @@ -269,6 +345,7 @@ fn cast_dictionary_column( source_key_type: &DataType, target_key_type: &DataType, target_value_type: &DataType, + cast_extension: Option<&dyn CastExtension>, cast_options: &CastOptions, ) -> Result { // Dispatch on source key type to access keys/values, then recursively @@ -279,8 +356,13 @@ fn cast_dictionary_column( .as_any() .downcast_ref::>() .expect("downcast must succeed"); - let cast_values = - cast_column(source_dict.values(), target_value_type, cast_options)?; + let cast_values = cast_column_fields( + source_dict.values(), + &source_dict.data_type().clone().into_nullable_field(), + &target_value_type.clone().into_nullable_field(), + cast_extension, + cast_options, + )?; Ok(Arc::new(DictionaryArray::<$t>::new( source_dict.keys().clone(), cast_values, @@ -345,6 +427,14 @@ fn cast_dictionary_column( pub fn validate_struct_compatibility( source_fields: &[FieldRef], target_fields: &[FieldRef], +) -> Result<()> { + validate_struct_compatibility_with_extension(source_fields, target_fields, None) +} + +pub fn validate_struct_compatibility_with_extension( + source_fields: &[FieldRef], + target_fields: &[FieldRef], + cast_extension: Option<&dyn CastExtension>, ) -> Result<()> { let has_overlap = has_one_of_more_common_fields(source_fields, target_fields); if !has_overlap { @@ -362,7 +452,7 @@ pub fn validate_struct_compatibility( .iter() .find(|f| f.name() == target_field.name()) { - validate_field_compatibility(source_field, target_field)?; + validate_field_compatibility(source_field, target_field, cast_extension)?; } else { // Target field is missing from source // If it's non-nullable, we cannot fill it with NULL @@ -383,6 +473,7 @@ pub fn validate_struct_compatibility( fn validate_field_compatibility( source_field: &Field, target_field: &Field, + cast_extension: Option<&dyn CastExtension>, ) -> Result<()> { if source_field.data_type() == &DataType::Null { // Validate that target allows nulls before returning early. @@ -407,10 +498,17 @@ fn validate_field_compatibility( ); } + if let Some(cast_extension) = cast_extension + && cast_extension.can_cast_fields(source_field, target_field)? + { + return Ok(()); + } + validate_data_type_compatibility( target_field.name(), source_field.data_type(), target_field.data_type(), + cast_extension, ) } @@ -420,6 +518,7 @@ pub fn validate_data_type_compatibility( field_name: &str, source_type: &DataType, target_type: &DataType, + cast_extension: Option<&dyn CastExtension>, ) -> Result<()> { match (source_type, target_type) { (Struct(source_nested), Struct(target_nested)) => { @@ -429,7 +528,7 @@ pub fn validate_data_type_compatibility( | (DataType::LargeList(s), DataType::LargeList(t)) | (DataType::ListView(s), DataType::ListView(t)) | (DataType::LargeListView(s), DataType::LargeListView(t)) => { - validate_field_compatibility(s, t)?; + validate_field_compatibility(s, t, cast_extension)?; } (DataType::Dictionary(s_key, s_val), DataType::Dictionary(t_key, t_val)) => { if !can_cast_types(s_key, t_key) { @@ -440,7 +539,7 @@ pub fn validate_data_type_compatibility( field_name ); } - validate_data_type_compatibility(field_name, s_val, t_val)?; + validate_data_type_compatibility(field_name, s_val, t_val, cast_extension)?; } _ => { if !can_cast_types(source_type, target_type) { @@ -502,6 +601,77 @@ pub fn has_one_of_more_common_fields( .any(|field| source_names.contains(field.name().as_str())) } +pub trait CastExtension: std::fmt::Debug + Send + Sync { + fn can_cast_fields(&self, source_field: &Field, target_field: &Field) + -> Result; + + fn cast_array_fields( + &self, + array: &ArrayRef, + source_field: &Field, + target_field: &Field, + cast_options: &CastOptions, + ) -> Result; +} + +#[derive(Debug)] +pub struct VecCastExtension { + extensions: Vec>, +} + +impl VecCastExtension { + pub fn new(extensions: Vec>) -> Self { + Self { extensions } + } +} + +impl CastExtension for VecCastExtension { + fn can_cast_fields( + &self, + source_field: &Field, + target_field: &Field, + ) -> Result { + for extension in &self.extensions { + if extension.can_cast_fields(source_field, target_field)? { + return Ok(true); + } + } + + Ok(false) + } + + fn cast_array_fields( + &self, + array: &ArrayRef, + source_field: &Field, + target_field: &Field, + cast_options: &CastOptions, + ) -> Result { + for extension in &self.extensions { + if extension.can_cast_fields(source_field, target_field)? { + return extension.cast_array_fields( + array, + source_field, + target_field, + cast_options, + ); + } + } + + let source_display = format_type_and_metadata( + source_field.data_type(), + Some(source_field.metadata()), + ); + let target_display = format_type_and_metadata( + target_field.data_type(), + Some(target_field.metadata()), + ); + _internal_err!( + "Can't resolve extension to cast from {source_display} to {target_display}" + ) + } +} + #[cfg(test)] mod tests { use super::*; @@ -1214,7 +1384,7 @@ mod tests { DataType::Dictionary(Box::new(DataType::Int32), Box::new(source_inner)); let target = DataType::Dictionary(Box::new(DataType::Int32), Box::new(target_inner)); - assert!(validate_data_type_compatibility("col", &source, &target).is_ok()); + assert!(validate_data_type_compatibility("col", &source, &target, None).is_ok()); } #[test] diff --git a/datafusion/common/src/types/canonical_extensions/uuid.rs b/datafusion/common/src/types/canonical_extensions/uuid.rs index 8cbcf3f58a80e..fe71cb5ff1183 100644 --- a/datafusion/common/src/types/canonical_extensions/uuid.rs +++ b/datafusion/common/src/types/canonical_extensions/uuid.rs @@ -16,13 +16,21 @@ // under the License. use crate::Result; -use crate::error::_internal_err; +use crate::cast::as_string_array; +use crate::error::{_exec_err, _internal_err}; +use crate::nested_struct::CastExtension; +use crate::types::DefaultExtensionCast; use crate::types::extension::DFExtensionType; -use arrow::array::{Array, FixedSizeBinaryArray}; +use arrow::array::{ + Array, ArrayRef, FixedSizeBinaryArray, builder::FixedSizeBinaryBuilder, +}; +use arrow::compute::{CastOptions, cast}; use arrow::datatypes::DataType; use arrow::util::display::{ArrayFormatter, DisplayIndex, FormatOptions, FormatResult}; +use arrow_schema::Field; use arrow_schema::extension::{ExtensionType, Uuid}; use std::fmt::Write; +use std::sync::Arc; use uuid::Bytes; /// Defines the extension type logic for the canonical `arrow.uuid` extension type. This extension @@ -45,6 +53,16 @@ impl DFUuid { ) -> Result { Ok(Self(::try_new(data_type, metadata)?)) } + + pub fn cast_extensions() -> Vec> { + vec![ + Arc::new( + DefaultExtensionCast::new(Uuid::NAME) + .with_default_cast_to_string(Some(Arc::new(DFUuid(Uuid)))), + ), + Arc::new(ParseUuid), + ] + } } impl DFExtensionType for DFUuid { @@ -98,6 +116,72 @@ impl DisplayIndex for UuidValueDisplayIndex<'_> { } } +#[derive(Debug)] +pub struct ParseUuid; + +impl CastExtension for ParseUuid { + fn can_cast_fields(&self, from: &Field, to: &Field) -> Result { + if from.extension_type_name().is_some() { + return Ok(false); + } + + if let Some(to_extension_name) = to.extension_type_name() + && to_extension_name == Uuid::NAME + { + Ok(matches!( + from.data_type(), + DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 + )) + } else { + Ok(false) + } + } + + fn cast_array_fields( + &self, + value: &ArrayRef, + from: &Field, + to: &Field, + options: &CastOptions, + ) -> Result { + if !self.can_cast_fields(from, to)? { + return _internal_err!("Unhandled cast"); + } + + match from.data_type() { + DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 => { + if options.safe { + return _exec_err!("Cast from UUID to string must be explicit"); + } + + let string_array_ref = cast(&value, &DataType::Utf8)?; + let string_array = as_string_array(&string_array_ref)?; + let mut builder = FixedSizeBinaryBuilder::new(16); + for string_opt in string_array { + match string_opt { + Some(string) => { + let uuid = uuid::Uuid::try_parse(string).map_err(|_| { + crate::DataFusionError::Execution(format!( + "Failed to parsed string '{string}' as UUID" + )) + })?; + builder.append_value(uuid.as_bytes())?; + } + None => { + builder.append_null(); + } + } + } + + return Ok(Arc::new(builder.finish())); + } + _ => {} + } + + _internal_err!("Unexpected difference between can_cast_from()") + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/common/src/types/extension.rs b/datafusion/common/src/types/extension.rs index 3bcb533dbf9e6..3d575d9936036 100644 --- a/datafusion/common/src/types/extension.rs +++ b/datafusion/common/src/types/extension.rs @@ -15,10 +15,13 @@ // specific language governing permissions and limitations // under the License. -use crate::error::Result; -use arrow::array::Array; +use crate::error::{_exec_err, _internal_err, Result}; +use crate::metadata::format_type_and_metadata; +use crate::nested_struct::CastExtension; +use arrow::array::{Array, ArrayRef, StringBuilder}; +use arrow::compute::CastOptions; use arrow::util::display::{ArrayFormatter, FormatOptions}; -use arrow_schema::DataType; +use arrow_schema::{DataType, Field}; use std::fmt::Debug; use std::sync::Arc; @@ -88,3 +91,150 @@ pub trait DFExtensionType: Debug + Send + Sync { Ok(None) } } + +#[derive(Debug)] +pub struct DefaultExtensionCast { + extension_name: &'static str, + instance: Option>, + can_cast_to_storage: bool, + can_cast_from_storage: bool, + use_default_cast_to_string: bool, +} + +impl DefaultExtensionCast { + pub fn new(extension_name: &'static str) -> Self { + Self { + extension_name, + instance: None, + can_cast_to_storage: true, + can_cast_from_storage: true, + use_default_cast_to_string: false, + } + } + + pub fn with_default_cast_to_string( + mut self, + instance: Option>, + ) -> Self { + self.use_default_cast_to_string = true; + self.instance = instance; + self + } + + fn is_cast_to_storage(&self, from: &Field, to: &Field) -> bool { + self.is_this_extension(from) + && !Self::is_any_extension(to) + && to.data_type() == from.data_type() + } + + fn is_cast_from_storage(&self, from: &Field, to: &Field) -> bool { + self.is_this_extension(to) + && !Self::is_any_extension(from) + && from.data_type() == to.data_type() + } + + fn is_cast_to_string(&self, from: &Field, to: &Field) -> bool { + self.is_this_extension(from) + && !Self::is_any_extension(to) + && matches!( + to.data_type(), + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View + ) + } + + fn is_this_extension(&self, field: &Field) -> bool { + if let Some(from_extension_name) = field.extension_type_name() + && from_extension_name == self.extension_name + { + true + } else { + false + } + } + + fn is_any_extension(field: &Field) -> bool { + field.extension_type_name().is_some() + } + + fn default_cast_to_string( + &self, + value: &ArrayRef, + to: &DataType, + ) -> Result { + let format_options = FormatOptions::default(); + + // Try to get a custom formatter from the extension type instance, + // otherwise fall back to the default formatter for the storage type + let formatter = if let Some(instance) = &self.instance { + match instance.create_array_formatter(value.as_ref(), &format_options)? { + Some(f) => f, + None => ArrayFormatter::try_new(value.as_ref(), &format_options)?, + } + } else { + ArrayFormatter::try_new(value.as_ref(), &format_options)? + }; + + // Format each value into a string type and cast to the target + let len = value.len(); + let mut builder = StringBuilder::with_capacity(len, len * 16); + for i in 0..len { + if value.is_null(i) { + builder.append_null(); + } else { + builder.append_value(formatter.value(i).to_string()); + } + } + + Ok(arrow::compute::cast(&builder.finish(), to)?) + } +} + +impl CastExtension for DefaultExtensionCast { + fn can_cast_fields(&self, from: &Field, to: &Field) -> Result { + if self.can_cast_to_storage && self.is_cast_to_storage(from, to) { + return Ok(true); + } + + if self.can_cast_from_storage && self.is_cast_from_storage(from, to) { + return Ok(true); + } + + if self.use_default_cast_to_string && self.is_cast_to_string(from, to) { + return Ok(true); + } + + Ok(false) + } + + fn cast_array_fields( + &self, + value: &ArrayRef, + from: &Field, + to: &Field, + options: &CastOptions, + ) -> Result { + if options.safe { + let from_display = + format_type_and_metadata(from.data_type(), Some(from.metadata())); + let to_display = + format_type_and_metadata(to.data_type(), Some(to.metadata())); + return _exec_err!( + "Can't cast from {from_display} to {to_display} with safe = true" + ); + } + + if self.can_cast_to_storage && self.is_cast_to_storage(from, to) { + return Ok(Arc::clone(value)); + } + + if self.can_cast_from_storage && self.is_cast_from_storage(from, to) { + return Ok(Arc::clone(value)); + } + + if self.use_default_cast_to_string && self.is_cast_to_string(from, to) { + return self.default_cast_to_string(value, to.data_type()); + } + + _internal_err!("Unhandled cast from {from} to {to} in default extension cast") + } +} diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index de5e6b97c1af9..b18104246304a 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -57,8 +57,8 @@ use datafusion_expr::planner::ExprPlanner; #[cfg(feature = "sql")] use datafusion_expr::planner::{RelationPlanner, TypePlanner}; use datafusion_expr::registry::{ - ExtensionTypeRegistryRef, FunctionRegistry, MemoryExtensionTypeRegistry, - SerializerRegistry, + ExtensionTypeRegistry, ExtensionTypeRegistryRef, FunctionRegistry, + MemoryExtensionTypeRegistry, SerializerRegistry, }; use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::{ @@ -1395,7 +1395,7 @@ impl SessionStateBuilder { self } - /// Sets the [`ExtensionTypeRegistry`](datafusion_expr::registry::ExtensionTypeRegistry). + /// Sets the [`ExtensionTypeRegistry`] pub fn with_extension_type_registry( mut self, registry: ExtensionTypeRegistryRef, @@ -1728,6 +1728,10 @@ impl SessionStateBuilder { } } + // Temporary hack while we figure out how to get the extension types where they + // need to go + state.execution_props.extension_types = Some(Arc::clone(&state.extension_types)); + state } @@ -2282,6 +2286,10 @@ impl OptimizerConfig for SessionState { fn function_registry(&self) -> Option<&dyn FunctionRegistry> { Some(self) } + + fn extension_types(&self) -> Option> { + Some(Arc::clone(&self.extension_types)) + } } /// Create a new task context instance from SessionState diff --git a/datafusion/core/tests/extension_types/pretty_printing.rs b/datafusion/core/tests/extension_types/pretty_printing.rs index c0796887b8b6e..aa0eaec3916d2 100644 --- a/datafusion/core/tests/extension_types/pretty_printing.rs +++ b/datafusion/core/tests/extension_types/pretty_printing.rs @@ -17,11 +17,13 @@ use arrow::array::{FixedSizeBinaryArray, RecordBatch}; use arrow_schema::extension::Uuid; -use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaRef}; +use datafusion::assert_batches_eq; use datafusion::dataframe::DataFrame; use datafusion::error::Result; use datafusion::execution::SessionStateBuilder; use datafusion::prelude::SessionContext; +use datafusion_expr::planner::TypePlanner; use datafusion_expr::registry::MemoryExtensionTypeRegistry; use insta::assert_snapshot; use std::sync::Arc; @@ -58,6 +60,8 @@ async fn create_test_table() -> Result { ctx.table("test").await } +// Test here + #[tokio::test] async fn test_pretty_print_extension_type_formatter() -> Result<()> { let result = create_test_table().await?.to_string().await?; @@ -76,3 +80,92 @@ async fn test_pretty_print_extension_type_formatter() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn create_cast_uuid_to_char() -> Result<()> { + let schema = test_schema(); + + // define data. + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(FixedSizeBinaryArray::from(vec![ + &[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 5, 6], + ]))], + )?; + + let state = SessionStateBuilder::default() + .with_type_planner(Arc::new(CustomTypePlanner {})) + .with_extension_type_registry(Arc::new( + MemoryExtensionTypeRegistry::new_with_canonical_extension_types(), + )) + .build(); + let ctx = SessionContext::new_with_state(state); + + ctx.register_batch("test", batch)?; + + let df = ctx.sql("SELECT my_uuids::VARCHAR FROM test").await?; + let batches = df.collect().await?; + + assert_batches_eq!( + [ + "+--------------------------------------+", + "| test.my_uuids |", + "+--------------------------------------+", + "| 00000000-0000-0000-0000-000000000000 |", + "| 00010203-0405-0607-0809-000102030506 |", + "+--------------------------------------+", + ], + &batches + ); + + Ok(()) +} + +#[tokio::test] +async fn create_cast_char_to_uuid() -> Result<()> { + let state = SessionStateBuilder::default() + .with_type_planner(Arc::new(CustomTypePlanner {})) + .with_extension_type_registry(Arc::new( + MemoryExtensionTypeRegistry::new_with_canonical_extension_types(), + )) + .build(); + let ctx = SessionContext::new_with_state(state); + + let df = ctx + .sql("SELECT '00010203-0405-0607-0809-000102030506'::UUID AS uuid") + .await?; + let batches = df.collect().await?; + assert_batches_eq!( + [ + "+----------------------------------+", + "| uuid |", + "+----------------------------------+", + "| 00010203040506070809000102030506 |", + "+----------------------------------+", + ], + &batches + ); + + Ok(()) +} + +#[derive(Debug)] +pub struct CustomTypePlanner {} + +impl TypePlanner for CustomTypePlanner { + fn plan_type_field( + &self, + sql_type: &sqlparser::ast::DataType, + ) -> Result> { + match sql_type { + sqlparser::ast::DataType::Uuid => Ok(Some(Arc::new( + Field::new("", DataType::FixedSizeBinary(16), true).with_metadata( + [("ARROW:extension:name".to_string(), "arrow.uuid".to_string())] + .into(), + ), + ))), + _ => Ok(None), + } + } +} diff --git a/datafusion/expr/src/execution_props.rs b/datafusion/expr/src/execution_props.rs index 649f74ed3997c..8f5e10cec3bfb 100644 --- a/datafusion/expr/src/execution_props.rs +++ b/datafusion/expr/src/execution_props.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::registry::ExtensionTypeRegistry; use crate::var_provider::{VarProvider, VarType}; use chrono::{DateTime, Utc}; use datafusion_common::HashMap; @@ -74,6 +75,7 @@ pub struct ExecutionProps { /// during physical planning. Populated by the physical planner for /// each lambda before calling `create_physical_expr`. pub lambda_variable_qualifier: HashMap, + pub extension_types: Option>, } impl Default for ExecutionProps { @@ -93,6 +95,7 @@ impl ExecutionProps { subquery_indexes: HashMap::new(), subquery_results: ScalarSubqueryResults::default(), lambda_variable_qualifier: HashMap::new(), + extension_types: None, } } @@ -274,7 +277,7 @@ mod test { fn debug() { let props = ExecutionProps::new(); assert_eq!( - "ExecutionProps { query_execution_start_time: None, alias_generator: AliasGenerator { next_id: 1 }, config_options: None, var_providers: None, subquery_indexes: {}, subquery_results: [], lambda_variable_qualifier: {} }", + "ExecutionProps { query_execution_start_time: None, alias_generator: AliasGenerator { next_id: 1 }, config_options: None, var_providers: None, subquery_indexes: {}, subquery_results: [], lambda_variable_qualifier: {}, extension_types: None }", format!("{props:?}") ); } diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 32a88ab8cf310..5478607994df1 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -22,14 +22,18 @@ use std::collections::HashSet; use std::fmt::Debug; use std::sync::Arc; -use crate::expr::{Alias, Sort, Unnest}; +use arrow::compute::can_cast_types; +use arrow::datatypes::FieldRef; + +use crate::expr::{Alias, Cast, Sort, Unnest}; +use crate::expr_schema::cast_subquery; use crate::logical_plan::Projection; use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; use datafusion_common::TableReference; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{Column, DFSchema, Result}; +use datafusion_common::{Column, DFSchema, ExprSchema, Result, plan_err}; mod guarantees; pub use guarantees::GuaranteeRewriter; @@ -252,11 +256,14 @@ fn coerce_exprs_for_schema( .into_iter() .enumerate() .map(|(idx, expr)| { - let new_type = dst_schema.field(idx).data_type(); + let dst_field = dst_schema.field(idx); + let new_type = dst_field.data_type(); if new_type != &expr.get_type(src_schema)? { match expr { Expr::Alias(Alias { expr, name, .. }) => { - Ok(expr.cast_to(new_type, src_schema)?.alias(name)) + // Use new_from_field to preserve metadata from dst_schema + Ok(cast_to_field(*expr, Arc::clone(dst_field), src_schema)? + .alias(name)) } #[expect(deprecated)] Expr::Wildcard { .. } => Ok(expr), @@ -267,9 +274,18 @@ fn coerce_exprs_for_schema( // (see: https://github.com/apache/datafusion/issues/18818) Expr::Column(ref column) => { let name = column.name().to_owned(); - Ok(expr.cast_to(new_type, src_schema)?.alias(name)) + // Use new_from_field to preserve metadata from dst_schema + Ok(cast_to_field( + expr, + Arc::clone(dst_field), + src_schema, + )? + .alias(name)) + } + _ => { + // Use new_from_field to preserve metadata from dst_schema + cast_to_field(expr, Arc::clone(dst_field), src_schema) } - _ => Ok(expr.cast_to(new_type, src_schema)?), } } } @@ -280,6 +296,48 @@ fn coerce_exprs_for_schema( .collect::>() } +// TODO: move to `ExprSchemable::cast_to_field`? + +/// Cast an expression to a target field, preserving field metadata. +/// This is similar to `ExprSchemable::cast_to` but uses the full field +/// (including metadata) rather than just the data type. +fn cast_to_field( + expr: Expr, + target_field: FieldRef, + schema: &dyn ExprSchema, +) -> Result { + use arrow::datatypes::DataType; + + let this_type = expr.get_type(schema)?; + let cast_to_type = target_field.data_type(); + if &this_type == cast_to_type { + return Ok(expr); + } + + // Special handling for struct-to-struct casts with name-based field matching + let can_cast = match (&this_type, cast_to_type) { + (DataType::Struct(_), DataType::Struct(_)) => { + // Always allow struct-to-struct casts; field matching happens at runtime + true + } + _ => can_cast_types(&this_type, cast_to_type), + }; + + if can_cast { + match expr { + Expr::ScalarSubquery(subquery) => { + Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?)) + } + _ => Ok(Expr::Cast(Cast::new_from_field( + Box::new(expr), + target_field, + ))), + } + } else { + plan_err!("Cannot automatically convert {this_type} to {cast_to_type}") + } +} + /// Recursively un-alias an expressions #[inline] pub fn unalias(expr: Expr) -> Expr { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index c989bab3048ad..22dcb697dd03d 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -73,17 +73,21 @@ pub trait ExprSchemable { /// For `TryCast`, `force_nullable` is `true` since a failed cast returns NULL. fn cast_output_field( source_field: &FieldRef, - target_type: &DataType, + target_field: &FieldRef, force_nullable: bool, ) -> Arc { - let mut f = source_field + // Do not propagate metadata through casts because extension metadata (1) + // should be derived from the target_field and (2) source extension metadata + // may become non-sensical if applied to an unrelated storage output type. + let mut f = target_field .as_ref() .clone() - .with_data_type(target_type.clone()) - .with_metadata(source_field.metadata().clone()); + .with_nullable(source_field.is_nullable()); + if force_nullable { f = f.with_nullable(true); } + Arc::new(f) } @@ -594,21 +598,16 @@ impl ExprSchemable for Expr { func.return_field_from_args(args) } - // _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), - Expr::Cast(Cast { expr, field }) => { - expr.to_field(schema).map(|(_table_ref, src)| { - cast_output_field(&src, field.data_type(), false) - }) - } + Expr::Cast(Cast { expr, field }) => expr + .to_field(schema) + .map(|(_table_ref, src)| cast_output_field(&src, field, false)), Expr::Placeholder(Placeholder { id: _, field: Some(field), }) => Ok(Arc::clone(field).renamed(&schema_name)), - Expr::TryCast(TryCast { expr, field }) => { - expr.to_field(schema).map(|(_table_ref, src)| { - cast_output_field(&src, field.data_type(), true) - }) - } + Expr::TryCast(TryCast { expr, field }) => expr + .to_field(schema) + .map(|(_table_ref, src)| cast_output_field(&src, field, true)), Expr::LambdaVariable(LambdaVariable { field: Some(field), .. }) => Ok(Arc::clone(field).renamed(&schema_name)), @@ -1043,17 +1042,9 @@ mod tests { .with_data_type(DataType::Int32) .with_metadata(meta.clone()); - // col, alias, and cast should be metadata-preserving + // col and alias should be metadata-preserving assert_eq!(meta, expr.metadata(&schema).unwrap()); assert_eq!(meta, expr.clone().alias("bar").metadata(&schema).unwrap()); - assert_eq!( - meta, - expr.clone() - .cast_to(&DataType::Int64, &schema) - .unwrap() - .metadata(&schema) - .unwrap() - ); let schema = DFSchema::from_unqualified_fields( vec![meta.add_to_field(Field::new("foo", DataType::Int32, true))].into(), diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index f03cc5936c6ed..848ad0434ea02 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -27,6 +27,7 @@ use arrow_schema::extension::{ Bool8, ExtensionType, FixedShapeTensor, Json, Opaque, TimestampWithOffset, Uuid, VariableShapeTensor, }; +use datafusion_common::nested_struct::{CastExtension, VecCastExtension}; use datafusion_common::types::{ DFBool8, DFExtensionTypeRef, DFFixedShapeTensor, DFJson, DFOpaque, DFTimestampWithOffset, DFUuid, DFVariableShapeTensor, @@ -346,6 +347,14 @@ pub trait ExtensionTypeRegistry: Debug + Send + Sync { &self, name: &str, ) -> Result>; + + fn cast_extension( + &self, + _source_field: &Field, + _target_field: &Field, + ) -> Option> { + None + } } /// A factory that creates instances of extension types from a storage [`DataType`] and the @@ -431,6 +440,7 @@ impl Debug for ExtensionTypeRegistration { pub struct MemoryExtensionTypeRegistry { /// Holds a mapping between the name of an extension type and its logical type. extension_types: Arc>>, + cast_extensions: Arc, } impl Default for MemoryExtensionTypeRegistry { @@ -444,6 +454,7 @@ impl MemoryExtensionTypeRegistry { pub fn new_empty() -> Self { Self { extension_types: Arc::new(RwLock::new(HashMap::new())), + cast_extensions: Arc::new(VecCastExtension::new(vec![])), } } @@ -504,6 +515,8 @@ impl MemoryExtensionTypeRegistry { ), ]; + let cast_extensions = DFUuid::cast_extensions(); + let mut extension_types = HashMap::new(); for registration in mapping.into_iter() { extension_types.insert(registration.type_name().to_owned(), registration); @@ -511,14 +524,11 @@ impl MemoryExtensionTypeRegistry { Self { extension_types: Arc::new(RwLock::new(HashMap::from(extension_types))), + cast_extensions: Arc::new(VecCastExtension::new(cast_extensions)), } } /// Creates a new [MemoryExtensionTypeRegistry] with the provided `types`. - /// - /// # Errors - /// - /// Returns an error if one of the `types` is a native type. pub fn new_with_types( types: impl IntoIterator, ) -> Result { @@ -528,6 +538,7 @@ impl MemoryExtensionTypeRegistry { .collect::>(); Ok(Self { extension_types: Arc::new(RwLock::new(extension_types)), + cast_extensions: Arc::new(VecCastExtension::new(vec![])), }) } @@ -585,12 +596,29 @@ impl ExtensionTypeRegistry for MemoryExtensionTypeRegistry { .expect("Extension type registry lock poisoned") .remove(name)) } + + fn cast_extension( + &self, + source_field: &Field, + target_field: &Field, + ) -> Option> { + if self + .cast_extensions + .can_cast_fields(source_field, target_field) + .unwrap_or(false) + { + Some(Arc::clone(&self.cast_extensions) as Arc) + } else { + None + } + } } impl From> for MemoryExtensionTypeRegistry { fn from(value: HashMap) -> Self { Self { extension_types: Arc::new(RwLock::new(value)), + cast_extensions: Arc::new(VecCastExtension::new(vec![])), } } } diff --git a/datafusion/expr/src/simplify.rs b/datafusion/expr/src/simplify.rs index 522cf122a273c..d2fe2739de150 100644 --- a/datafusion/expr/src/simplify.rs +++ b/datafusion/expr/src/simplify.rs @@ -24,6 +24,7 @@ use chrono::{DateTime, Utc}; use datafusion_common::config::ConfigOptions; use datafusion_common::{DFSchema, DFSchemaRef, Result}; +use crate::registry::ExtensionTypeRegistry; use crate::{Expr, ExprSchemable}; /// Provides simplification information based on schema, query execution time, @@ -38,6 +39,7 @@ pub struct SimplifyContext { schema: DFSchemaRef, query_execution_start_time: Option>, config_options: Arc, + extension_types: Option>, } /// Builder for [`SimplifyContext`]. @@ -46,6 +48,7 @@ pub struct SimplifyContextBuilder { schema: Option, query_execution_start_time: Option>, config_options: Option>, + extension_types: Option>, } impl Default for SimplifyContext { @@ -54,6 +57,7 @@ impl Default for SimplifyContext { schema: Arc::new(DFSchema::empty()), query_execution_start_time: None, config_options: Arc::new(ConfigOptions::default()), + extension_types: None, } } } @@ -137,6 +141,10 @@ impl SimplifyContext { pub fn config_options(&self) -> &Arc { &self.config_options } + + pub fn extension_types(&self) -> Option<&Arc> { + self.extension_types.as_ref() + } } impl SimplifyContextBuilder { @@ -167,6 +175,14 @@ impl SimplifyContextBuilder { self } + pub fn with_extension_types( + mut self, + extension_types: Option>, + ) -> Self { + self.extension_types = extension_types; + self + } + /// Build a [`SimplifyContext`], filling in any unspecified fields with defaults. pub fn build(self) -> SimplifyContext { SimplifyContext { @@ -175,6 +191,7 @@ impl SimplifyContextBuilder { config_options: self .config_options .unwrap_or_else(|| Arc::new(ConfigOptions::default())), + extension_types: self.extension_types, } } } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 7b81feab47a99..f55e86b63d2cb 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -1223,6 +1223,7 @@ fn coerce_union_schema_with_schema( ); } + // TODO: this type coercsion was causing an issue in one of the benchmark bits // coerce data type and nullability for each field for (union_datatype, union_nullable, union_field_map, plan_field) in izip!( union_datatypes.iter_mut(), diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index d0fbb31414dab..eef2db8b7b64e 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -21,7 +21,7 @@ use std::fmt::Debug; use std::sync::Arc; use chrono::{DateTime, Utc}; -use datafusion_expr::registry::FunctionRegistry; +use datafusion_expr::registry::{ExtensionTypeRegistry, FunctionRegistry}; use datafusion_expr::{InvariantLevel, assert_expected_schema}; use log::{debug, warn}; @@ -146,6 +146,10 @@ pub trait OptimizerConfig { fn function_registry(&self) -> Option<&dyn FunctionRegistry> { None } + + fn extension_types(&self) -> Option> { + None + } } /// A standalone [`OptimizerConfig`] that can be used independently diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 143d8eae695af..8be9b6b9a24da 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -43,6 +43,7 @@ use datafusion_expr::expr::HigherOrderFunction; use datafusion_expr::{ BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Like, Operator, Volatility, and, binary::BinaryTypeCoercer, lit, or, preimage::PreimageResult, + registry::ExtensionTypeRegistry, }; use datafusion_expr::{Cast, TryCast, simplify::ExprSimplifyResult}; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; @@ -211,7 +212,10 @@ impl ExprSimplifier { ) -> Result<(Transformed, u32)> { let mut simplifier = Simplifier::new(&self.info); let config_options = Some(Arc::clone(self.info.config_options())); - let mut const_evaluator = ConstEvaluator::try_new(config_options)?; + let mut const_evaluator = ConstEvaluator::try_new( + config_options, + self.info.extension_types().cloned(), + )?; let mut shorten_in_list_simplifier = ShortenInListSimplifier::new(); let guarantees_map: HashMap<&Expr, &NullableInterval> = self.guarantees.iter().map(|(k, v)| (k, v)).collect(); @@ -597,12 +601,16 @@ impl ConstEvaluator { /// /// The `config_options` parameter is used to pass session configuration /// (like timezone) to scalar functions during constant evaluation. - pub fn try_new(config_options: Option>) -> Result { + pub fn try_new( + config_options: Option>, + extension_types: Option>, + ) -> Result { // The dummy column name is unused and doesn't matter as only // expressions without column references can be evaluated let mut execution_props = ExecutionProps::new(); execution_props.config_options = config_options; + execution_props.extension_types = extension_types; Ok(Self { can_evaluate: vec![], diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index 3e495f5355103..240d7a257fc14 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -102,6 +102,7 @@ impl SimplifyExpressions { .with_schema(schema) .with_config_options(config.options()) .with_query_execution_start_time(config.query_execution_start_time()) + .with_extension_types(config.extension_types().clone()) .build(); // Inputs have already been rewritten (due to bottom-up traversal handled by Optimizer) diff --git a/datafusion/physical-expr-adapter/src/schema_rewriter.rs b/datafusion/physical-expr-adapter/src/schema_rewriter.rs index 9fb4950317ff8..d1e04c54f5f63 100644 --- a/datafusion/physical-expr-adapter/src/schema_rewriter.rs +++ b/datafusion/physical-expr-adapter/src/schema_rewriter.rs @@ -438,6 +438,7 @@ impl DefaultPhysicalExprAdapterRewriter { resolved_column.name(), physical_field.data_type(), logical_field.data_type(), + None // TODO: can we get a cast extension here? ) .map_err(|e| { DataFusionError::Execution(format!( @@ -451,6 +452,7 @@ impl DefaultPhysicalExprAdapterRewriter { Ok(Transformed::yes(Arc::new(CastExpr::new_with_target_field( Arc::new(resolved_column), Arc::new(logical_field.clone()), + None, // TODO: can we get a cast extension here? None, )))) } @@ -854,6 +856,7 @@ mod tests { Arc::new(Column::new("data", 0)), logical_field, None, + None, )) as Arc; assert_eq!(result.to_string(), expected.to_string()); diff --git a/datafusion/physical-expr/src/equivalence/properties/dependency.rs b/datafusion/physical-expr/src/equivalence/properties/dependency.rs index 2ebc71559fcf4..d3fa5e0f4b8a0 100644 --- a/datafusion/physical-expr/src/equivalence/properties/dependency.rs +++ b/datafusion/physical-expr/src/equivalence/properties/dependency.rs @@ -942,6 +942,7 @@ mod tests { col_c, Arc::new(Field::new("c", DataType::Date32, true)), None, + None, )) as _; let required_sort = vec![PhysicalSortExpr::new_default(col("c", &schema)?)]; diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index ad214a89ceb71..189b75e5096e0 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -26,10 +26,11 @@ use arrow::datatypes::{DataType, DataType::*, FieldRef, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::datatype::DataTypeExt; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; +use datafusion_common::metadata::format_type_and_metadata; use datafusion_common::nested_struct::{ - requires_nested_struct_cast, validate_data_type_compatibility, + CastExtension, requires_nested_struct_cast, validate_data_type_compatibility, }; -use datafusion_common::{Result, not_impl_err}; +use datafusion_common::{Result, ScalarValue, not_impl_err}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; use datafusion_expr_common::sort_properties::ExprProperties; @@ -50,12 +51,16 @@ const DEFAULT_SAFE_CAST_OPTIONS: CastOptions<'static> = CastOptions { /// planning-time validation matches runtime validation, enabling fail-fast behavior /// instead of deferring errors to execution. Handles structs at any nesting level /// (e.g., `List`, `Dictionary<_, Struct>`). -fn can_cast_named_struct_types(source: &DataType, target: &DataType) -> bool { - validate_data_type_compatibility("", source, target).is_ok() +fn can_cast_named_struct_types( + source: &DataType, + target: &DataType, + cast_extension: Option<&dyn CastExtension>, +) -> bool { + validate_data_type_compatibility("", source, target, cast_extension).is_ok() } /// CAST expression casts an expression to a specific data type and returns a runtime error on invalid cast -#[derive(Debug, Clone, Eq)] +#[derive(Debug, Clone)] pub struct CastExpr { /// The expression to cast pub expr: Arc, @@ -63,6 +68,8 @@ pub struct CastExpr { target_field: FieldRef, /// Cast options cast_options: CastOptions<'static>, + // CastExtension + cast_extension: Option>, } // Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808 @@ -82,6 +89,8 @@ impl Hash for CastExpr { } } +impl Eq for CastExpr {} + impl CastExpr { /// Create a new `CastExpr` using only a `DataType`. /// @@ -105,6 +114,7 @@ impl CastExpr { Self::new_with_target_field( expr, cast_type.into_nullable_field_ref(), + None, cast_options, ) } @@ -121,12 +131,14 @@ impl CastExpr { pub fn new_with_target_field( expr: Arc, target_field: FieldRef, + cast_extension: Option>, cast_options: Option>, ) -> Self { Self { expr, target_field, cast_options: cast_options.unwrap_or(DEFAULT_CAST_OPTIONS), + cast_extension, } } @@ -153,12 +165,20 @@ impl CastExpr { fn resolved_target_field(&self, input_schema: &Schema) -> Result { if is_default_target_field(&self.target_field) { self.expr.return_field(input_schema).map(|field| { - Arc::new( - field - .as_ref() - .clone() - .with_data_type(self.cast_type().clone()), - ) + let cast_type = self.cast_type(); + let mut out_field = + field.as_ref().clone().with_data_type(cast_type.clone()); + + // If we modify the storage type we can't ensure that the metadata + // is valid on the target type (e.g., a cast from UUID with extension + // metadata to Utf8 should not result in extension metadata + // on a Utf8 type, which would be invalid and may be rejected by + // consumers). + if field.data_type() != cast_type { + out_field = out_field.with_metadata(Default::default()); + } + + Arc::new(out_field) }) } else { Ok(Arc::clone(&self.target_field)) @@ -213,6 +233,7 @@ pub(crate) fn cast_expr_properties( child: &ExprProperties, target_type: &DataType, ) -> Result { + // TODO check the cast extension for this property let unbounded = Interval::make_unbounded(target_type)?; if is_order_preserving_cast_family(&child.range.data_type(), target_type) { Ok(child.clone().with_range(unbounded)) @@ -245,7 +266,37 @@ impl PhysicalExpr for CastExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let value = self.expr.evaluate(batch)?; - value.cast_to(self.cast_type(), Some(&self.cast_options)) + if let Some(cast_extension) = &self.cast_extension { + let from_field = self.expr.return_field(&batch.schema())?; + let to_field = self.return_field(&batch.schema())?; + match value { + ColumnarValue::Array(array) => { + Ok(ColumnarValue::Array(cast_extension.cast_array_fields( + &array, + &from_field, + &to_field, + &self.cast_options, + )?)) + } + ColumnarValue::Scalar(scalar_value) => { + let array = scalar_value.to_array()?; + let array_result = cast_extension.cast_array_fields( + &array, + &from_field, + &to_field, + &self.cast_options, + )?; + Ok(ColumnarValue::Scalar(ScalarValue::try_from_array( + &array_result, + 0, + )?)) + } + } + } else { + // TODO: this should use the struct casting directly so we can pass on the + // cast extension + value.cast_to(self.cast_type(), Some(&self.cast_options)) + } } fn return_field(&self, input_schema: &Schema) -> Result { @@ -263,12 +314,14 @@ impl PhysicalExpr for CastExpr { Ok(Arc::new(CastExpr::new_with_target_field( Arc::clone(&children[0]), Arc::clone(&self.target_field), + self.cast_extension.clone(), Some(self.cast_options.clone()), ))) } fn evaluate_bounds(&self, children: &[&Interval]) -> Result { // Cast current node's interval to the right type: + // TODO: check the cast extension or cast the interval children[0].cast_to(self.cast_type(), &self.cast_options) } @@ -277,6 +330,7 @@ impl PhysicalExpr for CastExpr { interval: &Interval, children: &[&Interval], ) -> Result>> { + // Check cast extension for this let child_interval = children[0]; // Get child's datatype: let cast_type = child_interval.data_type(); @@ -314,6 +368,7 @@ pub fn cast_with_options( expr, input_schema, cast_type.into_nullable_field_ref(), + None, cast_options, ) } @@ -330,32 +385,52 @@ pub fn cast_with_target_field( expr: Arc, input_schema: &Schema, target_field: FieldRef, + cast_extension: Option>, cast_options: Option>, ) -> Result> { - let expr_type = expr.data_type(input_schema)?; + let expr_field = expr.return_field(input_schema)?; + if let Some(cast_extension_ref) = cast_extension.as_deref() + && cast_extension_ref.can_cast_fields(&expr_field, target_field.as_ref())? + { + return Ok(Arc::new(CastExpr::new_with_target_field( + expr, + target_field, + cast_extension, + cast_options, + ))); + } + + let expr_type = expr_field.data_type(); let cast_type = target_field.data_type(); - if expr_type == *cast_type && is_default_target_field(&target_field) { + if expr_type == cast_type && is_default_target_field(&target_field) { return Ok(Arc::clone(&expr)); } - let can_build_cast = if requires_nested_struct_cast(&expr_type, cast_type) { + let can_build_cast = if requires_nested_struct_cast(expr_type, cast_type) { // Allow casts involving structs (including nested inside Lists, Dictionaries, // etc.) that pass name-based compatibility validation. This validation is // applied at planning time (now) to fail fast, rather than deferring errors // to execution time. The name-based casting logic will be executed at runtime // via ColumnarValue::cast_to. - can_cast_named_struct_types(&expr_type, cast_type) + // TODO: we can pass the cast extension here if we will end up using it for + // the nested casting + can_cast_named_struct_types(expr_type, cast_type, None) } else { - can_cast_types(&expr_type, cast_type) + can_cast_types(expr_type, cast_type) }; + let source_fmt = format_type_and_metadata(expr_type, Some(expr_field.metadata())); + let target_fmt = + format_type_and_metadata(target_field.data_type(), Some(target_field.metadata())); if !can_build_cast { - return not_impl_err!("Unsupported CAST from {expr_type} to {cast_type}"); + return not_impl_err!("Unsupported CAST from {source_fmt} to {target_fmt}"); } + // TODO: pass the cast extension here anyway so that nested casts work Ok(Arc::new(CastExpr::new_with_target_field( expr, target_field, + None, cast_options, ))) } @@ -414,6 +489,7 @@ mod tests { col(column, schema.as_ref())?, Arc::new(target_field), None, + None, ); let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; @@ -939,6 +1015,7 @@ mod tests { .with_metadata(metadata.clone()), ), None, + None, ); let field = expr.return_field(&schema)?; @@ -1108,7 +1185,8 @@ mod tests { let literal = Arc::new(crate::expressions::Literal::new(ScalarValue::Struct( Arc::new(scalar_struct), ))); - let expr = CastExpr::new_with_target_field(literal, Arc::new(target_field), None); + let expr = + CastExpr::new_with_target_field(literal, Arc::new(target_field), None, None); let batch = RecordBatch::new_empty(schema); let result = expr.evaluate(&batch)?; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index d0d0508a106a5..756fd853fafce 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -294,12 +294,23 @@ pub fn create_physical_expr( }; Ok(expressions::case(expr, when_then_expr, else_expr)?) } - Expr::Cast(Cast { expr, field }) => expressions::cast_with_target_field( - create_physical_expr(expr, input_dfschema, execution_props)?, - input_schema, - Arc::clone(field), - None, - ), + Expr::Cast(Cast { expr, field }) => { + let (_, src_field) = expr.to_field(input_dfschema)?; + let cast_extension = + if let Some(extension_types) = &execution_props.extension_types { + extension_types.cast_extension(&src_field, field) + } else { + None + }; + + expressions::cast_with_target_field( + create_physical_expr(expr, input_dfschema, execution_props)?, + input_schema, + Arc::clone(field), + cast_extension, + None, + ) + } Expr::TryCast(TryCast { expr, field }) => { if !field.metadata().is_empty() { let (_, src_field) = expr.to_field(input_dfschema)?; diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index 0dafcf6bd3390..45a14f00c8735 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -163,6 +163,7 @@ pub fn project_plan_to_schema( Arc::new(CastExpr::new_with_target_field( column, Arc::clone(expected_field), + None, // TODO: can we get a cast extension here? None, )) as _ } else { diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index ec9ea376e0b6d..e17d363dd93b0 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -1307,4 +1307,108 @@ mod tests { )); Ok(()) } + + #[test] + fn test_union_schema_metadata_preservation() { + use crate::empty::EmptyExec; + use std::collections::HashMap; + + // Create schemas - one with metadata, one without + let mut metadata = HashMap::new(); + metadata.insert("key".to_string(), "value".to_string()); + + let schema_with_metadata = Arc::new(Schema::new(vec![ + Field::new("name", DataType::Utf8, true).with_metadata(metadata.clone()), + ])); + + let schema_without_metadata = + Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)])); + + // Create two EmptyExec plans with different schemas + let input1 = Arc::new(EmptyExec::new(Arc::clone(&schema_with_metadata))); + let input2 = Arc::new(EmptyExec::new(Arc::clone(&schema_without_metadata))); + + // Test both orderings + let inputs_with_first = vec![ + Arc::clone(&input1) as Arc, + Arc::clone(&input2) as Arc, + ]; + let inputs_without_first = vec![ + Arc::clone(&input2) as Arc, + Arc::clone(&input1) as Arc, + ]; + + // Call union_schema directly + let result1 = union_schema(&inputs_with_first).unwrap(); + let result2 = union_schema(&inputs_without_first).unwrap(); + + // Both should have the metadata + assert!( + !result1.field(0).metadata().is_empty(), + "Expected metadata in result1 (with metadata first), got empty" + ); + assert!( + !result2.field(0).metadata().is_empty(), + "Expected metadata in result2 (without metadata first), got empty" + ); + assert_eq!( + result1.field(0).metadata().get("key"), + Some(&"value".to_string()) + ); + assert_eq!( + result2.field(0).metadata().get("key"), + Some(&"value".to_string()) + ); + } + + #[test] + fn test_union_schema_metadata_with_non_nullable() { + use crate::empty::EmptyExec; + use std::collections::HashMap; + + // Test case that matches the failing test: + // input 0: nonnull_name (NOT nullable, has metadata) + // input 1: NULL::string (nullable, no metadata) + + let mut metadata = HashMap::new(); + metadata.insert("key".to_string(), "value".to_string()); + + // input 0: NOT nullable, has metadata + let schema_with_metadata = Arc::new(Schema::new(vec![ + Field::new( + "name", + DataType::Utf8, + false, // NOT nullable + ) + .with_metadata(metadata.clone()), + ])); + + // input 1: nullable, no metadata + let schema_without_metadata = + Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)])); // nullable + + let input1 = Arc::new(EmptyExec::new(Arc::clone(&schema_with_metadata))); + let input2 = Arc::new(EmptyExec::new(Arc::clone(&schema_without_metadata))); + + let inputs = vec![ + Arc::clone(&input1) as Arc, + Arc::clone(&input2) as Arc, + ]; + + let result = union_schema(&inputs).unwrap(); + + // The result should be nullable (since one input is nullable) and have metadata + assert!( + result.field(0).is_nullable(), + "Expected nullable field in union result" + ); + assert!( + !result.field(0).metadata().is_empty(), + "Expected metadata preserved from non-nullable input, got empty" + ); + assert_eq!( + result.field(0).metadata().get("key"), + Some(&"value".to_string()) + ); + } } diff --git a/datafusion/pruning/src/pruning_predicate.rs b/datafusion/pruning/src/pruning_predicate.rs index 76cf14be88f5a..4027b03c66c42 100644 --- a/datafusion/pruning/src/pruning_predicate.rs +++ b/datafusion/pruning/src/pruning_predicate.rs @@ -1126,6 +1126,7 @@ fn rewrite_expr_to_prunable( let left = Arc::new(phys_expr::CastExpr::new_with_target_field( left, Arc::clone(cast.target_field()), + None, // TODO: can we get a CastExtension here? None, )); // PruningPredicate does not support pruning on nested fields yet. diff --git a/datafusion/sqllogictest/test_files/metadata.slt b/datafusion/sqllogictest/test_files/metadata.slt index 3fea8df260f05..1d2cdf494005b 100644 --- a/datafusion/sqllogictest/test_files/metadata.slt +++ b/datafusion/sqllogictest/test_files/metadata.slt @@ -218,33 +218,33 @@ FROM table_with_metadata; 2020-09-08 2020-09-08 -# Regression test: CAST should preserve source field metadata +# CAST should not preserve source field metadata query DT SELECT CAST(ts AS DATE) as casted, arrow_metadata(CAST(ts AS DATE), 'metadata_key') FROM table_with_metadata; ---- -2020-09-08 ts non-nullable field -2020-09-08 ts non-nullable field -2020-09-08 ts non-nullable field +2020-09-08 NULL +2020-09-08 NULL +2020-09-08 NULL -# Regression test: CAST preserves metadata on integer column +# CAST should not preserve metadata on integer column query IT SELECT CAST(id AS BIGINT) as casted, arrow_metadata(CAST(id AS BIGINT), 'metadata_key') FROM table_with_metadata; ---- -1 the id field -NULL the id field -3 the id field +1 NULL +NULL NULL +3 NULL -# Regression test: CAST with single-argument arrow_metadata (returns full map) +# CAST with single-argument arrow_metadata (returns full map) query ? select arrow_metadata(CAST(id AS BIGINT)) from table_with_metadata limit 1; ---- -{metadata_key: the id field} +{} # Regression test: distinct with cast query D