@@ -35,21 +35,32 @@ class Target(Enum):
3535
3636 # AUTO target will automatically select a packing format
3737 # based on the available hardware.
38- # TODO: in future, add the ability to specify specific
39- # hardware targets
4038 AUTO = auto ()
39+ UNIVERSAL = auto ()
40+ KLEIDIAI = auto ()
4141
4242 # ATEN target will use the ATen operator
4343 ATEN = auto ()
4444
4545
46+ _TARGET_AND_STR = [
47+ (Target .AUTO , "auto" ),
48+ (Target .ATEN , "aten" ),
49+ (Target .UNIVERSAL , "universal" ),
50+ (Target .KLEIDIAI , "kleidiai" ),
51+ ]
52+
53+
54+ def target_to_str (target : Target ) -> str :
55+ target_to_str = {t : s for t , s in _TARGET_AND_STR }
56+ return target_to_str [target ]
57+
58+
4659def target_from_str (target : str ) -> Target :
47- if target .lower () == "auto" :
48- return Target .AUTO
49- elif target .lower () == "aten" :
50- return Target .ATEN
51- else :
52- raise ValueError (f"Invalid target: { target } " )
60+ str_to_target = {s : t for t , s in _TARGET_AND_STR }
61+ if target .lower () in str_to_target :
62+ return str_to_target [target .lower ()]
63+ raise ValueError (f"Invalid target: { target } " )
5364
5465
5566class PackedLinearInt8DynamicActivationIntxWeightLayout (Layout ):
@@ -146,10 +157,9 @@ def from_plain(
146157 ):
147158 assert isinstance (layout , PackedLinearInt8DynamicActivationIntxWeightLayout )
148159 assert layout .has_params_set (), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain"
149- assert layout .target in {
150- Target .AUTO ,
151- Target .ATEN ,
152- }, f"Unexpected target: { layout .target } "
160+ assert layout .target in [
161+ t for t , _ in _TARGET_AND_STR
162+ ], f"Unexpected target: { layout .target } "
153163
154164 n , k = int_data .shape
155165 if layout .target == Target .ATEN :
@@ -174,7 +184,7 @@ def from_plain(
174184 zero_point .reshape (- 1 ).to (torch .int8 ) if layout .has_weight_zeros else None ,
175185 layout .group_size ,
176186 bias if layout .has_bias else None ,
177- None , # target, if not passed a packing format will be chosen on C++ side
187+ target_to_str ( layout . target ) if layout . target != Target . AUTO else None ,
178188 ]
179189
180190 packed_weight = getattr (
@@ -223,7 +233,7 @@ def _linear_check(input_tensor, weight_tensor, bias):
223233
224234
225235def _linear_impl (input_tensor , weight_tensor , bias ):
226- def _impl_2d_auto (input_tensor , weight_tensor ):
236+ def _impl_2d_non_aten (input_tensor , weight_tensor ):
227237 assert input_tensor .dim () == 2
228238 assert weight_tensor .dim () == 2
229239
@@ -272,8 +282,8 @@ def _impl_2d_aten(input_tensor, weight_tensor):
272282 if target == Target .ATEN :
273283 assert TORCH_VERSION_AT_LEAST_2_6 == 1 , "Target.ATEN requires torch >= 2.6.0"
274284 _impl_2d = _impl_2d_aten
275- elif target == Target . AUTO :
276- _impl_2d = _impl_2d_auto
285+ else :
286+ _impl_2d = _impl_2d_non_aten
277287
278288 if input_tensor .dim () == 2 :
279289 res = _impl_2d (input_tensor , weight_tensor )
0 commit comments