Skip to content

Commit 1a825e1

Browse files
HDCharlesamdfaa
authored andcommitted
gemlite integration in torchao (#1034)
* gemlite integration in torchao Summary: This PR adds support for gemlite kernels in torchao using a subclass integration with the gemlite_uintx_weight_only constructor. This works for int4 grouped and ungrouped assymmetric oeight only quantization and int8 symmetric ungrouped quantization for fp16 models. TP support through DTensor is included in thsi PR in the process of integrating gemlite into AQT i also made some fixes to a few quant primitives that are being used which previously were not. Test Plan: test_integration.py -k "test_gemlite_layout" test_affine_quantized_tensor_parallel.py -k "test_tp_gemlite" see benchmarks.sh for gemlite benchmarks as well. Reviewers: Subscribers: Tasks: Tags: new gemlite integration using pip install Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: tests ran Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: fixing gemlite to do int4 matmul instead of fp16 fp16 Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: running tests Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: more testing Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: AQT integration wip Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Wip Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: testing on gemlite a100_int8_tuning branch Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: gemlite subclass testing bitpacking 8 bits Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: bug fixing stuff Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: hicham fixes Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: new benchmarks Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: testing gemlite 8 bit Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: WIP Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: tp support Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: wip Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: final Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fixing regressions Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 463e506 commit 1a825e1

File tree

12 files changed

+577
-13
lines changed

12 files changed

+577
-13
lines changed

test/dtypes/test_affine_quantized_tensor_parallel.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@
1919
from torchao.quantization.quant_api import quantize_
2020
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6
2121

22+
try:
23+
import gemlite # noqa: F401
24+
25+
has_gemlite = True
26+
except ModuleNotFoundError:
27+
has_gemlite = False
28+
2229

2330
class TestAffineQuantizedTensorParallel(DTensorTestBase):
2431
"""Basic test case for tensor subclasses"""
@@ -139,8 +146,29 @@ def test_tp(self, dtype):
139146
return self._test_tp(dtype)
140147

141148

149+
class TestGemliteLayoutTensorParallel(TestAffineQuantizedTensorParallel):
150+
COMMON_DTYPES = [torch.float16]
151+
152+
@common_utils.parametrize("dtype", COMMON_DTYPES)
153+
@with_comms
154+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
155+
@unittest.skipIf(not has_gemlite, "gemlite not available")
156+
def test_tp_gemlite(self, dtype):
157+
from torchao.quantization import gemlite_uintx_weight_only
158+
159+
for packing_bitwidth in [32, 8]:
160+
for bit_width in [4, 8]:
161+
for group_size in [64, 32, None] if bit_width == 4 else [None]:
162+
api = lambda: gemlite_uintx_weight_only(
163+
group_size, bit_width, packing_bitwidth
164+
)
165+
self.QUANT_METHOD_FN = staticmethod(api)
166+
return self._test_tp(dtype)
167+
168+
142169
common_utils.instantiate_parametrized_tests(TestInt8woAffineQuantizedTensorParallel)
143170
common_utils.instantiate_parametrized_tests(TestInt4woAffineQuantizedTensorParallel)
171+
common_utils.instantiate_parametrized_tests(TestGemliteLayoutTensorParallel)
144172

145173
# Run only on H100
146174
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):

test/integration/test_integration.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,12 @@
9696
)
9797
from torchao.dtypes.utils import is_device
9898

99+
try:
100+
import gemlite
101+
has_gemlite = True
102+
except ModuleNotFoundError:
103+
has_gemlite = False
104+
99105
logger = logging.getLogger("INFO")
100106

101107
torch.manual_seed(0)
@@ -870,6 +876,10 @@ def _test_lin_weight_subclass_api_impl(
870876
ref_f = mod(x)
871877
api(mod)
872878

879+
# test get_plain()
880+
if hasattr(mod[0].weight, "tensor_impl"):
881+
mod[0].weight.tensor_impl.get_plain()
882+
873883
test = mod(x)
874884
self.assertGreater(
875885
SQNR(ref_f, test),
@@ -930,6 +940,30 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
930940
test_dtype=dtype
931941
)
932942

943+
@parameterized.expand(COMMON_DEVICE_DTYPE)
944+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "gemlite tests needs torch 2.5 or greater")
945+
@unittest.skipIf(not has_gemlite, "gemlite not available")
946+
def test_gemlite_layout(self, device, dtype):
947+
if dtype!= torch.float16:
948+
self.skipTest(f"gemlite only works for fp16 dtype")
949+
from torchao.quantization import gemlite_uintx_weight_only
950+
if device == "cpu":
951+
self.skipTest(f"gemlite is for cuda, not {device}")
952+
for packing_bitwidth in [32, 8]:
953+
for bit_width in [4,8]:
954+
for group_size in [64, 32, None] if bit_width ==4 else [None]:
955+
api = lambda mod: quantize_(mod, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth))
956+
for test_shape in [[1, 1024, 512],[16, 256, 1024], [128, 256, 1024]]:
957+
print(packing_bitwidth, bit_width, group_size, test_shape, dtype)
958+
self._test_lin_weight_subclass_api_impl(
959+
api,
960+
device,
961+
15,
962+
test_shape=test_shape,
963+
test_dtype=dtype,
964+
)
965+
966+
933967
@parameterized.expand(COMMON_DEVICE_DTYPE)
934968
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
935969
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")

torchao/_models/llama/benchmarks.sh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,21 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
9191
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int4wo-64 --write_result benchmark_results.txt --prefill_size 8000
9292
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization sparse-marlin --write_result benchmark_results.txt --prefill_size 8000 --precision float16 --sparsity semi-structured
9393

94+
# gemlite benchmarks
95+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-64 --write_result benchmark_results.txt
96+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-64 --write_result benchmark_results.txt
97+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-None --write_result benchmark_results.txt
98+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-None --write_result benchmark_results.txt
99+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-8-None --write_result benchmark_results.txt
100+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-8-None --write_result benchmark_results.txt
101+
102+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-64 --write_result benchmark_results.txt --batch_size 32
103+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-64 --write_result benchmark_results.txt --batch_size 32
104+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-None --write_result benchmark_results.txt --batch_size 32
105+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-None --write_result benchmark_results.txt --batch_size 32
106+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-8-None --write_result benchmark_results.txt --batch_size 32
107+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-8-None --write_result benchmark_results.txt --batch_size 32
108+
94109
# 2:4 sparse model
95110
export MODEL_REPO=nm-testing/SparseLlama-3-8B-pruned_50.2of4
96111
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --write_result benchmark_results.txt

torchao/_models/llama/generate.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ def decode_n_tokens(
171171
)
172172
next_token, next_prob = next_token.clone(), next_prob.clone()
173173
input_pos += 1
174-
new_tokens.append(next_token)
174+
# in some instances not having this causes weird issues with the stored tokens when you run the next decode_one_token step
175+
new_tokens.append(next_token.clone())
175176
callback(new_tokens[-1])
176177
new_probs.append(next_prob)
177178
cur_token = next_token
@@ -368,6 +369,7 @@ def ffn_or_attn_only(mod, fqn):
368369
int8_weight_only,
369370
quantize_,
370371
uintx_weight_only,
372+
gemlite_uintx_weight_only,
371373
)
372374

373375
from torchao.quantization.granularity import PerRow, PerTensor
@@ -377,6 +379,39 @@ def ffn_or_attn_only(mod, fqn):
377379
from torchao.prototype.spinquant import apply_spinquant
378380

379381
apply_spinquant(model)
382+
if "gemlite" in quantization:
383+
import os, pwd
384+
import gemlite
385+
from gemlite.core import GemLiteLinearTriton, set_autotune
386+
_quant_args = quantization.split("-")
387+
bit_width = int(_quant_args[-2])
388+
group_size = None if _quant_args[-1] == 'None' else int(_quant_args[-1])
389+
try:
390+
packing_bitwidth = int(_quant_args[-3])
391+
except:
392+
# if only 2 inputs found, use default value
393+
packing_bitwidth = 32
394+
395+
quantize_(model, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth))
396+
397+
# try to load gemlite kernel config
398+
try:
399+
GemLiteLinearTriton.load_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json")
400+
print(f"loaded gemlite kernel cache /tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json")
401+
except:
402+
print(f"unable to load gemlite kernel cache /tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json")
403+
404+
print("running gemlite warmup")
405+
generate(
406+
model,
407+
encode_tokens(tokenizer, prompt, bos=True, device=device),
408+
max_new_tokens,
409+
batch_size,
410+
interactive=False,
411+
temperature=temperature,
412+
top_k=top_k,
413+
)
414+
GemLiteLinearTriton.cache_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json")
380415
if "int8wo" in quantization:
381416
quantize_(model, int8_weight_only())
382417
if "int8dq" in quantization:
@@ -959,7 +994,7 @@ def callback(x):
959994

960995
parser = argparse.ArgumentParser(description="Your CLI description.")
961996
parser.add_argument(
962-
"--prefill_size", type=int, default=0, help="Whether to run in ttft mode"
997+
"--prefill_size", type=int, default=None, help="Whether to run in ttft mode"
963998
)
964999
parser.add_argument(
9651000
"--prompt", type=str, default="Hello, my name is", help="Input prompt."
@@ -993,7 +1028,7 @@ def callback(x):
9931028
help=(
9941029
"Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, "
9951030
+ "autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, "
996-
+ "embed-int8wo, marlin_qqq"
1031+
+ "embed-int8wo, marlin_qqq, gemlite-<pack_bitwidth>-<nbits>-<groupsize>"
9971032
),
9981033
)
9991034
parser.add_argument(
@@ -1053,6 +1088,7 @@ def callback(x):
10531088
)
10541089

10551090
args = parser.parse_args()
1091+
print(args)
10561092
main(
10571093
args.prefill_size,
10581094
args.prompt,

torchao/_models/llama/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,10 @@ def setup_caches(self, max_batch_size, max_seq_length, training: bool=False, kv_
170170
max_seq_length = find_multiple(max_seq_length, 8)
171171
self.max_seq_length = max_seq_length
172172
self.max_batch_size = max_batch_size
173-
dtype = self.output.weight.dtype
173+
dtype = None
174+
# module swaps can cause issues without this
175+
if hasattr(self.output, "weight"):
176+
dtype = self.output.weight.dtype
174177
# For quantized layers, dtype is encoded in scales
175178
if hasattr(self.output, "scales"):
176179
dtype = self.output.scales.dtype

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@ def from_hp_to_intx(
225225
else input_float.dtype
226226
)
227227
device = input_float.device
228+
from torchao.dtypes.uintx import TensorCoreTiledLayout
229+
228230
data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq(
229231
input_float,
230232
nbits=nbits,
@@ -233,7 +235,15 @@ def from_hp_to_intx(
233235
compute_dtype=compute_dtype,
234236
device=device,
235237
verbose=False,
236-
raw_output=False,
238+
raw_output=not isinstance(
239+
_layout, (TensorCoreTiledLayout, PlainLayout)
240+
),
241+
# raw_output=False is basically the 'convert to TensorCoreTiledLayout zero_point version' option (add scale*midpoint)
242+
# note in choose_qparams_affine, preserve_zero = False does this same thing while also controlling whether
243+
# zero is preserved.
244+
# TODO uncouple preserve_zero and conversion of zero_point to TensorCoreTiledLayout version
245+
# TODO move the conversion of zero_point out of quant_primitives and into TensorCoreTiledLayout.from_plain
246+
# TODO change PlainLayout to use raw_output.
237247
)
238248
data = data.to(target_dtype)
239249
else:
@@ -251,7 +261,8 @@ def from_hp_to_intx(
251261
zero_point_domain,
252262
)
253263
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
254-
if zero_point_domain is None:
264+
# TODO should probably consolidate ZeroPointDomain.NONE and None
265+
if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE:
255266
zero_point = None
256267
data = quantize_affine(
257268
input_float,

torchao/dtypes/affine_quantized_tensor_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
_linear_int8_act_int8_weight_block_sparse_check,
2121
_linear_int8_act_int8_weight_block_sparse_impl,
2222
)
23+
from torchao.dtypes.uintx.gemlite_layout import (
24+
_linear_fp_act_int4_weight_gemlite_check,
25+
_linear_fp_act_int4_weight_gemlite_impl,
26+
)
2327
from torchao.dtypes.uintx.marlin_qqq_tensor import (
2428
_linear_int8_act_int4_weight_marlin_qqq_check,
2529
_linear_int8_act_int4_weight_marlin_qqq_impl,
@@ -135,6 +139,10 @@ def _register_aqt_quantized_linear_dispatches():
135139
_linear_int8_act_int4_weight_marlin_qqq_check,
136140
_linear_int8_act_int4_weight_marlin_qqq_impl,
137141
),
142+
(
143+
_linear_fp_act_int4_weight_gemlite_check,
144+
_linear_fp_act_int4_weight_gemlite_impl,
145+
),
138146
]:
139147
register_aqt_quantized_linear_dispatch(dispatch_condition, impl)
140148

0 commit comments

Comments
 (0)