@@ -50,19 +50,65 @@ def __init__(
5050 self .input_kwargs_to_split = set (input_kwargs_to_split )
5151
5252 def forward (self , * args , ** kwargs ) -> Union [torch .Tensor , Tuple [torch .Tensor ]]:
53- r"""Forward method of `SplitInferenceModule`.
53+ r"""Forward method for the `SplitInferenceModule`.
5454
55- All inputs that should be split should be passed as keyword arguments. Only those keywords arguments will be
56- split that are specified in `inputs_to_split` when initializing the module.
55+ This method processes the input by splitting specified keyword arguments along a given dimension, running the
56+ underlying module on each split, and then concatenating the results. The splitting is controlled by the
57+ `split_size` and `split_dim` parameters specified during initialization.
58+
59+ Args:
60+ *args (`Any`):
61+ Positional arguments that are passed directly to the `module` without modification.
62+ **kwargs (`Dict[str, torch.Tensor]`):
63+ Keyword arguments passed to the underlying `module`. Only keyword arguments whose names match the
64+ entries in `input_kwargs_to_split` and are of type `torch.Tensor` will be split. The remaining keyword
65+ arguments are passed unchanged.
66+
67+ Returns:
68+ `Union[torch.Tensor, Tuple[torch.Tensor]]`:
69+ The outputs obtained from `SplitInferenceModule` are the same as if the underlying module was inferred
70+ without it.
71+ - If the underlying module returns a single tensor, the result will be a single concatenated tensor
72+ along the same `split_dim` after processing all splits.
73+ - If the underlying module returns a tuple of tensors, each element of the tuple will be concatenated
74+ along the `split_dim` across all splits, and the final result will be a tuple of concatenated tensors.
75+
76+ Workflow:
77+ 1. The keyword arguments specified in `input_kwargs_to_split` are split into smaller chunks using
78+ `torch.split()` along the dimension `split_dim` and with a chunk size of `split_size`.
79+ 2. The `module` is invoked once for each split with both the split inputs and any unchanged arguments
80+ that were passed.
81+ 3. The output tensors from each split are concatenated back together along `split_dim` before returning.
82+
83+ Example:
84+ ```python
85+ >>> import torch
86+
87+ >>> model = nn.Linear(1000, 1000)
88+ >>> split_module = SplitInferenceModule(
89+ ... model, split_size=2, split_dim=0, input_kwargs_to_split=["input_data"]
90+ ... )
91+
92+ >>> input_tensor = torch.randn(42, 1000)
93+ >>> # Will split the tensor into 21 slices of shape [2, 1000].
94+ >>> output = split_module(input_data=input_tensor)
95+ ```
96+
97+ This method is useful when you need to perform inference on large tensors in a memory-efficient way by breaking
98+ them into smaller chunks, processing each chunk separately, and then reassembling the results.
99+
100+ It is also possible to nest `SplitInferenceModule` across different split dimensions.
57101 """
58102 split_inputs = {}
59103
104+ # 1. Split inputs that were specified during initialization and also present in passed kwargs
60105 for key in list (kwargs .keys ()):
61106 if key not in self .input_kwargs_to_split or not torch .is_tensor (kwargs [key ]):
62107 continue
63108 split_inputs [key ] = torch .split (kwargs [key ], self .split_size , self .split_dim )
64109 kwargs .pop (key )
65110
111+ # 2. Invoke forward pass across each split
66112 results = []
67113 for split_input in zip (* split_inputs .values ()):
68114 inputs = dict (zip (split_inputs .keys (), split_input ))
@@ -71,6 +117,7 @@ def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
71117 intermediate_tensor_or_tensor_tuple = self .module (* args , ** inputs )
72118 results .append (intermediate_tensor_or_tensor_tuple )
73119
120+ # 3. Concatenate split restuls to obtain final outputs
74121 if isinstance (results [0 ], torch .Tensor ):
75122 return torch .cat (results , dim = self .split_dim )
76123 elif isinstance (results [0 ], tuple ):
0 commit comments