Skip to content

Commit c7dd4d2

Browse files
author
Davies Liu
committed
fix some catalyst tests
1 parent 38c303e commit c7dd4d2

File tree

6 files changed

+93
-79
lines changed

6 files changed

+93
-79
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -216,16 +216,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
216216
val $primitiveTerm: ${termForType(dataType)} = $value
217217
""".children
218218

219-
// case expressions.Literal(value: UTF8String, dataType) =>
220-
// q"""
221-
// val $nullTerm = ${value == null}
222-
// val $primitiveTerm: ${termForType(dataType)} = $value
223-
// """.children
224-
225-
case expressions.Literal(value: String, dataType) =>
219+
case expressions.Literal(value: UTF8String, dataType) =>
226220
q"""
227221
val $nullTerm = ${value == null}
228-
val $primitiveTerm: ${termForType(dataType)} = $value
222+
val $primitiveTerm: ${termForType(dataType)} =
223+
org.apache.spark.sql.types.UTF8String(${value.toString})
229224
""".children
230225

231226
case expressions.Literal(value: Int, dataType) =>
@@ -249,11 +244,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
249244
if($nullTerm)
250245
${defaultPrimitive(StringType)}
251246
else
252-
UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]])
247+
org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]])
253248
""".children
254249

255250
case Cast(child @ DateType(), StringType) =>
256-
child.castOrNull(c => q"org.apache.spark.sql.types.DateUtils.toString($c)", StringType)
251+
child.castOrNull(c => q"org.apache.spark.sql.types.UTF8String(org.apache.spark.sql.types.DateUtils.toString($c))", StringType)
257252

258253
case Cast(child @ NumericType(), IntegerType) =>
259254
child.castOrNull(c => q"$c.toInt", IntegerType)
@@ -278,7 +273,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
278273
if($nullTerm)
279274
${defaultPrimitive(StringType)}
280275
else
281-
${eval.primitiveTerm}.toString
276+
org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString)
282277
""".children
283278

284279
case EqualTo(e1, e2) =>
@@ -579,7 +574,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
579574
val localLogger = log
580575
val localLoggerTree = reify { localLogger }
581576
q"""
582-
$localLoggerTree.debug(${e.toString} + ": " + (if($nullTerm) "null" else $primitiveTerm))
577+
$localLoggerTree.debug(${e.toString} + ": " + (if($nullTerm) "null" else $primitiveTerm.toString))
583578
""" :: Nil
584579
} else {
585580
Nil
@@ -590,7 +585,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
590585

591586
protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = {
592587
dataType match {
593-
case StringType => q"$inputRow.apply($ordinal).asInstanceOf[UTF8String]"
588+
case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]"
594589
case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)"
595590
case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]"
596591
}
@@ -602,7 +597,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
602597
ordinal: Int,
603598
value: TermName) = {
604599
dataType match {
605-
case StringType => q"$destinationRow.setString($ordinal, $value)"
600+
case StringType => q"$destinationRow.update($ordinal, $value)"
606601
case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
607602
case _ => q"$destinationRow.update($ordinal, $value)"
608603
}
@@ -626,13 +621,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
626621
case DoubleType => "Double"
627622
case FloatType => "Float"
628623
case BooleanType => "Boolean"
629-
case StringType => "String"
624+
case StringType => "org.apache.spark.sql.types.UTF8String"
630625
}
631626

632627
protected def defaultPrimitive(dt: DataType) = dt match {
633628
case BooleanType => ru.Literal(Constant(false))
634629
case FloatType => ru.Literal(Constant(-1.0.toFloat))
635-
case StringType => ru.Literal(Constant("<uninit>"))
630+
case StringType => q"""org.apache.spark.sql.types.UTF8String("<uninit>")"""
636631
case ShortType => ru.Literal(Constant(-1.toShort))
637632
case LongType => ru.Literal(Constant(-1L))
638633
case ByteType => ru.Literal(Constant(-1.toByte))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,13 @@ object IntegerLiteral {
6464

6565
case class Literal(var value: Any, dataType: DataType) extends LeafExpression {
6666

67-
if (dataType == StringType && value.isInstanceOf[String]) {
68-
value = UTF8String(value.asInstanceOf[String])
67+
// TODO(davies): FIXME
68+
(value, dataType) match {
69+
case (s: String, StringType) =>
70+
value = UTF8String(s)
71+
case (seq: Seq[String], dt:ArrayType) if dt.elementType == StringType =>
72+
value = seq.map(UTF8String(_))
73+
case _ =>
6974
}
7075

7176
override def foldable: Boolean = true

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -198,14 +198,19 @@ object LikeSimplification extends Rule[LogicalPlan] {
198198
val equalTo = "([^_%]*)".r
199199

200200
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
201-
case Like(l, Literal(startsWith(pattern), StringType)) if !pattern.endsWith("\\") =>
202-
StartsWith(l, Literal(pattern))
203-
case Like(l, Literal(endsWith(pattern), StringType)) =>
204-
EndsWith(l, Literal(pattern))
205-
case Like(l, Literal(contains(pattern), StringType)) if !pattern.endsWith("\\") =>
206-
Contains(l, Literal(pattern))
207-
case Like(l, Literal(equalTo(pattern), StringType)) =>
208-
EqualTo(l, Literal(pattern))
201+
case Like(l, Literal(utf, StringType)) =>
202+
utf.toString match {
203+
case startsWith(pattern) if !pattern.endsWith("\\") =>
204+
StartsWith(l, Literal(pattern))
205+
case endsWith(pattern) =>
206+
EndsWith(l, Literal(pattern))
207+
case contains(pattern) if !pattern.endsWith("\\") =>
208+
Contains(l, Literal(pattern))
209+
case equalTo(pattern) =>
210+
EqualTo(l, Literal(pattern))
211+
case _ =>
212+
Like(l, Literal(utf, StringType))
213+
}
209214
}
210215
}
211216

sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import java.util.Arrays
2828
* Note: This is not designed for general use cases, should not be used outside SQL.
2929
*/
3030

31-
private[sql] final class UTF8String extends Ordered[UTF8String] with Serializable {
31+
final class UTF8String extends Ordered[UTF8String] with Serializable {
3232

3333
private var bytes: Array[Byte] = _
3434

@@ -57,7 +57,7 @@ private[sql] final class UTF8String extends Ordered[UTF8String] with Serializabl
5757
var len = 0
5858
var i: Int = 0
5959
while (i < bytes.length) {
60-
val b = bytes(i)
60+
val b = bytes(i) & 0xFF
6161
i += 1
6262
if (b >= 192) {
6363
i += UTF8String.tailBytesOfUTF8(b - 192)
@@ -84,7 +84,7 @@ private[sql] final class UTF8String extends Ordered[UTF8String] with Serializabl
8484
var c = 0
8585
var i: Int = 0
8686
while (c < start && i < bytes.length) {
87-
val b = bytes(i)
87+
val b = bytes(i) & 0xFF
8888
i += 1
8989
if (b >= 192) {
9090
i += UTF8String.tailBytesOfUTF8(b - 192)
@@ -93,7 +93,7 @@ private[sql] final class UTF8String extends Ordered[UTF8String] with Serializabl
9393
}
9494
var j = i
9595
while (c < until && j < bytes.length) {
96-
val b = bytes(j)
96+
val b = bytes(j) & 0xFF
9797
j += 1
9898
if (b >= 192) {
9999
j += UTF8String.tailBytesOfUTF8(b - 192)
@@ -160,7 +160,7 @@ private[sql] final class UTF8String extends Ordered[UTF8String] with Serializabl
160160
}
161161
}
162162

163-
private[sql] object UTF8String {
163+
object UTF8String {
164164
// number of tailing bytes in a UTF8 sequence for a code point
165165
// see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1
166166
private[types] val tailBytesOfUTF8: Array[Int] = Array(1, 1, 1, 1, 1,

0 commit comments

Comments
 (0)