diff --git a/datafusion/substrait/src/logical_plan/producer/expr/literal.rs b/datafusion/substrait/src/logical_plan/producer/expr/literal.rs index 8882c992dca1c..1a0f12c7b32a9 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/literal.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/literal.rs @@ -56,6 +56,10 @@ pub(crate) fn to_substrait_literal( producer: &mut impl SubstraitProducer, value: &ScalarValue, ) -> datafusion::common::Result { + if let ScalarValue::Dictionary(_, value) = value { + return to_substrait_literal(producer, value.as_ref()); + } + if value.is_null() { return Ok(Literal { nullable: true, @@ -558,4 +562,32 @@ mod tests { assert_eq!(scalar, roundtrip_scalar); Ok(()) } + + #[test] + fn dictionary_literals_are_serialized_as_inner_values() -> Result<()> { + let state = SessionContext::default().state(); + let mut producer = DefaultSubstraitProducer::new(&state); + let scalar = ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::Utf8(Some("req.latency".to_string()))), + ); + + let substrait_literal = to_substrait_literal(&mut producer, &scalar)?; + assert!(matches!( + substrait_literal.literal_type.as_ref(), + Some(LiteralType::String(value)) if value == "req.latency" + )); + assert_eq!( + substrait_literal.type_variation_reference, + DEFAULT_CONTAINER_TYPE_VARIATION_REF + ); + + let roundtrip_scalar = + from_substrait_literal_without_names(&test_consumer(), &substrait_literal)?; + assert_eq!( + ScalarValue::Utf8(Some("req.latency".to_string())), + roundtrip_scalar + ); + Ok(()) + } }