@@ -28,7 +28,7 @@ import org.apache.spark.sql.Row
2828import org .apache .spark .sql .catalyst .expressions .Attribute
2929import org .apache .spark .sql .catalyst .plans .logical .{LeafNode , Statistics }
3030import org .apache .spark .sql .catalyst .streaming .InternalOutputModes .{Append , Complete , Update }
31- import org .apache .spark .sql .execution .streaming .Sink
31+ import org .apache .spark .sql .execution .streaming .{ MemorySinkBase , Sink }
3232import org .apache .spark .sql .sources .v2 .{DataSourceOptions , DataSourceV2 , StreamWriteSupport }
3333import org .apache .spark .sql .sources .v2 .writer ._
3434import org .apache .spark .sql .sources .v2 .writer .streaming .StreamWriter
@@ -39,13 +39,13 @@ import org.apache.spark.sql.types.StructType
3939 * A sink that stores the results in memory. This [[Sink ]] is primarily intended for use in unit
4040 * tests and does not provide durability.
4141 */
42- class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging {
42+ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkBase with Logging {
4343 override def createStreamWriter (
4444 queryId : String ,
4545 schema : StructType ,
4646 mode : OutputMode ,
4747 options : DataSourceOptions ): StreamWriter = {
48- new MemoryStreamWriter (this , mode)
48+ new MemoryStreamWriter (this , mode, options )
4949 }
5050
5151 private case class AddedData (batchId : Long , data : Array [Row ])
@@ -54,6 +54,9 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging {
5454 @ GuardedBy (" this" )
5555 private val batches = new ArrayBuffer [AddedData ]()
5656
57+ /** The number of rows in this MemorySink. */
58+ private var numRows = 0
59+
5760 /** Returns all rows that are stored in this [[Sink ]]. */
5861 def allData : Seq [Row ] = synchronized {
5962 batches.flatMap(_.data)
@@ -67,6 +70,10 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging {
6770 batches.lastOption.toSeq.flatten(_.data)
6871 }
6972
73+ def dataSinceBatch (sinceBatchId : Long ): Seq [Row ] = synchronized {
74+ batches.filter(_.batchId > sinceBatchId).flatMap(_.data)
75+ }
76+
7077 def toDebugString : String = synchronized {
7178 batches.map { case AddedData (batchId, data) =>
7279 val dataStr = try data.mkString(" " ) catch {
@@ -76,27 +83,38 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging {
7683 }.mkString(" \n " )
7784 }
7885
79- def write (batchId : Long , outputMode : OutputMode , newRows : Array [Row ]): Unit = {
86+ def write (
87+ batchId : Long ,
88+ outputMode : OutputMode ,
89+ newRows : Array [Row ],
90+ sinkCapacity : Int ): Unit = {
8091 val notCommitted = synchronized {
8192 latestBatchId.isEmpty || batchId > latestBatchId.get
8293 }
8394 if (notCommitted) {
8495 logDebug(s " Committing batch $batchId to $this" )
8596 outputMode match {
8697 case Append | Update =>
87- val rows = AddedData (batchId, newRows)
88- synchronized { batches += rows }
98+ synchronized {
99+ val rowsToAdd =
100+ truncateRowsIfNeeded(newRows, sinkCapacity - numRows, sinkCapacity, batchId)
101+ val rows = AddedData (batchId, rowsToAdd)
102+ batches += rows
103+ numRows += rowsToAdd.length
104+ }
89105
90106 case Complete =>
91- val rows = AddedData (batchId, newRows)
92107 synchronized {
108+ val rowsToAdd = truncateRowsIfNeeded(newRows, sinkCapacity, sinkCapacity, batchId)
109+ val rows = AddedData (batchId, rowsToAdd)
93110 batches.clear()
94111 batches += rows
112+ numRows = rowsToAdd.length
95113 }
96114
97115 case _ =>
98116 throw new IllegalArgumentException (
99- s " Output mode $outputMode is not supported by MemorySink " )
117+ s " Output mode $outputMode is not supported by MemorySinkV2 " )
100118 }
101119 } else {
102120 logDebug(s " Skipping already committed batch: $batchId" )
@@ -105,40 +123,52 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging {
105123
106124 def clear (): Unit = synchronized {
107125 batches.clear()
126+ numRows = 0
108127 }
109128
110- override def toString (): String = " MemorySink "
129+ override def toString (): String = " MemorySinkV2 "
111130}
112131
113132case class MemoryWriterCommitMessage (partition : Int , data : Seq [Row ]) extends WriterCommitMessage {}
114133
115- class MemoryWriter (sink : MemorySinkV2 , batchId : Long , outputMode : OutputMode )
134+ class MemoryWriter (
135+ sink : MemorySinkV2 ,
136+ batchId : Long ,
137+ outputMode : OutputMode ,
138+ options : DataSourceOptions )
116139 extends DataSourceWriter with Logging {
117140
141+ val sinkCapacity : Int = MemorySinkBase .getMemorySinkCapacity(options)
142+
118143 override def createWriterFactory : MemoryWriterFactory = MemoryWriterFactory (outputMode)
119144
120145 def commit (messages : Array [WriterCommitMessage ]): Unit = {
121146 val newRows = messages.flatMap {
122147 case message : MemoryWriterCommitMessage => message.data
123148 }
124- sink.write(batchId, outputMode, newRows)
149+ sink.write(batchId, outputMode, newRows, sinkCapacity )
125150 }
126151
127152 override def abort (messages : Array [WriterCommitMessage ]): Unit = {
128153 // Don't accept any of the new input.
129154 }
130155}
131156
132- class MemoryStreamWriter (val sink : MemorySinkV2 , outputMode : OutputMode )
157+ class MemoryStreamWriter (
158+ val sink : MemorySinkV2 ,
159+ outputMode : OutputMode ,
160+ options : DataSourceOptions )
133161 extends StreamWriter {
134162
163+ val sinkCapacity : Int = MemorySinkBase .getMemorySinkCapacity(options)
164+
135165 override def createWriterFactory : MemoryWriterFactory = MemoryWriterFactory (outputMode)
136166
137167 override def commit (epochId : Long , messages : Array [WriterCommitMessage ]): Unit = {
138168 val newRows = messages.flatMap {
139169 case message : MemoryWriterCommitMessage => message.data
140170 }
141- sink.write(epochId, outputMode, newRows)
171+ sink.write(epochId, outputMode, newRows, sinkCapacity )
142172 }
143173
144174 override def abort (epochId : Long , messages : Array [WriterCommitMessage ]): Unit = {
@@ -175,7 +205,7 @@ class MemoryDataWriter(partition: Int, outputMode: OutputMode)
175205
176206
177207/**
178- * Used to query the data that has been written into a [[MemorySink ]].
208+ * Used to query the data that has been written into a [[MemorySinkV2 ]].
179209 */
180210case class MemoryPlanV2 (sink : MemorySinkV2 , override val output : Seq [Attribute ]) extends LeafNode {
181211 private val sizePerRow = output.map(_.dataType.defaultSize).sum
0 commit comments