Skip to content

Commit fd6cbc8

Browse files
committed
Add support for stateful layers
1 parent 055a953 commit fd6cbc8

File tree

2 files changed

+108
-17
lines changed

2 files changed

+108
-17
lines changed

src/kernels/layer.py

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Dict,
2020
Optional,
2121
Protocol,
22+
Set,
2223
Tuple,
2324
Type,
2425
Union,
@@ -868,9 +869,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
868869
raise ValueError("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING.")
869870

870871
if device is None:
871-
device_type = _find_device(model)
872+
device = _find_device(model)
873+
device_type = _find_device_type(model)
872874
elif isinstance(device, str):
873875
_validate_device_type(device)
876+
import torch
877+
878+
device = torch.device(device)
874879
device_type = Device(type=device)
875880
else:
876881
device_type = Device(device.type)
@@ -884,7 +889,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
884889
layer_name = module_class.kernel_layer_name
885890

886891
if _DISABLE_KERNEL_MAPPING:
887-
_replace_forward(module, module_class)
892+
_replace_forward(device, module, module_class)
888893
continue
889894

890895
kernel = _KERNEL_MAPPING.get().get(str(layer_name))
@@ -898,7 +903,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
898903
)
899904
if not use_fallback:
900905
raise ValueError(f"No layer mapping for `{layer_name}`")
901-
_replace_forward(module, module_class)
906+
_replace_forward(device, module, module_class)
902907
continue
903908

904909
# Get kernel options for the device
@@ -909,7 +914,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
909914
raise ValueError(
910915
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
911916
)
912-
_replace_forward(module, module_class)
917+
_replace_forward(device, module, module_class)
913918
continue
914919

915920
repos = property_repos.repos
@@ -919,7 +924,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
919924
raise ValueError(
920925
f"No layer mapping for `{layer_name}` device `{device_type}` with the right properties"
921926
)
922-
_replace_forward(module, module_class)
927+
_replace_forward(device, module, module_class)
923928
continue
924929

925930
repo_with_mode = _select_repository(
@@ -932,7 +937,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
932937
raise ValueError(
933938
f"No repository for `{layer_name}` for configuration mode={mode}"
934939
)
935-
_replace_forward(module, module_class)
940+
_replace_forward(device, module, module_class)
936941
continue
937942

938943
repo, repo_mode = repo_with_mode
@@ -951,6 +956,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
951956
)
952957

953958
_conditionally_replace_forward(
959+
device=device,
954960
module=module,
955961
layer=layer,
956962
mode=mode,
@@ -1037,19 +1043,26 @@ def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol):
10371043
raise TypeError(f"{repo} must not override nn.Module constructor.")
10381044

10391045
# ... or predefined member variables.
1040-
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
1041-
cls_members = {name for name, _ in inspect.getmembers(cls)}
1042-
difference = cls_members - torch_module_members
1046+
unique_members = _unique_layer_members(cls)
10431047
# verify if : difference ⊄ {"can_torch_compile", "has_backward"}
1044-
if not difference <= {"can_torch_compile", "has_backward"}:
1048+
if not unique_members <= {
1049+
"can_torch_compile",
1050+
"create_state",
1051+
"has_backward",
1052+
"forward_with_state",
1053+
}:
10451054
raise TypeError(
10461055
f"{repo} must not contain additional members compared to `{check_cls.__name__}`."
10471056
)
10481057

10491058
# Check whether the forward signatures are similar.
1050-
params = inspect.signature(cls.forward).parameters
10511059
ref_params = inspect.signature(check_cls.forward).parameters
10521060

1061+
if _is_stateful_layer(cls):
1062+
params = inspect.signature(cls.forward_with_state).parameters
1063+
else:
1064+
params = inspect.signature(cls.forward).parameters
1065+
10531066
if len(params) != len(ref_params):
10541067
raise TypeError(
10551068
f"Forward signature of {repo} does not match `{check_cls.__name__}`: different number of arguments."
@@ -1074,15 +1087,21 @@ def _is_rocm_platform():
10741087
return torch.version.hip is not None
10751088

10761089

1077-
def _find_device(model: "nn.Module") -> Device:
1090+
def _find_device(model: "nn.Module") -> torch.device:
10781091
try:
10791092
param = next(model.parameters())
10801093
except StopIteration:
10811094
raise ValueError(
10821095
"Cannot determine model device, provide as `device` argument to `kernelize`."
10831096
)
10841097

1085-
dev_type = param.device.type
1098+
return param.device
1099+
1100+
1101+
def _find_device_type(model: "nn.Module") -> Device:
1102+
device = _find_device(model)
1103+
1104+
dev_type = device.type
10861105
if dev_type == "cuda":
10871106
# Refine based on actual platform
10881107
if _is_rocm_platform():
@@ -1103,6 +1122,7 @@ def _find_capability() -> int:
11031122

11041123
def _conditionally_replace_forward(
11051124
*,
1125+
device: "torch.device",
11061126
module: "nn.Module",
11071127
layer: Type["nn.Module"],
11081128
mode: Mode,
@@ -1128,15 +1148,25 @@ def _conditionally_replace_forward(
11281148
logging.info("Layer does not support torch.compile, using fallback")
11291149
if needs_fallback_for_backward:
11301150
logging.info("Layer does not support backward, using fallback")
1131-
_replace_forward(module, module_class)
1151+
_replace_forward(device, module, module_class)
11321152
else:
11331153
raise ValueError(f"Available kernel does not support mode: {mode}")
11341154
else:
1135-
_replace_forward(module, layer)
1155+
_replace_forward(device, module, layer)
11361156

11371157

1138-
def _replace_forward(module: "nn.Module", layer: Type["nn.Module"]):
1139-
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
1158+
def _replace_forward(
1159+
device: "torch.device", module: "nn.Module", layer: Type["nn.Module"]
1160+
):
1161+
if _is_stateful_layer(layer):
1162+
state = layer.create_state(module, device)
1163+
1164+
def forward(self, *args, **kwargs):
1165+
return layer.forward_with_state(self, state, *args, **kwargs)
1166+
1167+
module.forward = forward
1168+
else:
1169+
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
11401170

11411171

11421172
def _validate_layer_has_mode(
@@ -1179,3 +1209,21 @@ def _get_layer_memoize(
11791209
_CACHED_LAYER[repo] = layer
11801210

11811211
return layer
1212+
1213+
1214+
def _unique_layer_members(layer: Type["nn.Module"]) -> Set[str]:
1215+
import torch.nn as nn
1216+
1217+
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
1218+
cls_members = {name for name, _ in inspect.getmembers(layer)}
1219+
return cls_members - torch_module_members
1220+
1221+
1222+
def _is_stateful_layer(layer: Type[nn.Module]) -> bool:
1223+
unique = _unique_layer_members(layer)
1224+
is_stateful = "forward_with_state" in unique
1225+
if is_stateful and len(unique & {"create_state", "forward_with_state"}) != 2:
1226+
raise TypeError(
1227+
f"Stateful layer `{layer.__name__}` must implement both `create_state` and `forward_with_state` or neither."
1228+
)
1229+
return is_stateful

tests/test_layer.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from pathlib import Path
12
import sys
23
from contextlib import nullcontext
34

45
import pytest
56
import torch
67
import torch.nn as nn
78
from torch.nn import functional as F
9+
from torch.testing import assert_close
810

911
from kernels import (
1012
CUDAProperties,
@@ -321,6 +323,47 @@ def test_local_layer_repo(device):
321323
assert linear.n_calls == 0
322324

323325

326+
def test_stateful_layer(device):
327+
@use_kernel_forward_from_hub("ReluWithHiddenSize")
328+
class ReluWithHiddenSize(nn.Module):
329+
hidden_size: int
330+
331+
def __init__(self, hidden_size: int):
332+
super().__init__()
333+
self.hidden_size = hidden_size
334+
335+
def forward(self, x: torch.Tensor) -> torch.Tensor:
336+
return F.relu(x)
337+
338+
model = ReluWithHiddenSize(hidden_size=64).to(device)
339+
x = torch.randn((32, 64), device=device)
340+
y_ref = model(x)
341+
342+
with use_kernel_mapping(
343+
{
344+
"ReluWithHiddenSize": {
345+
"cuda": LayerRepository(
346+
repo_id="kernels-test/state-test",
347+
layer_name="StatefulReLU",
348+
),
349+
"xpu": LayerRepository(
350+
repo_id="kernels-test/state-test",
351+
layer_name="StatefulReLU",
352+
),
353+
}
354+
},
355+
inherit_mapping=False,
356+
):
357+
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE, device=device)
358+
359+
y = model(x)
360+
assert_close(y, y_ref)
361+
362+
model = torch.compile(model, fullgraph=True)
363+
y = model(x)
364+
assert_close(y, y_ref)
365+
366+
324367
@pytest.mark.cuda_only
325368
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
326369
@pytest.mark.parametrize("device", ["cuda"])

0 commit comments

Comments
 (0)