Skip to content

Commit 7212897

Browse files
adrian-wangyhuai
authored andcommitted
[SPARK-6201] [SQL] promote string and do widen types for IN
huangjs Acutally spark sql will first go through analysis period, in which we do widen types and promote strings, and then optimization, where constant IN will be converted into INSET. So it turn out that we only need to fix this for IN. Author: Daoyuan Wang <[email protected]> Closes #4945 from adrian-wang/inset and squashes the following commits: 71e05cc [Daoyuan Wang] minor fix 581fa1c [Daoyuan Wang] mysql way f3f7baf [Daoyuan Wang] address comments 5eed4bc [Daoyuan Wang] promote string and do widen types for IN (cherry picked from commit c3eb441) Signed-off-by: Yin Huai <[email protected]>
1 parent f1a5caf commit 7212897

File tree

3 files changed

+22
-2
lines changed

3 files changed

+22
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ trait HiveTypeCoercion {
6969
val typeCoercionRules =
7070
PropagateTypes ::
7171
ConvertNaNs ::
72+
InConversion ::
7273
WidenTypes ::
7374
PromoteStrings ::
7475
DecimalPrecision ::
@@ -287,6 +288,16 @@ trait HiveTypeCoercion {
287288
}
288289
}
289290

291+
/**
292+
* Convert all expressions in in() list to the left operator type
293+
*/
294+
object InConversion extends Rule[LogicalPlan] {
295+
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
296+
case i @ In(a, b) if b.exists(_.dataType != a.dataType) =>
297+
i.makeCopy(Array(a, b.map(Cast(_, a.dataType))))
298+
}
299+
}
300+
290301
// scalastyle:off
291302
/**
292303
* Calculates and propagates precision for fixed-precision decimals. Hive has a number of

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,8 @@ object OptimizeIn extends Rule[LogicalPlan] {
310310
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
311311
case q: LogicalPlan => q transformExpressionsDown {
312312
case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) =>
313-
val hSet = list.map(e => e.eval(null))
314-
InSet(v, HashSet() ++ hSet)
313+
val hSet = list.map(e => e.eval(null))
314+
InSet(v, HashSet() ++ hSet)
315315
}
316316
}
317317
}

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,15 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
120120
Row(1, 1) :: Nil)
121121
}
122122

123+
test("SPARK-6201 IN type conversion") {
124+
jsonRDD(sparkContext.parallelize(Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}")))
125+
.registerTempTable("d")
126+
127+
checkAnswer(
128+
sql("select * from d where d.a in (1,2)"),
129+
Seq(Row("1"), Row("2")))
130+
}
131+
123132
test("SPARK-3176 Added Parser of SQL ABS()") {
124133
checkAnswer(
125134
sql("SELECT ABS(-1.3)"),

0 commit comments

Comments
 (0)