Skip to content

Commit 8584276

Browse files
Davies Liumarmbrus
authored andcommitted
[SPARK-6638] [SQL] Improve performance of StringType in SQL
This PR change the internal representation for StringType from java.lang.String to UTF8String, which is implemented use ArrayByte. This PR should not break any public API, Row.getString() will still return java.lang.String. This is the first step of improve the performance of String in SQL. cc rxin Author: Davies Liu <[email protected]> Closes #5350 from davies/string and squashes the following commits: 3b7bfa8 [Davies Liu] fix schema of AddJar 2772f0d [Davies Liu] fix new test failure 6d776a9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 59025c8 [Davies Liu] address comments from @marmbrus 341ec2c [Davies Liu] turn off scala style check in UTF8StringSuite 744788f [Davies Liu] Merge branch 'master' of github.com:apache/spark into string b04a19c [Davies Liu] add comment for getString/setString 08d897b [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 5116b43 [Davies Liu] rollback unrelated changes 1314a37 [Davies Liu] address comments from Yin 867bf50 [Davies Liu] fix String filter push down 13d9d42 [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 2089d24 [Davies Liu] add hashcode check back ac18ae6 [Davies Liu] address comment fd11364 [Davies Liu] optimize UTF8String 8d17f21 [Davies Liu] fix hive compatibility tests e5fa5b8 [Davies Liu] remove clone in UTF8String 28f3d81 [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 28d6f32 [Davies Liu] refactor 537631c [Davies Liu] some comment about Date 9f4c194 [Davies Liu] convert data type for data source 956b0a4 [Davies Liu] fix hive tests 73e4363 [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 9dc32d1 [Davies Liu] fix some hive tests 23a766c [Davies Liu] refactor 8b45864 [Davies Liu] fix codegen with UTF8String bb52e44 [Davies Liu] fix scala style c7dd4d2 [Davies Liu] fix some catalyst tests 38c303e [Davies Liu] fix python sql tests 5f9e120 [Davies Liu] fix sql tests 6b499ac [Davies Liu] fix style a85fb27 [Davies Liu] refactor d32abd1 [Davies Liu] fix utf8 for python api 4699c3a [Davies Liu] use Array[Byte] in UTF8String 21f67c6 [Davies Liu] cleanup 685fd07 [Davies Liu] use UTF8String instead of String for StringType
1 parent 785f955 commit 8584276

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+742
-298
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ def join(self, other, joinExprs=None, joinType=None):
456456
One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
457457
458458
>>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect()
459-
[Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)]
459+
[Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)]
460460
"""
461461

462462
if joinExprs is None:
@@ -637,9 +637,9 @@ def groupBy(self, *cols):
637637
>>> df.groupBy().avg().collect()
638638
[Row(AVG(age)=3.5)]
639639
>>> df.groupBy('name').agg({'age': 'mean'}).collect()
640-
[Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)]
640+
[Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
641641
>>> df.groupBy(df.name).avg().collect()
642-
[Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)]
642+
[Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
643643
"""
644644
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
645645
self._sc._gateway._gateway_client)
@@ -867,11 +867,11 @@ def agg(self, *exprs):
867867
868868
>>> gdf = df.groupBy(df.name)
869869
>>> gdf.agg({"*": "count"}).collect()
870-
[Row(name=u'Bob', COUNT(1)=1), Row(name=u'Alice', COUNT(1)=1)]
870+
[Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)]
871871
872872
>>> from pyspark.sql import functions as F
873873
>>> gdf.agg(F.min(df.age)).collect()
874-
[Row(MIN(age)=5), Row(MIN(age)=2)]
874+
[Row(MIN(age)=2), Row(MIN(age)=5)]
875875
"""
876876
assert exprs, "exprs should not be empty"
877877
if len(exprs) == 1 and isinstance(exprs[0], dict):

sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql
2020
import scala.util.hashing.MurmurHash3
2121

2222
import org.apache.spark.sql.catalyst.expressions.GenericRow
23-
import org.apache.spark.sql.types.{StructType, DateUtils}
23+
import org.apache.spark.sql.types.StructType
2424

2525
object Row {
2626
/**
@@ -257,6 +257,7 @@ trait Row extends Serializable {
257257
*
258258
* @throws ClassCastException when data type does not match.
259259
*/
260+
// TODO(davies): This is not the right default implementation, we use Int as Date internally
260261
def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date]
261262

262263
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ object CatalystTypeConverters {
7777
}
7878
new GenericRowWithSchema(ar, structType)
7979

80+
case (d: String, _) =>
81+
UTF8String(d)
82+
8083
case (d: BigDecimal, _) =>
8184
Decimal(d)
8285

@@ -175,6 +178,11 @@ object CatalystTypeConverters {
175178
case other => other
176179
}
177180

181+
case dataType: StringType => (item: Any) => extractOption(item) match {
182+
case s: String => UTF8String(s)
183+
case other => other
184+
}
185+
178186
case _ =>
179187
(item: Any) => extractOption(item) match {
180188
case d: BigDecimal => Decimal(d)
@@ -184,6 +192,26 @@ object CatalystTypeConverters {
184192
}
185193
}
186194

195+
/**
196+
* Converts Scala objects to catalyst rows / types.
197+
*
198+
* Note: This should be called before do evaluation on Row
199+
* (It does not support UDT)
200+
* This is used to create an RDD or test results with correct types for Catalyst.
201+
*/
202+
def convertToCatalyst(a: Any): Any = a match {
203+
case s: String => UTF8String(s)
204+
case d: java.sql.Date => DateUtils.fromJavaDate(d)
205+
case d: BigDecimal => Decimal(d)
206+
case d: java.math.BigDecimal => Decimal(d)
207+
case seq: Seq[Any] => seq.map(convertToCatalyst)
208+
case r: Row => Row(r.toSeq.map(convertToCatalyst): _*)
209+
case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray
210+
case m: Map[Any, Any] =>
211+
m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap
212+
case other => other
213+
}
214+
187215
/**
188216
* Converts Catalyst types used internally in rows to standard Scala types
189217
* This method is slow, and for batch conversion you should be using converter
@@ -211,6 +239,9 @@ object CatalystTypeConverters {
211239
case (i: Int, DateType) =>
212240
DateUtils.toJavaDate(i)
213241

242+
case (s: UTF8String, StringType) =>
243+
s.toString()
244+
214245
case (other, _) =>
215246
other
216247
}
@@ -262,6 +293,12 @@ object CatalystTypeConverters {
262293
case other => other
263294
}
264295

296+
case StringType =>
297+
(item: Any) => item match {
298+
case s: UTF8String => s.toString()
299+
case other => other
300+
}
301+
265302
case other =>
266303
(item: Any) => item
267304
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ trait ScalaReflection {
138138
// The data type can be determined without ambiguity.
139139
case obj: BooleanType.JvmType => BooleanType
140140
case obj: BinaryType.JvmType => BinaryType
141+
case obj: String => StringType
141142
case obj: StringType.JvmType => StringType
142143
case obj: ByteType.JvmType => ByteType
143144
case obj: ShortType.JvmType => ShortType

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ trait HiveTypeCoercion {
115115
* the appropriate numeric equivalent.
116116
*/
117117
object ConvertNaNs extends Rule[LogicalPlan] {
118-
val stringNaN = Literal.create("NaN", StringType)
118+
val stringNaN = Literal("NaN")
119119

120120
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
121121
case q: LogicalPlan => q transformExpressions {
@@ -563,6 +563,10 @@ trait HiveTypeCoercion {
563563
case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType))
564564
case Average(e @ TimestampType()) => Average(Cast(e, DoubleType))
565565

566+
// Compatible with Hive
567+
case Substring(e, start, len) if e.dataType != StringType =>
568+
Substring(Cast(e, StringType), start, len)
569+
566570
// Coalesce should return the first non-null value, which could be any column
567571
// from the list. So we need to make sure the return type is deterministic and
568572
// compatible with every child column.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import java.sql.{Date, Timestamp}
2121
import java.text.{DateFormat, SimpleDateFormat}
2222

2323
import org.apache.spark.Logging
24-
import org.apache.spark.sql.catalyst.errors.TreeNodeException
2524
import org.apache.spark.sql.types._
2625

2726
/** Cast the child expression to the target data type. */
@@ -112,21 +111,21 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
112111

113112
// UDFToString
114113
private[this] def castToString(from: DataType): Any => Any = from match {
115-
case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8"))
116-
case DateType => buildCast[Int](_, d => DateUtils.toString(d))
117-
case TimestampType => buildCast[Timestamp](_, timestampToString)
118-
case _ => buildCast[Any](_, _.toString)
114+
case BinaryType => buildCast[Array[Byte]](_, UTF8String(_))
115+
case DateType => buildCast[Int](_, d => UTF8String(DateUtils.toString(d)))
116+
case TimestampType => buildCast[Timestamp](_, t => UTF8String(timestampToString(t)))
117+
case _ => buildCast[Any](_, o => UTF8String(o.toString))
119118
}
120119

121120
// BinaryConverter
122121
private[this] def castToBinary(from: DataType): Any => Any = from match {
123-
case StringType => buildCast[String](_, _.getBytes("UTF-8"))
122+
case StringType => buildCast[UTF8String](_, _.getBytes)
124123
}
125124

126125
// UDFToBoolean
127126
private[this] def castToBoolean(from: DataType): Any => Any = from match {
128127
case StringType =>
129-
buildCast[String](_, _.length() != 0)
128+
buildCast[UTF8String](_, _.length() != 0)
130129
case TimestampType =>
131130
buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0)
132131
case DateType =>
@@ -151,8 +150,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
151150
// TimestampConverter
152151
private[this] def castToTimestamp(from: DataType): Any => Any = from match {
153152
case StringType =>
154-
buildCast[String](_, s => {
153+
buildCast[UTF8String](_, utfs => {
155154
// Throw away extra if more than 9 decimal places
155+
val s = utfs.toString
156156
val periodIdx = s.indexOf(".")
157157
var n = s
158158
if (periodIdx != -1 && n.length() - periodIdx > 9) {
@@ -227,8 +227,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
227227
// DateConverter
228228
private[this] def castToDate(from: DataType): Any => Any = from match {
229229
case StringType =>
230-
buildCast[String](_, s =>
231-
try DateUtils.fromJavaDate(Date.valueOf(s))
230+
buildCast[UTF8String](_, s =>
231+
try DateUtils.fromJavaDate(Date.valueOf(s.toString))
232232
catch { case _: java.lang.IllegalArgumentException => null }
233233
)
234234
case TimestampType =>
@@ -245,7 +245,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
245245
// LongConverter
246246
private[this] def castToLong(from: DataType): Any => Any = from match {
247247
case StringType =>
248-
buildCast[String](_, s => try s.toLong catch {
248+
buildCast[UTF8String](_, s => try s.toString.toLong catch {
249249
case _: NumberFormatException => null
250250
})
251251
case BooleanType =>
@@ -261,7 +261,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
261261
// IntConverter
262262
private[this] def castToInt(from: DataType): Any => Any = from match {
263263
case StringType =>
264-
buildCast[String](_, s => try s.toInt catch {
264+
buildCast[UTF8String](_, s => try s.toString.toInt catch {
265265
case _: NumberFormatException => null
266266
})
267267
case BooleanType =>
@@ -277,7 +277,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
277277
// ShortConverter
278278
private[this] def castToShort(from: DataType): Any => Any = from match {
279279
case StringType =>
280-
buildCast[String](_, s => try s.toShort catch {
280+
buildCast[UTF8String](_, s => try s.toString.toShort catch {
281281
case _: NumberFormatException => null
282282
})
283283
case BooleanType =>
@@ -293,7 +293,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
293293
// ByteConverter
294294
private[this] def castToByte(from: DataType): Any => Any = from match {
295295
case StringType =>
296-
buildCast[String](_, s => try s.toByte catch {
296+
buildCast[UTF8String](_, s => try s.toString.toByte catch {
297297
case _: NumberFormatException => null
298298
})
299299
case BooleanType =>
@@ -323,7 +323,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
323323

324324
private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
325325
case StringType =>
326-
buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch {
326+
buildCast[UTF8String](_, s => try {
327+
changePrecision(Decimal(s.toString.toDouble), target)
328+
} catch {
327329
case _: NumberFormatException => null
328330
})
329331
case BooleanType =>
@@ -348,7 +350,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
348350
// DoubleConverter
349351
private[this] def castToDouble(from: DataType): Any => Any = from match {
350352
case StringType =>
351-
buildCast[String](_, s => try s.toDouble catch {
353+
buildCast[UTF8String](_, s => try s.toString.toDouble catch {
352354
case _: NumberFormatException => null
353355
})
354356
case BooleanType =>
@@ -364,7 +366,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
364366
// FloatConverter
365367
private[this] def castToFloat(from: DataType): Any => Any = from match {
366368
case StringType =>
367-
buildCast[String](_, s => try s.toFloat catch {
369+
buildCast[UTF8String](_, s => try s.toString.toFloat catch {
368370
case _: NumberFormatException => null
369371
})
370372
case BooleanType =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,13 +230,17 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
230230
new GenericRow(newValues)
231231
}
232232

233-
override def update(ordinal: Int, value: Any): Unit = {
234-
if (value == null) setNullAt(ordinal) else values(ordinal).update(value)
233+
override def update(ordinal: Int, value: Any) {
234+
if (value == null) {
235+
setNullAt(ordinal)
236+
} else {
237+
values(ordinal).update(value)
238+
}
235239
}
236240

237-
override def setString(ordinal: Int, value: String): Unit = update(ordinal, value)
241+
override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String(value))
238242

239-
override def getString(ordinal: Int): String = apply(ordinal).asInstanceOf[String]
243+
override def getString(ordinal: Int): String = apply(ordinal).toString
240244

241245
override def setInt(ordinal: Int, value: Int): Unit = {
242246
val currentValue = values(ordinal).asInstanceOf[MutableInt]

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
216216
val $primitiveTerm: ${termForType(dataType)} = $value
217217
""".children
218218

219-
case expressions.Literal(value: String, dataType) =>
219+
case expressions.Literal(value: UTF8String, dataType) =>
220220
q"""
221221
val $nullTerm = ${value == null}
222-
val $primitiveTerm: ${termForType(dataType)} = $value
222+
val $primitiveTerm: ${termForType(dataType)} =
223+
org.apache.spark.sql.types.UTF8String(${value.getBytes})
223224
""".children
224225

225226
case expressions.Literal(value: Int, dataType) =>
@@ -243,11 +244,14 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
243244
if($nullTerm)
244245
${defaultPrimitive(StringType)}
245246
else
246-
new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]])
247+
org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]])
247248
""".children
248249

249250
case Cast(child @ DateType(), StringType) =>
250-
child.castOrNull(c => q"org.apache.spark.sql.types.DateUtils.toString($c)", StringType)
251+
child.castOrNull(c =>
252+
q"""org.apache.spark.sql.types.UTF8String(
253+
org.apache.spark.sql.types.DateUtils.toString($c))""",
254+
StringType)
251255

252256
case Cast(child @ NumericType(), IntegerType) =>
253257
child.castOrNull(c => q"$c.toInt", IntegerType)
@@ -272,9 +276,18 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
272276
if($nullTerm)
273277
${defaultPrimitive(StringType)}
274278
else
275-
${eval.primitiveTerm}.toString
279+
org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString)
276280
""".children
277281

282+
case EqualTo(e1: BinaryType, e2: BinaryType) =>
283+
(e1, e2).evaluateAs (BooleanType) {
284+
case (eval1, eval2) =>
285+
q"""
286+
java.util.Arrays.equals($eval1.asInstanceOf[Array[Byte]],
287+
$eval2.asInstanceOf[Array[Byte]])
288+
"""
289+
}
290+
278291
case EqualTo(e1, e2) =>
279292
(e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" }
280293

@@ -597,7 +610,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
597610
val localLogger = log
598611
val localLoggerTree = reify { localLogger }
599612
q"""
600-
$localLoggerTree.debug(${e.toString} + ": " + (if($nullTerm) "null" else $primitiveTerm))
613+
$localLoggerTree.debug(
614+
${e.toString} + ": " + (if ($nullTerm) "null" else $primitiveTerm.toString))
601615
""" :: Nil
602616
} else {
603617
Nil
@@ -608,6 +622,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
608622

609623
protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = {
610624
dataType match {
625+
case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]"
611626
case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)"
612627
case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]"
613628
}
@@ -619,6 +634,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
619634
ordinal: Int,
620635
value: TermName) = {
621636
dataType match {
637+
case StringType => q"$destinationRow.update($ordinal, $value)"
622638
case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
623639
case _ => q"$destinationRow.update($ordinal, $value)"
624640
}
@@ -642,13 +658,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
642658
case DoubleType => "Double"
643659
case FloatType => "Float"
644660
case BooleanType => "Boolean"
645-
case StringType => "String"
661+
case StringType => "org.apache.spark.sql.types.UTF8String"
646662
}
647663

648664
protected def defaultPrimitive(dt: DataType) = dt match {
649665
case BooleanType => ru.Literal(Constant(false))
650666
case FloatType => ru.Literal(Constant(-1.0.toFloat))
651-
case StringType => ru.Literal(Constant("<uninit>"))
667+
case StringType => q"""org.apache.spark.sql.types.UTF8String("<uninit>")"""
652668
case ShortType => ru.Literal(Constant(-1.toShort))
653669
case LongType => ru.Literal(Constant(-1L))
654670
case ByteType => ru.Literal(Constant(-1.toByte))

0 commit comments

Comments
 (0)