Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 71 additions & 17 deletions src/kernels/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from typing import (
TYPE_CHECKING,
Dict,
Mapping,
Optional,
Protocol,
Set,
Tuple,
Type,
Union,
Expand Down Expand Up @@ -868,10 +870,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
raise ValueError("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING.")

if device is None:
device_type = _find_device(model)
device = _find_device(model)
device_type = _find_device_type(model)
elif isinstance(device, str):
_validate_device_type(device)
import torch

device_type = Device(type=device)
device = torch.device(device)
else:
device_type = Device(device.type)

Expand All @@ -884,7 +890,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
layer_name = module_class.kernel_layer_name

if _DISABLE_KERNEL_MAPPING:
_replace_forward(module, module_class)
_replace_forward(device, module, module_class)
continue

kernel = _KERNEL_MAPPING.get().get(str(layer_name))
Expand All @@ -898,7 +904,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)
if not use_fallback:
raise ValueError(f"No layer mapping for `{layer_name}`")
_replace_forward(module, module_class)
_replace_forward(device, module, module_class)
continue

# Get kernel options for the device
Expand All @@ -909,7 +915,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
raise ValueError(
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
)
_replace_forward(module, module_class)
_replace_forward(device, module, module_class)
continue

repos = property_repos.repos
Expand All @@ -919,7 +925,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
raise ValueError(
f"No layer mapping for `{layer_name}` device `{device_type}` with the right properties"
)
_replace_forward(module, module_class)
_replace_forward(device, module, module_class)
continue

repo_with_mode = _select_repository(
Expand All @@ -932,7 +938,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
raise ValueError(
f"No repository for `{layer_name}` for configuration mode={mode}"
)
_replace_forward(module, module_class)
_replace_forward(device, module, module_class)
continue

repo, repo_mode = repo_with_mode
Expand All @@ -951,6 +957,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)

_conditionally_replace_forward(
device=device,
module=module,
layer=layer,
mode=mode,
Expand Down Expand Up @@ -1037,19 +1044,31 @@ def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol):
raise TypeError(f"{repo} must not override nn.Module constructor.")

# ... or predefined member variables.
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
cls_members = {name for name, _ in inspect.getmembers(cls)}
difference = cls_members - torch_module_members
unique_members = _unique_layer_members(cls)
# verify if : difference ⊄ {"can_torch_compile", "has_backward"}
if not difference <= {"can_torch_compile", "has_backward"}:
if not unique_members <= {
"can_torch_compile",
"create_state",
"has_backward",
"forward_with_state",
}:
raise TypeError(
f"{repo} must not contain additional members compared to `{check_cls.__name__}`."
)

# Check whether the forward signatures are similar.
params = inspect.signature(cls.forward).parameters
ref_params = inspect.signature(check_cls.forward).parameters

params: Mapping[str, inspect.Parameter]
if _is_stateful_layer(cls):
params = inspect.signature(cls.forward_with_state).parameters
# Get rid of the mappingproxy.
params = params.copy()
# Remove the state to be able to compare with forward.
del params["state"]
else:
params = inspect.signature(cls.forward).parameters

if len(params) != len(ref_params):
raise TypeError(
f"Forward signature of {repo} does not match `{check_cls.__name__}`: different number of arguments."
Expand All @@ -1074,15 +1093,21 @@ def _is_rocm_platform():
return torch.version.hip is not None


def _find_device(model: "nn.Module") -> Device:
def _find_device(model: "nn.Module") -> torch.device:
try:
param = next(model.parameters())
except StopIteration:
raise ValueError(
"Cannot determine model device, provide as `device` argument to `kernelize`."
)

dev_type = param.device.type
return param.device


def _find_device_type(model: "nn.Module") -> Device:
device = _find_device(model)

dev_type = device.type
if dev_type == "cuda":
# Refine based on actual platform
if _is_rocm_platform():
Expand All @@ -1103,6 +1128,7 @@ def _find_capability() -> int:

def _conditionally_replace_forward(
*,
device: "torch.device",
module: "nn.Module",
layer: Type["nn.Module"],
mode: Mode,
Expand All @@ -1128,15 +1154,25 @@ def _conditionally_replace_forward(
logging.info("Layer does not support torch.compile, using fallback")
if needs_fallback_for_backward:
logging.info("Layer does not support backward, using fallback")
_replace_forward(module, module_class)
_replace_forward(device, module, module_class)
else:
raise ValueError(f"Available kernel does not support mode: {mode}")
else:
_replace_forward(module, layer)
_replace_forward(device, module, layer)


def _replace_forward(module: "nn.Module", layer: Type["nn.Module"]):
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
def _replace_forward(
device: "torch.device", module: "nn.Module", layer: Type["nn.Module"]
):
if _is_stateful_layer(layer):
state = layer.create_state(device, module) # type: ignore[attr-defined]

def forward(self, *args, **kwargs):
return layer.forward_with_state(self, state, *args, **kwargs)

module.forward = MethodType(forward, module)
else:
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]


def _validate_layer_has_mode(
Expand Down Expand Up @@ -1179,3 +1215,21 @@ def _get_layer_memoize(
_CACHED_LAYER[repo] = layer

return layer


def _unique_layer_members(layer: Type["nn.Module"]) -> Set[str]:
import torch.nn as nn

torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
cls_members = {name for name, _ in inspect.getmembers(layer)}
return cls_members - torch_module_members


def _is_stateful_layer(layer: Type[nn.Module]) -> bool:
unique = _unique_layer_members(layer)
is_stateful = "forward_with_state" in unique
if is_stateful and len(unique & {"create_state", "forward_with_state"}) != 2:
raise TypeError(
f"Stateful layer `{layer.__name__}` must implement both `create_state` and `forward_with_state` or neither."
)
return is_stateful
42 changes: 42 additions & 0 deletions tests/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.testing import assert_close

from kernels import (
CUDAProperties,
Expand Down Expand Up @@ -321,6 +322,47 @@ def test_local_layer_repo(device):
assert linear.n_calls == 0


def test_stateful_layer(device):
@use_kernel_forward_from_hub("ReluWithHiddenSize")
class ReluWithHiddenSize(nn.Module):
hidden_size: int

def __init__(self, hidden_size: int):
super().__init__()
self.hidden_size = hidden_size

def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.relu(x)

model = ReluWithHiddenSize(hidden_size=64).to(device)
x = torch.randn((32, 64), device=device)
y_ref = model(x)

with use_kernel_mapping(
{
"ReluWithHiddenSize": {
"cuda": LayerRepository(
repo_id="kernels-test/state-test",
layer_name="StatefulReLU",
),
"xpu": LayerRepository(
repo_id="kernels-test/state-test",
layer_name="StatefulReLU",
),
}
},
inherit_mapping=False,
):
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE, device=device)

y = model(x)
assert_close(y, y_ref)

model = torch.compile(model, fullgraph=True)
y = model(x)
assert_close(y, y_ref)


@pytest.mark.cuda_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
@pytest.mark.parametrize("device", ["cuda"])
Expand Down
Loading