19
19
Dict ,
20
20
Optional ,
21
21
Protocol ,
22
+ Set ,
22
23
Tuple ,
23
24
Type ,
24
25
Union ,
@@ -868,9 +869,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
868
869
raise ValueError ("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING." )
869
870
870
871
if device is None :
871
- device_type = _find_device (model )
872
+ device = _find_device (model )
873
+ device_type = _find_device_type (model )
872
874
elif isinstance (device , str ):
873
875
_validate_device_type (device )
876
+ import torch
877
+
878
+ device = torch .device (device )
874
879
device_type = Device (type = device )
875
880
else :
876
881
device_type = Device (device .type )
@@ -884,7 +889,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
884
889
layer_name = module_class .kernel_layer_name
885
890
886
891
if _DISABLE_KERNEL_MAPPING :
887
- _replace_forward (module , module_class )
892
+ _replace_forward (device , module , module_class )
888
893
continue
889
894
890
895
kernel = _KERNEL_MAPPING .get ().get (str (layer_name ))
@@ -898,7 +903,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
898
903
)
899
904
if not use_fallback :
900
905
raise ValueError (f"No layer mapping for `{ layer_name } `" )
901
- _replace_forward (module , module_class )
906
+ _replace_forward (device , module , module_class )
902
907
continue
903
908
904
909
# Get kernel options for the device
@@ -909,7 +914,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
909
914
raise ValueError (
910
915
f"No layer mapping for `{ layer_name } ` with device type `{ device_type } `"
911
916
)
912
- _replace_forward (module , module_class )
917
+ _replace_forward (device , module , module_class )
913
918
continue
914
919
915
920
repos = property_repos .repos
@@ -919,7 +924,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
919
924
raise ValueError (
920
925
f"No layer mapping for `{ layer_name } ` device `{ device_type } ` with the right properties"
921
926
)
922
- _replace_forward (module , module_class )
927
+ _replace_forward (device , module , module_class )
923
928
continue
924
929
925
930
repo_with_mode = _select_repository (
@@ -932,7 +937,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
932
937
raise ValueError (
933
938
f"No repository for `{ layer_name } ` for configuration mode={ mode } "
934
939
)
935
- _replace_forward (module , module_class )
940
+ _replace_forward (device , module , module_class )
936
941
continue
937
942
938
943
repo , repo_mode = repo_with_mode
@@ -951,6 +956,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
951
956
)
952
957
953
958
_conditionally_replace_forward (
959
+ device = device ,
954
960
module = module ,
955
961
layer = layer ,
956
962
mode = mode ,
@@ -1037,19 +1043,26 @@ def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol):
1037
1043
raise TypeError (f"{ repo } must not override nn.Module constructor." )
1038
1044
1039
1045
# ... or predefined member variables.
1040
- torch_module_members = {name for name , _ in inspect .getmembers (nn .Module )}
1041
- cls_members = {name for name , _ in inspect .getmembers (cls )}
1042
- difference = cls_members - torch_module_members
1046
+ unique_members = _unique_layer_members (cls )
1043
1047
# verify if : difference ⊄ {"can_torch_compile", "has_backward"}
1044
- if not difference <= {"can_torch_compile" , "has_backward" }:
1048
+ if not unique_members <= {
1049
+ "can_torch_compile" ,
1050
+ "create_state" ,
1051
+ "has_backward" ,
1052
+ "forward_with_state" ,
1053
+ }:
1045
1054
raise TypeError (
1046
1055
f"{ repo } must not contain additional members compared to `{ check_cls .__name__ } `."
1047
1056
)
1048
1057
1049
1058
# Check whether the forward signatures are similar.
1050
- params = inspect .signature (cls .forward ).parameters
1051
1059
ref_params = inspect .signature (check_cls .forward ).parameters
1052
1060
1061
+ if _is_stateful_layer (cls ):
1062
+ params = inspect .signature (cls .forward_with_state ).parameters
1063
+ else :
1064
+ params = inspect .signature (cls .forward ).parameters
1065
+
1053
1066
if len (params ) != len (ref_params ):
1054
1067
raise TypeError (
1055
1068
f"Forward signature of { repo } does not match `{ check_cls .__name__ } `: different number of arguments."
@@ -1074,15 +1087,21 @@ def _is_rocm_platform():
1074
1087
return torch .version .hip is not None
1075
1088
1076
1089
1077
- def _find_device (model : "nn.Module" ) -> Device :
1090
+ def _find_device (model : "nn.Module" ) -> torch . device :
1078
1091
try :
1079
1092
param = next (model .parameters ())
1080
1093
except StopIteration :
1081
1094
raise ValueError (
1082
1095
"Cannot determine model device, provide as `device` argument to `kernelize`."
1083
1096
)
1084
1097
1085
- dev_type = param .device .type
1098
+ return param .device
1099
+
1100
+
1101
+ def _find_device_type (model : "nn.Module" ) -> Device :
1102
+ device = _find_device (model )
1103
+
1104
+ dev_type = device .type
1086
1105
if dev_type == "cuda" :
1087
1106
# Refine based on actual platform
1088
1107
if _is_rocm_platform ():
@@ -1103,6 +1122,7 @@ def _find_capability() -> int:
1103
1122
1104
1123
def _conditionally_replace_forward (
1105
1124
* ,
1125
+ device : "torch.device" ,
1106
1126
module : "nn.Module" ,
1107
1127
layer : Type ["nn.Module" ],
1108
1128
mode : Mode ,
@@ -1128,15 +1148,25 @@ def _conditionally_replace_forward(
1128
1148
logging .info ("Layer does not support torch.compile, using fallback" )
1129
1149
if needs_fallback_for_backward :
1130
1150
logging .info ("Layer does not support backward, using fallback" )
1131
- _replace_forward (module , module_class )
1151
+ _replace_forward (device , module , module_class )
1132
1152
else :
1133
1153
raise ValueError (f"Available kernel does not support mode: { mode } " )
1134
1154
else :
1135
- _replace_forward (module , layer )
1155
+ _replace_forward (device , module , layer )
1136
1156
1137
1157
1138
- def _replace_forward (module : "nn.Module" , layer : Type ["nn.Module" ]):
1139
- module .forward = MethodType (layer .forward , module ) # type: ignore[method-assign]
1158
+ def _replace_forward (
1159
+ device : "torch.device" , module : "nn.Module" , layer : Type ["nn.Module" ]
1160
+ ):
1161
+ if _is_stateful_layer (layer ):
1162
+ state = layer .create_state (module , device )
1163
+
1164
+ def forward (self , * args , ** kwargs ):
1165
+ return layer .forward_with_state (self , state , * args , ** kwargs )
1166
+
1167
+ module .forward = forward
1168
+ else :
1169
+ module .forward = MethodType (layer .forward , module ) # type: ignore[method-assign]
1140
1170
1141
1171
1142
1172
def _validate_layer_has_mode (
@@ -1179,3 +1209,21 @@ def _get_layer_memoize(
1179
1209
_CACHED_LAYER [repo ] = layer
1180
1210
1181
1211
return layer
1212
+
1213
+
1214
+ def _unique_layer_members (layer : Type ["nn.Module" ]) -> Set [str ]:
1215
+ import torch .nn as nn
1216
+
1217
+ torch_module_members = {name for name , _ in inspect .getmembers (nn .Module )}
1218
+ cls_members = {name for name , _ in inspect .getmembers (layer )}
1219
+ return cls_members - torch_module_members
1220
+
1221
+
1222
+ def _is_stateful_layer (layer : Type [nn .Module ]) -> bool :
1223
+ unique = _unique_layer_members (layer )
1224
+ is_stateful = "forward_with_state" in unique
1225
+ if is_stateful and len (unique & {"create_state" , "forward_with_state" }) != 2 :
1226
+ raise TypeError (
1227
+ f"Stateful layer `{ layer .__name__ } ` must implement both `create_state` and `forward_with_state` or neither."
1228
+ )
1229
+ return is_stateful
0 commit comments