diff --git a/src/kernels/layer.py b/src/kernels/layer.py index 9032b79..3763abc 100644 --- a/src/kernels/layer.py +++ b/src/kernels/layer.py @@ -17,8 +17,10 @@ from typing import ( TYPE_CHECKING, Dict, + Mapping, Optional, Protocol, + Set, Tuple, Type, Union, @@ -868,10 +870,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: raise ValueError("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING.") if device is None: - device_type = _find_device(model) + device = _find_device(model) + device_type = _find_device_type(model) elif isinstance(device, str): _validate_device_type(device) + import torch + device_type = Device(type=device) + device = torch.device(device) else: device_type = Device(device.type) @@ -884,7 +890,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: layer_name = module_class.kernel_layer_name if _DISABLE_KERNEL_MAPPING: - _replace_forward(module, module_class) + _replace_forward(device, module, module_class) continue kernel = _KERNEL_MAPPING.get().get(str(layer_name)) @@ -898,7 +904,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) if not use_fallback: raise ValueError(f"No layer mapping for `{layer_name}`") - _replace_forward(module, module_class) + _replace_forward(device, module, module_class) continue # Get kernel options for the device @@ -909,7 +915,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: raise ValueError( f"No layer mapping for `{layer_name}` with device type `{device_type}`" ) - _replace_forward(module, module_class) + _replace_forward(device, module, module_class) continue repos = property_repos.repos @@ -919,7 +925,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: raise ValueError( f"No layer mapping for `{layer_name}` device `{device_type}` with the right properties" ) - _replace_forward(module, module_class) + _replace_forward(device, module, module_class) continue repo_with_mode = _select_repository( @@ -932,7 +938,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: raise ValueError( f"No repository for `{layer_name}` for configuration mode={mode}" ) - _replace_forward(module, module_class) + _replace_forward(device, module, module_class) continue repo, repo_mode = repo_with_mode @@ -951,6 +957,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) _conditionally_replace_forward( + device=device, module=module, layer=layer, mode=mode, @@ -1037,19 +1044,31 @@ def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol): raise TypeError(f"{repo} must not override nn.Module constructor.") # ... or predefined member variables. - torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)} - cls_members = {name for name, _ in inspect.getmembers(cls)} - difference = cls_members - torch_module_members + unique_members = _unique_layer_members(cls) # verify if : difference ⊄ {"can_torch_compile", "has_backward"} - if not difference <= {"can_torch_compile", "has_backward"}: + if not unique_members <= { + "can_torch_compile", + "create_state", + "has_backward", + "forward_with_state", + }: raise TypeError( f"{repo} must not contain additional members compared to `{check_cls.__name__}`." ) # Check whether the forward signatures are similar. - params = inspect.signature(cls.forward).parameters ref_params = inspect.signature(check_cls.forward).parameters + params: Mapping[str, inspect.Parameter] + if _is_stateful_layer(cls): + params = inspect.signature(cls.forward_with_state).parameters + # Get rid of the mappingproxy. + params = params.copy() + # Remove the state to be able to compare with forward. + del params["state"] + else: + params = inspect.signature(cls.forward).parameters + if len(params) != len(ref_params): raise TypeError( f"Forward signature of {repo} does not match `{check_cls.__name__}`: different number of arguments." @@ -1074,7 +1093,7 @@ def _is_rocm_platform(): return torch.version.hip is not None -def _find_device(model: "nn.Module") -> Device: +def _find_device(model: "nn.Module") -> torch.device: try: param = next(model.parameters()) except StopIteration: @@ -1082,7 +1101,13 @@ def _find_device(model: "nn.Module") -> Device: "Cannot determine model device, provide as `device` argument to `kernelize`." ) - dev_type = param.device.type + return param.device + + +def _find_device_type(model: "nn.Module") -> Device: + device = _find_device(model) + + dev_type = device.type if dev_type == "cuda": # Refine based on actual platform if _is_rocm_platform(): @@ -1103,6 +1128,7 @@ def _find_capability() -> int: def _conditionally_replace_forward( *, + device: "torch.device", module: "nn.Module", layer: Type["nn.Module"], mode: Mode, @@ -1128,15 +1154,25 @@ def _conditionally_replace_forward( logging.info("Layer does not support torch.compile, using fallback") if needs_fallback_for_backward: logging.info("Layer does not support backward, using fallback") - _replace_forward(module, module_class) + _replace_forward(device, module, module_class) else: raise ValueError(f"Available kernel does not support mode: {mode}") else: - _replace_forward(module, layer) + _replace_forward(device, module, layer) -def _replace_forward(module: "nn.Module", layer: Type["nn.Module"]): - module.forward = MethodType(layer.forward, module) # type: ignore[method-assign] +def _replace_forward( + device: "torch.device", module: "nn.Module", layer: Type["nn.Module"] +): + if _is_stateful_layer(layer): + state = layer.create_state(device, module) # type: ignore[attr-defined] + + def forward(self, *args, **kwargs): + return layer.forward_with_state(self, state, *args, **kwargs) + + module.forward = MethodType(forward, module) + else: + module.forward = MethodType(layer.forward, module) # type: ignore[method-assign] def _validate_layer_has_mode( @@ -1179,3 +1215,21 @@ def _get_layer_memoize( _CACHED_LAYER[repo] = layer return layer + + +def _unique_layer_members(layer: Type["nn.Module"]) -> Set[str]: + import torch.nn as nn + + torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)} + cls_members = {name for name, _ in inspect.getmembers(layer)} + return cls_members - torch_module_members + + +def _is_stateful_layer(layer: Type[nn.Module]) -> bool: + unique = _unique_layer_members(layer) + is_stateful = "forward_with_state" in unique + if is_stateful and len(unique & {"create_state", "forward_with_state"}) != 2: + raise TypeError( + f"Stateful layer `{layer.__name__}` must implement both `create_state` and `forward_with_state` or neither." + ) + return is_stateful diff --git a/tests/test_layer.py b/tests/test_layer.py index 7bfffca..b807f45 100644 --- a/tests/test_layer.py +++ b/tests/test_layer.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn from torch.nn import functional as F +from torch.testing import assert_close from kernels import ( CUDAProperties, @@ -321,6 +322,47 @@ def test_local_layer_repo(device): assert linear.n_calls == 0 +def test_stateful_layer(device): + @use_kernel_forward_from_hub("ReluWithHiddenSize") + class ReluWithHiddenSize(nn.Module): + hidden_size: int + + def __init__(self, hidden_size: int): + super().__init__() + self.hidden_size = hidden_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.relu(x) + + model = ReluWithHiddenSize(hidden_size=64).to(device) + x = torch.randn((32, 64), device=device) + y_ref = model(x) + + with use_kernel_mapping( + { + "ReluWithHiddenSize": { + "cuda": LayerRepository( + repo_id="kernels-test/state-test", + layer_name="StatefulReLU", + ), + "xpu": LayerRepository( + repo_id="kernels-test/state-test", + layer_name="StatefulReLU", + ), + } + }, + inherit_mapping=False, + ): + model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE, device=device) + + y = model(x) + assert_close(y, y_ref) + + model = torch.compile(model, fullgraph=True) + y = model(x) + assert_close(y, y_ref) + + @pytest.mark.cuda_only @pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel]) @pytest.mark.parametrize("device", ["cuda"])