Skip to content

Conversation

@Lunderberg
Copy link
Contributor

Prior to this commit, FuseOps and FuseOpsByPattern exposed a symbolic variable to the fused function if it was used within the fused function, but wasn't inferable from other parameter shapes. While this prevents undefined symbolic variables, it can cause issues for downstream use of CodegenJSON, which requires all arguments to be tensors, or tuple of tensors.

Frequently, all uses of a non-inferable symbolic shape occur within a symbolic expression that can be inferred. For example, a function that takes arg: R.Tensor([N+1]) and returns R.add(arg, R.const(1)) cannot infer N. However, all occurrences of N occur as part of the expression N+1, and the value of N+1 can be inferred. Therefore, if we replace N+1 with M, the additional ShapeTuple argument isn't required.

In addition, prior to this commit, the CompositeFunctionAnnotator visited the body of functions without the parameters being considered in-scope. As a result, EraseToWellDefined would remove known shapes from the function body's StructInfo.

@tqchen
Copy link
Member

tqchen commented Jan 22, 2024

Ideally we don't want to change FuseOps behavior, since in cases where expressions are intermediate (e.g. intermediate compute include values that contains exprs like n * 4).

This is because we should get maybe we should look into compose them? FuseOps first then rewrite signatures

@Lunderberg
Copy link
Contributor Author

I could see having a post-processing pass to update the signature, maybe as an extension of RemoveUnusedParameters. There would still need to be an update to FuseOps to have the fused functions marked as private, since the post-processing step would only be allowed to update the signature of internal functions.

Though, could you expand on what you mean by intermediate expressions? In either case, whether implemented in FuseOps or in a post-processing pass, I think intermediate expressions would be handled correctly. If an expression n*4 can be inferred from the tensor shapes, but n+42 also appears in the fused function, then there would still be a shape expr used to expose n to the fused function.

@Lunderberg Lunderberg force-pushed the unity_fuse_ops_symbolic_var branch from b34ffe9 to b014332 Compare January 22, 2024 20:58
@Lunderberg Lunderberg force-pushed the unity_fuse_ops_symbolic_var branch from b014332 to 3556c4f Compare February 14, 2024 15:50
@Lunderberg
Copy link
Contributor Author

Rebased onto main to resolve conflicts.

For long-term, I think I agree that it would be cleaner and more general-purpose to have the functionality separated out into three distinct passes:

  1. FuseOps, with the first commit in this PR to preserve symbolic variables in the ret_struct_info.
  2. A not-yet-existing HoistCommonSubexpressions, which would recognize that a symbolic variable is always used within a specific expression, and would hoist it to the calling scope.
  3. Applying the RemoveUnusedParameters to remove the no-longer-required R.shape param.

@Lunderberg
Copy link
Contributor Author

I've separated the first commit of this PR branch into an independent PR (#16637), as the bugfix it provides is independent of the concerns raised, and does not require the not-yet-implemented HoistCommonSubexpressions transform.

Prior to this commit, `FuseOps` and `FuseOpsByPattern` exposed
a symbolic variable to the fused function if it was used within the
fused function, but wasn't inferable from other parameter shapes.
While this prevents undefined symbolic variables, it can cause issues
for downstream use of `CodegenJSON`, which requires all arguments to
be tensors, or tuple of tensors.

Frequently, all uses of a non-inferable symbolic shape occur within a
symbolic expression that can be inferred.  For example, a function
that takes `arg: R.Tensor([N+1])` and returns `R.add(arg, R.const(1))`
cannot infer `N`.  However, all occurrences of `N` occur as part
of the expression `N+1`, and the value of `N+1` can be inferred.
Therefore, if we replace `N+1` with `M`, the additional `ShapeTuple`
argument isn't required.
@Lunderberg
Copy link
Contributor Author

New functionality implemented in #16450, which would hoist out the common subexpressions. After it lands, this PR can be updated to make use of it.

@tqchen tqchen closed this Feb 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants