@@ -45,7 +45,7 @@ class GraphCaptureContext:
4545
4646
4747def _split_tensor_dict (
48- tensor_dict : Dict [Any , Union [torch .Tensor , Any ]],
48+ tensor_dict : Dict [str , Union [torch .Tensor , Any ]],
4949 prefix : str = "" ) -> Tuple [List [Tuple [str , Any ]], List [torch .Tensor ]]:
5050 """Split the tensor dictionary into two parts:
5151 1. A list of (key, value) pairs. If the value is a tensor, it is replaced
@@ -473,11 +473,11 @@ def recv_object(self, src: int) -> Any:
473473
474474 def broadcast_tensor_dict (
475475 self ,
476- tensor_dict : Optional [Dict [Any , Union [torch .Tensor , Any ]]] = None ,
476+ tensor_dict : Optional [Dict [str , Union [torch .Tensor , Any ]]] = None ,
477477 src : int = 0 ,
478478 group : Optional [ProcessGroup ] = None ,
479479 metadata_group : Optional [ProcessGroup ] = None
480- ) -> Optional [Dict [Any , Union [torch .Tensor , Any ]]]:
480+ ) -> Optional [Dict [str , Union [torch .Tensor , Any ]]]:
481481 """Broadcast the input tensor dictionary.
482482 NOTE: `src` is the local rank of the source rank.
483483 """
@@ -558,9 +558,9 @@ def broadcast_tensor_dict(
558558
559559 def send_tensor_dict (
560560 self ,
561- tensor_dict : Dict [Any , Union [torch .Tensor , Any ]],
561+ tensor_dict : Dict [str , Union [torch .Tensor , Any ]],
562562 dst : Optional [int ] = None
563- ) -> Optional [Dict [Any , Union [torch .Tensor , Any ]]]:
563+ ) -> Optional [Dict [str , Union [torch .Tensor , Any ]]]:
564564 """Send the input tensor dictionary.
565565 NOTE: `dst` is the local rank of the source rank.
566566 """
@@ -599,7 +599,7 @@ def send_tensor_dict(
599599 def recv_tensor_dict (
600600 self ,
601601 src : Optional [int ] = None
602- ) -> Optional [Dict [Any , Union [torch .Tensor , Any ]]]:
602+ ) -> Optional [Dict [str , Union [torch .Tensor , Any ]]]:
603603 """Recv the input tensor dictionary.
604604 NOTE: `src` is the local rank of the source rank.
605605 """
@@ -615,15 +615,15 @@ def recv_tensor_dict(
615615 assert src < self .world_size , f"Invalid src rank ({ src } )"
616616
617617 recv_metadata_list = self .recv_object (src = src )
618- tensor_dict = {}
618+ tensor_dict : Dict [ str , Any ] = {}
619619 for key , value in recv_metadata_list :
620620 if isinstance (value , TensorMetadata ):
621621 tensor = torch .empty (value .size ,
622622 dtype = value .dtype ,
623623 device = value .device )
624624 if tensor .numel () == 0 :
625625 # Skip broadcasting empty tensors.
626- tensor_dict [ key ] = tensor
626+ _update_nested_dict ( tensor_dict , key , tensor )
627627 continue
628628 if tensor .is_cpu :
629629 # use metadata_group for CPU tensors
@@ -633,9 +633,9 @@ def recv_tensor_dict(
633633 else :
634634 # use group for GPU tensors
635635 torch .distributed .recv (tensor , src = src , group = group )
636- tensor_dict [ key ] = tensor
636+ _update_nested_dict ( tensor_dict , key , tensor )
637637 else :
638- tensor_dict [ key ] = value
638+ _update_nested_dict ( tensor_dict , key , value )
639639 return tensor_dict
640640
641641 def barrier (self ):
0 commit comments