From 88679ea64a2ee6cf4c65b7ab7bcede77d0a8991a Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 3 Jan 2024 09:53:10 -0500 Subject: [PATCH 1/3] revert eb8aff7becaf5d4a44c723b29445deb958fbe3b4 / Materialize dictionaries in group keys --- datafusion/core/tests/path_partition.rs | 15 ++++++-- .../src/aggregates/group_values/row.rs | 27 +++++++++++--- .../physical-plan/src/aggregates/mod.rs | 35 ++----------------- .../physical-plan/src/aggregates/row_hash.rs | 4 +-- .../sqllogictest/test_files/aggregate.slt | 9 ----- 5 files changed, 39 insertions(+), 51 deletions(-) diff --git a/datafusion/core/tests/path_partition.rs b/datafusion/core/tests/path_partition.rs index abe6ab283aff..dd8eb52f67c7 100644 --- a/datafusion/core/tests/path_partition.rs +++ b/datafusion/core/tests/path_partition.rs @@ -168,9 +168,9 @@ async fn parquet_distinct_partition_col() -> Result<()> { assert_eq!(min_limit, resulting_limit); let s = ScalarValue::try_from_array(results[0].column(1), 0)?; - let month = match s { - ScalarValue::Utf8(Some(month)) => month, - s => panic!("Expected month as Utf8 found {s:?}"), + let month = match extract_as_utf(&s) { + Some(month) => month, + s => panic!("Expected month as Dict(_, Utf8) found {s:?}"), }; let sql_on_partition_boundary = format!( @@ -191,6 +191,15 @@ async fn parquet_distinct_partition_col() -> Result<()> { Ok(()) } +fn extract_as_utf(v: &ScalarValue) -> Option { + if let ScalarValue::Dictionary(_, v) = v { + if let ScalarValue::Utf8(v) = v.as_ref() { + return v.clone(); + } + } + None +} + #[tokio::test] async fn csv_filter_with_file_col() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index e7c7a42cf902..10ff9edb8912 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -17,18 +17,22 @@ use crate::aggregates::group_values::GroupValues; use ahash::RandomState; +use arrow::compute::cast; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, Rows, SortField}; -use arrow_array::ArrayRef; -use arrow_schema::SchemaRef; +use arrow_array::{Array, ArrayRef}; +use arrow_schema::{DataType, SchemaRef}; use datafusion_common::hash_utils::create_hashes; -use datafusion_common::Result; +use datafusion_common::{DataFusionError, Result}; use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; use datafusion_physical_expr::EmitTo; use hashbrown::raw::RawTable; /// A [`GroupValues`] making use of [`Rows`] pub struct GroupValuesRows { + /// The output schema + schema: SchemaRef, + /// Converter for the group values row_converter: RowConverter, @@ -75,6 +79,7 @@ impl GroupValuesRows { let map = RawTable::with_capacity(0); Ok(Self { + schema, row_converter, map, map_size: 0, @@ -165,7 +170,7 @@ impl GroupValues for GroupValuesRows { .take() .expect("Can not emit from empty rows"); - let output = match emit_to { + let mut output = match emit_to { EmitTo::All => { let output = self.row_converter.convert_rows(&group_values)?; group_values.clear(); @@ -198,6 +203,20 @@ impl GroupValues for GroupValuesRows { } }; + // TODO: Materialize dictionaries in group keys (#7647) + for (field, array) in self.schema.fields.iter().zip(&mut output) { + let expected = field.data_type(); + if let DataType::Dictionary(_, v) = expected { + let actual = array.data_type(); + if v.as_ref() != actual { + return Err(DataFusionError::Internal(format!( + "Converted group rows expected dictionary of {v} got {actual}" + ))); + } + *array = cast(array.as_ref(), expected)?; + } + } + self.group_values = Some(group_values); Ok(output) } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index a38044de02e3..0b94dd01cfd4 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -36,7 +36,6 @@ use crate::{ use arrow::array::ArrayRef; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use arrow_schema::DataType; use datafusion_common::stats::Precision; use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_execution::TaskContext; @@ -254,9 +253,6 @@ pub struct AggregateExec { limit: Option, /// Input plan, could be a partial aggregate or the input to the aggregate pub input: Arc, - /// Original aggregation schema, could be different from `schema` before dictionary group - /// keys get materialized - original_schema: SchemaRef, /// Schema after the aggregate is applied schema: SchemaRef, /// Input schema before any aggregation is applied. For partial aggregate this will be the @@ -287,7 +283,7 @@ impl AggregateExec { input: Arc, input_schema: SchemaRef, ) -> Result { - let original_schema = create_schema( + let schema = create_schema( &input.schema(), &group_by.expr, &aggr_expr, @@ -295,11 +291,7 @@ impl AggregateExec { mode, )?; - let schema = Arc::new(materialize_dict_group_keys( - &original_schema, - group_by.expr.len(), - )); - let original_schema = Arc::new(original_schema); + let schema = Arc::new(schema); AggregateExec::try_new_with_schema( mode, group_by, @@ -308,7 +300,6 @@ impl AggregateExec { input, input_schema, schema, - original_schema, ) } @@ -329,7 +320,6 @@ impl AggregateExec { input: Arc, input_schema: SchemaRef, schema: SchemaRef, - original_schema: SchemaRef, ) -> Result { let input_eq_properties = input.equivalence_properties(); // Get GROUP BY expressions: @@ -382,7 +372,6 @@ impl AggregateExec { aggr_expr, filter_expr, input, - original_schema, schema, input_schema, projection_mapping, @@ -693,7 +682,7 @@ impl ExecutionPlan for AggregateExec { children[0].clone(), self.input_schema.clone(), self.schema.clone(), - self.original_schema.clone(), + //self.original_schema.clone(), )?; me.limit = self.limit; Ok(Arc::new(me)) @@ -800,24 +789,6 @@ fn create_schema( Ok(Schema::new(fields)) } -/// returns schema with dictionary group keys materialized as their value types -/// The actual convertion happens in `RowConverter` and we don't do unnecessary -/// conversion back into dictionaries -fn materialize_dict_group_keys(schema: &Schema, group_count: usize) -> Schema { - let fields = schema - .fields - .iter() - .enumerate() - .map(|(i, field)| match field.data_type() { - DataType::Dictionary(_, value_data_type) if i < group_count => { - Field::new(field.name(), *value_data_type.clone(), field.is_nullable()) - } - _ => Field::clone(field), - }) - .collect::>(); - Schema::new(fields) -} - fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef { let group_fields = schema.fields()[0..group_count].to_vec(); Arc::new(Schema::new(group_fields)) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 89614fd3020c..6a0c02f5caf3 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -324,9 +324,7 @@ impl GroupedHashAggregateStream { .map(create_group_accumulator) .collect::>()?; - // we need to use original schema so RowConverter in group_values below - // will do the proper coversion of dictionaries into value types - let group_schema = group_schema(&agg.original_schema, agg_group_by.expr.len()); + let group_schema = group_schema(&agg_schema, agg_group_by.expr.len()); let spill_expr = group_schema .fields .into_iter() diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 78575c9dffc5..96af710f99fb 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2466,15 +2466,6 @@ select max(x_dict) from value_dict group by x_dict % 2 order by max(x_dict); 4 5 -query T -select arrow_typeof(x_dict) from value_dict group by x_dict; ----- -Int32 -Int32 -Int32 -Int32 -Int32 - statement ok drop table value From c180c5d9d502a9d10dbc19855bba94ea7a8192d9 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 5 Jan 2024 14:35:21 -0500 Subject: [PATCH 2/3] Update tests --- .../sqllogictest/test_files/aggregate.slt | 13 +++++++++-- .../sqllogictest/test_files/dictionary.slt | 23 ++++++++++++++++++- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 96af710f99fb..23989d8561ac 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2466,6 +2466,15 @@ select max(x_dict) from value_dict group by x_dict % 2 order by max(x_dict); 4 5 +query T +select arrow_typeof(x_dict) from value_dict group by x_dict; +---- +Dictionary(Int64, Int32) +Dictionary(Int64, Int32) +Dictionary(Int64, Int32) +Dictionary(Int64, Int32) +Dictionary(Int64, Int32) + statement ok drop table value @@ -2560,7 +2569,7 @@ select trace_id, other, MIN(timestamp) from traces group by trace_id, other orde a -1 -1 b 0 0 NULL 0 0 -c 1 1 +a 1 1 query TII select trace_id, MIN(other), MIN(timestamp) from traces group by trace_id order by MIN(timestamp), MIN(other) limit 4; @@ -2685,7 +2694,7 @@ select trace_id, other, MIN(timestamp) from traces group by trace_id, other orde a -1 -1 b 0 0 NULL 0 0 -c 1 1 +a 1 1 query TII select trace_id, MIN(other), MIN(timestamp) from traces group by trace_id order by MIN(timestamp), MIN(other) limit 4; diff --git a/datafusion/sqllogictest/test_files/dictionary.slt b/datafusion/sqllogictest/test_files/dictionary.slt index d4ad46711b9f..ff49070f4590 100644 --- a/datafusion/sqllogictest/test_files/dictionary.slt +++ b/datafusion/sqllogictest/test_files/dictionary.slt @@ -169,7 +169,7 @@ order by date_bin('30 minutes', time) DESC # Reproducer for https://github.com/apache/arrow-datafusion/issues/8738 # This query should work correctly -query error DataFusion error: External error: Arrow error: Invalid argument error: RowConverter column schema mismatch, expected Utf8 got Dictionary\(Int32, Utf8\) +query P?TT SELECT "data"."timestamp" as "time", "data"."tag_id", @@ -201,3 +201,24 @@ ORDER BY "time", "data"."tag_id" ; +---- +2023-12-20T00:00:00 1000 f1 32.0 +2023-12-20T00:00:00 1000 f2 foo +2023-12-20T00:10:00 1000 f2 foo +2023-12-20T00:10:00 1000 f1 32.0 +2023-12-20T00:20:00 1000 f1 32.0 +2023-12-20T00:20:00 1000 f2 foo +2023-12-20T00:30:00 1000 f1 32.0 +2023-12-20T00:30:00 1000 f2 foo +2023-12-20T00:40:00 1000 f2 foo +2023-12-20T00:40:00 1000 f1 32.0 +2023-12-20T00:50:00 1000 f1 32.0 +2023-12-20T00:50:00 1000 f2 foo +2023-12-20T01:00:00 1000 f1 32.0 +2023-12-20T01:00:00 1000 f2 foo +2023-12-20T01:10:00 1000 f2 foo +2023-12-20T01:10:00 1000 f1 32.0 +2023-12-20T01:20:00 1000 f2 foo +2023-12-20T01:20:00 1000 f1 32.0 +2023-12-20T01:30:00 1000 f2 foo +2023-12-20T01:30:00 1000 f1 32.0 From 920d058432e44150946c547478b941855b5ed380 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 5 Jan 2024 14:54:54 -0500 Subject: [PATCH 3/3] Update tests --- .../sqllogictest/test_files/aggregate.slt | 4 +- .../sqllogictest/test_files/dictionary.slt | 68 +++++++++++++++++-- 2 files changed, 65 insertions(+), 7 deletions(-) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 23989d8561ac..aa512f6e2600 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2569,7 +2569,7 @@ select trace_id, other, MIN(timestamp) from traces group by trace_id, other orde a -1 -1 b 0 0 NULL 0 0 -a 1 1 +c 1 1 query TII select trace_id, MIN(other), MIN(timestamp) from traces group by trace_id order by MIN(timestamp), MIN(other) limit 4; @@ -2694,7 +2694,7 @@ select trace_id, other, MIN(timestamp) from traces group by trace_id, other orde a -1 -1 b 0 0 NULL 0 0 -a 1 1 +c 1 1 query TII select trace_id, MIN(other), MIN(timestamp) from traces group by trace_id order by MIN(timestamp), MIN(other) limit 4; diff --git a/datafusion/sqllogictest/test_files/dictionary.slt b/datafusion/sqllogictest/test_files/dictionary.slt index ff49070f4590..bfe1712a4d91 100644 --- a/datafusion/sqllogictest/test_files/dictionary.slt +++ b/datafusion/sqllogictest/test_files/dictionary.slt @@ -169,7 +169,7 @@ order by date_bin('30 minutes', time) DESC # Reproducer for https://github.com/apache/arrow-datafusion/issues/8738 # This query should work correctly -query P?TT +query P?TT rowsort SELECT "data"."timestamp" as "time", "data"."tag_id", @@ -204,21 +204,79 @@ ORDER BY ---- 2023-12-20T00:00:00 1000 f1 32.0 2023-12-20T00:00:00 1000 f2 foo -2023-12-20T00:10:00 1000 f2 foo 2023-12-20T00:10:00 1000 f1 32.0 +2023-12-20T00:10:00 1000 f2 foo 2023-12-20T00:20:00 1000 f1 32.0 2023-12-20T00:20:00 1000 f2 foo 2023-12-20T00:30:00 1000 f1 32.0 2023-12-20T00:30:00 1000 f2 foo -2023-12-20T00:40:00 1000 f2 foo 2023-12-20T00:40:00 1000 f1 32.0 +2023-12-20T00:40:00 1000 f2 foo 2023-12-20T00:50:00 1000 f1 32.0 2023-12-20T00:50:00 1000 f2 foo 2023-12-20T01:00:00 1000 f1 32.0 2023-12-20T01:00:00 1000 f2 foo -2023-12-20T01:10:00 1000 f2 foo 2023-12-20T01:10:00 1000 f1 32.0 -2023-12-20T01:20:00 1000 f2 foo +2023-12-20T01:10:00 1000 f2 foo 2023-12-20T01:20:00 1000 f1 32.0 +2023-12-20T01:20:00 1000 f2 foo +2023-12-20T01:30:00 1000 f1 32.0 2023-12-20T01:30:00 1000 f2 foo + + +# deterministic sort (so we can avoid rowsort) +query P?TT +SELECT + "data"."timestamp" as "time", + "data"."tag_id", + "data"."field", + "data"."value" +FROM ( + ( + SELECT "m2"."time" as "timestamp", "m2"."tag_id", 'active_power' as "field", "m2"."f5" as "value" + FROM "m2" + WHERE "m2"."time" >= '2023-12-05T14:46:35+01:00' AND "m2"."time" < '2024-01-03T14:46:35+01:00' + AND "m2"."f5" IS NOT NULL + AND "m2"."type" IN ('active') + AND "m2"."tag_id" IN ('1000') + ) UNION ( + SELECT "m1"."time" as "timestamp", "m1"."tag_id", 'f1' as "field", "m1"."f1" as "value" + FROM "m1" + WHERE "m1"."time" >= '2023-12-05T14:46:35+01:00' AND "m1"."time" < '2024-01-03T14:46:35+01:00' + AND "m1"."f1" IS NOT NULL + AND "m1"."tag_id" IN ('1000') + ) UNION ( + SELECT "m1"."time" as "timestamp", "m1"."tag_id", 'f2' as "field", "m1"."f2" as "value" + FROM "m1" + WHERE "m1"."time" >= '2023-12-05T14:46:35+01:00' AND "m1"."time" < '2024-01-03T14:46:35+01:00' + AND "m1"."f2" IS NOT NULL + AND "m1"."tag_id" IN ('1000') + ) +) as "data" +ORDER BY + "time", + "data"."tag_id", + "data"."field", + "data"."value" +; +---- +2023-12-20T00:00:00 1000 f1 32.0 +2023-12-20T00:00:00 1000 f2 foo +2023-12-20T00:10:00 1000 f1 32.0 +2023-12-20T00:10:00 1000 f2 foo +2023-12-20T00:20:00 1000 f1 32.0 +2023-12-20T00:20:00 1000 f2 foo +2023-12-20T00:30:00 1000 f1 32.0 +2023-12-20T00:30:00 1000 f2 foo +2023-12-20T00:40:00 1000 f1 32.0 +2023-12-20T00:40:00 1000 f2 foo +2023-12-20T00:50:00 1000 f1 32.0 +2023-12-20T00:50:00 1000 f2 foo +2023-12-20T01:00:00 1000 f1 32.0 +2023-12-20T01:00:00 1000 f2 foo +2023-12-20T01:10:00 1000 f1 32.0 +2023-12-20T01:10:00 1000 f2 foo +2023-12-20T01:20:00 1000 f1 32.0 +2023-12-20T01:20:00 1000 f2 foo 2023-12-20T01:30:00 1000 f1 32.0 +2023-12-20T01:30:00 1000 f2 foo