1010import torch
1111import torch .nn as nn
1212# TPU XLA related
13+ import torch_xla
1314import torch_xla .core .xla_model as xm
1415import torch_xla .distributed .spmd as xs
1516import torch_xla .runtime as xr
@@ -846,10 +847,10 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
846847 # 2. A list or tuple (length: num_items) of tensors, each of shape
847848 # (feature_size, hidden_size) in case the feature size is dynamic
848849 # depending on the input multimodal items.
849- xm . mark_step ( )
850+ torch_xla . sync ( wait = False )
850851 curr_group_outputs = self .model .get_multimodal_embeddings (
851852 ** mm_kwargs_group )
852- xm . mark_step ( )
853+ torch_xla . sync ( wait = False )
853854
854855 sanity_check_mm_encoder_outputs (
855856 curr_group_outputs ,
@@ -952,7 +953,7 @@ def execute_model(
952953 mm_embeds = self ._gather_mm_embeddings (scheduler_output )
953954 else :
954955 mm_embeds = []
955- xm . mark_step ( )
956+ torch_xla . sync ( wait = False )
956957 # Prepare inputs, the requests might be split into multiple
957958 # executions, combine the result of each execution.
958959 start_index = 0
@@ -969,7 +970,7 @@ def execute_model(
969970 end_index = self ._prepare_inputs (scheduler_output , start_index )
970971 input_ids , inputs_embeds = self ._get_model_inputs (
971972 self .input_ids , mm_embeds )
972- xm . mark_step ( )
973+ torch_xla . sync ( wait = False )
973974 # Run the decoder
974975 with set_forward_context (
975976 attn_metadata ,
@@ -1183,7 +1184,7 @@ def load_model(self) -> None:
11831184
11841185 # Sync all pending XLA execution during model initialization and weight
11851186 # loading.
1186- xm . mark_step ( )
1187+ torch_xla . sync ( wait = False )
11871188 xm .wait_device_ops ()
11881189 if not hasattr (self , "model" ):
11891190 self .model = model
@@ -1267,10 +1268,10 @@ def _dummy_run(self, num_tokens: int, num_reqs: int,
12671268
12681269 def _set_active_loras (self , prompt_lora_mapping , token_lora_mapping ,
12691270 lora_requests ) -> None :
1270- xm . mark_step ( ) # Captures input updates
1271+ torch_xla . sync ( wait = False ) # Captures input updates
12711272 super ()._set_active_loras (prompt_lora_mapping , token_lora_mapping ,
12721273 lora_requests )
1273- xm . mark_step ( ) # Captures metadata updates
1274+ torch_xla . sync ( wait = False ) # Captures metadata updates
12741275
12751276 def _precompile_mm_encoder (self ) -> None :
12761277 if not self .supports_mm_inputs :
@@ -1297,10 +1298,10 @@ def _precompile_mm_encoder(self) -> None:
12971298 num_items ,
12981299 )
12991300 # Run multimodal encoder.
1300- xm . mark_step ( )
1301+ torch_xla . sync ( wait = False )
13011302 mm_embeds = self .model .get_multimodal_embeddings (
13021303 ** batched_dummy_mm_inputs )
1303- xm . mark_step ( )
1304+ torch_xla . sync ( wait = False )
13041305 num_patches = mm_embeds [0 ].shape [0 ]
13051306 items_size = num_patches * num_items
13061307
@@ -1325,7 +1326,7 @@ def _precompile_mm_encoder(self) -> None:
13251326 a , b = self ._get_model_inputs (placeholders_ids ,
13261327 [mm_embeds ])
13271328 assert a is None
1328- xm . mark_step ( )
1329+ torch_xla . sync ( wait = False )
13291330
13301331 # Pre-compile `get_input_embeddings` when mm_embeddings are not
13311332 # present. Chunk is only made of text, no mm_placeholders.
@@ -1336,7 +1337,7 @@ def _precompile_mm_encoder(self) -> None:
13361337 placeholders_ids = placeholders_ids .to (self .device )
13371338 a , b = self ._get_model_inputs (placeholders_ids , [])
13381339 assert a is None
1339- xm . mark_step ( )
1340+ torch_xla . sync ( wait = False )
13401341
13411342 xm .wait_device_ops ()
13421343 end = time .perf_counter ()
@@ -1532,11 +1533,11 @@ def profile_run(
15321533 # Isolate encoder graph from post-processing to minimize
15331534 # impact of recompilation until it's fixed.
15341535 start = time .perf_counter ()
1535- xm . mark_step ( )
1536+ torch_xla . sync ( wait = False )
15361537 dummy_encoder_outputs = \
15371538 self .model .get_multimodal_embeddings (
15381539 ** batched_dummy_mm_inputs )
1539- xm . mark_step ( )
1540+ torch_xla . sync ( wait = False )
15401541 xm .wait_device_ops ()
15411542 end = time .perf_counter ()
15421543 logger .info (
@@ -1559,7 +1560,7 @@ def profile_run(
15591560 self ._dummy_run (num_tokens , self .num_reqs_most_model_len ,
15601561 self .num_blocks_per_most_len_req )
15611562
1562- xm . mark_step ( )
1563+ torch_xla . sync ( wait = False )
15631564 xm .wait_device_ops ()
15641565 self .encoder_cache .clear ()
15651566 gc .collect ()
@@ -1927,11 +1928,11 @@ def _tpu_set_lora(
19271928 # to a tensor doesn't seem to work anymore. This might be fixed with a
19281929 # later release of torch_xla.
19291930 self ._original_set_lora (index , lora_a , lora_b , embeddings_tensor , bias )
1930- xm . mark_step ( )
1931+ torch_xla . sync ( wait = False )
19311932
19321933 def _tpu_reset_lora (self , index : int ):
19331934 self ._original_reset_lora (index )
1934- xm . mark_step ( )
1935+ torch_xla . sync ( wait = False )
19351936
19361937 for _ , module in model .named_modules ():
19371938 if isinstance (module , BaseLayerWithLoRA ):
0 commit comments