Skip to content

Conversation

@Lunderberg
Copy link
Contributor

Prior to this commit, the relax.op.split did not provide a known shape if the split was performed over a dynamic-size axis. This commit updates the shape inference to provide correct shapes in this case. A test case is also added for CombineParallelMatmul to show the intended usage of this feature, to ensure that the split outputs from a combined matmul still have correct shape information.

@Lunderberg Lunderberg force-pushed the unity_infer_split_dynamic_shapes branch from 81084c4 to b15daba Compare January 22, 2024 17:58
@Lunderberg Lunderberg changed the base branch from unity to main January 22, 2024 17:58
@Lunderberg Lunderberg force-pushed the unity_infer_split_dynamic_shapes branch from b15daba to ed7c05d Compare January 23, 2024 16:21
Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a clear improvement and the changes to the code are straightforward.

Comment on lines 2266 to 2271
# Spliting a dynamic axis at specific indices is allowed. The
# algebraic form here isn't the cleanest, primarily because the
# test case doesn't know that `n` is a shape variable. When
# occurring in a relax function, `n` would be marked with
# `analyzer_.MarkGlobalNonNegValue`, which would make the shapes
# simplify to `[(2,16), (3,16), (n,16)]`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way to test that directly? Maybe the BlockBuilder can expose its analyzer for testing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also note the typo in the first word :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, and I've updated to fix the typo, and to test the simplified expressions directly. Rather than exposing the block builder's analyzer, this test case now calls bb.begin_scope. That way, failures resulting from changes to the BlockBuilderImpl::BeginScope implementation would also be caught.

auto it = shape_var_map.find(shape_var);
if (it == shape_var_map.end()) {
shape_var_map.Set(shape_var, shape_expr);
analyzer_.MarkGlobalNonNegValue(shape_var);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth commenting that indicating this results in much simpler symbolic shapes in many cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, and I've added a comment on it.

@Lunderberg Lunderberg merged commit cb6e4ee into apache:main Feb 6, 2024
@Lunderberg Lunderberg deleted the unity_infer_split_dynamic_shapes branch February 6, 2024 14:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants