|  | 
| 5 | 5 | import torch | 
| 6 | 6 | import torch.nn as nn | 
| 7 | 7 | 
 | 
|  | 8 | +import vllm.envs as envs | 
| 8 | 9 | from vllm.model_executor.custom_op import CustomOp | 
|  | 10 | +from vllm.platforms import current_platform | 
|  | 11 | + | 
|  | 12 | + | 
|  | 13 | +def is_rocm_aiter_rmsnorm_enabled() -> bool: | 
|  | 14 | +    return current_platform.is_rocm() \ | 
|  | 15 | +        and envs.VLLM_ROCM_USE_AITER_RMSNORM \ | 
|  | 16 | +        and envs.VLLM_ROCM_USE_AITER | 
|  | 17 | + | 
|  | 18 | + | 
|  | 19 | +def rms_norm(x: torch.Tensor, weight: torch.Tensor, | 
|  | 20 | +             variance_epsilon: float) -> torch.Tensor: | 
|  | 21 | +    from vllm import _custom_ops as ops | 
|  | 22 | +    out = torch.empty_like(x) | 
|  | 23 | +    ops.rms_norm( | 
|  | 24 | +        out, | 
|  | 25 | +        x, | 
|  | 26 | +        weight, | 
|  | 27 | +        variance_epsilon, | 
|  | 28 | +    ) | 
|  | 29 | +    return out | 
|  | 30 | + | 
|  | 31 | + | 
|  | 32 | +def fused_add_rms_norm( | 
|  | 33 | +        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, | 
|  | 34 | +        variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]: | 
|  | 35 | +    from vllm import _custom_ops as ops | 
|  | 36 | +    ops.fused_add_rms_norm( | 
|  | 37 | +        x, | 
|  | 38 | +        residual, | 
|  | 39 | +        weight, | 
|  | 40 | +        variance_epsilon, | 
|  | 41 | +    ) | 
|  | 42 | +    return x, residual | 
|  | 43 | + | 
|  | 44 | + | 
|  | 45 | +def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor, | 
|  | 46 | +                        variance_epsilon: float) -> torch.Tensor: | 
|  | 47 | + | 
|  | 48 | +    import aiter as rocm_aiter | 
|  | 49 | +    return rocm_aiter.rms_norm(x, weight, variance_epsilon) | 
|  | 50 | + | 
|  | 51 | + | 
|  | 52 | +def rocm_aiter_fused_add_rms_norm( | 
|  | 53 | +        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, | 
|  | 54 | +        variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]: | 
|  | 55 | + | 
|  | 56 | +    import aiter as rocm_aiter | 
|  | 57 | + | 
|  | 58 | +    # Assuming the correct signature for rmsnorm2d_fwd_with_add | 
|  | 59 | +    rocm_aiter.rmsnorm2d_fwd_with_add( | 
|  | 60 | +        x,  # output | 
|  | 61 | +        x,  # input | 
|  | 62 | +        residual,  # residual input | 
|  | 63 | +        residual,  # residual output | 
|  | 64 | +        weight, | 
|  | 65 | +        variance_epsilon, | 
|  | 66 | +    ) | 
|  | 67 | +    return x, residual | 
|  | 68 | + | 
|  | 69 | + | 
|  | 70 | +def dispatch_cuda_rmsnorm_func(add_residual: bool): | 
|  | 71 | +    if add_residual: | 
|  | 72 | +        if is_rocm_aiter_rmsnorm_enabled(): | 
|  | 73 | +            return rocm_aiter_fused_add_rms_norm | 
|  | 74 | +        return fused_add_rms_norm | 
|  | 75 | + | 
|  | 76 | +    if is_rocm_aiter_rmsnorm_enabled(): | 
|  | 77 | +        return rocm_aiter_rms_norm | 
|  | 78 | +    return rms_norm | 
| 9 | 79 | 
 | 
| 10 | 80 | 
 | 
| 11 | 81 | @CustomOp.register("rms_norm") | 
| @@ -81,24 +151,14 @@ def forward_cuda( | 
| 81 | 151 |         if self.variance_size_override is not None: | 
| 82 | 152 |             return self.forward_native(x, residual) | 
| 83 | 153 | 
 | 
| 84 |  | -        from vllm import _custom_ops as ops | 
|  | 154 | +        add_residual = residual is not None | 
|  | 155 | +        norm_func = dispatch_cuda_rmsnorm_func(add_residual) | 
| 85 | 156 | 
 | 
| 86 |  | -        if residual is not None: | 
| 87 |  | -            ops.fused_add_rms_norm( | 
| 88 |  | -                x, | 
| 89 |  | -                residual, | 
| 90 |  | -                self.weight.data, | 
| 91 |  | -                self.variance_epsilon, | 
| 92 |  | -            ) | 
| 93 |  | -            return x, residual | 
| 94 |  | -        out = torch.empty_like(x) | 
| 95 |  | -        ops.rms_norm( | 
| 96 |  | -            out, | 
| 97 |  | -            x, | 
| 98 |  | -            self.weight.data, | 
| 99 |  | -            self.variance_epsilon, | 
| 100 |  | -        ) | 
| 101 |  | -        return out | 
|  | 157 | +        if add_residual: | 
|  | 158 | +            return norm_func(x, residual, self.weight.data, | 
|  | 159 | +                             self.variance_epsilon) | 
|  | 160 | +        else: | 
|  | 161 | +            return norm_func(x, self.weight.data, self.variance_epsilon) | 
| 102 | 162 | 
 | 
| 103 | 163 |     def forward_hpu( | 
| 104 | 164 |         self, | 
|  | 
0 commit comments