Skip to content

Commit 6f989e7

Browse files
committed
[CONTRIB] PopenPoolExecutor
PopenPoolExecutor implements a ProcessPoolExecutor backed by popen. - Only handles invoking functions in tvm namespace. - Unlike multiprocessing, does not require __main__ block, which means it can directly run on jupyter notebook. - Come with timeout and fault tolerant support to timeout long running jobs, and restart the process when an error happens. Recommended usage: it is recommended to create a pool and reuse it in a long running job(e.g. autotuning) so that the process are reused when possible.
1 parent 5697440 commit 6f989e7

File tree

4 files changed

+532
-0
lines changed

4 files changed

+532
-0
lines changed

python/tvm/contrib/popen_pool.py

Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name
18+
"""Multiprocessing via Popen.
19+
20+
This module provides a multi-processing pool backed by Popen.
21+
with additional timeout support.
22+
"""
23+
import os
24+
import sys
25+
import struct
26+
import threading
27+
import subprocess
28+
import concurrent.futures
29+
from enum import IntEnum
30+
from collections import namedtuple
31+
import pickle
32+
33+
34+
def kill_child_processes(pid):
35+
"""Kill all child processes recursively for a given pid.
36+
37+
Parameters
38+
----------
39+
pid : int
40+
The given parameter id.
41+
"""
42+
# pylint: disable=import-outside-toplevel
43+
import psutil
44+
45+
try:
46+
parent = psutil.Process(pid)
47+
except psutil.NoSuchProcess:
48+
return
49+
50+
for process in parent.children(recursive=True):
51+
try:
52+
process.kill()
53+
except psutil.NoSuchProcess:
54+
pass
55+
56+
57+
class StatusKind(IntEnum):
58+
"""Running and return value status."""
59+
60+
RUNNING = 0
61+
COMPLETE = 1
62+
EXCEPTION = 2
63+
TIMEOUT = 3
64+
65+
66+
class MapResult(namedtuple("MapResult", ["status", "value"])):
67+
"""Result of map_with_error_catching.
68+
69+
Parameters
70+
----------
71+
status : StatusKind
72+
The status of the result.
73+
74+
value : Any
75+
The result value.
76+
"""
77+
78+
__slots__ = []
79+
80+
81+
class PopenWorker:
82+
"""A subprocess worker via Popen.
83+
84+
PopenWorker provides a low-level
85+
API to interact with a separate process via Popen.
86+
"""
87+
88+
def __init__(self):
89+
self._proc = None
90+
91+
def __del__(self):
92+
try:
93+
self.kill()
94+
except ImportError:
95+
pass
96+
97+
def kill(self):
98+
"""Kill the current running process and cleanup.
99+
100+
Note
101+
----
102+
The worker can start a new process when send is called again.
103+
"""
104+
if self._proc is not None:
105+
# allow gracefully shutdown
106+
try:
107+
self._writer.close()
108+
except IOError:
109+
pass
110+
try:
111+
self._reader.close()
112+
except IOError:
113+
pass
114+
# kill all child processes recurisvely
115+
kill_child_processes(self._proc.pid)
116+
try:
117+
self._proc.kill()
118+
except OSError:
119+
pass
120+
self._proc = None
121+
122+
def _start(self):
123+
"""Start a new subprocess if nothing is available"""
124+
if self._proc is not None:
125+
return
126+
127+
# connect subprocess with a pair of pipes
128+
main_read, worker_write = os.pipe()
129+
worker_read, main_write = os.pipe()
130+
131+
cmd = [sys.executable, "-m", "tvm.exec.popen_worker"]
132+
if sys.platform == "win32":
133+
# pylint: disable=import-outside-toplevel
134+
import msvcrt
135+
136+
worker_read_handle = msvcrt.get_osfhandle(worker_read)
137+
worker_write_handle = msvcrt.get_osfhandle(worker_write)
138+
os.set_handle_inheritable(worker_read_handle, True)
139+
os.set_handle_inheritable(worker_write_handle, True)
140+
cmd += [str(worker_read_handle), str(worker_write_handle)]
141+
self._proc = subprocess.Popen(cmd, close_fds=False)
142+
else:
143+
cmd += [str(worker_read), str(worker_write)]
144+
self._proc = subprocess.Popen(cmd, pass_fds=(worker_read, worker_write))
145+
146+
# close worker side of the pipe
147+
os.close(worker_read)
148+
os.close(worker_write)
149+
self._reader = os.fdopen(main_read, "rb")
150+
self._writer = os.fdopen(main_write, "wb")
151+
152+
def send(self, fn, args=(), kwargs=None, timeout=None):
153+
"""Send a new function task fn(*args, **kwargs) to the subprocess.
154+
155+
Parameters
156+
----------
157+
fn : function
158+
The function to be invoked.
159+
160+
args : list
161+
Positional argument.
162+
163+
kwargs : dict
164+
Keyword arguments
165+
166+
timeout : float
167+
Timeout value when executing the function
168+
169+
Note
170+
----
171+
The caller must call recv before calling the next send in
172+
order to make sure the timeout and child process exit
173+
won't affect the later requests.
174+
"""
175+
# use cloud pickle
176+
# pylint: disable=import-outside-toplevel
177+
import cloudpickle
178+
179+
if self._proc is None:
180+
self._start()
181+
kwargs = {} if not kwargs else kwargs
182+
data = cloudpickle.dumps((fn, args, kwargs, timeout), protocol=pickle.HIGHEST_PROTOCOL)
183+
try:
184+
self._writer.write(struct.pack("<i", len(data)))
185+
self._writer.write(data)
186+
self._writer.flush()
187+
except IOError:
188+
pass
189+
190+
def _child_process_error(self):
191+
"""Raise a child process error."""
192+
# kill and lazily restart the process in the next send.
193+
self.kill()
194+
return ChildProcessError("Subprocess terminated")
195+
196+
def recv(self):
197+
"""Receive the result of the last send.
198+
199+
Returns
200+
-------
201+
result: object
202+
The result of the last send.
203+
204+
Raises
205+
------
206+
ChildProcessError: if the child process exited abnormally.
207+
TimeoutError: if timeout happens
208+
Exception: if other exception happens during the execution.
209+
"""
210+
# pylint: disable=import-outside-toplevel
211+
import cloudpickle
212+
213+
try:
214+
len_data = self._reader.read(4)
215+
except IOError:
216+
raise self._child_process_error()
217+
218+
if len(len_data) == 0:
219+
raise self._child_process_error()
220+
221+
try:
222+
recv_bytes = struct.unpack("<i", len_data)[0]
223+
status, value = cloudpickle.loads(self._reader.read(recv_bytes))
224+
except IOError:
225+
raise self._child_process_error()
226+
227+
if status == StatusKind.COMPLETE:
228+
return value
229+
if status == StatusKind.EXCEPTION:
230+
raise value
231+
assert status == StatusKind.TIMEOUT
232+
# kill and lazily restart the process in the next send.
233+
self.kill()
234+
raise TimeoutError()
235+
236+
237+
class PopenPoolExecutor:
238+
"""An parallel executor backed by Popen processes.
239+
240+
Parameters
241+
----------
242+
max_worker : int
243+
Maximum number of workers
244+
245+
timeout : float
246+
Timeout value for each function submit.
247+
"""
248+
249+
def __init__(self, max_workers, timeout=None):
250+
# Use an internal thread pool to send to popen workers
251+
self._threadpool = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
252+
self._timeout = timeout
253+
self._worker_map = {}
254+
self._lock = threading.Lock()
255+
256+
def __del__(self):
257+
self._lock.acquire()
258+
for worker in self._worker_map.values():
259+
try:
260+
worker.kill()
261+
except ImportError:
262+
pass
263+
self._lock.release()
264+
self._threadpool.shutdown()
265+
266+
def _worker_run(self, fn, args, kwargs):
267+
"""Internal thread runner."""
268+
self._lock.acquire()
269+
tid = threading.get_ident()
270+
if tid not in self._worker_map:
271+
proc = PopenWorker()
272+
self._worker_map[tid] = proc
273+
else:
274+
proc = self._worker_map[tid]
275+
self._lock.release()
276+
277+
proc.send(fn, args, kwargs, self._timeout)
278+
return proc.recv()
279+
280+
def _worker_run_with_error_catching(self, fn, args, kwargs) -> MapResult:
281+
# pylint: disable=broad-except
282+
try:
283+
return MapResult(status=StatusKind.COMPLETE, value=self._worker_run(fn, args, kwargs))
284+
except TimeoutError as exception:
285+
return MapResult(status=StatusKind.TIMEOUT, value=exception)
286+
except Exception as exception:
287+
return MapResult(status=StatusKind.EXCEPTION, value=exception)
288+
289+
def submit(self, fn, *args, **kwargs) -> concurrent.futures.Future:
290+
"""Submit a new function job to the pool
291+
292+
Parameters
293+
----------
294+
fn : function
295+
The function to be invoked.
296+
297+
args : list
298+
Positional argument.
299+
300+
kwargs : dict
301+
Keyword arguments
302+
303+
Returns
304+
-------
305+
future : concurrent.futures.Future
306+
A future that can be used to access the result.
307+
"""
308+
# pylint: disable=unnecessary-lambda
309+
worker = lambda *args: self._worker_run(*args)
310+
return self._threadpool.submit(worker, fn, args, kwargs)
311+
312+
def map_with_error_catching(self, fn, iterator):
313+
"""Same as map, but catches exceptions and return them instead.
314+
315+
Parameters
316+
----------
317+
fn : function
318+
The function to be invoked.
319+
320+
iterator : Iterator
321+
Input iterator.
322+
323+
Returns
324+
-------
325+
out_iter : Iterator[MapResult]
326+
The result iterator.
327+
"""
328+
worker = lambda x: self._worker_run_with_error_catching(fn, (x,), None)
329+
return self._threadpool.map(worker, iterator)

0 commit comments

Comments
 (0)