diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index 32923dc9f5a6b..5f0802b466039 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -42,11 +42,11 @@ case object KafkaWriterCommitMessage extends WriterCommitMessage */ class KafkaStreamWriter( topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends StreamWriter with SupportsWriteInternalRow { + extends StreamWriter { validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) - override def createInternalRowWriterFactory(): KafkaStreamWriterFactory = + override def createWriterFactory(): KafkaStreamWriterFactory = KafkaStreamWriterFactory(topic, producerParams, schema) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index 7eedc85a5d6f3..385fc294fea82 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -18,8 +18,8 @@ package org.apache.spark.sql.sources.v2.writer; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.StreamWriteSupport; import org.apache.spark.sql.sources.v2.WriteSupport; @@ -61,7 +61,7 @@ public interface DataSourceWriter { * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ - DataWriterFactory createWriterFactory(); + DataWriterFactory createWriterFactory(); /** * Returns whether Spark should use the commit coordinator to ensure that at most one task for diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 1626c0013e4e7..27dc5ea224fe2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -53,9 +53,7 @@ * successfully, and have a way to revert committed data writers without the commit message, because * Spark only accepts the commit message that arrives first and ignore others. * - * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data - * source writers, or {@link org.apache.spark.sql.catalyst.InternalRow} for data source writers - * that mix in {@link SupportsWriteInternalRow}. + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow}. */ @InterfaceStability.Evolving public interface DataWriter { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index 0932ff8f8f8a7..3d337b6e0bdfd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -33,7 +33,10 @@ public interface DataWriterFactory extends Serializable { /** - * Returns a data writer to do the actual writing work. + * Returns a data writer to do the actual writing work. Note that, Spark will reuse the same data + * object instance when sending data to the data writer, for better performance. Data writers + * are responsible for defensive copies if necessary, e.g. copy the data before buffer it in a + * list. * * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java deleted file mode 100644 index d2cf7e01c08c8..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.writer; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.InternalRow; - -/** - * A mix-in interface for {@link DataSourceWriter}. Data source writers can implement this - * interface to write {@link InternalRow} directly and avoid the row conversion at Spark side. - * This is an experimental and unstable interface, as {@link InternalRow} is not public and may get - * changed in the future Spark versions. - */ - -@InterfaceStability.Unstable -public interface SupportsWriteInternalRow extends DataSourceWriter { - - @Override - default DataWriterFactory createWriterFactory() { - throw new IllegalStateException( - "createWriterFactory should not be called with SupportsWriteInternalRow."); - } - - DataWriterFactory createInternalRowWriterFactory(); -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index b1148c0f62f7c..0399970495bec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -50,11 +50,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { - val writeTask = writer match { - case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() - case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) - } - + val writeTask = writer.createWriterFactory() val useCommitCoordinator = writer.useCommitCoordinator val rdd = query.execute() val messages = new Array[WriterCommitMessage](rdd.partitions.length) @@ -155,27 +151,3 @@ object DataWritingSparkTask extends Logging { }) } } - -class InternalRowDataWriterFactory( - rowWriterFactory: DataWriterFactory[Row], - schema: StructType) extends DataWriterFactory[InternalRow] { - - override def createDataWriter( - partitionId: Int, - taskId: Long, - epochId: Long): DataWriter[InternalRow] = { - new InternalRowDataWriter( - rowWriterFactory.createDataWriter(partitionId, taskId, epochId), - RowEncoder.apply(schema).resolveAndBind()) - } -} - -class InternalRowDataWriter(rowWriter: DataWriter[Row], encoder: ExpressionEncoder[Row]) - extends DataWriter[InternalRow] { - - override def write(record: InternalRow): Unit = rowWriter.write(encoder.fromRow(record)) - - override def commit(): WriterCommitMessage = rowWriter.commit() - - override def abort(): Unit = rowWriter.abort() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index abb807def6239..c759f5be8ba35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -28,10 +28,9 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} -import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} +import org.apache.spark.sql.execution.streaming.sources.MicroBatchWriter import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} -import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} @@ -498,12 +497,7 @@ class MicroBatchExecution( newAttributePlan.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) - if (writer.isInstanceOf[SupportsWriteInternalRow]) { - WriteToDataSourceV2( - new InternalRowMicroBatchWriter(currentBatchId, writer), newAttributePlan) - } else { - WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan) - } + WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan) case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala index 76f3f5baa8d56..967dbe24a3705 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -17,13 +17,10 @@ package org.apache.spark.sql.execution.streaming.continuous -import java.util.concurrent.atomic.AtomicLong - import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo} -import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory} import org.apache.spark.util.Utils /** @@ -47,7 +44,6 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor SparkEnv.get) EpochTracker.initializeCurrentEpoch( context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong) - while (!context.isInterrupted() && !context.isCompleted()) { var dataWriter: DataWriter[InternalRow] = null // write the data and commit this writer. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index e0af3a2f1b85d..927d3a84e296b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -19,18 +19,14 @@ package org.apache.spark.sql.execution.streaming.continuous import scala.util.control.NonFatal -import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.datasources.v2.{DataWritingSparkTask, InternalRowDataWriterFactory} -import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo} import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter -import org.apache.spark.util.Utils /** * The physical plan for writing data into a continuous processing [[StreamWriter]]. @@ -41,11 +37,7 @@ case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPla override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { - val writerFactory = writer match { - case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() - case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) - } - + val writerFactory = writer.createWriterFactory() val rdd = new ContinuousWriteRDD(query.execute(), writerFactory) logInfo(s"Start processing data source writer: $writer. " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala index d276403190b3c..fd45ba509091e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.execution.streaming.sources -import scala.collection.JavaConverters._ - import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter @@ -39,7 +39,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions) assert(SparkSession.getActiveSession.isDefined) protected val spark = SparkSession.getActiveSession.get - def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory + def createWriterFactory(): DataWriterFactory[InternalRow] = PackedRowWriterFactory override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { // We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2 @@ -62,8 +62,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions) println(printMessage) println("-------------------------------------------") // scalastyle:off println - spark - .createDataFrame(rows.toList.asJava, schema) + Dataset.ofRows(spark, LocalRelation(schema.toAttributes, rows)) .show(numRowsToShow, isTruncated) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala index bc9b6d93ce7d9..e8ce21cc12044 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.execution.streaming.sources -import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession} +import org.apache.spark.sql.{ForeachWriter, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.python.PythonForeachWriter import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -46,11 +46,11 @@ case class ForeachWriterProvider[T]( schema: StructType, mode: OutputMode, options: DataSourceOptions): StreamWriter = { - new StreamWriter with SupportsWriteInternalRow { + new StreamWriter { override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { + override def createWriterFactory(): DataWriterFactory[InternalRow] = { val rowConverter: InternalRow => T = converter match { case Left(enc) => val boundEnc = enc.resolveAndBind( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala index 56f7ff25cbed0..d023a35ea20b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.execution.streaming.sources -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter /** @@ -34,21 +33,5 @@ class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceWr override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) - override def createWriterFactory(): DataWriterFactory[Row] = writer.createWriterFactory() -} - -class InternalRowMicroBatchWriter(batchId: Long, writer: StreamWriter) - extends DataSourceWriter with SupportsWriteInternalRow { - override def commit(messages: Array[WriterCommitMessage]): Unit = { - writer.commit(batchId, messages) - } - - override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) - - override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = - writer match { - case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() - case _ => throw new IllegalStateException( - "InternalRowMicroBatchWriter should only be created with base writer support") - } + override def createWriterFactory(): DataWriterFactory[InternalRow] = writer.createWriterFactory() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index b501d90c81f06..f26e11d842b29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.sources import scala.collection.mutable import org.apache.spark.internal.Logging -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage} /** @@ -30,11 +30,11 @@ import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, Dat * Note that, because it sends all rows to the driver, this factory will generally be unsuitable * for production-quality sinks. It's intended for use in tests. */ -case object PackedRowWriterFactory extends DataWriterFactory[Row] { +case object PackedRowWriterFactory extends DataWriterFactory[InternalRow] { override def createDataWriter( partitionId: Int, taskId: Long, - epochId: Long): DataWriter[Row] = { + epochId: Long): DataWriter[InternalRow] = { new PackedRowDataWriter() } } @@ -43,15 +43,16 @@ case object PackedRowWriterFactory extends DataWriterFactory[Row] { * Commit message for a [[PackedRowDataWriter]], containing all the rows written in the most * recent interval. */ -case class PackedRowCommitMessage(rows: Array[Row]) extends WriterCommitMessage +case class PackedRowCommitMessage(rows: Array[InternalRow]) extends WriterCommitMessage /** * A simple [[DataWriter]] that just sends all the rows it's received as a commit message. */ -class PackedRowDataWriter() extends DataWriter[Row] with Logging { - private val data = mutable.Buffer[Row]() +class PackedRowDataWriter() extends DataWriter[InternalRow] with Logging { + private val data = mutable.Buffer[InternalRow]() - override def write(row: Row): Unit = data.append(row) + // Spark reuses the same `InternalRow` instance, here we copy it before buffer it. + override def write(row: InternalRow): Unit = data.append(row.copy()) override def commit(): PackedRowCommitMessage = { val msg = PackedRowCommitMessage(data.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index f2a35a90af24a..afacb2f72c926 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -25,6 +25,8 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils @@ -46,7 +48,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB schema: StructType, mode: OutputMode, options: DataSourceOptions): StreamWriter = { - new MemoryStreamWriter(this, mode) + new MemoryStreamWriter(this, mode, schema) } private case class AddedData(batchId: Long, data: Array[Row]) @@ -115,12 +117,13 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB override def toString(): String = "MemorySinkV2" } -case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} +case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) + extends WriterCommitMessage {} -class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) +class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode, schema: StructType) extends DataSourceWriter with Logging { - override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode, schema) def commit(messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { @@ -134,10 +137,10 @@ class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) } } -class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) +class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) extends StreamWriter { - override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode, schema) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { @@ -151,22 +154,26 @@ class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) } } -case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] { +case class MemoryWriterFactory(outputMode: OutputMode, schema: StructType) + extends DataWriterFactory[InternalRow] { + override def createDataWriter( partitionId: Int, taskId: Long, - epochId: Long): DataWriter[Row] = { - new MemoryDataWriter(partitionId, outputMode) + epochId: Long): DataWriter[InternalRow] = { + new MemoryDataWriter(partitionId, outputMode, schema) } } -class MemoryDataWriter(partition: Int, outputMode: OutputMode) - extends DataWriter[Row] with Logging { +class MemoryDataWriter(partition: Int, outputMode: OutputMode, schema: StructType) + extends DataWriter[InternalRow] with Logging { private val data = mutable.Buffer[Row]() - override def write(row: Row): Unit = { - data.append(row) + private val encoder = RowEncoder(schema).resolveAndBind() + + override def write(row: InternalRow): Unit = { + data.append(encoder.fromRow(row)) } override def commit(): MemoryWriterCommitMessage = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 9be22d94b5654..b4d9b68c78152 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -20,16 +20,19 @@ package org.apache.spark.sql.execution.streaming import org.scalatest.BeforeAndAfter import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.streaming.sources._ import org.apache.spark.sql.streaming.{OutputMode, StreamTest} +import org.apache.spark.sql.types.StructType class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("data writer") { val partition = 1234 - val writer = new MemoryDataWriter(partition, OutputMode.Append()) - writer.write(Row(1)) - writer.write(Row(2)) - writer.write(Row(44)) + val writer = new MemoryDataWriter( + partition, OutputMode.Append(), new StructType().add("i", "int")) + writer.write(InternalRow(1)) + writer.write(InternalRow(2)) + writer.write(InternalRow(44)) val msg = writer.commit() assert(msg.data.map(_.getInt(0)) == Seq(1, 2, 44)) assert(msg.partition == partition) @@ -40,7 +43,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("continuous writer") { val sink = new MemorySinkV2 - val writer = new MemoryStreamWriter(sink, OutputMode.Append()) + val writer = new MemoryStreamWriter(sink, OutputMode.Append(), new StructType().add("i", "int")) writer.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), @@ -62,7 +65,8 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("microbatch writer") { val sink = new MemorySinkV2 - new MemoryWriter(sink, 0, OutputMode.Append()).commit( + val schema = new StructType().add("i", "int") + new MemoryWriter(sink, 0, OutputMode.Append(), schema).commit( Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), @@ -70,7 +74,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - new MemoryWriter(sink, 19, OutputMode.Append()).commit( + new MemoryWriter(sink, 19, OutputMode.Append(), schema).commit( Array( MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), MemoryWriterCommitMessage(0, Seq(Row(33))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index c7da137219894..2496ac7bfdce9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -24,7 +24,6 @@ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec} import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector @@ -243,13 +242,6 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { assert(e2.getMessage.contains("Writing job aborted")) // make sure we don't have partial data. assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) - - // test internal row writer - spark.range(5).select('id, -'id).write.format(cls.getName) - .option("path", path).option("internal", "true").mode("overwrite").save() - checkAnswer( - spark.read.format(cls.getName).option("path", path).load(), - spark.range(5).select('id, -'id)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 183d0399d3bcd..e1b8e9c44d725 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} import org.apache.spark.SparkContext -import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.writer._ @@ -65,9 +65,9 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceWriter { - override def createWriterFactory(): DataWriterFactory[Row] = { + override def createWriterFactory(): DataWriterFactory[InternalRow] = { SimpleCounter.resetCounter - new SimpleCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) + new CSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) } override def onDataWriterCommit(message: WriterCommitMessage): Unit = { @@ -97,18 +97,6 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } } - class InternalRowWriter(jobId: String, path: String, conf: Configuration) - extends Writer(jobId, path, conf) with SupportsWriteInternalRow { - - override def createWriterFactory(): DataWriterFactory[Row] = { - throw new IllegalArgumentException("not expected!") - } - - override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { - new InternalRowCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) - } - } - override def createReader(options: DataSourceOptions): DataSourceReader = { val path = new Path(options.get("path").get()) val conf = SparkContext.getActive.get.hadoopConfiguration @@ -124,7 +112,6 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS assert(!SparkContext.getActive.get.conf.getBoolean("spark.speculation", false)) val path = new Path(options.get("path").get()) - val internal = options.get("internal").isPresent val conf = SparkContext.getActive.get.hadoopConfiguration val fs = path.getFileSystem(conf) @@ -142,17 +129,8 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS fs.delete(path, true) } - Optional.of(createWriter(jobId, path, conf, internal)) - } - - private def createWriter( - jobId: String, path: Path, conf: Configuration, internal: Boolean): DataSourceWriter = { val pathStr = path.toUri.toString - if (internal) { - new InternalRowWriter(jobId, pathStr, conf) - } else { - new Writer(jobId, pathStr, conf) - } + Optional.of(new Writer(jobId, pathStr, conf)) } } @@ -204,43 +182,7 @@ private[v2] object SimpleCounter { } } -class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) - extends DataWriterFactory[Row] { - - override def createDataWriter( - partitionId: Int, - taskId: Long, - epochId: Long): DataWriter[Row] = { - val jobPath = new Path(new Path(path, "_temporary"), jobId) - val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") - val fs = filePath.getFileSystem(conf.value) - new SimpleCSVDataWriter(fs, filePath) - } -} - -class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[Row] { - - private val out = fs.create(file) - - override def write(record: Row): Unit = { - out.writeBytes(s"${record.getLong(0)},${record.getLong(1)}\n") - } - - override def commit(): WriterCommitMessage = { - out.close() - null - } - - override def abort(): Unit = { - try { - out.close() - } finally { - fs.delete(file, false) - } - } -} - -class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) +class CSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[InternalRow] { override def createDataWriter( @@ -250,11 +192,11 @@ class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: Seriali val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") val fs = filePath.getFileSystem(conf.value) - new InternalRowCSVDataWriter(fs, filePath) + new CSVDataWriter(fs, filePath) } } -class InternalRowCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow] { +class CSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow] { private val out = fs.create(file)