Skip to content

Commit 8eb5116

Browse files
committed
Update worker.py for compatibility with upstream TVM
This commit updates `mlc_llm.cli.worker` to be compatible with upstream TVM apache/tvm#17180, which adds a `num_groups` argument to the disco worker function. To de-couple this compatibility from a general TVM version bump, this commit has a check on the number of `worker.py` arguments provided, to determine whether the `num_groups` argument is present. After the TVM version used by MLC-LLM is updated to include the upstream changes, this check can be removed.
1 parent d54007b commit 8eb5116

File tree

1 file changed

+27
-9
lines changed

1 file changed

+27
-9
lines changed

python/mlc_llm/cli/worker.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
# pylint: disable=invalid-name
1818
"""Internal DiscoWorker for Disco ProcessSession."""
19+
1920
import os
2021
import sys
2122

@@ -30,23 +31,40 @@
3031

3132
def main():
3233
"""Main worker function"""
33-
if len(sys.argv) != 5:
34-
print("Usage: <worker_id> <num_workers> <read_fd> <write_fd>")
34+
35+
if len(sys.argv) == 5 or len(sys.argv) == 6:
36+
*args, read_fd, write_fd = map(int, sys.argv[1:])
37+
else:
38+
print(
39+
f"Expected exactly either 4 or 5 arguments, "
40+
f"but received {len(sys.argv)-1} arguments.: {sys.argv}"
41+
)
42+
# The <num_groups> argument was added in
43+
# https://github.com/apache/tvm/pull/17180. This script
44+
# currently checks the number of arguments present, to
45+
# determine whether `num_groups` was provided. This allows
46+
# the worker.py script provided by MLC-LLM to be compatible
47+
# with either pre-17180 or post-17180 arguments.
48+
#
49+
# After the TVM version used by MLC-LLM includes #17180, the
50+
# usage can be updated to always require `len(sys.argv)==6`.
51+
print("Usage (without num groups): <worker_id> <num_workers> <read_fd> <write_fd>")
52+
print(
53+
"Usage (with num groups): <worker_id> <num_workers> <num_groups> <read_fd> <write_fd>"
54+
)
3555
return
3656

37-
worker_id = int(sys.argv[1])
38-
num_workers = int(sys.argv[2])
3957
if sys.platform == "win32":
4058
import msvcrt # pylint: disable=import-outside-toplevel,import-error
4159

42-
reader = msvcrt.open_osfhandle(int(sys.argv[3]), os.O_BINARY)
43-
writer = msvcrt.open_osfhandle(int(sys.argv[4]), os.O_BINARY)
60+
reader = msvcrt.open_osfhandle(read_fd, os.O_BINARY)
61+
writer = msvcrt.open_osfhandle(write_fd, os.O_BINARY)
4462
else:
45-
reader = int(sys.argv[3])
46-
writer = int(sys.argv[4])
63+
reader = read_fd
64+
writer = write_fd
4765

4866
worker_func = get_global_func("runtime.disco.WorkerProcess")
49-
worker_func(worker_id, num_workers, reader, writer)
67+
worker_func(*args, reader, writer)
5068

5169

5270
if __name__ == "__main__":

0 commit comments

Comments
 (0)