Skip to content

Commit 5ce7680

Browse files
hvanhovellcloud-fan
authored andcommitted
[SPARK-20329][SQL] Make timezone aware expression without timezone unresolved
## What changes were proposed in this pull request? A cast expression with a resolved time zone is not equal to a cast expression without a resolved time zone. The `ResolveAggregateFunction` assumed that these expression were the same, and would fail to resolve `HAVING` clauses which contain a `Cast` expression. This is in essence caused by the fact that a `TimeZoneAwareExpression` can be resolved without a set time zone. This PR fixes this, and makes a `TimeZoneAwareExpression` unresolved as long as it has no TimeZone set. ## How was this patch tested? Added a regression test to the `SQLQueryTestSuite.having` file. Author: Herman van Hovell <[email protected]> Closes #17641 from hvanhovell/SPARK-20329. (cherry picked from commit 760c8d0) Signed-off-by: Wenchen Fan <[email protected]>
1 parent d17dea8 commit 5ce7680

File tree

19 files changed

+148
-78
lines changed

19 files changed

+148
-78
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ class Analyzer(
150150
ResolveAggregateFunctions ::
151151
TimeWindowing ::
152152
ResolveInlineTables(conf) ::
153+
ResolveTimeZone(conf) ::
153154
TypeCoercion.typeCoercionRules ++
154155
extendedResolutionRules : _*),
155156
Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
@@ -161,8 +162,6 @@ class Analyzer(
161162
HandleNullInputsForUDF),
162163
Batch("FixNullability", Once,
163164
FixNullability),
164-
Batch("ResolveTimeZone", Once,
165-
ResolveTimeZone),
166165
Batch("Subquery", Once,
167166
UpdateOuterReferences),
168167
Batch("Cleanup", fixedPoint,
@@ -2341,23 +2340,6 @@ class Analyzer(
23412340
}
23422341
}
23432342
}
2344-
2345-
/**
2346-
* Replace [[TimeZoneAwareExpression]] without timezone id by its copy with session local
2347-
* time zone.
2348-
*/
2349-
object ResolveTimeZone extends Rule[LogicalPlan] {
2350-
2351-
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions {
2352-
case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty =>
2353-
e.withTimeZone(conf.sessionLocalTimeZone)
2354-
// Casts could be added in the subquery plan through the rule TypeCoercion while coercing
2355-
// the types between the value expression and list query expression of IN expression.
2356-
// We need to subject the subquery plan through ResolveTimeZone again to setup timezone
2357-
// information for time zone aware expressions.
2358-
case e: ListQuery => e.withNewPlan(apply(e.plan))
2359-
}
2360-
}
23612343
}
23622344

23632345
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.analysis
2020
import scala.util.control.NonFatal
2121

2222
import org.apache.spark.sql.catalyst.InternalRow
23-
import org.apache.spark.sql.catalyst.expressions.{Cast, TimeZoneAwareExpression}
2423
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
2524
import org.apache.spark.sql.catalyst.rules.Rule
2625
import org.apache.spark.sql.internal.SQLConf
@@ -29,7 +28,7 @@ import org.apache.spark.sql.types.{StructField, StructType}
2928
/**
3029
* An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]].
3130
*/
32-
case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] {
31+
case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport {
3332
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
3433
case table: UnresolvedInlineTable if table.expressionsResolved =>
3534
validateInputDimension(table)
@@ -99,12 +98,9 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] {
9998
val castedExpr = if (e.dataType.sameType(targetType)) {
10099
e
101100
} else {
102-
Cast(e, targetType)
101+
cast(e, targetType)
103102
}
104-
castedExpr.transform {
105-
case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty =>
106-
e.withTimeZone(conf.sessionLocalTimeZone)
107-
}.eval()
103+
castedExpr.eval()
108104
} catch {
109105
case NonFatal(ex) =>
110106
table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}")
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.catalyst.analysis
18+
19+
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ListQuery, TimeZoneAwareExpression}
20+
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
21+
import org.apache.spark.sql.catalyst.rules.Rule
22+
import org.apache.spark.sql.internal.SQLConf
23+
import org.apache.spark.sql.types.DataType
24+
25+
/**
26+
* Replace [[TimeZoneAwareExpression]] without timezone id by its copy with session local
27+
* time zone.
28+
*/
29+
case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] {
30+
private val transformTimeZoneExprs: PartialFunction[Expression, Expression] = {
31+
case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty =>
32+
e.withTimeZone(conf.sessionLocalTimeZone)
33+
// Casts could be added in the subquery plan through the rule TypeCoercion while coercing
34+
// the types between the value expression and list query expression of IN expression.
35+
// We need to subject the subquery plan through ResolveTimeZone again to setup timezone
36+
// information for time zone aware expressions.
37+
case e: ListQuery => e.withNewPlan(apply(e.plan))
38+
}
39+
40+
override def apply(plan: LogicalPlan): LogicalPlan =
41+
plan.resolveExpressions(transformTimeZoneExprs)
42+
43+
def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs)
44+
}
45+
46+
/**
47+
* Mix-in trait for constructing valid [[Cast]] expressions.
48+
*/
49+
trait CastSupport {
50+
/**
51+
* Configuration used to create a valid cast expression.
52+
*/
53+
def conf: SQLConf
54+
55+
/**
56+
* Create a Cast expression with the session local time zone.
57+
*/
58+
def cast(child: Expression, dataType: DataType): Cast = {
59+
Cast(child, dataType, Option(conf.sessionLocalTimeZone))
60+
}
61+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ import org.apache.spark.sql.internal.SQLConf
4747
* This should be only done after the batch of Resolution, because the view attributes are not
4848
* completely resolved during the batch of Resolution.
4949
*/
50-
case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] {
50+
case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport {
5151
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
5252
case v @ View(desc, output, child) if child.resolved && output != child.output =>
5353
val resolver = conf.resolver
@@ -78,7 +78,7 @@ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] {
7878
throw new AnalysisException(s"Cannot up cast ${originAttr.sql} from " +
7979
s"${originAttr.dataType.simpleString} to ${attr.simpleString} as it may truncate\n")
8080
} else {
81-
Alias(Cast(originAttr, attr.dataType), attr.name)(exprId = attr.exprId,
81+
Alias(cast(originAttr, attr.dataType), attr.name)(exprId = attr.exprId,
8282
qualifier = attr.qualifier, explicitMetadata = Some(attr.metadata))
8383
}
8484
case (_, originAttr) => originAttr

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import java.util.{Calendar, TimeZone}
2424
import scala.util.control.NonFatal
2525

2626
import org.apache.spark.sql.catalyst.InternalRow
27-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2827
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
2928
import org.apache.spark.sql.catalyst.util.DateTimeUtils
3029
import org.apache.spark.sql.types._
@@ -34,6 +33,9 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
3433
* Common base class for time zone aware expressions.
3534
*/
3635
trait TimeZoneAwareExpression extends Expression {
36+
/** The expression is only resolved when the time zone has been set. */
37+
override lazy val resolved: Boolean =
38+
childrenResolved && checkInputDataTypes().isSuccess && timeZoneId.isDefined
3739

3840
/** the timezone ID to be used to evaluate value. */
3941
def timeZoneId: Option[String]

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfter
2222
import org.apache.spark.sql.AnalysisException
2323
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand}
2424
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
25+
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
2526
import org.apache.spark.sql.types.{LongType, NullType, TimestampType}
2627

2728
/**
@@ -91,12 +92,13 @@ class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter {
9192
test("convert TimeZoneAwareExpression") {
9293
val table = UnresolvedInlineTable(Seq("c1"),
9394
Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType))))
94-
val converted = ResolveInlineTables(conf).convert(table)
95+
val withTimeZone = ResolveTimeZone(conf).apply(table)
96+
val LocalRelation(output, data) = ResolveInlineTables(conf).apply(withTimeZone)
9597
val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType)
9698
.withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long]
97-
assert(converted.output.map(_.dataType) == Seq(TimestampType))
98-
assert(converted.data.size == 1)
99-
assert(converted.data(0).getLong(0) == correct)
99+
assert(output.map(_.dataType) == Seq(TimestampType))
100+
assert(data.size == 1)
101+
assert(data.head.getLong(0) == correct)
100102
}
101103

102104
test("nullability inference in convert") {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.plans.PlanTest
2626
import org.apache.spark.sql.catalyst.plans.logical._
2727
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
28+
import org.apache.spark.sql.internal.SQLConf
2829
import org.apache.spark.sql.types._
2930
import org.apache.spark.unsafe.types.CalendarInterval
3031

@@ -787,6 +788,12 @@ class TypeCoercionSuite extends PlanTest {
787788
}
788789
}
789790

791+
private val timeZoneResolver = ResolveTimeZone(new SQLConf)
792+
793+
private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = {
794+
timeZoneResolver(TypeCoercion.WidenSetOperationTypes(plan))
795+
}
796+
790797
test("WidenSetOperationTypes for except and intersect") {
791798
val firstTable = LocalRelation(
792799
AttributeReference("i", IntegerType)(),
@@ -799,11 +806,10 @@ class TypeCoercionSuite extends PlanTest {
799806
AttributeReference("f", FloatType)(),
800807
AttributeReference("l", LongType)())
801808

802-
val wt = TypeCoercion.WidenSetOperationTypes
803809
val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)
804810

805-
val r1 = wt(Except(firstTable, secondTable)).asInstanceOf[Except]
806-
val r2 = wt(Intersect(firstTable, secondTable)).asInstanceOf[Intersect]
811+
val r1 = widenSetOperationTypes(Except(firstTable, secondTable)).asInstanceOf[Except]
812+
val r2 = widenSetOperationTypes(Intersect(firstTable, secondTable)).asInstanceOf[Intersect]
807813
checkOutput(r1.left, expectedTypes)
808814
checkOutput(r1.right, expectedTypes)
809815
checkOutput(r2.left, expectedTypes)
@@ -838,10 +844,9 @@ class TypeCoercionSuite extends PlanTest {
838844
AttributeReference("p", ByteType)(),
839845
AttributeReference("q", DoubleType)())
840846

841-
val wt = TypeCoercion.WidenSetOperationTypes
842847
val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)
843848

844-
val unionRelation = wt(
849+
val unionRelation = widenSetOperationTypes(
845850
Union(firstTable :: secondTable :: thirdTable :: forthTable :: Nil)).asInstanceOf[Union]
846851
assert(unionRelation.children.length == 4)
847852
checkOutput(unionRelation.children.head, expectedTypes)
@@ -862,17 +867,15 @@ class TypeCoercionSuite extends PlanTest {
862867
}
863868
}
864869

865-
val dp = TypeCoercion.WidenSetOperationTypes
866-
867870
val left1 = LocalRelation(
868871
AttributeReference("l", DecimalType(10, 8))())
869872
val right1 = LocalRelation(
870873
AttributeReference("r", DecimalType(5, 5))())
871874
val expectedType1 = Seq(DecimalType(10, 8))
872875

873-
val r1 = dp(Union(left1, right1)).asInstanceOf[Union]
874-
val r2 = dp(Except(left1, right1)).asInstanceOf[Except]
875-
val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect]
876+
val r1 = widenSetOperationTypes(Union(left1, right1)).asInstanceOf[Union]
877+
val r2 = widenSetOperationTypes(Except(left1, right1)).asInstanceOf[Except]
878+
val r3 = widenSetOperationTypes(Intersect(left1, right1)).asInstanceOf[Intersect]
876879

877880
checkOutput(r1.children.head, expectedType1)
878881
checkOutput(r1.children.last, expectedType1)
@@ -891,17 +894,17 @@ class TypeCoercionSuite extends PlanTest {
891894
val plan2 = LocalRelation(
892895
AttributeReference("r", rType)())
893896

894-
val r1 = dp(Union(plan1, plan2)).asInstanceOf[Union]
895-
val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except]
896-
val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect]
897+
val r1 = widenSetOperationTypes(Union(plan1, plan2)).asInstanceOf[Union]
898+
val r2 = widenSetOperationTypes(Except(plan1, plan2)).asInstanceOf[Except]
899+
val r3 = widenSetOperationTypes(Intersect(plan1, plan2)).asInstanceOf[Intersect]
897900

898901
checkOutput(r1.children.last, Seq(expectedType))
899902
checkOutput(r2.right, Seq(expectedType))
900903
checkOutput(r3.right, Seq(expectedType))
901904

902-
val r4 = dp(Union(plan2, plan1)).asInstanceOf[Union]
903-
val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except]
904-
val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect]
905+
val r4 = widenSetOperationTypes(Union(plan2, plan1)).asInstanceOf[Union]
906+
val r5 = widenSetOperationTypes(Except(plan2, plan1)).asInstanceOf[Except]
907+
val r6 = widenSetOperationTypes(Intersect(plan2, plan1)).asInstanceOf[Intersect]
905908

906909
checkOutput(r4.children.last, Seq(expectedType))
907910
checkOutput(r5.left, Seq(expectedType))

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.unsafe.types.UTF8String
3434
*/
3535
class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
3636

37-
private def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): Cast = {
37+
private def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = Some("GMT")): Cast = {
3838
v match {
3939
case lit: Expression => Cast(lit, targetType, timeZoneId)
4040
case _ => Cast(Literal(v), targetType, timeZoneId)
@@ -47,7 +47,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
4747
}
4848

4949
private def checkNullCast(from: DataType, to: DataType): Unit = {
50-
checkEvaluation(cast(Literal.create(null, from), to, Option("GMT")), null)
50+
checkEvaluation(cast(Literal.create(null, from), to), null)
5151
}
5252

5353
test("null cast") {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
160160

161161
test("Seconds") {
162162
assert(Second(Literal.create(null, DateType), gmtId).resolved === false)
163-
assert(Second(Cast(Literal(d), TimestampType), None).resolved === true)
163+
assert(Second(Cast(Literal(d), TimestampType, gmtId), gmtId).resolved === true)
164164
checkEvaluation(Second(Cast(Literal(d), TimestampType, gmtId), gmtId), 0)
165165
checkEvaluation(Second(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 15)
166166
checkEvaluation(Second(Literal(ts), gmtId), 15)
@@ -220,7 +220,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
220220

221221
test("Hour") {
222222
assert(Hour(Literal.create(null, DateType), gmtId).resolved === false)
223-
assert(Hour(Literal(ts), None).resolved === true)
223+
assert(Hour(Literal(ts), gmtId).resolved === true)
224224
checkEvaluation(Hour(Cast(Literal(d), TimestampType, gmtId), gmtId), 0)
225225
checkEvaluation(Hour(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 13)
226226
checkEvaluation(Hour(Literal(ts), gmtId), 13)
@@ -246,7 +246,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
246246

247247
test("Minute") {
248248
assert(Minute(Literal.create(null, DateType), gmtId).resolved === false)
249-
assert(Minute(Literal(ts), None).resolved === true)
249+
assert(Minute(Literal(ts), gmtId).resolved === true)
250250
checkEvaluation(Minute(Cast(Literal(d), TimestampType, gmtId), gmtId), 0)
251251
checkEvaluation(
252252
Minute(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 10)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks
2525
import org.apache.spark.{SparkConf, SparkFunSuite}
2626
import org.apache.spark.serializer.JavaSerializer
2727
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
28+
import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
2829
import org.apache.spark.sql.catalyst.expressions.codegen._
2930
import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer
3031
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
31-
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
32+
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
33+
import org.apache.spark.sql.internal.SQLConf
3234
import org.apache.spark.sql.types._
3335
import org.apache.spark.util.Utils
3436

@@ -45,7 +47,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
4547
protected def checkEvaluation(
4648
expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = {
4749
val serializer = new JavaSerializer(new SparkConf()).newInstance
48-
val expr: Expression = serializer.deserialize(serializer.serialize(expression))
50+
val resolver = ResolveTimeZone(new SQLConf)
51+
val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression)))
4952
val catalystValue = CatalystTypeConverters.convertToCatalyst(expected)
5053
checkEvaluationWithoutCodegen(expr, catalystValue, inputRow)
5154
checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow)

0 commit comments

Comments
 (0)