|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +# pylint: disable=invalid-name |
| 18 | +"""Multiprocessing via Popen. |
| 19 | +
|
| 20 | +This module provides a multi-processing pool backed by Popen. |
| 21 | +with additional timeout support. |
| 22 | +""" |
| 23 | +import os |
| 24 | +import sys |
| 25 | +import struct |
| 26 | +import threading |
| 27 | +import subprocess |
| 28 | +import concurrent.futures |
| 29 | +from enum import IntEnum |
| 30 | +from collections import namedtuple |
| 31 | +import pickle |
| 32 | + |
| 33 | + |
| 34 | +def kill_child_processes(pid): |
| 35 | + """Kill all child processes recursively for a given pid. |
| 36 | +
|
| 37 | + Parameters |
| 38 | + ---------- |
| 39 | + pid : int |
| 40 | + The given parameter id. |
| 41 | + """ |
| 42 | + # pylint: disable=import-outside-toplevel |
| 43 | + import psutil |
| 44 | + |
| 45 | + try: |
| 46 | + parent = psutil.Process(pid) |
| 47 | + except psutil.NoSuchProcess: |
| 48 | + return |
| 49 | + |
| 50 | + for process in parent.children(recursive=True): |
| 51 | + try: |
| 52 | + process.kill() |
| 53 | + except psutil.NoSuchProcess: |
| 54 | + pass |
| 55 | + |
| 56 | + |
| 57 | +class StatusKind(IntEnum): |
| 58 | + """Running and return value status.""" |
| 59 | + |
| 60 | + RUNNING = 0 |
| 61 | + COMPLETE = 1 |
| 62 | + EXCEPTION = 2 |
| 63 | + TIMEOUT = 3 |
| 64 | + |
| 65 | + |
| 66 | +class MapResult(namedtuple("MapResult", ["status", "value"])): |
| 67 | + """Result of map_with_error_catching. |
| 68 | +
|
| 69 | + Parameters |
| 70 | + ---------- |
| 71 | + status : StatusKind |
| 72 | + The status of the result. |
| 73 | +
|
| 74 | + value : Any |
| 75 | + The result value. |
| 76 | + """ |
| 77 | + |
| 78 | + __slots__ = [] |
| 79 | + |
| 80 | + |
| 81 | +class PopenWorker: |
| 82 | + """A subprocess worker via Popen. |
| 83 | +
|
| 84 | + PopenWorker provides a low-level |
| 85 | + API to interact with a separate process via Popen. |
| 86 | + """ |
| 87 | + |
| 88 | + def __init__(self): |
| 89 | + self._proc = None |
| 90 | + |
| 91 | + def __del__(self): |
| 92 | + try: |
| 93 | + self.kill() |
| 94 | + except ImportError: |
| 95 | + pass |
| 96 | + |
| 97 | + def kill(self): |
| 98 | + """Kill the current running process and cleanup. |
| 99 | +
|
| 100 | + Note |
| 101 | + ---- |
| 102 | + The worker can start a new process when send is called again. |
| 103 | + """ |
| 104 | + if self._proc is not None: |
| 105 | + # allow gracefully shutdown |
| 106 | + try: |
| 107 | + self._writer.close() |
| 108 | + except IOError: |
| 109 | + pass |
| 110 | + try: |
| 111 | + self._reader.close() |
| 112 | + except IOError: |
| 113 | + pass |
| 114 | + # kill all child processes recurisvely |
| 115 | + kill_child_processes(self._proc.pid) |
| 116 | + try: |
| 117 | + self._proc.kill() |
| 118 | + except OSError: |
| 119 | + pass |
| 120 | + self._proc = None |
| 121 | + |
| 122 | + def _start(self): |
| 123 | + """Start a new subprocess if nothing is available""" |
| 124 | + if self._proc is not None: |
| 125 | + return |
| 126 | + |
| 127 | + # connect subprocess with a pair of pipes |
| 128 | + main_read, worker_write = os.pipe() |
| 129 | + worker_read, main_write = os.pipe() |
| 130 | + |
| 131 | + cmd = [sys.executable, "-m", "tvm.exec.popen_worker"] |
| 132 | + if sys.platform == "win32": |
| 133 | + # pylint: disable=import-outside-toplevel |
| 134 | + import msvcrt |
| 135 | + |
| 136 | + worker_read_handle = msvcrt.get_osfhandle(worker_read) |
| 137 | + worker_write_handle = msvcrt.get_osfhandle(worker_write) |
| 138 | + os.set_handle_inheritable(worker_read_handle, True) |
| 139 | + os.set_handle_inheritable(worker_write_handle, True) |
| 140 | + cmd += [str(worker_read_handle), str(worker_write_handle)] |
| 141 | + self._proc = subprocess.Popen(cmd, close_fds=False) |
| 142 | + else: |
| 143 | + cmd += [str(worker_read), str(worker_write)] |
| 144 | + self._proc = subprocess.Popen(cmd, pass_fds=(worker_read, worker_write)) |
| 145 | + |
| 146 | + # close worker side of the pipe |
| 147 | + os.close(worker_read) |
| 148 | + os.close(worker_write) |
| 149 | + self._reader = os.fdopen(main_read, "rb") |
| 150 | + self._writer = os.fdopen(main_write, "wb") |
| 151 | + |
| 152 | + def send(self, fn, args=(), kwargs=None, timeout=None): |
| 153 | + """Send a new function task fn(*args, **kwargs) to the subprocess. |
| 154 | +
|
| 155 | + Parameters |
| 156 | + ---------- |
| 157 | + fn : function |
| 158 | + The function to be invoked. |
| 159 | +
|
| 160 | + args : list |
| 161 | + Positional argument. |
| 162 | +
|
| 163 | + kwargs : dict |
| 164 | + Keyword arguments |
| 165 | +
|
| 166 | + timeout : float |
| 167 | + Timeout value when executing the function |
| 168 | +
|
| 169 | + Note |
| 170 | + ---- |
| 171 | + The caller must call recv before calling the next send in |
| 172 | + order to make sure the timeout and child process exit |
| 173 | + won't affect the later requests. |
| 174 | + """ |
| 175 | + # use cloud pickle |
| 176 | + # pylint: disable=import-outside-toplevel |
| 177 | + import cloudpickle |
| 178 | + |
| 179 | + if self._proc is None: |
| 180 | + self._start() |
| 181 | + kwargs = {} if not kwargs else kwargs |
| 182 | + data = cloudpickle.dumps((fn, args, kwargs, timeout), protocol=pickle.HIGHEST_PROTOCOL) |
| 183 | + try: |
| 184 | + self._writer.write(struct.pack("<i", len(data))) |
| 185 | + self._writer.write(data) |
| 186 | + self._writer.flush() |
| 187 | + except IOError: |
| 188 | + pass |
| 189 | + |
| 190 | + def _child_process_error(self): |
| 191 | + """Raise a child process error.""" |
| 192 | + # kill and lazily restart the process in the next send. |
| 193 | + self.kill() |
| 194 | + return ChildProcessError("Subprocess terminated") |
| 195 | + |
| 196 | + def recv(self): |
| 197 | + """Receive the result of the last send. |
| 198 | +
|
| 199 | + Returns |
| 200 | + ------- |
| 201 | + result: object |
| 202 | + The result of the last send. |
| 203 | +
|
| 204 | + Raises |
| 205 | + ------ |
| 206 | + ChildProcessError: if the child process exited abnormally. |
| 207 | + TimeoutError: if timeout happens |
| 208 | + Exception: if other exception happens during the execution. |
| 209 | + """ |
| 210 | + # pylint: disable=import-outside-toplevel |
| 211 | + import cloudpickle |
| 212 | + |
| 213 | + try: |
| 214 | + len_data = self._reader.read(4) |
| 215 | + except IOError: |
| 216 | + raise self._child_process_error() |
| 217 | + |
| 218 | + if len(len_data) == 0: |
| 219 | + raise self._child_process_error() |
| 220 | + |
| 221 | + try: |
| 222 | + recv_bytes = struct.unpack("<i", len_data)[0] |
| 223 | + status, value = cloudpickle.loads(self._reader.read(recv_bytes)) |
| 224 | + except IOError: |
| 225 | + raise self._child_process_error() |
| 226 | + |
| 227 | + if status == StatusKind.COMPLETE: |
| 228 | + return value |
| 229 | + if status == StatusKind.EXCEPTION: |
| 230 | + raise value |
| 231 | + assert status == StatusKind.TIMEOUT |
| 232 | + # kill and lazily restart the process in the next send. |
| 233 | + self.kill() |
| 234 | + raise TimeoutError() |
| 235 | + |
| 236 | + |
| 237 | +class PopenPoolExecutor: |
| 238 | + """An parallel executor backed by Popen processes. |
| 239 | +
|
| 240 | + Parameters |
| 241 | + ---------- |
| 242 | + max_worker : int |
| 243 | + Maximum number of workers |
| 244 | +
|
| 245 | + timeout : float |
| 246 | + Timeout value for each function submit. |
| 247 | + """ |
| 248 | + |
| 249 | + def __init__(self, max_workers, timeout=None): |
| 250 | + # Use an internal thread pool to send to popen workers |
| 251 | + self._threadpool = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) |
| 252 | + self._timeout = timeout |
| 253 | + self._worker_map = {} |
| 254 | + self._lock = threading.Lock() |
| 255 | + |
| 256 | + def __del__(self): |
| 257 | + self._lock.acquire() |
| 258 | + for worker in self._worker_map.values(): |
| 259 | + try: |
| 260 | + worker.kill() |
| 261 | + except ImportError: |
| 262 | + pass |
| 263 | + self._lock.release() |
| 264 | + self._threadpool.shutdown() |
| 265 | + |
| 266 | + def _worker_run(self, fn, args, kwargs): |
| 267 | + """Internal thread runner.""" |
| 268 | + self._lock.acquire() |
| 269 | + tid = threading.get_ident() |
| 270 | + if tid not in self._worker_map: |
| 271 | + proc = PopenWorker() |
| 272 | + self._worker_map[tid] = proc |
| 273 | + else: |
| 274 | + proc = self._worker_map[tid] |
| 275 | + self._lock.release() |
| 276 | + |
| 277 | + proc.send(fn, args, kwargs, self._timeout) |
| 278 | + return proc.recv() |
| 279 | + |
| 280 | + def _worker_run_with_error_catching(self, fn, args, kwargs) -> MapResult: |
| 281 | + # pylint: disable=broad-except |
| 282 | + try: |
| 283 | + return MapResult(status=StatusKind.COMPLETE, value=self._worker_run(fn, args, kwargs)) |
| 284 | + except TimeoutError as exception: |
| 285 | + return MapResult(status=StatusKind.TIMEOUT, value=exception) |
| 286 | + except Exception as exception: |
| 287 | + return MapResult(status=StatusKind.EXCEPTION, value=exception) |
| 288 | + |
| 289 | + def submit(self, fn, *args, **kwargs) -> concurrent.futures.Future: |
| 290 | + """Submit a new function job to the pool |
| 291 | +
|
| 292 | + Parameters |
| 293 | + ---------- |
| 294 | + fn : function |
| 295 | + The function to be invoked. |
| 296 | +
|
| 297 | + args : list |
| 298 | + Positional argument. |
| 299 | +
|
| 300 | + kwargs : dict |
| 301 | + Keyword arguments |
| 302 | +
|
| 303 | + Returns |
| 304 | + ------- |
| 305 | + future : concurrent.futures.Future |
| 306 | + A future that can be used to access the result. |
| 307 | + """ |
| 308 | + # pylint: disable=unnecessary-lambda |
| 309 | + worker = lambda *args: self._worker_run(*args) |
| 310 | + return self._threadpool.submit(worker, fn, args, kwargs) |
| 311 | + |
| 312 | + def map_with_error_catching(self, fn, iterator): |
| 313 | + """Same as map, but catches exceptions and return them instead. |
| 314 | +
|
| 315 | + Parameters |
| 316 | + ---------- |
| 317 | + fn : function |
| 318 | + The function to be invoked. |
| 319 | +
|
| 320 | + iterator : Iterator |
| 321 | + Input iterator. |
| 322 | +
|
| 323 | + Returns |
| 324 | + ------- |
| 325 | + out_iter : Iterator[MapResult] |
| 326 | + The result iterator. |
| 327 | + """ |
| 328 | + worker = lambda x: self._worker_run_with_error_catching(fn, (x,), None) |
| 329 | + return self._threadpool.map(worker, iterator) |
0 commit comments