1515from swift .torchacc_utils import patch_acc_model
1616from swift .trainers import TrainerFactory
1717from swift .trainers .utils import can_return_loss , find_labels
18- from swift .utils import (append_to_jsonl , check_json_format , compute_acc_metrics , compute_nlg_metrics , get_logger ,
19- get_main , get_model_info , is_ddp_plus_mp , is_dist , is_master , plot_images ,
18+ from swift .utils import (append_to_jsonl , check_json_format , compute_acc_metrics , compute_nlg_metrics , get_dist_setting ,
19+ get_logger , get_main , get_model_info , is_ddp_plus_mp , is_dist , is_master , plot_images ,
2020 preprocess_logits_for_metrics , seed_everything , show_layers , use_torchacc )
2121from .accelerator import ta_accelerate
2222from .tuner import prepare_model
@@ -114,6 +114,25 @@ def llm_sft_megatron(args: SftArguments) -> Dict[str, Any]:
114114 return {}
115115
116116
117+ def get_default_device_map ():
118+ if is_deepspeed_zero3_enabled () or os .environ .get ('ACCELERATE_USE_FSDP' , 'False' ) == 'true' :
119+ return None
120+ local_rank = get_dist_setting ()[1 ]
121+ if is_torch_npu_available ():
122+ if local_rank >= 0 :
123+ return f'npu:{ local_rank } '
124+ else :
125+ return 'npu:0'
126+ if torch .cuda .device_count () == 0 :
127+ return 'cpu'
128+ elif torch .cuda .device_count () == 1 :
129+ return 'cuda:0'
130+ elif is_dist () and not is_ddp_plus_mp ():
131+ return f'cuda:{ local_rank } '
132+ else :
133+ return 'auto'
134+
135+
117136def prepare_model_template_train (args , msg : Optional [Dict [str , Any ]] = None ):
118137
119138 if args .gpu_memory_fraction is not None :
@@ -128,21 +147,15 @@ def prepare_model_template_train(args, msg: Optional[Dict[str, Any]] = None):
128147 f'world_size: { args .world_size } , local_world_size: { args .local_world_size } ' )
129148
130149 # Loading Model and Tokenizer
131- if is_deepspeed_zero3_enabled () or os .environ .get ('ACCELERATE_USE_FSDP' , 'False' ) == 'true' :
132- model_kwargs = {'device_map' : None }
133- elif is_torch_npu_available ():
134- model_kwargs = {'device_map' : args .local_rank if args .local_rank >= 0 else 0 }
135- elif args .device_map_config is not None :
136- model_kwargs = {'device_map' : args .device_map_config }
137- else :
138- model_kwargs = {'low_cpu_mem_usage' : True }
139- if is_dist () and not is_ddp_plus_mp ():
140- model_kwargs ['device_map' ] = {'' : args .local_rank }
141- elif torch .cuda .device_count () == 1 :
142- model_kwargs ['device_map' ] = 'cuda:0'
143- elif not use_torchacc ():
144- model_kwargs ['device_map' ] = 'auto'
145-
150+ model_kwargs = {}
151+ if not use_torchacc ():
152+ if args .device_map_config is not None :
153+ device_map = args .device_map_config
154+ else :
155+ device_map = get_default_device_map ()
156+ model_kwargs ['device_map' ] = device_map
157+ if device_map == 'auto' :
158+ model_kwargs ['low_cpu_mem_usage' ] = True
146159 if args .device_max_memory :
147160 n_gpu = torch .cuda .device_count ()
148161 assert len (args .device_max_memory ) == n_gpu // args .local_world_size
@@ -354,7 +367,6 @@ def prepare_dataset(args, template: Template, msg: Optional[Dict[str, Any]] = No
354367 f'Setting args.preprocess_num_proc to: { args .preprocess_num_proc } ' )
355368 else :
356369 template .model = None
357- logger .info (f'Using num_proc: { args .preprocess_num_proc } ' )
358370 td0 , tkwargs0 = template .encode (train_dataset [0 ])
359371 print_example (td0 , tokenizer , tkwargs0 )
360372 train_dataset = dataset_map (train_dataset , template .encode , args .preprocess_num_proc , streaming = args .streaming )
0 commit comments