Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
import org.apache.spark.sql.streaming.OutputMode
Expand Down Expand Up @@ -221,26 +222,73 @@ class MemoryStreamInputPartition(records: Array[UnsafeRow])
}

/** A common trait for MemorySinks with methods used for testing */
trait MemorySinkBase extends BaseStreamingSink {
trait MemorySinkBase extends BaseStreamingSink with Logging {
def allData: Seq[Row]
def latestBatchData: Seq[Row]
def dataSinceBatch(sinceBatchId: Long): Seq[Row]
def latestBatchId: Option[Long]

/**
* Truncates the given rows to return at most maxRows rows.
* @param rows The data that may need to be truncated.
* @param batchLimit Number of rows to keep in this batch; the rest will be truncated
* @param sinkLimit Total number of rows kept in this sink, for logging purposes.
* @param batchId The ID of the batch that sent these rows, for logging purposes.
* @return Truncated rows.
*/
protected def truncateRowsIfNeeded(
rows: Array[Row],
batchLimit: Int,
sinkLimit: Int,
batchId: Long): Array[Row] = {
if (rows.length > batchLimit && batchLimit >= 0) {
logWarning(s"Truncating batch $batchId to $batchLimit rows because of sink limit $sinkLimit")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: not sure if these sinks get used by Continuous processing too. If so I would rename batch to trigger version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This piece is shared by MemorySink and MemorySinkV2, and the MemorySinkV2 (continuous processing) sink still calls them batches.

rows.take(batchLimit)
} else {
rows
}
}
}

/**
* Companion object to MemorySinkBase.
*/
object MemorySinkBase {
val MAX_MEMORY_SINK_ROWS = "maxRows"
val MAX_MEMORY_SINK_ROWS_DEFAULT = -1

/**
* Gets the max number of rows a MemorySink should store. This number is based on the memory
* sink row limit option if it is set. If not, we use a large value so that data truncates
* rather than causing out of memory errors.
* @param options Options for writing from which we get the max rows option
* @return The maximum number of rows a memorySink should store.
*/
def getMemorySinkCapacity(options: DataSourceOptions): Int = {
val maxRows = options.getInt(MAX_MEMORY_SINK_ROWS, MAX_MEMORY_SINK_ROWS_DEFAULT)
if (maxRows >= 0) maxRows else Int.MaxValue - 10
}
}

/**
* A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
* tests and does not provide durability.
*/
class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink
with MemorySinkBase with Logging {
class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSourceOptions)
extends Sink with MemorySinkBase with Logging {

private case class AddedData(batchId: Long, data: Array[Row])

/** An order list of batches that have been written to this [[Sink]]. */
@GuardedBy("this")
private val batches = new ArrayBuffer[AddedData]()

/** The number of rows in this MemorySink. */
private var numRows = 0

/** The capacity in rows of this sink. */
val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options)

/** Returns all rows that are stored in this [[Sink]]. */
def allData: Seq[Row] = synchronized {
batches.flatMap(_.data)
Expand Down Expand Up @@ -273,14 +321,23 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink
logDebug(s"Committing batch $batchId to $this")
outputMode match {
case Append | Update =>
val rows = AddedData(batchId, data.collect())
synchronized { batches += rows }
var rowsToAdd = data.collect()
synchronized {
rowsToAdd =
truncateRowsIfNeeded(rowsToAdd, sinkCapacity - numRows, sinkCapacity, batchId)
val rows = AddedData(batchId, rowsToAdd)
batches += rows
numRows += rowsToAdd.length
}

case Complete =>
val rows = AddedData(batchId, data.collect())
var rowsToAdd = data.collect()
synchronized {
rowsToAdd = truncateRowsIfNeeded(rowsToAdd, sinkCapacity, sinkCapacity, batchId)
val rows = AddedData(batchId, rowsToAdd)
batches.clear()
batches += rows
numRows = rowsToAdd.length
}

case _ =>
Expand All @@ -294,6 +351,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink

def clear(): Unit = synchronized {
batches.clear()
numRows = 0
}

override def toString(): String = "MemorySink"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB
schema: StructType,
mode: OutputMode,
options: DataSourceOptions): StreamWriter = {
new MemoryStreamWriter(this, mode)
new MemoryStreamWriter(this, mode, options)
}

private case class AddedData(batchId: Long, data: Array[Row])
Expand All @@ -55,6 +55,9 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB
@GuardedBy("this")
private val batches = new ArrayBuffer[AddedData]()

/** The number of rows in this MemorySink. */
private var numRows = 0

/** Returns all rows that are stored in this [[Sink]]. */
def allData: Seq[Row] = synchronized {
batches.flatMap(_.data)
Expand All @@ -81,22 +84,33 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB
}.mkString("\n")
}

def write(batchId: Long, outputMode: OutputMode, newRows: Array[Row]): Unit = {
def write(
batchId: Long,
outputMode: OutputMode,
newRows: Array[Row],
sinkCapacity: Int): Unit = {
val notCommitted = synchronized {
latestBatchId.isEmpty || batchId > latestBatchId.get
}
if (notCommitted) {
logDebug(s"Committing batch $batchId to $this")
outputMode match {
case Append | Update =>
val rows = AddedData(batchId, newRows)
synchronized { batches += rows }
synchronized {
val rowsToAdd =
truncateRowsIfNeeded(newRows, sinkCapacity - numRows, sinkCapacity, batchId)
val rows = AddedData(batchId, rowsToAdd)
batches += rows
numRows += rowsToAdd.length
}

case Complete =>
val rows = AddedData(batchId, newRows)
synchronized {
val rowsToAdd = truncateRowsIfNeeded(newRows, sinkCapacity, sinkCapacity, batchId)
val rows = AddedData(batchId, rowsToAdd)
batches.clear()
batches += rows
numRows = rowsToAdd.length
}

case _ =>
Expand All @@ -110,40 +124,52 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB

def clear(): Unit = synchronized {
batches.clear()
numRows = 0
}

override def toString(): String = "MemorySinkV2"
}

case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {}

class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode)
class MemoryWriter(
sink: MemorySinkV2,
batchId: Long,
outputMode: OutputMode,
options: DataSourceOptions)
extends DataSourceWriter with Logging {

val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options)

override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode)

def commit(messages: Array[WriterCommitMessage]): Unit = {
val newRows = messages.flatMap {
case message: MemoryWriterCommitMessage => message.data
}
sink.write(batchId, outputMode, newRows)
sink.write(batchId, outputMode, newRows, sinkCapacity)
}

override def abort(messages: Array[WriterCommitMessage]): Unit = {
// Don't accept any of the new input.
}
}

class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode)
class MemoryStreamWriter(
val sink: MemorySinkV2,
outputMode: OutputMode,
options: DataSourceOptions)
extends StreamWriter {

val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options)

override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode)

override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
val newRows = messages.flatMap {
case message: MemoryWriterCommitMessage => message.data
}
sink.write(epochId, outputMode, newRows)
sink.write(epochId, outputMode, newRows, sinkCapacity)
}

override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2}
import org.apache.spark.sql.sources.v2.StreamWriteSupport
import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport}

/**
* Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems,
Expand Down Expand Up @@ -249,7 +249,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes))
(s, r)
case _ =>
val s = new MemorySink(df.schema, outputMode)
val s = new MemorySink(df.schema, outputMode, new DataSourceOptions(extraOptions.asJava))
val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s))
(s, r)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

package org.apache.spark.sql.execution.streaming

import scala.collection.JavaConverters._
import scala.language.implicitConversions

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql._
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.streaming.{OutputMode, StreamTest}
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.util.Utils
Expand All @@ -36,7 +38,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {

test("directly add data in Append output mode") {
implicit val schema = new StructType().add(new StructField("value", IntegerType))
val sink = new MemorySink(schema, OutputMode.Append)
val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty())

// Before adding data, check output
assert(sink.latestBatchId === None)
Expand Down Expand Up @@ -68,9 +70,35 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
checkAnswer(sink.allData, 1 to 9)
}

test("directly add data in Append output mode with row limit") {
implicit val schema = new StructType().add(new StructField("value", IntegerType))

var optionsMap = new scala.collection.mutable.HashMap[String, String]
optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString())
var options = new DataSourceOptions(optionsMap.toMap.asJava)
val sink = new MemorySink(schema, OutputMode.Append, options)

// Before adding data, check output
assert(sink.latestBatchId === None)
checkAnswer(sink.latestBatchData, Seq.empty)
checkAnswer(sink.allData, Seq.empty)

// Add batch 0 and check outputs
sink.addBatch(0, 1 to 3)
assert(sink.latestBatchId === Some(0))
checkAnswer(sink.latestBatchData, 1 to 3)
checkAnswer(sink.allData, 1 to 3)

// Add batch 1 and check outputs
sink.addBatch(1, 4 to 6)
assert(sink.latestBatchId === Some(1))
checkAnswer(sink.latestBatchData, 4 to 5)
checkAnswer(sink.allData, 1 to 5) // new data should not go over the limit
}

test("directly add data in Update output mode") {
implicit val schema = new StructType().add(new StructField("value", IntegerType))
val sink = new MemorySink(schema, OutputMode.Update)
val sink = new MemorySink(schema, OutputMode.Update, DataSourceOptions.empty())

// Before adding data, check output
assert(sink.latestBatchId === None)
Expand Down Expand Up @@ -104,7 +132,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {

test("directly add data in Complete output mode") {
implicit val schema = new StructType().add(new StructField("value", IntegerType))
val sink = new MemorySink(schema, OutputMode.Complete)
val sink = new MemorySink(schema, OutputMode.Complete, DataSourceOptions.empty())

// Before adding data, check output
assert(sink.latestBatchId === None)
Expand Down Expand Up @@ -136,6 +164,32 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
checkAnswer(sink.allData, 7 to 9)
}

test("directly add data in Complete output mode with row limit") {
implicit val schema = new StructType().add(new StructField("value", IntegerType))

var optionsMap = new scala.collection.mutable.HashMap[String, String]
optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString())
var options = new DataSourceOptions(optionsMap.toMap.asJava)
val sink = new MemorySink(schema, OutputMode.Complete, options)

// Before adding data, check output
assert(sink.latestBatchId === None)
checkAnswer(sink.latestBatchData, Seq.empty)
checkAnswer(sink.allData, Seq.empty)

// Add batch 0 and check outputs
sink.addBatch(0, 1 to 3)
assert(sink.latestBatchId === Some(0))
checkAnswer(sink.latestBatchData, 1 to 3)
checkAnswer(sink.allData, 1 to 3)

// Add batch 1 and check outputs
sink.addBatch(1, 4 to 10)
assert(sink.latestBatchId === Some(1))
checkAnswer(sink.latestBatchData, 4 to 8)
checkAnswer(sink.allData, 4 to 8) // new data should replace old data
}


test("registering as a table in Append output mode") {
val input = MemoryStream[Int]
Expand Down Expand Up @@ -211,7 +265,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {

test("MemoryPlan statistics") {
implicit val schema = new StructType().add(new StructField("value", IntegerType))
val sink = new MemorySink(schema, OutputMode.Append)
val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty())
val plan = new MemoryPlan(sink)

// Before adding data, check output
Expand Down
Loading