Skip to content

Commit 23d4ecb

Browse files
Mukul Murthyotterc
authored andcommitted
[SPARK-24525][SS] Provide an option to limit number of rows in a MemorySink
Provide an option to limit number of rows in a MemorySink. Currently, MemorySink and MemorySinkV2 have unbounded size, meaning that if they're used on big data, they can OOM the stream. This change adds a maxMemorySinkRows option to limit how many rows MemorySink and MemorySinkV2 can hold. By default, they are still unbounded. Added new unit tests. Author: Mukul Murthy <[email protected]> Closes apache#21559 from mukulmurthy/SPARK-24525. Ref: LIHADOOP-48531 RB=1852593 G=superfriends-reviewers R=mshen,fli,latang,yezhou,zolin A=
1 parent fcc3690 commit 23d4ecb

File tree

6 files changed

+261
-30
lines changed

6 files changed

+261
-30
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import javax.annotation.concurrent.GuardedBy
2424

2525
import scala.collection.JavaConverters._
2626
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
27-
import scala.reflect.ClassTag
2827
import scala.util.control.NonFatal
2928

3029
import org.apache.spark.internal.Logging
@@ -33,6 +32,7 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
3332
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
3433
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
3534
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
35+
import org.apache.spark.sql.sources.v2.DataSourceOptions
3636
import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow}
3737
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
3838
import org.apache.spark.sql.streaming.{OutputMode, Trigger}
@@ -221,21 +221,77 @@ class MemoryStreamInputPartition(records: Array[UnsafeRow])
221221
}
222222
}
223223

224+
/** A common trait for MemorySinks with methods used for testing */
225+
trait MemorySinkBase extends BaseStreamingSink with Logging {
226+
def allData: Seq[Row]
227+
def latestBatchData: Seq[Row]
228+
def dataSinceBatch(sinceBatchId: Long): Seq[Row]
229+
def latestBatchId: Option[Long]
230+
231+
/**
232+
* Truncates the given rows to return at most maxRows rows.
233+
* @param rows The data that may need to be truncated.
234+
* @param batchLimit Number of rows to keep in this batch; the rest will be truncated
235+
* @param sinkLimit Total number of rows kept in this sink, for logging purposes.
236+
* @param batchId The ID of the batch that sent these rows, for logging purposes.
237+
* @return Truncated rows.
238+
*/
239+
protected def truncateRowsIfNeeded(
240+
rows: Array[Row],
241+
batchLimit: Int,
242+
sinkLimit: Int,
243+
batchId: Long): Array[Row] = {
244+
if (rows.length > batchLimit && batchLimit >= 0) {
245+
logWarning(s"Truncating batch $batchId to $batchLimit rows because of sink limit $sinkLimit")
246+
rows.take(batchLimit)
247+
} else {
248+
rows
249+
}
250+
}
251+
}
252+
253+
/**
254+
* Companion object to MemorySinkBase.
255+
*/
256+
object MemorySinkBase {
257+
val MAX_MEMORY_SINK_ROWS = "maxRows"
258+
val MAX_MEMORY_SINK_ROWS_DEFAULT = -1
259+
260+
/**
261+
* Gets the max number of rows a MemorySink should store. This number is based on the memory
262+
* sink row limit option if it is set. If not, we use a large value so that data truncates
263+
* rather than causing out of memory errors.
264+
* @param options Options for writing from which we get the max rows option
265+
* @return The maximum number of rows a memorySink should store.
266+
*/
267+
def getMemorySinkCapacity(options: DataSourceOptions): Int = {
268+
val maxRows = options.getInt(MAX_MEMORY_SINK_ROWS, MAX_MEMORY_SINK_ROWS_DEFAULT)
269+
if (maxRows >= 0) maxRows else Int.MaxValue - 10
270+
}
271+
}
272+
224273
/**
225274
* A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
226275
* tests and does not provide durability.
227276
*/
228-
class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink with Logging {
277+
class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSourceOptions)
278+
extends Sink with MemorySinkBase with Logging {
229279

230280
private case class AddedData(batchId: Long, data: Array[Row])
231281

232282
/** An order list of batches that have been written to this [[Sink]]. */
233283
@GuardedBy("this")
234284
private val batches = new ArrayBuffer[AddedData]()
235285

286+
/** The number of rows in this MemorySink. */
287+
private var numRows = 0
288+
289+
/** The capacity in rows of this sink. */
290+
val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options)
291+
236292
/** Returns all rows that are stored in this [[Sink]]. */
237293
def allData: Seq[Row] = synchronized {
238-
batches.map(_.data).flatten
294+
batches.flatMap(_.data)
239295
}
240296

241297
def latestBatchId: Option[Long] = synchronized {
@@ -244,6 +300,10 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi
244300

245301
def latestBatchData: Seq[Row] = synchronized { batches.lastOption.toSeq.flatten(_.data) }
246302

303+
def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized {
304+
batches.filter(_.batchId > sinceBatchId).flatMap(_.data)
305+
}
306+
247307
def toDebugString: String = synchronized {
248308
batches.map { case AddedData(batchId, data) =>
249309
val dataStr = try data.mkString(" ") catch {
@@ -261,14 +321,23 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi
261321
logDebug(s"Committing batch $batchId to $this")
262322
outputMode match {
263323
case Append | Update =>
264-
val rows = AddedData(batchId, data.collect())
265-
synchronized { batches += rows }
324+
var rowsToAdd = data.collect()
325+
synchronized {
326+
rowsToAdd =
327+
truncateRowsIfNeeded(rowsToAdd, sinkCapacity - numRows, sinkCapacity, batchId)
328+
val rows = AddedData(batchId, rowsToAdd)
329+
batches += rows
330+
numRows += rowsToAdd.length
331+
}
266332

267333
case Complete =>
268-
val rows = AddedData(batchId, data.collect())
334+
var rowsToAdd = data.collect()
269335
synchronized {
336+
rowsToAdd = truncateRowsIfNeeded(rowsToAdd, sinkCapacity, sinkCapacity, batchId)
337+
val rows = AddedData(batchId, rowsToAdd)
270338
batches.clear()
271339
batches += rows
340+
numRows = rowsToAdd.length
272341
}
273342

274343
case _ =>
@@ -282,6 +351,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi
282351

283352
def clear(): Unit = synchronized {
284353
batches.clear()
354+
numRows = 0
285355
}
286356

287357
override def toString(): String = "MemorySink"

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.Row
2828
import org.apache.spark.sql.catalyst.expressions.Attribute
2929
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
3030
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update}
31-
import org.apache.spark.sql.execution.streaming.Sink
31+
import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink}
3232
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport}
3333
import org.apache.spark.sql.sources.v2.writer._
3434
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
@@ -39,13 +39,13 @@ import org.apache.spark.sql.types.StructType
3939
* A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
4040
* tests and does not provide durability.
4141
*/
42-
class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging {
42+
class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkBase with Logging {
4343
override def createStreamWriter(
4444
queryId: String,
4545
schema: StructType,
4646
mode: OutputMode,
4747
options: DataSourceOptions): StreamWriter = {
48-
new MemoryStreamWriter(this, mode)
48+
new MemoryStreamWriter(this, mode, options)
4949
}
5050

5151
private case class AddedData(batchId: Long, data: Array[Row])
@@ -54,6 +54,9 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging {
5454
@GuardedBy("this")
5555
private val batches = new ArrayBuffer[AddedData]()
5656

57+
/** The number of rows in this MemorySink. */
58+
private var numRows = 0
59+
5760
/** Returns all rows that are stored in this [[Sink]]. */
5861
def allData: Seq[Row] = synchronized {
5962
batches.flatMap(_.data)
@@ -67,6 +70,10 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging {
6770
batches.lastOption.toSeq.flatten(_.data)
6871
}
6972

73+
def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized {
74+
batches.filter(_.batchId > sinceBatchId).flatMap(_.data)
75+
}
76+
7077
def toDebugString: String = synchronized {
7178
batches.map { case AddedData(batchId, data) =>
7279
val dataStr = try data.mkString(" ") catch {
@@ -76,27 +83,38 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging {
7683
}.mkString("\n")
7784
}
7885

79-
def write(batchId: Long, outputMode: OutputMode, newRows: Array[Row]): Unit = {
86+
def write(
87+
batchId: Long,
88+
outputMode: OutputMode,
89+
newRows: Array[Row],
90+
sinkCapacity: Int): Unit = {
8091
val notCommitted = synchronized {
8192
latestBatchId.isEmpty || batchId > latestBatchId.get
8293
}
8394
if (notCommitted) {
8495
logDebug(s"Committing batch $batchId to $this")
8596
outputMode match {
8697
case Append | Update =>
87-
val rows = AddedData(batchId, newRows)
88-
synchronized { batches += rows }
98+
synchronized {
99+
val rowsToAdd =
100+
truncateRowsIfNeeded(newRows, sinkCapacity - numRows, sinkCapacity, batchId)
101+
val rows = AddedData(batchId, rowsToAdd)
102+
batches += rows
103+
numRows += rowsToAdd.length
104+
}
89105

90106
case Complete =>
91-
val rows = AddedData(batchId, newRows)
92107
synchronized {
108+
val rowsToAdd = truncateRowsIfNeeded(newRows, sinkCapacity, sinkCapacity, batchId)
109+
val rows = AddedData(batchId, rowsToAdd)
93110
batches.clear()
94111
batches += rows
112+
numRows = rowsToAdd.length
95113
}
96114

97115
case _ =>
98116
throw new IllegalArgumentException(
99-
s"Output mode $outputMode is not supported by MemorySink")
117+
s"Output mode $outputMode is not supported by MemorySinkV2")
100118
}
101119
} else {
102120
logDebug(s"Skipping already committed batch: $batchId")
@@ -105,40 +123,52 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging {
105123

106124
def clear(): Unit = synchronized {
107125
batches.clear()
126+
numRows = 0
108127
}
109128

110-
override def toString(): String = "MemorySink"
129+
override def toString(): String = "MemorySinkV2"
111130
}
112131

113132
case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {}
114133

115-
class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode)
134+
class MemoryWriter(
135+
sink: MemorySinkV2,
136+
batchId: Long,
137+
outputMode: OutputMode,
138+
options: DataSourceOptions)
116139
extends DataSourceWriter with Logging {
117140

141+
val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options)
142+
118143
override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode)
119144

120145
def commit(messages: Array[WriterCommitMessage]): Unit = {
121146
val newRows = messages.flatMap {
122147
case message: MemoryWriterCommitMessage => message.data
123148
}
124-
sink.write(batchId, outputMode, newRows)
149+
sink.write(batchId, outputMode, newRows, sinkCapacity)
125150
}
126151

127152
override def abort(messages: Array[WriterCommitMessage]): Unit = {
128153
// Don't accept any of the new input.
129154
}
130155
}
131156

132-
class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode)
157+
class MemoryStreamWriter(
158+
val sink: MemorySinkV2,
159+
outputMode: OutputMode,
160+
options: DataSourceOptions)
133161
extends StreamWriter {
134162

163+
val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options)
164+
135165
override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode)
136166

137167
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
138168
val newRows = messages.flatMap {
139169
case message: MemoryWriterCommitMessage => message.data
140170
}
141-
sink.write(epochId, outputMode, newRows)
171+
sink.write(epochId, outputMode, newRows, sinkCapacity)
142172
}
143173

144174
override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
@@ -175,7 +205,7 @@ class MemoryDataWriter(partition: Int, outputMode: OutputMode)
175205

176206

177207
/**
178-
* Used to query the data that has been written into a [[MemorySink]].
208+
* Used to query the data that has been written into a [[MemorySinkV2]].
179209
*/
180210
case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode {
181211
private val sizePerRow = output.map(_.dataType.defaultSize).sum

sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource
2929
import org.apache.spark.sql.execution.streaming._
3030
import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
3131
import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2}
32-
import org.apache.spark.sql.sources.v2.StreamWriteSupport
32+
import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport}
3333

3434
/**
3535
* Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems,
@@ -249,7 +249,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
249249
val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes))
250250
(s, r)
251251
case _ =>
252-
val s = new MemorySink(df.schema, outputMode)
252+
val s = new MemorySink(df.schema, outputMode, new DataSourceOptions(extraOptions.asJava))
253253
val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s))
254254
(s, r)
255255
}

0 commit comments

Comments
 (0)