diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index badaa69cc303c..48b91dfe764e9 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -26,6 +26,7 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ @@ -53,7 +54,7 @@ class KafkaContinuousReader( metadataPath: String, initialOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) - extends ContinuousReader with SupportsScanUnsafeRow with Logging { + extends ContinuousReader with Logging { private lazy val session = SparkSession.getActiveSession.get private lazy val sc = session.sparkContext @@ -86,7 +87,7 @@ class KafkaContinuousReader( KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } - override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { + override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { import scala.collection.JavaConverters._ val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) @@ -107,8 +108,8 @@ class KafkaContinuousReader( startOffsets.toSeq.map { case (topicPartition, start) => KafkaContinuousInputPartition( - topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) - .asInstanceOf[InputPartition[UnsafeRow]] + topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss + ): InputPartition[InternalRow] }.asJava } @@ -161,9 +162,10 @@ case class KafkaContinuousInputPartition( startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousInputPartition[UnsafeRow] { + failOnDataLoss: Boolean) extends ContinuousInputPartition[InternalRow] { - override def createContinuousReader(offset: PartitionOffset): InputPartitionReader[UnsafeRow] = { + override def createContinuousReader( + offset: PartitionOffset): InputPartitionReader[InternalRow] = { val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset] require(kafkaOffset.topicPartition == topicPartition, s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}") @@ -192,7 +194,7 @@ class KafkaContinuousInputPartitionReader( startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[UnsafeRow] { + failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[InternalRow] { private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false) private val converter = new KafkaRecordToUnsafeRowConverter diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index 737da2e51b125..6c95b2b2560c4 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -29,11 +29,12 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset} import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.UninterruptibleThread @@ -61,7 +62,7 @@ private[kafka010] class KafkaMicroBatchReader( metadataPath: String, startingOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) - extends MicroBatchReader with SupportsScanUnsafeRow with Logging { + extends MicroBatchReader with Logging { private var startPartitionOffsets: PartitionOffsetMap = _ private var endPartitionOffsets: PartitionOffsetMap = _ @@ -101,7 +102,7 @@ private[kafka010] class KafkaMicroBatchReader( } } - override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { + override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { // Find the new partitions, and get their earliest offsets val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) @@ -142,11 +143,11 @@ private[kafka010] class KafkaMicroBatchReader( val reuseKafkaConsumer = offsetRanges.map(_.topicPartition).toSet.size == offsetRanges.size // Generate factories based on the offset ranges - val factories = offsetRanges.map { range => + offsetRanges.map { range => new KafkaMicroBatchInputPartition( - range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) - } - factories.map(_.asInstanceOf[InputPartition[UnsafeRow]]).asJava + range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer + ): InputPartition[InternalRow] + }.asJava } override def getStartOffset: Offset = { @@ -305,11 +306,11 @@ private[kafka010] case class KafkaMicroBatchInputPartition( executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends InputPartition[UnsafeRow] { + reuseKafkaConsumer: Boolean) extends InputPartition[InternalRow] { override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray - override def createPartitionReader(): InputPartitionReader[UnsafeRow] = + override def createPartitionReader(): InputPartitionReader[InternalRow] = new KafkaMicroBatchInputPartitionReader(offsetRange, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) } @@ -320,7 +321,7 @@ private[kafka010] case class KafkaMicroBatchInputPartitionReader( executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends InputPartitionReader[UnsafeRow] with Logging { + reuseKafkaConsumer: Boolean) extends InputPartitionReader[InternalRow] with Logging { private val consumer = KafkaDataConsumer.acquire( offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index c6412eac97dba..5d5e57323cff5 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -678,7 +678,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))), Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L))) ) - val factories = reader.planUnsafeInputPartitions().asScala + val factories = reader.planInputPartitions().asScala .map(_.asInstanceOf[KafkaMicroBatchInputPartition]) withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") { assert(factories.size == numPartitionsGenerated) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java index 36a3e542b5a11..ad9c838992fa8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java @@ -20,7 +20,7 @@ import java.util.List; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; @@ -43,7 +43,7 @@ * Names of these interfaces start with `SupportsScan`. Note that a reader should only * implement at most one of the special scans, if more than one special scans are implemented, * only one of them would be respected, according to the priority list from high to low: - * {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}. + * {@link SupportsScanColumnarBatch}, {@link SupportsDeprecatedScanRow}. * * If an exception was throw when applying any of these query optimizations, the action will fail * and no Spark job will be submitted. @@ -76,5 +76,5 @@ public interface DataSourceReader { * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ - List> planInputPartitions(); + List> planInputPartitions(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java index 33fa7be4c1b20..7cf382e52f67e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java @@ -26,9 +26,10 @@ * An input partition reader returned by {@link InputPartition#createPartitionReader()} and is * responsible for outputting data for a RDD partition. * - * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal input - * partition readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for input - * partition readers that mix in {@link SupportsScanUnsafeRow}. + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow} + * for normal data source readers, {@link org.apache.spark.sql.vectorized.ColumnarBatch} for data + * source readers that mix in {@link SupportsScanColumnarBatch}, or {@link org.apache.spark.sql.Row} + * for data source readers that mix in {@link SupportsDeprecatedScanRow}. */ @InterfaceStability.Evolving public interface InputPartitionReader extends Closeable { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsDeprecatedScanRow.java similarity index 62% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsDeprecatedScanRow.java index f2220f6d31093..595943cf4d8ac 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsDeprecatedScanRow.java @@ -17,30 +17,23 @@ package org.apache.spark.sql.sources.v2.reader; -import java.util.List; - import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.InternalRow; + +import java.util.List; /** * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to output {@link UnsafeRow} directly and avoid the row copy at Spark side. - * This is an experimental and unstable interface, as {@link UnsafeRow} is not public and may get - * changed in the future Spark versions. + * interface to output {@link Row} instead of {@link InternalRow}. + * This is an experimental and unstable interface. */ @InterfaceStability.Unstable -public interface SupportsScanUnsafeRow extends DataSourceReader { - - @Override - default List> planInputPartitions() { +public interface SupportsDeprecatedScanRow extends DataSourceReader { + default List> planInputPartitions() { throw new IllegalStateException( - "planInputPartitions not supported by default within SupportsScanUnsafeRow"); + "planInputPartitions not supported by default within SupportsDeprecatedScanRow"); } - /** - * Similar to {@link DataSourceReader#planInputPartitions()}, - * but returns data in unsafe row format. - */ - List> planUnsafeInputPartitions(); + List> planRowInputPartitions(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java index 0faf81db24605..f4da686740d11 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java @@ -20,7 +20,7 @@ import java.util.List; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.vectorized.ColumnarBatch; /** @@ -30,7 +30,7 @@ @InterfaceStability.Evolving public interface SupportsScanColumnarBatch extends DataSourceReader { @Override - default List> planInputPartitions() { + default List> planInputPartitions() { throw new IllegalStateException( "planInputPartitions not supported by default within SupportsScanColumnarBatch."); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 8d6fb3820d420..7ea53424ae100 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.datasources.v2 -import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index c6a7684bf6ab0..b030b9a929b08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -75,12 +75,13 @@ case class DataSourceV2ScanExec( case _ => super.outputPartitioning } - private lazy val partitions: Seq[InputPartition[UnsafeRow]] = reader match { - case r: SupportsScanUnsafeRow => r.planUnsafeInputPartitions().asScala - case _ => - reader.planInputPartitions().asScala.map { - new RowToUnsafeRowInputPartition(_, reader.readSchema()): InputPartition[UnsafeRow] + private lazy val partitions: Seq[InputPartition[InternalRow]] = reader match { + case r: SupportsDeprecatedScanRow => + r.planRowInputPartitions().asScala.map { + new RowToUnsafeRowInputPartition(_, reader.readSchema()): InputPartition[InternalRow] } + case _ => + reader.planInputPartitions().asScala } private lazy val batchPartitions: Seq[InputPartition[ColumnarBatch]] = reader match { @@ -132,11 +133,11 @@ case class DataSourceV2ScanExec( } class RowToUnsafeRowInputPartition(partition: InputPartition[Row], schema: StructType) - extends InputPartition[UnsafeRow] { + extends InputPartition[InternalRow] { override def preferredLocations: Array[String] = partition.preferredLocations - override def createPartitionReader: InputPartitionReader[UnsafeRow] = { + override def createPartitionReader: InputPartitionReader[InternalRow] = { new RowToUnsafeInputPartitionReader( partition.createPartitionReader, RowEncoder.apply(schema).resolveAndBind()) } @@ -146,7 +147,7 @@ class RowToUnsafeInputPartitionReader( val rowReader: InputPartitionReader[Row], encoder: ExpressionEncoder[Row]) - extends InputPartitionReader[UnsafeRow] { + extends InputPartitionReader[InternalRow] { override def next: Boolean = rowReader.next diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 2a7f1de2c7c19..9414e68155b98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -125,16 +125,13 @@ object DataSourceV2Strategy extends Strategy { val filterCondition = postScanFilters.reduceLeftOption(And) val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) - val withProjection = if (withFilter.output != project) { - ProjectExec(project, withFilter) - } else { - withFilter - } - - withProjection :: Nil + // always add the projection, which will produce unsafe rows required by some operators + ProjectExec(project, withFilter) :: Nil case r: StreamingDataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil + // ensure there is a projection, which will produce unsafe rows required by some operators + ProjectExec(r.output, + DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader)) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala index 73868d5967e90..1ffa1d02f1432 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -19,16 +19,16 @@ package org.apache.spark.sql.execution.streaming.continuous import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeInputPartitionReader} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.v2.RowToUnsafeInputPartitionReader import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, PartitionOffset} -import org.apache.spark.util.{NextIterator, ThreadUtils} +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader +import org.apache.spark.util.NextIterator class ContinuousDataSourceRDDPartition( val index: Int, - val inputPartition: InputPartition[UnsafeRow]) + val inputPartition: InputPartition[InternalRow]) extends Partition with Serializable { // This is semantically a lazy val - it's initialized once the first time a call to @@ -51,8 +51,8 @@ class ContinuousDataSourceRDD( sc: SparkContext, dataQueueSize: Int, epochPollIntervalMs: Long, - private val readerInputPartitions: Seq[InputPartition[UnsafeRow]]) - extends RDD[UnsafeRow](sc, Nil) { + private val readerInputPartitions: Seq[InputPartition[InternalRow]]) + extends RDD[InternalRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { readerInputPartitions.zipWithIndex.map { @@ -64,7 +64,7 @@ class ContinuousDataSourceRDD( * Initialize the shared reader for this partition if needed, then read rows from it until * it returns null to signal the end of the epoch. */ - override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { // If attempt number isn't 0, this is a task retry, which we don't support. if (context.attemptNumber() != 0) { throw new ContinuousTaskRetryException() @@ -80,8 +80,8 @@ class ContinuousDataSourceRDD( partition.queueReader } - new NextIterator[UnsafeRow] { - override def getNext(): UnsafeRow = { + new NextIterator[InternalRow] { + override def getNext(): InternalRow = { readerForPartition.next() match { case null => finished = true @@ -101,9 +101,9 @@ class ContinuousDataSourceRDD( object ContinuousDataSourceRDD { private[continuous] def getContinuousReader( - reader: InputPartitionReader[UnsafeRow]): ContinuousInputPartitionReader[_] = { + reader: InputPartitionReader[InternalRow]): ContinuousInputPartitionReader[_] = { reader match { - case r: ContinuousInputPartitionReader[UnsafeRow] => r + case r: ContinuousInputPartitionReader[InternalRow] => r case wrapped: RowToUnsafeInputPartitionReader => wrapped.rowReader.asInstanceOf[ContinuousInputPartitionReader[Row]] case _ => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index 8c74b8244d096..bfb87053db475 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -24,7 +24,7 @@ import scala.util.control.NonFatal import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset import org.apache.spark.util.ThreadUtils @@ -52,7 +52,7 @@ class ContinuousQueuedDataReader( */ sealed trait ContinuousRecord case object EpochMarker extends ContinuousRecord - case class ContinuousRow(row: UnsafeRow, offset: PartitionOffset) extends ContinuousRecord + case class ContinuousRow(row: InternalRow, offset: PartitionOffset) extends ContinuousRecord private val queue = new ArrayBlockingQueue[ContinuousRecord](dataQueueSize) @@ -79,12 +79,12 @@ class ContinuousQueuedDataReader( } /** - * Return the next UnsafeRow to be read in the current epoch, or null if the epoch is done. + * Return the next row to be read in the current epoch, or null if the epoch is done. * * After returning null, the [[ContinuousDataSourceRDD]] compute() for the following epoch * will call next() again to start getting rows. */ - def next(): UnsafeRow = { + def next(): InternalRow = { val POLL_TIMEOUT_MS = 1000 var currentEntry: ContinuousRecord = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 516a563bdcc7a..55ce3ae38ee3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -35,7 +35,7 @@ case class RateStreamPartitionOffset( partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset class RateStreamContinuousReader(options: DataSourceOptions) - extends ContinuousReader { + extends ContinuousReader with SupportsDeprecatedScanRow { implicit val defaultFormats: DefaultFormats = DefaultFormats val creationTime = System.currentTimeMillis() @@ -67,7 +67,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) override def getStartOffset(): Offset = offset - override def planInputPartitions(): java.util.List[InputPartition[Row]] = { + override def planRowInputPartitions(): java.util.List[InputPartition[Row]] = { val partitionStartMap = offset match { case off: RateStreamOffset => off.partitionToValueAndRunTimeMs case off => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index b137f98045c5a..f81abdcc3711a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -28,12 +28,13 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -79,8 +80,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas * available. */ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) - extends MemoryStreamBase[A](sqlContext) - with MicroBatchReader with SupportsScanUnsafeRow with Logging { + extends MemoryStreamBase[A](sqlContext) with MicroBatchReader with Logging { protected val logicalPlan: LogicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) @@ -139,7 +139,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) if (endOffset.offset == -1) null else endOffset } - override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { + override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { synchronized { // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) val startOrdinal = startOffset.offset.toInt + 1 @@ -156,7 +156,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) newBlocks.map { block => - new MemoryStreamInputPartition(block).asInstanceOf[InputPartition[UnsafeRow]] + new MemoryStreamInputPartition(block): InputPartition[InternalRow] }.asJava } } @@ -202,9 +202,9 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) class MemoryStreamInputPartition(records: Array[UnsafeRow]) - extends InputPartition[UnsafeRow] { - override def createPartitionReader(): InputPartitionReader[UnsafeRow] = { - new InputPartitionReader[UnsafeRow] { + extends InputPartition[InternalRow] { + override def createPartitionReader(): InputPartitionReader[InternalRow] = { + new InputPartitionReader[InternalRow] { private var currentIndex = -1 override def next(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index 0bf90b8063326..e776ebc08e30d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.{Encoder, Row, SQLContext} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions} -import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.sources.v2.reader.{InputPartition, SupportsDeprecatedScanRow} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.RpcUtils @@ -49,7 +49,8 @@ import org.apache.spark.util.RpcUtils * the specified offset within the list, or null if that offset doesn't yet have a record. */ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2) - extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { + extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport + with SupportsDeprecatedScanRow { private implicit val formats = Serialization.formats(NoTypeHints) protected val logicalPlan = @@ -99,7 +100,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa ) } - override def planInputPartitions(): ju.List[InputPartition[Row]] = { + override def planRowInputPartitions(): ju.List[InputPartition[Row]] = { synchronized { val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id" endpointRef = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index b393c48baee8d..7a3452aa315cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.util.{ManualClock, SystemClock} class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) - extends MicroBatchReader with Logging { + extends MicroBatchReader with SupportsDeprecatedScanRow with Logging { import RateStreamProvider._ private[sources] val clock = { @@ -134,7 +134,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: LongOffset(json.toLong) } - override def planInputPartitions(): java.util.List[InputPartition[Row]] = { + override def planRowInputPartitions(): java.util.List[InputPartition[Row]] = { val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 91e3b7179c34a..e3a2c007a9ce4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.LongOffset import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsDeprecatedScanRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} @@ -50,7 +50,8 @@ object TextSocketMicroBatchReader { * debugging. This MicroBatchReader will *not* work in production applications due to multiple * reasons, including no support for fault recovery. */ -class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader with Logging { +class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader + with SupportsDeprecatedScanRow with Logging { private var startOffset: Offset = _ private var endOffset: Offset = _ @@ -141,7 +142,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR } } - override def planInputPartitions(): JList[InputPartition[Row]] = { + override def planRowInputPartitions(): JList[InputPartition[Row]] = { assert(startOffset != null && endOffset != null, "start offset and end offset should already be set before create read tasks.") diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 445cb29f5ee3a..c130b5f1e2513 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -33,7 +33,7 @@ public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport { public class Reader implements DataSourceReader, SupportsPushDownRequiredColumns, - SupportsPushDownFilters { + SupportsPushDownFilters, SupportsDeprecatedScanRow { // Exposed for testing. public StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); @@ -79,7 +79,7 @@ public Filter[] pushedFilters() { } @Override - public List> planInputPartitions() { + public List> planRowInputPartitions() { List> res = new ArrayList<>(); Integer lowerBound = null; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index e49c8cf8b9e16..35aafb532d80d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -34,7 +34,7 @@ public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { - class Reader implements DataSourceReader, SupportsReportPartitioning { + class Reader implements DataSourceReader, SupportsReportPartitioning, SupportsDeprecatedScanRow { private final StructType schema = new StructType().add("a", "int").add("b", "int"); @Override @@ -43,7 +43,7 @@ public StructType readSchema() { } @Override - public List> planInputPartitions() { + public List> planRowInputPartitions() { return java.util.Arrays.asList( new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}), new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2})); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index 80eeffd95f83b..6dee94c34e21c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -25,11 +25,12 @@ import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.SupportsDeprecatedScanRow; import org.apache.spark.sql.types.StructType; public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema { - class Reader implements DataSourceReader { + class Reader implements DataSourceReader, SupportsDeprecatedScanRow { private final StructType schema; Reader(StructType schema) { @@ -42,7 +43,7 @@ public StructType readSchema() { } @Override - public List> planInputPartitions() { + public List> planRowInputPartitions() { return java.util.Collections.emptyList(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 8522a63898a3b..5c2f351975c74 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -28,11 +28,12 @@ import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; +import org.apache.spark.sql.sources.v2.reader.SupportsDeprecatedScanRow; import org.apache.spark.sql.types.StructType; public class JavaSimpleDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceReader { + class Reader implements DataSourceReader, SupportsDeprecatedScanRow { private final StructType schema = new StructType().add("i", "int").add("j", "int"); @Override @@ -41,7 +42,7 @@ public StructType readSchema() { } @Override - public List> planInputPartitions() { + public List> planRowInputPartitions() { return java.util.Arrays.asList( new JavaSimpleInputPartition(0, 5), new JavaSimpleInputPartition(5, 10)); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java index 3ad8e7a0104ce..25b89c7fd36a9 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.List; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; @@ -29,7 +30,7 @@ public class JavaUnsafeRowDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceReader, SupportsScanUnsafeRow { + class Reader implements DataSourceReader { private final StructType schema = new StructType().add("i", "int").add("j", "int"); @Override @@ -38,7 +39,7 @@ public StructType readSchema() { } @Override - public List> planUnsafeInputPartitions() { + public List> planInputPartitions() { return java.util.Arrays.asList( new JavaUnsafeRowInputPartition(0, 5), new JavaUnsafeRowInputPartition(5, 10)); @@ -46,7 +47,7 @@ public List> planUnsafeInputPartitions() { } static class JavaUnsafeRowInputPartition - implements InputPartition, InputPartitionReader { + implements InputPartition, InputPartitionReader { private int start; private int end; private UnsafeRow row; @@ -59,7 +60,7 @@ static class JavaUnsafeRowInputPartition } @Override - public InputPartitionReader createPartitionReader() { + public InputPartitionReader createPartitionReader() { return new JavaUnsafeRowInputPartition(start - 1, end); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index 9115a384d0790..260a0376daeb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -146,7 +146,7 @@ class RateSourceSuite extends StreamTest { val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.planInputPartitions() + val tasks = reader.planRowInputPartitions() assert(tasks.size == 1) val dataReader = tasks.get(0).createPartitionReader() val data = ArrayBuffer[Row]() @@ -165,7 +165,7 @@ class RateSourceSuite extends StreamTest { val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.planInputPartitions() + val tasks = reader.planRowInputPartitions() assert(tasks.size == 11) val readData = tasks.asScala @@ -311,7 +311,7 @@ class RateSourceSuite extends StreamTest { val reader = new RateStreamContinuousReader( new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) reader.setStartOffset(Optional.empty()) - val tasks = reader.planInputPartitions() + val tasks = reader.planRowInputPartitions() assert(tasks.size == 2) val data = scala.collection.mutable.ListBuffer[Row]() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index e96cd4500458d..d73eebbc84b71 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -23,6 +23,7 @@ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec} import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} @@ -344,10 +345,10 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader { + class Reader extends DataSourceReader with SupportsDeprecatedScanRow { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def planInputPartitions(): JList[InputPartition[Row]] = { + override def planRowInputPartitions(): JList[InputPartition[Row]] = { java.util.Arrays.asList(new SimpleInputPartition(0, 5)) } } @@ -357,10 +358,10 @@ class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader { + class Reader extends DataSourceReader with SupportsDeprecatedScanRow { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def planInputPartitions(): JList[InputPartition[Row]] = { + override def planRowInputPartitions(): JList[InputPartition[Row]] = { java.util.Arrays.asList(new SimpleInputPartition(0, 5), new SimpleInputPartition(5, 10)) } } @@ -390,7 +391,7 @@ class SimpleInputPartition(start: Int, end: Int) class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader + class Reader extends DataSourceReader with SupportsDeprecatedScanRow with SupportsPushDownRequiredColumns with SupportsPushDownFilters { var requiredSchema = new StructType().add("i", "int").add("j", "int") @@ -415,7 +416,7 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { requiredSchema } - override def planInputPartitions(): JList[InputPartition[Row]] = { + override def planRowInputPartitions(): JList[InputPartition[Row]] = { val lowerBound = filters.collect { case GreaterThan("i", v: Int) => v }.headOption @@ -467,10 +468,10 @@ class AdvancedInputPartition(start: Int, end: Int, requiredSchema: StructType) class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader with SupportsScanUnsafeRow { + class Reader extends DataSourceReader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def planUnsafeInputPartitions(): JList[InputPartition[UnsafeRow]] = { + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { java.util.Arrays.asList(new UnsafeRowInputPartitionReader(0, 5), new UnsafeRowInputPartitionReader(5, 10)) } @@ -480,14 +481,14 @@ class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { } class UnsafeRowInputPartitionReader(start: Int, end: Int) - extends InputPartition[UnsafeRow] with InputPartitionReader[UnsafeRow] { + extends InputPartition[InternalRow] with InputPartitionReader[InternalRow] { private val row = new UnsafeRow(2) row.pointTo(new Array[Byte](8 * 3), 8 * 3) private var current = start - 1 - override def createPartitionReader(): InputPartitionReader[UnsafeRow] = this + override def createPartitionReader(): InputPartitionReader[InternalRow] = this override def next(): Boolean = { current += 1 @@ -504,8 +505,8 @@ class UnsafeRowInputPartitionReader(start: Int, end: Int) class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { - class Reader(val readSchema: StructType) extends DataSourceReader { - override def planInputPartitions(): JList[InputPartition[Row]] = + class Reader(val readSchema: StructType) extends DataSourceReader with SupportsDeprecatedScanRow { + override def planRowInputPartitions(): JList[InputPartition[Row]] = java.util.Collections.emptyList() } @@ -568,10 +569,11 @@ class BatchInputPartitionReader(start: Int, end: Int) class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader with SupportsReportPartitioning { + class Reader extends DataSourceReader with SupportsReportPartitioning + with SupportsDeprecatedScanRow { override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") - override def planInputPartitions(): JList[InputPartition[Row]] = { + override def planRowInputPartitions(): JList[InputPartition[Row]] = { // Note that we don't have same value of column `a` across partitions. java.util.Arrays.asList( new SpecificInputPartitionReader(Array(1, 1, 3), Array(4, 4, 6)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 1334cf71ae988..98d7eedbcb9c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader, SupportsDeprecatedScanRow} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -42,10 +42,11 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS private val schema = new StructType().add("i", "long").add("j", "long") - class Reader(path: String, conf: Configuration) extends DataSourceReader { + class Reader(path: String, conf: Configuration) extends DataSourceReader + with SupportsDeprecatedScanRow { override def readSchema(): StructType = schema - override def planInputPartitions(): JList[InputPartition[Row]] = { + override def planRowInputPartitions(): JList[InputPartition[Row]] = { val dataPath = new Path(path) val fs = dataPath.getFileSystem(conf) if (fs.exists(dataPath)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 936a076d647b6..78199b0a1c19a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -30,7 +30,7 @@ import org.scalatest.mockito.MockitoSugar import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ @@ -227,10 +227,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } // getBatch should take 100 ms the first time it is called - override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { + override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { synchronized { clock.waitTillTime(1350) - super.planUnsafeInputPartitions() + super.planInputPartitions() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index 0e7e6febb53df..4f198819b58d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.streaming.continuous import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} -import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Mockito._ import org.scalatest.mockito.MockitoSugar import org.apache.spark.{SparkEnv, SparkFunSuite, TaskContext} import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader.InputPartition @@ -73,8 +73,8 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { */ private def setup(): (BlockingQueue[UnsafeRow], ContinuousQueuedDataReader) = { val queue = new ArrayBlockingQueue[UnsafeRow](1024) - val factory = new InputPartition[UnsafeRow] { - override def createPartitionReader() = new ContinuousInputPartitionReader[UnsafeRow] { + val factory = new InputPartition[InternalRow] { + override def createPartitionReader() = new ContinuousInputPartitionReader[InternalRow] { var index = -1 var curr: UnsafeRow = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index c1a28b9bc75ef..7c012158bd751 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -26,14 +26,15 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.sources.v2.reader.{InputPartition, SupportsDeprecatedScanRow} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils -case class FakeReader() extends MicroBatchReader with ContinuousReader { +case class FakeReader() extends MicroBatchReader with ContinuousReader + with SupportsDeprecatedScanRow { def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = {} def getStartOffset: Offset = RateStreamOffset(Map()) def getEndOffset: Offset = RateStreamOffset(Map()) @@ -44,7 +45,7 @@ case class FakeReader() extends MicroBatchReader with ContinuousReader { def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) def setStartOffset(start: Optional[Offset]): Unit = {} - def planInputPartitions(): java.util.ArrayList[InputPartition[Row]] = { + def planRowInputPartitions(): java.util.ArrayList[InputPartition[Row]] = { throw new IllegalStateException("fake source - cannot actually read") } }