From 4c3e94df03863bf43eac7b47be51764f2f4be659 Mon Sep 17 00:00:00 2001 From: huangxingbo Date: Wed, 28 Apr 2021 20:39:34 +0800 Subject: [PATCH] [FLINK-22511][python] Fix the bug of non-composite result type in Python TableAggregateFunction --- .../pyflink/fn_execution/table/aggregate_fast.pxd | 2 +- .../pyflink/fn_execution/table/aggregate_fast.pyx | 13 +++++++++++-- .../pyflink/fn_execution/table/aggregate_slow.py | 10 +++++++++- .../pyflink/table/tests/test_row_based_operation.py | 9 ++++----- 4 files changed, 25 insertions(+), 9 deletions(-) diff --git a/flink-python/pyflink/fn_execution/table/aggregate_fast.pxd b/flink-python/pyflink/fn_execution/table/aggregate_fast.pxd index 50935a0602765..79469ff871ecc 100644 --- a/flink-python/pyflink/fn_execution/table/aggregate_fast.pxd +++ b/flink-python/pyflink/fn_execution/table/aggregate_fast.pxd @@ -59,7 +59,7 @@ cdef class SimpleAggsHandleFunction(SimpleAggsHandleFunctionBase): cdef size_t _get_value_indexes_length cdef class SimpleTableAggsHandleFunction(SimpleAggsHandleFunctionBase): - pass + cdef list _convert_to_row(self, data) cdef class RecordCounter: cdef bint record_count_is_zero(self, list acc) diff --git a/flink-python/pyflink/fn_execution/table/aggregate_fast.pyx b/flink-python/pyflink/fn_execution/table/aggregate_fast.pyx index 564bd65e2562b..bf1719ee8537b 100644 --- a/flink-python/pyflink/fn_execution/table/aggregate_fast.pyx +++ b/flink-python/pyflink/fn_execution/table/aggregate_fast.pyx @@ -24,6 +24,7 @@ from typing import List, Dict from apache_beam.coders import PickleCoder, Coder +from pyflink.common import Row from pyflink.fn_execution.table.state_data_view import DataViewSpec, ListViewSpec, MapViewSpec, \ PerKeyStateDataViewStore from pyflink.fn_execution.state_impl import RemoteKeyedStateBackend @@ -379,12 +380,20 @@ cdef class SimpleTableAggsHandleFunction(SimpleAggsHandleFunctionBase): results = [] for x in udf.emit_value(self._accumulators[0]): if is_retract: - result = join_row(current_key, x._values, InternalRowKind.DELETE) + result = join_row(current_key, self._convert_to_row(x), InternalRowKind.DELETE) else: - result = join_row(current_key, x._values, InternalRowKind.INSERT) + result = join_row(current_key, self._convert_to_row(x), InternalRowKind.INSERT) results.append(result) return results + cdef list _convert_to_row(self, data): + if isinstance(data, Row): + return data._values + elif isinstance(data, tuple): + return list(data) + else: + return [data] + cdef class RecordCounter: """ The RecordCounter is used to count the number of input records under the current key. diff --git a/flink-python/pyflink/fn_execution/table/aggregate_slow.py b/flink-python/pyflink/fn_execution/table/aggregate_slow.py index 2edb328a7af79..ca434c6a7a5f6 100644 --- a/flink-python/pyflink/fn_execution/table/aggregate_slow.py +++ b/flink-python/pyflink/fn_execution/table/aggregate_slow.py @@ -355,13 +355,21 @@ def emit_value(self, current_key: List, is_retract: bool): udf = self._udfs[0] # type: TableAggregateFunction results = udf.emit_value(self._accumulators[0]) for x in results: - result = join_row(current_key, x._values) + result = join_row(current_key, self._convert_to_row(x)) if is_retract: result.set_row_kind(RowKind.DELETE) else: result.set_row_kind(RowKind.INSERT) yield result + def _convert_to_row(self, data): + if isinstance(data, Row): + return data._values + elif isinstance(data, tuple): + return list(data) + else: + return [data] + class RecordCounter(ABC): """ diff --git a/flink-python/pyflink/table/tests/test_row_based_operation.py b/flink-python/pyflink/table/tests/test_row_based_operation.py index 25698adaf2cff..1ba989a61b150 100644 --- a/flink-python/pyflink/table/tests/test_row_based_operation.py +++ b/flink-python/pyflink/table/tests/test_row_based_operation.py @@ -264,7 +264,7 @@ def test_flat_aggregate(self): (2, 'Hi', 'Hello')], ['a', 'b', 'c']) result = t.select(t.a, t.c) \ .group_by(t.c) \ - .flat_aggregate(mytop) \ + .flat_aggregate(mytop.alias('a')) \ .select(t.a) \ .flat_aggregate(mytop.alias("b")) \ .select("b") \ @@ -339,8 +339,8 @@ def get_result_type(self): class Top2(TableAggregateFunction): def emit_value(self, accumulator): - yield Row(accumulator[0]) - yield Row(accumulator[1]) + yield accumulator[0] + yield accumulator[1] def create_accumulator(self): return [None, None] @@ -365,8 +365,7 @@ def get_accumulator_type(self): return DataTypes.ARRAY(DataTypes.BIGINT()) def get_result_type(self): - return DataTypes.ROW( - [DataTypes.FIELD("a", DataTypes.BIGINT())]) + return DataTypes.BIGINT() class ListViewConcatTableAggregateFunction(TableAggregateFunction):