1717
1818// LowBit Quantized Linear on MPS Backend
1919template <int nbit>
20- Tensor linear_mps_kernel (
20+ void check_linear_mps_args (
2121 const Tensor& A,
2222 const Tensor& B,
2323 int64_t group_size,
2424 const Tensor& S,
2525 const Tensor& Z) {
26- auto M = A.size (0 );
2726 auto N = B.size (0 );
2827 auto K = A.size (1 );
2928
30- TORCH_CHECK (
31- A.is_mps (), __func__, " A is on " , A.device (), " but expected on mps" );
32- TORCH_CHECK (
33- B.is_mps (), __func__, " B is on " , B.device (), " but expected on mps" );
34- TORCH_CHECK (
35- S.is_mps (), __func__, " S is on " , S.device (), " but expected on mps" );
36- TORCH_CHECK (
37- Z.is_mps (), __func__, " Z is on " , Z.device (), " but expected on mps" );
38-
3929 TORCH_CHECK (
4030 A.dtype () == at::kBFloat16 || A.dtype () == at::kHalf ||
4131 A.dtype () == at::kFloat ,
@@ -75,6 +65,29 @@ Tensor linear_mps_kernel(
7565 " : expect Z to be 2d tensor with shape [:, " ,
7666 N,
7767 " ]" );
68+ }
69+
70+ template <int nbit>
71+ Tensor linear_mps_kernel (
72+ const Tensor& A,
73+ const Tensor& B,
74+ int64_t group_size,
75+ const Tensor& S,
76+ const Tensor& Z) {
77+ TORCH_CHECK (
78+ A.is_mps (), __func__, " : A is on " , A.device (), " but expected on mps" );
79+ TORCH_CHECK (
80+ B.is_mps (), __func__, " : B is on " , B.device (), " but expected on mps" );
81+ TORCH_CHECK (
82+ S.is_mps (), __func__, " : S is on " , S.device (), " but expected on mps" );
83+ TORCH_CHECK (
84+ Z.is_mps (), __func__, " : Z is on " , Z.device (), " but expected on mps" );
85+
86+ check_linear_mps_args<nbit>(A, B, group_size, S, Z);
87+
88+ auto M = A.size (0 );
89+ auto N = B.size (0 );
90+ auto K = A.size (1 );
7891
7992 auto C = at::empty ({M, N}, A.options ());
8093
@@ -93,6 +106,30 @@ Tensor linear_mps_kernel(
93106 return C;
94107}
95108
109+ template <int nbit>
110+ Tensor linear_mps_kernel_meta (
111+ const Tensor& A,
112+ const Tensor& B,
113+ int64_t group_size,
114+ const Tensor& S,
115+ const Tensor& Z) {
116+ TORCH_CHECK (
117+ A.is_meta (), __func__, " : A is on " , A.device (), " but expected on meta" );
118+ TORCH_CHECK (
119+ B.is_meta (), __func__, " : B is on " , B.device (), " but expected on meta" );
120+ TORCH_CHECK (
121+ S.is_meta (), __func__, " : S is on " , S.device (), " but expected on meta" );
122+ TORCH_CHECK (
123+ Z.is_meta (), __func__, " : Z is on " , Z.device (), " but expected on meta" );
124+
125+ check_linear_mps_args<nbit>(A, B, group_size, S, Z);
126+
127+ auto M = A.size (0 );
128+ auto N = B.size (0 );
129+
130+ return at::empty ({M, N}, A.options ()).to (" meta" );
131+ }
132+
96133// LowBit Packing on CPU Backend
97134template <int nbit>
98135Tensor pack_weights_cpu_kernel (const Tensor& W) {
@@ -155,4 +192,14 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) {
155192 m.impl (" _linear_fp_act_7bit_weight" , &linear_mps_kernel<7 >);
156193}
157194
195+ TORCH_LIBRARY_IMPL (torchao, Meta, m) {
196+ m.impl (" _linear_fp_act_1bit_weight" , &linear_mps_kernel_meta<1 >);
197+ m.impl (" _linear_fp_act_2bit_weight" , &linear_mps_kernel_meta<2 >);
198+ m.impl (" _linear_fp_act_3bit_weight" , &linear_mps_kernel_meta<3 >);
199+ m.impl (" _linear_fp_act_4bit_weight" , &linear_mps_kernel_meta<4 >);
200+ m.impl (" _linear_fp_act_5bit_weight" , &linear_mps_kernel_meta<5 >);
201+ m.impl (" _linear_fp_act_6bit_weight" , &linear_mps_kernel_meta<6 >);
202+ m.impl (" _linear_fp_act_7bit_weight" , &linear_mps_kernel_meta<7 >);
203+ }
204+
158205} // namespace torchao::kernels::mps::lowbit::aten
0 commit comments