Skip to content

Conversation

@PatrikPerssonInceptron
Copy link
Contributor

Problem

The shape names in an ONNX model can contain expressions such as the shape int64[batch_size,past_sequence_length + sequence_length] of the attention mask of an LLM. In this case, the second dimension contains an expression past_sequence_length + sequence_length where past_sequence_length and sequence_length should be individual variables added together. However, currently, a new variable named "past_sequence_length + sequence_length" is instead created when translating the graph.

Fix

I added a simple parser that creates individual size variables for the variable names and generates the resulting prim expression. Note, in order to keep the parser simple, it evaluates expressions left to right. Not accounting for operator precedence.

Test

I added regression tests to verify that the onnx shape dim expression are evaluated correctly.

Additional small fixes

In the case when PrimValues are encountered in the BinaryBase, they are not always fully extracted before turning them into numpy arrays. I added an additional check that extracts the value from IntImm and FloatImm types before converting them to numpy arrays.

expressions such as past_sequence_length + sequence_length where each
variable becomes a tvm.tir.SizeVar
updated binary base to completely unpack relax.PrimValue if it contains
tir.IntImm or tir.FloatImm

added regression tests
@Hzfengsy
Copy link
Member

Thanks for the improvements

@Hzfengsy Hzfengsy merged commit d5b9f5c into apache:main Nov 10, 2024
19 checks passed
@PatrikPerssonInceptron PatrikPerssonInceptron deleted the feature/onnx-input-shape-computations branch November 18, 2024 08:01
ShiboXing pushed a commit to ShiboXing/tvm that referenced this pull request Aug 10, 2025
… names (apache#17505)

* added a simple parser that can handle onnx variable names containing
expressions such as past_sequence_length + sequence_length where each
variable becomes a tvm.tir.SizeVar

* added doc strings

updated binary base to completely unpack relax.PrimValue if it contains
tir.IntImm or tir.FloatImm

added regression tests

* updated formatting
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants