Skip to content

Commit e4fee39

Browse files
Mukul Murthybrkyvz
authored andcommitted
[SPARK-24525][SS] Provide an option to limit number of rows in a MemorySink
## What changes were proposed in this pull request? Provide an option to limit number of rows in a MemorySink. Currently, MemorySink and MemorySinkV2 have unbounded size, meaning that if they're used on big data, they can OOM the stream. This change adds a maxMemorySinkRows option to limit how many rows MemorySink and MemorySinkV2 can hold. By default, they are still unbounded. ## How was this patch tested? Added new unit tests. Author: Mukul Murthy <[email protected]> Closes #21559 from mukulmurthy/SPARK-24525.
1 parent 90da7dc commit e4fee39

File tree

6 files changed

+239
-25
lines changed

6 files changed

+239
-25
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
3333
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
3434
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
3535
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
36+
import org.apache.spark.sql.sources.v2.DataSourceOptions
3637
import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow}
3738
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
3839
import org.apache.spark.sql.streaming.OutputMode
@@ -221,26 +222,73 @@ class MemoryStreamInputPartition(records: Array[UnsafeRow])
221222
}
222223

223224
/** A common trait for MemorySinks with methods used for testing */
224-
trait MemorySinkBase extends BaseStreamingSink {
225+
trait MemorySinkBase extends BaseStreamingSink with Logging {
225226
def allData: Seq[Row]
226227
def latestBatchData: Seq[Row]
227228
def dataSinceBatch(sinceBatchId: Long): Seq[Row]
228229
def latestBatchId: Option[Long]
230+
231+
/**
232+
* Truncates the given rows to return at most maxRows rows.
233+
* @param rows The data that may need to be truncated.
234+
* @param batchLimit Number of rows to keep in this batch; the rest will be truncated
235+
* @param sinkLimit Total number of rows kept in this sink, for logging purposes.
236+
* @param batchId The ID of the batch that sent these rows, for logging purposes.
237+
* @return Truncated rows.
238+
*/
239+
protected def truncateRowsIfNeeded(
240+
rows: Array[Row],
241+
batchLimit: Int,
242+
sinkLimit: Int,
243+
batchId: Long): Array[Row] = {
244+
if (rows.length > batchLimit && batchLimit >= 0) {
245+
logWarning(s"Truncating batch $batchId to $batchLimit rows because of sink limit $sinkLimit")
246+
rows.take(batchLimit)
247+
} else {
248+
rows
249+
}
250+
}
251+
}
252+
253+
/**
254+
* Companion object to MemorySinkBase.
255+
*/
256+
object MemorySinkBase {
257+
val MAX_MEMORY_SINK_ROWS = "maxRows"
258+
val MAX_MEMORY_SINK_ROWS_DEFAULT = -1
259+
260+
/**
261+
* Gets the max number of rows a MemorySink should store. This number is based on the memory
262+
* sink row limit option if it is set. If not, we use a large value so that data truncates
263+
* rather than causing out of memory errors.
264+
* @param options Options for writing from which we get the max rows option
265+
* @return The maximum number of rows a memorySink should store.
266+
*/
267+
def getMemorySinkCapacity(options: DataSourceOptions): Int = {
268+
val maxRows = options.getInt(MAX_MEMORY_SINK_ROWS, MAX_MEMORY_SINK_ROWS_DEFAULT)
269+
if (maxRows >= 0) maxRows else Int.MaxValue - 10
270+
}
229271
}
230272

231273
/**
232274
* A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
233275
* tests and does not provide durability.
234276
*/
235-
class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink
236-
with MemorySinkBase with Logging {
277+
class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSourceOptions)
278+
extends Sink with MemorySinkBase with Logging {
237279

238280
private case class AddedData(batchId: Long, data: Array[Row])
239281

240282
/** An order list of batches that have been written to this [[Sink]]. */
241283
@GuardedBy("this")
242284
private val batches = new ArrayBuffer[AddedData]()
243285

286+
/** The number of rows in this MemorySink. */
287+
private var numRows = 0
288+
289+
/** The capacity in rows of this sink. */
290+
val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options)
291+
244292
/** Returns all rows that are stored in this [[Sink]]. */
245293
def allData: Seq[Row] = synchronized {
246294
batches.flatMap(_.data)
@@ -273,14 +321,23 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink
273321
logDebug(s"Committing batch $batchId to $this")
274322
outputMode match {
275323
case Append | Update =>
276-
val rows = AddedData(batchId, data.collect())
277-
synchronized { batches += rows }
324+
var rowsToAdd = data.collect()
325+
synchronized {
326+
rowsToAdd =
327+
truncateRowsIfNeeded(rowsToAdd, sinkCapacity - numRows, sinkCapacity, batchId)
328+
val rows = AddedData(batchId, rowsToAdd)
329+
batches += rows
330+
numRows += rowsToAdd.length
331+
}
278332

279333
case Complete =>
280-
val rows = AddedData(batchId, data.collect())
334+
var rowsToAdd = data.collect()
281335
synchronized {
336+
rowsToAdd = truncateRowsIfNeeded(rowsToAdd, sinkCapacity, sinkCapacity, batchId)
337+
val rows = AddedData(batchId, rowsToAdd)
282338
batches.clear()
283339
batches += rows
340+
numRows = rowsToAdd.length
284341
}
285342

286343
case _ =>
@@ -294,6 +351,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink
294351

295352
def clear(): Unit = synchronized {
296353
batches.clear()
354+
numRows = 0
297355
}
298356

299357
override def toString(): String = "MemorySink"

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB
4646
schema: StructType,
4747
mode: OutputMode,
4848
options: DataSourceOptions): StreamWriter = {
49-
new MemoryStreamWriter(this, mode)
49+
new MemoryStreamWriter(this, mode, options)
5050
}
5151

5252
private case class AddedData(batchId: Long, data: Array[Row])
@@ -55,6 +55,9 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB
5555
@GuardedBy("this")
5656
private val batches = new ArrayBuffer[AddedData]()
5757

58+
/** The number of rows in this MemorySink. */
59+
private var numRows = 0
60+
5861
/** Returns all rows that are stored in this [[Sink]]. */
5962
def allData: Seq[Row] = synchronized {
6063
batches.flatMap(_.data)
@@ -81,22 +84,33 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB
8184
}.mkString("\n")
8285
}
8386

84-
def write(batchId: Long, outputMode: OutputMode, newRows: Array[Row]): Unit = {
87+
def write(
88+
batchId: Long,
89+
outputMode: OutputMode,
90+
newRows: Array[Row],
91+
sinkCapacity: Int): Unit = {
8592
val notCommitted = synchronized {
8693
latestBatchId.isEmpty || batchId > latestBatchId.get
8794
}
8895
if (notCommitted) {
8996
logDebug(s"Committing batch $batchId to $this")
9097
outputMode match {
9198
case Append | Update =>
92-
val rows = AddedData(batchId, newRows)
93-
synchronized { batches += rows }
99+
synchronized {
100+
val rowsToAdd =
101+
truncateRowsIfNeeded(newRows, sinkCapacity - numRows, sinkCapacity, batchId)
102+
val rows = AddedData(batchId, rowsToAdd)
103+
batches += rows
104+
numRows += rowsToAdd.length
105+
}
94106

95107
case Complete =>
96-
val rows = AddedData(batchId, newRows)
97108
synchronized {
109+
val rowsToAdd = truncateRowsIfNeeded(newRows, sinkCapacity, sinkCapacity, batchId)
110+
val rows = AddedData(batchId, rowsToAdd)
98111
batches.clear()
99112
batches += rows
113+
numRows = rowsToAdd.length
100114
}
101115

102116
case _ =>
@@ -110,40 +124,52 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB
110124

111125
def clear(): Unit = synchronized {
112126
batches.clear()
127+
numRows = 0
113128
}
114129

115130
override def toString(): String = "MemorySinkV2"
116131
}
117132

118133
case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {}
119134

120-
class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode)
135+
class MemoryWriter(
136+
sink: MemorySinkV2,
137+
batchId: Long,
138+
outputMode: OutputMode,
139+
options: DataSourceOptions)
121140
extends DataSourceWriter with Logging {
122141

142+
val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options)
143+
123144
override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode)
124145

125146
def commit(messages: Array[WriterCommitMessage]): Unit = {
126147
val newRows = messages.flatMap {
127148
case message: MemoryWriterCommitMessage => message.data
128149
}
129-
sink.write(batchId, outputMode, newRows)
150+
sink.write(batchId, outputMode, newRows, sinkCapacity)
130151
}
131152

132153
override def abort(messages: Array[WriterCommitMessage]): Unit = {
133154
// Don't accept any of the new input.
134155
}
135156
}
136157

137-
class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode)
158+
class MemoryStreamWriter(
159+
val sink: MemorySinkV2,
160+
outputMode: OutputMode,
161+
options: DataSourceOptions)
138162
extends StreamWriter {
139163

164+
val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options)
165+
140166
override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode)
141167

142168
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
143169
val newRows = messages.flatMap {
144170
case message: MemoryWriterCommitMessage => message.data
145171
}
146-
sink.write(epochId, outputMode, newRows)
172+
sink.write(epochId, outputMode, newRows, sinkCapacity)
147173
}
148174

149175
override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {

sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource
2929
import org.apache.spark.sql.execution.streaming._
3030
import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
3131
import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2}
32-
import org.apache.spark.sql.sources.v2.StreamWriteSupport
32+
import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport}
3333

3434
/**
3535
* Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems,
@@ -249,7 +249,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
249249
val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes))
250250
(s, r)
251251
case _ =>
252-
val s = new MemorySink(df.schema, outputMode)
252+
val s = new MemorySink(df.schema, outputMode, new DataSourceOptions(extraOptions.asJava))
253253
val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s))
254254
(s, r)
255255
}

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717

1818
package org.apache.spark.sql.execution.streaming
1919

20+
import scala.collection.JavaConverters._
2021
import scala.language.implicitConversions
2122

2223
import org.scalatest.BeforeAndAfter
2324

2425
import org.apache.spark.sql._
26+
import org.apache.spark.sql.sources.v2.DataSourceOptions
2527
import org.apache.spark.sql.streaming.{OutputMode, StreamTest}
2628
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
2729
import org.apache.spark.util.Utils
@@ -36,7 +38,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
3638

3739
test("directly add data in Append output mode") {
3840
implicit val schema = new StructType().add(new StructField("value", IntegerType))
39-
val sink = new MemorySink(schema, OutputMode.Append)
41+
val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty())
4042

4143
// Before adding data, check output
4244
assert(sink.latestBatchId === None)
@@ -68,9 +70,35 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
6870
checkAnswer(sink.allData, 1 to 9)
6971
}
7072

73+
test("directly add data in Append output mode with row limit") {
74+
implicit val schema = new StructType().add(new StructField("value", IntegerType))
75+
76+
var optionsMap = new scala.collection.mutable.HashMap[String, String]
77+
optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString())
78+
var options = new DataSourceOptions(optionsMap.toMap.asJava)
79+
val sink = new MemorySink(schema, OutputMode.Append, options)
80+
81+
// Before adding data, check output
82+
assert(sink.latestBatchId === None)
83+
checkAnswer(sink.latestBatchData, Seq.empty)
84+
checkAnswer(sink.allData, Seq.empty)
85+
86+
// Add batch 0 and check outputs
87+
sink.addBatch(0, 1 to 3)
88+
assert(sink.latestBatchId === Some(0))
89+
checkAnswer(sink.latestBatchData, 1 to 3)
90+
checkAnswer(sink.allData, 1 to 3)
91+
92+
// Add batch 1 and check outputs
93+
sink.addBatch(1, 4 to 6)
94+
assert(sink.latestBatchId === Some(1))
95+
checkAnswer(sink.latestBatchData, 4 to 5)
96+
checkAnswer(sink.allData, 1 to 5) // new data should not go over the limit
97+
}
98+
7199
test("directly add data in Update output mode") {
72100
implicit val schema = new StructType().add(new StructField("value", IntegerType))
73-
val sink = new MemorySink(schema, OutputMode.Update)
101+
val sink = new MemorySink(schema, OutputMode.Update, DataSourceOptions.empty())
74102

75103
// Before adding data, check output
76104
assert(sink.latestBatchId === None)
@@ -104,7 +132,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
104132

105133
test("directly add data in Complete output mode") {
106134
implicit val schema = new StructType().add(new StructField("value", IntegerType))
107-
val sink = new MemorySink(schema, OutputMode.Complete)
135+
val sink = new MemorySink(schema, OutputMode.Complete, DataSourceOptions.empty())
108136

109137
// Before adding data, check output
110138
assert(sink.latestBatchId === None)
@@ -136,6 +164,32 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
136164
checkAnswer(sink.allData, 7 to 9)
137165
}
138166

167+
test("directly add data in Complete output mode with row limit") {
168+
implicit val schema = new StructType().add(new StructField("value", IntegerType))
169+
170+
var optionsMap = new scala.collection.mutable.HashMap[String, String]
171+
optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString())
172+
var options = new DataSourceOptions(optionsMap.toMap.asJava)
173+
val sink = new MemorySink(schema, OutputMode.Complete, options)
174+
175+
// Before adding data, check output
176+
assert(sink.latestBatchId === None)
177+
checkAnswer(sink.latestBatchData, Seq.empty)
178+
checkAnswer(sink.allData, Seq.empty)
179+
180+
// Add batch 0 and check outputs
181+
sink.addBatch(0, 1 to 3)
182+
assert(sink.latestBatchId === Some(0))
183+
checkAnswer(sink.latestBatchData, 1 to 3)
184+
checkAnswer(sink.allData, 1 to 3)
185+
186+
// Add batch 1 and check outputs
187+
sink.addBatch(1, 4 to 10)
188+
assert(sink.latestBatchId === Some(1))
189+
checkAnswer(sink.latestBatchData, 4 to 8)
190+
checkAnswer(sink.allData, 4 to 8) // new data should replace old data
191+
}
192+
139193

140194
test("registering as a table in Append output mode") {
141195
val input = MemoryStream[Int]
@@ -211,7 +265,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
211265

212266
test("MemoryPlan statistics") {
213267
implicit val schema = new StructType().add(new StructField("value", IntegerType))
214-
val sink = new MemorySink(schema, OutputMode.Append)
268+
val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty())
215269
val plan = new MemoryPlan(sink)
216270

217271
// Before adding data, check output

0 commit comments

Comments
 (0)