Skip to content
Closed
33 changes: 33 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,39 @@ def loads(self, obj):
raise NotImplementedError


class ArrowCollectSerializer(Serializer):
"""
Deserialize a stream of batches followed by batch order information. Used in
DataFrame._collectAsArrow() after invoking Dataset.collectAsArrowToPython() in the JVM.
"""

def __init__(self):
self.serializer = ArrowStreamSerializer()

def dump_stream(self, iterator, stream):
return self.serializer.dump_stream(iterator, stream)

def load_stream(self, stream):
"""
Load a stream of un-ordered Arrow RecordBatches, where the last iteration yields
a list of indices that can be used to put the RecordBatches in the correct order.
"""
# load the batches
for batch in self.serializer.load_stream(stream):
yield batch

# load the batch order indices
num = read_int(stream)
batch_order = []
for i in xrange(num):
index = read_int(stream)
batch_order.append(index)
yield batch_order

def __repr__(self):
return "ArrowCollectSerializer(%s)" % self.serializer


class ArrowStreamSerializer(Serializer):
"""
Serializes Arrow record batches as a stream.
Expand Down
11 changes: 9 additions & 2 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from pyspark import copy_func, since, _NoValue
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, PickleSerializer, \
from pyspark.serializers import ArrowCollectSerializer, BatchedSerializer, PickleSerializer, \
UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
Expand Down Expand Up @@ -2168,7 +2168,14 @@ def _collectAsArrow(self):
"""
with SCCallSiteSync(self._sc) as css:
sock_info = self._jdf.collectAsArrowToPython()
return list(_load_from_socket(sock_info, ArrowStreamSerializer()))

# Collect list of un-ordered batches where last element is a list of correct order indices
results = list(_load_from_socket(sock_info, ArrowCollectSerializer()))
batches = results[:-1]
batch_order = results[-1]

# Re-order the batch list using the correct order
return [batches[i] for i in batch_order]

##########################################################################################
# Pandas compatibility
Expand Down
28 changes: 28 additions & 0 deletions python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,34 @@ def test_timestamp_dst(self):
self.assertPandasEqual(pdf, df_from_python.toPandas())
self.assertPandasEqual(pdf, df_from_pandas.toPandas())

def test_toPandas_batch_order(self):

def delay_first_part(partition_index, iterator):
if partition_index == 0:
time.sleep(0.1)
return iterator

# Collects Arrow RecordBatches out of order in driver JVM then re-orders in Python
def run_test(num_records, num_parts, max_records, use_delay=False):
df = self.spark.range(num_records, numPartitions=num_parts).toDF("a")
if use_delay:
df = df.rdd.mapPartitionsWithIndex(delay_first_part).toDF()
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": max_records}):
pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
self.assertPandasEqual(pdf, pdf_arrow)

cases = [
(1024, 512, 2), # Use large num partitions for more likely collecting out of order
(64, 8, 2, True), # Use delay in first partition to force collecting out of order
(64, 64, 1), # Test single batch per partition
(64, 1, 64), # Test single partition, single batch
(64, 1, 8), # Test single partition, multiple batches
(30, 7, 2), # Test different sized partitions
]

for case in cases:
run_test(*case)


class EncryptionArrowTests(ArrowTests):

Expand Down
45 changes: 25 additions & 20 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.sql

import java.io.CharArrayWriter
import java.io.{CharArrayWriter, DataOutputStream}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions
import scala.util.control.NonFatal

Expand Down Expand Up @@ -3200,34 +3201,38 @@ class Dataset[T] private[sql](
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone

withAction("collectAsArrowToPython", queryExecution) { plan =>
PythonRDD.serveToStream("serve-Arrow") { out =>
PythonRDD.serveToStream("serve-Arrow") { outputStream =>
val out = new DataOutputStream(outputStream)
val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId)
val arrowBatchRdd = toArrowBatchRdd(plan)
val numPartitions = arrowBatchRdd.partitions.length

// Store collection results for worst case of 1 to N-1 partitions
val results = new Array[Array[Array[Byte]]](numPartitions - 1)
var lastIndex = -1 // index of last partition written
// Batches ordered by (index of partition, batch index in that partition) tuple
val batchOrder = new ArrayBuffer[(Int, Int)]()
var partitionCount = 0

// Handler to eagerly write partitions to Python in order
// Handler to eagerly write batches to Python as they arrive, un-ordered
def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = {
// If result is from next partition in order
if (index - 1 == lastIndex) {
if (arrowBatches.nonEmpty) {
// Write all batches (can be more than 1) in the partition, store the batch order tuple
batchWriter.writeBatches(arrowBatches.iterator)
lastIndex += 1
// Write stored partitions that come next in order
while (lastIndex < results.length && results(lastIndex) != null) {
batchWriter.writeBatches(results(lastIndex).iterator)
results(lastIndex) = null
lastIndex += 1
arrowBatches.indices.foreach {
partition_batch_index => batchOrder.append((index, partition_batch_index))
}
// After last batch, end the stream
if (lastIndex == results.length) {
batchWriter.end()
}
partitionCount += 1

// After last batch, end the stream and write batch order indices
if (partitionCount == numPartitions) {
batchWriter.end()
out.writeInt(batchOrder.length)
// Sort by (index of partition, batch index in that partition) tuple to get the
// overall_batch_index from 0 to N-1 batches, which can be used to put the
// transferred batches in the correct order
batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overall_batch_index) =>
out.writeInt(overall_batch_index)
}
} else {
// Store partitions received out of order
results(index - 1) = arrowBatches
out.flush()
}
}

Expand Down