Skip to content

Commit 106f8b8

Browse files
committed
[FLINK-38525][python] Fix coder doesn't process TIMESTAMP_LTZ correctly
1 parent 64ce0c2 commit 106f8b8

File tree

6 files changed

+66
-19
lines changed

6 files changed

+66
-19
lines changed

flink-python/pyflink/fn_execution/coder_impl_fast.pxd

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ cdef class TimestampCoderImpl(FieldCoderImpl):
148148

149149
cdef _decode_timestamp_data_from_stream(self, InputStream in_stream)
150150

151+
cdef _to_utc_timestamp(self, value)
152+
153+
cdef _to_datetime(self, int64_t seconds, int32_t microseconds)
154+
151155
cdef class LocalZonedTimestampCoderImpl(TimestampCoderImpl):
152156
cdef object _timezone
153157

flink-python/pyflink/fn_execution/coder_impl_fast.pyx

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -685,8 +685,9 @@ cdef class TimestampCoderImpl(FieldCoderImpl):
685685
cpdef encode_to_stream(self, value, OutputStream out_stream):
686686
cdef int32_t microseconds_of_second, nanoseconds
687687
cdef int64_t timestamp_seconds, timestamp_milliseconds
688-
timestamp_seconds = <int64_t> (value.replace(tzinfo=datetime.timezone.utc).timestamp())
689-
microseconds_of_second = value.microsecond
688+
utc_ts = self._to_utc_timestamp(value)
689+
timestamp_seconds = <int64_t> (utc_ts.timestamp())
690+
microseconds_of_second = utc_ts.microsecond
690691
timestamp_milliseconds = timestamp_seconds * 1000 + microseconds_of_second // 1000
691692
nanoseconds = microseconds_of_second % 1000 * 1000
692693
if self._is_compact:
@@ -709,7 +710,15 @@ cdef class TimestampCoderImpl(FieldCoderImpl):
709710
nanoseconds = in_stream.read_int32()
710711
seconds = milliseconds // 1000
711712
microseconds = milliseconds % 1000 * 1000 + nanoseconds // 1000
712-
return datetime.datetime.utcfromtimestamp(seconds).replace(microsecond=microseconds)
713+
return self._to_datetime(seconds, microseconds)
714+
715+
cdef _to_utc_timestamp(self, value):
716+
return value.replace(tzinfo=datetime.timezone.utc)
717+
718+
cdef _to_datetime(self, int64_t seconds, int32_t microseconds):
719+
datetime.datetime.utcfromtimestamp(seconds).replace(microsecond=microseconds)
720+
721+
cdef _to_datetime(self, int64_t seconds, int32_t microseconds)
713722

714723
cdef class LocalZonedTimestampCoderImpl(TimestampCoderImpl):
715724
"""
@@ -723,6 +732,13 @@ cdef class LocalZonedTimestampCoderImpl(TimestampCoderImpl):
723732
cpdef decode_from_stream(self, InputStream in_stream, size_t size):
724733
return self._timezone.localize(self._decode_timestamp_data_from_stream(in_stream))
725734

735+
cpdef _to_utc_timestamp(self, value):
736+
return value.astimezone(tzinfo=datetime.timezone.utc)
737+
738+
cdef _to_datetime(self, int64_t seconds, int32_t microseconds):
739+
(datetime.datetime.fromtimestamp(seconds, tz=datetime.timezone.utc)
740+
.replace(microsecond=microseconds).astimezone(self._timezone))
741+
726742
cdef class InstantCoderImpl(FieldCoderImpl):
727743
"""
728744
A coder for Instant.

flink-python/pyflink/fn_execution/coder_impl_slow.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -570,8 +570,7 @@ def decode_from_stream(self, in_stream: InputStream, length=0):
570570
nanoseconds = in_stream.read_int32()
571571
return self.internal_to_timestamp(milliseconds, nanoseconds)
572572

573-
@staticmethod
574-
def timestamp_to_internal(timestamp):
573+
def timestamp_to_internal(self, timestamp):
575574
seconds = int(timestamp.replace(tzinfo=datetime.timezone.utc).timestamp())
576575
microseconds_of_second = timestamp.microsecond
577576
milliseconds = seconds * 1000 + microseconds_of_second // 1000
@@ -593,10 +592,18 @@ def __init__(self, precision, timezone):
593592
super(LocalZonedTimestampCoderImpl, self).__init__(precision)
594593
self.timezone = timezone
595594

595+
def timestamp_to_internal(self, timestamp):
596+
seconds = int(timestamp.astimezone(datetime.timezone.utc).timestamp())
597+
microseconds_of_second = timestamp.microsecond
598+
milliseconds = seconds * 1000 + microseconds_of_second // 1000
599+
nanoseconds = microseconds_of_second % 1000 * 1000
600+
return milliseconds, nanoseconds
601+
596602
def internal_to_timestamp(self, milliseconds, nanoseconds):
597-
return self.timezone.localize(
598-
super(LocalZonedTimestampCoderImpl, self).internal_to_timestamp(
599-
milliseconds, nanoseconds))
603+
return (super(LocalZonedTimestampCoderImpl, self).internal_to_timestamp(
604+
milliseconds, nanoseconds)
605+
.replace(tzinfo=datetime.timezone.utc)
606+
.astimezone(self.timezone))
600607

601608

602609
class InstantCoderImpl(FieldCoderImpl):

flink-python/pyflink/table/tests/test_udf.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,13 @@ def timestamp_func(timestamp_param):
364364
'timestamp_param is wrong value %s !' % timestamp_param
365365
return timestamp_param
366366

367+
@udf(result_type=DataTypes.TIMESTAMP_LTZ(3))
368+
def timestamp_ltz_func(timestamp_ltz_param):
369+
from datetime import datetime, timezone
370+
assert timestamp_ltz_param == datetime(2018, 3, 11, 3, 0, 0, 123000, timezone.utc), \
371+
'timestamp_ltz_param is wrong value %s !' % timestamp_ltz_param
372+
return timestamp_ltz_param
373+
367374
@udf(result_type=DataTypes.ARRAY(DataTypes.BIGINT()))
368375
def array_func(array_param):
369376
assert array_param == [[1, 2, 3]] or array_param == ((1, 2, 3),), \
@@ -427,7 +434,8 @@ def varchar_func(varchar_param):
427434
q DECIMAL(38, 18),
428435
r BINARY(5),
429436
s CHAR(7),
430-
t VARCHAR(10)
437+
t VARCHAR(10),
438+
u TIMESTAMP_LTZ(3)
431439
) WITH ('connector'='test-sink')
432440
"""
433441
self.t_env.execute_sql(sink_table_ddl)
@@ -441,7 +449,8 @@ def varchar_func(varchar_param):
441449
datetime.datetime(2018, 3, 11, 3, 0, 0, 123000), [[1, 2, 3]],
442450
{1: 'flink', 2: 'pyflink'}, decimal.Decimal('1000000000000000000.05'),
443451
decimal.Decimal('1000000000000000000.05999999999999999899999999999'),
444-
bytearray(b'flink'), 'pyflink', 'pyflink')],
452+
bytearray(b'flink'), 'pyflink', 'pyflink',
453+
datetime.datetime(2018, 3, 11, 3, 0, 0, 123000, datetime.timezone.utc))],
445454
DataTypes.ROW(
446455
[DataTypes.FIELD("a", DataTypes.BIGINT()),
447456
DataTypes.FIELD("b", DataTypes.BIGINT()),
@@ -462,7 +471,8 @@ def varchar_func(varchar_param):
462471
DataTypes.FIELD("q", DataTypes.DECIMAL(38, 18)),
463472
DataTypes.FIELD("r", DataTypes.BINARY(5)),
464473
DataTypes.FIELD("s", DataTypes.CHAR(7)),
465-
DataTypes.FIELD("t", DataTypes.VARCHAR(10))]))
474+
DataTypes.FIELD("t", DataTypes.VARCHAR(10)),
475+
DataTypes.FIELD("u", DataTypes.TIMESTAMP_LTZ(3))]))
466476

467477
t.select(
468478
bigint_func(t.a),
@@ -484,7 +494,8 @@ def varchar_func(varchar_param):
484494
decimal_cut_func(t.q),
485495
binary_func(t.r),
486496
char_func(t.s),
487-
varchar_func(t.t)) \
497+
varchar_func(t.t),
498+
timestamp_ltz_func(t.u)) \
488499
.execute_insert(sink_table).wait()
489500
actual = source_sink_utils.results()
490501
# Currently the sink result precision of DataTypes.TIME(precision) only supports 0.
@@ -494,7 +505,7 @@ def varchar_func(varchar_param):
494505
"2018-03-11T03:00:00.123, [1, 2, 3], "
495506
"{1=flink, 2=pyflink}, 1000000000000000000.050000000000000000, "
496507
"1000000000000000000.059999999999999999, [102, 108, 105, 110, 107], "
497-
"pyflink, pyflink]"])
508+
"pyflink, pyflink, 2018-03-11T03:00:00.123Z]"])
498509

499510
def test_all_data_types(self):
500511
def boolean_func(bool_param):
@@ -995,7 +1006,7 @@ def local_zoned_timestamp_func(local_zoned_timestamp_param):
9951006
.execute_insert(sink_table) \
9961007
.wait()
9971008
actual = source_sink_utils.results()
998-
self.assert_equals(actual, ["+I[1970-01-01T00:00:00.123Z]"])
1009+
self.assert_equals(actual, ["+I[1969-12-31T16:00:01.123Z]"])
9991010

10001011
def test_execute_from_json_plan(self):
10011012
# create source file path
@@ -1161,6 +1172,8 @@ def echo(i: str):
11611172
if __name__ == '__main__':
11621173
import unittest
11631174

1175+
os.environ['_python_worker_execution_mode'] = "loopback"
1176+
11641177
try:
11651178
import xmlrunner
11661179

flink-python/pyflink/table/types.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from pyflink.common.types import _create_row
3535
from pyflink.util.api_stability_decorators import PublicEvolving
36+
from pyflink.util.exceptions import TableException
3637
from pyflink.util.java_utils import to_jarray, is_instance_of
3738
from pyflink.java_gateway import get_gateway
3839
from pyflink.common import Row, RowKind
@@ -498,14 +499,19 @@ def need_conversion(self):
498499

499500
def to_sql_type(self, dt):
500501
if dt is not None:
501-
seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo
502-
else time.mktime(dt.timetuple()))
503-
return int(seconds) * 10 ** 6 + dt.microsecond + self.EPOCH_ORDINAL
502+
if dt.tzinfo is None:
503+
raise TableException(
504+
f"""The input field {dt} does not specify time zone but its SQL type \
505+
TIMESTAMP_LTZ requires TIME ZONE. Please use TIMESTAMP instead or use CAST \
506+
function to cast TIMESTAMP as TIMESTAMP_LTZ."""
507+
)
508+
seconds = calendar.timegm(dt.utctimetuple())
509+
return int(seconds) * 10 ** 6 + dt.microsecond
504510

505511
def from_sql_type(self, ts):
506512
if ts is not None:
507-
ts = ts - self.EPOCH_ORDINAL
508-
return datetime.datetime.fromtimestamp(ts // 10 ** 6).replace(microsecond=ts % 10 ** 6)
513+
return (datetime.datetime.fromtimestamp(ts // 10 ** 6, datetime.timezone.utc)
514+
.replace(microsecond=ts % 10 ** 6))
509515

510516

511517
class ZonedTimestampType(AtomicType):

flink-python/pyflink/testing/test_case_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def setUpClass(cls):
164164
super(PyFlinkStreamTableTestCase, cls).setUpClass()
165165
cls.env.set_runtime_mode(RuntimeExecutionMode.STREAMING)
166166
cls.env.set_parallelism(2)
167+
os.environ['_python_worker_execution_mode'] = "loopback"
167168
cls.t_env = StreamTableEnvironment.create(cls.env)
168169
cls.t_env.get_config().set("python.fn-execution.bundle.size", "1")
169170

0 commit comments

Comments
 (0)