Skip to content

Commit 3b82b5d

Browse files
committed
fix cpu infer device_map (#2103)
1 parent 4bd62ea commit 3b82b5d

File tree

3 files changed

+39
-28
lines changed

3 files changed

+39
-28
lines changed

swift/llm/infer.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,16 +135,15 @@ def prepare_model_template(args: InferArguments,
135135
device_map: Optional[str] = None,
136136
verbose: bool = True,
137137
automodel_class=None) -> Tuple[PreTrainedModel, Template]:
138-
139-
model_kwargs = {}
138+
from .sft import get_default_device_map
140139
if is_torch_npu_available():
141-
logger.info(f'device_count: {torch.npu.device_count()}')
142-
if device_map is None:
143-
device_map = 'npu:0'
140+
print(f'device_count: {torch.npu.device_count()}')
144141
else:
145-
logger.info(f'device_count: {torch.cuda.device_count()}')
146-
if device_map is None:
147-
device_map = 'auto' if torch.cuda.device_count() > 1 else 'cuda:0'
142+
print(f'device_count: {torch.cuda.device_count()}')
143+
model_kwargs = {}
144+
if device_map is not None:
145+
device_map = get_default_device_map()
146+
model_kwargs['device_map'] = device_map
148147
if device_map == 'auto':
149148
model_kwargs['low_cpu_mem_usage'] = True
150149
model_kwargs['device_map'] = device_map

swift/llm/sft.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from swift.torchacc_utils import patch_acc_model
1616
from swift.trainers import TrainerFactory
1717
from swift.trainers.utils import can_return_loss, find_labels
18-
from swift.utils import (append_to_jsonl, check_json_format, compute_acc_metrics, compute_nlg_metrics, get_logger,
19-
get_main, get_model_info, is_ddp_plus_mp, is_dist, is_master, plot_images,
18+
from swift.utils import (append_to_jsonl, check_json_format, compute_acc_metrics, compute_nlg_metrics, get_dist_setting,
19+
get_logger, get_main, get_model_info, is_ddp_plus_mp, is_dist, is_master, plot_images,
2020
preprocess_logits_for_metrics, seed_everything, show_layers, use_torchacc)
2121
from .accelerator import ta_accelerate
2222
from .tuner import prepare_model
@@ -114,6 +114,25 @@ def llm_sft_megatron(args: SftArguments) -> Dict[str, Any]:
114114
return {}
115115

116116

117+
def get_default_device_map():
118+
if is_deepspeed_zero3_enabled() or os.environ.get('ACCELERATE_USE_FSDP', 'False') == 'true':
119+
return None
120+
local_rank = get_dist_setting()[1]
121+
if is_torch_npu_available():
122+
if local_rank >= 0:
123+
return f'npu:{local_rank}'
124+
else:
125+
return 'npu:0'
126+
if torch.cuda.device_count() == 0:
127+
return 'cpu'
128+
elif torch.cuda.device_count() == 1:
129+
return 'cuda:0'
130+
elif is_dist() and not is_ddp_plus_mp():
131+
return f'cuda:{local_rank}'
132+
else:
133+
return 'auto'
134+
135+
117136
def prepare_model_template_train(args, msg: Optional[Dict[str, Any]] = None):
118137

119138
if args.gpu_memory_fraction is not None:
@@ -128,21 +147,15 @@ def prepare_model_template_train(args, msg: Optional[Dict[str, Any]] = None):
128147
f'world_size: {args.world_size}, local_world_size: {args.local_world_size}')
129148

130149
# Loading Model and Tokenizer
131-
if is_deepspeed_zero3_enabled() or os.environ.get('ACCELERATE_USE_FSDP', 'False') == 'true':
132-
model_kwargs = {'device_map': None}
133-
elif is_torch_npu_available():
134-
model_kwargs = {'device_map': args.local_rank if args.local_rank >= 0 else 0}
135-
elif args.device_map_config is not None:
136-
model_kwargs = {'device_map': args.device_map_config}
137-
else:
138-
model_kwargs = {'low_cpu_mem_usage': True}
139-
if is_dist() and not is_ddp_plus_mp():
140-
model_kwargs['device_map'] = {'': args.local_rank}
141-
elif torch.cuda.device_count() == 1:
142-
model_kwargs['device_map'] = 'cuda:0'
143-
elif not use_torchacc():
144-
model_kwargs['device_map'] = 'auto'
145-
150+
model_kwargs = {}
151+
if not use_torchacc():
152+
if args.device_map_config is not None:
153+
device_map = args.device_map_config
154+
else:
155+
device_map = get_default_device_map()
156+
model_kwargs['device_map'] = device_map
157+
if device_map == 'auto':
158+
model_kwargs['low_cpu_mem_usage'] = True
146159
if args.device_max_memory:
147160
n_gpu = torch.cuda.device_count()
148161
assert len(args.device_max_memory) == n_gpu // args.local_world_size
@@ -354,7 +367,6 @@ def prepare_dataset(args, template: Template, msg: Optional[Dict[str, Any]] = No
354367
f'Setting args.preprocess_num_proc to: {args.preprocess_num_proc}')
355368
else:
356369
template.model = None
357-
logger.info(f'Using num_proc: {args.preprocess_num_proc}')
358370
td0, tkwargs0 = template.encode(train_dataset[0])
359371
print_example(td0, tokenizer, tkwargs0)
360372
train_dataset = dataset_map(train_dataset, template.encode, args.preprocess_num_proc, streaming=args.streaming)

swift/llm/utils/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def _map_mp(dataset: HfDataset, map_func: MapFunc, num_proc: int) -> List[Dict[s
299299
# Solving the unordered problem
300300
data = [None] * len(dataset)
301301
num_proc = min(num_proc, len(dataset))
302-
for d in tqdm(_map_mp_i(dataset, map_func, num_proc), total=len(dataset)):
302+
for d in tqdm(_map_mp_i(dataset, map_func, num_proc), total=len(dataset), desc=f'Map (num_proc={num_proc})'):
303303
data[d[0]] = d[1]
304304
return data
305305

@@ -314,7 +314,7 @@ def dataset_map(dataset: DATASET_TYPE,
314314
single_map = partial(_single_map, map_func=map_func)
315315
if num_proc == 1:
316316
data = []
317-
for d in tqdm(dataset):
317+
for d in tqdm(dataset, desc='Map'):
318318
d = single_map(d)
319319
data.append(d)
320320
else:

0 commit comments

Comments
 (0)