From d368a88648de13aca529402c2e55e2e271653fcd Mon Sep 17 00:00:00 2001 From: "navis.ryu" Date: Mon, 22 Jun 2015 13:15:02 +0900 Subject: [PATCH] [SPARK-8420] [SQL] Inconsistent behavior with Dataframe Timestamp between 1.3.1 and 1.4.0 --- .../catalyst/analysis/HiveTypeCoercion.scala | 13 +++-- .../spark/sql/catalyst/expressions/Cast.scala | 4 ++ .../expressions/stringOperations.scala | 47 ++++++++++++++++++- .../apache/spark/sql/DataFrameDateSuite.scala | 4 ++ .../org/apache/spark/sql/SQLQuerySuite.scala | 32 +++++++++++++ 5 files changed, 92 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index d4ab1fc643c33..7f908da64aea9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -280,11 +280,11 @@ trait HiveTypeCoercion { // For equality between string and timestamp we cast the string to a timestamp // so that things like rounding of subsecond precision does not affect the comparison. case p @ Equality(left @ StringType(), right @ TimestampType()) => - p.makeCopy(Array(Cast(left, TimestampType), right)) + p.makeCopy(Array(NormalizeTS(left), Cast(right, StringType))) case p @ Equality(left @ TimestampType(), right @ StringType()) => - p.makeCopy(Array(left, Cast(right, TimestampType))) + p.makeCopy(Array(Cast(left, StringType), NormalizeTS(right))) - // We should cast all relative timestamp/date/string comparison into string comparisions + // We should cast all relative timestamp/date/string comparison into string comparisons // This behaves as a user would expect because timestamp strings sort lexicographically. // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true case p @ BinaryComparison(left @ StringType(), right @ DateType()) => @@ -292,11 +292,10 @@ trait HiveTypeCoercion { case p @ BinaryComparison(left @ DateType(), right @ StringType()) => p.makeCopy(Array(Cast(left, StringType), right)) case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) => - p.makeCopy(Array(left, Cast(right, StringType))) - case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) => - p.makeCopy(Array(Cast(left, StringType), right)) - // Comparisons between dates and timestamps. + p.makeCopy(Array(NormalizeTS(left), Cast(right, StringType))) + case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) => + p.makeCopy(Array(Cast(left, StringType), NormalizeTS(right))) case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) => p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index ad920f287820c..5c665e90b7ba0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -475,6 +475,10 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } object Cast { + + private[sql] val timestampRegex = + """^(\d{4}\-\d{1,2}\-\d{1,2})( \d{2}:\d{2}:\d{2}(\.(\d+))?)?$""".r + // `SimpleDateFormat` is not thread-safe. private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { override def initialValue(): SimpleDateFormat = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 315c63e63c635..1a8e2d456ede3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import java.util.regex.Pattern +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.expressions.Substring import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -160,6 +160,51 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE } } +/** + * A function that converts string to normalized format for comparing it with timestamp type + */ +private[catalyst] case class NormalizeTS(child: Expression) + extends UnaryExpression with CaseConversionExpression with Logging { + + override def convert(v: UTF8String): UTF8String = { + val s = v.toString.trim + val m = Cast.timestampRegex.pattern.matcher(s) + if (m.matches()) { + UTF8String.fromString( + if (m.group(2) == null) { + s + " 00:00:00" + } else if (m.group(3) != null) { + // contains nano part + var nano = m.group(4) + if (nano.length > 9) { + nano = nano.substring(0, 9) // trim to max 9 + } + if (isAllZeros(nano)) { + s.substring(0, m.start(3)) // remove nano part + } else { + s.substring(0, m.start(4) + nano.length) + } + } else { + s + } + ) + } else { + v + } + } + + private def isAllZeros(s: String): Boolean = { + for (c <- s) { + if (c != '0') return false + } + true + } + + override def toString: String = { + s"NormalizeTS($child)" + } +} + /** A base trait for functions that compare two strings, returning a boolean. */ trait StringComparison extends ExpectsInputTypes { self: BinaryExpression => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala index a4719a38de1d4..374ab75c00741 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala @@ -37,6 +37,10 @@ class DataFrameDateTimeSuite extends QueryTest { checkAnswer( df.select("t").filter($"t" >= "2014-06-01"), Row(Timestamp.valueOf("2015-01-01 00:00:00")) :: Nil) + + checkAnswer( + df.select("t").filter($"t" === "2014-01-01"), + Row(Timestamp.valueOf("2014-01-01 00:00:00")) :: Nil) } test("date comparison with date strings") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 4441afd6bd811..e8a1c6d6b238a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -383,6 +383,38 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Nil) } + test("SPARK-8420 Inconsistent behavior with Dataframe Timestamp between 1.3.1 and 1.4.0") { + val timestamps = Seq( + Timestamp.valueOf("2015-06-16 00:00:00"), + Timestamp.valueOf("2015-06-18 06:00:00"), + Timestamp.valueOf("2015-06-20 12:00:00")) + timestamps.map(t => Tuple1(t)).toDF("time").registerTempTable("timestamps2") + + val expected = Row(java.sql.Timestamp.valueOf("2015-06-16 00:00:00")) + + // convert string to normalized format, if it's in timestamp format + checkAnswer( + sql("SELECT time FROM timestamps2 WHERE time='2015-06-16'"), expected) + + checkAnswer( + sql("SELECT time FROM timestamps2 WHERE time='2015-06-16 00:00:00'"), expected) + + checkAnswer( + sql("SELECT time FROM timestamps2 WHERE time='2015-06-16 00:00:00.00'"), expected) + + checkAnswer( + sql("SELECT time FROM timestamps2 WHERE time='2015-06-16 00:00:00.000000000123'"), expected) + + // if it's not timestamp format, use as-is string to be compared + checkAnswer(sql("SELECT time FROM timestamps2 WHERE time > '2015-06-2'"), + Row(java.sql.Timestamp.valueOf("2015-06-20 12:00:00"))) + + checkAnswer(sql("SELECT time FROM timestamps2 WHERE time < 'abcd'"), + Row(java.sql.Timestamp.valueOf("2015-06-16 00:00:00")) :: + Row(java.sql.Timestamp.valueOf("2015-06-18 06:00:00")) :: + Row(java.sql.Timestamp.valueOf("2015-06-20 12:00:00")) :: Nil) + } + test("index into array") { checkAnswer( sql("SELECT data, data[0], data[0] + data[1], data[0 + 1] FROM arrayData"),