From 685fd071ce453cc6b956f98c897c869ad31702a9 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 30 Mar 2015 22:42:07 -0700 Subject: [PATCH 01/30] use UTF8String instead of String for StringType --- .../main/scala/org/apache/spark/sql/Row.scala | 16 ++- .../spark/sql/catalyst/ScalaReflection.scala | 3 + .../apache/spark/sql/catalyst/SqlParser.scala | 2 +- .../catalyst/analysis/HiveTypeCoercion.scala | 2 +- .../spark/sql/catalyst/expressions/Cast.scala | 34 +++--- .../expressions/SpecificMutableRow.scala | 34 +++++- .../expressions/codegen/CodeGenerator.scala | 10 +- .../codegen/GenerateProjection.scala | 40 +++++-- .../sql/catalyst/expressions/literals.scala | 6 +- .../sql/catalyst/expressions/predicates.scala | 3 +- .../spark/sql/catalyst/expressions/rows.scala | 12 +- .../expressions/stringOperations.scala | 92 +++++++++------ .../apache/spark/sql/types/UTF8String.scala | 108 ++++++++++++++++++ .../apache/spark/sql/types/dataTypes.scala | 9 +- .../org/apache/spark/sql/DataFrame.scala | 3 +- .../spark/sql/columnar/ColumnStats.scala | 4 +- .../spark/sql/columnar/ColumnType.scala | 16 +-- .../spark/sql/execution/debug/package.scala | 1 + .../org/apache/spark/sql/jdbc/JDBCRDD.scala | 1 + .../org/apache/spark/sql/jdbc/jdbc.scala | 7 +- .../org/apache/spark/sql/json/JsonRDD.scala | 4 +- .../spark/sql/parquet/ParquetConverter.scala | 19 ++- .../spark/sql/parquet/ParquetFilters.scala | 12 +- .../sql/parquet/ParquetTableSupport.scala | 4 +- .../apache/spark/sql/parquet/newParquet.scala | 3 +- .../org/apache/spark/sql/JavaRowSuite.java | 13 +-- .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- .../scala/org/apache/spark/sql/RowSuite.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 1 + .../spark/sql/columnar/ColumnTypeSuite.scala | 4 +- .../sql/columnar/ColumnarTestUtils.scala | 4 +- .../ParquetPartitionDiscoverySuite.scala | 4 +- 32 files changed, 343 insertions(+), 132 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index d794f034f5578..92717a60ce6bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.util.hashing.MurmurHash3 import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.sql.types.{StructType, DateUtils} +import org.apache.spark.sql.types.{UTF8String, StructType} object Row { /** @@ -39,12 +39,22 @@ object Row { /** * This method can be used to construct a [[Row]] with the given values. */ - def apply(values: Any*): Row = new GenericRow(values.toArray) + def apply(values: Any*): Row = { + new GenericRow(values.map { + case s: String => UTF8String(s) + case other => other + }.toArray) + } /** * This method can be used to construct a [[Row]] from a [[Seq]] of values. */ - def fromSeq(values: Seq[Any]): Row = new GenericRow(values.toArray) + def fromSeq(values: Seq[Any]): Row = { + new GenericRow(values.map { + case s: String => UTF8String(s) + case other => other + }.toArray) + } def fromTuple(tuple: Product): Row = fromSeq(tuple.productIterator.toSeq) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 2220970085462..e24603530e428 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -72,6 +72,7 @@ trait ScalaReflection { case (d: BigDecimal, _) => Decimal(d) case (d: java.math.BigDecimal, _) => Decimal(d) case (d: java.sql.Date, _) => DateUtils.fromJavaDate(d) + case (s: String, st: StringType) => UTF8String(s) case (other, _) => other } @@ -86,6 +87,7 @@ trait ScalaReflection { case (r: Row, s: StructType) => convertRowToScala(r, s) case (d: Decimal, _: DecimalType) => d.toJavaBigDecimal case (i: Int, DateType) => DateUtils.toJavaDate(i) + case (s: UTF8String, StringType) => s.toString() case (other, _) => other } @@ -188,6 +190,7 @@ trait ScalaReflection { // The data type can be determined without ambiguity. case obj: BooleanType.JvmType => BooleanType case obj: BinaryType.JvmType => BinaryType + case obj: String => StringType case obj: StringType.JvmType => StringType case obj: ByteType.JvmType => ByteType case obj: ShortType.JvmType => ShortType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index b176f7e729a42..1e697749eeda1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -316,7 +316,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val literal: Parser[Literal] = ( numericLiteral | booleanLiteral - | stringLit ^^ {case s => Literal(s, StringType) } + | stringLit ^^ {case s => Literal(UTF8String(s), StringType) } | NULL ^^^ Literal(null, NullType) ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 34ef7d28cc7f2..4f54086c2b8d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -114,7 +114,7 @@ trait HiveTypeCoercion { * the appropriate numeric equivalent. */ object ConvertNaNs extends Rule[LogicalPlan] { - val stringNaN = Literal("NaN", StringType) + val stringNaN = Literal("NaN") def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 31f1a5fdc7e53..7935723418903 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,7 +21,6 @@ import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.types._ /** Cast the child expression to the target data type. */ @@ -112,21 +111,21 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // UDFToString private[this] def castToString(from: DataType): Any => Any = from match { - case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8")) - case DateType => buildCast[Int](_, d => DateUtils.toString(d)) - case TimestampType => buildCast[Timestamp](_, timestampToString) - case _ => buildCast[Any](_, _.toString) + case BinaryType => buildCast[Array[Byte]](_, UTF8String(_)) + case DateType => buildCast[Int](_, d => UTF8String(DateUtils.toString(d))) + case TimestampType => buildCast[Timestamp](_, t => UTF8String(timestampToString(t))) + case _ => buildCast[Any](_, o => UTF8String(o.toString)) } // BinaryConverter private[this] def castToBinary(from: DataType): Any => Any = from match { - case StringType => buildCast[String](_, _.getBytes("UTF-8")) + case StringType => buildCast[UTF8String](_, _.getBytes("UTF-8")) } // UDFToBoolean private[this] def castToBoolean(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, _.length() != 0) + buildCast[UTF8String](_, _.length() != 0) case TimestampType => buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0) case DateType => @@ -151,8 +150,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // TimestampConverter private[this] def castToTimestamp(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => { + buildCast[UTF8String](_, utfs => { // Throw away extra if more than 9 decimal places + val s = utfs.toString val periodIdx = s.indexOf(".") var n = s if (periodIdx != -1 && n.length() - periodIdx > 9) { @@ -227,8 +227,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // DateConverter private[this] def castToDate(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => - try DateUtils.fromJavaDate(Date.valueOf(s)) + buildCast[UTF8String](_, s => + try DateUtils.fromJavaDate(Date.valueOf(s.toString)) catch { case _: java.lang.IllegalArgumentException => null } ) case TimestampType => @@ -245,7 +245,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // LongConverter private[this] def castToLong(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toLong catch { + buildCast[UTF8String](_, s => try s.toString.toLong catch { case _: NumberFormatException => null }) case BooleanType => @@ -261,7 +261,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // IntConverter private[this] def castToInt(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toInt catch { + buildCast[UTF8String](_, s => try s.toString.toInt catch { case _: NumberFormatException => null }) case BooleanType => @@ -277,7 +277,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // ShortConverter private[this] def castToShort(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toShort catch { + buildCast[UTF8String](_, s => try s.toString.toShort catch { case _: NumberFormatException => null }) case BooleanType => @@ -293,7 +293,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // ByteConverter private[this] def castToByte(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toByte catch { + buildCast[UTF8String](_, s => try s.toString.toByte catch { case _: NumberFormatException => null }) case BooleanType => @@ -323,7 +323,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { case StringType => - buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch { + buildCast[UTF8String](_, s => try changePrecision(Decimal(s.toString.toDouble), target) catch { case _: NumberFormatException => null }) case BooleanType => @@ -348,7 +348,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // DoubleConverter private[this] def castToDouble(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toDouble catch { + buildCast[UTF8String](_, s => try s.toString.toDouble catch { case _: NumberFormatException => null }) case BooleanType => @@ -364,7 +364,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // FloatConverter private[this] def castToFloat(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toFloat catch { + buildCast[UTF8String](_, s => try s.toString.toFloat catch { case _: NumberFormatException => null }) case BooleanType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 47b6f358ed1b1..05b8ac5fba279 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -170,6 +170,21 @@ final class MutableByte extends MutableValue { } } +//final class MutableString extends MutableValue { +// var value: UTF8String = _ +// override def boxed: Any = if (isNull) null else value +// override def update(v: Any): Unit = { +// isNull = false +// value = v.asInstanceOf[UTF8String] +// } +// override def copy(): MutableString = { +// val newCopy = new MutableString +// newCopy.isNull = isNull +// newCopy.value = value +// newCopy.asInstanceOf[MutableString] +// } +//} + final class MutableAny extends MutableValue { var value: Any = _ override def boxed: Any = if (isNull) null else value @@ -202,6 +217,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR case DoubleType => new MutableDouble case BooleanType => new MutableBoolean case LongType => new MutableLong +// case StringType => new MutableString case _ => new MutableAny }.toArray) @@ -230,13 +246,23 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR new GenericRow(newValues) } - override def update(ordinal: Int, value: Any): Unit = { - if (value == null) setNullAt(ordinal) else values(ordinal).update(value) + override def update(ordinal: Int, value: Any): Unit = value match { + case null => setNullAt(ordinal) + case s: String => update(ordinal, UTF8String(s)) + case other => values(ordinal).update(value) } - override def setString(ordinal: Int, value: String): Unit = update(ordinal, value) + override def setString(ordinal: Int, value: String): Unit = { + update(ordinal, UTF8String(value)) + } - override def getString(ordinal: Int): String = apply(ordinal).asInstanceOf[String] + override def getString(ordinal: Int): String = { + //TODO(davies): FIXME + apply(ordinal) match { + case s: UTF8String => s.toString() + //case ms: MutableString => ms.value.toString() + } + } override def setInt(ordinal: Int, value: Int): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableInt] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d1abf3c0b64a5..0a291ca7c037e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -216,6 +216,12 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val $primitiveTerm: ${termForType(dataType)} = $value """.children +// case expressions.Literal(value: UTF8String, dataType) => +// q""" +// val $nullTerm = ${value == null} +// val $primitiveTerm: ${termForType(dataType)} = $value +// """.children + case expressions.Literal(value: String, dataType) => q""" val $nullTerm = ${value == null} @@ -243,7 +249,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin if($nullTerm) ${defaultPrimitive(StringType)} else - new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) + UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) """.children case Cast(child @ DateType(), StringType) => @@ -584,6 +590,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { dataType match { + case StringType => q"$inputRow.apply($ordinal).asInstanceOf[UTF8String]" case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)" case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]" } @@ -595,6 +602,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin ordinal: Int, value: TermName) = { dataType match { + case StringType => q"$destinationRow.setString($ordinal, $value)" case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" case _ => q"$destinationRow.update($ordinal, $value)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 69397a73a8880..dc0a29611a779 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -118,12 +118,20 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { q"if(i == $i) return $elementName" :: Nil case _ => Nil } - - q""" - override def ${accessorForType(dataType)}(i: Int):${termForType(dataType)} = { - ..$ifStatements; - $accessorFailure - }""" + dataType match { + case StringType => + q""" + override def getString(i: Int): String = { + ..$ifStatements; + $accessorFailure + }""" + case other => + q""" + override def ${accessorForType(dataType)}(i: Int):${termForType(dataType)} = { + ..$ifStatements; + $accessorFailure + }""" + } } val specificMutatorFunctions = NativeType.all.map { dataType => @@ -135,12 +143,20 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil case _ => Nil } - - q""" - override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}): Unit = { - ..$ifStatements; - $accessorFailure - }""" + dataType match { + case StringType => + q""" + override def setString(i: Int, value: String): Unit = { + ..$ifStatements; + $accessorFailure + }""" + case other => + q""" + override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}): Unit = { + ..$ifStatements; + $accessorFailure + }""" + } } val hashValues = expressions.zipWithIndex.map { case (e,i) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 19f3fc9c2291a..f9ed478e498c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -29,7 +29,7 @@ object Literal { case f: Float => Literal(f, FloatType) case b: Byte => Literal(b, ByteType) case s: Short => Literal(s, ShortType) - case s: String => Literal(s, StringType) + case s: String => Literal(UTF8String(s), StringType) case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) @@ -70,7 +70,9 @@ case class Literal(value: Any, dataType: DataType) extends LeafExpression { override def toString: String = if (value != null) value.toString else "null" type EvaluatedType = Any - override def eval(input: Row): Any = value + override def eval(input: Row): Any = { + value + } } // TODO: Specialize diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 7e47cb3fffe12..b34fa4ac10921 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -179,8 +179,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison val r = right.eval(input) if (r == null) null else if (left.dataType != BinaryType) l == r - else BinaryType.ordering.compare( - l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) == 0 + else BinaryType.ordering.equiv(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index a8983df208318..e32c03cb39dd0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.{StructType, NativeType} +import org.apache.spark.sql.types.{UTF8String, StructType, NativeType} /** @@ -114,7 +114,12 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { } override def getString(i: Int): String = { - values(i).asInstanceOf[String] + val utf8 = UTF8String(values(i)) + if (utf8 != null) { + utf8.toString + } else { + null + } } // Custom hashCode function that matches the efficient code generated version. @@ -189,8 +194,7 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value } override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value } override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value } - override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = value } - + override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = UTF8String(value) } override def setNullAt(i: Int): Unit = { values(i) = null } override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 3cdca4e9dd2d1..ae7031a5fbcd7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -23,7 +23,7 @@ import scala.collection.IndexedSeqOptimized import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType, StringType} +import org.apache.spark.sql.types._ trait StringRegexExpression { self: BinaryExpression => @@ -60,38 +60,17 @@ trait StringRegexExpression { if(r == null) { null } else { - val regex = pattern(r.asInstanceOf[String]) + val regex = pattern(r.asInstanceOf[UTF8String].toString) if(regex == null) { null } else { - matches(regex, l.asInstanceOf[String]) + matches(regex, l.asInstanceOf[UTF8String].toString) } } } } } -trait CaseConversionExpression { - self: UnaryExpression => - - type EvaluatedType = Any - - def convert(v: String): String - - override def foldable: Boolean = child.foldable - def nullable: Boolean = child.nullable - def dataType: DataType = StringType - - override def eval(input: Row): Any = { - val evaluated = child.eval(input) - if (evaluated == null) { - null - } else { - convert(evaluated.toString) - } - } -} - /** * Simple RegEx pattern matching function */ @@ -134,12 +113,33 @@ case class RLike(left: Expression, right: Expression) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) } +trait CaseConversionExpression { + self: UnaryExpression => + + type EvaluatedType = Any + + def convert(v: UTF8String): UTF8String + + override def foldable: Boolean = child.foldable + def nullable: Boolean = child.nullable + def dataType: DataType = StringType + + override def eval(input: Row): Any = { + val evaluated = child.eval(input) + if (evaluated == null) { + null + } else { + convert(UTF8String(evaluated)) + } + } +} + /** * A function that converts the characters of a string to uppercase. */ case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: String): String = v.toUpperCase() + override def convert(v: UTF8String): UTF8String = v.toUpperCase override def toString: String = s"Upper($child)" } @@ -149,7 +149,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE */ case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: String): String = v.toLowerCase() + override def convert(v: UTF8String): UTF8String = v.toLowerCase override def toString: String = s"Lower($child)" } @@ -163,15 +163,15 @@ trait StringComparison { override def nullable: Boolean = left.nullable || right.nullable override def dataType: DataType = BooleanType - def compare(l: String, r: String): Boolean + def compare(l: UTF8String, r: UTF8String): Boolean override def eval(input: Row): Any = { - val leftEval = left.eval(input).asInstanceOf[String] + val leftEval = left.eval(input) if(leftEval == null) { null } else { - val rightEval = right.eval(input).asInstanceOf[String] - if (rightEval == null) null else compare(leftEval, rightEval) + val rightEval = right.eval(input) + if (rightEval == null) null else compare(UTF8String(leftEval), UTF8String(rightEval)) } } @@ -185,7 +185,7 @@ trait StringComparison { */ case class Contains(left: Expression, right: Expression) extends BinaryExpression with StringComparison { - override def compare(l: String, r: String): Boolean = l.contains(r) + override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) } /** @@ -193,7 +193,7 @@ case class Contains(left: Expression, right: Expression) */ case class StartsWith(left: Expression, right: Expression) extends BinaryExpression with StringComparison { - override def compare(l: String, r: String): Boolean = l.startsWith(r) + override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) } /** @@ -201,7 +201,7 @@ case class StartsWith(left: Expression, right: Expression) */ case class EndsWith(left: Expression, right: Expression) extends BinaryExpression with StringComparison { - override def compare(l: String, r: String): Boolean = l.endsWith(r) + override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) } /** @@ -248,6 +248,29 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends str.slice(start, end) } + @inline + def slice(str: UTF8String, startPos: Int, sliceLen: Int): Any = { + val len = str.length + // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and + // negative indices for start positions. If a start index i is greater than 0, it + // refers to element i-1 in the sequence. If a start index i is less than 0, it refers + // to the -ith element before the end of the sequence. If a start index i is 0, it + // refers to the first element. + + val start = startPos match { + case pos if pos > 0 => pos - 1 + case neg if neg < 0 => len + neg + case _ => 0 + } + + val end = sliceLen match { + case max if max == Integer.MAX_VALUE => max + case x => start + x + } + + str.slice(start, end) + } + override def eval(input: Row): Any = { val string = str.eval(input) @@ -262,7 +285,8 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends string match { case ba: Array[Byte] => slice(ba, start, length) - case other => slice(other.toString, start, length) + case s: UTF8String => slice(s, start, length) + case other => slice(UTF8String(other), start, length) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala new file mode 100644 index 0000000000000..a8adfa823ffb3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala @@ -0,0 +1,108 @@ +package org.apache.spark.sql.types + +/** + * A mutable UTF-8 String + */ + +final class UTF8String extends Ordered[UTF8String] with Serializable { + private var s: String = _ + def set(str: String): UTF8String = { + this.s = str + this + } + + def set(bytes: Array[Byte]): UTF8String = { + this.s = new String(bytes, "utf-8") + this + } + + def set(a: UTF8String): UTF8String = { + this.s = a.s + this + } + + def length(): Int = { + this.s.length + } + + def getBytes(): Array[Byte] = { + this.s.getBytes("utf-8") + } + + def getBytes(encoding: String): Array[Byte] = { + this.s.getBytes(encoding) + } + + def slice(start: Int, end: Int): UTF8String = { + UTF8String(this.s.slice(start, end)) + } + + def contains(sub: UTF8String): Boolean = { + this.s.contains(sub.s) + } + + def startsWith(prefix: UTF8String): Boolean = { + this.s.startsWith(prefix.s) + } + + def endsWith(suffix: UTF8String): Boolean = { + this.s.endsWith(suffix.s) + } + + def toUpperCase(): UTF8String = { + UTF8String(s.toUpperCase) + } + + def toLowerCase(): UTF8String = { + UTF8String(s.toLowerCase) + } + + override def toString(): String = { + this.s + } + + override def clone(): UTF8String = new UTF8String().set(this) + + override def compare(other: UTF8String): Int = { + this.s.compare(other.s) + } + + def compare(other: String): Int = { + this.s.compare(other) + } + + override def compareTo(other: UTF8String): Int = { + this.s.compareTo(other.s) + } + + def compareTo(other: String): Int = { + this.s.compareTo(other) + } + + override def equals(other: Any): Boolean = other match { + case s: UTF8String => + compare(s) == 0 + case s: String => + this.s.compare(s) == 0 + case _ => + false + } + + override def hashCode(): Int = { + this.s.hashCode + } +} + +object UTF8String { + implicit def apply(s: String): UTF8String = new UTF8String().set(s) + implicit def toString(utf: UTF8String): String = utf.toString + def apply(bytes: Array[Byte]): UTF8String = new UTF8String().set(bytes) + def apply(utf8: UTF8String): UTF8String = utf8 + def apply(o: Any): UTF8String = o match { + case null => null + case utf8: UTF8String => utf8 + case s: String => new UTF8String().set(s) + case bytes: Array[Byte]=> new UTF8String().set(bytes) + case other => new UTF8String().set(other.toString) + } +} \ No newline at end of file diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index 952cf5c75688d..499eb1e0a2929 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -349,7 +349,7 @@ class StringType private() extends NativeType with PrimitiveType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "StringType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = String + private[sql] type JvmType = UTF8String @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] @@ -386,6 +386,13 @@ class BinaryType private() extends NativeType with PrimitiveType { } x.length - y.length } + override def equiv(x: Array[Byte], y: Array[Byte]): Boolean = { + if (x.length != y.length) { + false + } else { + compare(x, y) == 0 + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 423ef3912bc89..436b088079ada 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -723,7 +723,8 @@ class DataFrame private[sql]( val dataType = ScalaReflection.schemaFor[B].dataType val attributes = AttributeReference(outputColumn, dataType)() :: Nil def rowFunction(row: Row): TraversableOnce[Row] = { - f(row(0).asInstanceOf[A]).map(o => Row(ScalaReflection.convertToCatalyst(o, dataType))) + f((if (row(0).isInstanceOf[UTF8String]) row(0).toString else row(0)).asInstanceOf[A]). + map(o => Row(ScalaReflection.convertToCatalyst(o, dataType))) } val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 87a6631da8300..eac6c06592729 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -216,8 +216,8 @@ private[sql] class IntColumnStats extends ColumnStats { } private[sql] class StringColumnStats extends ColumnStats { - protected var upper: String = null - protected var lower: String = null + protected var upper: UTF8String = null + protected var lower: UTF8String = null override def gatherStats(row: Row, ordinal: Int): Unit = { super.gatherStats(row, ordinal) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index c47497e0662d9..d79bfa076cb26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import scala.reflect.runtime.universe.TypeTag @@ -312,23 +312,25 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { row.getString(ordinal).getBytes("utf-8").length + 4 } - override def append(v: String, buffer: ByteBuffer): Unit = { + override def append(v: UTF8String, buffer: ByteBuffer): Unit = { val stringBytes = v.getBytes("utf-8") buffer.putInt(stringBytes.length).put(stringBytes, 0, stringBytes.length) } - override def extract(buffer: ByteBuffer): String = { + override def extract(buffer: ByteBuffer): UTF8String = { val length = buffer.getInt() val stringBytes = new Array[Byte](length) buffer.get(stringBytes, 0, length) - new String(stringBytes, "utf-8") + UTF8String(stringBytes) } - override def setField(row: MutableRow, ordinal: Int, value: String): Unit = { - row.setString(ordinal, value) + override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = { + row.update(ordinal, value) } - override def getField(row: Row, ordinal: Int): String = row.getString(ordinal) + override def getField(row: Row, ordinal: Int): UTF8String = { + row.apply(ordinal).asInstanceOf[UTF8String] + } override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setString(toOrdinal, from.getString(fromOrdinal)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index e916e68e58b5d..7ca8c38163fb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -164,6 +164,7 @@ package object debug { case (_: Long, LongType) => case (_: Int, IntegerType) => + case (_: UTF8String, StringType) => case (_: String, StringType) => case (_: Float, FloatType) => case (_: Byte, ByteType) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 463e1dcc268bc..38b7e344d57cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -234,6 +234,7 @@ private[sql] class JDBCRDD( */ private def compileValue(value: Any): Any = value match { case stringValue: String => s"'${escapeSql(stringValue)}'" + case stringValue: UTF8String => s"'${escapeSql(stringValue.toString)}'" case _ => value } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index 34f864f5fda7a..5e7e41c982943 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -18,11 +18,8 @@ package org.apache.spark.sql import java.sql.{Connection, DriverManager, PreparedStatement} -import org.apache.spark.{Logging, Partition} -import org.apache.spark.sql._ -import org.apache.spark.sql.sources.LogicalRelation -import org.apache.spark.sql.jdbc.{JDBCPartitioningInfo, JDBCRelation, JDBCPartition} +import org.apache.spark.Logging import org.apache.spark.sql.types._ package object jdbc { @@ -80,7 +77,7 @@ package object jdbc { case ShortType => stmt.setInt(i + 1, row.getShort(i)) case ByteType => stmt.setInt(i + 1, row.getByte(i)) case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i)) - case StringType => stmt.setString(i + 1, row.getString(i)) + case StringType => stmt.setString(i + 1, row.getString(i).toString) case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i)) case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 2b0358c4e2a1e..6340b71957dbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -409,7 +409,7 @@ private[sql] object JsonRDD extends Logging { null } else { desiredType match { - case StringType => toString(value) + case StringType => UTF8String(toString(value)) case _ if value == null || value == "" => null // guard the non string type case IntegerType => value.asInstanceOf[IntegerType.JvmType] case LongType => toLong(value) @@ -421,6 +421,7 @@ private[sql] object JsonRDD extends Logging { value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) case MapType(StringType, valueType, _) => val map = value.asInstanceOf[Map[String, Any]] + //TODO(davies): use UTF8String for key? map.mapValues(enforceCorrectType(_, valueType)).map(identity) case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct) case DateType => toDate(value) @@ -451,6 +452,7 @@ private[sql] object JsonRDD extends Logging { def valWriter: (DataType, Any) => Unit = { case (_, null) | (NullType, _) => gen.writeNull() case (StringType, v: String) => gen.writeString(v) + case (StringType, v: UTF8String) => gen.writeString(v.toString) case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString) case (IntegerType, v: Int) => gen.writeNumber(v) case (ShortType, v: Short) => gen.writeNumber(v) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 43ca359b51735..5cf217e3e162a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -219,8 +219,8 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = updateField(fieldIndex, value.getBytes) - protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = - updateField(fieldIndex, value) + protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = + updateField(fieldIndex, UTF8String(value)) protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = updateField(fieldIndex, readTimestamp(value)) @@ -418,8 +418,8 @@ private[parquet] class CatalystPrimitiveRowConverter( override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = current.update(fieldIndex, value.getBytes) - override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = - current.setString(fieldIndex, value) + override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = + current.update(fieldIndex, UTF8String(value)) override protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = current.update(fieldIndex, readTimestamp(value)) @@ -475,19 +475,18 @@ private[parquet] class CatalystPrimitiveConverter( private[parquet] class CatalystPrimitiveStringConverter(parent: CatalystConverter, fieldIndex: Int) extends CatalystPrimitiveConverter(parent, fieldIndex) { - private[this] var dict: Array[String] = null + private[this] var dict: Array[Array[Byte]] = null override def hasDictionarySupport: Boolean = true override def setDictionary(dictionary: Dictionary):Unit = - dict = Array.tabulate(dictionary.getMaxId + 1) {dictionary.decodeToBinary(_).toStringUsingUTF8} - + dict = Array.tabulate(dictionary.getMaxId + 1) {dictionary.decodeToBinary(_).getBytes} override def addValueFromDictionary(dictionaryId: Int): Unit = parent.updateString(fieldIndex, dict(dictionaryId)) override def addBinary(value: Binary): Unit = - parent.updateString(fieldIndex, value.toStringUsingUTF8) + parent.updateString(fieldIndex, value.getBytes) } private[parquet] object CatalystArrayConverter { @@ -714,9 +713,9 @@ private[parquet] class CatalystNativeArrayConverter( elements += 1 } - override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = { + override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = { checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] + buffer(elements) = UTF8String(value).asInstanceOf[NativeType] elements += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index 0357dcc4688be..e693d13f51fb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -55,7 +55,7 @@ private[sql] object ParquetFilters { case StringType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) case BinaryType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), @@ -76,7 +76,7 @@ private[sql] object ParquetFilters { case StringType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes())).orNull) case BinaryType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), @@ -94,7 +94,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes())) case BinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -111,7 +111,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -128,7 +128,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -145,7 +145,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 5a1b15490d273..357baf6f15361 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -199,7 +199,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { schema match { case StringType => writer.addBinary( Binary.fromByteArray( - value.asInstanceOf[String].getBytes("utf-8") + value.asInstanceOf[UTF8String].getBytes("utf-8") ) ) case BinaryType => writer.addBinary( @@ -349,7 +349,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { index: Int): Unit = { ctype match { case StringType => writer.addBinary( - Binary.fromByteArray(record(index).asInstanceOf[String].getBytes("utf-8"))) + Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes("utf-8"))) case BinaryType => writer.addBinary( Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) case IntegerType => writer.addInteger(record.getInt(index)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 53f765ee26a13..854d57c612c8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -959,7 +959,8 @@ private[sql] object ParquetRelation2 extends Logging { .orElse(Try(Literal(new JBigDecimal(raw), DecimalType.Unlimited))) // Then falls back to string .getOrElse { - if (raw == defaultPartitionName) Literal(null, NullType) else Literal(raw, StringType) + if (raw == defaultPartitionName) Literal(null, NullType) else + Literal(UTF8String(raw), StringType) } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java index 4ce1d1dddb26a..3f9c1e7f55935 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java @@ -17,6 +17,12 @@ package test.org.apache.spark.sql; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + import java.math.BigDecimal; import java.sql.Date; import java.sql.Timestamp; @@ -25,13 +31,6 @@ import java.util.List; import java.util.Map; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; - public class JavaRowSuite { private byte byteValue; private short shortValue; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6761d996fd975..1d6f41ff4231d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -128,7 +128,7 @@ class DataFrameSuite extends QueryTest { val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDF("number", "letters") val df2 = df.explode('letters) { - case Row(letters: String) => letters.split(" ").map(Tuple1(_)).toSeq + case Row(letters: UTF8String) => letters.toString.split(" ").map(Tuple1(_)).toSeq } checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 36465cc2fa11a..bf6cf1321a056 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -30,7 +30,7 @@ class RowSuite extends FunSuite { test("create row") { val expected = new GenericMutableRow(4) expected.update(0, 2147483647) - expected.update(1, "this is a string") + expected.setString(1, "this is a string") expected.update(2, false) expected.update(3, null) val actual1 = Row(2147483647, "this is a string", false, null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a3c0076e16d6c..24752c3a14425 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -698,6 +698,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { val v4 = try values(3).toInt catch { case _: NumberFormatException => null } + print("rowRDD1",values, v4) Row(values(0).toInt, values(1), values(2).toBoolean, v4) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 5f08834f73c6b..1dba551e87fc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -65,7 +65,7 @@ class ColumnTypeSuite extends FunSuite with Logging { checkActualSize(FLOAT, Float.MaxValue, 4) checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8) checkActualSize(BOOLEAN, true, 1) - checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length) + checkActualSize(STRING, UTF8String("hello"), 4 + "hello".getBytes("utf-8").length) checkActualSize(DATE, 0, 4) checkActualSize(TIMESTAMP, new Timestamp(0L), 12) @@ -108,7 +108,7 @@ class ColumnTypeSuite extends FunSuite with Logging { testNativeColumnType[StringType.type]( STRING, - (buffer: ByteBuffer, string: String) => { + (buffer: ByteBuffer, string: UTF8String) => { val bytes = string.getBytes("utf-8") buffer.putInt(bytes.length) buffer.put(bytes) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index c7a40845db16c..33ac4d551c921 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -24,7 +24,7 @@ import scala.util.Random import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.{Decimal, DataType, NativeType} +import org.apache.spark.sql.types.{UTF8String, DataType, Decimal, NativeType} object ColumnarTestUtils { def makeNullRow(length: Int) = { @@ -48,7 +48,7 @@ object ColumnarTestUtils { case FLOAT => Random.nextFloat() case DOUBLE => Random.nextDouble() case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) - case STRING => Random.nextString(Random.nextInt(32)) + case STRING => UTF8String(Random.nextString(Random.nextInt(32))) case BOOLEAN => Random.nextBoolean() case BINARY => randomBytes(Random.nextInt(32)) case DATE => Random.nextInt() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index adb3c9391f6c2..865618bed9dd3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -48,7 +48,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { check("10", Literal(10, IntegerType)) check("1000000000000000", Literal(1000000000000000L, LongType)) check("1.5", Literal(1.5, FloatType)) - check("hello", Literal("hello", StringType)) + check("hello", Literal(UTF8String("hello"), StringType)) check(defaultPartitionName, Literal(null, NullType)) } @@ -83,7 +83,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { ArrayBuffer("a", "b", "c"), ArrayBuffer( Literal(10, IntegerType), - Literal("hello", StringType), + Literal(UTF8String("hello"), StringType), Literal(1.5, FloatType)))) check( From 21f67c6fda3504caa0b13524d4e498c6e4c9c701 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 31 Mar 2015 00:50:11 -0700 Subject: [PATCH 02/30] cleanup --- .../main/scala/org/apache/spark/sql/Row.scala | 16 +++--------- .../spark/sql/catalyst/ScalaReflection.scala | 3 ++- .../expressions/SpecificMutableRow.scala | 12 ++------- .../sql/catalyst/expressions/generators.scala | 1 + .../sql/catalyst/expressions/literals.scala | 12 +++++---- .../spark/sql/catalyst/expressions/rows.scala | 9 +++---- .../expressions/stringOperations.scala | 6 ++--- .../apache/spark/sql/types/UTF8String.scala | 19 +++++++------- .../org/apache/spark/sql/DataFrame.scala | 13 +++++++--- .../org/apache/spark/sql/SQLContext.scala | 6 +++-- .../spark/sql/columnar/ColumnStats.scala | 2 +- .../spark/sql/execution/ExistingRDD.scala | 25 ++++++++++++++++--- .../org/apache/spark/sql/jdbc/jdbc.scala | 2 +- .../org/apache/spark/sql/json/JsonRDD.scala | 1 - .../apache/spark/sql/parquet/newParquet.scala | 2 +- .../org/apache/spark/sql/DataFrameSuite.scala | 1 + .../org/apache/spark/sql/SQLQuerySuite.scala | 10 +++----- .../spark/sql/columnar/ColumnTypeSuite.scala | 2 +- .../ParquetPartitionDiscoverySuite.scala | 4 +-- 19 files changed, 76 insertions(+), 70 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 92717a60ce6bd..fb2346c1b1831 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.util.hashing.MurmurHash3 import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.sql.types.{UTF8String, StructType} +import org.apache.spark.sql.types.StructType object Row { /** @@ -39,22 +39,12 @@ object Row { /** * This method can be used to construct a [[Row]] with the given values. */ - def apply(values: Any*): Row = { - new GenericRow(values.map { - case s: String => UTF8String(s) - case other => other - }.toArray) - } + def apply(values: Any*): Row = new GenericRow(values.toArray) /** * This method can be used to construct a [[Row]] from a [[Seq]] of values. */ - def fromSeq(values: Seq[Any]): Row = { - new GenericRow(values.map { - case s: String => UTF8String(s) - case other => other - }.toArray) - } + def fromSeq(values: Seq[Any]): Row = new GenericRow(values.toArray) def fromTuple(tuple: Product): Row = fromSeq(tuple.productIterator.toSeq) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index e24603530e428..2945303e39aac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -87,7 +87,8 @@ trait ScalaReflection { case (r: Row, s: StructType) => convertRowToScala(r, s) case (d: Decimal, _: DecimalType) => d.toJavaBigDecimal case (i: Int, DateType) => DateUtils.toJavaDate(i) - case (s: UTF8String, StringType) => s.toString() + case (s: UTF8String, StringType) => + s.toString() case (other, _) => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 05b8ac5fba279..85c9a4c00a2b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -252,17 +252,9 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR case other => values(ordinal).update(value) } - override def setString(ordinal: Int, value: String): Unit = { - update(ordinal, UTF8String(value)) - } + override def setString(ordinal: Int, value: String): Unit = update(ordinal, value) - override def getString(ordinal: Int): String = { - //TODO(davies): FIXME - apply(ordinal) match { - case s: UTF8String => s.toString() - //case ms: MutableString => ms.value.toString() - } - } + override def getString(ordinal: Int): String = apply(ordinal).toString override def setInt(ordinal: Int, value: Int): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableInt] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 860b72fad38b3..99c5fafe593f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -86,6 +86,7 @@ case class UserDefinedGenerator( override def eval(input: Row): TraversableOnce[Row] = { val inputRow = new InterpretedProjection(children) + //TODO(davies): convertToScala function(inputRow(input)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index f9ed478e498c9..148253fa3cbf1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -29,7 +29,7 @@ object Literal { case f: Float => Literal(f, FloatType) case b: Byte => Literal(b, ByteType) case s: Short => Literal(s, ShortType) - case s: String => Literal(UTF8String(s), StringType) + case s: String => Literal(s, StringType) case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) @@ -62,7 +62,11 @@ object IntegerLiteral { } } -case class Literal(value: Any, dataType: DataType) extends LeafExpression { +case class Literal(var value: Any, dataType: DataType) extends LeafExpression { + + if (dataType == StringType && value.isInstanceOf[String]) { + value = UTF8String(value.asInstanceOf[String]) + } override def foldable: Boolean = true override def nullable: Boolean = value == null @@ -70,9 +74,7 @@ case class Literal(value: Any, dataType: DataType) extends LeafExpression { override def toString: String = if (value != null) value.toString else "null" type EvaluatedType = Any - override def eval(input: Row): Any = { - value - } + override def eval(input: Row): Any = value } // TODO: Specialize diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index e32c03cb39dd0..548ce741eead4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -114,11 +114,10 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { } override def getString(i: Int): String = { - val utf8 = UTF8String(values(i)) - if (utf8 != null) { - utf8.toString - } else { - null + values(i) match { + case null => null + case s: String => s + case utf8: UTF8String => utf8.toString } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index ae7031a5fbcd7..dd0e133b64537 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -129,7 +129,7 @@ trait CaseConversionExpression { if (evaluated == null) { null } else { - convert(UTF8String(evaluated)) + convert(evaluated.asInstanceOf[UTF8String]) } } } @@ -171,7 +171,8 @@ trait StringComparison { null } else { val rightEval = right.eval(input) - if (rightEval == null) null else compare(UTF8String(leftEval), UTF8String(rightEval)) + if (rightEval == null) null + else compare(leftEval.asInstanceOf[UTF8String], rightEval.asInstanceOf[UTF8String]) } } @@ -286,7 +287,6 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends string match { case ba: Array[Byte] => slice(ba, start, length) case s: UTF8String => slice(s, start, length) - case other => slice(UTF8String(other), start, length) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala index a8adfa823ffb3..cec379440e404 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala @@ -94,15 +94,14 @@ final class UTF8String extends Ordered[UTF8String] with Serializable { } object UTF8String { - implicit def apply(s: String): UTF8String = new UTF8String().set(s) - implicit def toString(utf: UTF8String): String = utf.toString + def apply(s: String): UTF8String = new UTF8String().set(s) def apply(bytes: Array[Byte]): UTF8String = new UTF8String().set(bytes) - def apply(utf8: UTF8String): UTF8String = utf8 - def apply(o: Any): UTF8String = o match { - case null => null - case utf8: UTF8String => utf8 - case s: String => new UTF8String().set(s) - case bytes: Array[Byte]=> new UTF8String().set(bytes) - case other => new UTF8String().set(other.toString) - } + //def apply(utf8: UTF8String): UTF8String = utf8 +// def apply(o: Any): UTF8String = o match { +// case null => null +// case utf8: UTF8String => utf8 +// case s: String => new UTF8String().set(s) +// case bytes: Array[Byte]=> new UTF8String().set(bytes) +// case other => new UTF8String().set(other.toString) +// } } \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 436b088079ada..4174357e4c2ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -701,8 +701,10 @@ class DataFrame private[sql]( def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributes = schema.toAttributes - val rowFunction = - f.andThen(_.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row])) + def rowFunction(row: Row): TraversableOnce[Row] = { + val rows = f(row) + rows.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row]) + } val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr)) Generate(generator, join = true, outer = false, None, logicalPlan) @@ -722,9 +724,12 @@ class DataFrame private[sql]( : DataFrame = { val dataType = ScalaReflection.schemaFor[B].dataType val attributes = AttributeReference(outputColumn, dataType)() :: Nil + def convert[A](x: Any): A = x match { + case utf: UTF8String => x.toString.asInstanceOf[A] + case other: A => other + } def rowFunction(row: Row): TraversableOnce[Row] = { - f((if (row(0).isInstanceOf[UTF8String]) row(0).toString else row(0)).asInstanceOf[A]). - map(o => Row(ScalaReflection.convertToCatalyst(o, dataType))) + f(convert[A](row(0))).map(o => Row(ScalaReflection.convertToCatalyst(o, dataType))) } val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index b8100782ec937..aef24227e06de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -31,7 +31,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, OneRowRelation} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.{ScalaReflection, expressions} import org.apache.spark.sql.execution.{Filter, _} @@ -390,7 +390,9 @@ class SQLContext(@transient val sparkContext: SparkContext) def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. - val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self) + // TODO(davies): only do convertion when needed (having StringType) + val logicalPlan = LogicalRDD(schema.toAttributes, + RDDConversions.rowToRowRdd(rowRDD, schema))(self) DataFrame(this, logicalPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index eac6c06592729..b0f983c180673 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -222,7 +222,7 @@ private[sql] class StringColumnStats extends ColumnStats { override def gatherStats(row: Row, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getString(ordinal) + val value = row(ordinal).asInstanceOf[UTF8String] if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value sizeInBytes += STRING.actualSize(row, ordinal) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index d8955725e59b1..261a1ebfbcd45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -19,14 +19,12 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.types.StructType - -import scala.collection.immutable +import org.apache.spark.sql.{Row, SQLContext} /** * :: DeveloperApi :: @@ -49,6 +47,27 @@ object RDDConversions { i += 1 } + mutableRow + } + } + } + } + def rowToRowRdd(data: RDD[Row], schema: StructType): RDD[Row] = { + data.mapPartitions { iterator => + if (iterator.isEmpty) { + Iterator.empty + } else { + val bufferedIterator = iterator.buffered + val mutableRow = new GenericMutableRow(bufferedIterator.head.toSeq.toArray) + val schemaFields = schema.fields.toArray + bufferedIterator.map { r => + var i = 0 + while (i < mutableRow.length) { + mutableRow(i) = + ScalaReflection.convertToCatalyst(r(i), schemaFields(i).dataType) + i += 1 + } + mutableRow } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index 5e7e41c982943..d4e0abc040bc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -77,7 +77,7 @@ package object jdbc { case ShortType => stmt.setInt(i + 1, row.getShort(i)) case ByteType => stmt.setInt(i + 1, row.getByte(i)) case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i)) - case StringType => stmt.setString(i + 1, row.getString(i).toString) + case StringType => stmt.setString(i + 1, row.getString(i)) case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i)) case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 6340b71957dbd..fd2072cd28371 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -421,7 +421,6 @@ private[sql] object JsonRDD extends Logging { value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) case MapType(StringType, valueType, _) => val map = value.asInstanceOf[Map[String, Any]] - //TODO(davies): use UTF8String for key? map.mapValues(enforceCorrectType(_, valueType)).map(identity) case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct) case DateType => toDate(value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 854d57c612c8a..a486cfaf278aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -960,7 +960,7 @@ private[sql] object ParquetRelation2 extends Logging { // Then falls back to string .getOrElse { if (raw == defaultPartitionName) Literal(null, NullType) else - Literal(UTF8String(raw), StringType) + Literal(raw, StringType) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1d6f41ff4231d..23eec2d2c0223 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -128,6 +128,7 @@ class DataFrameSuite extends QueryTest { val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDF("number", "letters") val df2 = df.explode('letters) { + // FIXME case Row(letters: UTF8String) => letters.toString.split(" ").map(Tuple1(_)).toSeq } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 24752c3a14425..81357dafed32e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,16 +17,13 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.TestSQLContext import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types._ - import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.{udf => _, _} +import org.apache.spark.sql.types._ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { @@ -698,7 +695,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { val v4 = try values(3).toInt catch { case _: NumberFormatException => null } - print("rowRDD1",values, v4) Row(values(0).toInt, values(1), values(2).toBoolean, v4) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 1dba551e87fc0..6d513cbcb659d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -117,7 +117,7 @@ class ColumnTypeSuite extends FunSuite with Logging { val length = buffer.getInt() val bytes = new Array[Byte](length) buffer.get(bytes) - new String(bytes, "utf-8") + UTF8String(bytes) }) testColumnType[BinaryType.type, Array[Byte]]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index 865618bed9dd3..adb3c9391f6c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -48,7 +48,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { check("10", Literal(10, IntegerType)) check("1000000000000000", Literal(1000000000000000L, LongType)) check("1.5", Literal(1.5, FloatType)) - check("hello", Literal(UTF8String("hello"), StringType)) + check("hello", Literal("hello", StringType)) check(defaultPartitionName, Literal(null, NullType)) } @@ -83,7 +83,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { ArrayBuffer("a", "b", "c"), ArrayBuffer( Literal(10, IntegerType), - Literal(UTF8String("hello"), StringType), + Literal("hello", StringType), Literal(1.5, FloatType)))) check( From 4699c3ae1dab6482b26dd3d3739193e68cd77ca3 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 31 Mar 2015 13:46:42 -0700 Subject: [PATCH 03/30] use Array[Byte] in UTF8String --- .../spark/sql/catalyst/ScalaReflection.scala | 16 +++- .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../expressions/stringOperations.scala | 3 +- .../apache/spark/sql/types/UTF8String.scala | 91 ++++++++++--------- .../spark/sql/types/UTF8StringSuite.scala | 59 ++++++++++++ .../org/apache/spark/sql/DataFrame.scala | 6 +- .../org/apache/spark/sql/SQLContext.scala | 29 +++++- .../spark/sql/columnar/ColumnType.scala | 2 +- .../apache/spark/sql/execution/commands.scala | 9 +- .../spark/sql/parquet/ParquetFilters.scala | 4 +- .../sql/parquet/ParquetTableSupport.scala | 7 +- .../apache/spark/sql/parquet/newParquet.scala | 5 +- .../apache/spark/sql/sources/commands.scala | 2 +- .../spark/sql/columnar/ColumnTypeSuite.scala | 2 +- 14 files changed, 162 insertions(+), 75 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 2945303e39aac..a66d72453df06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -56,11 +56,12 @@ trait ScalaReflection { case (obj, udt: UserDefinedType[_]) => udt.serialize(obj) case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) - case (s: Array[_], arrayType: ArrayType) => if (arrayType.elementType.isPrimitive) { - s.toSeq - } else { - s.toSeq.map(convertToCatalyst(_, arrayType.elementType)) - } + case (s: Array[_], arrayType: ArrayType) => + if (arrayType.elementType.isPrimitive) { + s.toSeq + } else { + s.toSeq.map(convertToCatalyst(_, arrayType.elementType)) + } case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) } @@ -69,6 +70,11 @@ trait ScalaReflection { p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) => convertToCatalyst(elem, field.dataType) }.toArray) + case (r: Row, structType: StructType) => + new GenericRow( + r.toSeq.zip(structType.fields).map { case (elem, field) => + convertToCatalyst(elem, field.dataType) + }.toArray) case (d: BigDecimal, _) => Decimal(d) case (d: java.math.BigDecimal, _) => Decimal(d) case (d: java.sql.Date, _) => DateUtils.fromJavaDate(d) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 7935723418903..b50b483d22854 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -119,7 +119,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // BinaryConverter private[this] def castToBinary(from: DataType): Any => Any = from match { - case StringType => buildCast[UTF8String](_, _.getBytes("UTF-8")) + case StringType => buildCast[UTF8String](_, _.getBytes) } // UDFToBoolean diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index dd0e133b64537..006ae77f6e39b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -251,7 +251,6 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends @inline def slice(str: UTF8String, startPos: Int, sliceLen: Int): Any = { - val len = str.length // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and // negative indices for start positions. If a start index i is greater than 0, it // refers to element i-1 in the sequence. If a start index i is less than 0, it refers @@ -260,7 +259,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends val start = startPos match { case pos if pos > 0 => pos - 1 - case neg if neg < 0 => len + neg + case neg if neg < 0 => str.length + neg case _ => 0 } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala index cec379440e404..90866536893de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala @@ -5,103 +5,106 @@ package org.apache.spark.sql.types */ final class UTF8String extends Ordered[UTF8String] with Serializable { - private var s: String = _ + private var bytes: Array[Byte] = _ + def set(str: String): UTF8String = { - this.s = str + this.bytes = str.getBytes("utf-8") this } def set(bytes: Array[Byte]): UTF8String = { - this.s = new String(bytes, "utf-8") - this - } - - def set(a: UTF8String): UTF8String = { - this.s = a.s + this.bytes = bytes.clone() this } def length(): Int = { - this.s.length - } - - def getBytes(): Array[Byte] = { - this.s.getBytes("utf-8") + var len = 0 + var i: Int = 0 + while (i < bytes.length) { + val b = bytes(i) + i += 1 + if (b >= 196) { + i += UTF8String.bytesFromUTF8(b - 196) + } + len += 1 + } + len } - def getBytes(encoding: String): Array[Byte] = { - this.s.getBytes(encoding) + def getBytes: Array[Byte] = { + bytes } def slice(start: Int, end: Int): UTF8String = { - UTF8String(this.s.slice(start, end)) + UTF8String(toString().slice(start, end)) } def contains(sub: UTF8String): Boolean = { - this.s.contains(sub.s) + bytes.containsSlice(sub.bytes) } def startsWith(prefix: UTF8String): Boolean = { - this.s.startsWith(prefix.s) + bytes.startsWith(prefix.bytes) } def endsWith(suffix: UTF8String): Boolean = { - this.s.endsWith(suffix.s) + bytes.endsWith(suffix.bytes) } def toUpperCase(): UTF8String = { - UTF8String(s.toUpperCase) + UTF8String(toString().toUpperCase) } def toLowerCase(): UTF8String = { - UTF8String(s.toLowerCase) + UTF8String(toString().toLowerCase) } override def toString(): String = { - this.s + new String(bytes, "utf-8") } - override def clone(): UTF8String = new UTF8String().set(this) + override def clone(): UTF8String = new UTF8String().set(this.bytes) override def compare(other: UTF8String): Int = { - this.s.compare(other.s) - } - - def compare(other: String): Int = { - this.s.compare(other) + var i: Int = 0 + while (i < bytes.length && i < other.bytes.length) { + val res = bytes(i).compareTo(other.bytes(i)) + if (res != 0) return res + i += 1 + } + bytes.length - other.bytes.length } override def compareTo(other: UTF8String): Int = { - this.s.compareTo(other.s) - } - - def compareTo(other: String): Int = { - this.s.compareTo(other) + compare(other) } override def equals(other: Any): Boolean = other match { case s: UTF8String => - compare(s) == 0 + bytes.length == s.bytes.length && compare(s) == 0 case s: String => - this.s.compare(s) == 0 + toString() == s case _ => false } override def hashCode(): Int = { - this.s.hashCode + var h: Int = 1 + var i: Int = 0 + while (i < bytes.length) { + h = h * 31 + bytes(i) + i += 1 + } + h } } object UTF8String { + val bytesFromUTF8: Array[Int] = Array(1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, + 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5) + def apply(s: String): UTF8String = new UTF8String().set(s) def apply(bytes: Array[Byte]): UTF8String = new UTF8String().set(bytes) - //def apply(utf8: UTF8String): UTF8String = utf8 -// def apply(o: Any): UTF8String = o match { -// case null => null -// case utf8: UTF8String => utf8 -// case s: String => new UTF8String().set(s) -// case bytes: Array[Byte]=> new UTF8String().set(bytes) -// case other => new UTF8String().set(other.toString) -// } } \ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala new file mode 100644 index 0000000000000..46c7a3e92b4f2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala @@ -0,0 +1,59 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.types + +import org.scalatest.FunSuite + +class UTF8StringSuite extends FunSuite { + test("basic") { + def check(str: String, len: Int) { + + assert(UTF8String(str).length == len) + assert(UTF8String(str.getBytes("utf8")).length() == len) + + assert(UTF8String(str) == str) + assert(UTF8String(str.getBytes("utf8")) == str) + assert(UTF8String(str.getBytes("utf8")) == UTF8String(str)) + + assert(UTF8String(str).hashCode() == UTF8String(str.getBytes("utf8")).hashCode()) + } + + check("hello", 5) + check("世 界", 3) + } + + test("contains, prefix and suffix") { + assert(UTF8String("hello").contains(UTF8String("ello"))) + assert(!UTF8String("hello").contains(UTF8String("vello"))) + assert(UTF8String("大千世界").contains(UTF8String("千世"))) + assert(!UTF8String("大千世界").contains(UTF8String("世千"))) + + assert(UTF8String("hello").startsWith(UTF8String("hell"))) + assert(!UTF8String("hello").startsWith(UTF8String("ell"))) + assert(UTF8String("大千世界").startsWith(UTF8String("大千"))) + assert(!UTF8String("大千世界").startsWith(UTF8String("千"))) + + assert(UTF8String("hello").endsWith(UTF8String("ello"))) + assert(!UTF8String("hello").endsWith(UTF8String("ellov"))) + assert(UTF8String("大千世界").endsWith(UTF8String("世界"))) + assert(!UTF8String("大千世界").endsWith(UTF8String("世"))) + + assert(UTF8String("hello").slice(1, 3) == UTF8String("ell")) + assert(UTF8String("大千世界").slice(1, 2) == UTF8String("千世")) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 4174357e4c2ad..5719dcbac7d6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -726,7 +726,7 @@ class DataFrame private[sql]( val attributes = AttributeReference(outputColumn, dataType)() :: Nil def convert[A](x: Any): A = x match { case utf: UTF8String => x.toString.asInstanceOf[A] - case other: A => other + case other => other.asInstanceOf[A] } def rowFunction(row: Row): TraversableOnce[Row] = { f(convert[A](row(0))).map(o => Row(ScalaReflection.convertToCatalyst(o, dataType))) @@ -898,8 +898,8 @@ class DataFrame private[sql]( * @group rdd */ override def repartition(numPartitions: Int): DataFrame = { - sqlContext.createDataFrame( - queryExecution.toRdd.map(_.copy()).repartition(numPartitions), schema) + val repartitioned = queryExecution.toRdd.map(_.copy()).repartition(numPartitions) + DataFrame(sqlContext, LogicalRDD(schema.toAttributes, repartitioned)(sqlContext)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index aef24227e06de..4ae8dd08b2370 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -390,10 +390,31 @@ class SQLContext(@transient val sparkContext: SparkContext) def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. - // TODO(davies): only do convertion when needed (having StringType) - val logicalPlan = LogicalRDD(schema.toAttributes, - RDDConversions.rowToRowRdd(rowRDD, schema))(self) - DataFrame(this, logicalPlan) + def needsConversion(dt: DataType): Boolean = dt match { + case StringType => true + case dt: ArrayType => needsConversion(dt.elementType) + case dt: MapType => needsConversion(dt.keyType) || needsConversion(dt.valueType) + case dt: StructType => + !dt.fields.forall(f => !needsConversion(f.dataType)) + // TODO(davies): check other types and values + case other => false + } + val convertedRdd = if (needsConversion(schema)) { + RDDConversions.rowToRowRdd(rowRDD, schema) + } else { + rowRDD + } + DataFrame(this, LogicalRDD(schema.toAttributes, convertedRdd)(self)) + } + + /** + * An internal API to apply a new schema on existing DataFrame without do the + * conversion for Rows. + */ + private[sql] def createDataFrame(df: DataFrame, schema: StructType): DataFrame = { + // TODO: use MutableProjection when rowRDD is another DataFrame and the applied + // schema differs from the existing schema on any field data type. + DataFrame(this, LogicalRDD(schema.toAttributes, df.queryExecution.toRdd)(self)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index d79bfa076cb26..721719e734efb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -313,7 +313,7 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { } override def append(v: UTF8String, buffer: ByteBuffer): Unit = { - val stringBytes = v.getBytes("utf-8") + val stringBytes = v.getBytes buffer.putInt(stringBytes.length).put(stringBytes, 0, stringBytes.length) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index fad7a281dc1e2..1f0954a481050 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types.{BooleanType, StructField, StructType, StringType} +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row, Attribute} @@ -61,7 +62,11 @@ case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { override def executeTake(limit: Int): Array[Row] = sideEffectResult.take(limit).toArray - override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1) + override def execute(): RDD[Row] = { + val converted = sideEffectResult.map(r => ScalaReflection.convertToCatalyst(r, this.schema) + .asInstanceOf[Row]) + sqlContext.sparkContext.parallelize(converted, 1) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index e693d13f51fb1..5eb1c6abc2432 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -76,7 +76,7 @@ private[sql] object ParquetFilters { case StringType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), - Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes())).orNull) + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) case BinaryType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), @@ -94,7 +94,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes())) + FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 357baf6f15361..e05a4c20b0d41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -198,10 +198,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { if (value != null) { schema match { case StringType => writer.addBinary( - Binary.fromByteArray( - value.asInstanceOf[UTF8String].getBytes("utf-8") - ) - ) + Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes)) case BinaryType => writer.addBinary( Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) case IntegerType => writer.addInteger(value.asInstanceOf[Int]) @@ -349,7 +346,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { index: Int): Unit = { ctype match { case StringType => writer.addBinary( - Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes("utf-8"))) + Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes)) case BinaryType => writer.addBinary( Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) case IntegerType => writer.addInteger(record.getInt(index)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index a486cfaf278aa..e23fd0abb55cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -118,10 +118,7 @@ private[sql] class DefaultSource val relation = if (doInsertion) { // This is a hack. We always set nullable/containsNull/valueContainsNull to true // for the schema of a parquet data. - val df = - sqlContext.createDataFrame( - data.queryExecution.toRdd, - data.schema.asNullable) + val df = sqlContext.createDataFrame(data, data.schema.asNullable) val createdRelation = createRelation(sqlContext, parameters, df.schema).asInstanceOf[ParquetRelation2] createdRelation.insert(df, overwrite = mode == SaveMode.Overwrite) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 9bbe06e59ba30..3c4adc30a5ff8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -31,7 +31,7 @@ private[sql] case class InsertIntoDataSource( val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] val data = DataFrame(sqlContext, query) // Apply the schema of the existing table to the new data. - val df = sqlContext.createDataFrame(data.queryExecution.toRdd, logicalRelation.schema) + val df = sqlContext.createDataFrame(data, logicalRelation.schema) relation.insert(df, overwrite) // Invalidate the cache. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 6d513cbcb659d..c86ef338fc644 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -109,7 +109,7 @@ class ColumnTypeSuite extends FunSuite with Logging { testNativeColumnType[StringType.type]( STRING, (buffer: ByteBuffer, string: UTF8String) => { - val bytes = string.getBytes("utf-8") + val bytes = string.getBytes buffer.putInt(bytes.length) buffer.put(bytes) }, From d32abd1e8e6b7b5ef92a34a5d3a42919db58a43c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 31 Mar 2015 13:57:17 -0700 Subject: [PATCH 04/30] fix utf8 for python api --- sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala | 1 + .../scala/org/apache/spark/sql/execution/pythonUdfs.scala | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4ae8dd08b2370..ba040958c264c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1207,6 +1207,7 @@ class SQLContext(@transient val sparkContext: SparkContext) case FloatType => true case DateType => true case TimestampType => true + case StringType => true case ArrayType(_, _) => true case MapType(_, _, _) => true case StructType(_) => true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 5b308d88d4cdf..7a43bfd8bc8d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -140,6 +140,7 @@ object EvaluatePython { case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) case (date: Int, DateType) => DateUtils.toJavaDate(date) + case (s: UTF8String, StringType) => s.toString // Pyrolite can handle Timestamp and Decimal case (other, _) => other @@ -192,7 +193,8 @@ object EvaluatePython { case (c: Long, IntegerType) => c.toInt case (c: Int, LongType) => c.toLong case (c: Double, FloatType) => c.toFloat - case (c, StringType) if !c.isInstanceOf[String] => c.toString + case (c: String, StringType) => UTF8String(c) + case (c, StringType) if !c.isInstanceOf[String] => UTF8String(c.toString) case (c, _) => c } From a85fb275d742dd9384e15f22878b545e9a77a106 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 31 Mar 2015 16:42:18 -0700 Subject: [PATCH 05/30] refactor --- .../expressions/stringOperations.scala | 3 +- .../apache/spark/sql/types/UTF8String.scala | 88 +++++++++++++++---- .../spark/sql/types/UTF8StringSuite.scala | 4 +- 3 files changed, 75 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 006ae77f6e39b..232b1f3d8da56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -228,7 +228,6 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends @inline def slice[T, C <: Any](str: C, startPos: Int, sliceLen: Int) (implicit ev: (C=>IndexedSeqOptimized[T,_])): Any = { - val len = str.length // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and // negative indices for start positions. If a start index i is greater than 0, it // refers to element i-1 in the sequence. If a start index i is less than 0, it refers @@ -237,7 +236,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends val start = startPos match { case pos if pos > 0 => pos - 1 - case neg if neg < 0 => len + neg + case neg if neg < 0 => str.length + neg case _ => 0 } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala index 90866536893de..0b3101c6d4456 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala @@ -1,14 +1,31 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + package org.apache.spark.sql.types /** - * A mutable UTF-8 String + * A UTF-8 String used only in SparkSQL */ -final class UTF8String extends Ordered[UTF8String] with Serializable { +private[sql] final class UTF8String extends Ordered[UTF8String] with Serializable { private var bytes: Array[Byte] = _ def set(str: String): UTF8String = { - this.bytes = str.getBytes("utf-8") + bytes = str.getBytes("utf-8") this } @@ -36,7 +53,30 @@ final class UTF8String extends Ordered[UTF8String] with Serializable { } def slice(start: Int, end: Int): UTF8String = { - UTF8String(toString().slice(start, end)) + if (end <= start || start >= bytes.length || bytes == null) { + new UTF8String + } + + var c = 0 + var i: Int = 0 + while (c < start && i < bytes.length) { + val b = bytes(i) + i += 1 + if (b >= 196) { + i += UTF8String.bytesFromUTF8(b - 196) + } + c += 1 + } + val bstart = i + while (c < end && i < bytes.length) { + val b = bytes(i) + i += 1 + if (b >= 196) { + i += UTF8String.bytesFromUTF8(b - 196) + } + c += 1 + } + UTF8String(java.util.Arrays.copyOfRange(bytes, bstart, i)) } def contains(sub: UTF8String): Boolean = { @@ -81,30 +121,44 @@ final class UTF8String extends Ordered[UTF8String] with Serializable { override def equals(other: Any): Boolean = other match { case s: UTF8String => - bytes.length == s.bytes.length && compare(s) == 0 + java.util.Arrays.equals(bytes, s.bytes) case s: String => - toString() == s + bytes.length >= s.length && length() == s.length && toString() == s case _ => false } override def hashCode(): Int = { - var h: Int = 1 - var i: Int = 0 - while (i < bytes.length) { - h = h * 31 + bytes(i) - i += 1 - } - h + java.util.Arrays.hashCode(bytes) } } -object UTF8String { - val bytesFromUTF8: Array[Int] = Array(1, 1, 1, 1, 1, +private[sql] object UTF8String { + // number of tailing bytes in a UTF8 sequence for a code point + private[types] val bytesFromUTF8: Array[Int] = Array(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5) - def apply(s: String): UTF8String = new UTF8String().set(s) - def apply(bytes: Array[Byte]): UTF8String = new UTF8String().set(bytes) + /** + * Create a UTF-8 String from String + */ + def apply(s: String): UTF8String = { + if (s != null) { + new UTF8String().set(s) + } else{ + null + } + } + + /** + * Create a UTF-8 String from Array[Byte], which should be encoded in UTF-8 + */ + def apply(bytes: Array[Byte]): UTF8String = { + if (bytes != null) { + new UTF8String().set(bytes) + } else { + null + } + } } \ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala index 46c7a3e92b4f2..94de6a792c99a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala @@ -54,6 +54,8 @@ class UTF8StringSuite extends FunSuite { assert(!UTF8String("大千世界").endsWith(UTF8String("世"))) assert(UTF8String("hello").slice(1, 3) == UTF8String("ell")) - assert(UTF8String("大千世界").slice(1, 2) == UTF8String("千世")) + assert(UTF8String("大千世界").slice(0, 1) == UTF8String("大")) + assert(UTF8String("大千世界").slice(1, 3) == UTF8String("千世")) + assert(UTF8String("大千世界").slice(3, 5) == UTF8String("界")) } } From 6b499ac13528f8062e1e81ddb0cd462975960067 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 31 Mar 2015 23:19:47 -0700 Subject: [PATCH 06/30] fix style --- .../spark/sql/catalyst/ScalaReflection.scala | 3 +- .../spark/sql/catalyst/expressions/Cast.scala | 4 +- .../expressions/SpecificMutableRow.scala | 31 ++++----- .../codegen/GenerateProjection.scala | 2 +- .../sql/catalyst/expressions/generators.scala | 7 +- .../sql/catalyst/expressions/predicates.scala | 2 +- .../spark/sql/catalyst/expressions/rows.scala | 5 +- .../apache/spark/sql/types/UTF8String.scala | 64 +++++++++++++------ .../apache/spark/sql/types/dataTypes.scala | 7 -- .../spark/sql/types/UTF8StringSuite.scala | 6 ++ .../org/apache/spark/sql/DataFrame.scala | 12 +--- .../spark/sql/columnar/ColumnType.scala | 2 +- .../apache/spark/sql/execution/commands.scala | 2 +- .../spark/sql/execution/debug/package.scala | 1 - .../org/apache/spark/sql/jdbc/JDBCRDD.scala | 1 - .../org/apache/spark/sql/json/JsonRDD.scala | 1 - .../spark/sql/parquet/ParquetConverter.scala | 2 +- .../apache/spark/sql/parquet/newParquet.scala | 3 +- .../org/apache/spark/sql/DataFrameSuite.scala | 3 +- 19 files changed, 90 insertions(+), 68 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index a66d72453df06..7856064d2a2ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -93,8 +93,7 @@ trait ScalaReflection { case (r: Row, s: StructType) => convertRowToScala(r, s) case (d: Decimal, _: DecimalType) => d.toJavaBigDecimal case (i: Int, DateType) => DateUtils.toJavaDate(i) - case (s: UTF8String, StringType) => - s.toString() + case (s: UTF8String, StringType) => s.toString() case (other, _) => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index b50b483d22854..adf941ab2a45f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -323,7 +323,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try changePrecision(Decimal(s.toString.toDouble), target) catch { + buildCast[UTF8String](_, s => try { + changePrecision(Decimal(s.toString.toDouble), target) + } catch { case _: NumberFormatException => null }) case BooleanType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 85c9a4c00a2b6..a0ac9e21ca240 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -170,20 +170,20 @@ final class MutableByte extends MutableValue { } } -//final class MutableString extends MutableValue { -// var value: UTF8String = _ -// override def boxed: Any = if (isNull) null else value -// override def update(v: Any): Unit = { -// isNull = false -// value = v.asInstanceOf[UTF8String] -// } -// override def copy(): MutableString = { -// val newCopy = new MutableString -// newCopy.isNull = isNull -// newCopy.value = value -// newCopy.asInstanceOf[MutableString] -// } -//} +final class MutableString extends MutableValue { + var value: UTF8String = _ + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { + isNull = false + value = v.asInstanceOf[UTF8String] + } + override def copy(): MutableString = { + val newCopy = new MutableString + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[MutableString] + } +} final class MutableAny extends MutableValue { var value: Any = _ @@ -217,7 +217,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR case DoubleType => new MutableDouble case BooleanType => new MutableBoolean case LongType => new MutableLong -// case StringType => new MutableString + // TODO(davies): Enable this + // case StringType => new MutableString case _ => new MutableAny }.toArray) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index dc0a29611a779..9a0d32e241824 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -152,7 +152,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { }""" case other => q""" - override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}): Unit = { + override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}):Unit = { ..$ifStatements; $accessorFailure }""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 99c5fafe593f6..872a07ce1b17f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map -import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.{ScalaReflection, trees} import org.apache.spark.sql.types._ /** @@ -82,12 +82,13 @@ case class UserDefinedGenerator( children: Seq[Expression]) extends Generator{ + var input_schema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) + override protected def makeOutput(): Seq[Attribute] = schema override def eval(input: Row): TraversableOnce[Row] = { val inputRow = new InterpretedProjection(children) - //TODO(davies): convertToScala - function(inputRow(input)) + function(ScalaReflection.convertToCatalyst(inputRow(input), input_schema).asInstanceOf[Row]) } override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index b34fa4ac10921..fcd6352079b4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -179,7 +179,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison val r = right.eval(input) if (r == null) null else if (left.dataType != BinaryType) l == r - else BinaryType.ordering.equiv(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) + else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 548ce741eead4..4c7ab93f8490a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -193,7 +193,10 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value } override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value } override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value } - override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = UTF8String(value) } + override def setString(ordinal: Int, value: String): Unit = { + // TODO(davies): need this? + values(ordinal) = UTF8String(value) + } override def setNullAt(i: Int): Unit = { values(i) = null } override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala index 0b3101c6d4456..6e9a907596d70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala @@ -17,31 +17,50 @@ package org.apache.spark.sql.types +import java.util.Arrays + /** - * A UTF-8 String used only in SparkSQL + * A UTF-8 String, as internal representation of StringType in SparkSQL + * + * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison, + * search, see http://en.wikipedia.org/wiki/UTF-8 for details. + * + * Note: This is not designed for general use cases, should not be used outside SQL. */ private[sql] final class UTF8String extends Ordered[UTF8String] with Serializable { + private var bytes: Array[Byte] = _ + /** + * Update the UTF8String with String. + */ def set(str: String): UTF8String = { bytes = str.getBytes("utf-8") this } + /** + * Update the UTF8String with Array[Byte], which should be encoded in UTF-8 + */ def set(bytes: Array[Byte]): UTF8String = { this.bytes = bytes.clone() this } + /** + * Return the number of code points in it. + * + * This is only used by Substring() when `start` is negative. + */ def length(): Int = { var len = 0 var i: Int = 0 while (i < bytes.length) { val b = bytes(i) i += 1 - if (b >= 196) { - i += UTF8String.bytesFromUTF8(b - 196) + if (b >= 192) { + i += UTF8String.tailBytesOfUTF8(b - 192) } len += 1 } @@ -52,8 +71,13 @@ private[sql] final class UTF8String extends Ordered[UTF8String] with Serializabl bytes } - def slice(start: Int, end: Int): UTF8String = { - if (end <= start || start >= bytes.length || bytes == null) { + /** + * Return a substring of this, + * @param start the position of first code point + * @param until the position after last code point + */ + def slice(start: Int, until: Int): UTF8String = { + if (until <= start || start >= bytes.length || bytes == null) { new UTF8String } @@ -62,21 +86,21 @@ private[sql] final class UTF8String extends Ordered[UTF8String] with Serializabl while (c < start && i < bytes.length) { val b = bytes(i) i += 1 - if (b >= 196) { - i += UTF8String.bytesFromUTF8(b - 196) + if (b >= 192) { + i += UTF8String.tailBytesOfUTF8(b - 192) } c += 1 } - val bstart = i - while (c < end && i < bytes.length) { - val b = bytes(i) - i += 1 - if (b >= 196) { - i += UTF8String.bytesFromUTF8(b - 196) + var j = i + while (c < until && j < bytes.length) { + val b = bytes(j) + j += 1 + if (b >= 192) { + j += UTF8String.tailBytesOfUTF8(b - 192) } c += 1 } - UTF8String(java.util.Arrays.copyOfRange(bytes, bstart, i)) + UTF8String(Arrays.copyOfRange(bytes, i, j)) } def contains(sub: UTF8String): Boolean = { @@ -92,10 +116,12 @@ private[sql] final class UTF8String extends Ordered[UTF8String] with Serializabl } def toUpperCase(): UTF8String = { + // upper case depends on locale, fallback to String. UTF8String(toString().toUpperCase) } def toLowerCase(): UTF8String = { + // lower case depends on locale, fallback to String. UTF8String(toString().toLowerCase) } @@ -121,21 +147,23 @@ private[sql] final class UTF8String extends Ordered[UTF8String] with Serializabl override def equals(other: Any): Boolean = other match { case s: UTF8String => - java.util.Arrays.equals(bytes, s.bytes) + Arrays.equals(bytes, s.bytes) case s: String => + // fail fast bytes.length >= s.length && length() == s.length && toString() == s case _ => false } override def hashCode(): Int = { - java.util.Arrays.hashCode(bytes) + Arrays.hashCode(bytes) } } private[sql] object UTF8String { // number of tailing bytes in a UTF8 sequence for a code point - private[types] val bytesFromUTF8: Array[Int] = Array(1, 1, 1, 1, 1, + // see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1 + private[types] val tailBytesOfUTF8: Array[Int] = Array(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5) @@ -161,4 +189,4 @@ private[sql] object UTF8String { null } } -} \ No newline at end of file +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index 499eb1e0a2929..045c4cc923471 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -386,13 +386,6 @@ class BinaryType private() extends NativeType with PrimitiveType { } x.length - y.length } - override def equiv(x: Array[Byte], y: Array[Byte]): Boolean = { - if (x.length != y.length) { - false - } else { - compare(x, y) == 0 - } - } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala index 94de6a792c99a..987087e0b7a64 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala @@ -42,17 +42,23 @@ class UTF8StringSuite extends FunSuite { assert(!UTF8String("hello").contains(UTF8String("vello"))) assert(UTF8String("大千世界").contains(UTF8String("千世"))) assert(!UTF8String("大千世界").contains(UTF8String("世千"))) + } + test("prefix") { assert(UTF8String("hello").startsWith(UTF8String("hell"))) assert(!UTF8String("hello").startsWith(UTF8String("ell"))) assert(UTF8String("大千世界").startsWith(UTF8String("大千"))) assert(!UTF8String("大千世界").startsWith(UTF8String("千"))) + } + test("suffix") { assert(UTF8String("hello").endsWith(UTF8String("ello"))) assert(!UTF8String("hello").endsWith(UTF8String("ellov"))) assert(UTF8String("大千世界").endsWith(UTF8String("世界"))) assert(!UTF8String("大千世界").endsWith(UTF8String("世"))) + } + test("slice") { assert(UTF8String("hello").slice(1, 3) == UTF8String("ell")) assert(UTF8String("大千世界").slice(0, 1) == UTF8String("大")) assert(UTF8String("大千世界").slice(1, 3) == UTF8String("千世")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 5719dcbac7d6b..80b0c22939480 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -701,10 +701,8 @@ class DataFrame private[sql]( def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributes = schema.toAttributes - def rowFunction(row: Row): TraversableOnce[Row] = { - val rows = f(row) - rows.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row]) - } + val rowFunction = + f.andThen(_.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row])) val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr)) Generate(generator, join = true, outer = false, None, logicalPlan) @@ -724,12 +722,8 @@ class DataFrame private[sql]( : DataFrame = { val dataType = ScalaReflection.schemaFor[B].dataType val attributes = AttributeReference(outputColumn, dataType)() :: Nil - def convert[A](x: Any): A = x match { - case utf: UTF8String => x.toString.asInstanceOf[A] - case other => other.asInstanceOf[A] - } def rowFunction(row: Row): TraversableOnce[Row] = { - f(convert[A](row(0))).map(o => Row(ScalaReflection.convertToCatalyst(o, dataType))) + f(row(0).asInstanceOf[A]).map(o => Row(ScalaReflection.convertToCatalyst(o, dataType))) } val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 721719e734efb..bc5b7f2371890 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -329,7 +329,7 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { } override def getField(row: Row, ordinal: Int): UTF8String = { - row.apply(ordinal).asInstanceOf[UTF8String] + row(ordinal).asInstanceOf[UTF8String] } override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 1f0954a481050..24e48dd6de018 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -63,7 +63,7 @@ case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { override def executeTake(limit: Int): Array[Row] = sideEffectResult.take(limit).toArray override def execute(): RDD[Row] = { - val converted = sideEffectResult.map(r => ScalaReflection.convertToCatalyst(r, this.schema) + val converted = sideEffectResult.map(r => ScalaReflection.convertToCatalyst(r, schema) .asInstanceOf[Row]) sqlContext.sparkContext.parallelize(converted, 1) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 7ca8c38163fb2..710787096e6cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -165,7 +165,6 @@ package object debug { case (_: Long, LongType) => case (_: Int, IntegerType) => case (_: UTF8String, StringType) => - case (_: String, StringType) => case (_: Float, FloatType) => case (_: Byte, ByteType) => case (_: Short, ShortType) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 38b7e344d57cd..9447a8b77f541 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -233,7 +233,6 @@ private[sql] class JDBCRDD( * Converts value to SQL expression. */ private def compileValue(value: Any): Any = value match { - case stringValue: String => s"'${escapeSql(stringValue)}'" case stringValue: UTF8String => s"'${escapeSql(stringValue.toString)}'" case _ => value } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index fd2072cd28371..9996f9fb709bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -450,7 +450,6 @@ private[sql] object JsonRDD extends Logging { private[sql] def rowToJSON(rowSchema: StructType, gen: JsonGenerator)(row: Row) = { def valWriter: (DataType, Any) => Unit = { case (_, null) | (NullType, _) => gen.writeNull() - case (StringType, v: String) => gen.writeString(v) case (StringType, v: UTF8String) => gen.writeString(v.toString) case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString) case (IntegerType, v: Int) => gen.writeNumber(v) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 5cf217e3e162a..bc108e37dfb0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -480,7 +480,7 @@ private[parquet] class CatalystPrimitiveStringConverter(parent: CatalystConverte override def hasDictionarySupport: Boolean = true override def setDictionary(dictionary: Dictionary):Unit = - dict = Array.tabulate(dictionary.getMaxId + 1) {dictionary.decodeToBinary(_).getBytes} + dict = Array.tabulate(dictionary.getMaxId + 1) { dictionary.decodeToBinary(_).getBytes } override def addValueFromDictionary(dictionaryId: Int): Unit = parent.updateString(fieldIndex, dict(dictionaryId)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index e23fd0abb55cd..9d399653e4f37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -956,8 +956,7 @@ private[sql] object ParquetRelation2 extends Logging { .orElse(Try(Literal(new JBigDecimal(raw), DecimalType.Unlimited))) // Then falls back to string .getOrElse { - if (raw == defaultPartitionName) Literal(null, NullType) else - Literal(raw, StringType) + if (raw == defaultPartitionName) Literal(null, NullType) else Literal(raw, StringType) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 23eec2d2c0223..6761d996fd975 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -128,8 +128,7 @@ class DataFrameSuite extends QueryTest { val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDF("number", "letters") val df2 = df.explode('letters) { - // FIXME - case Row(letters: UTF8String) => letters.toString.split(" ").map(Tuple1(_)).toSeq + case Row(letters: String) => letters.split(" ").map(Tuple1(_)).toSeq } checkAnswer( From 5f9e1207a3896cb779d17b4aff6c58961e54b827 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 31 Mar 2015 23:48:58 -0700 Subject: [PATCH 07/30] fix sql tests --- .../apache/spark/sql/catalyst/expressions/generators.scala | 6 +++--- .../scala/org/apache/spark/sql/execution/pythonUdfs.scala | 1 - .../src/main/scala/org/apache/spark/sql/json/JsonRDD.scala | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 872a07ce1b17f..17daddb129f33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -82,13 +82,13 @@ case class UserDefinedGenerator( children: Seq[Expression]) extends Generator{ - var input_schema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) - override protected def makeOutput(): Seq[Attribute] = schema override def eval(input: Row): TraversableOnce[Row] = { + // TODO(davies): improve this + val input_schema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) val inputRow = new InterpretedProjection(children) - function(ScalaReflection.convertToCatalyst(inputRow(input), input_schema).asInstanceOf[Row]) + function(ScalaReflection.convertToScala(inputRow(input), input_schema).asInstanceOf[Row]) } override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 7a43bfd8bc8d9..05672592c6933 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -140,7 +140,6 @@ object EvaluatePython { case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) case (date: Int, DateType) => DateUtils.toJavaDate(date) - case (s: UTF8String, StringType) => s.toString // Pyrolite can handle Timestamp and Decimal case (other, _) => other diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 9996f9fb709bf..df94d08451954 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -450,7 +450,7 @@ private[sql] object JsonRDD extends Logging { private[sql] def rowToJSON(rowSchema: StructType, gen: JsonGenerator)(row: Row) = { def valWriter: (DataType, Any) => Unit = { case (_, null) | (NullType, _) => gen.writeNull() - case (StringType, v: UTF8String) => gen.writeString(v.toString) + case (StringType, v: String) => gen.writeString(v.toString) case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString) case (IntegerType, v: Int) => gen.writeNumber(v) case (ShortType, v: Short) => gen.writeNumber(v) From 38c303ede6a96157158a1624472a9f92035289a0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 1 Apr 2015 00:10:30 -0700 Subject: [PATCH 08/30] fix python sql tests --- python/pyspark/sql/dataframe.py | 10 +++++----- .../org/apache/spark/sql/execution/pythonUdfs.scala | 2 ++ 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 23c0e63e77812..d2ddb8cd3fe71 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -475,7 +475,7 @@ def join(self, other, joinExprs=None, joinType=None): :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() - [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] + [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] """ if joinExprs is None: @@ -645,9 +645,9 @@ def groupBy(self, *cols): >>> df.groupBy().avg().collect() [Row(AVG(age)=3.5)] >>> df.groupBy('name').agg({'age': 'mean'}).collect() - [Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)] + [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] >>> df.groupBy(df.name).avg().collect() - [Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)] + [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] """ jcols = ListConverter().convert([_to_java_column(c) for c in cols], self._sc._gateway._gateway_client) @@ -774,11 +774,11 @@ def agg(self, *exprs): >>> gdf = df.groupBy(df.name) >>> gdf.agg({"*": "count"}).collect() - [Row(name=u'Bob', COUNT(1)=1), Row(name=u'Alice', COUNT(1)=1)] + [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)] >>> from pyspark.sql import functions as F >>> gdf.agg(F.min(df.age)).collect() - [Row(MIN(age)=5), Row(MIN(age)=2)] + [Row(MIN(age)=2), Row(MIN(age)=5)] """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 05672592c6933..53c2664a47b00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -140,6 +140,7 @@ object EvaluatePython { case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) case (date: Int, DateType) => DateUtils.toJavaDate(date) + case (s: UTF8String, StringType) => s.toString // Pyrolite can handle Timestamp and Decimal case (other, _) => other @@ -229,6 +230,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: def execute(): RDD[Row] = { // TODO: Clean up after ourselves? + // TODO(davies): convert internal type to Scala Type val childResults = child.execute().map(_.copy()).cache() val parent = childResults.mapPartitions { iter => From c7dd4d285bc08313f98334225435e77ca0f79bb9 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 1 Apr 2015 11:40:38 -0700 Subject: [PATCH 09/30] fix some catalyst tests --- .../expressions/codegen/CodeGenerator.scala | 27 ++--- .../sql/catalyst/expressions/literals.scala | 9 +- .../sql/catalyst/optimizer/Optimizer.scala | 21 ++-- .../apache/spark/sql/types/UTF8String.scala | 10 +- .../ExpressionEvaluationSuite.scala | 101 ++++++++++-------- .../spark/sql/types/UTF8StringSuite.scala | 4 +- 6 files changed, 93 insertions(+), 79 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 0a291ca7c037e..53ff8fb284619 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -216,16 +216,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val $primitiveTerm: ${termForType(dataType)} = $value """.children -// case expressions.Literal(value: UTF8String, dataType) => -// q""" -// val $nullTerm = ${value == null} -// val $primitiveTerm: ${termForType(dataType)} = $value -// """.children - - case expressions.Literal(value: String, dataType) => + case expressions.Literal(value: UTF8String, dataType) => q""" val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value + val $primitiveTerm: ${termForType(dataType)} = + org.apache.spark.sql.types.UTF8String(${value.toString}) """.children case expressions.Literal(value: Int, dataType) => @@ -249,11 +244,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin if($nullTerm) ${defaultPrimitive(StringType)} else - UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) + org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) """.children case Cast(child @ DateType(), StringType) => - child.castOrNull(c => q"org.apache.spark.sql.types.DateUtils.toString($c)", StringType) + child.castOrNull(c => q"org.apache.spark.sql.types.UTF8String(org.apache.spark.sql.types.DateUtils.toString($c))", StringType) case Cast(child @ NumericType(), IntegerType) => child.castOrNull(c => q"$c.toInt", IntegerType) @@ -278,7 +273,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin if($nullTerm) ${defaultPrimitive(StringType)} else - ${eval.primitiveTerm}.toString + org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString) """.children case EqualTo(e1, e2) => @@ -579,7 +574,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val localLogger = log val localLoggerTree = reify { localLogger } q""" - $localLoggerTree.debug(${e.toString} + ": " + (if($nullTerm) "null" else $primitiveTerm)) + $localLoggerTree.debug(${e.toString} + ": " + (if($nullTerm) "null" else $primitiveTerm.toString)) """ :: Nil } else { Nil @@ -590,7 +585,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { dataType match { - case StringType => q"$inputRow.apply($ordinal).asInstanceOf[UTF8String]" + case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]" case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)" case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]" } @@ -602,7 +597,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin ordinal: Int, value: TermName) = { dataType match { - case StringType => q"$destinationRow.setString($ordinal, $value)" + case StringType => q"$destinationRow.update($ordinal, $value)" case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" case _ => q"$destinationRow.update($ordinal, $value)" } @@ -626,13 +621,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin case DoubleType => "Double" case FloatType => "Float" case BooleanType => "Boolean" - case StringType => "String" + case StringType => "org.apache.spark.sql.types.UTF8String" } protected def defaultPrimitive(dt: DataType) = dt match { case BooleanType => ru.Literal(Constant(false)) case FloatType => ru.Literal(Constant(-1.0.toFloat)) - case StringType => ru.Literal(Constant("")) + case StringType => q"""org.apache.spark.sql.types.UTF8String("")""" case ShortType => ru.Literal(Constant(-1.toShort)) case LongType => ru.Literal(Constant(-1L)) case ByteType => ru.Literal(Constant(-1.toByte)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 148253fa3cbf1..215785e5964a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -64,8 +64,13 @@ object IntegerLiteral { case class Literal(var value: Any, dataType: DataType) extends LeafExpression { - if (dataType == StringType && value.isInstanceOf[String]) { - value = UTF8String(value.asInstanceOf[String]) + // TODO(davies): FIXME + (value, dataType) match { + case (s: String, StringType) => + value = UTF8String(s) + case (seq: Seq[String], dt:ArrayType) if dt.elementType == StringType => + value = seq.map(UTF8String(_)) + case _ => } override def foldable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c23d3b61887c6..4d11b58533910 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -198,14 +198,19 @@ object LikeSimplification extends Rule[LogicalPlan] { val equalTo = "([^_%]*)".r def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case Like(l, Literal(startsWith(pattern), StringType)) if !pattern.endsWith("\\") => - StartsWith(l, Literal(pattern)) - case Like(l, Literal(endsWith(pattern), StringType)) => - EndsWith(l, Literal(pattern)) - case Like(l, Literal(contains(pattern), StringType)) if !pattern.endsWith("\\") => - Contains(l, Literal(pattern)) - case Like(l, Literal(equalTo(pattern), StringType)) => - EqualTo(l, Literal(pattern)) + case Like(l, Literal(utf, StringType)) => + utf.toString match { + case startsWith(pattern) if !pattern.endsWith("\\") => + StartsWith(l, Literal(pattern)) + case endsWith(pattern) => + EndsWith(l, Literal(pattern)) + case contains(pattern) if !pattern.endsWith("\\") => + Contains(l, Literal(pattern)) + case equalTo(pattern) => + EqualTo(l, Literal(pattern)) + case _ => + Like(l, Literal(utf, StringType)) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala index 6e9a907596d70..bea8526003eae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala @@ -28,7 +28,7 @@ import java.util.Arrays * Note: This is not designed for general use cases, should not be used outside SQL. */ -private[sql] final class UTF8String extends Ordered[UTF8String] with Serializable { +final class UTF8String extends Ordered[UTF8String] with Serializable { private var bytes: Array[Byte] = _ @@ -57,7 +57,7 @@ private[sql] final class UTF8String extends Ordered[UTF8String] with Serializabl var len = 0 var i: Int = 0 while (i < bytes.length) { - val b = bytes(i) + val b = bytes(i) & 0xFF i += 1 if (b >= 192) { i += UTF8String.tailBytesOfUTF8(b - 192) @@ -84,7 +84,7 @@ private[sql] final class UTF8String extends Ordered[UTF8String] with Serializabl var c = 0 var i: Int = 0 while (c < start && i < bytes.length) { - val b = bytes(i) + val b = bytes(i) & 0xFF i += 1 if (b >= 192) { i += UTF8String.tailBytesOfUTF8(b - 192) @@ -93,7 +93,7 @@ private[sql] final class UTF8String extends Ordered[UTF8String] with Serializabl } var j = i while (c < until && j < bytes.length) { - val b = bytes(j) + val b = bytes(j) & 0xFF j += 1 if (b >= 192) { j += UTF8String.tailBytesOfUTF8(b - 192) @@ -160,7 +160,7 @@ private[sql] final class UTF8String extends Ordered[UTF8String] with Serializabl } } -private[sql] object UTF8String { +object UTF8String { // number of tailing bytes in a UTF8 sequence for a code point // see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1 private[types] val tailBytesOfUTF8: Array[Int] = Array(1, 1, 1, 1, 1, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index dcfd8b28cb02a..9d91fc1d41a70 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -32,6 +32,13 @@ import org.apache.spark.sql.types._ class ExpressionEvaluationSuite extends FunSuite { + def create_row(values: Array[Any]): Row = { + new GenericRow(values.toSeq.map { + case s: String => UTF8String(s) + case other => other + }.toArray) + } + test("literals") { checkEvaluation(Literal(1), 1) checkEvaluation(Literal(true), true) @@ -242,23 +249,23 @@ class ExpressionEvaluationSuite extends FunSuite { test("LIKE Non-literal Regular Expression") { val regEx = 'a.string.at(0) - checkEvaluation("abcd" like regEx, null, new GenericRow(Array[Any](null))) - checkEvaluation("abdef" like regEx, true, new GenericRow(Array[Any]("abdef"))) - checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a\\__b"))) - checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a_%b"))) - checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a\\__b"))) - checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a%\\%b"))) - checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a%\\%b"))) - checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a%"))) - checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("**"))) - checkEvaluation("abc" like regEx, true, new GenericRow(Array[Any]("a%"))) - checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("b%"))) - checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("bc%"))) - checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a_b"))) - checkEvaluation("ab" like regEx, true, new GenericRow(Array[Any]("a%b"))) - checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a%b"))) - - checkEvaluation(Literal(null, StringType) like regEx, null, new GenericRow(Array[Any]("bc%"))) + checkEvaluation("abcd" like regEx, null, create_row(Array[Any](null))) + checkEvaluation("abdef" like regEx, true, create_row(Array[Any]("abdef"))) + checkEvaluation("a_%b" like regEx, true, create_row(Array[Any]("a\\__b"))) + checkEvaluation("addb" like regEx, true, create_row(Array[Any]("a_%b"))) + checkEvaluation("addb" like regEx, false, create_row(Array[Any]("a\\__b"))) + checkEvaluation("addb" like regEx, false, create_row(Array[Any]("a%\\%b"))) + checkEvaluation("a_%b" like regEx, true, create_row(Array[Any]("a%\\%b"))) + checkEvaluation("addb" like regEx, true, create_row(Array[Any]("a%"))) + checkEvaluation("addb" like regEx, false, create_row(Array[Any]("**"))) + checkEvaluation("abc" like regEx, true, create_row(Array[Any]("a%"))) + checkEvaluation("abc" like regEx, false, create_row(Array[Any]("b%"))) + checkEvaluation("abc" like regEx, false, create_row(Array[Any]("bc%"))) + checkEvaluation("a\nb" like regEx, true, create_row(Array[Any]("a_b"))) + checkEvaluation("ab" like regEx, true, create_row(Array[Any]("a%b"))) + checkEvaluation("a\nb" like regEx, true, create_row(Array[Any]("a%b"))) + + checkEvaluation(Literal(null, StringType) like regEx, null, create_row(Array[Any]("bc%"))) } test("RLIKE literal Regular Expression") { @@ -289,14 +296,14 @@ class ExpressionEvaluationSuite extends FunSuite { test("RLIKE Non-literal Regular Expression") { val regEx = 'a.string.at(0) - checkEvaluation("abdef" rlike regEx, true, new GenericRow(Array[Any]("abdef"))) - checkEvaluation("abbbbc" rlike regEx, true, new GenericRow(Array[Any]("a.*c"))) - checkEvaluation("fofo" rlike regEx, true, new GenericRow(Array[Any]("^fo"))) - checkEvaluation("fo\no" rlike regEx, true, new GenericRow(Array[Any]("^fo\no$"))) - checkEvaluation("Bn" rlike regEx, true, new GenericRow(Array[Any]("^Ba*n"))) + checkEvaluation("abdef" rlike regEx, true, create_row(Array[Any]("abdef"))) + checkEvaluation("abbbbc" rlike regEx, true, create_row(Array[Any]("a.*c"))) + checkEvaluation("fofo" rlike regEx, true, create_row(Array[Any]("^fo"))) + checkEvaluation("fo\no" rlike regEx, true, create_row(Array[Any]("^fo\no$"))) + checkEvaluation("Bn" rlike regEx, true, create_row(Array[Any]("^Ba*n"))) intercept[java.util.regex.PatternSyntaxException] { - evaluate("abbbbc" rlike regEx, new GenericRow(Array[Any]("**"))) + evaluate("abbbbc" rlike regEx, create_row(Array[Any]("**"))) } } @@ -551,10 +558,10 @@ class ExpressionEvaluationSuite extends FunSuite { test("map casting") { val map = Literal( - Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null), + Map(u("a") -> u("123"), u("b") -> u("abc"), u("c") -> u(""), u("d") -> null), MapType(StringType, StringType, valueContainsNull = true)) val map_notNull = Literal( - Map("a" -> "123", "b" -> "abc", "c" -> ""), + Map(u("a") -> u("123"), u("b") -> u("abc"), u("c") -> u("")), MapType(StringType, StringType, valueContainsNull = false)) { @@ -612,14 +619,14 @@ class ExpressionEvaluationSuite extends FunSuite { test("struct casting") { val struct = Literal( - Row("123", "abc", "", null), + Row(u("123"), u("abc"), u(""), null), StructType(Seq( StructField("a", StringType, nullable = true), StructField("b", StringType, nullable = true), StructField("c", StringType, nullable = true), StructField("d", StringType, nullable = true)))) val struct_notNull = Literal( - Row("123", "abc", ""), + Row(u("123"), u("abc"), u("")), StructType(Seq( StructField("a", StringType, nullable = false), StructField("b", StringType, nullable = false), @@ -705,11 +712,13 @@ class ExpressionEvaluationSuite extends FunSuite { } } + def u(s: String): UTF8String = UTF8String(s) + test("complex casting") { val complex = Literal( Row( - Seq("123", "abc", ""), - Map("a" -> "123", "b" -> "abc", "c" -> ""), + Seq(u("123"), u("abc"), u("")), + Map(u("a") -> u("123"), u("b") -> u("abc"), u("c") -> u("")), Row(0)), StructType(Seq( StructField("a", @@ -737,7 +746,7 @@ class ExpressionEvaluationSuite extends FunSuite { } test("null checking") { - val row = new GenericRow(Array[Any]("^Ba*n", null, true, null)) + val row = create_row(Array[Any]("^Ba*n", null, true, null)) val c1 = 'a.string.at(0) val c2 = 'a.string.at(1) val c3 = 'a.boolean.at(2) @@ -776,7 +785,7 @@ class ExpressionEvaluationSuite extends FunSuite { } test("case when") { - val row = new GenericRow(Array[Any](null, false, true, "a", "b", "c")) + val row = create_row(Array[Any](null, false, true, "a", "b", "c")) val c1 = 'a.boolean.at(0) val c2 = 'a.boolean.at(1) val c3 = 'a.boolean.at(2) @@ -819,12 +828,12 @@ class ExpressionEvaluationSuite extends FunSuite { } test("complex type") { - val row = new GenericRow(Array[Any]( - "^Ba*n", // 0 - null.asInstanceOf[String], // 1 - new GenericRow(Array[Any]("aa", "bb")), // 2 - Map("aa"->"bb"), // 3 - Seq("aa", "bb") // 4 + val row = create_row(Array[Any]( + "^Ba*n", // 0 + null.asInstanceOf[UTF8String], // 1 + create_row(Array[Any]("aa", "bb")), // 2 + Map(u("aa")->u("bb")), // 3 + Seq(u("aa"), u("bb")) // 4 )) val typeS = StructType( @@ -877,7 +886,7 @@ class ExpressionEvaluationSuite extends FunSuite { } test("arithmetic") { - val row = new GenericRow(Array[Any](1, 2, 3, null)) + val row = create_row(Array[Any](1, 2, 3, null)) val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.int.at(2) @@ -901,7 +910,7 @@ class ExpressionEvaluationSuite extends FunSuite { } test("fractional arithmetic") { - val row = new GenericRow(Array[Any](1.1, 2.0, 3.1, null)) + val row = create_row(Array[Any](1.1, 2.0, 3.1, null)) val c1 = 'a.double.at(0) val c2 = 'a.double.at(1) val c3 = 'a.double.at(2) @@ -924,7 +933,7 @@ class ExpressionEvaluationSuite extends FunSuite { } test("BinaryComparison") { - val row = new GenericRow(Array[Any](1, 2, 3, null, 3, null)) + val row = create_row(Array[Any](1, 2, 3, null, 3, null)) val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.int.at(2) @@ -953,7 +962,7 @@ class ExpressionEvaluationSuite extends FunSuite { } test("StringComparison") { - val row = new GenericRow(Array[Any]("abc", null)) + val row = create_row(Array[Any]("abc", null)) val c1 = 'a.string.at(0) val c2 = 'a.string.at(1) @@ -974,7 +983,7 @@ class ExpressionEvaluationSuite extends FunSuite { } test("Substring") { - val row = new GenericRow(Array[Any]("example", "example".toArray.map(_.toByte))) + val row = create_row(Array[Any]("example", "example".toArray.map(_.toByte))) val s = 'a.string.at(0) @@ -1006,7 +1015,7 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(4, IntegerType)), "", row) // substring(null, _, _) -> null - checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(4, IntegerType)), null, new GenericRow(Array[Any](null))) + checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(4, IntegerType)), null, create_row(Array[Any](null))) // substring(_, null, _) -> null checkEvaluation(Substring(s, Literal(null, IntegerType), Literal(4, IntegerType)), null, row) @@ -1037,20 +1046,20 @@ class ExpressionEvaluationSuite extends FunSuite { test("SQRT") { val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24)) val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble)) - val rowSequence = inputSequence.map(l => new GenericRow(Array[Any](l.toDouble))) + val rowSequence = inputSequence.map(l => create_row(Array[Any](l.toDouble))) val d = 'a.double.at(0) for ((row, expected) <- rowSequence zip expectedResults) { checkEvaluation(Sqrt(d), expected, row) } - checkEvaluation(Sqrt(Literal(null, DoubleType)), null, new GenericRow(Array[Any](null))) + checkEvaluation(Sqrt(Literal(null, DoubleType)), null, create_row(Array[Any](null))) checkEvaluation(Sqrt(-1), null, EmptyRow) checkEvaluation(Sqrt(-1.5), null, EmptyRow) } test("Bitwise operations") { - val row = new GenericRow(Array[Any](1, 2, 3, null)) + val row = create_row(Array[Any](1, 2, 3, null)) val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.int.at(2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala index 987087e0b7a64..04a435da33584 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala @@ -37,7 +37,7 @@ class UTF8StringSuite extends FunSuite { check("世 界", 3) } - test("contains, prefix and suffix") { + test("contains") { assert(UTF8String("hello").contains(UTF8String("ello"))) assert(!UTF8String("hello").contains(UTF8String("vello"))) assert(UTF8String("大千世界").contains(UTF8String("千世"))) @@ -59,7 +59,7 @@ class UTF8StringSuite extends FunSuite { } test("slice") { - assert(UTF8String("hello").slice(1, 3) == UTF8String("ell")) + assert(UTF8String("hello").slice(1, 3) == UTF8String("el")) assert(UTF8String("大千世界").slice(0, 1) == UTF8String("大")) assert(UTF8String("大千世界").slice(1, 3) == UTF8String("千世")) assert(UTF8String("大千世界").slice(3, 5) == UTF8String("界")) From bb52e442779c3ab46fc0dc24ac5332ae17e22ad7 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 1 Apr 2015 13:00:03 -0700 Subject: [PATCH 10/30] fix scala style --- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 53ff8fb284619..c2e94d1334f26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -248,7 +248,10 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin """.children case Cast(child @ DateType(), StringType) => - child.castOrNull(c => q"org.apache.spark.sql.types.UTF8String(org.apache.spark.sql.types.DateUtils.toString($c))", StringType) + child.castOrNull(c => + q"""org.apache.spark.sql.types.UTF8String( + org.apache.spark.sql.types.DateUtils.toString($c))""", + StringType) case Cast(child @ NumericType(), IntegerType) => child.castOrNull(c => q"$c.toInt", IntegerType) @@ -574,7 +577,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val localLogger = log val localLoggerTree = reify { localLogger } q""" - $localLoggerTree.debug(${e.toString} + ": " + (if($nullTerm) "null" else $primitiveTerm.toString)) + $localLoggerTree.debug( + ${e.toString} + ": " + (if ($nullTerm) "null" else $primitiveTerm.toString)) """ :: Nil } else { Nil From 8b458644a269d9e1542a2a4ab24d735aa95ed49d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 1 Apr 2015 14:41:12 -0700 Subject: [PATCH 11/30] fix codegen with UTF8String --- .../expressions/codegen/GenerateProjection.scala | 6 ++++-- .../expressions/GeneratedMutableEvaluationSuite.scala | 9 --------- .../scala/org/apache/spark/sql/columnar/ColumnType.scala | 2 +- 3 files changed, 5 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 9a0d32e241824..cad0611005937 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -111,7 +111,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val specificAccessorFunctions = NativeType.all.map { dataType => val ifStatements = expressions.zipWithIndex.flatMap { - case (e, i) if e.dataType == dataType => + // getString() is not used by expressions + case (e, i) if e.dataType == dataType && dataType != StringType => val elementName = newTermName(s"c$i") // TODO: The string of ifs gets pretty inefficient as the row grows in size. // TODO: Optional null checks? @@ -136,7 +137,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val specificMutatorFunctions = NativeType.all.map { dataType => val ifStatements = expressions.zipWithIndex.flatMap { - case (e, i) if e.dataType == dataType => + // setString() is not used by expressions + case (e, i) if e.dataType == dataType && dataType != StringType => val elementName = newTermName(s"c$i") // TODO: The string of ifs gets pretty inefficient as the row grows in size. // TODO: Optional null checks? diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala index 275ea2627ebcd..925e994b9cf55 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ /** @@ -44,14 +43,6 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { val actual = plan(inputRow) val expectedRow = new GenericRow(Array[Any](expected)) - if (actual.hashCode() != expectedRow.hashCode()) { - fail( - s""" - |Mismatched hashCodes for values: $actual, $expectedRow - |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} - |${evaluated.code.mkString("\n")} - """.stripMargin) - } if (actual != expectedRow) { val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index bc5b7f2371890..1b9e0df2dcb5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -333,7 +333,7 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { } override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { - to.setString(toOrdinal, from.getString(fromOrdinal)) + to.update(toOrdinal, from(fromOrdinal)) } } From 23a766cf8eb9142ef641a7eb69f136a817d8fb32 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 1 Apr 2015 15:33:28 -0700 Subject: [PATCH 12/30] refactor --- .../sql/catalyst/expressions/literals.scala | 25 ++++++++++------ .../ExpressionEvaluationSuite.scala | 29 ++++++++----------- 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 215785e5964a7..1196a9cbf806c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -41,6 +41,19 @@ object Literal { case _ => throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } + + /** + * convert String in `v` as UTF8String + */ + def convertToUTF8String(v: Any): Any = v match { + case s: String => UTF8String(s) + case seq: Seq[Any] => seq.map(convertToUTF8String) + case r: Row => Row(r.toSeq.map(convertToUTF8String): _*) + case arr: Array[Any] => arr.toSeq.map(convertToUTF8String).toArray + case m: Map[Any, Any] => + m.map { case (k, v) => (convertToUTF8String(k), convertToUTF8String(v)) }.toMap + case other => other + } } /** @@ -62,16 +75,10 @@ object IntegerLiteral { } } -case class Literal(var value: Any, dataType: DataType) extends LeafExpression { +case class Literal (var value: Any, dataType: DataType) extends LeafExpression { - // TODO(davies): FIXME - (value, dataType) match { - case (s: String, StringType) => - value = UTF8String(s) - case (seq: Seq[String], dt:ArrayType) if dt.elementType == StringType => - value = seq.map(UTF8String(_)) - case _ => - } + // TODO(davies): move this out of constructor + value = Literal.convertToUTF8String(value) override def foldable: Boolean = true override def nullable: Boolean = value == null diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 9d91fc1d41a70..af583cb8bfd46 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -33,10 +33,7 @@ import org.apache.spark.sql.types._ class ExpressionEvaluationSuite extends FunSuite { def create_row(values: Array[Any]): Row = { - new GenericRow(values.toSeq.map { - case s: String => UTF8String(s) - case other => other - }.toArray) + new GenericRow(values.toSeq.map(Literal.convertToUTF8String).toArray) } test("literals") { @@ -558,10 +555,10 @@ class ExpressionEvaluationSuite extends FunSuite { test("map casting") { val map = Literal( - Map(u("a") -> u("123"), u("b") -> u("abc"), u("c") -> u(""), u("d") -> null), + Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null), MapType(StringType, StringType, valueContainsNull = true)) val map_notNull = Literal( - Map(u("a") -> u("123"), u("b") -> u("abc"), u("c") -> u("")), + Map("a" -> "123", "b" -> "abc", "c" -> ""), MapType(StringType, StringType, valueContainsNull = false)) { @@ -619,14 +616,14 @@ class ExpressionEvaluationSuite extends FunSuite { test("struct casting") { val struct = Literal( - Row(u("123"), u("abc"), u(""), null), + Row("123", "abc", "", null), StructType(Seq( StructField("a", StringType, nullable = true), StructField("b", StringType, nullable = true), StructField("c", StringType, nullable = true), StructField("d", StringType, nullable = true)))) val struct_notNull = Literal( - Row(u("123"), u("abc"), u("")), + Row("123", "abc", ""), StructType(Seq( StructField("a", StringType, nullable = false), StructField("b", StringType, nullable = false), @@ -712,13 +709,11 @@ class ExpressionEvaluationSuite extends FunSuite { } } - def u(s: String): UTF8String = UTF8String(s) - test("complex casting") { val complex = Literal( Row( - Seq(u("123"), u("abc"), u("")), - Map(u("a") -> u("123"), u("b") -> u("abc"), u("c") -> u("")), + Seq("123", "abc", ""), + Map("a" -> "123", "b" -> "abc", "c" -> ""), Row(0)), StructType(Seq( StructField("a", @@ -829,11 +824,11 @@ class ExpressionEvaluationSuite extends FunSuite { test("complex type") { val row = create_row(Array[Any]( - "^Ba*n", // 0 - null.asInstanceOf[UTF8String], // 1 - create_row(Array[Any]("aa", "bb")), // 2 - Map(u("aa")->u("bb")), // 3 - Seq(u("aa"), u("bb")) // 4 + "^Ba*n", // 0 + null.asInstanceOf[UTF8String], // 1 + create_row(Array[Any]("aa", "bb")), // 2 + Map("aa"->"bb"), // 3 + Seq("aa", "bb") // 4 )) val typeS = StructType( From 9dc32d1b08cd6c76a7a096beba67bcbca89b0634 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 1 Apr 2015 21:42:34 -0700 Subject: [PATCH 13/30] fix some hive tests --- .../spark/sql/hive/HiveInspectors.scala | 11 +++--- .../org/apache/spark/sql/hive/Shim13.scala | 36 +++++++++---------- 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 4afa2e71d77cc..ae060442adb98 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -34,7 +34,7 @@ import scala.collection.JavaConversions._ * 1. The Underlying data type in catalyst and in Hive * In catalyst: * Primitive => - * java.lang.String + * UTF8String * int / scala.Int * boolean / scala.Boolean * float / scala.Float @@ -239,7 +239,8 @@ private[hive] trait HiveInspectors { */ def unwrap(data: Any, oi: ObjectInspector): Any = oi match { case coi: ConstantObjectInspector if coi.getWritableConstantValue == null => null - case poi: WritableConstantStringObjectInspector => poi.getWritableConstantValue.toString + case poi: WritableConstantStringObjectInspector => + UTF8String(poi.getWritableConstantValue.toString) case poi: WritableConstantHiveVarcharObjectInspector => poi.getWritableConstantValue.getHiveVarchar.getValue case poi: WritableConstantHiveDecimalObjectInspector => @@ -287,7 +288,7 @@ private[hive] trait HiveInspectors { hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue case hvoi: HiveVarcharObjectInspector => hvoi.getPrimitiveJavaObject(data).getValue case x: StringObjectInspector if x.preferWritable() => - x.getPrimitiveWritableObject(data).toString + UTF8String(x.getPrimitiveWritableObject(data).toString) case x: IntObjectInspector if x.preferWritable() => x.get(data) case x: BooleanObjectInspector if x.preferWritable() => x.get(data) case x: FloatObjectInspector if x.preferWritable() => x.get(data) @@ -340,7 +341,7 @@ private[hive] trait HiveInspectors { */ protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match { case _: JavaHiveVarcharObjectInspector => - (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) + (o: Any) => new HiveVarchar(o.asInstanceOf[UTF8String].toString, o.asInstanceOf[String].size) case _: JavaHiveDecimalObjectInspector => (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toJavaBigDecimal) @@ -409,7 +410,7 @@ private[hive] trait HiveInspectors { case x: PrimitiveObjectInspector => x match { // TODO we don't support the HiveVarcharObjectInspector yet. case _: StringObjectInspector if x.preferWritable() => HiveShim.getStringWritable(a) - case _: StringObjectInspector => a.asInstanceOf[java.lang.String] + case _: StringObjectInspector => a.asInstanceOf[UTF8String].toString() case _: IntObjectInspector if x.preferWritable() => HiveShim.getIntWritable(a) case _: IntObjectInspector => a.asInstanceOf[java.lang.Integer] case _: BooleanObjectInspector if x.preferWritable() => HiveShim.getBooleanWritable(a) diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index 7577309900209..4523de5bffbb2 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -17,36 +17,33 @@ package org.apache.spark.sql.hive -import java.util -import java.util.{ArrayList => JArrayList} -import java.util.Properties import java.rmi.server.UID +import java.util.{Properties, ArrayList => JArrayList} import scala.collection.JavaConversions._ import scala.language.implicitConversions import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.{NullWritable, Writable} -import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.hive.common.StatsSetupConst -import org.apache.hadoop.hive.common.`type`.{HiveDecimal} +import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Context -import org.apache.hadoop.hive.ql.metadata.{Table, Hive, Partition} +import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, DecimalTypeInfo, TypeInfoFactory} -import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, PrimitiveObjectInspector, ObjectInspector} -import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils} -import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable +import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorConverters, PrimitiveObjectInspector} +import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfo, TypeInfoFactory} +import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo} +import org.apache.hadoop.io.{NullWritable, Writable} +import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.Logging -import org.apache.spark.sql.types.{Decimal, DecimalType} +import org.apache.spark.sql.types.{Decimal, DecimalType, UTF8String} /** @@ -63,11 +60,12 @@ private[hive] case class HiveFunctionWrapper(var functionClassName: String) // for Serialization def this() = this(null) - import java.io.{OutputStream, InputStream} - import com.esotericsoftware.kryo.Kryo - import org.apache.spark.util.Utils._ - import org.apache.hadoop.hive.ql.exec.Utilities - import org.apache.hadoop.hive.ql.exec.UDF + import java.io.{InputStream, OutputStream} + +import com.esotericsoftware.kryo.Kryo + import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} + +import org.apache.spark.util.Utils._ @transient private val methodDeSerialize = { @@ -224,7 +222,7 @@ private[hive] object HiveShim { TypeInfoFactory.voidTypeInfo, null) def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String]) + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) def getIntWritable(value: Any): hadoopIo.IntWritable = if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) From 956b0a48bfccc6550f66319c4475220eb6101b23 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 2 Apr 2015 01:12:40 -0700 Subject: [PATCH 14/30] fix hive tests --- .../sql/catalyst/expressions/literals.scala | 2 +- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../ExpressionEvaluationSuite.scala | 2 +- .../ParquetPartitionDiscoverySuite.scala | 48 ++++++++++--------- .../spark/sql/hive/HiveStrategies.scala | 7 ++- 5 files changed, 34 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 0c0de195e1fd3..2505334426c46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -29,7 +29,7 @@ object Literal { case f: Float => Literal(f, FloatType) case b: Byte => Literal(b, ByteType) case s: Short => Literal(s, ShortType) - case s: String => Literal(s, StringType) + case s: String => Literal(UTF8String(s), StringType) case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9d017d957cd0e..7c80634d2c852 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -209,7 +209,7 @@ object LikeSimplification extends Rule[LogicalPlan] { case equalTo(pattern) => EqualTo(l, Literal(pattern)) case _ => - Like(l, Literal(utf, StringType)) + Like(l, Literal.create(utf, StringType)) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index cd28ab36737d3..0b4e2f4c1f9ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -1054,7 +1054,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(Sqrt(d), expected, row) } - checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, new create_row(Array[Any](null))) + checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(Array[Any](null))) checkEvaluation(Sqrt(-1), null, EmptyRow) checkEvaluation(Sqrt(-1.5), null, EmptyRow) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index b7561ce7298cb..d294dfbac0563 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -20,7 +20,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.parquet.ParquetRelation2._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ @@ -40,6 +40,10 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { val defaultPartitionName = "__NULL__" + def create_row(values: Any*): Row = { + new GenericRow(values.map(Literal.convertToUTF8String).toArray) + } + test("column type inference") { def check(raw: String, literal: Literal): Unit = { assert(inferPartitionColumnValue(raw, defaultPartitionName) === literal) @@ -107,7 +111,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { StructType(Seq( StructField("a", IntegerType), StructField("b", StringType))), - Seq(Partition(Row(10, "hello"), "hdfs://host:9000/path/a=10/b=hello")))) + Seq(Partition(create_row(10, "hello"), "hdfs://host:9000/path/a=10/b=hello")))) check(Seq( "hdfs://host:9000/path/a=10/b=20", @@ -117,8 +121,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { StructField("a", FloatType), StructField("b", StringType))), Seq( - Partition(Row(10, "20"), "hdfs://host:9000/path/a=10/b=20"), - Partition(Row(10.5, "hello"), "hdfs://host:9000/path/a=10.5/b=hello")))) + Partition(create_row(10, "20"), "hdfs://host:9000/path/a=10/b=20"), + Partition(create_row(10.5, "hello"), "hdfs://host:9000/path/a=10.5/b=hello")))) check(Seq( s"hdfs://host:9000/path/a=10/b=20", @@ -128,8 +132,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { StructField("a", IntegerType), StructField("b", StringType))), Seq( - Partition(Row(10, "20"), s"hdfs://host:9000/path/a=10/b=20"), - Partition(Row(null, "hello"), s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello")))) + Partition(create_row(10, "20"), s"hdfs://host:9000/path/a=10/b=20"), + Partition(create_row(null, "hello"), s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello")))) check(Seq( s"hdfs://host:9000/path/a=10/b=$defaultPartitionName", @@ -139,8 +143,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { StructField("a", FloatType), StructField("b", StringType))), Seq( - Partition(Row(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), - Partition(Row(10.5, null), s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName")))) + Partition(create_row(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), + Partition(create_row(10.5, null), s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName")))) } test("read partitioned table - normal case") { @@ -163,7 +167,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { i <- 1 to 10 pi <- Seq(1, 2) ps <- Seq("foo", "bar") - } yield Row(i, i.toString, pi, ps)) + } yield create_row(i, i.toString, pi, ps)) checkAnswer( sql("SELECT intField, pi FROM t"), @@ -171,21 +175,21 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { i <- 1 to 10 pi <- Seq(1, 2) _ <- Seq("foo", "bar") - } yield Row(i, pi)) + } yield create_row(i, pi)) checkAnswer( sql("SELECT * FROM t WHERE pi = 1"), for { i <- 1 to 10 ps <- Seq("foo", "bar") - } yield Row(i, i.toString, 1, ps)) + } yield create_row(i, i.toString, 1, ps)) checkAnswer( sql("SELECT * FROM t WHERE ps = 'foo'"), for { i <- 1 to 10 pi <- Seq(1, 2) - } yield Row(i, i.toString, pi, "foo")) + } yield create_row(i, i.toString, pi, "foo")) } } } @@ -210,7 +214,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { i <- 1 to 10 pi <- Seq(1, 2) ps <- Seq("foo", "bar") - } yield Row(i, pi, i.toString, ps)) + } yield create_row(i, pi, i.toString, ps)) checkAnswer( sql("SELECT intField, pi FROM t"), @@ -218,21 +222,21 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { i <- 1 to 10 pi <- Seq(1, 2) _ <- Seq("foo", "bar") - } yield Row(i, pi)) + } yield create_row(i, pi)) checkAnswer( sql("SELECT * FROM t WHERE pi = 1"), for { i <- 1 to 10 ps <- Seq("foo", "bar") - } yield Row(i, 1, i.toString, ps)) + } yield create_row(i, 1, i.toString, ps)) checkAnswer( sql("SELECT * FROM t WHERE ps = 'foo'"), for { i <- 1 to 10 pi <- Seq(1, 2) - } yield Row(i, pi, i.toString, "foo")) + } yield create_row(i, pi, i.toString, "foo")) } } } @@ -264,21 +268,21 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { i <- 1 to 10 pi <- Seq(1, null.asInstanceOf[Integer]) ps <- Seq("foo", null.asInstanceOf[String]) - } yield Row(i, i.toString, pi, ps)) + } yield create_row(i, i.toString, pi, ps)) checkAnswer( sql("SELECT * FROM t WHERE pi IS NULL"), for { i <- 1 to 10 ps <- Seq("foo", null.asInstanceOf[String]) - } yield Row(i, i.toString, null, ps)) + } yield create_row(i, i.toString, null, ps)) checkAnswer( sql("SELECT * FROM t WHERE ps IS NULL"), for { i <- 1 to 10 pi <- Seq(1, null.asInstanceOf[Integer]) - } yield Row(i, i.toString, pi, null)) + } yield create_row(i, i.toString, pi, null)) } } } @@ -309,14 +313,14 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { i <- 1 to 10 pi <- Seq(1, 2) ps <- Seq("foo", null.asInstanceOf[String]) - } yield Row(i, pi, i.toString, ps)) + } yield create_row(i, pi, i.toString, ps)) checkAnswer( sql("SELECT * FROM t WHERE ps IS NULL"), for { i <- 1 to 10 pi <- Seq(1, 2) - } yield Row(i, pi, i.toString, null)) + } yield create_row(i, pi, i.toString, null)) } } } @@ -336,7 +340,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { withTempTable("t") { checkAnswer( sql("SELECT * FROM t"), - (1 to 10).map(i => Row(i, null, 1)) ++ (1 to 10).map(i => Row(i, i.toString, 2))) + (1 to 10).map(i => create_row(i, null, 1)) ++ (1 to 10).map(i => create_row(i, i.toString, 2))) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 5f7e897295117..66b9089bf2a01 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, CreateTableUsing} -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{UTF8String, StringType} private[hive] trait HiveStrategies { @@ -131,7 +131,10 @@ private[hive] trait HiveStrategies { val partitionValues = part.getValues var i = 0 while (i < partitionValues.size()) { - inputData(i) = partitionValues(i) + inputData(i) = partitionValues(i) match { + case s: String => UTF8String(s) + case other => other + } i += 1 } pruningCondition(inputData) From 9f4c194cbc44a3b163a6f2d3e1f7f831518d7776 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 2 Apr 2015 14:28:05 -0700 Subject: [PATCH 15/30] convert data type for data source --- .../org/apache/spark/sql/jdbc/JDBCRDD.scala | 1 + .../apache/spark/sql/jdbc/JDBCRelation.scala | 2 + .../apache/spark/sql/json/JSONRelation.scala | 8 ++-- .../apache/spark/sql/parquet/newParquet.scala | 3 ++ .../sql/sources/DataSourceStrategy.scala | 21 ++++++--- .../apache/spark/sql/sources/interfaces.scala | 7 +++ .../ParquetPartitionDiscoverySuite.scala | 46 +++++++++---------- .../spark/sql/sources/TableScanSuite.scala | 10 ++-- .../spark/sql/hive/HiveInspectors.scala | 4 +- 9 files changed, 62 insertions(+), 40 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 9447a8b77f541..1f7f4d6bfb838 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -355,6 +355,7 @@ private[sql] class JDBCRDD( case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) case LongConversion => mutableRow.setLong(i, rs.getLong(pos)) + // TODO(davies): use getBytes for better performance case StringConversion => mutableRow.setString(i, rs.getString(pos)) case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos)) case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 4fa84dc076f7e..99b755c9f25d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -130,6 +130,8 @@ private[sql] case class JDBCRelation( extends BaseRelation with PrunedFilteredScan { + override val needConversion: Boolean = false + override val schema: StructType = JDBCRDD.resolveTable(url, table, properties) override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index f4c99b4b56606..e3352d02787fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.json import java.io.IOException import org.apache.hadoop.fs.Path + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.Row - -import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} private[sql] class DefaultSource @@ -113,6 +113,8 @@ private[sql] case class JSONRelation( // TODO: Support partitioned JSON relation. private def baseRDD = sqlContext.sparkContext.textFile(path) + override val needConversion: Boolean = false + override val schema = userSpecifiedSchema.getOrElse( JsonRDD.nullTypeToStringType( JsonRDD.inferSchema( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index fbc6309992d4f..6a884c83e8009 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -404,6 +404,9 @@ private[sql] case class ParquetRelation2( file.getName == ParquetFileWriter.PARQUET_METADATA_FILE } + // Skip type conversion + override val needConversion: Boolean = false + // TODO Should calculate per scan size // It's common that a query only scans a fraction of a large Parquet file. Returning size of the // whole Parquet file disables some optimizations in this case (e.g. broadcast join). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index e13759b7feb7b..564084998b939 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -53,7 +53,7 @@ private[sql] object DataSourceStrategy extends Strategy { (a, _) => t.buildScan(a)) :: Nil case l @ LogicalRelation(t: TableScan) => - execution.PhysicalRDD(l.output, t.buildScan()) :: Nil + createPhysicalRDD(l.relation, l.output, t.buildScan()) :: Nil case i @ logical.InsertIntoTable( l @ LogicalRelation(t: InsertableRelation), part, query, overwrite) if part.isEmpty => @@ -102,20 +102,29 @@ private[sql] object DataSourceStrategy extends Strategy { projectList.asInstanceOf[Seq[Attribute]] // Safe due to if above. .map(relation.attributeMap) // Match original case of attributes. - val scan = - execution.PhysicalRDD( - projectList.map(_.toAttribute), + val scan = createPhysicalRDD(relation.relation, projectList.map(_.toAttribute), scanBuilder(requestedColumns, pushedFilters)) filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) } else { val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq - val scan = - execution.PhysicalRDD(requestedColumns, scanBuilder(requestedColumns, pushedFilters)) + val scan = createPhysicalRDD(relation.relation, requestedColumns, + scanBuilder(requestedColumns, pushedFilters)) execution.Project(projectList, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) } } + private[this] def createPhysicalRDD(relation: BaseRelation, + output: Seq[Attribute], + rdd: RDD[Row]) = { + val converted = if (relation.needConversion) { + execution.RDDConversions.rowToRowRdd(rdd, relation.schema) + } else { + rdd + } + execution.PhysicalRDD(output, converted) + } + /** * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s, * and convert them. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 8f9946a5a801e..d431db42814ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -126,6 +126,13 @@ abstract class BaseRelation { * could lead to execution plans that are suboptimal (i.e. broadcasting a very large table). */ def sizeInBytes: Long = sqlContext.conf.defaultSizeInBytes + + /** + * Whether does it need to convert the objects in Row to internal representation, for example: + * java.lang.String -> UTF8String + * java.lang.Decimal -> Decimal + */ + def needConversion: Boolean = true } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index d294dfbac0563..fc0d418f99351 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -40,10 +40,6 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { val defaultPartitionName = "__NULL__" - def create_row(values: Any*): Row = { - new GenericRow(values.map(Literal.convertToUTF8String).toArray) - } - test("column type inference") { def check(raw: String, literal: Literal): Unit = { assert(inferPartitionColumnValue(raw, defaultPartitionName) === literal) @@ -111,7 +107,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { StructType(Seq( StructField("a", IntegerType), StructField("b", StringType))), - Seq(Partition(create_row(10, "hello"), "hdfs://host:9000/path/a=10/b=hello")))) + Seq(Partition(Row(10, "hello"), "hdfs://host:9000/path/a=10/b=hello")))) check(Seq( "hdfs://host:9000/path/a=10/b=20", @@ -121,8 +117,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { StructField("a", FloatType), StructField("b", StringType))), Seq( - Partition(create_row(10, "20"), "hdfs://host:9000/path/a=10/b=20"), - Partition(create_row(10.5, "hello"), "hdfs://host:9000/path/a=10.5/b=hello")))) + Partition(Row(10, "20"), "hdfs://host:9000/path/a=10/b=20"), + Partition(Row(10.5, "hello"), "hdfs://host:9000/path/a=10.5/b=hello")))) check(Seq( s"hdfs://host:9000/path/a=10/b=20", @@ -132,8 +128,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { StructField("a", IntegerType), StructField("b", StringType))), Seq( - Partition(create_row(10, "20"), s"hdfs://host:9000/path/a=10/b=20"), - Partition(create_row(null, "hello"), s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello")))) + Partition(Row(10, "20"), s"hdfs://host:9000/path/a=10/b=20"), + Partition(Row(null, "hello"), s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello")))) check(Seq( s"hdfs://host:9000/path/a=10/b=$defaultPartitionName", @@ -143,8 +139,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { StructField("a", FloatType), StructField("b", StringType))), Seq( - Partition(create_row(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), - Partition(create_row(10.5, null), s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName")))) + Partition(Row(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), + Partition(Row(10.5, null), s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName")))) } test("read partitioned table - normal case") { @@ -167,7 +163,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { i <- 1 to 10 pi <- Seq(1, 2) ps <- Seq("foo", "bar") - } yield create_row(i, i.toString, pi, ps)) + } yield Row(i, i.toString, pi, ps)) checkAnswer( sql("SELECT intField, pi FROM t"), @@ -175,21 +171,21 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { i <- 1 to 10 pi <- Seq(1, 2) _ <- Seq("foo", "bar") - } yield create_row(i, pi)) + } yield Row(i, pi)) checkAnswer( sql("SELECT * FROM t WHERE pi = 1"), for { i <- 1 to 10 ps <- Seq("foo", "bar") - } yield create_row(i, i.toString, 1, ps)) + } yield Row(i, i.toString, 1, ps)) checkAnswer( sql("SELECT * FROM t WHERE ps = 'foo'"), for { i <- 1 to 10 pi <- Seq(1, 2) - } yield create_row(i, i.toString, pi, "foo")) + } yield Row(i, i.toString, pi, "foo")) } } } @@ -214,7 +210,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { i <- 1 to 10 pi <- Seq(1, 2) ps <- Seq("foo", "bar") - } yield create_row(i, pi, i.toString, ps)) + } yield Row(i, pi, i.toString, ps)) checkAnswer( sql("SELECT intField, pi FROM t"), @@ -222,21 +218,21 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { i <- 1 to 10 pi <- Seq(1, 2) _ <- Seq("foo", "bar") - } yield create_row(i, pi)) + } yield Row(i, pi)) checkAnswer( sql("SELECT * FROM t WHERE pi = 1"), for { i <- 1 to 10 ps <- Seq("foo", "bar") - } yield create_row(i, 1, i.toString, ps)) + } yield Row(i, 1, i.toString, ps)) checkAnswer( sql("SELECT * FROM t WHERE ps = 'foo'"), for { i <- 1 to 10 pi <- Seq(1, 2) - } yield create_row(i, pi, i.toString, "foo")) + } yield Row(i, pi, i.toString, "foo")) } } } @@ -268,21 +264,21 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { i <- 1 to 10 pi <- Seq(1, null.asInstanceOf[Integer]) ps <- Seq("foo", null.asInstanceOf[String]) - } yield create_row(i, i.toString, pi, ps)) + } yield Row(i, i.toString, pi, ps)) checkAnswer( sql("SELECT * FROM t WHERE pi IS NULL"), for { i <- 1 to 10 ps <- Seq("foo", null.asInstanceOf[String]) - } yield create_row(i, i.toString, null, ps)) + } yield Row(i, i.toString, null, ps)) checkAnswer( sql("SELECT * FROM t WHERE ps IS NULL"), for { i <- 1 to 10 pi <- Seq(1, null.asInstanceOf[Integer]) - } yield create_row(i, i.toString, pi, null)) + } yield Row(i, i.toString, pi, null)) } } } @@ -313,14 +309,14 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { i <- 1 to 10 pi <- Seq(1, 2) ps <- Seq("foo", null.asInstanceOf[String]) - } yield create_row(i, pi, i.toString, ps)) + } yield Row(i, pi, i.toString, ps)) checkAnswer( sql("SELECT * FROM t WHERE ps IS NULL"), for { i <- 1 to 10 pi <- Seq(1, 2) - } yield create_row(i, pi, i.toString, null)) + } yield Row(i, pi, i.toString, null)) } } } @@ -340,7 +336,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { withTempTable("t") { checkAnswer( sql("SELECT * FROM t"), - (1 to 10).map(i => create_row(i, null, 1)) ++ (1 to 10).map(i => create_row(i, i.toString, 2))) + (1 to 10).map(i => Row(i, null, 1)) ++ (1 to 10).map(i => Row(i, i.toString, 2))) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 7928600ac2fb5..cadb3dfa1dadc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -73,7 +73,7 @@ case class AllDataTypesScan( i.toDouble, new java.math.BigDecimal(i), new java.math.BigDecimal(i), - new Date((i + 1) * 8640000), + new Date(1970, 1, 1), new Timestamp(20000 + i), s"varchar_$i", Seq(i, i + 1), @@ -81,7 +81,7 @@ case class AllDataTypesScan( Map(i -> i.toString), Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), Row(i, i.toString), - Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000))))) + Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(1970, 1, i + 1))))) } } } @@ -102,7 +102,7 @@ class TableScanSuite extends DataSourceTest { i.toDouble, new java.math.BigDecimal(i), new java.math.BigDecimal(i), - new Date((i + 1) * 8640000), + new Date(1970, 1, 1), new Timestamp(20000 + i), s"varchar_$i", Seq(i, i + 1), @@ -110,7 +110,7 @@ class TableScanSuite extends DataSourceTest { Map(i -> i.toString), Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), Row(i, i.toString), - Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000))))) + Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(1970, 1, i + 1))))) }.toSeq before { @@ -265,7 +265,7 @@ class TableScanSuite extends DataSourceTest { sqlTest( "SELECT structFieldComplex.Value.`value_(2)` FROM tableWithSchema", - (1 to 10).map(i => Row(Seq(new Date((i + 2) * 8640000)))).toSeq) + (1 to 10).map(i => Row(Seq(new Date(1970, 1, i + 1)))).toSeq) test("Caching") { // Cached Query Execution diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 563866e7ee504..0d0a5b134f07f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -341,7 +341,9 @@ private[hive] trait HiveInspectors { */ protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match { case _: JavaHiveVarcharObjectInspector => - (o: Any) => new HiveVarchar(o.asInstanceOf[UTF8String].toString, o.asInstanceOf[String].size) + (o: Any) => + val s = o.asInstanceOf[UTF8String].toString + new HiveVarchar(s, s.size) case _: JavaHiveDecimalObjectInspector => (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toJavaBigDecimal) From 537631c09c5167dc8b13785b5aa678d5c984d73a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 2 Apr 2015 14:42:33 -0700 Subject: [PATCH 16/30] some comment about Date --- sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala | 1 + .../scala/org/apache/spark/sql/catalyst/expressions/rows.scala | 3 +++ .../src/main/scala/org/apache/spark/sql/types/DateUtils.scala | 2 ++ .../src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala | 3 ++- 4 files changed, 8 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index fb2346c1b1831..ac8a782976465 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -257,6 +257,7 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ + // TODO(davies): This is not the right default implementation, we use Int as Date internally def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date] /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 4c7ab93f8490a..91b89d8c11620 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -37,6 +37,7 @@ trait MutableRow extends Row { def setByte(ordinal: Int, value: Byte) def setFloat(ordinal: Int, value: Float) def setString(ordinal: Int, value: String) + // TODO(davies): add setDate() and setDecimal() } /** @@ -121,6 +122,8 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { } } + // TODO(davies): add getDate and getDecimal + // Custom hashCode function that matches the efficient code generated version. override def hashCode: Int = { var result: Int = 37 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala index 8a1a3b81b3d2c..8e4105109f7d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala @@ -39,6 +39,8 @@ object DateUtils { millisToDays(d.getTime) } + // TODO(davies): This is buggy, it will be wrong if the date is not aligned with day + // we should use the exact day as Int, for example, (year, month, day) -> day def millisToDays(millisLocal: Long): Int = { ((millisLocal + LOCAL_TIMEZONE.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 1f7f4d6bfb838..b9022fcd9e3ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -349,13 +349,14 @@ private[sql] class JDBCRDD( val pos = i + 1 conversions(i) match { case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos)) + // TODO(davies): convert Date into Int case DateConversion => mutableRow.update(i, rs.getDate(pos)) case DecimalConversion => mutableRow.update(i, rs.getBigDecimal(pos)) case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos)) case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) case LongConversion => mutableRow.setLong(i, rs.getLong(pos)) - // TODO(davies): use getBytes for better performance + // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 case StringConversion => mutableRow.setString(i, rs.getString(pos)) case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos)) case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) From 28d6f32eda151ed51f35117eb5beb1ec6b6882d1 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 2 Apr 2015 16:38:31 -0700 Subject: [PATCH 17/30] refactor --- .../spark/sql/catalyst/ScalaReflection.scala | 20 ++++++++- .../codegen/GenerateProjection.scala | 6 +-- .../sql/catalyst/expressions/literals.scala | 16 ++----- .../spark/sql/catalyst/expressions/rows.scala | 5 +-- .../expressions/stringOperations.scala | 44 +++++-------------- .../ExpressionEvaluationSuite.scala | 3 +- .../org/apache/spark/sql/SQLContext.scala | 6 +-- .../spark/sql/execution/pythonUdfs.scala | 1 - .../org/apache/spark/sql/json/JsonRDD.scala | 2 +- .../ParquetPartitionDiscoverySuite.scala | 2 +- .../spark/sql/hive/HiveStrategies.scala | 18 +++----- .../org/apache/spark/sql/hive/Shim13.scala | 14 +++--- 12 files changed, 54 insertions(+), 83 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 7856064d2a2ee..fa9732ea7a9da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -78,10 +78,28 @@ trait ScalaReflection { case (d: BigDecimal, _) => Decimal(d) case (d: java.math.BigDecimal, _) => Decimal(d) case (d: java.sql.Date, _) => DateUtils.fromJavaDate(d) - case (s: String, st: StringType) => UTF8String(s) + case (s: String, _) => UTF8String(s) case (other, _) => other } + /** + * Converts Scala objects to catalyst rows / types. + * Note: This should be called before do evaluation on Row + * (It does not support UDT) + */ + def convertToCatalyst(a: Any): Any = a match { + case s: String => UTF8String(s) + case d: java.sql.Date => DateUtils.fromJavaDate(d) + case d: BigDecimal => Decimal(d) + case d: java.math.BigDecimal => Decimal(d) + case seq: Seq[Any] => seq.map(convertToCatalyst) + case r: Row => Row(r.toSeq.map(convertToCatalyst): _*) + case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray + case m: Map[Any, Any] => + m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap + case other => other + } + /** Converts Catalyst types used internally in rows to standard Scala types */ def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { // Check UDT first since UDTs can override other types diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index cad0611005937..463046f491244 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -128,7 +128,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { }""" case other => q""" - override def ${accessorForType(dataType)}(i: Int):${termForType(dataType)} = { + override def ${accessorForType(dataType)}(i: Int): ${termForType(dataType)} = { ..$ifStatements; $accessorFailure }""" @@ -148,13 +148,13 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { dataType match { case StringType => q""" - override def setString(i: Int, value: String): Unit = { + override def setString(i: Int, value: String) { ..$ifStatements; $accessorFailure }""" case other => q""" - override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}):Unit = { + override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}) { ..$ifStatements; $accessorFailure }""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 2505334426c46..db470fe4cd8e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types._ object Literal { @@ -42,20 +43,9 @@ object Literal { throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } - /** - * convert String in `v` as UTF8String - */ - def convertToUTF8String(v: Any): Any = v match { - case s: String => UTF8String(s) - case seq: Seq[Any] => seq.map(convertToUTF8String) - case r: Row => Row(r.toSeq.map(convertToUTF8String): _*) - case arr: Array[Any] => arr.toSeq.map(convertToUTF8String).toArray - case m: Map[Any, Any] => - m.map { case (k, v) => (convertToUTF8String(k), convertToUTF8String(v)) }.toMap - case other => other + def create(v: Any, dataType: DataType): Literal = { + Literal(ScalaReflection.convertToCatalyst(v), dataType) } - - def create(v: Any, dataType: DataType): Literal = Literal(convertToUTF8String(v), dataType) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 91b89d8c11620..f52d75a5bfc6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -196,10 +196,7 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value } override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value } override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value } - override def setString(ordinal: Int, value: String): Unit = { - // TODO(davies): need this? - values(ordinal) = UTF8String(value) - } + override def setString(ordinal: Int, value: String) { values(ordinal) = UTF8String(value)} override def setNullAt(i: Int): Unit = { values(i) = null } override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 232b1f3d8da56..2dc01256848dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -19,9 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import java.util.regex.Pattern -import scala.collection.IndexedSeqOptimized - - import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.types._ @@ -226,8 +223,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends override def children: Seq[Expression] = str :: pos :: len :: Nil @inline - def slice[T, C <: Any](str: C, startPos: Int, sliceLen: Int) - (implicit ev: (C=>IndexedSeqOptimized[T,_])): Any = { + def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = { // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and // negative indices for start positions. If a start index i is greater than 0, it // refers to element i-1 in the sequence. If a start index i is less than 0, it refers @@ -236,29 +232,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends val start = startPos match { case pos if pos > 0 => pos - 1 - case neg if neg < 0 => str.length + neg - case _ => 0 - } - - val end = sliceLen match { - case max if max == Integer.MAX_VALUE => max - case x => start + x - } - - str.slice(start, end) - } - - @inline - def slice(str: UTF8String, startPos: Int, sliceLen: Int): Any = { - // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and - // negative indices for start positions. If a start index i is greater than 0, it - // refers to element i-1 in the sequence. If a start index i is less than 0, it refers - // to the -ith element before the end of the sequence. If a start index i is 0, it - // refers to the first element. - - val start = startPos match { - case pos if pos > 0 => pos - 1 - case neg if neg < 0 => str.length + neg + case neg if neg < 0 => length() + neg case _ => 0 } @@ -267,12 +241,11 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends case x => start + x } - str.slice(start, end) + (start, end) } override def eval(input: Row): Any = { val string = str.eval(input) - val po = pos.eval(input) val ln = len.eval(input) @@ -280,11 +253,14 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends null } else { val start = po.asInstanceOf[Int] - val length = ln.asInstanceOf[Int] - + val length = ln.asInstanceOf[Int] string match { - case ba: Array[Byte] => slice(ba, start, length) - case s: UTF8String => slice(s, start, length) + case ba: Array[Byte] => + val (st, end) = slicePos(start, length, () => ba.length) + ba.slice(st, end) + case s: UTF8String => + val (st, end) = slicePos(start, length, () => s.length) + s.slice(st, end) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 0b4e2f4c1f9ab..ad76aa095cc72 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -25,6 +25,7 @@ import org.scalactic.TripleEqualsSupport.Spread import org.scalatest.FunSuite import org.scalatest.Matchers._ +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField import org.apache.spark.sql.types._ @@ -60,7 +61,7 @@ class ExpressionEvaluationBaseSuite extends FunSuite { class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { def create_row(values: Array[Any]): Row = { - new GenericRow(values.toSeq.map(Literal.convertToUTF8String).toArray) + new GenericRow(values.toSeq.map(ScalaReflection.convertToCatalyst).toArray) } test("literals") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index e0344b53cdf60..660fcc9e11b80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -396,11 +396,11 @@ class SQLContext(@transient val sparkContext: SparkContext) // schema differs from the existing schema on any field data type. def needsConversion(dt: DataType): Boolean = dt match { case StringType => true + case DateType => true + case DecimalType() => true case dt: ArrayType => needsConversion(dt.elementType) case dt: MapType => needsConversion(dt.keyType) || needsConversion(dt.valueType) - case dt: StructType => - !dt.fields.forall(f => !needsConversion(f.dataType)) - // TODO(davies): check other types and values + case dt: StructType => !dt.fields.forall(f => !needsConversion(f.dataType)) case other => false } val convertedRdd = if (needsConversion(schema)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 53c2664a47b00..7a43bfd8bc8d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -230,7 +230,6 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: def execute(): RDD[Row] = { // TODO: Clean up after ourselves? - // TODO(davies): convert internal type to Scala Type val childResults = child.execute().map(_.copy()).cache() val parent = childResults.mapPartitions { iter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 6b59336b37f56..816bf7f474041 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -450,7 +450,7 @@ private[sql] object JsonRDD extends Logging { private[sql] def rowToJSON(rowSchema: StructType, gen: JsonGenerator)(row: Row) = { def valWriter: (DataType, Any) => Unit = { case (_, null) | (NullType, _) => gen.writeNull() - case (StringType, v: String) => gen.writeString(v.toString) + case (StringType, v: String) => gen.writeString(v) case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString) case (IntegerType, v: Int) => gen.writeNumber(v) case (ShortType, v: Short) => gen.writeNumber(v) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index fc0d418f99351..b7561ce7298cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -20,7 +20,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.parquet.ParquetRelation2._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 66b9089bf2a01..1eb3ea7126528 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -17,25 +17,22 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.catalyst.expressions.Row - import scala.collection.JavaConversions._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Row, _} import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.sources.DescribeCommand -import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand, _} import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, CreateTableUsing} -import org.apache.spark.sql.types.{UTF8String, StringType} +import org.apache.spark.sql.sources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand} +import org.apache.spark.sql.types.StringType private[hive] trait HiveStrategies { @@ -131,10 +128,7 @@ private[hive] trait HiveStrategies { val partitionValues = part.getValues var i = 0 while (i < partitionValues.size()) { - inputData(i) = partitionValues(i) match { - case s: String => UTF8String(s) - case other => other - } + inputData(i) = ScalaReflection.convertToCatalyst(partitionValues(i)) i += 1 } pruningCondition(inputData) diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index 4523de5bffbb2..d331c210e8939 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -23,12 +23,14 @@ import java.util.{Properties, ArrayList => JArrayList} import scala.collection.JavaConversions._ import scala.language.implicitConversions +import com.esotericsoftware.kryo.Kryo import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Context +import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory @@ -45,7 +47,6 @@ import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.Logging import org.apache.spark.sql.types.{Decimal, DecimalType, UTF8String} - /** * This class provides the UDF creation and also the UDF instance serialization and * de-serialization cross process boundary. @@ -60,19 +61,14 @@ private[hive] case class HiveFunctionWrapper(var functionClassName: String) // for Serialization def this() = this(null) - import java.io.{InputStream, OutputStream} - -import com.esotericsoftware.kryo.Kryo - import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} - -import org.apache.spark.util.Utils._ + import org.apache.spark.util.Utils._ @transient private val methodDeSerialize = { val method = classOf[Utilities].getDeclaredMethod( "deserializeObjectByKryo", classOf[Kryo], - classOf[InputStream], + classOf[java.io.InputStream], classOf[Class[_]]) method.setAccessible(true) @@ -85,7 +81,7 @@ import org.apache.spark.util.Utils._ "serializeObjectByKryo", classOf[Kryo], classOf[Object], - classOf[OutputStream]) + classOf[java.io.OutputStream]) method.setAccessible(true) method From e5fa5b824b62670b8ba76399831ddd0e9e25efa5 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 2 Apr 2015 17:50:11 -0700 Subject: [PATCH 18/30] remove clone in UTF8String --- .../src/main/scala/org/apache/spark/sql/types/UTF8String.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala index bea8526003eae..cdb2a6c6e1e59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala @@ -44,7 +44,7 @@ final class UTF8String extends Ordered[UTF8String] with Serializable { * Update the UTF8String with Array[Byte], which should be encoded in UTF-8 */ def set(bytes: Array[Byte]): UTF8String = { - this.bytes = bytes.clone() + this.bytes = bytes this } From 8d17f21b1f4bb7daa1dcf950330e53d8aebd074b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 3 Apr 2015 00:39:07 -0700 Subject: [PATCH 19/30] fix hive compatibility tests --- .../expressions/SpecificMutableRow.scala | 16 +++++++++++----- .../catalyst/expressions/stringOperations.scala | 4 ++++ .../apache/spark/sql/parquet/newParquet.scala | 5 +++-- .../apache/spark/sql/hive/HiveInspectors.scala | 9 ++++++--- .../hive/execution/ScriptTransformation.scala | 10 +++++----- 5 files changed, 29 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index a0ac9e21ca240..96f00964adcc4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -175,12 +175,16 @@ final class MutableString extends MutableValue { override def boxed: Any = if (isNull) null else value override def update(v: Any): Unit = { isNull = false - value = v.asInstanceOf[UTF8String] + if (value == null) { + value = v.asInstanceOf[UTF8String] + } else { + value.set(v.asInstanceOf[UTF8String].getBytes) + } } override def copy(): MutableString = { val newCopy = new MutableString newCopy.isNull = isNull - newCopy.value = value + newCopy.value = value.clone() newCopy.asInstanceOf[MutableString] } } @@ -217,7 +221,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR case DoubleType => new MutableDouble case BooleanType => new MutableBoolean case LongType => new MutableLong - // TODO(davies): Enable this + // TODO(davies): enable this // case StringType => new MutableString case _ => new MutableAny }.toArray) @@ -249,11 +253,13 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def update(ordinal: Int, value: Any): Unit = value match { case null => setNullAt(ordinal) - case s: String => update(ordinal, UTF8String(s)) + case s: String => + // for tests + throw new Exception("String should be converted into UTF8String") case other => values(ordinal).update(value) } - override def setString(ordinal: Int, value: String): Unit = update(ordinal, value) + override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String(value)) override def getString(ordinal: Int): String = apply(ordinal).toString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 2dc01256848dd..23d9290b24a95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -261,6 +261,10 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends case s: UTF8String => val (st, end) = slicePos(start, length, () => s.length) s.slice(st, end) + case other => + val s = other.toString + val (st, end) = slicePos(start, length, () => s.length) + UTF8String(s.slice(st, end)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index e51870d3e2cc1..8115b277427ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -45,7 +45,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.{NewHadoopPartition, NewHadoopRDD, RDD} -import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.{ScalaReflection, expressions} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.parquet.ParquetTypesConverter._ import org.apache.spark.sql.sources._ @@ -528,7 +528,8 @@ private[sql] case class ParquetRelation2( baseRDD.mapPartitionsWithInputSplit { case (split: ParquetInputSplit, iterator) => val partValues = selectedPartitions.collectFirst { - case p if split.getPath.getParent.toString == p.path => p.values + case p if split.getPath.getParent.toString == p.path => + ScalaReflection.convertToCatalyst(p.values).asInstanceOf[Row] }.get val requiredPartOrdinal = partitionKeyLocations.keys.toSeq diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 0d0a5b134f07f..74ae984f34866 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -242,7 +242,7 @@ private[hive] trait HiveInspectors { case poi: WritableConstantStringObjectInspector => UTF8String(poi.getWritableConstantValue.toString) case poi: WritableConstantHiveVarcharObjectInspector => - poi.getWritableConstantValue.getHiveVarchar.getValue + UTF8String(poi.getWritableConstantValue.getHiveVarchar.getValue) case poi: WritableConstantHiveDecimalObjectInspector => HiveShim.toCatalystDecimal( PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector, @@ -285,10 +285,13 @@ private[hive] trait HiveInspectors { case pi: PrimitiveObjectInspector => pi match { // We think HiveVarchar is also a String case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() => - hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue - case hvoi: HiveVarcharObjectInspector => hvoi.getPrimitiveJavaObject(data).getValue + UTF8String(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue) + case hvoi: HiveVarcharObjectInspector => + UTF8String(hvoi.getPrimitiveJavaObject(data).getValue) case x: StringObjectInspector if x.preferWritable() => UTF8String(x.getPrimitiveWritableObject(data).toString) + case x: StringObjectInspector => + UTF8String(x.getPrimitiveJavaObject(data)) case x: IntObjectInspector if x.preferWritable() => x.get(data) case x: BooleanObjectInspector if x.preferWritable() => x.get(data) case x: FloatObjectInspector if x.preferWritable() => x.get(data) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 8efed7f0299bf..204b9b69d12c5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.execution._ @@ -121,14 +122,13 @@ case class ScriptTransformation( if (outputSerde == null) { val prevLine = curLine curLine = reader.readLine() - if (!ioschema.schemaLess) { - new GenericRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + new GenericRow(ScalaReflection.convertToCatalyst( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))) .asInstanceOf[Array[Any]]) } else { - new GenericRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) + new GenericRow(ScalaReflection.convertToCatalyst( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)) .asInstanceOf[Array[Any]]) } } else { From fd113643c48b633eb505540a13b8fd4798c0197d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 3 Apr 2015 11:54:24 -0700 Subject: [PATCH 20/30] optimize UTF8String --- .../apache/spark/sql/types/UTF8String.scala | 71 ++++++++++++------- 1 file changed, 44 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala index cdb2a6c6e1e59..d393bc97bd086 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala @@ -30,7 +30,7 @@ import java.util.Arrays final class UTF8String extends Ordered[UTF8String] with Serializable { - private var bytes: Array[Byte] = _ + private[this] var bytes: Array[Byte] = _ /** * Update the UTF8String with String. @@ -48,6 +48,12 @@ final class UTF8String extends Ordered[UTF8String] with Serializable { this } + @inline + private[this] def numOfBytes(b: Byte): Int = { + val offset = (b & 0xFF) - 192 + if (offset >= 0) UTF8String.tailBytesOfUTF8(offset) else 1 + } + /** * Return the number of code points in it. * @@ -57,11 +63,7 @@ final class UTF8String extends Ordered[UTF8String] with Serializable { var len = 0 var i: Int = 0 while (i < bytes.length) { - val b = bytes(i) & 0xFF - i += 1 - if (b >= 192) { - i += UTF8String.tailBytesOfUTF8(b - 192) - } + i += numOfBytes(bytes(i)) len += 1 } len @@ -84,35 +86,47 @@ final class UTF8String extends Ordered[UTF8String] with Serializable { var c = 0 var i: Int = 0 while (c < start && i < bytes.length) { - val b = bytes(i) & 0xFF - i += 1 - if (b >= 192) { - i += UTF8String.tailBytesOfUTF8(b - 192) - } + i += numOfBytes(bytes(i)) c += 1 } var j = i while (c < until && j < bytes.length) { - val b = bytes(j) & 0xFF - j += 1 - if (b >= 192) { - j += UTF8String.tailBytesOfUTF8(b - 192) - } + j += numOfBytes(bytes(j)) c += 1 } UTF8String(Arrays.copyOfRange(bytes, i, j)) } def contains(sub: UTF8String): Boolean = { - bytes.containsSlice(sub.bytes) + val b = sub.getBytes + if (b.length == 0) { + return true + } + var i: Int = 0 + while (i <= bytes.length - b.length) { + // In worst case, it's O(N*K), but should works fine with SQL + if (bytes(i) == b(0) && Arrays.equals(Arrays.copyOfRange(bytes, i, i + b.length), b)) { + return true + } + i += 1 + } + false } def startsWith(prefix: UTF8String): Boolean = { - bytes.startsWith(prefix.bytes) + val b = prefix.getBytes + if (b.length > bytes.length) { + return false + } + Arrays.equals(Arrays.copyOfRange(bytes, 0, b.length), b) } def endsWith(suffix: UTF8String): Boolean = { - bytes.endsWith(suffix.bytes) + val b = suffix.getBytes + if (b.length > bytes.length) { + return false + } + Arrays.equals(Arrays.copyOfRange(bytes, bytes.length - b.length, bytes.length), b) } def toUpperCase(): UTF8String = { @@ -133,12 +147,13 @@ final class UTF8String extends Ordered[UTF8String] with Serializable { override def compare(other: UTF8String): Int = { var i: Int = 0 - while (i < bytes.length && i < other.bytes.length) { - val res = bytes(i).compareTo(other.bytes(i)) + val b = other.getBytes + while (i < bytes.length && i < b.length) { + val res = bytes(i).compareTo(b(i)) if (res != 0) return res i += 1 } - bytes.length - other.bytes.length + bytes.length - b.length } override def compareTo(other: UTF8String): Int = { @@ -147,7 +162,7 @@ final class UTF8String extends Ordered[UTF8String] with Serializable { override def equals(other: Any): Boolean = other match { case s: UTF8String => - Arrays.equals(bytes, s.bytes) + Arrays.equals(bytes, s.getBytes) case s: String => // fail fast bytes.length >= s.length && length() == s.length && toString() == s @@ -163,10 +178,12 @@ final class UTF8String extends Ordered[UTF8String] with Serializable { object UTF8String { // number of tailing bytes in a UTF8 sequence for a code point // see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1 - private[types] val tailBytesOfUTF8: Array[Int] = Array(1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, - 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5) + private[types] val tailBytesOfUTF8: Array[Int] = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, + 5, 5, 5, 5, + 6, 6, 6, 6) /** * Create a UTF-8 String from String From ac18ae63e20041c48166eb6863c995e21bafa64d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 3 Apr 2015 13:39:37 -0700 Subject: [PATCH 21/30] address comment --- .../catalyst/expressions/SpecificMutableRow.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 96f00964adcc4..5b7685b75a994 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -251,12 +251,14 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR new GenericRow(newValues) } - override def update(ordinal: Int, value: Any): Unit = value match { - case null => setNullAt(ordinal) - case s: String => + override def update(ordinal: Int, value: Any) { + if (value == null) { + setNullAt(ordinal) + } else { // for tests - throw new Exception("String should be converted into UTF8String") - case other => values(ordinal).update(value) + assert(!value.isInstanceOf[String], "String should be converted into UTF8String") + values(ordinal).update(value) + } } override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String(value)) From 2089d24123c2e7defe808c6d56de0271c9476796 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 3 Apr 2015 15:46:21 -0700 Subject: [PATCH 22/30] add hashcode check back --- .../catalyst/analysis/HiveTypeCoercion.scala | 4 + .../expressions/stringOperations.scala | 4 - .../ExpressionEvaluationSuite.scala | 78 +++++++++---------- .../GeneratedMutableEvaluationSuite.scala | 11 ++- 4 files changed, 53 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 1780c5a648bac..a2e507b15e4c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -505,6 +505,10 @@ trait HiveTypeCoercion { case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType)) case Average(e @ TimestampType()) => Average(Cast(e, DoubleType)) + // Compatible with Hive + case Substring(e, start, len) if e.dataType != StringType => + Substring(Cast(e, StringType), start, len) + // Coalesce should return the first non-null value, which could be any column // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 23d9290b24a95..2dc01256848dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -261,10 +261,6 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends case s: UTF8String => val (st, end) = slicePos(start, length, () => s.length) s.slice(st, end) - case other => - val s = other.toString - val (st, end) = slicePos(start, length, () => s.length) - UTF8String(s.slice(st, end)) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index ad76aa095cc72..b851f8cdeac33 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -60,8 +60,8 @@ class ExpressionEvaluationBaseSuite extends FunSuite { class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { - def create_row(values: Array[Any]): Row = { - new GenericRow(values.toSeq.map(ScalaReflection.convertToCatalyst).toArray) + def create_row(values: Any*): Row = { + new GenericRow(values.map(ScalaReflection.convertToCatalyst).toArray) } test("literals") { @@ -253,23 +253,23 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { test("LIKE Non-literal Regular Expression") { val regEx = 'a.string.at(0) - checkEvaluation("abcd" like regEx, null, create_row(Array[Any](null))) - checkEvaluation("abdef" like regEx, true, create_row(Array[Any]("abdef"))) - checkEvaluation("a_%b" like regEx, true, create_row(Array[Any]("a\\__b"))) - checkEvaluation("addb" like regEx, true, create_row(Array[Any]("a_%b"))) - checkEvaluation("addb" like regEx, false, create_row(Array[Any]("a\\__b"))) - checkEvaluation("addb" like regEx, false, create_row(Array[Any]("a%\\%b"))) - checkEvaluation("a_%b" like regEx, true, create_row(Array[Any]("a%\\%b"))) - checkEvaluation("addb" like regEx, true, create_row(Array[Any]("a%"))) - checkEvaluation("addb" like regEx, false, create_row(Array[Any]("**"))) - checkEvaluation("abc" like regEx, true, create_row(Array[Any]("a%"))) - checkEvaluation("abc" like regEx, false, create_row(Array[Any]("b%"))) - checkEvaluation("abc" like regEx, false, create_row(Array[Any]("bc%"))) - checkEvaluation("a\nb" like regEx, true, create_row(Array[Any]("a_b"))) - checkEvaluation("ab" like regEx, true, create_row(Array[Any]("a%b"))) - checkEvaluation("a\nb" like regEx, true, create_row(Array[Any]("a%b"))) - - checkEvaluation(Literal.create(null, StringType) like regEx, null, create_row(Array[Any]("bc%"))) + checkEvaluation("abcd" like regEx, null, create_row(null)) + checkEvaluation("abdef" like regEx, true, create_row("abdef")) + checkEvaluation("a_%b" like regEx, true, create_row("a\\__b")) + checkEvaluation("addb" like regEx, true, create_row("a_%b")) + checkEvaluation("addb" like regEx, false, create_row("a\\__b")) + checkEvaluation("addb" like regEx, false, create_row("a%\\%b")) + checkEvaluation("a_%b" like regEx, true, create_row("a%\\%b")) + checkEvaluation("addb" like regEx, true, create_row("a%")) + checkEvaluation("addb" like regEx, false, create_row("**")) + checkEvaluation("abc" like regEx, true, create_row("a%")) + checkEvaluation("abc" like regEx, false, create_row("b%")) + checkEvaluation("abc" like regEx, false, create_row("bc%")) + checkEvaluation("a\nb" like regEx, true, create_row("a_b")) + checkEvaluation("ab" like regEx, true, create_row("a%b")) + checkEvaluation("a\nb" like regEx, true, create_row("a%b")) + + checkEvaluation(Literal.create(null, StringType) like regEx, null, create_row("bc%")) } test("RLIKE literal Regular Expression") { @@ -300,14 +300,14 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { test("RLIKE Non-literal Regular Expression") { val regEx = 'a.string.at(0) - checkEvaluation("abdef" rlike regEx, true, create_row(Array[Any]("abdef"))) - checkEvaluation("abbbbc" rlike regEx, true, create_row(Array[Any]("a.*c"))) - checkEvaluation("fofo" rlike regEx, true, create_row(Array[Any]("^fo"))) - checkEvaluation("fo\no" rlike regEx, true, create_row(Array[Any]("^fo\no$"))) - checkEvaluation("Bn" rlike regEx, true, create_row(Array[Any]("^Ba*n"))) + checkEvaluation("abdef" rlike regEx, true, create_row("abdef")) + checkEvaluation("abbbbc" rlike regEx, true, create_row("a.*c")) + checkEvaluation("fofo" rlike regEx, true, create_row("^fo")) + checkEvaluation("fo\no" rlike regEx, true, create_row("^fo\no$")) + checkEvaluation("Bn" rlike regEx, true, create_row("^Ba*n")) intercept[java.util.regex.PatternSyntaxException] { - evaluate("abbbbc" rlike regEx, create_row(Array[Any]("**"))) + evaluate("abbbbc" rlike regEx, create_row("**")) } } @@ -748,7 +748,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("null checking") { - val row = create_row(Array[Any]("^Ba*n", null, true, null)) + val row = create_row("^Ba*n", null, true, null) val c1 = 'a.string.at(0) val c2 = 'a.string.at(1) val c3 = 'a.boolean.at(2) @@ -787,7 +787,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("case when") { - val row = create_row(Array[Any](null, false, true, "a", "b", "c")) + val row = create_row(null, false, true, "a", "b", "c") val c1 = 'a.boolean.at(0) val c2 = 'a.boolean.at(1) val c3 = 'a.boolean.at(2) @@ -830,13 +830,13 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("complex type") { - val row = create_row(Array[Any]( + val row = create_row( "^Ba*n", // 0 null.asInstanceOf[UTF8String], // 1 - create_row(Array[Any]("aa", "bb")), // 2 + create_row("aa", "bb"), // 2 Map("aa"->"bb"), // 3 Seq("aa", "bb") // 4 - )) + ) val typeS = StructType( StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil @@ -888,7 +888,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("arithmetic") { - val row = create_row(Array[Any](1, 2, 3, null)) + val row = create_row(1, 2, 3, null) val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.int.at(2) @@ -912,7 +912,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("fractional arithmetic") { - val row = create_row(Array[Any](1.1, 2.0, 3.1, null)) + val row = create_row(1.1, 2.0, 3.1, null) val c1 = 'a.double.at(0) val c2 = 'a.double.at(1) val c3 = 'a.double.at(2) @@ -935,7 +935,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("BinaryComparison") { - val row = create_row(Array[Any](1, 2, 3, null, 3, null)) + val row = create_row(1, 2, 3, null, 3, null) val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.int.at(2) @@ -964,7 +964,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("StringComparison") { - val row = create_row(Array[Any]("abc", null)) + val row = create_row("abc", null) val c1 = 'a.string.at(0) val c2 = 'a.string.at(1) @@ -985,7 +985,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("Substring") { - val row = create_row(Array[Any]("example", "example".toArray.map(_.toByte))) + val row = create_row("example", "example".toArray.map(_.toByte)) val s = 'a.string.at(0) @@ -1017,7 +1017,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), "", row) // substring(null, _, _) -> null - checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), null, create_row(Array[Any](null))) + checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), null, create_row(null)) // substring(_, null, _) -> null checkEvaluation(Substring(s, Literal.create(null, IntegerType), Literal.create(4, IntegerType)), null, row) @@ -1048,20 +1048,20 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { test("SQRT") { val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24)) val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble)) - val rowSequence = inputSequence.map(l => create_row(Array[Any](l.toDouble))) + val rowSequence = inputSequence.map(l => create_row(l.toDouble)) val d = 'a.double.at(0) for ((row, expected) <- rowSequence zip expectedResults) { checkEvaluation(Sqrt(d), expected, row) } - checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(Array[Any](null))) + checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) checkEvaluation(Sqrt(-1), null, EmptyRow) checkEvaluation(Sqrt(-1.5), null, EmptyRow) } test("Bitwise operations") { - val row = create_row(Array[Any](1, 2, 3, null)) + val row = create_row(1, 2, 3, null) val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.int.at(2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala index 925e994b9cf55..0a8f6ac377dcf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.ScalaReflection /** * Overrides our expression evaluation tests to use generated code on mutable rows. @@ -42,7 +43,15 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { } val actual = plan(inputRow) - val expectedRow = new GenericRow(Array[Any](expected)) + val expectedRow = new GenericRow(Array[Any](ScalaReflection.convertToCatalyst(expected))) + if (actual.hashCode() != expectedRow.hashCode()) { + fail( + s""" + |Mismatched hashCodes for values: $actual, $expectedRow + |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} + |${evaluated.code.mkString("\n")} + """.stripMargin) + } if (actual != expectedRow) { val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") From 867bf5077d0670e7674df5a9d18d97d7ae90b416 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 3 Apr 2015 22:15:31 -0700 Subject: [PATCH 23/30] fix String filter push down --- .../spark/sql/sources/DataSourceStrategy.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 564084998b939..ce60ae2fe9042 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{UTF8String, StringType} import org.apache.spark.sql.{Row, Strategy, execution, sources} /** @@ -176,14 +176,14 @@ private[sql] object DataSourceStrategy extends Strategy { case expressions.Not(child) => translate(child).map(sources.Not) - case expressions.StartsWith(a: Attribute, Literal(v: String, StringType)) => - Some(sources.StringStartsWith(a.name, v)) + case expressions.StartsWith(a: Attribute, Literal(v: UTF8String, StringType)) => + Some(sources.StringStartsWith(a.name, v.toString)) - case expressions.EndsWith(a: Attribute, Literal(v: String, StringType)) => - Some(sources.StringEndsWith(a.name, v)) + case expressions.EndsWith(a: Attribute, Literal(v: UTF8String, StringType)) => + Some(sources.StringEndsWith(a.name, v.toString)) - case expressions.Contains(a: Attribute, Literal(v: String, StringType)) => - Some(sources.StringContains(a.name, v)) + case expressions.Contains(a: Attribute, Literal(v: UTF8String, StringType)) => + Some(sources.StringContains(a.name, v.toString)) case _ => None } From 1314a3727327b3215d8c527e07e58f82ebf67e64 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 8 Apr 2015 16:23:55 -0700 Subject: [PATCH 24/30] address comments from Yin --- .../spark/sql/catalyst/ScalaReflection.scala | 2 ++ .../expressions/SpecificMutableRow.scala | 23 ------------------- .../expressions/codegen/CodeGenerator.scala | 11 ++++++++- .../sql/catalyst/expressions/generators.scala | 5 ++-- .../apache/spark/sql/types/DateUtils.scala | 1 - .../apache/spark/sql/types/UTF8String.scala | 9 ++++++-- .../spark/sql/types/UTF8StringSuite.scala | 2 ++ .../spark/sql/execution/ExistingRDD.scala | 4 ++++ .../org/apache/spark/sql/hive/Shim12.scala | 4 ++-- 9 files changed, 30 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 4ae7988db5fbe..ac2629205c51f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -84,8 +84,10 @@ trait ScalaReflection { /** * Converts Scala objects to catalyst rows / types. + * * Note: This should be called before do evaluation on Row * (It does not support UDT) + * This is used to create an RDD or test results with correct types for Catalyst. */ def convertToCatalyst(a: Any): Any = a match { case s: String => UTF8String(s) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 5b7685b75a994..3475ed05f4454 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -170,25 +170,6 @@ final class MutableByte extends MutableValue { } } -final class MutableString extends MutableValue { - var value: UTF8String = _ - override def boxed: Any = if (isNull) null else value - override def update(v: Any): Unit = { - isNull = false - if (value == null) { - value = v.asInstanceOf[UTF8String] - } else { - value.set(v.asInstanceOf[UTF8String].getBytes) - } - } - override def copy(): MutableString = { - val newCopy = new MutableString - newCopy.isNull = isNull - newCopy.value = value.clone() - newCopy.asInstanceOf[MutableString] - } -} - final class MutableAny extends MutableValue { var value: Any = _ override def boxed: Any = if (isNull) null else value @@ -221,8 +202,6 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR case DoubleType => new MutableDouble case BooleanType => new MutableBoolean case LongType => new MutableLong - // TODO(davies): enable this - // case StringType => new MutableString case _ => new MutableAny }.toArray) @@ -255,8 +234,6 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR if (value == null) { setNullAt(ordinal) } else { - // for tests - assert(!value.isInstanceOf[String], "String should be converted into UTF8String") values(ordinal).update(value) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c2e94d1334f26..a3bdce0ca9139 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -220,7 +220,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin q""" val $nullTerm = ${value == null} val $primitiveTerm: ${termForType(dataType)} = - org.apache.spark.sql.types.UTF8String(${value.toString}) + org.apache.spark.sql.types.UTF8String(${value.getBytes}) """.children case expressions.Literal(value: Int, dataType) => @@ -279,6 +279,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString) """.children + case EqualTo(e1: BinaryType, e2: BinaryType) => + (e1, e2).evaluateAs (BooleanType) { + case (eval1, eval2) => + q""" + java.util.Arrays.equals($eval1.asInstanceOf[Array[Byte]], + $eval2.asInstanceOf[Array[Byte]]) + """ + } + case EqualTo(e1, e2) => (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 17daddb129f33..7d8ff7a6dfec4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -86,9 +86,10 @@ case class UserDefinedGenerator( override def eval(input: Row): TraversableOnce[Row] = { // TODO(davies): improve this - val input_schema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) + // Convert the objects into Scala Type before calling function, we need schema to support UDT + val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) val inputRow = new InterpretedProjection(children) - function(ScalaReflection.convertToScala(inputRow(input), input_schema).asInstanceOf[Row]) + function(ScalaReflection.convertToScala(inputRow(input), inputSchema).asInstanceOf[Row]) } override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala index 8e4105109f7d8..140e52a8494e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala @@ -39,7 +39,6 @@ object DateUtils { millisToDays(d.getTime) } - // TODO(davies): This is buggy, it will be wrong if the date is not aligned with day // we should use the exact day as Int, for example, (year, month, day) -> day def millisToDays(millisLocal: Long): Int = { ((millisLocal + LOCAL_TIMEZONE.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala index d393bc97bd086..fc02ba6c9c43e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala @@ -48,10 +48,14 @@ final class UTF8String extends Ordered[UTF8String] with Serializable { this } + /** + * Return the number of bytes for a code point with the first byte as `b` + * @param b The first byte of a code point + */ @inline private[this] def numOfBytes(b: Byte): Int = { val offset = (b & 0xFF) - 192 - if (offset >= 0) UTF8String.tailBytesOfUTF8(offset) else 1 + if (offset >= 0) UTF8String.bytesOfCodePointInUTF8(offset) else 1 } /** @@ -164,6 +168,7 @@ final class UTF8String extends Ordered[UTF8String] with Serializable { case s: UTF8String => Arrays.equals(bytes, s.getBytes) case s: String => + // This is only used for Catalyst unit tests // fail fast bytes.length >= s.length && length() == s.length && toString() == s case _ => @@ -178,7 +183,7 @@ final class UTF8String extends Ordered[UTF8String] with Serializable { object UTF8String { // number of tailing bytes in a UTF8 sequence for a code point // see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1 - private[types] val tailBytesOfUTF8: Array[Int] = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + private[types] val bytesOfCodePointInUTF8: Array[Int] = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala index 04a435da33584..954494c78633f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala @@ -28,6 +28,8 @@ class UTF8StringSuite extends FunSuite { assert(UTF8String(str) == str) assert(UTF8String(str.getBytes("utf8")) == str) + assert(UTF8String(str).toString == str) + assert(UTF8String(str.getBytes("utf8")).toString == str) assert(UTF8String(str.getBytes("utf8")) == UTF8String(str)) assert(UTF8String(str).hashCode() == UTF8String(str.getBytes("utf8")).hashCode()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 261a1ebfbcd45..4cb24384644e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -52,6 +52,10 @@ object RDDConversions { } } } + + /** + * Convert the objects inside Row into the types Catalyst expected. + */ def rowToRowRdd(data: RDD[Row], schema: StructType): RDD[Row] = { data.mapPartitions { iterator => if (iterator.isEmpty) { diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index 0ed93c2c5b1fa..33e96eaabfbf6 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -41,7 +41,7 @@ import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory} import org.apache.hadoop.io.{NullWritable, Writable} import org.apache.hadoop.mapred.InputFormat -import org.apache.spark.sql.types.{Decimal, DecimalType} +import org.apache.spark.sql.types.{UTF8String, Decimal, DecimalType} private[hive] case class HiveFunctionWrapper(functionClassName: String) extends java.io.Serializable { @@ -135,7 +135,7 @@ private[hive] object HiveShim { PrimitiveCategory.VOID, null) def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String]) + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) def getIntWritable(value: Any): hadoopIo.IntWritable = if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) From 5116b438bbb8d5dcd3e25dc196076a9b3d52b951 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 8 Apr 2015 16:28:55 -0700 Subject: [PATCH 25/30] rollback unrelated changes --- .../test/org/apache/spark/sql/JavaRowSuite.java | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java index 3f9c1e7f55935..4ce1d1dddb26a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java @@ -17,12 +17,6 @@ package test.org.apache.spark.sql; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - import java.math.BigDecimal; import java.sql.Date; import java.sql.Timestamp; @@ -31,6 +25,13 @@ import java.util.List; import java.util.Map; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; + public class JavaRowSuite { private byte byteValue; private short shortValue; From b04a19c8978abaac2a68173e3f412728d7e826e1 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 10 Apr 2015 10:55:53 -0700 Subject: [PATCH 26/30] add comment for getString/setString --- .../sql/catalyst/expressions/codegen/GenerateProjection.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 463046f491244..6f572ff959fb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -120,10 +120,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { case _ => Nil } dataType match { + // Row() need this interface to compile case StringType => q""" override def getString(i: Int): String = { - ..$ifStatements; $accessorFailure }""" case other => @@ -147,9 +147,9 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } dataType match { case StringType => + // MutableRow() need this interface to compile q""" override def setString(i: Int, value: String) { - ..$ifStatements; $accessorFailure }""" case other => From 341ec2c925a90f34e439de9254d8f44d9e86520e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 13 Apr 2015 12:15:02 -0700 Subject: [PATCH 27/30] turn off scala style check in UTF8StringSuite --- .../test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala index 954494c78633f..a22aa6f244c48 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.types import org.scalatest.FunSuite +// scalastyle:off class UTF8StringSuite extends FunSuite { test("basic") { def check(str: String, len: Int) { From 59025c8f8c6824efa587c6876f3e6b978925e890 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 14 Apr 2015 17:31:59 -0700 Subject: [PATCH 28/30] address comments from @marmbrus --- .../org/apache/spark/sql/sources/DataSourceStrategy.scala | 8 +++++--- .../scala/org/apache/spark/sql/sources/interfaces.scala | 3 +++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index ce60ae2fe9042..e3a890942c8b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.{UTF8String, StringType} import org.apache.spark.sql.{Row, Strategy, execution, sources} @@ -114,9 +115,10 @@ private[sql] object DataSourceStrategy extends Strategy { } } - private[this] def createPhysicalRDD(relation: BaseRelation, - output: Seq[Attribute], - rdd: RDD[Row]) = { + private[this] def createPhysicalRDD( + relation: BaseRelation, + output: Seq[Attribute], + rdd: RDD[Row]): SparkPlan = { val converted = if (relation.needConversion) { execution.RDDConversions.rowToRowRdd(rdd, relation.schema) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index d431db42814ad..ca53dcdb92c52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -131,6 +131,9 @@ abstract class BaseRelation { * Whether does it need to convert the objects in Row to internal representation, for example: * java.lang.String -> UTF8String * java.lang.Decimal -> Decimal + * + * Note: The internal representation is not stable across releases and thus data sources outside + * of Spark SQL should leave this as true. */ def needConversion: Boolean = true } From 2772f0d8face2f9c634718fb8719fe56c5d8d676 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 14 Apr 2015 20:04:21 -0700 Subject: [PATCH 29/30] fix new test failure --- .../apache/spark/sql/catalyst/CatalystTypeConverters.scala | 4 ++-- .../main/scala/org/apache/spark/sql/execution/commands.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index d4f9fdacda4fb..72f8545ec166d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -321,9 +321,9 @@ object CatalystTypeConverters { row: Row, schema: StructType, converters: Array[Any => Any]): Row = { - val ar = new Array[Any](row.size) + val ar = new Array[Any](converters.size) var idx = 0 - while (idx < row.size) { + while (idx < converters.size && idx < row.size) { ar(idx) = converters(idx)(row(idx)) idx += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 99c45cfa641c3..99f24910fd61f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -63,8 +63,8 @@ case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { override def executeTake(limit: Int): Array[Row] = sideEffectResult.take(limit).toArray override def execute(): RDD[Row] = { - val converted = sideEffectResult.map(r => CatalystTypeConverters.convertToCatalyst(r, schema) - .asInstanceOf[Row]) + val converted = sideEffectResult.map(r => + CatalystTypeConverters.convertToCatalyst(r, schema).asInstanceOf[Row]) sqlContext.sparkContext.parallelize(converted, 1) } } From 3b7bfa8f37e7f2b9aefdfd0e5e57d7b5c6b516ce Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 14 Apr 2015 21:56:40 -0700 Subject: [PATCH 30/30] fix schema of AddJar --- .../spark/sql/catalyst/CatalystTypeConverters.scala | 4 ++-- .../org/apache/spark/sql/hive/execution/commands.scala | 10 ++++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 72f8545ec166d..d4f9fdacda4fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -321,9 +321,9 @@ object CatalystTypeConverters { row: Row, schema: StructType, converters: Array[Any => Any]): Row = { - val ar = new Array[Any](converters.size) + val ar = new Array[Any](row.size) var idx = 0 - while (idx < converters.size && idx < row.size) { + while (idx < row.size) { ar(idx) = converters(idx)(row(idx)) idx += 1 } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 902a12785e3e9..a40a1e53117cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -22,11 +22,11 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ /** * Analyzes the given table in the current database to generate statistics, which will be @@ -76,6 +76,12 @@ case class DropTable( private[hive] case class AddJar(path: String) extends RunnableCommand { + override val output: Seq[Attribute] = { + val schema = StructType( + StructField("result", IntegerType, false) :: Nil) + schema.toAttributes + } + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] hiveContext.runSqlHive(s"ADD JAR $path")