Skip to content

Commit 14bb398

Browse files
maropuhvanhovell
authored andcommitted
[SPARK-19254][SQL] Support Seq, Map, and Struct in functions.lit
## What changes were proposed in this pull request? This pr is to support Seq, Map, and Struct in functions.lit; it adds a new IF named `lit2` with `TypeTag` for avoiding type erasure. ## How was this patch tested? Added tests in `LiteralExpressionSuite` Author: Takeshi Yamamuro <[email protected]> Author: Takeshi YAMAMURO <[email protected]> Closes #16610 from maropu/SPARK-19254.
1 parent f48461a commit 14bb398

File tree

4 files changed

+121
-20
lines changed

4 files changed

+121
-20
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,13 @@ import java.util.Objects
3232
import javax.xml.bind.DatatypeConverter
3333

3434
import scala.math.{BigDecimal, BigInt}
35+
import scala.reflect.runtime.universe.TypeTag
36+
import scala.util.Try
3537

3638
import org.json4s.JsonAST._
3739

3840
import org.apache.spark.sql.AnalysisException
39-
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
41+
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
4042
import org.apache.spark.sql.catalyst.expressions.codegen._
4143
import org.apache.spark.sql.catalyst.util.DateTimeUtils
4244
import org.apache.spark.sql.types._
@@ -153,6 +155,14 @@ object Literal {
153155
Literal(CatalystTypeConverters.convertToCatalyst(v), dataType)
154156
}
155157

158+
def create[T : TypeTag](v: T): Literal = Try {
159+
val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[T]
160+
val convert = CatalystTypeConverters.createToCatalystConverter(dataType)
161+
Literal(convert(v), dataType)
162+
}.getOrElse {
163+
Literal(v)
164+
}
165+
156166
/**
157167
* Create a literal with default value for given DataType
158168
*/

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala

Lines changed: 79 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import java.nio.charset.StandardCharsets
2121

22+
import scala.reflect.runtime.universe.{typeTag, TypeTag}
23+
2224
import org.apache.spark.SparkFunSuite
2325
import org.apache.spark.sql.Row
24-
import org.apache.spark.sql.catalyst.CatalystTypeConverters
26+
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection}
2527
import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
2628
import org.apache.spark.sql.catalyst.util.DateTimeUtils
2729
import org.apache.spark.sql.types._
@@ -75,6 +77,9 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
7577
test("boolean literals") {
7678
checkEvaluation(Literal(true), true)
7779
checkEvaluation(Literal(false), false)
80+
81+
checkEvaluation(Literal.create(true), true)
82+
checkEvaluation(Literal.create(false), false)
7883
}
7984

8085
test("int literals") {
@@ -83,36 +88,60 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
8388
checkEvaluation(Literal(d.toLong), d.toLong)
8489
checkEvaluation(Literal(d.toShort), d.toShort)
8590
checkEvaluation(Literal(d.toByte), d.toByte)
91+
92+
checkEvaluation(Literal.create(d), d)
93+
checkEvaluation(Literal.create(d.toLong), d.toLong)
94+
checkEvaluation(Literal.create(d.toShort), d.toShort)
95+
checkEvaluation(Literal.create(d.toByte), d.toByte)
8696
}
8797
checkEvaluation(Literal(Long.MinValue), Long.MinValue)
8898
checkEvaluation(Literal(Long.MaxValue), Long.MaxValue)
99+
100+
checkEvaluation(Literal.create(Long.MinValue), Long.MinValue)
101+
checkEvaluation(Literal.create(Long.MaxValue), Long.MaxValue)
89102
}
90103

91104
test("double literals") {
92105
List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { d =>
93106
checkEvaluation(Literal(d), d)
94107
checkEvaluation(Literal(d.toFloat), d.toFloat)
108+
109+
checkEvaluation(Literal.create(d), d)
110+
checkEvaluation(Literal.create(d.toFloat), d.toFloat)
95111
}
96112
checkEvaluation(Literal(Double.MinValue), Double.MinValue)
97113
checkEvaluation(Literal(Double.MaxValue), Double.MaxValue)
98114
checkEvaluation(Literal(Float.MinValue), Float.MinValue)
99115
checkEvaluation(Literal(Float.MaxValue), Float.MaxValue)
100116

117+
checkEvaluation(Literal.create(Double.MinValue), Double.MinValue)
118+
checkEvaluation(Literal.create(Double.MaxValue), Double.MaxValue)
119+
checkEvaluation(Literal.create(Float.MinValue), Float.MinValue)
120+
checkEvaluation(Literal.create(Float.MaxValue), Float.MaxValue)
121+
101122
}
102123

103124
test("string literals") {
104125
checkEvaluation(Literal(""), "")
105126
checkEvaluation(Literal("test"), "test")
106127
checkEvaluation(Literal("\u0000"), "\u0000")
128+
129+
checkEvaluation(Literal.create(""), "")
130+
checkEvaluation(Literal.create("test"), "test")
131+
checkEvaluation(Literal.create("\u0000"), "\u0000")
107132
}
108133

109134
test("sum two literals") {
110135
checkEvaluation(Add(Literal(1), Literal(1)), 2)
136+
checkEvaluation(Add(Literal.create(1), Literal.create(1)), 2)
111137
}
112138

113139
test("binary literals") {
114140
checkEvaluation(Literal.create(new Array[Byte](0), BinaryType), new Array[Byte](0))
115141
checkEvaluation(Literal.create(new Array[Byte](2), BinaryType), new Array[Byte](2))
142+
143+
checkEvaluation(Literal.create(new Array[Byte](0)), new Array[Byte](0))
144+
checkEvaluation(Literal.create(new Array[Byte](2)), new Array[Byte](2))
116145
}
117146

118147
test("decimal") {
@@ -124,24 +153,63 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
124153
Decimal((d * 1000L).toLong, 10, 3))
125154
checkEvaluation(Literal(BigDecimal(d.toString)), Decimal(d))
126155
checkEvaluation(Literal(new java.math.BigDecimal(d.toString)), Decimal(d))
156+
157+
checkEvaluation(Literal.create(Decimal(d)), Decimal(d))
158+
checkEvaluation(Literal.create(Decimal(d.toInt)), Decimal(d.toInt))
159+
checkEvaluation(Literal.create(Decimal(d.toLong)), Decimal(d.toLong))
160+
checkEvaluation(Literal.create(Decimal((d * 1000L).toLong, 10, 3)),
161+
Decimal((d * 1000L).toLong, 10, 3))
162+
checkEvaluation(Literal.create(BigDecimal(d.toString)), Decimal(d))
163+
checkEvaluation(Literal.create(new java.math.BigDecimal(d.toString)), Decimal(d))
164+
127165
}
128166
}
129167

168+
private def toCatalyst[T: TypeTag](value: T): Any = {
169+
val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[T]
170+
CatalystTypeConverters.createToCatalystConverter(dataType)(value)
171+
}
172+
130173
test("array") {
131-
def checkArrayLiteral(a: Array[_], elementType: DataType): Unit = {
132-
val toCatalyst = (a: Array[_], elementType: DataType) => {
133-
CatalystTypeConverters.createToCatalystConverter(ArrayType(elementType))(a)
134-
}
135-
checkEvaluation(Literal(a), toCatalyst(a, elementType))
174+
def checkArrayLiteral[T: TypeTag](a: Array[T]): Unit = {
175+
checkEvaluation(Literal(a), toCatalyst(a))
176+
checkEvaluation(Literal.create(a), toCatalyst(a))
177+
}
178+
checkArrayLiteral(Array(1, 2, 3))
179+
checkArrayLiteral(Array("a", "b", "c"))
180+
checkArrayLiteral(Array(1.0, 4.0))
181+
checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR))
182+
}
183+
184+
test("seq") {
185+
def checkSeqLiteral[T: TypeTag](a: Seq[T], elementType: DataType): Unit = {
186+
checkEvaluation(Literal.create(a), toCatalyst(a))
136187
}
137-
checkArrayLiteral(Array(1, 2, 3), IntegerType)
138-
checkArrayLiteral(Array("a", "b", "c"), StringType)
139-
checkArrayLiteral(Array(1.0, 4.0), DoubleType)
140-
checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR),
188+
checkSeqLiteral(Seq(1, 2, 3), IntegerType)
189+
checkSeqLiteral(Seq("a", "b", "c"), StringType)
190+
checkSeqLiteral(Seq(1.0, 4.0), DoubleType)
191+
checkSeqLiteral(Seq(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR),
141192
CalendarIntervalType)
142193
}
143194

144-
test("unsupported types (map and struct) in literals") {
195+
test("map") {
196+
def checkMapLiteral[T: TypeTag](m: T): Unit = {
197+
checkEvaluation(Literal.create(m), toCatalyst(m))
198+
}
199+
checkMapLiteral(Map("a" -> 1, "b" -> 2, "c" -> 3))
200+
checkMapLiteral(Map("1" -> 1.0, "2" -> 2.0, "3" -> 3.0))
201+
}
202+
203+
test("struct") {
204+
def checkStructLiteral[T: TypeTag](s: T): Unit = {
205+
checkEvaluation(Literal.create(s), toCatalyst(s))
206+
}
207+
checkStructLiteral((1, 3.0, "abcde"))
208+
checkStructLiteral(("de", 1, 2.0f))
209+
checkStructLiteral((1, ("fgh", 3.0)))
210+
}
211+
212+
test("unsupported types (map and struct) in Literal.apply") {
145213
def checkUnsupportedTypeInLiteral(v: Any): Unit = {
146214
val errMsgMap = intercept[RuntimeException] {
147215
Literal(v)

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,24 @@ object functions {
9191
* @group normal_funcs
9292
* @since 1.3.0
9393
*/
94-
def lit(literal: Any): Column = {
95-
literal match {
96-
case c: Column => return c
97-
case s: Symbol => return new ColumnName(literal.asInstanceOf[Symbol].name)
98-
case _ => // continue
99-
}
94+
def lit(literal: Any): Column = typedLit(literal)
10095

101-
val literalExpr = Literal(literal)
102-
Column(literalExpr)
96+
/**
97+
* Creates a [[Column]] of literal value.
98+
*
99+
* The passed in object is returned directly if it is already a [[Column]].
100+
* If the object is a Scala Symbol, it is converted into a [[Column]] also.
101+
* Otherwise, a new [[Column]] is created to represent the literal value.
102+
* The difference between this function and [[lit]] is that this function
103+
* can handle parameterized scala types e.g.: List, Seq and Map.
104+
*
105+
* @group normal_funcs
106+
* @since 2.2.0
107+
*/
108+
def typedLit[T : TypeTag](literal: T): Column = literal match {
109+
case c: Column => c
110+
case s: Symbol => new ColumnName(s.name)
111+
case _ => Column(Literal.create(literal))
103112
}
104113

105114
//////////////////////////////////////////////////////////////////////////////////////////////

sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,4 +712,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
712712
testData2.select($"a".bitwiseXOR($"b").bitwiseXOR(39)),
713713
testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ r.getInt(1) ^ 39)))
714714
}
715+
716+
test("typedLit") {
717+
val df = Seq(Tuple1(0)).toDF("a")
718+
// Only check the types `lit` cannot handle
719+
checkAnswer(
720+
df.select(typedLit(Seq(1, 2, 3))),
721+
Row(Seq(1, 2, 3)) :: Nil)
722+
checkAnswer(
723+
df.select(typedLit(Map("a" -> 1, "b" -> 2))),
724+
Row(Map("a" -> 1, "b" -> 2)) :: Nil)
725+
checkAnswer(
726+
df.select(typedLit(("a", 2, 1.0))),
727+
Row(Row("a", 2, 1.0)) :: Nil)
728+
}
715729
}

0 commit comments

Comments
 (0)