|
39 | 39 | from itertools import imap as map, ifilter as filter |
40 | 40 |
|
41 | 41 | from pyspark.java_gateway import local_connect_and_auth |
42 | | -from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ |
43 | | - BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ |
44 | | - PickleSerializer, pack_long, AutoBatchedSerializer, read_int, write_int |
| 42 | +from pyspark.serializers import AutoBatchedSerializer, BatchedSerializer, NoOpSerializer, \ |
| 43 | + CartesianDeserializer, CloudPickleSerializer, PairDeserializer, PickleSerializer, \ |
| 44 | + UTF8Deserializer, pack_long, read_int, write_int |
45 | 45 | from pyspark.join import python_join, python_left_outer_join, \ |
46 | 46 | python_right_outer_join, python_full_outer_join, python_cogroup |
47 | 47 | from pyspark.statcounter import StatCounter |
@@ -161,32 +161,42 @@ def __init__(self, _sock_info, _serializer): |
161 | 161 | self._sockfile = _create_local_socket(_sock_info) |
162 | 162 | self._serializer = _serializer |
163 | 163 | self._read_iter = iter([]) # Initialize as empty iterator |
| 164 | + self._read_status = 1 |
164 | 165 |
|
165 | 166 | def __iter__(self): |
166 | | - while True: |
| 167 | + while self._read_status == 1: |
167 | 168 | # Request next partition data from Java |
168 | 169 | write_int(1, self._sockfile) |
169 | 170 | self._sockfile.flush() |
170 | 171 |
|
171 | | - # If nonzero response, then there is a partition to read |
172 | | - if read_int(self._sockfile) == 0: |
173 | | - break |
| 172 | + # If response is 1, then there is a partition to read |
| 173 | + self._read_status = read_int(self._sockfile) |
| 174 | + if self._read_status == 1: |
174 | 175 |
|
175 | | - # Load the partition data as a stream and read each item |
176 | | - self._read_iter = self._serializer.load_stream(self._sockfile) |
177 | | - for item in self._read_iter: |
178 | | - yield item |
| 176 | + # Load the partition data as a stream and read each item |
| 177 | + self._read_iter = self._serializer.load_stream(self._sockfile) |
| 178 | + for item in self._read_iter: |
| 179 | + yield item |
| 180 | + |
| 181 | + # An error occurred, read error message and raise it |
| 182 | + elif self._read_status == -1: |
| 183 | + error_msg = UTF8Deserializer().loads(self._sockfile) |
| 184 | + raise RuntimeError("An error occurred while reading the next element from " |
| 185 | + "toLocalIterator: {}".format(error_msg)) |
179 | 186 |
|
180 | 187 | def __del__(self): |
181 | | - try: |
182 | | - # Finish consuming partition data stream |
183 | | - for _ in self._read_iter: |
| 188 | + # If local iterator is not fully consumed, |
| 189 | + if self._read_status == 1: |
| 190 | + try: |
| 191 | + # Finish consuming partition data stream |
| 192 | + for _ in self._read_iter: |
| 193 | + pass |
| 194 | + # Tell Java to stop sending data and close connection |
| 195 | + write_int(0, self._sockfile) |
| 196 | + self._sockfile.flush() |
| 197 | + except Exception: |
| 198 | + # Ignore any errors, socket is automatically closed when garbage-collected |
184 | 199 | pass |
185 | | - # Tell Java to stop sending data and close connection |
186 | | - write_int(0, self._sockfile) |
187 | | - self._sockfile.flush() |
188 | | - except Exception: |
189 | | - pass # Ignore any errors, socket is automatically closed when garbage-collected |
190 | 200 |
|
191 | 201 | return iter(PyLocalIterable(sock_info, serializer)) |
192 | 202 |
|
|
0 commit comments