diff --git a/datafusion/substrait/src/logical_plan/consumer/types.rs b/datafusion/substrait/src/logical_plan/consumer/types.rs index 4fc7a92804b4..80300af24ac4 100644 --- a/datafusion/substrait/src/logical_plan/consumer/types.rs +++ b/datafusion/substrait/src/logical_plan/consumer/types.rs @@ -22,7 +22,8 @@ use crate::variation_const::{ DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_INTERVAL_DAY_TYPE_VARIATION_REF, - DEFAULT_TYPE_VARIATION_REF, DURATION_INTERVAL_DAY_TYPE_VARIATION_REF, + DEFAULT_MAP_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + DICTIONARY_MAP_TYPE_VARIATION_REF, DURATION_INTERVAL_DAY_TYPE_VARIATION_REF, INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_NAME, INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF, LARGE_CONTAINER_TYPE_VARIATION_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, @@ -177,24 +178,32 @@ pub fn from_substrait_type( let value_type = map.value.as_ref().ok_or_else(|| { substrait_datafusion_err!("Map type must have value type") })?; - let key_field = Arc::new(Field::new( - "key", - from_substrait_type(consumer, key_type, dfs_names, name_idx)?, - false, - )); - let value_field = Arc::new(Field::new( - "value", - from_substrait_type(consumer, value_type, dfs_names, name_idx)?, - true, - )); - Ok(DataType::Map( - Arc::new(Field::new_struct( - "entries", - [key_field, value_field], - false, // The inner map field is always non-nullable (Arrow #1697), + let key_type = + from_substrait_type(consumer, key_type, dfs_names, name_idx)?; + let value_type = + from_substrait_type(consumer, value_type, dfs_names, name_idx)?; + + match map.type_variation_reference { + DEFAULT_MAP_TYPE_VARIATION_REF => { + let key_field = Arc::new(Field::new("key", key_type, false)); + let value_field = Arc::new(Field::new("value", value_type, true)); + Ok(DataType::Map( + Arc::new(Field::new_struct( + "entries", + [key_field, value_field], + false, // The inner map field is always non-nullable (Arrow #1697), + )), + false, // whether keys are sorted + )) + } + DICTIONARY_MAP_TYPE_VARIATION_REF => Ok(DataType::Dictionary( + Box::new(key_type), + Box::new(value_type), )), - false, // whether keys are sorted - )) + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + } } r#type::Kind::Decimal(d) => match d.type_variation_reference { DECIMAL_128_TYPE_VARIATION_REF => { diff --git a/datafusion/substrait/src/logical_plan/producer/types.rs b/datafusion/substrait/src/logical_plan/producer/types.rs index 0c9266347529..d819c2042c08 100644 --- a/datafusion/substrait/src/logical_plan/producer/types.rs +++ b/datafusion/substrait/src/logical_plan/producer/types.rs @@ -21,7 +21,8 @@ use crate::variation_const::{ DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_INTERVAL_DAY_TYPE_VARIATION_REF, - DEFAULT_TYPE_VARIATION_REF, DURATION_INTERVAL_DAY_TYPE_VARIATION_REF, + DEFAULT_MAP_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + DICTIONARY_MAP_TYPE_VARIATION_REF, DURATION_INTERVAL_DAY_TYPE_VARIATION_REF, LARGE_CONTAINER_TYPE_VARIATION_REF, TIME_32_TYPE_VARIATION_REF, TIME_64_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, VIEW_CONTAINER_TYPE_VARIATION_REF, @@ -276,13 +277,25 @@ pub(crate) fn to_substrait_type( kind: Some(r#type::Kind::Map(Box::new(r#type::Map { key: Some(Box::new(key_type)), value: Some(Box::new(value_type)), - type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, + type_variation_reference: DEFAULT_MAP_TYPE_VARIATION_REF, nullability, }))), }) } _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), }, + DataType::Dictionary(key_type, value_type) => { + let key_type = to_substrait_type(key_type, nullable)?; + let value_type = to_substrait_type(value_type, nullable)?; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Map(Box::new(r#type::Map { + key: Some(Box::new(key_type)), + value: Some(Box::new(value_type)), + type_variation_reference: DICTIONARY_MAP_TYPE_VARIATION_REF, + nullability, + }))), + }) + } DataType::Struct(fields) => { let field_types = fields .iter() @@ -407,6 +420,10 @@ mod tests { .into(), false, ))?; + round_trip_type(DataType::Dictionary( + Box::new(DataType::Utf8), + Box::new(DataType::Int32), + ))?; round_trip_type(DataType::Struct( vec![ diff --git a/datafusion/substrait/src/variation_const.rs b/datafusion/substrait/src/variation_const.rs index 74fc6035efae..a967e7d5ae48 100644 --- a/datafusion/substrait/src/variation_const.rs +++ b/datafusion/substrait/src/variation_const.rs @@ -55,6 +55,8 @@ pub const TIME_64_TYPE_VARIATION_REF: u32 = 1; pub const DEFAULT_CONTAINER_TYPE_VARIATION_REF: u32 = 0; pub const LARGE_CONTAINER_TYPE_VARIATION_REF: u32 = 1; pub const VIEW_CONTAINER_TYPE_VARIATION_REF: u32 = 2; +pub const DEFAULT_MAP_TYPE_VARIATION_REF: u32 = 0; +pub const DICTIONARY_MAP_TYPE_VARIATION_REF: u32 = 1; pub const DECIMAL_128_TYPE_VARIATION_REF: u32 = 0; pub const DECIMAL_256_TYPE_VARIATION_REF: u32 = 1; /// Used for the arrow type [`DataType::Interval`] with [`IntervalUnit::DayTime`].