|
16 | 16 | # under the License. |
17 | 17 | # pylint: disable=invalid-name |
18 | 18 | """Internal DiscoWorker for Disco ProcessSession.""" |
| 19 | + |
19 | 20 | import os |
20 | 21 | import sys |
21 | 22 |
|
|
31 | 32 |
|
32 | 33 | def main(): |
33 | 34 | """Main worker function""" |
34 | | - if len(sys.argv) != 5: |
35 | | - print("Usage: <worker_id> <num_workers> <read_fd> <write_fd>") |
| 35 | + |
| 36 | + if len(sys.argv) == 5 or len(sys.argv) == 6: |
| 37 | + *args, read_fd, write_fd = map(int, sys.argv[1:]) |
| 38 | + else: |
| 39 | + print( |
| 40 | + f"Expected exactly either 4 or 5 arguments, " |
| 41 | + f"but received {len(sys.argv)-1} arguments.: {sys.argv}" |
| 42 | + ) |
| 43 | + # The <num_groups> argument was added in |
| 44 | + # https://github.com/apache/tvm/pull/17180. This script |
| 45 | + # currently checks the number of arguments present, to |
| 46 | + # determine whether `num_groups` was provided. This allows |
| 47 | + # the worker.py script provided by MLC-LLM to be compatible |
| 48 | + # with either pre-17180 or post-17180 arguments. |
| 49 | + # |
| 50 | + # After the TVM version used by MLC-LLM includes #17180, the |
| 51 | + # usage can be updated to always require `len(sys.argv)==6`. |
| 52 | + print("Usage (without num groups): <worker_id> <num_workers> <read_fd> <write_fd>") |
| 53 | + print( |
| 54 | + "Usage (with num groups): <worker_id> <num_workers> <num_groups> <read_fd> <write_fd>" |
| 55 | + ) |
36 | 56 | return |
37 | 57 |
|
38 | | - worker_id = int(sys.argv[1]) |
39 | | - num_workers = int(sys.argv[2]) |
40 | 58 | if sys.platform == "win32": |
41 | 59 | import msvcrt # pylint: disable=import-outside-toplevel,import-error |
42 | 60 |
|
43 | | - reader = msvcrt.open_osfhandle(int(sys.argv[3]), os.O_BINARY) |
44 | | - writer = msvcrt.open_osfhandle(int(sys.argv[4]), os.O_BINARY) |
| 61 | + reader = msvcrt.open_osfhandle(read_fd, os.O_BINARY) |
| 62 | + writer = msvcrt.open_osfhandle(write_fd, os.O_BINARY) |
45 | 63 | else: |
46 | | - reader = int(sys.argv[3]) |
47 | | - writer = int(sys.argv[4]) |
| 64 | + reader = read_fd |
| 65 | + writer = write_fd |
48 | 66 |
|
49 | 67 | worker_func = get_global_func("runtime.disco.WorkerProcess") |
50 | | - worker_func(worker_id, num_workers, reader, writer) |
| 68 | + worker_func(*args, reader, writer) |
51 | 69 |
|
52 | 70 |
|
53 | 71 | if __name__ == "__main__": |
|
0 commit comments