@@ -312,6 +312,9 @@ class Event:
312312 _instruction_id : Optional [int ] = None
313313
314314 _delegate_metadata_parser : Optional [Callable [[List [str ]], Dict [str , Any ]]] = None
315+ _delegate_time_scale_converter : Optional [
316+ Callable [[Union [int , str ], Union [int , float ]], Union [int , float ]]
317+ ] = None
315318
316319 @cached_property
317320 def delegate_debug_metadatas (self ) -> Union [List [str ], Dict [str , Any ]]:
@@ -391,6 +394,9 @@ def _gen_from_inference_events(
391394 delegate_metadata_parser : Optional [
392395 Callable [[List [str ]], Dict [str , Any ]]
393396 ] = None ,
397+ delegate_time_scale_converter : Optional [
398+ Callable [[Union [int , str ], Union [int , float ]], Union [int , float ]]
399+ ] = None ,
394400 ) -> "Event" :
395401 """
396402 Given an EventSignature and a list of Events with that signature,
@@ -411,6 +417,7 @@ def _gen_from_inference_events(
411417 name = "" ,
412418 _instruction_id = signature .instruction_id ,
413419 _delegate_metadata_parser = delegate_metadata_parser ,
420+ _delegate_time_scale_converter = delegate_time_scale_converter ,
414421 )
415422
416423 # Populate fields from profile events
@@ -476,14 +483,31 @@ def _populate_profiling_related_fields(
476483 f"Expected exactly one profile event per InstructionEvent when generating Inspector Event, but got { len (profile_events )} "
477484 )
478485
486+ profile_event = profile_events [0 ]
487+
479488 # Scale factor should only be applied to non-delegated ops
480- scale_factor_updated = 1 if ret_event .is_delegated_op else scale_factor
489+ if (
490+ ret_event .is_delegated_op
491+ and ret_event ._delegate_time_scale_converter is not None
492+ ):
493+ scaled_time = ret_event ._delegate_time_scale_converter (
494+ ret_event .name ,
495+ profile_event .end_time ,
496+ # pyre-ignore
497+ ) - ret_event ._delegate_time_scale_converter (
498+ ret_event .name , profile_event .start_time
499+ )
500+ elif not ret_event .is_delegated_op :
501+ scaled_time = (
502+ float (profile_event .end_time - profile_event .start_time )
503+ / scale_factor
504+ )
505+ else :
506+ scaled_time = float (
507+ profile_event .end_time - profile_event .start_time
508+ )
481509
482- profile_event = profile_events [0 ]
483- data .append (
484- float (profile_event .end_time - profile_event .start_time )
485- / scale_factor_updated
486- )
510+ data .append (scaled_time )
487511 delegate_debug_metadatas .append (
488512 profile_event .delegate_debug_metadata
489513 if profile_event .delegate_debug_metadata
@@ -646,6 +670,9 @@ def _gen_from_etdump(
646670 delegate_metadata_parser : Optional [
647671 Callable [[List [str ]], Dict [str , Any ]]
648672 ] = None ,
673+ delegate_time_scale_converter : Optional [
674+ Callable [[Union [int , str ], Union [int , float ]], Union [int , float ]]
675+ ] = None ,
649676 ) -> List ["EventBlock" ]:
650677 """
651678 Given an etdump, generate a list of EventBlocks corresponding to the
@@ -743,6 +770,7 @@ class GroupedRunInstances:
743770 scale_factor ,
744771 output_buffer ,
745772 delegate_metadata_parser ,
773+ delegate_time_scale_converter ,
746774 )
747775 for signature , instruction_events in run_group .items ()
748776 ]
@@ -875,6 +903,9 @@ def __init__(
875903 delegate_metadata_parser : Optional [
876904 Callable [[List [str ]], Dict [str , Any ]]
877905 ] = None ,
906+ delegate_time_scale_converter : Optional [
907+ Callable [[Union [int , str ], Union [int , float ]], Union [int , float ]]
908+ ] = None ,
878909 enable_module_hierarchy : bool = False ,
879910 ) -> None :
880911 r"""
@@ -930,6 +961,7 @@ def __init__(
930961 self ._target_time_scale ,
931962 output_buffer ,
932963 delegate_metadata_parser = delegate_metadata_parser ,
964+ delegate_time_scale_converter = delegate_time_scale_converter ,
933965 )
934966
935967 # Connect ETRecord to EventBlocks
0 commit comments