-
Notifications
You must be signed in to change notification settings - Fork 344
Add quantize_ nn.Parameter support #3083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary: This PR adds in a simple 2d and 3d moe implementation and tests `quantize_` on them to see if we get the same results. Test Plan: ``` pytest test/prototype/test_parameter.py -k test_quantize_parameter ``` Reviewers: Subscribers: Tasks: Tags:
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3083
Note: Links to docs will display an error until the docs builds have been completed. ❌ 8 New FailuresAs of commit f68f572 with merge base 5346f0e ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
current AOBaseConfig is more for linear weights, can it be extended to param config cleanly? |
Would it work to stick with def handle_module(model, fqn, config):
if has_parameter(model, fqn):
... new behavior for parameters, apply parameter swap config ...
elif has_parameter(model, fqn + '.weight'):
... old behavior, apply parameter swap config ...
elif has_module(model, fqn):
... old behavior, apply module swap ... |
Yeah, we can do this. Do you think we should keep the |
Yes I believe so, especially in the case of the Config object itself. We attach everything to the weight parameter for nn.Linear, so this allows us to specify the parameter name instead of assuming it's "weight". The only thing that does not map cleanly IMO is the
I think we should define the transform for parameters as the base case (aka |
IMO we should change the current name and keep the old name for BC: ParamOrModuleFqnToConfig = ...
# for bc
ModuleFqnToConfig = ParamOrModuleFqnToConfig |
To me it seems that the transform has to be for modules, because it is inplace. User can target a parameter if they want to, but the transform function always runs on a module that owns the parameter. |
fe12f23
to
7c5ab04
Compare
torchao/quantization/quant_api.py
Outdated
# skip if not direct child | ||
if "." not in name: | ||
for pattern in config.param_fqn_to_config: | ||
if re.match(pattern, f"{fqn}.{name}"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so it applies to all params, regardless of what it is? e.g. bias? should we be more specific in what people are configuring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should consider the regex syntax separately, I can remove from this PR.
One thing I would like would be for quantize_
log the modules/params it's swapping so it's easy to see what the difference is.
Does this mean we need to refactor all supported configs to use this structure?
|
torchao/quantization/quant_api.py
Outdated
class ModuleOrParamFqnToConfig(AOBaseConfig): | ||
"""Configuration class for applying different quantization configs to modules or parameters based on their fully qualified names (FQNs). | ||
This extends the functionality of ModuleFqnToConfig to support parameter-level quantization configurations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: comment seems stale
@jcaip would this be simpler than having two transform registration systems? |
cc @vkuzo Hmm, I think the pseudocode mentioned here vs the logic in the PR and having two transform registration systems are a bit orthogonal. It's possible to have one registration system with the logic in the PR as well. I'm assuming your main concern is with having two registration systems? Let me know if that's not the case. IMO it's about the same complexity to have one registration system vs two. My main preference for having two registration systems is that it reduces the amount of work we have to do to enable other Configs for parameter quantization - we just need to add the decorator to our |
yes, and even further IMO we should have a single "modify module inplace" paradigm instead of having one paradigm for modules and one for parameters
IMO we should go for the solution where the resulting code is the simplest, if that involves manual work that seems OK to me, and we can parallelize the conversions if you don't want to do them alone. Reducing the work to convert but ending up with two systems seems like trading dev time now for increased system complexity later. |
OK I'll update the PR to use a single registration system.
One thing I want to point out is that it's difficult to supports stuff like our vLLM integration, where we pass in a parameter that's not tied to any module, with a single "modify module inplace" paradigm. |
I think "everything is parameters" is also a valid solution, I just don't think we should have both - let's pick one? |
torchao/quantization/quant_api.py
Outdated
`module_fqn_to_config`: typing.OrderedDict[str, Optional[AOBaseConfig]]: an | ||
ordered dictionary from | ||
(1). fully qualified name (fqn) of module or | ||
module_fqn_to_config (OrderedDict[str, Optional[AOBaseConfig]]): An ordered dictionary mapping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use typing.OrderedDict
since it's different from collections.OrderedDict
torchao/quantization/quant_api.py
Outdated
Raises: | ||
NotImplementedError: If a configuration type doesn't have a registered parameter handler. | ||
""" | ||
top_level_named_parameters_list = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: is this the same as list(dict(mod_containing_param.named_parameters()).items())
torchao/quantization/quant_api.py
Outdated
for name, param in top_level_named_parameters_list: | ||
for pattern, param_config in config.module_or_param_fqn_to_config.items(): | ||
full_param_fqn = f"{fqn}.{name}" | ||
if (pattern == full_param_fqn) or ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, if we want exact match (==
) to take precedence, I think it has to be a separate check,
if pattern == full_param_fqn:
...
elif pattern.startswith("re:") and ...:
...
A test of
model: with linear1 module
config: {"re:linear.*": config1, "linear1": config2}
and linear1 should be quantized with config2 instead of config1 should catch it
test/quantization/test_quant_api.py
Outdated
"0": Float8DynamicActivationFloat8WeightConfig( | ||
granularity=PerRow(), | ||
), | ||
"re:.*weight": Float8DynamicActivationFloat8WeightConfig( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should test the reverse order I think, to make sure 0
takes precedence
quantize_( | ||
model, | ||
quant_config, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checks?
torchao/quantization/quant_api.py
Outdated
`module_fqn_to_config`: typing.OrderedDict[str, Optional[AOBaseConfig]]: an | ||
ordered dictionary from | ||
(1). fully qualified name (fqn) of module or | ||
module_fqn_to_config (OrderedDict[str, Optional[AOBaseConfig]]): An ordered dictionary mapping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also to correct the naming, we can add a module_or_param_fqn_to_config
field and use that for version 2, and go through the normal version update path like other configs as well I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about just fqn_to_config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah sounds good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jcaip can you add ModuleOrParamFqnToConfig
to torchao docs as well? I would like to link to it in transformer docs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall, just one main question about how the default filter_fn
interacts with the config
(fqn-configuration)= | ||
### 3. FQN Configuration | ||
|
||
For granular control, use `ModuleFqnToConfig`: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we also document this in serving.md, can you update that doc as well?
assert isinstance(model.shared_expert.gate_proj.weight, Float8Tensor) | ||
assert model.shared_expert.gate_proj.weight.scale.numel() == 1 | ||
|
||
def test_quantize_modle_exact_match_preference(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: typo modle
""" | ||
torch._C._log_api_usage_once("torchao.quantization.quantize_") | ||
|
||
filter_fn = _is_linear if filter_fn is None else filter_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this default filter_fn
going to have unexpected consequences if people are using FqnToConfig
? E.g. let's say someone literally just wants to quantize a very specific parameter:
quantize_(model, FqnToConfig({"layers.0.some.parameter": Int4WeightOnlyConfig()}))
If I'm reading the code correctly, right now we do the replacement if either (1) we match the filter_fn, or (2) we match the fqn. Would the above unexpectedly quantize all the other linear layers in the model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, replacement won't do anything as the other linear layers aren't specified in the config. I can add a test for this though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think it would be good to verify this, from the code it seems we do the replacement if we match either the filter_fn
or the config
(not and). Would also be good to clearly document the semantics of filter_fn
in the docstring in this case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I think the semantic should be
- if both fqn_to_config and filter_fn specified, both have to match for config to be applied (AND, not OR)
- else, use whichever one is applied
it seems like we should consider breaking BC here and change the default filter_fn
to is_linear
, so that if user passes in filter_fn == None
then only fqn_to_config
is applied?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my mind, if someone specifies a fqn in the config, it's pretty clear that they want to quantize it. So I think AND is kind of a footgun here, especially if the default filter_fn
is is_linear
. i.e. First time user wants to quantize a parameter, adds an entry to FqnToConfig, and the new param doesn't get quantized because the default filter_fn is is_linear
. I guess we can just throw a warning in this instance though.
cc @jerryzh168 what do you think? I'll defer to whatever's most popular with the team.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me, ill update the pr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed on removing filter_fn
longer term
I think it is used pretty widely though, so maybe not in this PR and we do it separately with a proper deprecation? We can punt in this PR by just throwing an exception if fqn_to_config
is provided along with a non-default filter_fn
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
filter_fn
has a lot of internal uses, and it's how many users apply quantization/QAT to linear and embedding separately today. We should do a careful deprecation of this and make sure existing use cases have a good alternative
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@andrewor14 , any thoughts on "We can punt in this PR by just throwing an exception if fqn_to_config is provided along with a non-default filter_fn."?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can punt in this PR by just throwing an exception if fqn_to_config is provided along with a non-default filter_fn
Yeah sounds good to me
torchao/quantization/quant_api.py
Outdated
regex patterns (as strings) to quantization configurations. | ||
The patterns can be one of the follows: | ||
(1). fully qualified name (fqn) of module or paramter or |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: paramter
torchao/quantization/quant_api.py
Outdated
`module_fqn_to_config`: typing.OrderedDict[str, Optional[AOBaseConfig]]: an | ||
ordered dictionary from | ||
(1). fully qualified name (fqn) of module or | ||
module_fqn_to_config (OrderedDict[str, Optional[AOBaseConfig]]): An ordered dictionary mapping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the docstring still references the old arg name I think
torchao/quantization/quant_api.py
Outdated
torch._C._log_api_usage_once("torchao.quantization.FqnToConfig") | ||
if len(self.module_fqn_to_config) > 0 and len(self.fqn_to_config) > 0: | ||
warnings.warn( | ||
"Both module_fqn_to_config and fqn_to_config are specified, only fqn_to_config will be used" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel this is going to be a silent error for some users, should we just ban this case for simplicity? It's not for BC
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we should just ValueError here.
torchao/quantization/quant_api.py
Outdated
warnings.warn( | ||
"Both module_fqn_to_config and fqn_to_config are specified, only fqn_to_config will be used" | ||
) | ||
if len(self.module_fqn_to_config) > 0 and len(self.fqn_to_config) == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: if you throw an error above then this can become:
if len(self.module_fqn_to_config) > 0:
assert len(self.fqn_to_config) == 0
self.fqn_to_config = self.module_fqn_to_config
and you don't need the rest of the cases (probably don't need to update self.module_fqn_to_config
to match self.fqn_to_config
?)
torchao/quantization/quant_api.py
Outdated
return handler(module, c) | ||
|
||
return module | ||
def select_module_if_filter_fn_or_contains_params_matching_pattern( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
private?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! I'll let Jerry/Vasiliy stamp since they reviewed this in more detail
torchao/quantization/quant_api.py
Outdated
Args: | ||
fqn (str): The fully qualified name to match against the config patterns. | ||
config (FqnToConfig): The FqnToConfig object containing mapping of FQNs or regex patterns to quantization configs. | ||
torchao/quantization/quant_api.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove?
""" | ||
torch._C._log_api_usage_once("torchao.quantization.quantize_") | ||
|
||
filter_fn = _is_linear if filter_fn is None else filter_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think it would be good to verify this, from the code it seems we do the replacement if we match either the filter_fn
or the config
(not and). Would also be good to clearly document the semantics of filter_fn
in the docstring in this case
return found, c | ||
|
||
|
||
def _select_module_if_filter_fn_or_contains_params_matching_pattern( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO should be AND, not OR
_module_fqn_to_config_handler, | ||
filter_fn, | ||
_fqn_to_config_handler, | ||
partial( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like we are passing one callable and one callable wrapping a callable into a fuction, seems a bit hard to follow. Have we considered just writing this directly instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can write this as a lambda, if that's a bit clearer to you?
lambda mod, fqn: filter_fn(mod, fqn) and select_with_module(mod, fqn, config=config)
This PR adds in support for quantizing
nn.Parameter
toquantize_
.ModuleFqnToConfig
has been renamed toFqnToConfig
, which now accepts both module fqn and parameter fqns.ModuleFqnToConfig
has been aliased to maintain BC.API examples
For example, a toy nn.Linear model,
The keys to
FqnToConfig
can be one of the following (in order of precedence):re:
)re:
)To enable support for parameter fqn for a paticular config, we need to add the
parameter_name
kwarg into the config signature, and updateCUSTOM_PARAM_QUANTIZATION_SUPPOTED_CONFIGS
. See the changes here for more details.Float8DynamicActivationFloat8WeightConfig
has been enabled by this PR, but other configs will throw anNotImplementedError
.Test Plan
How do our configs translate for MoEs?
Currently, we define a bunch of configs that are for dense nn.Linear modules, how do these configs translate in the case of MoE inference?
Some background on MoE inference
There are two ways that forwards is implemented for MoE
nn.Linear
- In this case, we break down the 3d weight x activation matmul into a for loop of 2d weight x activation matmuls. This can be seen here.In this case, I argue that the semantics of the configs do not change at all from the normal
nn.Linear
case, as we are just doing a bunch of normal 2d linear matmuls.For this case, we'd need to add additional op support (bmm) for forwards. Depending on whether the subclass is an AQT subclass or non AQT subclass this will be added differently.
I plan to only support parameter quantization for non-AQT subclasses, my reasoning being that those are the most popular / important configs anyway (Float8Dynamic, Int4WeightOnly).
Below is a breakdown of what Configs map to AQT / non-AQT subclasses:
For these the majority of the semantics remain the same, the only semantics that really changes is
PerRow
granularity. and there's a very natural extension ofPerRow
to the 3d case (apply on the last dimension).I took a look at the keys of the non-AQT configs below and what they would mean for MoEs.
Float8DynamicActivationFloat8WeightConfig
activation_dtype
,weight_dtype
,activation_value_lb
,activation_value_ub
all do not change meaning semantically.granularity=PerTensor()
does not change semantic meaning - we still use a single tensor to scale the entire weight tensor.granularity=PerRow()
does change meaning - we now calculate a scale for each row for the last dimension [-1] i.e for a weight of (E, N, K) we would expect PerRow to create scales of block size (1, 1, K).mm_config
kernel_preference
andset_inductor_config
stay the same as well.Float8StaticActivationFloat8WeightConfig
scale
should be passed in as a 3d tensor instead of a 2d tensor in the case ofPerRow
granularityFloat8DynamicActivationInt4WeightConfig
int4_packing_format - Only "preshuffled" is supported and Int4PreshuffledTensor supports 3d weights.
Int4WeightOnlyConfig
group_size
,int4_packing_format
,int4_choose_qparams_algorithm
,set_inductor_config
are the only things that are set for v2 config,I don't think these semantics of these change, although there are some packing formats that do not support 3d weights. It looks like (
Int4PackingFormat.PLAIN_INT32
,Int4PackingFormat.MARLIN_SPARSE
).