1717from typing import (
1818 TYPE_CHECKING ,
1919 Dict ,
20+ Mapping ,
2021 Optional ,
2122 Protocol ,
23+ Set ,
2224 Tuple ,
2325 Type ,
2426 Union ,
@@ -868,10 +870,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
868870 raise ValueError ("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING." )
869871
870872 if device is None :
871- device_type = _find_device (model )
873+ device = _find_device (model )
874+ device_type = _find_device_type (model )
872875 elif isinstance (device , str ):
873876 _validate_device_type (device )
877+ import torch
878+
874879 device_type = Device (type = device )
880+ device = torch .device (device )
875881 else :
876882 device_type = Device (device .type )
877883
@@ -884,7 +890,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
884890 layer_name = module_class .kernel_layer_name
885891
886892 if _DISABLE_KERNEL_MAPPING :
887- _replace_forward (module , module_class )
893+ _replace_forward (device , module , module_class )
888894 continue
889895
890896 kernel = _KERNEL_MAPPING .get ().get (str (layer_name ))
@@ -898,7 +904,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
898904 )
899905 if not use_fallback :
900906 raise ValueError (f"No layer mapping for `{ layer_name } `" )
901- _replace_forward (module , module_class )
907+ _replace_forward (device , module , module_class )
902908 continue
903909
904910 # Get kernel options for the device
@@ -909,7 +915,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
909915 raise ValueError (
910916 f"No layer mapping for `{ layer_name } ` with device type `{ device_type } `"
911917 )
912- _replace_forward (module , module_class )
918+ _replace_forward (device , module , module_class )
913919 continue
914920
915921 repos = property_repos .repos
@@ -919,7 +925,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
919925 raise ValueError (
920926 f"No layer mapping for `{ layer_name } ` device `{ device_type } ` with the right properties"
921927 )
922- _replace_forward (module , module_class )
928+ _replace_forward (device , module , module_class )
923929 continue
924930
925931 repo_with_mode = _select_repository (
@@ -932,7 +938,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
932938 raise ValueError (
933939 f"No repository for `{ layer_name } ` for configuration mode={ mode } "
934940 )
935- _replace_forward (module , module_class )
941+ _replace_forward (device , module , module_class )
936942 continue
937943
938944 repo , repo_mode = repo_with_mode
@@ -951,6 +957,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
951957 )
952958
953959 _conditionally_replace_forward (
960+ device = device ,
954961 module = module ,
955962 layer = layer ,
956963 mode = mode ,
@@ -1037,19 +1044,31 @@ def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol):
10371044 raise TypeError (f"{ repo } must not override nn.Module constructor." )
10381045
10391046 # ... or predefined member variables.
1040- torch_module_members = {name for name , _ in inspect .getmembers (nn .Module )}
1041- cls_members = {name for name , _ in inspect .getmembers (cls )}
1042- difference = cls_members - torch_module_members
1047+ unique_members = _unique_layer_members (cls )
10431048 # verify if : difference ⊄ {"can_torch_compile", "has_backward"}
1044- if not difference <= {"can_torch_compile" , "has_backward" }:
1049+ if not unique_members <= {
1050+ "can_torch_compile" ,
1051+ "create_state" ,
1052+ "has_backward" ,
1053+ "forward_with_state" ,
1054+ }:
10451055 raise TypeError (
10461056 f"{ repo } must not contain additional members compared to `{ check_cls .__name__ } `."
10471057 )
10481058
10491059 # Check whether the forward signatures are similar.
1050- params = inspect .signature (cls .forward ).parameters
10511060 ref_params = inspect .signature (check_cls .forward ).parameters
10521061
1062+ params : Mapping [str , inspect .Parameter ]
1063+ if _is_stateful_layer (cls ):
1064+ params = inspect .signature (cls .forward_with_state ).parameters
1065+ # Get rid of the mappingproxy.
1066+ params = params .copy ()
1067+ # Remove the state to be able to compare with forward.
1068+ del params ["state" ]
1069+ else :
1070+ params = inspect .signature (cls .forward ).parameters
1071+
10531072 if len (params ) != len (ref_params ):
10541073 raise TypeError (
10551074 f"Forward signature of { repo } does not match `{ check_cls .__name__ } `: different number of arguments."
@@ -1074,15 +1093,21 @@ def _is_rocm_platform():
10741093 return torch .version .hip is not None
10751094
10761095
1077- def _find_device (model : "nn.Module" ) -> Device :
1096+ def _find_device (model : "nn.Module" ) -> torch . device :
10781097 try :
10791098 param = next (model .parameters ())
10801099 except StopIteration :
10811100 raise ValueError (
10821101 "Cannot determine model device, provide as `device` argument to `kernelize`."
10831102 )
10841103
1085- dev_type = param .device .type
1104+ return param .device
1105+
1106+
1107+ def _find_device_type (model : "nn.Module" ) -> Device :
1108+ device = _find_device (model )
1109+
1110+ dev_type = device .type
10861111 if dev_type == "cuda" :
10871112 # Refine based on actual platform
10881113 if _is_rocm_platform ():
@@ -1103,6 +1128,7 @@ def _find_capability() -> int:
11031128
11041129def _conditionally_replace_forward (
11051130 * ,
1131+ device : "torch.device" ,
11061132 module : "nn.Module" ,
11071133 layer : Type ["nn.Module" ],
11081134 mode : Mode ,
@@ -1128,15 +1154,25 @@ def _conditionally_replace_forward(
11281154 logging .info ("Layer does not support torch.compile, using fallback" )
11291155 if needs_fallback_for_backward :
11301156 logging .info ("Layer does not support backward, using fallback" )
1131- _replace_forward (module , module_class )
1157+ _replace_forward (device , module , module_class )
11321158 else :
11331159 raise ValueError (f"Available kernel does not support mode: { mode } " )
11341160 else :
1135- _replace_forward (module , layer )
1161+ _replace_forward (device , module , layer )
11361162
11371163
1138- def _replace_forward (module : "nn.Module" , layer : Type ["nn.Module" ]):
1139- module .forward = MethodType (layer .forward , module ) # type: ignore[method-assign]
1164+ def _replace_forward (
1165+ device : "torch.device" , module : "nn.Module" , layer : Type ["nn.Module" ]
1166+ ):
1167+ if _is_stateful_layer (layer ):
1168+ state = layer .create_state (device , module ) # type: ignore[attr-defined]
1169+
1170+ def forward (self , * args , ** kwargs ):
1171+ return layer .forward_with_state (self , state , * args , ** kwargs )
1172+
1173+ module .forward = MethodType (forward , module )
1174+ else :
1175+ module .forward = MethodType (layer .forward , module ) # type: ignore[method-assign]
11401176
11411177
11421178def _validate_layer_has_mode (
@@ -1179,3 +1215,21 @@ def _get_layer_memoize(
11791215 _CACHED_LAYER [repo ] = layer
11801216
11811217 return layer
1218+
1219+
1220+ def _unique_layer_members (layer : Type ["nn.Module" ]) -> Set [str ]:
1221+ import torch .nn as nn
1222+
1223+ torch_module_members = {name for name , _ in inspect .getmembers (nn .Module )}
1224+ cls_members = {name for name , _ in inspect .getmembers (layer )}
1225+ return cls_members - torch_module_members
1226+
1227+
1228+ def _is_stateful_layer (layer : Type [nn .Module ]) -> bool :
1229+ unique = _unique_layer_members (layer )
1230+ is_stateful = "forward_with_state" in unique
1231+ if is_stateful and len (unique & {"create_state" , "forward_with_state" }) != 2 :
1232+ raise TypeError (
1233+ f"Stateful layer `{ layer .__name__ } ` must implement both `create_state` and `forward_with_state` or neither."
1234+ )
1235+ return is_stateful
0 commit comments