Skip to content

Commit 7a62042

Browse files
hhaAndroidyumion
authored andcommitted
Fix accepting an unexpected argument local-rank in PyTorch 2.0 (open-mmlab#10050)
1 parent 88b77ff commit 7a62042

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

tools/test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@ def parse_args():
5555
default='none',
5656
help='job launcher')
5757
parser.add_argument('--tta', action='store_true')
58-
parser.add_argument('--local_rank', type=int, default=0)
58+
# When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
59+
# will pass the `--local-rank` parameter to `tools/train.py` instead
60+
# of `--local_rank`.
61+
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
5962
args = parser.parse_args()
6063
if 'LOCAL_RANK' not in os.environ:
6164
os.environ['LOCAL_RANK'] = str(args.local_rank)

tools/train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,15 @@ def parse_args():
4646
choices=['none', 'pytorch', 'slurm', 'mpi'],
4747
default='none',
4848
help='job launcher')
49-
parser.add_argument('--local_rank', type=int, default=0)
5049
parser.add_argument(
5150
'--mmyolo',
5251
action='store_true',
5352
default=False,
5453
help='if using mmyolo model, set --mmyolo')
54+
# When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
55+
# will pass the `--local-rank` parameter to `tools/train.py` instead
56+
# of `--local_rank`.
57+
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
5558
args = parser.parse_args()
5659
if 'LOCAL_RANK' not in os.environ:
5760
os.environ['LOCAL_RANK'] = str(args.local_rank)

0 commit comments

Comments
 (0)