@@ -152,9 +152,15 @@ def skip_forward(
152152 if hasattr (module , 'skip_forward' ):
153153 module .forward = module .skip_forward
154154 remove_weights (module , ignore_modules )
155+ else :
156+ logger .warning (
157+ f"Fail to skip forward since { module .__class__ .__name__ } "
158+ f"does not have `skip_forward`." )
155159
156160
157161def forward_after_recv (forward_fn ):
162+ if hasattr (forward_fn , "__wrapped_by_forward_after_recv__" ):
163+ return forward_fn
158164
159165 def forward_after_recv_fn (
160166 position_ids ,
@@ -176,10 +182,13 @@ def forward_after_recv_fn(
176182 ** kwargs ,
177183 )
178184
185+ forward_after_recv_fn .__wrapped_by_forward_after_recv__ = True
179186 return forward_after_recv_fn
180187
181188
182189def forward_before_send (forward_fn ):
190+ if hasattr (forward_fn , "__wrapped_by_forward_before_send__" ):
191+ return forward_fn
183192
184193 def forward_before_send_fn (
185194 position_ids ,
@@ -204,6 +213,7 @@ def forward_before_send_fn(
204213 pp_send (hidden_states )
205214 return output
206215
216+ forward_before_send_fn .__wrapped_by_forward_before_send__ = True
207217 return forward_before_send_fn
208218
209219
@@ -411,6 +421,8 @@ def __pp_init__(self):
411421 for module in self .epilogue :
412422 skip_forward (module )
413423
424+ self .model .__pp_init__ ()
425+
414426 def __post_init__ (self ):
415427 # 1. mixed precision
416428 quant_config_dict = self .model_config .quant_config_dict
0 commit comments