Skip to content

Commit 984551e

Browse files
committed
add multi ordering test case
Signed-off-by: jayzhan211 <[email protected]>
1 parent c8e1c84 commit 984551e

File tree

3 files changed

+60
-30
lines changed

3 files changed

+60
-30
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
c1,c2,c3
2+
1,20,0
3+
2,20,1
4+
3,10,2
5+
4,10,3
6+
5,30,4
7+
6,30,5
8+
7,30,6
9+
8,30,7
10+
9,30,8
11+
10,10,9

datafusion/physical-expr/src/aggregate/array_agg_ordered.rs

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ use crate::{AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr};
3030

3131
use arrow::array::ArrayRef;
3232
use arrow::datatypes::{DataType, Field};
33+
use arrow_array::cast::AsArray;
3334
use arrow_array::Array;
3435
use arrow_schema::{Fields, SortOptions};
35-
use datafusion_common::cast::as_list_array;
3636
use datafusion_common::utils::{compare_rows, get_row_at_idx};
3737
use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
3838
use datafusion_expr::Accumulator;
@@ -214,7 +214,7 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
214214
// values received from its ordering requirement expression. (This information is necessary for during merging).
215215
let agg_orderings = &states[1];
216216

217-
if as_list_array(agg_orderings).is_ok() {
217+
if let Some(agg_orderings) = agg_orderings.as_list_opt::<i32>() {
218218
// Stores ARRAY_AGG results coming from each partition
219219
let mut partition_values = vec![];
220220
// Stores ordering requirement expression results coming from each partition
@@ -232,10 +232,21 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
232232
}
233233

234234
let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?;
235-
// Ordering requirement expression values for each entry in the ARRAY_AGG list
236-
let other_ordering_values = self.convert_array_agg_to_orderings(orderings)?;
237-
for v in other_ordering_values.into_iter() {
238-
partition_ordering_values.push(v);
235+
236+
for partition_ordering_rows in orderings.into_iter() {
237+
// Extract value from struct to ordering_rows for each group/partition
238+
let ordering_value = partition_ordering_rows.into_iter().map(|ordering_row| {
239+
if let ScalarValue::Struct(Some(ordering_columns_per_row), _) = ordering_row {
240+
Ok(ordering_columns_per_row)
241+
} else {
242+
exec_err!(
243+
"Expects to receive ScalarValue::Struct(Some(..), _) but got:{:?}",
244+
ordering_row.data_type()
245+
)
246+
}
247+
}).collect::<Result<Vec<_>>>()?;
248+
249+
partition_ordering_values.push(ordering_value);
239250
}
240251

241252
let sort_options = self
@@ -293,33 +304,10 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
293304
}
294305

295306
impl OrderSensitiveArrayAggAccumulator {
296-
/// Inner Vec\<ScalarValue> in the ordering_values can be thought as ordering information for the each ScalarValue in the values array.
297-
/// See [`merge_ordered_arrays`] for more information.
298-
fn convert_array_agg_to_orderings(
299-
&self,
300-
array_agg: Vec<Vec<ScalarValue>>,
301-
) -> Result<Vec<Vec<Vec<ScalarValue>>>> {
302-
let mut orderings = vec![];
303-
// in_data is Vec<ScalarValue> where ScalarValue does not include ScalarValue::List
304-
for in_data in array_agg.into_iter() {
305-
let ordering = in_data.into_iter().map(|struct_vals| {
306-
if let ScalarValue::Struct(Some(orderings), _) = struct_vals {
307-
Ok(orderings)
308-
} else {
309-
exec_err!(
310-
"Expects to receive ScalarValue::Struct(Some(..), _) but got:{:?}",
311-
struct_vals.data_type()
312-
)
313-
}
314-
}).collect::<Result<Vec<_>>>()?;
315-
orderings.push(ordering);
316-
}
317-
Ok(orderings)
318-
}
319-
320307
fn evaluate_orderings(&self) -> Result<ScalarValue> {
321308
let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]);
322309
let struct_field = Fields::from(fields.clone());
310+
323311
let orderings: Vec<ScalarValue> = self
324312
.ordering_values
325313
.iter()
@@ -329,6 +317,7 @@ impl OrderSensitiveArrayAggAccumulator {
329317
.collect();
330318
let struct_type = DataType::Struct(Fields::from(fields));
331319

320+
// Wrap in List, so we have the same data structure ListArray(StructArray..) for group by cases
332321
let arr = ScalarValue::new_list(&orderings, &struct_type);
333322
Ok(ScalarValue::List(arr))
334323
}

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,36 @@ FROM
106106
----
107107
[0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm, 0keZ5G8BffGwgF2RwQD59TFzMStxCB, 0og6hSkhbX8AC1ktFS4kounvTzy8Vo, 1aOcrEGd0cOqZe2I5XBOm0nDcwtBZO, 2T3wSlHdEmASmO0xcXHnndkKEt6bz8]
108108

109+
statement ok
110+
CREATE EXTERNAL TABLE agg_order (
111+
c1 INT NOT NULL,
112+
c2 INT NOT NULL,
113+
c3 INT NOT NULL
114+
)
115+
STORED AS CSV
116+
WITH HEADER ROW
117+
LOCATION '../core/tests/data/aggregate_agg_multi_order.csv';
118+
119+
# test array_agg with order by multiple columns
120+
query ?
121+
select array_agg(c1 order by c2 desc, c3) from agg_order;
122+
----
123+
[5, 6, 7, 8, 9, 1, 2, 3, 4, 10]
124+
125+
query TT
126+
explain select array_agg(c1 order by c2 desc, c3) from agg_order;
127+
----
128+
logical_plan
129+
Aggregate: groupBy=[[]], aggr=[[ARRAY_AGG(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]]]
130+
--TableScan: agg_order projection=[c1, c2, c3]
131+
physical_plan
132+
AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(agg_order.c1)]
133+
--CoalescePartitionsExec
134+
----AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(agg_order.c1)]
135+
------SortExec: expr=[c2@1 DESC,c3@2 ASC NULLS LAST]
136+
--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
137+
----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/aggregate_agg_multi_order.csv]]}, projection=[c1, c2, c3], has_header=true
138+
109139
statement error This feature is not implemented: LIMIT not supported in ARRAY_AGG: 1
110140
SELECT array_agg(c13 LIMIT 1) FROM aggregate_test_100
111141

0 commit comments

Comments
 (0)