|
9 | 9 | import logging |
10 | 10 | import os |
11 | 11 | import redis |
| 12 | +from six.moves import queue |
12 | 13 | import sys |
13 | 14 | import threading |
14 | 15 | import time |
|
68 | 69 | logger = logging.getLogger(__name__) |
69 | 70 |
|
70 | 71 |
|
71 | | -# Visible for testing. |
72 | | -def _unhandled_error_handler(e: Exception): |
73 | | - logger.error("Unhandled error (suppress with " |
74 | | - "RAY_IGNORE_UNHANDLED_ERRORS=1): {}".format(e)) |
75 | | - |
76 | | - |
77 | 72 | class Worker: |
78 | 73 | """A class used to define the control flow of a worker process. |
79 | 74 |
|
@@ -282,14 +277,6 @@ def put_object(self, value, object_ref=None): |
282 | 277 | self.core_worker.put_serialized_object( |
283 | 278 | serialized_value, object_ref=object_ref)) |
284 | 279 |
|
285 | | - def raise_errors(self, data_metadata_pairs, object_refs): |
286 | | - context = self.get_serialization_context() |
287 | | - out = context.deserialize_objects(data_metadata_pairs, object_refs) |
288 | | - if "RAY_IGNORE_UNHANDLED_ERRORS" in os.environ: |
289 | | - return |
290 | | - for e in out: |
291 | | - _unhandled_error_handler(e) |
292 | | - |
293 | 280 | def deserialize_objects(self, data_metadata_pairs, object_refs): |
294 | 281 | context = self.get_serialization_context() |
295 | 282 | return context.deserialize_objects(data_metadata_pairs, object_refs) |
@@ -878,6 +865,13 @@ def custom_excepthook(type, value, tb): |
878 | 865 |
|
879 | 866 | sys.excepthook = custom_excepthook |
880 | 867 |
|
| 868 | +# The last time we raised a TaskError in this process. We use this value to |
| 869 | +# suppress redundant error messages pushed from the workers. |
| 870 | +last_task_error_raise_time = 0 |
| 871 | + |
| 872 | +# The max amount of seconds to wait before printing out an uncaught error. |
| 873 | +UNCAUGHT_ERROR_GRACE_PERIOD = 5 |
| 874 | + |
881 | 875 |
|
882 | 876 | def print_logs(redis_client, threads_stopped, job_id): |
883 | 877 | """Prints log messages from workers on all of the nodes. |
@@ -1028,14 +1022,51 @@ def color_for(data: Dict[str, str]) -> str: |
1028 | 1022 | file=print_file) |
1029 | 1023 |
|
1030 | 1024 |
|
1031 | | -def listen_error_messages_raylet(worker, threads_stopped): |
| 1025 | +def print_error_messages_raylet(task_error_queue, threads_stopped): |
| 1026 | + """Prints message received in the given output queue. |
| 1027 | +
|
| 1028 | + This checks periodically if any un-raised errors occurred in the |
| 1029 | + background. |
| 1030 | +
|
| 1031 | + Args: |
| 1032 | + task_error_queue (queue.Queue): A queue used to receive errors from the |
| 1033 | + thread that listens to Redis. |
| 1034 | + threads_stopped (threading.Event): A threading event used to signal to |
| 1035 | + the thread that it should exit. |
| 1036 | + """ |
| 1037 | + |
| 1038 | + while True: |
| 1039 | + # Exit if we received a signal that we should stop. |
| 1040 | + if threads_stopped.is_set(): |
| 1041 | + return |
| 1042 | + |
| 1043 | + try: |
| 1044 | + error, t = task_error_queue.get(block=False) |
| 1045 | + except queue.Empty: |
| 1046 | + threads_stopped.wait(timeout=0.01) |
| 1047 | + continue |
| 1048 | + # Delay errors a little bit of time to attempt to suppress redundant |
| 1049 | + # messages originating from the worker. |
| 1050 | + while t + UNCAUGHT_ERROR_GRACE_PERIOD > time.time(): |
| 1051 | + threads_stopped.wait(timeout=1) |
| 1052 | + if threads_stopped.is_set(): |
| 1053 | + break |
| 1054 | + if t < last_task_error_raise_time + UNCAUGHT_ERROR_GRACE_PERIOD: |
| 1055 | + logger.debug(f"Suppressing error from worker: {error}") |
| 1056 | + else: |
| 1057 | + logger.error(f"Possible unhandled error from worker: {error}") |
| 1058 | + |
| 1059 | + |
| 1060 | +def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): |
1032 | 1061 | """Listen to error messages in the background on the driver. |
1033 | 1062 |
|
1034 | 1063 | This runs in a separate thread on the driver and pushes (error, time) |
1035 | 1064 | tuples to the output queue. |
1036 | 1065 |
|
1037 | 1066 | Args: |
1038 | 1067 | worker: The worker class that this thread belongs to. |
| 1068 | + task_error_queue (queue.Queue): A queue used to communicate with the |
| 1069 | + thread that prints the errors found by this thread. |
1039 | 1070 | threads_stopped (threading.Event): A threading event used to signal to |
1040 | 1071 | the thread that it should exit. |
1041 | 1072 | """ |
@@ -1074,9 +1105,8 @@ def listen_error_messages_raylet(worker, threads_stopped): |
1074 | 1105 |
|
1075 | 1106 | error_message = error_data.error_message |
1076 | 1107 | if (error_data.type == ray_constants.TASK_PUSH_ERROR): |
1077 | | - # TODO(ekl) remove task push errors entirely now that we have |
1078 | | - # the separate unhandled exception handler. |
1079 | | - pass |
| 1108 | + # Delay it a bit to see if we can suppress it |
| 1109 | + task_error_queue.put((error_message, time.time())) |
1080 | 1110 | else: |
1081 | 1111 | logger.warning(error_message) |
1082 | 1112 | except (OSError, redis.exceptions.ConnectionError) as e: |
@@ -1239,12 +1269,19 @@ def connect(node, |
1239 | 1269 | # temporarily using this implementation which constantly queries the |
1240 | 1270 | # scheduler for new error messages. |
1241 | 1271 | if mode == SCRIPT_MODE: |
| 1272 | + q = queue.Queue() |
1242 | 1273 | worker.listener_thread = threading.Thread( |
1243 | 1274 | target=listen_error_messages_raylet, |
1244 | 1275 | name="ray_listen_error_messages", |
1245 | | - args=(worker, worker.threads_stopped)) |
| 1276 | + args=(worker, q, worker.threads_stopped)) |
| 1277 | + worker.printer_thread = threading.Thread( |
| 1278 | + target=print_error_messages_raylet, |
| 1279 | + name="ray_print_error_messages", |
| 1280 | + args=(q, worker.threads_stopped)) |
1246 | 1281 | worker.listener_thread.daemon = True |
1247 | 1282 | worker.listener_thread.start() |
| 1283 | + worker.printer_thread.daemon = True |
| 1284 | + worker.printer_thread.start() |
1248 | 1285 | if log_to_driver: |
1249 | 1286 | global_worker_stdstream_dispatcher.add_handler( |
1250 | 1287 | "ray_print_logs", print_to_stdstream) |
@@ -1297,6 +1334,8 @@ def disconnect(exiting_interpreter=False): |
1297 | 1334 | worker.import_thread.join_import_thread() |
1298 | 1335 | if hasattr(worker, "listener_thread"): |
1299 | 1336 | worker.listener_thread.join() |
| 1337 | + if hasattr(worker, "printer_thread"): |
| 1338 | + worker.printer_thread.join() |
1300 | 1339 | if hasattr(worker, "logger_thread"): |
1301 | 1340 | worker.logger_thread.join() |
1302 | 1341 | worker.threads_stopped.clear() |
@@ -1408,11 +1447,13 @@ def get(object_refs, *, timeout=None): |
1408 | 1447 | raise ValueError("'object_refs' must either be an object ref " |
1409 | 1448 | "or a list of object refs.") |
1410 | 1449 |
|
| 1450 | + global last_task_error_raise_time |
1411 | 1451 | # TODO(ujvl): Consider how to allow user to retrieve the ready objects. |
1412 | 1452 | values, debugger_breakpoint = worker.get_objects( |
1413 | 1453 | object_refs, timeout=timeout) |
1414 | 1454 | for i, value in enumerate(values): |
1415 | 1455 | if isinstance(value, RayError): |
| 1456 | + last_task_error_raise_time = time.time() |
1416 | 1457 | if isinstance(value, ray.exceptions.ObjectLostError): |
1417 | 1458 | worker.core_worker.dump_object_store_memory_usage() |
1418 | 1459 | if isinstance(value, RayTaskError): |
|
0 commit comments