Skip to content

Commit addcb24

Browse files
metal lowbit kernels: add meta kernels
Differential Revision: D65234032 Pull Request resolved: #1262
1 parent 657ebbb commit addcb24

File tree

1 file changed

+58
-11
lines changed

1 file changed

+58
-11
lines changed

torchao/experimental/ops/mps/register.mm

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,15 @@
1717

1818
// LowBit Quantized Linear on MPS Backend
1919
template <int nbit>
20-
Tensor linear_mps_kernel(
20+
void check_linear_mps_args(
2121
const Tensor& A,
2222
const Tensor& B,
2323
int64_t group_size,
2424
const Tensor& S,
2525
const Tensor& Z) {
26-
auto M = A.size(0);
2726
auto N = B.size(0);
2827
auto K = A.size(1);
2928

30-
TORCH_CHECK(
31-
A.is_mps(), __func__, "A is on ", A.device(), " but expected on mps");
32-
TORCH_CHECK(
33-
B.is_mps(), __func__, "B is on ", B.device(), " but expected on mps");
34-
TORCH_CHECK(
35-
S.is_mps(), __func__, "S is on ", S.device(), " but expected on mps");
36-
TORCH_CHECK(
37-
Z.is_mps(), __func__, "Z is on ", Z.device(), " but expected on mps");
38-
3929
TORCH_CHECK(
4030
A.dtype() == at::kBFloat16 || A.dtype() == at::kHalf ||
4131
A.dtype() == at::kFloat,
@@ -75,6 +65,29 @@ Tensor linear_mps_kernel(
7565
": expect Z to be 2d tensor with shape [:, ",
7666
N,
7767
"]");
68+
}
69+
70+
template <int nbit>
71+
Tensor linear_mps_kernel(
72+
const Tensor& A,
73+
const Tensor& B,
74+
int64_t group_size,
75+
const Tensor& S,
76+
const Tensor& Z) {
77+
TORCH_CHECK(
78+
A.is_mps(), __func__, ": A is on ", A.device(), " but expected on mps");
79+
TORCH_CHECK(
80+
B.is_mps(), __func__, ": B is on ", B.device(), " but expected on mps");
81+
TORCH_CHECK(
82+
S.is_mps(), __func__, ": S is on ", S.device(), " but expected on mps");
83+
TORCH_CHECK(
84+
Z.is_mps(), __func__, ": Z is on ", Z.device(), " but expected on mps");
85+
86+
check_linear_mps_args<nbit>(A, B, group_size, S, Z);
87+
88+
auto M = A.size(0);
89+
auto N = B.size(0);
90+
auto K = A.size(1);
7891

7992
auto C = at::empty({M, N}, A.options());
8093

@@ -93,6 +106,30 @@ Tensor linear_mps_kernel(
93106
return C;
94107
}
95108

109+
template <int nbit>
110+
Tensor linear_mps_kernel_meta(
111+
const Tensor& A,
112+
const Tensor& B,
113+
int64_t group_size,
114+
const Tensor& S,
115+
const Tensor& Z) {
116+
TORCH_CHECK(
117+
A.is_meta(), __func__, ": A is on ", A.device(), " but expected on meta");
118+
TORCH_CHECK(
119+
B.is_meta(), __func__, ": B is on ", B.device(), " but expected on meta");
120+
TORCH_CHECK(
121+
S.is_meta(), __func__, ": S is on ", S.device(), " but expected on meta");
122+
TORCH_CHECK(
123+
Z.is_meta(), __func__, ": Z is on ", Z.device(), " but expected on meta");
124+
125+
check_linear_mps_args<nbit>(A, B, group_size, S, Z);
126+
127+
auto M = A.size(0);
128+
auto N = B.size(0);
129+
130+
return at::empty({M, N}, A.options()).to("meta");
131+
}
132+
96133
// LowBit Packing on CPU Backend
97134
template <int nbit>
98135
Tensor pack_weights_cpu_kernel(const Tensor& W) {
@@ -155,4 +192,14 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) {
155192
m.impl("_linear_fp_act_7bit_weight", &linear_mps_kernel<7>);
156193
}
157194

195+
TORCH_LIBRARY_IMPL(torchao, Meta, m) {
196+
m.impl("_linear_fp_act_1bit_weight", &linear_mps_kernel_meta<1>);
197+
m.impl("_linear_fp_act_2bit_weight", &linear_mps_kernel_meta<2>);
198+
m.impl("_linear_fp_act_3bit_weight", &linear_mps_kernel_meta<3>);
199+
m.impl("_linear_fp_act_4bit_weight", &linear_mps_kernel_meta<4>);
200+
m.impl("_linear_fp_act_5bit_weight", &linear_mps_kernel_meta<5>);
201+
m.impl("_linear_fp_act_6bit_weight", &linear_mps_kernel_meta<6>);
202+
m.impl("_linear_fp_act_7bit_weight", &linear_mps_kernel_meta<7>);
203+
}
204+
158205
} // namespace torchao::kernels::mps::lowbit::aten

0 commit comments

Comments
 (0)