Skip to content

Conversation

gs-olive
Copy link
Contributor

Description

  • Fix 2 bugs in linear-to-addmm lowering pass:
    • Lowering pass did not explore nested sub-blocks of a node, of the sort contained in prim::If when bias=None. As an example, these would be ignored:
      %4 : Tensor = prim::If(%true)
        block0():
          %res = aten::linear(%input, %weight, %biasNone)
          -> (%res)
        block1():
          %res = aten::linear(%input, %weight, %biasNone)
          -> (%res)
  • Lowering pass did not insert fused linear code inside sub-blocks of prim::If even when the original function call occurred within such a block. As an example, the following translates to the subsequent graph, which contains an invalid computation:
      %4 : Tensor = prim::If(%true)
        block0():
          %res = aten::linear(%input, %weight, %bias)
          -> (%res)
        block1():
          %res = aten::linear(%input, %invalid_weight, %bias)
          -> (%res)

=============== TRANSLATES TO ===============

  %13 : int = prim::Constant[value=1]()
  %14 : Tensor = aten::t(%weight)
  %15 : Tensor = aten::matmul(%input, %14)
  %16 : Tensor = trt::const(%bias)
  %17 : Tensor = aten::add(%16, %15, %13)
  %3 : bool = prim::Constant[value=1]()
  %4 : Tensor = aten::t(%weight)
  %8 : int = prim::Constant[value=1]()
  %9 : Tensor = aten::t(%4)
  %10 : Tensor = aten::matmul(%input, %9)
  %11 : Tensor = trt::const(%bias)
  %12 : Tensor = aten::add(%11, %10, %8)
  %5 : Tensor = prim::If(%3)
    block0():
      -> (%17)
    block1():
      -> (%12)

=============== LEADING TO ===============

%10 : Tensor = aten::matmul(%input, %9): last dimension of input0 = 7 and second to last dimension of input1 = 8 but must match.
  • The latter causes issues when the control-flow switches between two versions of aten::linear, only one of which is a valid operation. Thus, evaluating both branches can cause compilation to crash, as invalid Tensor shapes can be encountered
  • Update implementation to run recursively through all nested blocks within all nodes
  • Update implementation to remove the use of RegisterRewritePattern paradigm for Tensor biases, as the rewrite does not always place the subgraph in the desired location
  • Add regression test cases to isolate both bugs

Addresses 1st Bug in #1616

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • [ x ] My code follows the style guidelines of this project (You can use the linters)
  • [ x ] I have performed a self-review of my own code
  • [ x ] I have commented my code, particularly in hard-to-understand areas and hacks
  • [ x ] I have made corresponding changes to the documentation
  • [ x ] I have added tests to verify my fix or my feature
  • [ x ] New and existing unit tests pass locally with my changes
  • [ x ] I have added the relevant labels to my PR in so that relevant reviewers are notified

@gs-olive gs-olive self-assigned this Jan 26, 2023
@github-actions github-actions bot added component: core Issues re: The core compiler component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests labels Jan 26, 2023
@gs-olive gs-olive force-pushed the linear_to_addmm_bugfix branch from 0fd1b12 to efd796a Compare January 27, 2023 01:14
- Fix 2 bugs in linear-to-addmm lowering pass:
  - Lowering pass did not explore nested sub-blocks of a node, of the
sort contained in `prim::If` when `bias=None`
  - Lowering pass did not insert fused linear code inside sub-blocks of
`prim::If` even when the original function call occurred within such a
block
  - The latter causes issues when the control-flow switches between two
versions of `aten::linear`, only one of which is a valid operation.
Thus, evaluating both branches can cause compilation to crash, as
invalid Tensor shapes can be encountered
- Update implementation to run recursively through all nested blocks
within all nodes
- Update implementation to remove the use of `RegisterRewritePattern`
paradigm for Tensor biases, as the rewrite does not always place the
subgraph in the desired location
- Add regression test cases to isolate both bugs
Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@narendasan narendasan merged commit 3e422f5 into pytorch:main Feb 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: core Issues re: The core compiler component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants