Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ import java.util.Objects
import javax.xml.bind.DatatypeConverter

import scala.math.{BigDecimal, BigInt}
import scala.reflect.runtime.universe.TypeTag
import scala.util.Try

import org.json4s.JsonAST._

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -153,6 +155,14 @@ object Literal {
Literal(CatalystTypeConverters.convertToCatalyst(v), dataType)
}

def create[T : TypeTag](v: T): Literal = Try {
val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[T]
val convert = CatalystTypeConverters.createToCatalystConverter(dataType)
Literal(convert(v), dataType)
}.getOrElse {
Literal(v)
}

/**
* Create a literal with default value for given DataType
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.expressions

import java.nio.charset.StandardCharsets

import scala.reflect.runtime.universe.{typeTag, TypeTag}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -75,6 +77,9 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
test("boolean literals") {
checkEvaluation(Literal(true), true)
checkEvaluation(Literal(false), false)

checkEvaluation(Literal.create(true), true)
checkEvaluation(Literal.create(false), false)
}

test("int literals") {
Expand All @@ -83,36 +88,60 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Literal(d.toLong), d.toLong)
checkEvaluation(Literal(d.toShort), d.toShort)
checkEvaluation(Literal(d.toByte), d.toByte)

checkEvaluation(Literal.create(d), d)
checkEvaluation(Literal.create(d.toLong), d.toLong)
checkEvaluation(Literal.create(d.toShort), d.toShort)
checkEvaluation(Literal.create(d.toByte), d.toByte)
}
checkEvaluation(Literal(Long.MinValue), Long.MinValue)
checkEvaluation(Literal(Long.MaxValue), Long.MaxValue)

checkEvaluation(Literal.create(Long.MinValue), Long.MinValue)
checkEvaluation(Literal.create(Long.MaxValue), Long.MaxValue)
}

test("double literals") {
List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { d =>
checkEvaluation(Literal(d), d)
checkEvaluation(Literal(d.toFloat), d.toFloat)

checkEvaluation(Literal.create(d), d)
checkEvaluation(Literal.create(d.toFloat), d.toFloat)
}
checkEvaluation(Literal(Double.MinValue), Double.MinValue)
checkEvaluation(Literal(Double.MaxValue), Double.MaxValue)
checkEvaluation(Literal(Float.MinValue), Float.MinValue)
checkEvaluation(Literal(Float.MaxValue), Float.MaxValue)

checkEvaluation(Literal.create(Double.MinValue), Double.MinValue)
checkEvaluation(Literal.create(Double.MaxValue), Double.MaxValue)
checkEvaluation(Literal.create(Float.MinValue), Float.MinValue)
checkEvaluation(Literal.create(Float.MaxValue), Float.MaxValue)

}

test("string literals") {
checkEvaluation(Literal(""), "")
checkEvaluation(Literal("test"), "test")
checkEvaluation(Literal("\u0000"), "\u0000")

checkEvaluation(Literal.create(""), "")
checkEvaluation(Literal.create("test"), "test")
checkEvaluation(Literal.create("\u0000"), "\u0000")
}

test("sum two literals") {
checkEvaluation(Add(Literal(1), Literal(1)), 2)
checkEvaluation(Add(Literal.create(1), Literal.create(1)), 2)
}

test("binary literals") {
checkEvaluation(Literal.create(new Array[Byte](0), BinaryType), new Array[Byte](0))
checkEvaluation(Literal.create(new Array[Byte](2), BinaryType), new Array[Byte](2))

checkEvaluation(Literal.create(new Array[Byte](0)), new Array[Byte](0))
checkEvaluation(Literal.create(new Array[Byte](2)), new Array[Byte](2))
}

test("decimal") {
Expand All @@ -124,24 +153,63 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
Decimal((d * 1000L).toLong, 10, 3))
checkEvaluation(Literal(BigDecimal(d.toString)), Decimal(d))
checkEvaluation(Literal(new java.math.BigDecimal(d.toString)), Decimal(d))

checkEvaluation(Literal.create(Decimal(d)), Decimal(d))
checkEvaluation(Literal.create(Decimal(d.toInt)), Decimal(d.toInt))
checkEvaluation(Literal.create(Decimal(d.toLong)), Decimal(d.toLong))
checkEvaluation(Literal.create(Decimal((d * 1000L).toLong, 10, 3)),
Decimal((d * 1000L).toLong, 10, 3))
checkEvaluation(Literal.create(BigDecimal(d.toString)), Decimal(d))
checkEvaluation(Literal.create(new java.math.BigDecimal(d.toString)), Decimal(d))

}
}

private def toCatalyst[T: TypeTag](value: T): Any = {
val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[T]
CatalystTypeConverters.createToCatalystConverter(dataType)(value)
}

test("array") {
def checkArrayLiteral(a: Array[_], elementType: DataType): Unit = {
val toCatalyst = (a: Array[_], elementType: DataType) => {
CatalystTypeConverters.createToCatalystConverter(ArrayType(elementType))(a)
}
checkEvaluation(Literal(a), toCatalyst(a, elementType))
def checkArrayLiteral[T: TypeTag](a: Array[T]): Unit = {
checkEvaluation(Literal(a), toCatalyst(a))
checkEvaluation(Literal.create(a), toCatalyst(a))
}
checkArrayLiteral(Array(1, 2, 3))
checkArrayLiteral(Array("a", "b", "c"))
checkArrayLiteral(Array(1.0, 4.0))
checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR))
}

test("seq") {
def checkSeqLiteral[T: TypeTag](a: Seq[T], elementType: DataType): Unit = {
checkEvaluation(Literal.create(a), toCatalyst(a))
}
checkArrayLiteral(Array(1, 2, 3), IntegerType)
checkArrayLiteral(Array("a", "b", "c"), StringType)
checkArrayLiteral(Array(1.0, 4.0), DoubleType)
checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR),
checkSeqLiteral(Seq(1, 2, 3), IntegerType)
checkSeqLiteral(Seq("a", "b", "c"), StringType)
checkSeqLiteral(Seq(1.0, 4.0), DoubleType)
checkSeqLiteral(Seq(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR),
CalendarIntervalType)
}

test("unsupported types (map and struct) in literals") {
test("map") {
def checkMapLiteral[T: TypeTag](m: T): Unit = {
checkEvaluation(Literal.create(m), toCatalyst(m))
}
checkMapLiteral(Map("a" -> 1, "b" -> 2, "c" -> 3))
checkMapLiteral(Map("1" -> 1.0, "2" -> 2.0, "3" -> 3.0))
}

test("struct") {
def checkStructLiteral[T: TypeTag](s: T): Unit = {
checkEvaluation(Literal.create(s), toCatalyst(s))
}
checkStructLiteral((1, 3.0, "abcde"))
checkStructLiteral(("de", 1, 2.0f))
checkStructLiteral((1, ("fgh", 3.0)))
}

test("unsupported types (map and struct) in Literal.apply") {
def checkUnsupportedTypeInLiteral(v: Any): Unit = {
val errMsgMap = intercept[RuntimeException] {
Literal(v)
Expand Down
25 changes: 17 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,24 @@ object functions {
* @group normal_funcs
* @since 1.3.0
*/
def lit(literal: Any): Column = {
literal match {
case c: Column => return c
case s: Symbol => return new ColumnName(literal.asInstanceOf[Symbol].name)
case _ => // continue
}
def lit(literal: Any): Column = typedLit(literal)

val literalExpr = Literal(literal)
Column(literalExpr)
/**
* Creates a [[Column]] of literal value.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should explain when you want to use this method instead of lit.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, I'll update

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

*
* The passed in object is returned directly if it is already a [[Column]].
* If the object is a Scala Symbol, it is converted into a [[Column]] also.
* Otherwise, a new [[Column]] is created to represent the literal value.
* The difference between this function and [[lit]] is that this function
* can handle parameterized scala types e.g.: List, Seq and Map.
*
* @group normal_funcs
* @since 2.2.0
*/
def typedLit[T : TypeTag](literal: T): Column = literal match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @cloud-fan wdyt about the naming?

case c: Column => c
case s: Symbol => new ColumnName(s.name)
case _ => Column(Literal.create(literal))
}

//////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -712,4 +712,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
testData2.select($"a".bitwiseXOR($"b").bitwiseXOR(39)),
testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ r.getInt(1) ^ 39)))
}

test("typedLit") {
val df = Seq(Tuple1(0)).toDF("a")
// Only check the types `lit` cannot handle
checkAnswer(
df.select(typedLit(Seq(1, 2, 3))),
Row(Seq(1, 2, 3)) :: Nil)
checkAnswer(
df.select(typedLit(Map("a" -> 1, "b" -> 2))),
Row(Map("a" -> 1, "b" -> 2)) :: Nil)
checkAnswer(
df.select(typedLit(("a", 2, 1.0))),
Row(Row("a", 2, 1.0)) :: Nil)
}
}