Skip to content

Commit a4206d5

Browse files
cloud-fanHyukjinKwon
authored andcommitted
[SPARK-22938][SQL][FOLLOWUP] Assert that SQLConf.get is accessed only on the driver
## What changes were proposed in this pull request? This is a followup of apache#20136 . apache#20136 didn't really work because in the test, we are using local backend, which shares the driver side `SparkEnv`, so `SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER` doesn't work. This PR changes the check to `TaskContext.get != null`, and move the check to `SQLConf.get`, and fix all the places that violate this check: * `InMemoryTableScanExec#createAndDecompressColumn` is executed inside `rdd.map`, we can't access `conf.offHeapColumnVectorEnabled` there. apache#21223 merged * `DataType#sameType` may be executed in executor side, for things like json schema inference, so we can't call `conf.caseSensitiveAnalysis` there. This contributes to most of the code changes, as we need to add `caseSensitive` parameter to a lot of methods. * `ParquetFilters` is used in the file scan function, which is executed in executor side, so we can't can't call `conf.parquetFilterPushDownDate` there. apache#21224 merged * `WindowExec#createBoundOrdering` is called on executor side, so we can't use `conf.sessionLocalTimezone` there. apache#21225 merged * `JsonToStructs` can be serialized to executors and evaluate, we should not call `SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)` in the body. apache#21226 merged ## How was this patch tested? existing test Author: Wenchen Fan <[email protected]> Closes apache#21190 from cloud-fan/minor.
1 parent d3c426a commit a4206d5

File tree

10 files changed

+188
-140
lines changed

10 files changed

+188
-140
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2424
import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification
2525
import org.apache.spark.sql.catalyst.plans._
2626
import org.apache.spark.sql.catalyst.plans.logical._
27+
import org.apache.spark.sql.internal.SQLConf
2728
import org.apache.spark.sql.types._
2829

2930
/**
@@ -260,7 +261,9 @@ trait CheckAnalysis extends PredicateHelper {
260261
// Check if the data types match.
261262
dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) =>
262263
// SPARK-18058: we shall not care about the nullability of columns
263-
if (TypeCoercion.findWiderTypeForTwo(dt1.asNullable, dt2.asNullable).isEmpty) {
264+
val widerType = TypeCoercion.findWiderTypeForTwo(
265+
dt1.asNullable, dt2.asNullable, SQLConf.get.caseSensitiveAnalysis)
266+
if (widerType.isEmpty) {
264267
failAnalysis(
265268
s"""
266269
|${operator.nodeName} can only be performed on tables with the compatible

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with Cas
8383
// For each column, traverse all the values and find a common data type and nullability.
8484
val fields = table.rows.transpose.zip(table.names).map { case (column, name) =>
8585
val inputTypes = column.map(_.dataType)
86-
val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse {
86+
val wideType = TypeCoercion.findWiderTypeWithoutStringPromotion(
87+
inputTypes, conf.caseSensitiveAnalysis)
88+
val tpe = wideType.getOrElse {
8789
table.failAnalysis(s"incompatible types found in column $name for inline table")
8890
}
8991
StructField(name, tpe, nullable = column.exists(_.nullable))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 89 additions & 67 deletions
Large diffs are not rendered by default.

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import scala.util.matching.Regex
2727

2828
import org.apache.hadoop.fs.Path
2929

30-
import org.apache.spark.{SparkContext, SparkEnv}
30+
import org.apache.spark.TaskContext
3131
import org.apache.spark.internal.Logging
3232
import org.apache.spark.internal.config._
3333
import org.apache.spark.network.util.ByteUnit
@@ -107,7 +107,13 @@ object SQLConf {
107107
* run tests in parallel. At the time this feature was implemented, this was a no-op since we
108108
* run unit tests (that does not involve SparkSession) in serial order.
109109
*/
110-
def get: SQLConf = confGetter.get()()
110+
def get: SQLConf = {
111+
if (Utils.isTesting && TaskContext.get != null) {
112+
// we're accessing it during task execution, fail.
113+
throw new IllegalStateException("SQLConf should only be created and accessed on the driver.")
114+
}
115+
confGetter.get()()
116+
}
111117

112118
val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations")
113119
.internal()
@@ -1274,12 +1280,6 @@ object SQLConf {
12741280
class SQLConf extends Serializable with Logging {
12751281
import SQLConf._
12761282

1277-
if (Utils.isTesting && SparkEnv.get != null) {
1278-
// assert that we're only accessing it on the driver.
1279-
assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER,
1280-
"SQLConf should only be created and accessed on the driver.")
1281-
}
1282-
12831283
/** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */
12841284
@transient protected[spark] val settings = java.util.Collections.synchronizedMap(
12851285
new java.util.HashMap[String, String]())

sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,7 @@ abstract class DataType extends AbstractDataType {
8181
* (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
8282
*/
8383
private[spark] def sameType(other: DataType): Boolean =
84-
if (SQLConf.get.caseSensitiveAnalysis) {
85-
DataType.equalsIgnoreNullability(this, other)
86-
} else {
87-
DataType.equalsIgnoreCaseAndNullability(this, other)
88-
}
84+
DataType.equalsIgnoreNullability(this, other)
8985

9086
/**
9187
* Returns the same data type but set all nullability fields are true
@@ -218,7 +214,7 @@ object DataType {
218214
/**
219215
* Compares two types, ignoring nullability of ArrayType, MapType, StructType.
220216
*/
221-
private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
217+
private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
222218
(left, right) match {
223219
case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) =>
224220
equalsIgnoreNullability(leftElementType, rightElementType)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -128,17 +128,17 @@ class TypeCoercionSuite extends AnalysisTest {
128128
}
129129

130130
private def checkWidenType(
131-
widenFunc: (DataType, DataType) => Option[DataType],
131+
widenFunc: (DataType, DataType, Boolean) => Option[DataType],
132132
t1: DataType,
133133
t2: DataType,
134134
expected: Option[DataType],
135135
isSymmetric: Boolean = true): Unit = {
136-
var found = widenFunc(t1, t2)
136+
var found = widenFunc(t1, t2, conf.caseSensitiveAnalysis)
137137
assert(found == expected,
138138
s"Expected $expected as wider common type for $t1 and $t2, found $found")
139139
// Test both directions to make sure the widening is symmetric.
140140
if (isSymmetric) {
141-
found = widenFunc(t2, t1)
141+
found = widenFunc(t2, t1, conf.caseSensitiveAnalysis)
142142
assert(found == expected,
143143
s"Expected $expected as wider common type for $t2 and $t1, found $found")
144144
}
@@ -524,29 +524,29 @@ class TypeCoercionSuite extends AnalysisTest {
524524
test("cast NullType for expressions that implement ExpectsInputTypes") {
525525
import TypeCoercionSuite._
526526

527-
ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
527+
ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
528528
AnyTypeUnaryExpression(Literal.create(null, NullType)),
529529
AnyTypeUnaryExpression(Literal.create(null, NullType)))
530530

531-
ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
531+
ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
532532
NumericTypeUnaryExpression(Literal.create(null, NullType)),
533533
NumericTypeUnaryExpression(Literal.create(null, DoubleType)))
534534
}
535535

536536
test("cast NullType for binary operators") {
537537
import TypeCoercionSuite._
538538

539-
ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
539+
ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
540540
AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
541541
AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)))
542542

543-
ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
543+
ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
544544
NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
545545
NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType)))
546546
}
547547

548548
test("coalesce casts") {
549-
val rule = TypeCoercion.FunctionArgumentConversion
549+
val rule = TypeCoercion.FunctionArgumentConversion(conf)
550550

551551
val intLit = Literal(1)
552552
val longLit = Literal.create(1L)
@@ -606,7 +606,7 @@ class TypeCoercionSuite extends AnalysisTest {
606606
}
607607

608608
test("CreateArray casts") {
609-
ruleTest(TypeCoercion.FunctionArgumentConversion,
609+
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
610610
CreateArray(Literal(1.0)
611611
:: Literal(1)
612612
:: Literal.create(1.0, FloatType)
@@ -616,7 +616,7 @@ class TypeCoercionSuite extends AnalysisTest {
616616
:: Cast(Literal.create(1.0, FloatType), DoubleType)
617617
:: Nil))
618618

619-
ruleTest(TypeCoercion.FunctionArgumentConversion,
619+
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
620620
CreateArray(Literal(1.0)
621621
:: Literal(1)
622622
:: Literal("a")
@@ -626,15 +626,15 @@ class TypeCoercionSuite extends AnalysisTest {
626626
:: Cast(Literal("a"), StringType)
627627
:: Nil))
628628

629-
ruleTest(TypeCoercion.FunctionArgumentConversion,
629+
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
630630
CreateArray(Literal.create(null, DecimalType(5, 3))
631631
:: Literal(1)
632632
:: Nil),
633633
CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(13, 3))
634634
:: Literal(1).cast(DecimalType(13, 3))
635635
:: Nil))
636636

637-
ruleTest(TypeCoercion.FunctionArgumentConversion,
637+
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
638638
CreateArray(Literal.create(null, DecimalType(5, 3))
639639
:: Literal.create(null, DecimalType(22, 10))
640640
:: Literal.create(null, DecimalType(38, 38))
@@ -647,7 +647,7 @@ class TypeCoercionSuite extends AnalysisTest {
647647

648648
test("CreateMap casts") {
649649
// type coercion for map keys
650-
ruleTest(TypeCoercion.FunctionArgumentConversion,
650+
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
651651
CreateMap(Literal(1)
652652
:: Literal("a")
653653
:: Literal.create(2.0, FloatType)
@@ -658,7 +658,7 @@ class TypeCoercionSuite extends AnalysisTest {
658658
:: Cast(Literal.create(2.0, FloatType), FloatType)
659659
:: Literal("b")
660660
:: Nil))
661-
ruleTest(TypeCoercion.FunctionArgumentConversion,
661+
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
662662
CreateMap(Literal.create(null, DecimalType(5, 3))
663663
:: Literal("a")
664664
:: Literal.create(2.0, FloatType)
@@ -670,7 +670,7 @@ class TypeCoercionSuite extends AnalysisTest {
670670
:: Literal("b")
671671
:: Nil))
672672
// type coercion for map values
673-
ruleTest(TypeCoercion.FunctionArgumentConversion,
673+
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
674674
CreateMap(Literal(1)
675675
:: Literal("a")
676676
:: Literal(2)
@@ -681,7 +681,7 @@ class TypeCoercionSuite extends AnalysisTest {
681681
:: Literal(2)
682682
:: Cast(Literal(3.0), StringType)
683683
:: Nil))
684-
ruleTest(TypeCoercion.FunctionArgumentConversion,
684+
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
685685
CreateMap(Literal(1)
686686
:: Literal.create(null, DecimalType(38, 0))
687687
:: Literal(2)
@@ -693,7 +693,7 @@ class TypeCoercionSuite extends AnalysisTest {
693693
:: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38))
694694
:: Nil))
695695
// type coercion for both map keys and values
696-
ruleTest(TypeCoercion.FunctionArgumentConversion,
696+
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
697697
CreateMap(Literal(1)
698698
:: Literal("a")
699699
:: Literal(2.0)
@@ -708,7 +708,7 @@ class TypeCoercionSuite extends AnalysisTest {
708708

709709
test("greatest/least cast") {
710710
for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
711-
ruleTest(TypeCoercion.FunctionArgumentConversion,
711+
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
712712
operator(Literal(1.0)
713713
:: Literal(1)
714714
:: Literal.create(1.0, FloatType)
@@ -717,7 +717,7 @@ class TypeCoercionSuite extends AnalysisTest {
717717
:: Cast(Literal(1), DoubleType)
718718
:: Cast(Literal.create(1.0, FloatType), DoubleType)
719719
:: Nil))
720-
ruleTest(TypeCoercion.FunctionArgumentConversion,
720+
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
721721
operator(Literal(1L)
722722
:: Literal(1)
723723
:: Literal(new java.math.BigDecimal("1000000000000000000000"))
@@ -726,7 +726,7 @@ class TypeCoercionSuite extends AnalysisTest {
726726
:: Cast(Literal(1), DecimalType(22, 0))
727727
:: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0))
728728
:: Nil))
729-
ruleTest(TypeCoercion.FunctionArgumentConversion,
729+
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
730730
operator(Literal(1.0)
731731
:: Literal.create(null, DecimalType(10, 5))
732732
:: Literal(1)
@@ -735,7 +735,7 @@ class TypeCoercionSuite extends AnalysisTest {
735735
:: Literal.create(null, DecimalType(10, 5)).cast(DoubleType)
736736
:: Literal(1).cast(DoubleType)
737737
:: Nil))
738-
ruleTest(TypeCoercion.FunctionArgumentConversion,
738+
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
739739
operator(Literal.create(null, DecimalType(15, 0))
740740
:: Literal.create(null, DecimalType(10, 5))
741741
:: Literal(1)
@@ -744,7 +744,7 @@ class TypeCoercionSuite extends AnalysisTest {
744744
:: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(20, 5))
745745
:: Literal(1).cast(DecimalType(20, 5))
746746
:: Nil))
747-
ruleTest(TypeCoercion.FunctionArgumentConversion,
747+
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
748748
operator(Literal.create(2L, LongType)
749749
:: Literal(1)
750750
:: Literal.create(null, DecimalType(10, 5))
@@ -757,25 +757,25 @@ class TypeCoercionSuite extends AnalysisTest {
757757
}
758758

759759
test("nanvl casts") {
760-
ruleTest(TypeCoercion.FunctionArgumentConversion,
760+
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
761761
NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)),
762762
NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType)))
763-
ruleTest(TypeCoercion.FunctionArgumentConversion,
763+
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
764764
NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)),
765765
NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType)))
766-
ruleTest(TypeCoercion.FunctionArgumentConversion,
766+
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
767767
NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)),
768768
NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)))
769-
ruleTest(TypeCoercion.FunctionArgumentConversion,
769+
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
770770
NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)),
771771
NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType)))
772-
ruleTest(TypeCoercion.FunctionArgumentConversion,
772+
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
773773
NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)),
774774
NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType)))
775775
}
776776

777777
test("type coercion for If") {
778-
val rule = TypeCoercion.IfCoercion
778+
val rule = TypeCoercion.IfCoercion(conf)
779779
val intLit = Literal(1)
780780
val doubleLit = Literal(1.0)
781781
val trueLit = Literal.create(true, BooleanType)
@@ -823,20 +823,20 @@ class TypeCoercionSuite extends AnalysisTest {
823823
}
824824

825825
test("type coercion for CaseKeyWhen") {
826-
ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
826+
ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
827827
CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))),
828828
CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a")))
829829
)
830-
ruleTest(TypeCoercion.CaseWhenCoercion,
830+
ruleTest(TypeCoercion.CaseWhenCoercion(conf),
831831
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))),
832832
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a")))
833833
)
834-
ruleTest(TypeCoercion.CaseWhenCoercion,
834+
ruleTest(TypeCoercion.CaseWhenCoercion(conf),
835835
CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))),
836836
CaseWhen(Seq((Literal(true), Literal(1.2))),
837837
Cast(Literal.create(1, DecimalType(7, 2)), DoubleType))
838838
)
839-
ruleTest(TypeCoercion.CaseWhenCoercion,
839+
ruleTest(TypeCoercion.CaseWhenCoercion(conf),
840840
CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))),
841841
CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))),
842842
Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2)))
@@ -1085,7 +1085,7 @@ class TypeCoercionSuite extends AnalysisTest {
10851085
private val timeZoneResolver = ResolveTimeZone(new SQLConf)
10861086

10871087
private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = {
1088-
timeZoneResolver(TypeCoercion.WidenSetOperationTypes(plan))
1088+
timeZoneResolver(TypeCoercion.WidenSetOperationTypes(conf)(plan))
10891089
}
10901090

10911091
test("WidenSetOperationTypes for except and intersect") {
@@ -1256,7 +1256,7 @@ class TypeCoercionSuite extends AnalysisTest {
12561256

12571257
test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " +
12581258
"in aggregation function like sum") {
1259-
val rules = Seq(FunctionArgumentConversion, Division)
1259+
val rules = Seq(FunctionArgumentConversion(conf), Division)
12601260
// Casts Integer to Double
12611261
ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType))))
12621262
// Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will
@@ -1275,7 +1275,7 @@ class TypeCoercionSuite extends AnalysisTest {
12751275
}
12761276

12771277
test("SPARK-17117 null type coercion in divide") {
1278-
val rules = Seq(FunctionArgumentConversion, Division, new ImplicitTypeCasts(conf))
1278+
val rules = Seq(FunctionArgumentConversion(conf), Division, ImplicitTypeCasts(conf))
12791279
val nullLit = Literal.create(null, NullType)
12801280
ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType)))
12811281
ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType)))

0 commit comments

Comments
 (0)