From 39704abb09271dc38d6afbc87652235add0ea32e Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 13 Apr 2015 14:04:04 -0700 Subject: [PATCH 01/11] Specialized serializer for Exchange. --- .../scala/org/apache/spark/sql/SQLConf.scala | 4 + .../apache/spark/sql/execution/Exchange.scala | 39 +- .../sql/execution/SparkSqlSerializer2.scala | 378 ++++++++++++++++++ .../org/apache/spark/sql/QueryTest.scala | 3 + .../execution/SparkSqlSerializer2Suite.scala | 198 +++++++++ 5 files changed, 618 insertions(+), 4 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index ee641bdfeb2d7..1070bf44abd6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -63,6 +63,8 @@ private[spark] object SQLConf { // Set to false when debugging requires the ability to look at invalid query plans. val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis" + val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2" + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -139,6 +141,8 @@ private[sql] class SQLConf extends Serializable { */ private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean + private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "false").toBoolean + /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to * a broadcast value during the physical executions of join operations. Setting this to -1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 437408d30bfd2..b69b89d74c099 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.types.DataType import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf} import org.apache.spark.rdd.{RDD, ShuffledRDD} +import org.apache.spark.serializer.Serializer import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.{Attribute, RowOrdering} @@ -45,6 +46,27 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una private val bypassMergeThreshold = child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + def serializer( + keySchema: Array[DataType], + valueSchema: Array[DataType], + numPartitions: Int): Serializer = { + val useSqlSerializer2 = + !(sortBasedShuffleOn && numPartitions > bypassMergeThreshold) && + child.sqlContext.conf.useSqlSerializer2 && + SparkSqlSerializer2.support(keySchema) && + SparkSqlSerializer2.support(valueSchema) + + val serializer = if (useSqlSerializer2) { + logInfo("Use ShuffleSerializer") + new SparkSqlSerializer2(keySchema, valueSchema) + } else { + logInfo("Use SparkSqlSerializer") + new SparkSqlSerializer(new SparkConf(false)) + } + + serializer + } + override def execute(): RDD[Row] = attachTree(this , "execute") { newPartitioning match { case HashPartitioning(expressions, numPartitions) => @@ -70,7 +92,11 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } val part = new HashPartitioner(numPartitions) val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + + val keySchema = expressions.map(_.dataType).toArray + val valueSchema = child.output.map(_.dataType).toArray + shuffled.setSerializer(serializer(keySchema, valueSchema, numPartitions)) + shuffled.map(_._2) case RangePartitioning(sortingExpressions, numPartitions) => @@ -88,7 +114,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una val part = new RangePartitioner(numPartitions, rdd, ascending = true) val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + + val keySchema = sortingExpressions.map(_.dataType).toArray + shuffled.setSerializer(serializer(keySchema, null, numPartitions)) shuffled.map(_._1) @@ -107,7 +135,10 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } val partitioner = new HashPartitioner(1) val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + + val valueSchema = child.output.map(_.dataType).toArray + shuffled.setSerializer(serializer(null, valueSchema, 1)) + shuffled.map(_._2) case _ => sys.error(s"Exchange not implemented for $newPartitioning") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala new file mode 100644 index 0000000000000..9e7b4ab63fe79 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -0,0 +1,378 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.io._ +import java.nio.ByteBuffer +import java.sql.Timestamp + +import scala.reflect.ClassTag + +import org.apache.spark.serializer._ +import org.apache.spark.Logging +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.types._ + +/** + * The serialization stream for SparkSqlSerializer2. + */ +private[sql] class Serializer2SerializationStream( + keySchema: Array[DataType], + valueSchema: Array[DataType], + out: OutputStream) + extends SerializationStream with Logging { + + val rowOut = new DataOutputStream(out) + val writeKey = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut) + val writeValue = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut) + + def writeObject[T: ClassTag](t: T): SerializationStream = { + val kv = t.asInstanceOf[Product2[Row, Row]] + writeKey(kv._1) + writeValue(kv._2) + + this + } + + def flush(): Unit = { + rowOut.flush() + } + + def close(): Unit = { + rowOut.close() + } +} + +/** + * The deserialization stream for SparkSqlSerializer2. + */ +private[sql] class Serializer2DeserializationStream( + keySchema: Array[DataType], + valueSchema: Array[DataType], + in: InputStream) + extends DeserializationStream with Logging { + + val rowIn = new DataInputStream(new BufferedInputStream(in)) + + val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null + val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null + val readKey = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key) + val readValue = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value) + + def readObject[T: ClassTag](): T = { + readKey() + readValue() + + (key, value).asInstanceOf[T] + } + + def close(): Unit = { + rowIn.close() + } +} + +private[sql] class ShuffleSerializerInstance( + keySchema: Array[DataType], + valueSchema: Array[DataType]) + extends SerializerInstance { + + def serialize[T: ClassTag](t: T): ByteBuffer = + throw new UnsupportedOperationException("Not supported.") + + def deserialize[T: ClassTag](bytes: ByteBuffer): T = + throw new UnsupportedOperationException("Not supported.") + + def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = + throw new UnsupportedOperationException("Not supported.") + + def serializeStream(s: OutputStream): SerializationStream = { + new Serializer2SerializationStream(keySchema, valueSchema, s) + } + + def deserializeStream(s: InputStream): DeserializationStream = { + new Serializer2DeserializationStream(keySchema, valueSchema, s) + } +} + +/** + * SparkSqlSerializer2 is a special serializer that creates serialization function and + * deserialization function based on the schema of data. It assumes that values passed in + * are key/value pairs and values returned from it are also key/value pairs. + * The schema of keys is represented by `keySchema` and that of values is represented by + * `valueSchema`. + */ +private[sql] class SparkSqlSerializer2(keySchema: Array[DataType], valueSchema: Array[DataType]) + extends Serializer + with Logging + with Serializable{ + + def newInstance(): SerializerInstance = new ShuffleSerializerInstance(keySchema, valueSchema) +} + +private[sql] object SparkSqlSerializer2 { + + final val NULL = 0 + final val NOT_NULL = 1 + + /** + * Check if rows with the given schema can be serialized with ShuffleSerializer. + */ + def support(schema: Array[DataType]): Boolean = { + if (schema == null) return true + + var i = 0 + while (i < schema.length) { + schema(i) match { + case udt: UserDefinedType[_] => return false + case array: ArrayType => return false + case map: MapType => return false + case struct: StructType => return false + case decimal: DecimalType => return false + case _ => + } + i += 1 + } + + return true + } + + /** + * The util function to create the serialization function based on the given schema. + */ + def createSerializationFunction(schema: Array[DataType], out: DataOutputStream): Row => Unit = { + (row: Row) => + // If the schema is null, the returned function does nothing when it get called. + if (schema != null) { + var i = 0 + while (i < schema.length) { + schema(i) match { + // When we write values to the underlying stream, we also first write the null byte + // first. Then, if the value is not null, we write the contents out. + + case NullType => // Write nothing. + + case BooleanType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeBoolean(row.getBoolean(i)) + } + + case ByteType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeByte(row.getByte(i)) + } + + case ShortType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeShort(row.getShort(i)) + } + + case IntegerType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeInt(row.getInt(i)) + } + + case LongType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeLong(row.getLong(i)) + } + + case FloatType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeFloat(row.getFloat(i)) + } + + case DoubleType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeDouble(row.getDouble(i)) + } + + case DateType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeInt(row.getInt(i)) + } + + case TimestampType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + val timestamp = row.getAs[java.sql.Timestamp](i) + val time = timestamp.getTime + val nanos = timestamp.getNanos + out.writeLong(time - (nanos / 1000000)) // Write the milliseconds value. + out.writeInt(nanos) // Write the nanoseconds part. + } + + case StringType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + // TODO: Update it once the string improvement is in. + out.writeUTF(row.getString(i)) + } + + case BinaryType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + val bytes = row.getAs[Array[Byte]](i) + out.writeInt(bytes.length) + out.write(bytes) + } + } + i += 1 + } + } + } + + /** + * The util function to create the deserialization function based on the given schema. + */ + def createDeserializationFunction( + schema: Array[DataType], + in: DataInputStream, + mutableRow: SpecificMutableRow): () => Unit = { + () => { + // If the schema is null, the returned function does nothing when it get called. + if (schema != null) { + var i = 0 + while (i < schema.length) { + schema(i) match { + // When we read values from the underlying stream, we also first read the null byte + // first. Then, if the value is not null, we update the field of the mutable row. + + case NullType => mutableRow.setNullAt(i) // Read nothing. + + case BooleanType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setBoolean(i, in.readBoolean()) + } + + case ByteType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setByte(i, in.readByte()) + } + + case ShortType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setShort(i, in.readShort()) + } + + case IntegerType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setInt(i, in.readInt()) + } + + case LongType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setLong(i, in.readLong()) + } + + case FloatType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setFloat(i, in.readFloat()) + } + + case DoubleType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setDouble(i, in.readDouble()) + } + + case DateType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.update(i, in.readInt()) + } + + case TimestampType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + val time = in.readLong() // Read the milliseconds value. + val nanos = in.readInt() // Read the nanoseconds part. + val timestamp = new Timestamp(time) + timestamp.setNanos(nanos) + mutableRow.update(i, timestamp) + } + + case StringType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + // TODO: Update it once the string improvement is in. + mutableRow.setString(i, in.readUTF()) + } + + case BinaryType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + mutableRow.update(i, bytes) + } + } + i += 1 + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 9a81fc5d72819..59f9508444f25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -104,9 +104,12 @@ object QueryTest { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. val converted: Seq[Row] = answer.map { s => Row.fromSeq(s.toSeq.map { case d: java.math.BigDecimal => BigDecimal(d) + case b: Array[Byte] => b.toSeq case o => o }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala new file mode 100644 index 0000000000000..335cca219931e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -0,0 +1,198 @@ +package org.apache.spark.sql.execution + +import java.sql.{Timestamp, Date} + +import org.apache.spark.serializer.Serializer +import org.apache.spark.{SparkConf, ShuffleDependency, SparkContext} +import org.apache.spark.rdd.ShuffledRDD +import org.apache.spark.sql.types._ +import org.apache.spark.sql.Row +import org.scalatest.{FunSuite, BeforeAndAfterAll} + +import org.apache.spark.sql.{MyDenseVectorUDT, SQLContext, QueryTest} + +class SparkSqlSerializer2DataTypeSuite extends FunSuite { + // Make sure that we will not use serializer2 for unsupported data types. + def checkSupported(dataType: DataType, isSupported: Boolean): Unit = { + val testName = + s"${if (dataType == null) null else dataType.toString} is " + + s"${if (isSupported) "supported" else "unsupported"}" + + test(testName) { + assert(SparkSqlSerializer2.support(Array(dataType)) === isSupported) + } + } + + checkSupported(null, isSupported = true) + checkSupported(NullType, isSupported = true) + checkSupported(BooleanType, isSupported = true) + checkSupported(ByteType, isSupported = true) + checkSupported(ShortType, isSupported = true) + checkSupported(IntegerType, isSupported = true) + checkSupported(LongType, isSupported = true) + checkSupported(FloatType, isSupported = true) + checkSupported(DoubleType, isSupported = true) + checkSupported(DateType, isSupported = true) + checkSupported(TimestampType, isSupported = true) + checkSupported(StringType, isSupported = true) + checkSupported(BinaryType, isSupported = true) + + // Because at the runtime we accepts three kinds of Decimals + // (Java BigDecimal, Scala BigDecimal, and Spark SQL's Decimal), we do support DecimalType + // right now. We will support it once we fixed the internal type. + checkSupported(DecimalType(10, 5), isSupported = false) + checkSupported(DecimalType.Unlimited, isSupported = false) + // For now, ArrayType, MapType, and StructType are not supported. + checkSupported(ArrayType(DoubleType, true), isSupported = false) + checkSupported(ArrayType(StringType, false), isSupported = false) + checkSupported(MapType(IntegerType, StringType, true), isSupported = false) + checkSupported(MapType(IntegerType, ArrayType(DoubleType), false), isSupported = false) + checkSupported(StructType(StructField("a", IntegerType, true) :: Nil), isSupported = false) + // UDTs are not supported right now. + checkSupported(new MyDenseVectorUDT, isSupported = false) +} + +abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll { + + @transient var sparkContext: SparkContext = _ + @transient var sqlContext: SQLContext = _ + var allColumns: String = _ + val serializerClass: Class[Serializer] = + classOf[SparkSqlSerializer2].asInstanceOf[Class[Serializer]] + + override def beforeAll(): Unit = { + sqlContext.sql("set spark.sql.shuffle.partitions=5") + sqlContext.sql("set spark.sql.useSerializer2=true") + + val supportedTypes = + Seq(StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DateType, TimestampType) + + val fields = supportedTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, true) + } + allColumns = fields.map(_.name).mkString(",") + val schema = StructType(fields) + + // Create a RDD with all data types supported by SparkSqlSerializer2. + val rdd = + sparkContext.parallelize((1 to 1000), 10).map { i => + Row( + s"str${i}: test serializer2.", + s"binary${i}: test serializer2.".getBytes("UTF-8"), + null, + i % 2 == 0, + i.toByte, + i.toShort, + i, + i.toLong, + (i + 0.25).toFloat, + (i + 0.75), + new Date(i), + new Timestamp(i)) + } + + sqlContext.createDataFrame(rdd, schema).registerTempTable("shuffle") + + super.beforeAll() + } + + override def afterAll(): Unit = { + sqlContext.dropTempTable("shuffle") + sparkContext.stop() + super.afterAll() + } + + def checkSerializer[T <: Serializer]( + executedPlan: SparkPlan, + expectedSerializerClass: Class[T]): Unit = { + executedPlan.foreach { + case exchange: Exchange => + val shuffledRDD = exchange.execute().firstParent.asInstanceOf[ShuffledRDD[_, _, _]] + val dependency = shuffledRDD.getDependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + val serializerNotSetMessage = + s"Expected $expectedSerializerClass as the serializer of Exchange. " + + s"However, the serializer was not set." + val serializer = dependency.serializer.getOrElse(fail(serializerNotSetMessage)) + assert(serializer.getClass === expectedSerializerClass) + case _ => // Ignore other nodes. + } + } + + test("key schema and value schema are not nulls") { + val df = sqlContext.sql(s"SELECT DISTINCT ${allColumns} FROM shuffle") + checkSerializer(df.queryExecution.executedPlan, serializerClass) + checkAnswer( + df, + sqlContext.table("shuffle").collect()) + } + + test("value schema is null") { + val df = sqlContext.sql(s"SELECT col0 FROM shuffle ORDER BY col0") + checkSerializer(df.queryExecution.executedPlan, serializerClass) + assert( + df.map(r => r.getString(0)).collect().toSeq === + sqlContext.table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq) + } + + test("key schema is null") { + val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",") + val df = sqlContext.sql(s"SELECT $aggregations FROM shuffle") + checkSerializer(df.queryExecution.executedPlan, serializerClass) + checkAnswer( + df, + Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000)) + } +} + +/** Tests SparkSqlSerializer2 with hash based shuffle. */ +class SparkSqlSerializer2HashShuffleSuite extends SparkSqlSerializer2Suite { + override def beforeAll(): Unit = { + val sparkConf = + new SparkConf() + .set("spark.sql.testkey", "true") + .set("spark.shuffle.manager", "hash") + + sparkContext = new SparkContext("local[2]", "Serializer2SQLContext", sparkConf) + sqlContext = new SQLContext(sparkContext) + super.beforeAll() + } +} + +/** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */ +class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite { + override def beforeAll(): Unit = { + // Since spark.sql.shuffle.partition is 5, we will not do sort merge when + // spark.shuffle.sort.bypassMergeThreshold is also 5. + val sparkConf = + new SparkConf() + .set("spark.sql.testkey", "true") + .set("spark.shuffle.manager", "sort") + .set("spark.shuffle.sort.bypassMergeThreshold", "5") + + sparkContext = new SparkContext("local[2]", "Serializer2SQLContext", sparkConf) + sqlContext = new SQLContext(sparkContext) + super.beforeAll() + } +} + +/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */ +class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite { + + // We are expecting SparkSqlSerializer. + override val serializerClass: Class[Serializer] = + classOf[SparkSqlSerializer].asInstanceOf[Class[Serializer]] + + override def beforeAll(): Unit = { + val sparkConf = + new SparkConf() + .set("spark.sql.testkey", "true") + .set("spark.shuffle.manager", "sort") + .set("spark.shuffle.sort.bypassMergeThreshold", "0") // Always do sort merge. + + sparkContext = new SparkContext("local[2]", "Serializer2SQLContext", sparkConf) + sqlContext = new SQLContext(sparkContext) + super.beforeAll() + } +} From 2379eeb00c8beb55f028f74420036a390d32c7c9 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 13 Apr 2015 14:15:01 -0700 Subject: [PATCH 02/11] ASF header. --- .../execution/SparkSqlSerializer2Suite.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 335cca219931e..3bf89cdb066ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.sql.execution import java.sql.{Timestamp, Date} From c9373c8951dd89fab1036e74f64922a0622fe937 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 13 Apr 2015 20:47:10 -0700 Subject: [PATCH 03/11] Support DecimalType. --- .../sql/execution/SparkSqlSerializer2.scala | 32 ++++++++++++++++++- .../execution/SparkSqlSerializer2Suite.scala | 16 +++++----- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index 9e7b4ab63fe79..8e19efb718d0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import java.io._ +import java.math.{BigDecimal, BigInteger} import java.nio.ByteBuffer import java.sql.Timestamp @@ -143,7 +144,6 @@ private[sql] object SparkSqlSerializer2 { case array: ArrayType => return false case map: MapType => return false case struct: StructType => return false - case decimal: DecimalType => return false case _ => } i += 1 @@ -223,6 +223,21 @@ private[sql] object SparkSqlSerializer2 { out.writeDouble(row.getDouble(i)) } + case decimal: DecimalType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + val value = row.apply(i).asInstanceOf[Decimal] + val javaBigDecimal = value.toJavaBigDecimal + // First, write out the unscaled value. + val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray + out.writeInt(bytes.length) + out.write(bytes) + // Then, write out the scale. + out.writeInt(javaBigDecimal.scale()) + } + case DateType => if (row.isNullAt(i)) { out.writeByte(NULL) @@ -334,6 +349,21 @@ private[sql] object SparkSqlSerializer2 { mutableRow.setDouble(i, in.readDouble()) } + case decimal: DecimalType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + // First, read in the unscaled value. + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + val unscaledVal = new BigInteger(bytes) + // Then, read the scale. + val scale = in.readInt() + // Finally, create the Decimal object and set it in the row. + mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale))) + } + case DateType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 3bf89cdb066ae..0a34e605fa142 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -53,12 +53,9 @@ class SparkSqlSerializer2DataTypeSuite extends FunSuite { checkSupported(TimestampType, isSupported = true) checkSupported(StringType, isSupported = true) checkSupported(BinaryType, isSupported = true) + checkSupported(DecimalType(10, 5), isSupported = true) + checkSupported(DecimalType.Unlimited, isSupported = true) - // Because at the runtime we accepts three kinds of Decimals - // (Java BigDecimal, Scala BigDecimal, and Spark SQL's Decimal), we do support DecimalType - // right now. We will support it once we fixed the internal type. - checkSupported(DecimalType(10, 5), isSupported = false) - checkSupported(DecimalType.Unlimited, isSupported = false) // For now, ArrayType, MapType, and StructType are not supported. checkSupported(ArrayType(DoubleType, true), isSupported = false) checkSupported(ArrayType(StringType, false), isSupported = false) @@ -84,7 +81,8 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll val supportedTypes = Seq(StringType, BinaryType, NullType, BooleanType, ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DateType, TimestampType) + FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5), + DateType, TimestampType) val fields = supportedTypes.zipWithIndex.map { case (dataType, index) => StructField(s"col$index", dataType, true) @@ -103,9 +101,11 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll i.toByte, i.toShort, i, - i.toLong, + Long.MaxValue - i.toLong, (i + 0.25).toFloat, (i + 0.75), + BigDecimal(Long.MaxValue.toString + ".12345"), + new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"), new Date(i), new Timestamp(i)) } @@ -159,7 +159,7 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll checkSerializer(df.queryExecution.executedPlan, serializerClass) checkAnswer( df, - Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000)) + Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000)) } } From 8297732a5b1ec068769e904fef5a41ceeb77f132 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 14 Apr 2015 10:04:36 -0700 Subject: [PATCH 04/11] Fix test. --- .../sql/execution/SparkSqlSerializer2Suite.scala | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 0a34e605fa142..f8c833f54648f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import java.sql.{Timestamp, Date} import org.apache.spark.serializer.Serializer -import org.apache.spark.{SparkConf, ShuffleDependency, SparkContext} +import org.apache.spark.{SparkEnv, SparkConf, ShuffleDependency, SparkContext} import org.apache.spark.rdd.ShuffledRDD import org.apache.spark.sql.types._ import org.apache.spark.sql.Row @@ -70,6 +70,8 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll @transient var sparkContext: SparkContext = _ @transient var sqlContext: SQLContext = _ + // We may have an existing SparkEnv (e.g. the one used by TestSQLContext). + @transient val existingSparkEnv = SparkEnv.get var allColumns: String = _ val serializerClass: Class[Serializer] = classOf[SparkSqlSerializer2].asInstanceOf[Class[Serializer]] @@ -118,6 +120,10 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll override def afterAll(): Unit = { sqlContext.dropTempTable("shuffle") sparkContext.stop() + sqlContext = null + sparkContext = null + // Set the existing SparkEnv back. + SparkEnv.set(existingSparkEnv) super.afterAll() } @@ -168,6 +174,7 @@ class SparkSqlSerializer2HashShuffleSuite extends SparkSqlSerializer2Suite { override def beforeAll(): Unit = { val sparkConf = new SparkConf() + .set("spark.driver.allowMultipleContexts", "true") .set("spark.sql.testkey", "true") .set("spark.shuffle.manager", "hash") @@ -184,6 +191,7 @@ class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite { // spark.shuffle.sort.bypassMergeThreshold is also 5. val sparkConf = new SparkConf() + .set("spark.driver.allowMultipleContexts", "true") .set("spark.sql.testkey", "true") .set("spark.shuffle.manager", "sort") .set("spark.shuffle.sort.bypassMergeThreshold", "5") @@ -204,6 +212,7 @@ class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite override def beforeAll(): Unit = { val sparkConf = new SparkConf() + .set("spark.driver.allowMultipleContexts", "true") .set("spark.sql.testkey", "true") .set("spark.shuffle.manager", "sort") .set("spark.shuffle.sort.bypassMergeThreshold", "0") // Always do sort merge. From 43b9fb47e4907e156aa2bbf36288d3a3cd039185 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 14 Apr 2015 10:36:03 -0700 Subject: [PATCH 05/11] Test. --- .../execution/SparkSqlSerializer2Suite.scala | 98 ++++++------------- 1 file changed, 32 insertions(+), 66 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index f8c833f54648f..3bb78a8ccc9d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -19,14 +19,15 @@ package org.apache.spark.sql.execution import java.sql.{Timestamp, Date} -import org.apache.spark.serializer.Serializer -import org.apache.spark.{SparkEnv, SparkConf, ShuffleDependency, SparkContext} +import org.scalatest.{FunSuite, BeforeAndAfterAll} + import org.apache.spark.rdd.ShuffledRDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.ShuffleDependency import org.apache.spark.sql.types._ import org.apache.spark.sql.Row -import org.scalatest.{FunSuite, BeforeAndAfterAll} - -import org.apache.spark.sql.{MyDenseVectorUDT, SQLContext, QueryTest} +import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest} class SparkSqlSerializer2DataTypeSuite extends FunSuite { // Make sure that we will not use serializer2 for unsupported data types. @@ -67,18 +68,17 @@ class SparkSqlSerializer2DataTypeSuite extends FunSuite { } abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll { - - @transient var sparkContext: SparkContext = _ - @transient var sqlContext: SQLContext = _ - // We may have an existing SparkEnv (e.g. the one used by TestSQLContext). - @transient val existingSparkEnv = SparkEnv.get var allColumns: String = _ val serializerClass: Class[Serializer] = classOf[SparkSqlSerializer2].asInstanceOf[Class[Serializer]] + var numShufflePartitions: Int = _ + var useSerializer2: Boolean = _ override def beforeAll(): Unit = { - sqlContext.sql("set spark.sql.shuffle.partitions=5") - sqlContext.sql("set spark.sql.useSerializer2=true") + numShufflePartitions = conf.numShufflePartitions + useSerializer2 = conf.useSqlSerializer2 + + sql("set spark.sql.useSerializer2=true") val supportedTypes = Seq(StringType, BinaryType, NullType, BooleanType, @@ -112,18 +112,15 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll new Timestamp(i)) } - sqlContext.createDataFrame(rdd, schema).registerTempTable("shuffle") + createDataFrame(rdd, schema).registerTempTable("shuffle") super.beforeAll() } override def afterAll(): Unit = { - sqlContext.dropTempTable("shuffle") - sparkContext.stop() - sqlContext = null - sparkContext = null - // Set the existing SparkEnv back. - SparkEnv.set(existingSparkEnv) + dropTempTable("shuffle") + sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions") + sql(s"set spark.sql.useSerializer2=$useSerializer2") super.afterAll() } @@ -144,24 +141,33 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll } test("key schema and value schema are not nulls") { - val df = sqlContext.sql(s"SELECT DISTINCT ${allColumns} FROM shuffle") + val df = sql(s"SELECT DISTINCT ${allColumns} FROM shuffle") checkSerializer(df.queryExecution.executedPlan, serializerClass) checkAnswer( df, - sqlContext.table("shuffle").collect()) + table("shuffle").collect()) } test("value schema is null") { - val df = sqlContext.sql(s"SELECT col0 FROM shuffle ORDER BY col0") + val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0") checkSerializer(df.queryExecution.executedPlan, serializerClass) assert( df.map(r => r.getString(0)).collect().toSeq === - sqlContext.table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq) + table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq) + } +} + +/** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */ +class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite { + override def beforeAll(): Unit = { + super.beforeAll() + // Sort merge will not be triggered. + sql("set spark.sql.shuffle.partitions = 200") } test("key schema is null") { val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",") - val df = sqlContext.sql(s"SELECT $aggregations FROM shuffle") + val df = sql(s"SELECT $aggregations FROM shuffle") checkSerializer(df.queryExecution.executedPlan, serializerClass) checkAnswer( df, @@ -169,39 +175,6 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll } } -/** Tests SparkSqlSerializer2 with hash based shuffle. */ -class SparkSqlSerializer2HashShuffleSuite extends SparkSqlSerializer2Suite { - override def beforeAll(): Unit = { - val sparkConf = - new SparkConf() - .set("spark.driver.allowMultipleContexts", "true") - .set("spark.sql.testkey", "true") - .set("spark.shuffle.manager", "hash") - - sparkContext = new SparkContext("local[2]", "Serializer2SQLContext", sparkConf) - sqlContext = new SQLContext(sparkContext) - super.beforeAll() - } -} - -/** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */ -class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite { - override def beforeAll(): Unit = { - // Since spark.sql.shuffle.partition is 5, we will not do sort merge when - // spark.shuffle.sort.bypassMergeThreshold is also 5. - val sparkConf = - new SparkConf() - .set("spark.driver.allowMultipleContexts", "true") - .set("spark.sql.testkey", "true") - .set("spark.shuffle.manager", "sort") - .set("spark.shuffle.sort.bypassMergeThreshold", "5") - - sparkContext = new SparkContext("local[2]", "Serializer2SQLContext", sparkConf) - sqlContext = new SQLContext(sparkContext) - super.beforeAll() - } -} - /** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */ class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite { @@ -210,15 +183,8 @@ class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite classOf[SparkSqlSerializer].asInstanceOf[Class[Serializer]] override def beforeAll(): Unit = { - val sparkConf = - new SparkConf() - .set("spark.driver.allowMultipleContexts", "true") - .set("spark.sql.testkey", "true") - .set("spark.shuffle.manager", "sort") - .set("spark.shuffle.sort.bypassMergeThreshold", "0") // Always do sort merge. - - sparkContext = new SparkContext("local[2]", "Serializer2SQLContext", sparkConf) - sqlContext = new SQLContext(sparkContext) super.beforeAll() + // To trigger the sort merge. + sql("set spark.sql.shuffle.partitions = 201") } } From 3e0965578ce2c09b79565fbd032615133997ecd0 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 14 Apr 2015 17:54:09 -0700 Subject: [PATCH 06/11] Use getAs for Date column. --- .../org/apache/spark/sql/execution/SparkSqlSerializer2.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index 8e19efb718d0d..fbccaa3dfd39c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -243,7 +243,7 @@ private[sql] object SparkSqlSerializer2 { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - out.writeInt(row.getInt(i)) + out.writeInt(row.getAs[Int](i)) } case TimestampType => From 791b96a555f4a18e061cca636ae08ba16f6cc30f Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 15 Apr 2015 16:48:25 -0700 Subject: [PATCH 07/11] Use UTF8String. --- .../org/apache/spark/sql/execution/Exchange.scala | 8 ++++---- .../spark/sql/execution/SparkSqlSerializer2.scala | 12 ++++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 355239c6d67b9..0eb5fab263546 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf} +import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner} import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.serializer.Serializer import org.apache.spark.sql.{SQLContext, Row} @@ -79,7 +79,7 @@ case class Exchange( } } - private lazy val sparkConf = child.sqlContext.sparkContext.getConf + @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf def serializer( keySchema: Array[DataType], @@ -92,10 +92,10 @@ case class Exchange( SparkSqlSerializer2.support(valueSchema) val serializer = if (useSqlSerializer2) { - logInfo("Use ShuffleSerializer") + logInfo("Use SparkSqlSerializer2.") new SparkSqlSerializer2(keySchema, valueSchema) } else { - logInfo("Use SparkSqlSerializer") + logInfo("Use SparkSqlSerializer.") new SparkSqlSerializer(sparkConf) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index fbccaa3dfd39c..fdfc8417a3ec6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -263,8 +263,9 @@ private[sql] object SparkSqlSerializer2 { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - // TODO: Update it once the string improvement is in. - out.writeUTF(row.getString(i)) + val bytes = row.getAs[UTF8String](i).getBytes + out.writeInt(bytes.length) + out.write(bytes) } case BinaryType => @@ -386,8 +387,11 @@ private[sql] object SparkSqlSerializer2 { if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { - // TODO: Update it once the string improvement is in. - mutableRow.setString(i, in.readUTF()) + // TODO: reuse the byte array in the UTF8String. + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + mutableRow.update(i, UTF8String(bytes)) } case BinaryType => From 09e587a1ff9136c81da05c2a7f2c0e9b85c09173 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 15 Apr 2015 16:59:46 -0700 Subject: [PATCH 08/11] Remove TODO. --- .../org/apache/spark/sql/execution/SparkSqlSerializer2.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index fdfc8417a3ec6..974aae18fb2c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -387,7 +387,6 @@ private[sql] object SparkSqlSerializer2 { if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { - // TODO: reuse the byte array in the UTF8String. val length = in.readInt() val bytes = new Array[Byte](length) in.readFully(bytes) From 4273b8c06dec0c7a20d05e1e09b2d00681b9f3a8 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 17 Apr 2015 10:41:57 -0700 Subject: [PATCH 09/11] Enabled SparkSqlSerializer2. --- sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index e7eb22983dc70..4fc5de7e824fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -149,7 +149,7 @@ private[sql] class SQLConf extends Serializable { */ private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean - private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "false").toBoolean + private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to From 6d076789bf981cd41cb99271763a8fa7ca703219 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 17 Apr 2015 11:57:34 -0700 Subject: [PATCH 10/11] Address comments. --- .../apache/spark/sql/execution/Exchange.scala | 25 +++++++++++++------ .../sql/execution/SparkSqlSerializer2.scala | 14 +++++++++-- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 0eb5fab263546..64e90be4d460f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -85,17 +85,28 @@ case class Exchange( keySchema: Array[DataType], valueSchema: Array[DataType], numPartitions: Int): Serializer = { + // In ExternalSorter's spillToMergeableFile function, key-value pairs are written out + // through write(key) and then write(value) instead of write((key, value)). Because + // SparkSqlSerializer2 assumes that objects passed in are Product2, we cannot safely use + // it when spillToMergeableFile in ExternalSorter will be used. + // So, we will not use SparkSqlSerializer2 when + // - Sort-based shuffle is enabled and the number of reducers (numPartitions) is greater + // then the bypassMergeThreshold; or + // - newOrdering is defined. + val cannotUseSqlSerializer2 = + (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || newOrdering.nonEmpty + val useSqlSerializer2 = - !(sortBasedShuffleOn && numPartitions > bypassMergeThreshold) && - child.sqlContext.conf.useSqlSerializer2 && - SparkSqlSerializer2.support(keySchema) && - SparkSqlSerializer2.support(valueSchema) + child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled. + !cannotUseSqlSerializer2 && // Safe to use Serializer2. + SparkSqlSerializer2.support(keySchema) && // The schema of key is supported. + SparkSqlSerializer2.support(valueSchema) // The schema of value is supported. val serializer = if (useSqlSerializer2) { - logInfo("Use SparkSqlSerializer2.") + logInfo("Using SparkSqlSerializer2.") new SparkSqlSerializer2(keySchema, valueSchema) } else { - logInfo("Use SparkSqlSerializer.") + logInfo("Using SparkSqlSerializer.") new SparkSqlSerializer(sparkConf) } @@ -160,7 +171,7 @@ case class Exchange( } else { new ShuffledRDD[Row, Null, Null](rdd, part) } - val keySchema = sortingExpressions.map(_.dataType).toArray + val keySchema = child.output.map(_.dataType).toArray shuffled.setSerializer(serializer(keySchema, null, numPartitions)) shuffled.map(_._1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index 974aae18fb2c9..cec97de2cd8e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -31,7 +31,17 @@ import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow import org.apache.spark.sql.types._ /** - * The serialization stream for SparkSqlSerializer2. + * The serialization stream for [[SparkSqlSerializer2]]. It assumes that the object passed in + * its `writeObject` are [[Product2]]. The serialization functions for the key and value of the + * [[Product2]] are constructed based on their schemata. + * The benefit of this serialization stream is that compared with general-purpose serializers like + * Kryo and Java serializer, it can significantly reduce the size of serialized and has a lower + * allocation cost, which can benefit the shuffle operation. Right now, its main limitations are: + * 1. It does not support complex types, i.e. Map, Array, and Struct. + * 2. It assumes that the objects passed in are [[Product2]]. So, it cannot be used when + * [[org.apache.spark.util.collection.ExternalSorter]]'s merge sort operation is used because + * the objects passed in the serializer are not in the type of [[Product2]]. Also also see + * the comment of the `serializer` method in [[Exchange]] for more information on it. */ private[sql] class Serializer2SerializationStream( keySchema: Array[DataType], @@ -61,7 +71,7 @@ private[sql] class Serializer2SerializationStream( } /** - * The deserialization stream for SparkSqlSerializer2. + * The corresponding deserialization stream for [[Serializer2SerializationStream]]. */ private[sql] class Serializer2DeserializationStream( keySchema: Array[DataType], From 50e0c3d8244f3210ac3f925904e08976108d6701 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 17 Apr 2015 20:52:57 -0700 Subject: [PATCH 11/11] When no filed is emitted to shuffle, use SparkSqlSerializer for now. --- .../org/apache/spark/sql/execution/Exchange.scala | 15 +++++++++++---- .../sql/execution/SparkSqlSerializer2Suite.scala | 5 +++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 64e90be4d460f..5b2e46962cd3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -96,11 +96,18 @@ case class Exchange( val cannotUseSqlSerializer2 = (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || newOrdering.nonEmpty + // It is true when there is no field that needs to be write out. + // For now, we will not use SparkSqlSerializer2 when noField is true. + val noField = + (keySchema == null || keySchema.length == 0) && + (valueSchema == null || valueSchema.length == 0) + val useSqlSerializer2 = - child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled. - !cannotUseSqlSerializer2 && // Safe to use Serializer2. - SparkSqlSerializer2.support(keySchema) && // The schema of key is supported. - SparkSqlSerializer2.support(valueSchema) // The schema of value is supported. + child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled. + !cannotUseSqlSerializer2 && // Safe to use Serializer2. + SparkSqlSerializer2.support(keySchema) && // The schema of key is supported. + SparkSqlSerializer2.support(valueSchema) && // The schema of value is supported. + !noField val serializer = if (useSqlSerializer2) { logInfo("Using SparkSqlSerializer2.") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 3bb78a8ccc9d0..27f063d73a9a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -155,6 +155,11 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll df.map(r => r.getString(0)).collect().toSeq === table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq) } + + test("no map output field") { + val df = sql(s"SELECT 1 + 1 FROM shuffle") + checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer]) + } } /** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */