diff --git a/examples/awq/llama_example.py b/examples/awq/llama_example.py index 2994749278..38ed9324be 100644 --- a/examples/awq/llama_example.py +++ b/examples/awq/llama_example.py @@ -1,17 +1,8 @@ -import lm_eval -from compressed_tensors.quantization import ( - QuantizationArgs, - QuantizationScheme, - QuantizationStrategy, - QuantizationType, -) from datasets import load_dataset -from lm_eval.utils import make_table from transformers import AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot from llmcompressor.modifiers.awq import AWQModifier -from llmcompressor.modifiers.quantization import QuantizationModifier # Select model and load it. MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" @@ -61,23 +52,7 @@ def tokenize(sample): # Configure the quantization algorithm to run. recipe = [ - AWQModifier(bits=4, symmetric=False), - QuantizationModifier( - ignore=["lm_head"], - config_groups={ - "group_0": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs( - num_bits=4, - type=QuantizationType.INT, - dynamic=False, - symmetric=False, - strategy=QuantizationStrategy.GROUP, - group_size=128, - ), - ) - }, - ), + AWQModifier(ignore=["lm_head"], scheme="W4A16_ASYM", targets=["Linear"]), ] # Apply algorithms. @@ -101,21 +76,3 @@ def tokenize(sample): SAVE_DIR = MODEL_ID.split("/")[-1] + "-awq-asym" model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer.save_pretrained(SAVE_DIR) - -# -# 2) Evaluate model on wikitext perplexity -# - -results = lm_eval.simple_evaluate( - model="hf", - model_args={ - "pretrained": SAVE_DIR, - "add_bos_token": True, - "dtype": "bfloat16", - "gpu_memory_utilization": 0.5, - }, - tasks=["wikitext"], - num_fewshot=5, - batch_size="auto", -) -print(make_table(results)) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index d3b00f3299..4201911e96 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -1,20 +1,24 @@ import inspect -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import torch +from compressed_tensors.quantization import disable_quantization from compressed_tensors.utils import ( align_module_device, get_execution_device, update_offload_parameter, ) from loguru import logger -from pydantic import ConfigDict +from pydantic import ConfigDict, PrivateAttr, model_validator from torch.nn import Module +from torch.utils.hooks import RemovableHandle from tqdm import tqdm -from llmcompressor.core import State +from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier -from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward +from llmcompressor.modifiers.quantization.calibration import update_weight_zp_scale +from llmcompressor.modifiers.quantization.quantization import QuantizationMixin +from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.utils.fsdp.helpers import get_fsdp_parent from llmcompressor.utils.helpers import calibration_forward_context from llmcompressor.utils.pytorch.module import ( @@ -29,7 +33,7 @@ # TODO (Brian INFERENG-531) Add support for offloaded models -class AWQModifier(Modifier): +class AWQModifier(Modifier, QuantizationMixin): """ Implements the AWQ (Activation-Weighted Quantization) algorithm, as described in https://arxiv.org/pdf/2306.00978. The algorithm @@ -49,32 +53,50 @@ class AWQModifier(Modifier): example recipe: ```yaml AWQModifier: - bits: 4 mappings: - smooth_layer: "re:.*self_attn_layer_norm" balance_layers: ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"] - smooth_layer: "re:.*final_layer_norm" balance_layers: ["re:.*fc1"] ] - ignore: ["model.decoder.final_layer_norm"] + ignore: ["lm_head"] + config_groups: + group_0: + targets: + - "Linear" + input_activations: null + output_activations: null + weights: + num_bits: 4 + type: int + symmetric: false + strategy: group + group_size: 128 ``` Lifecycle: - on_initialize - resolve mappings - capture kwargs needed for forward passes into modules - - capture input activations to balance layers - - register hook to capture inputs and offload to cpu - - run calibration dataset through, to capture inputs - - clear hooks - - concatenate activations across all batches - - apply smooothing + - on_start + - set up activation cache hooks to capture input activations + to balance layers + - on sequential epoch end + - apply smoothing to each smoothing layer + - consume cached activations across all batches + - clear cached activations as they are used - find best smoothing scale for each smoothing layer - - apply - - move to next smoothing layer + - apply to model weights + - raise error if any unused activations remain + - on_end + - re-run logic of sequential epoch end (in case of basic pipeline) + - set scales and zero points + - remove activation hooks - on_finalize - clear resolved mappings and captured activations + :param sequential_targets: list of module names to compress in + the same calibration pass :param mappings: list activation layers to smooth, and which layers to scale the output such that activations are smoothed. Each entry of the mapping list should be a list itself, in which the first @@ -96,53 +118,165 @@ class AWQModifier(Modifier): # Allow arbitrary types because AWQMapping has fields of type torch.nn.Module model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True) + # User-provided vars (in addition to QuantizationMixin args) + sequential_targets: Union[str, List[str], None] = None mappings: List[AWQMapping] = AWQ_MAPPING_REGISTRY["Llama"] - ignore: List[str] = [] - group_size: int = 128 max_chunk_memory: int = 1024 * 1024 * 1024 - num_bits: int = 4 - symmetric: bool = False duo_scaling: bool = True - _resolved_mappings: List[ResolvedMapping] = [] - _scales: Dict[str, Union[torch.Tensor, List[torch.Tensor]]] = {} - _module_kwargs: Dict = {} + # Private vars set during validation + _num_bits: Optional[int] = PrivateAttr(default=None) + _symmetric: Optional[bool] = PrivateAttr(default=None) + _group_size: Optional[int] = PrivateAttr(default=None) + + # Private vars set during initialization, cleared during finalization + _resolved_mappings: List[ResolvedMapping] = PrivateAttr(default_factory=list) + _activations: Dict[str, List[torch.Tensor]] = PrivateAttr(default_factory=dict) + _activation_hooks: Set[RemovableHandle] = PrivateAttr(default_factory=set) + _module_kwargs: Dict = PrivateAttr(default_factory=dict) + + @model_validator(mode="after") + def validate_model_after(model: "AWQModifier") -> "AWQModifier": + """ + Confirm only one configuration for group_size, symmetric, and num_bits, + as AWQ algorithm depends on it + Confirm no activation quantization, as AWQ only works with WNA16 + """ + config = model.resolve_quantization_config() + + num_bits_set = set( + group.weights.num_bits + for group in config.config_groups.values() + if group.weights is not None + ) + assert ( + len(num_bits_set) == 1 + ), "In AWQ, all config groups must use the same configuration for num_bits" + + model._num_bits = next(iter(num_bits_set)) + + symmetric_set = set( + group.weights.symmetric + for group in config.config_groups.values() + if group.weights is not None + ) + assert ( + len(symmetric_set) == 1 + ), "In AWQ, all config groups must use the same configuration for symmetric" + + model._symmetric = next(iter(symmetric_set)) + + group_size_set = set( + group.weights.group_size + for group in config.config_groups.values() + if group.weights is not None + ) + assert ( + len(group_size_set) == 1 + ), "In AWQ, all config groups must use the same configuration for group_size" + + model._group_size = next(iter(group_size_set)) + + in_num_bits_set = set( + group.input_activations.num_bits + for group in config.config_groups.values() + if group.input_activations is not None + ) + assert len(in_num_bits_set) == 0 or in_num_bits_set == {16}, ( + "AWQ activations must be 16-bit precision, " + f"input activations {in_num_bits_set} not allowed" + ) + + out_num_bits_set = set( + group.output_activations.num_bits + for group in config.config_groups.values() + if group.output_activations is not None + ) + assert len(out_num_bits_set) == 0 or out_num_bits_set == {16}, ( + "AWQ activations must be 16-bit precision, " + f"output activations {out_num_bits_set} not allowed" + ) + + return model def on_initialize(self, state: State, **kwargs) -> bool: """ - Initialize and run AWQ on the given state + Initialize AWQ on the given state + Initialize quantization, resolve mappings, cache module kwargs :param state: state to run AWQ on :return: True on a successful run, False otherwise """ + # apply config to model and prepare calibration hooks + if QuantizationMixin.has_config(self): + QuantizationMixin.initialize_quantization(self, state.model) + self._set_resolved_mappings(state.model) - with calibration_forward_context(state.model): - self._set_module_kwargs(state.model, state.data.calib) + self._set_module_kwargs(state.model, state.data.calib) + + return True - self._setup_scale_hooks() - with calibration_forward_context(state.model): - self._calibrate(state.model, state.data.calib) - self.remove_hooks() - self._concat_collected_activations() + def on_start(self, state: State, event: Event, **kwargs): + self.started_ = True - with calibration_forward_context(state.model): + # register quantization calibration hooks + # assume quantization has been initialized by this modifier or one before it + QuantizationMixin.start_calibration(self, state.model) + # Unlike qmod, do not quantize as we calibrate + # This choice does not seem to have a meaningful impact on accuracy + state.model.apply(disable_quantization) + + self._setup_activation_cache_hooks() + + def on_event(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.CALIBRATION_EPOCH_START: + if not self.started_: + self.on_start(state, None) + + elif event.type_ == EventType.SEQUENTIAL_EPOCH_END: + # Run smoothing in case of sequential pipeline self._apply_smoothing(state.model) - return True + elif event.type_ == EventType.CALIBRATION_EPOCH_END: + # Run smoothing in case of basic pipeline + self._apply_smoothing(state.model) + + if not self.ended_: + self.on_end(state, None) + + def on_end(self, state: State, event: Event, **kwargs): + """ + Finish calibrating by setting scales and zero-points, + removing observers and calibration hooks + """ + self._assert_all_activations_consumed() + + self.ended_ = True + + modules = list(state.model.modules()) + for module in tqdm(modules, desc="Calibrating weights"): + update_weight_zp_scale(module) + + QuantizationMixin.end_calibration(self, state.model) + + # remove activation hooks + self.remove_hooks(self._activation_hooks) + self._activation_hooks.clear() def on_finalize(self, state: State, **kwargs) -> bool: """ - Clean up by clearing the scale and mapping data + Clean up by clearing the activations and mapping data :param state: unused :return: True """ - if self._scales is not None: - self._scales.clear() - if self._resolved_mappings is not None: - self._resolved_mappings.clear() + if not self.ended_: + self.on_end(state, None) + + self._activations.clear() + self._resolved_mappings.clear() return True @@ -163,7 +297,10 @@ def _set_resolved_mappings(self, model: Module) -> None: for mapping in self.mappings: to_smooth_layers = get_layers(mapping.smooth_layer, model) for layer_name, smooth_layer in to_smooth_layers.items(): - if layer_name not in self.ignore: + # always exclude `.weight_observer`, only want `.weight` + if layer_name not in self.ignore and not layer_name.endswith( + "_observer" + ): balance_layers, balance_names = [], [] for balance_suffix in mapping.balance_layers: # find the submodule that matches the activation layer @@ -220,64 +357,36 @@ def _set_resolved_mappings(self, model: Module) -> None: self._resolved_mappings = resolved_mappings return - def _setup_scale_hooks(self) -> None: + def _setup_activation_cache_hooks(self) -> None: """ Attach a forward hook to each activation we want to smooth. This allows us to calculate the dynamic range during calibration """ - def create_hook_fn(layer_name): - def hook_fn(module, inp, out): - inp = inp[0].cpu().detach() - - if layer_name in self._scales: - self._scales[layer_name].append(inp) + def create_cache_activation_hook(smooth_layer_name): + def cache_activation_hook_fn( + _module: torch.nn.Module, + args: Tuple[torch.Tensor, ...], + _output: torch.Tensor, + ): + # Assume that first argument is the input + inp = args[0].cpu().detach() + + if smooth_layer_name in self._activations: + self._activations[smooth_layer_name].append(inp) else: - self._scales[layer_name] = [inp] + self._activations[smooth_layer_name] = [inp] - return hook_fn + return cache_activation_hook_fn for mapping in self._resolved_mappings: - name = mapping.smooth_name - # storing inps to first balance layer - # is enough, as other balance layers - # get the same input + # storing inputs to first balance layer is sufficient + # other balance layers get the same input layer = mapping.balance_layers[0] - self.register_hook(layer, create_hook_fn(name), "forward") - - @torch.no_grad() - def _calibrate(self, model: Module, calibration_dataloader: List) -> None: - """ - Catch the output dynamic ranges of each layer that will be smoothed by running - forward passes with calibration_dataloader - """ - class_name = self.__class__.__name__.replace("PyTorch", "") - logger.info( - f"Running {class_name} calibration with " - f"{len(calibration_dataloader)} samples..." - ) - if not calibration_dataloader: - raise ValueError( - "Calibration data loader not set, must populate the calib_data field of" - " CompressionSession to run the AWQ modifier" + hook = self.register_hook( + layer, create_cache_activation_hook(mapping.smooth_name), "forward" ) - - run_calibration_forward( - model, - calibration_dataloader, - ) - - def _concat_collected_activations(self) -> None: - """ - Concatenate the collected activation values from each forward pass into a single - tensor for each layer - - :postcondition: each layer in self._scales will have a single tensor containing - all the activation values seen during calibration - """ - for mapping in self._resolved_mappings: - name = mapping.smooth_name - self._scales[name] = torch.cat(self._scales[name], dim=0) + self._activation_hooks.add(hook) @torch.no_grad() def _apply_smoothing(self, model: Module) -> None: @@ -288,13 +397,17 @@ def _apply_smoothing(self, model: Module) -> None: :param model: model to apply smoothing to """ - logger.info("Smoothing activation scales...") - for mapping in tqdm(self._resolved_mappings): - smooth_layer = mapping.smooth_layer - balance_layers = mapping.balance_layers + for mapping in tqdm(self._resolved_mappings, desc="Smoothing"): + # NOTE: When using SequentialPipeline, not all the mappings + # will have cached activations in the segment being udpated + if mapping.smooth_name not in self._activations: + continue - activations = self._scales[mapping.smooth_name] + activations = torch.cat(self._activations[mapping.smooth_name], dim=0) + del self._activations[mapping.smooth_name] + smooth_layer = mapping.smooth_layer + balance_layers = mapping.balance_layers module2inspect = mapping.parent # [STEP 1]: Compute per-channel mean of normalised weights @@ -302,7 +415,7 @@ def _apply_smoothing(self, model: Module) -> None: weight = torch.cat([bl.weight for bl in balance_layers], dim=0) org_shape = weight.shape # The weights are reshaped to be organised by quantization group - weight = weight.view(-1, self.group_size) + weight = weight.view(-1, self._group_size) # Calculates the relative magnitude of the weights within # each of the quantization groups, and rescales each group # individually so that each group has weights on a 0-1 scale. @@ -336,20 +449,22 @@ def _apply_smoothing(self, model: Module) -> None: x_mean = (x_sum / num_elements).to(inp.dtype) - # [STEP 3]: Compute output of module - fp16_output = self._forward_input_with_kwargs( - module=module2inspect, - inputs=inp, - input_kwargs=_sanitize_kwargs(self._module_kwargs, module2inspect), - ) - fp16_output = fp16_output.clip( - torch.finfo(fp16_output.dtype).min, torch.finfo(fp16_output.dtype).max - ) + with calibration_forward_context(model), HooksMixin.disable_hooks(): + # [STEP 3]: Compute output of module + fp16_output = self._forward_input_with_kwargs( + module=module2inspect, + inputs=inp, + input_kwargs=_sanitize_kwargs(self._module_kwargs, module2inspect), + ) + fp16_output = fp16_output.clip( + torch.finfo(fp16_output.dtype).min, + torch.finfo(fp16_output.dtype).max, + ) - # [STEP 4]: Compute loss - best_scales = self._compute_best_scale( - inp, w_mean, x_mean, module2inspect, balance_layers, fp16_output - ) + # [STEP 4]: Compute loss + best_scales = self._compute_best_scale( + inp, w_mean, x_mean, module2inspect, balance_layers, fp16_output + ) scales = best_scales @@ -389,6 +504,8 @@ def smooth(module): smooth(layer) smooth(smooth_layer) + self._assert_all_activations_consumed() + def _compute_best_scale( self, x: torch.Tensor, @@ -438,17 +555,17 @@ def _compute_best_scale( scales[torch.isnan(scales)] = 1 # Q(W * s) - for fc in linears2scale: - with align_module_device(fc): - fc.weight.mul_(_scalesview) + for linear in linears2scale: + with align_module_device(linear): + linear.weight.mul_(_scalesview) update_offload_parameter( - fc, + linear, "weight", _pseudo_quantize_tensor( - w=fc.weight.data, - symmetric=self.symmetric, - bit_width=self.num_bits, - group_size=self.group_size, + w=linear.weight.data, + symmetric=self._symmetric, + bit_width=self._num_bits, + group_size=self._group_size, )[0] / _scalesview, ) @@ -458,7 +575,8 @@ def _compute_best_scale( module=module2inspect, inputs=x, input_kwargs=self._module_kwargs ) int_w_output = int_w_output.clip( - torch.finfo(int_w_output.dtype).min, torch.finfo(int_w_output.dtype).max + torch.finfo(int_w_output.dtype).min, + torch.finfo(int_w_output.dtype).max, ) # compute mean squared error (L2 norm) @@ -503,7 +621,7 @@ def _compute_loss( fp16_chunks = torch.split(fp16_output_flat, chunk_size) int_w_chunks = torch.split(int_w_output_flat, chunk_size) - # Compute the loss for each chunk + # Compute the MSE loss for each chunk for fp16_chunk, int_w_chunk in zip(fp16_chunks, int_w_chunks): chunk_loss = ( (fp16_chunk.to(device) - int_w_chunk.to(device)) @@ -519,6 +637,14 @@ def _compute_loss( return loss + def _assert_all_activations_consumed(self): + """ + Confirm all activations have been consumed + If not, something has gone wrong + """ + if len(self._activations) > 0: + raise RuntimeError("Some cached activations were not used") + def _set_module_kwargs(self, model, dataloader) -> None: _, modules = next(iter(get_layers("re:.*layers", model).items())) @@ -556,7 +682,8 @@ def forward(self, *args, **kwargs): # patch layer 0 to catch input and kwargs modules[0] = Catcher(modules[0]) try: - model(samples.to(next(model.parameters()).device)) + with calibration_forward_context(model): + model(samples.to(next(model.parameters()).device)) except ValueError: # work with early exit pass modules[0] = modules[0].module # restore diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 53fccf37ee..4f44552862 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -49,7 +49,7 @@ class GPTQModifier(Modifier, QuantizationMixin): | num_bits: 8 | type: "int" | symmetric: true - | strategy: "tensor" + | strategy: group | group_size: 128 | actorder: False @@ -98,7 +98,7 @@ class GPTQModifier(Modifier, QuantizationMixin): """ # gptq modifier arguments - sequential_update: bool = True # DEPRECIATED + sequential_update: bool = True # DEPRECATED sequential_targets: Union[str, List[str], None] = None block_size: int = 128 dampening_frac: Optional[float] = 0.01 diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index 75e5e89359..1eec560d60 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -50,6 +50,10 @@ class QuantizationMixin(HooksMixin): - Remove calibration hooks - Apply freeze status - Keep quantization enabled for future steps + NOTE: QuantizationMixin does not update scales and zero-points on its own, + as this is not desired for all Modifiers inheriting from it. Modifier must + explicitly call `update_weight_zp_scale`. + See QuantizationModifier.on_start method for example :param config_groups: dictionary specifying quantization schemes to apply to target modules. Modules not matching a scheme target will NOT be quantized. diff --git a/src/llmcompressor/modifiers/utils/pytorch_helpers.py b/src/llmcompressor/modifiers/utils/pytorch_helpers.py index 8ae9945a70..df043c611b 100644 --- a/src/llmcompressor/modifiers/utils/pytorch_helpers.py +++ b/src/llmcompressor/modifiers/utils/pytorch_helpers.py @@ -1,16 +1,10 @@ -from itertools import cycle -from typing import Callable, Dict, Optional +from typing import Dict import torch from torch.nn import Module -from torch.utils.data import DataLoader -from tqdm import tqdm - -from llmcompressor.pytorch.utils import tensors_module_forward, tensors_to_device __all__ = [ "apply_pad_mask_to_batch", - "run_calibration_forward", "is_moe_model", ] @@ -34,54 +28,6 @@ def apply_pad_mask_to_batch(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.T return batch -def run_calibration_forward( - model: Module, - calibration_dataloader: DataLoader, - num_calibration_steps: Optional[int] = None, - calibration_function: Optional[Callable] = None, - device: Optional[str] = None, - mask_padding: bool = False, -): - """ - Helper function used by one-shot modifiers, runs calibration data through a model to - update modifier statistics and trigger hooks - - :param model: PyTorch model to run - :param calibration_dataloader: data to use for calibration - :param num_calibration_steps: number of items in calibration_dataloader to process, - None or a negative number to process all available data - :param calibration_function: option to pass a custom forward function for model - :param device: option to move the model to a specific device before calibration - :param mask_padding: whether to zero out padding tokens during calibration - """ - model.eval() - - forward_fn: Callable = ( - calibration_function if calibration_function else tensors_module_forward - ) - - # move model to optional specified device if it is not already there - model_device = next(model.parameters()).device - if device is not None and model_device != device: - model.to(device) - model_device = next(model.parameters()).device - _dataloader = ( - calibration_dataloader - if num_calibration_steps is None - else cycle(calibration_dataloader) - ) - - # run through the calibration data - for batch_idx, batch in enumerate(tqdm(_dataloader)): - if num_calibration_steps and batch_idx >= num_calibration_steps: - break - if mask_padding: - batch = apply_pad_mask_to_batch(batch) - batch = tensors_to_device(batch, model_device) - with torch.no_grad(): - forward_fn(batch, module=model) - - def is_moe_model(model: Module) -> bool: """ Check if the model is a mixture of experts model diff --git a/src/llmcompressor/pipelines/cache.py b/src/llmcompressor/pipelines/cache.py index e35cad8278..794135273c 100644 --- a/src/llmcompressor/pipelines/cache.py +++ b/src/llmcompressor/pipelines/cache.py @@ -75,7 +75,7 @@ def from_dataloader( for batch in tqdm.tqdm(dataloader, desc="Preparing intermediates cache"): intermediate = {} for key, value in batch.items(): - if mask_padding and key == "input_ids": + if mask_padding and (key == "input_ids") and "attention_mask" in batch: value = cls._mask_padding(value, batch["attention_mask"]) intermediate[key] = IntermediateValue(value=value, device=model_device) diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index 79e37ff6eb..9cb2f37087 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -73,12 +73,11 @@ def __call__( calib_desc = f"({layer_index + 1}/{num_layers}): Calibrating" prop_desc = f"({layer_index + 1}/{num_layers}): Propagating" - # do an preliminary pass to trigger modifier hooks + # do a preliminary pass to trigger modifier hooks for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=calib_desc): inputs = intermediates.fetch(batch_idx) layer(**inputs) - # trigger compression LifecycleCallbacks.sequential_epoch_end() # this pass does not trigger modifier hooks @@ -98,5 +97,5 @@ def __call__( intermediates.delete(batch_idx) intermediates.update(batch_idx, output) - # redudant, finish any remaining compression + # redundant, finish any remaining compression LifecycleCallbacks.calibration_epoch_end() diff --git a/src/llmcompressor/pipelines/registry.py b/src/llmcompressor/pipelines/registry.py index ea27df5351..77d6e79ab1 100644 --- a/src/llmcompressor/pipelines/registry.py +++ b/src/llmcompressor/pipelines/registry.py @@ -17,7 +17,7 @@ __all__ = ["CalibrationPipeline"] -SEQUENTIAL_MODIFIERS = (GPTQModifier, SparsityModifierBase) +SEQUENTIAL_MODIFIERS = (AWQModifier, GPTQModifier, SparsityModifierBase) class CalibrationPipeline(ABC, RegistryMixin): @@ -60,15 +60,6 @@ def from_modifiers( @staticmethod def _validate_infer_pipeline(modifiers: List[Modifier]) -> str: - if any(isinstance(modifier, AWQModifier) for modifier in modifiers): - if len(modifiers) > 1: - logger.warning( - "AWQ does not currently support sharing a data pipeline with other " - "modifiers. Inferring `independent` calibration pipeline" - ) - return "independent" - return "datafree" - if any(isinstance(modifier, SEQUENTIAL_MODIFIERS) for modifier in modifiers): return "sequential" diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 80eb3739d6..b7fd47006e 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -72,12 +72,11 @@ def __call__( calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating" prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagating" - # do an preliminary pass to trigger modifier hooks + # do a preliminary pass to trigger modifier hooks for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=calib_desc): inputs = intermediates.fetch(batch_idx, subgraph.input_names) subgraph.forward(model, **inputs) - # trigger compression LifecycleCallbacks.sequential_epoch_end() # this pass does not trigger modifier hooks @@ -91,5 +90,5 @@ def __call__( intermediates.update(batch_idx, output) intermediates.delete(batch_idx, subgraph.consumed_names) - # redudant, finish any remaining compression + # redundant, finish any remaining compression LifecycleCallbacks.calibration_epoch_end() diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 251add493e..929f1da220 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -1,5 +1,7 @@ import pytest import torch +from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme +from pydantic import ValidationError from llmcompressor.modifiers.awq import AWQMapping, AWQModifier from llmcompressor.modifiers.awq.base import _sanitize_kwargs @@ -17,6 +19,7 @@ def test_awq_is_registered(): type_="AWQModifier", allow_experimental=False, allow_registered=True, + scheme="W4A16_ASYM", ) assert isinstance(modifier, AWQModifier), "AWQModifier not registered" @@ -35,7 +38,8 @@ def test_set_resolved_mappings(): "re:.*up_proj", ["re:.*down_proj"], ), - ] + ], + scheme="W4A16_ASYM", ) self_attn = torch.nn.ModuleDict( { @@ -84,7 +88,8 @@ def test_set_resolved_mappings(): awq = AWQModifier( mappings=[ AWQMapping("re:.*v_proj", ["re:.*o_proj"]), - ] + ], + scheme="W4A16_ASYM", ) model = torch.nn.ModuleDict( { @@ -102,6 +107,66 @@ def test_set_resolved_mappings(): assert len(awq._resolved_mappings) == 0 +@pytest.mark.unit +def test_validate(): + with pytest.raises(ValidationError): + AWQModifier(scheme="W8A8") + + with pytest.raises(ValidationError): + AWQModifier( + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=4, + group_size=64, + ), + ), + "group_1": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=4, + group_size=128, + ), + ), + } + ) + + with pytest.raises(ValidationError): + AWQModifier( + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=4, + group_size=128, + ), + ), + "group_1": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=8, + group_size=128, + ), + ), + } + ) + + # valid configuration + AWQModifier( + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs(num_bits=4, group_size=128, symmetric=False), + ), + "group_1": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs(num_bits=4, group_size=128, symmetric=False), + ), + } + ) + + @pytest.mark.unit def test_sanitize_kwargs(): module = torch.nn.Linear(10, 20)