-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[SLM] Allow modules to define pre-processing of weights #16757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SLM] Allow modules to define pre-processing of weights #16757
Conversation
Prior to this commit, the parameter specification for SLM tensor needed to be passed as a `nn.spec.Tensor`. As this object is only used to construct a `relax.TensorStructInfo`, and has the same fields as a `relax.TensorStructInfo`, this commit allows the parameter specification to be passed as a `relax.TensorStructInfo`.
Prior to this commit, a `nn.spec.Tensor`'s shape had special handling to ensure that symbolic variable were not reused across multiple functions. This commit updates this to instead be performed using the `CopyWithNewVars` function.
Prior to this commit, the weights used by `nn.Module` instances were required to be `nn.Parameter` instances. This commit allows the weights to instead be `nn.Tensor` instances, defined in terms of other `nn.Parameter` weights. This allows a model to define both the original weights that would be present in an external checkpoint (e.g. a Pytorch or Safetensors file), and the pre-processing that should be performed on those weights.
78e206c to
dcecf81
Compare
|
Separated out the functionality of this PR from the portions of #16737 that require additional discussion. This PR is now independent, and ready for review. |
| builder = BlockBuilder() | ||
| with builder.function("dummy_scope", params=[]): | ||
| expr = builder.normalize(expr) | ||
| builder.emit_func_output([]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have use case using it outside block builder? emitting a function is not equivalent to emitting the expression itself, unless the function is evaluated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The expression isn't actually being generated in the builder.function("dummy_scope") here, because it uses builder.normalize and not builder.emit. The purpose is to infer the StructInfo to be used, but to skip the flattening of nested expressions.
The only reason why the function scope is entered is because the builder.normalize method requires an active scope. This is something I hope to change in the future.
| outputs = _emit_effect_init(self.builder, effects) | ||
| self.builder.emit_func_output(outputs, params=[]) | ||
| for method_name, method_spec in zip(spec.method_names, spec.method_specs): | ||
| params = _params() # Re-initialize so symbolic shapes not shared across methods |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
personally this is preferable given its explicitness and simplicity (we don't need to do copy the whole function)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, generating fresh Relax variables for the parameters breaks any expressions that have been defined in terms of the previous Relax variables. (See this test case for an example.) The expressions held internally to represent the pre-processed weights are specified in terms of those initial Relax variables, so replacing them with fresh Relax variables leaves the original variables undefined. Using copy_with_new_vars afterward provides unique symbolic shapes in each function, without breaking the expressions used within each function.
Long-term, I'm planning to expand the tir.transform.ConvertSSA pass to handle both TIR and Relax functions. That way, it can be applied as a post-processing step to ensure all symbolic variables are unique to a single function, without needing to handle it before that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i also agree that the original logic was simpler. The intermediate computation sharing was a mechanism that we would like to introduce, would be good to discuss alternative ways to achieve the same goal. E.g. defining a seperate parameter mapping function explicitly rather than having things implicitly
| def _get_var(shape_var: tir.Var) -> tir.Var: | ||
| name = shape_var.name | ||
| if name in str2var_params: | ||
| return str2var_params[name] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removing this would cause issues on handling symbolic shape tying remapping and which allows simplicity of mapping string to variables. We should preserve the behavior
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you describe the behavior that should be preserved?
The main functionality of this step was to ensure symbolic variables are not duplicated across multiple functions, and this functionality is still implemented through copy_with_new_vars. The mapping of strings to variables is still handled here, in the _get_var function inside _method_spec_to_input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @Lunderberg, this PR causes the compilation failure in MLC LLM CI for the Phi model https://ci.mlc.ai/blue/organizations/jenkins/mlc-llm/detail/main/373/
I dug a bit into this. To describe the issue, say the Phi model has hidden size 2560, the final linear layer of the Phi model has weight shape ("vocab_size", 2560), and bias shape ("vocab_size",).
Prior to this PR, both "vocab_size" in the weight and bias are remapped to the same tir.Var, so that for any x, op.matmul(x, op.transpose(weight)) + bias can have the output shape successfully inferred, represented in vocab_size.
Since this PR, the "vocab_size" var in the weight and bias shape mismatch. Therefore, the result shape cannot be symbolically inferred, falling back to R.Tensor(dtype="float16", ndim=3), and thus caused the assertion failure as you can see in the CI:
In matmul, x.sinfo = R.Tensor((1, 1, vocab_size), dtype="float16"), b.sinfo = R.Tensor((vocab_size,), dtype="float16")
In wrap nested, expr = R.add(matmul128, lm_head_linear_bias1)
after emitting, expr = add192, expr.sinfo = R.Tensor(dtype="float16", ndim=3)
File "/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/modules.py", line 139, in forward
x = x + self.bias
~~^~~~~~~~~~~
File "/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/_tensor_op.py", line 44, in __add__
return _op().add(self, other)
^^^^^^^^^^^^^^^^^^^^^^
File "/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/op.py", line 105, in add
return wrap_nested(_op.add(a._expr, b._expr), name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/core.py", line 614, in wrap_nested
return Tensor(_expr=expr)
^^^^^^^^^^^^^^^^^^
File "/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/core.py", line 117, in __init__
_check_tensor(_expr)
File "/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/core.py", line 112, in _check_tensor
assert expr.struct_info.shape is not None
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError File "/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/modules.py", line 139, in forward
x = x + self.bias
~~^~~~~~~~~~~
File "/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/_tensor_op.py", line 44, in __add__
return _op().add(self, other)
^^^^^^^^^^^^^^^^^^^^^^
File "/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/op.py", line 105, in add
return wrap_nested(_op.add(a._expr, b._expr), name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/core.py", line 614, in wrap_nested
return Tensor(_expr=expr)
^^^^^^^^^^^^^^^^^^
File "/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/core.py", line 117, in __init__
_check_tensor(_expr)
File "/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/core.py", line 112, in _check_tensor
assert expr.struct_info.shape is not None
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for digging into it, and I'll look into it on Monday.
My first impression is that we shouldn't support this behavior while exporting a module. If we have a dynamic Linear module with dynamic shape parameters input_features and output_features, we should not assume that all such instances of the dynamic Linear have the same input_features and output_features.
I think we should normalize all str shapes to the same tir.Var, but anything already specified as distinct tir.Var shapes should remain separate. This could later be moved to an improved struct inference, which could provide R.assert_op on symbolic shapes, rather than falling back to R.Tensor(ndim=3).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
having ability to have common dynamic variables like vocab_size is a pretty common need, so likely we will need to support this behavior. The in_features and out_features are usually static in most cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @Lunderberg so much. I just made a regression test for this behavior and we can add it to test_frontend_nn_module.py later on. The test basically reflects what I described above using language. Please let me know if the test makes sense to you!
def test_linear_dynamic_shape():
@R.function
def forward(
x: R.Tensor((1, 4), dtype="float32"),
_io: R.Object,
weight: R.Tensor(("n", 4), dtype="float32"),
bias: R.Tensor(("n",), dtype="float32"),
) -> R.Tuple(R.Tensor((1, "n"), dtype="float32"), R.Tuple(R.Object)):
n = T.int64()
R.func_attr({"num_input": 2})
with R.dataflow():
permute_dims: R.Tensor((4, n), dtype="float32") = R.permute_dims(weight, axes=None)
matmul: R.Tensor((1, n), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
add: R.Tensor((1, n), dtype="float32") = R.add(matmul, bias)
gv1: R.Tuple(R.Tensor((1, n), dtype="float32"), R.Tuple(R.Object)) = add, (_io,)
R.output(gv1)
return gv1
mod = modules.Linear(in_features=4, out_features="n", bias=True)
tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((1, 4), "float32")}}, debug=True)
assert_structural_equal(tvm_mod["forward"], forward, True)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MasterJH5574 Thank you, and that helps to reproduce the issue. Does the issue only involve nn.Linear, or would it also apply to dynamic shapes that span across multiple independent nn.Module subclasses? If the issue is specific to nn.Linear, can you take a look at #16781? This avoids the duplicate default conversion from str to a fresh TIR variable in both the weights and bias.
Overall, I think de-duplicating str values makes sense, because Python str objects have value equality. De-duplicating TIR variables by their name doesn't make sense, because TIR variables have reference equality. I think we could have a good intermediate implementation if we de-duplicate str values prior to conversion into TIR variables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MasterJH5574 Can you test against #16785? It re-implements the same functionality that I require on my side, while maintaining the handling of dynamic shapes specified by strings.
|
The mixtral compile using mlc-llm with the PR failed, resulting in the following error: |
…)" This reverts commit 1cccc3b.
|
temp revert this PR for now as the missing symbolic var remapping logic that caused some compatibility issues. We can followup with a test case that covers the name remapping. Also love to discuss the design of the preprocessing part, since nn.Parameter is more common in most usages, it might be possible for us to think about alternative maybe explicitly create a remapping function in module? |
|
Having an explicit remapping function could also work well. I wanted to avoid having any changes to the |
Resolve a breakage introduced in apache#16757. Prior to apache#16757, distinct TIR variables were unified if they had the same name. This commit avoids using distinct TIR variables to represent the same user input.
Prior to this commit, the weights used by `nn.Module` instances were required to be `nn.Parameter` instances. This commit allows the weights to instead be `nn.Tensor` instances, defined in terms of other `nn.Parameter` weights. This allows a model to define both the original weights that would be present in an external checkpoint (e.g. a Pytorch or Safetensors file), and the pre-processing that should be performed on those weights. This is a re-implementation of apache#16757, which was reverted in apache#16777. The re-implementation preserves the handling of dynamic shapes specified as python strings, enabling the test cases that were added in apache#16784.
* [SLM] Allow TensorStructInfo to specify parameter in export Prior to this commit, the parameter specification for SLM tensor needed to be passed as a `nn.spec.Tensor`. As this object is only used to construct a `relax.TensorStructInfo`, and has the same fields as a `relax.TensorStructInfo`, this commit allows the parameter specification to be passed as a `relax.TensorStructInfo`. * Resolve breakage in unit tests * [SLM] Use `CopyWithNewVars` to de-duplicate symbolic variables Prior to this commit, a `nn.spec.Tensor`'s shape had special handling to ensure that symbolic variable were not reused across multiple functions. This commit updates this to instead be performed using the `CopyWithNewVars` function. * [SLM] Allow modules to define pre-processing of weights Prior to this commit, the weights used by `nn.Module` instances were required to be `nn.Parameter` instances. This commit allows the weights to instead be `nn.Tensor` instances, defined in terms of other `nn.Parameter` weights. This allows a model to define both the original weights that would be present in an external checkpoint (e.g. a Pytorch or Safetensors file), and the pre-processing that should be performed on those weights. * Undo portions that would introduce R.Tensor to nn.Module * Remove unit tests that were related to TensorStructInfo
…che#16777) Revert "[SLM] Allow modules to define pre-processing of weights (apache#16757)" This reverts commit 1cccc3b.
Prior to this commit, the weights used by
nn.Moduleinstances were required to benn.Parameterinstances. This commit allows the weights to instead benn.Tensorinstances, defined in terms of othernn.Parameterweights. This allows a model to define both the original weights that would be present in an external checkpoint (e.g. a Pytorch or Safetensors file), and the pre-processing that should be performed on those weights.