@@ -45,14 +45,17 @@ class GraphCaptureContext:
4545
4646
4747def _split_tensor_dict (
48- tensor_dict : Dict [Any , Union [torch .Tensor , Any ]]
49- ) -> Tuple [List [Tuple [str , Any ]], List [torch .Tensor ]]:
48+ tensor_dict : Dict [Any , Union [torch .Tensor , Any ]],
49+ prefix : str = "" ) -> Tuple [List [Tuple [str , Any ]], List [torch .Tensor ]]:
5050 """Split the tensor dictionary into two parts:
5151 1. A list of (key, value) pairs. If the value is a tensor, it is replaced
5252 by its metadata.
5353 2. A list of tensors.
54+
55+ If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its
56+ metadata will be "key1%key2".
5457 """
55- metadata_list = []
58+ metadata_list : List [ Tuple [ str , Any ]] = []
5659 tensor_list = []
5760 for key , value in tensor_dict .items ():
5861 if isinstance (value , torch .Tensor ):
@@ -62,13 +65,31 @@ def _split_tensor_dict(
6265 # receiving side will set the device index.
6366 device = value .device .type
6467 metadata_list .append (
65- (key , TensorMetadata (device , value .dtype , value .size ())))
68+ (prefix + key , TensorMetadata (device , value .dtype ,
69+ value .size ())))
6670 tensor_list .append (value )
71+ elif isinstance (value , dict ):
72+ if len (value ) == 0 :
73+ metadata_list .append ((prefix + key , value ))
74+ inner_metadata_list , inner_tensor_list = _split_tensor_dict (
75+ value , prefix + key + "%" )
76+ metadata_list .extend (inner_metadata_list )
77+ tensor_list .extend (inner_tensor_list )
6778 else :
68- metadata_list .append ((key , value ))
79+ metadata_list .append ((prefix + key , value ))
6980 return metadata_list , tensor_list
7081
7182
83+ def _update_nested_dict (nested_dict , flattened_key , value ):
84+ key_splits = flattened_key .split ("%" )
85+ cur_dict = nested_dict
86+ for k in key_splits [:- 1 ]:
87+ if k not in cur_dict :
88+ cur_dict [k ] = {}
89+ cur_dict = cur_dict [k ]
90+ cur_dict [key_splits [- 1 ]] = value
91+
92+
7293class GroupCoordinator :
7394 """
7495 PyTorch ProcessGroup wrapper for a group of processes.
@@ -512,7 +533,7 @@ def broadcast_tensor_dict(
512533 device = value .device )
513534 if tensor .numel () == 0 :
514535 # Skip broadcasting empty tensors.
515- tensor_dict [ key ] = tensor
536+ _update_nested_dict ( tensor_dict , key , tensor )
516537 continue
517538 if tensor .is_cpu :
518539 # use metadata_group for CPU tensors
@@ -528,9 +549,9 @@ def broadcast_tensor_dict(
528549 group = group ,
529550 async_op = True )
530551 async_handles .append (handle )
531- tensor_dict [ key ] = tensor
552+ _update_nested_dict ( tensor_dict , key , tensor )
532553 else :
533- tensor_dict [ key ] = value
554+ _update_nested_dict ( tensor_dict , key , value )
534555 for async_handle in async_handles :
535556 async_handle .wait ()
536557 return tensor_dict
0 commit comments