diff --git a/swift/llm/infer.py b/swift/llm/infer.py index 660e721806..3e1a1439ed 100644 --- a/swift/llm/infer.py +++ b/swift/llm/infer.py @@ -141,7 +141,7 @@ def prepare_model_template(args: InferArguments, else: print(f'device_count: {torch.cuda.device_count()}') model_kwargs = {} - if device_map is not None: + if device_map is None: device_map = get_default_device_map() model_kwargs['device_map'] = device_map if device_map == 'auto':