Skip to content

Commit 29b8ab6

Browse files
committed
Added handling of task failures and send error msg to python
1 parent a1f811a commit 29b8ab6

File tree

2 files changed

+59
-29
lines changed

2 files changed

+59
-29
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,17 @@ private[spark] object PythonRDD extends Logging {
163163
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
164164
}
165165

166+
/**
167+
* A helper function to create a local RDD iterator and serve it via socket. Partitions are
168+
* are collected as separate jobs, by order of index. Partition data is first requested by a
169+
* non-zero integer to start a collection job. The response is prefaced by an integer with 1
170+
* meaning partition data will be served, 0 meaning the local iterator has been consumed,
171+
* and -1 meaining an error occurred during collection. This function is used by
172+
* pyspark.rdd._local_iterator_from_socket().
173+
*
174+
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
175+
* data collected from these jobs, and the secret for authentication.
176+
*/
166177
def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
167178
val (port, secret) = SocketAuthServer.setupOneConnectionServer(
168179
authHelper, "serve toLocalIterator") { s =>
@@ -177,17 +188,26 @@ private[spark] object PythonRDD extends Logging {
177188

178189
// Read request for data and send next partition if nonzero
179190
var complete = false
180-
while (in.readInt() != 0 && !complete) {
191+
while (!complete && in.readInt() != 0) {
181192
if (collectPartitionIter.hasNext) {
182-
183-
// Send response there is a partition to read
184-
out.writeInt(1)
185-
186-
// Write the next object and signal end of data for this iteration
187-
val partitionArray = collectPartitionIter.next()
188-
writeIteratorToStream(partitionArray.toIterator, out)
189-
out.writeInt(SpecialLengths.END_OF_DATA_SECTION)
190-
out.flush()
193+
try {
194+
// Attempt to collect the next partition
195+
val partitionArray = collectPartitionIter.next()
196+
197+
// Send response there is a partition to read
198+
out.writeInt(1)
199+
200+
// Write the next object and signal end of data for this iteration
201+
writeIteratorToStream(partitionArray.toIterator, out)
202+
out.writeInt(SpecialLengths.END_OF_DATA_SECTION)
203+
out.flush()
204+
} catch {
205+
case e: SparkException =>
206+
// Send response that an error occurred followed by error message
207+
out.writeInt(-1)
208+
writeUTF(e.getMessage, out)
209+
complete = true
210+
}
191211
} else {
192212
// Send response there are no more partitions to read and close
193213
out.writeInt(0)

python/pyspark/rdd.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@
3939
from itertools import imap as map, ifilter as filter
4040

4141
from pyspark.java_gateway import local_connect_and_auth
42-
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
43-
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
44-
PickleSerializer, pack_long, AutoBatchedSerializer, read_int, write_int
42+
from pyspark.serializers import AutoBatchedSerializer, BatchedSerializer, NoOpSerializer, \
43+
CartesianDeserializer, CloudPickleSerializer, PairDeserializer, PickleSerializer, \
44+
UTF8Deserializer, pack_long, read_int, write_int
4545
from pyspark.join import python_join, python_left_outer_join, \
4646
python_right_outer_join, python_full_outer_join, python_cogroup
4747
from pyspark.statcounter import StatCounter
@@ -161,32 +161,42 @@ def __init__(self, _sock_info, _serializer):
161161
self._sockfile = _create_local_socket(_sock_info)
162162
self._serializer = _serializer
163163
self._read_iter = iter([]) # Initialize as empty iterator
164+
self._read_status = 1
164165

165166
def __iter__(self):
166-
while True:
167+
while self._read_status == 1:
167168
# Request next partition data from Java
168169
write_int(1, self._sockfile)
169170
self._sockfile.flush()
170171

171-
# If nonzero response, then there is a partition to read
172-
if read_int(self._sockfile) == 0:
173-
break
172+
# If response is 1, then there is a partition to read
173+
self._read_status = read_int(self._sockfile)
174+
if self._read_status == 1:
174175

175-
# Load the partition data as a stream and read each item
176-
self._read_iter = self._serializer.load_stream(self._sockfile)
177-
for item in self._read_iter:
178-
yield item
176+
# Load the partition data as a stream and read each item
177+
self._read_iter = self._serializer.load_stream(self._sockfile)
178+
for item in self._read_iter:
179+
yield item
180+
181+
# An error occurred, read error message and raise it
182+
elif self._read_status == -1:
183+
error_msg = UTF8Deserializer().loads(self._sockfile)
184+
raise RuntimeError("An error occurred while reading the next element from "
185+
"toLocalIterator: {}".format(error_msg))
179186

180187
def __del__(self):
181-
try:
182-
# Finish consuming partition data stream
183-
for _ in self._read_iter:
188+
# If local iterator is not fully consumed,
189+
if self._read_status == 1:
190+
try:
191+
# Finish consuming partition data stream
192+
for _ in self._read_iter:
193+
pass
194+
# Tell Java to stop sending data and close connection
195+
write_int(0, self._sockfile)
196+
self._sockfile.flush()
197+
except Exception:
198+
# Ignore any errors, socket is automatically closed when garbage-collected
184199
pass
185-
# Tell Java to stop sending data and close connection
186-
write_int(0, self._sockfile)
187-
self._sockfile.flush()
188-
except Exception:
189-
pass # Ignore any errors, socket is automatically closed when garbage-collected
190200

191201
return iter(PyLocalIterable(sock_info, serializer))
192202

0 commit comments

Comments
 (0)