Skip to content

Commit c3aa71a

Browse files
authored
[Unity][Analysis] Add utility for collecting compile-time bindings (#16312)
Whether an optimizations should be performed may depend on when the variables in an expression are known. For example, consider a LoRA-adjusted model, with base weights `W` of shape `[m,n]`, LoRA components `A` and `B` with shapes `[r,n]` and `[m,r]` respectively, and activations `x` with shape `[n,1]`. The LoRA-adjusted matmul could be computed either as `(W + B*A)*x` or as `(W*x + B*(A*x))`. If `A` and `B` are provided at run-time, then computing `(W + B*(A*x))` requires significantly fewer computations. * `(W + B*A)*x`: `m*n*(2*r + 3)` operations 1. `B*A`: `2*m*n*r` operations using a naive matmul 2. Adding `W` to (1): `m*n` operations 3. Multiplying `x` by (2): `2*m*n` operations * `(W*x + B*(A*x))`: (2*m*n + r*(2*n + 2*m + 1)) 1. `W*x`: `2*m*n` operations 2. `A*x`: `2*r*n` operations 3. Multiplying `B` by (2): `2*m*r` operations 4. Adding (1) and (3)`: `m` operations However, if `A` and `B` are known at compile-time, then computing `(W + B*A)*x` groups all compile-time values together, allowing them to be computed earlier (i.e. using `LiftTransformParams`) * `(W + B*A)*x`: `2*m*n` operations 1. `B*A`: 0 operations, computed at compile-time 2. Adding `W` to (1): 0 operations, computed at compile-time 3. Multiplying `x` by (2): `2*m*n` operations Since the choice of optimized expression depends on which parameters can be computed at compile-time, it is useful to have a utility that identifies values that can be computed at compile-time.
1 parent 49fc613 commit c3aa71a

File tree

5 files changed

+383
-0
lines changed

5 files changed

+383
-0
lines changed

include/tvm/relax/analysis.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,21 @@ TVM_DLL bool WellFormed(IRModule m, bool check_struct_info = true);
533533
TVM_DLL Map<tir::Block, Map<ObjectRef, tir::IndexMap>> SuggestLayoutTransforms(
534534
const Function& fn, Array<tir::IndexMap> write_buffer_transformations);
535535

536+
/* \brief Collect variables whose value can be computed at compile-time
537+
*
538+
* If a function has the `kNumInput` attribute, then the first
539+
* `kNumInput` parameters are provided at run-time, while all
540+
* remaining parameters may be known at compile-time. This utility
541+
* collects all variable bindings that only depend, directly or
542+
* indirectly, on the parameters known at compile-time.
543+
*
544+
* \param func The relax::Function to analyze
545+
*
546+
* \return The set of variables that can be computed at compile-time,
547+
* in order of their occurrence within the function.
548+
*/
549+
TVM_DLL Array<Var> ComputableAtCompileTime(const Function& func);
550+
536551
} // namespace relax
537552
} // namespace tvm
538553

python/tvm/relax/analysis/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
all_global_vars,
2222
all_vars,
2323
bound_vars,
24+
computable_at_compile_time,
2425
contains_impure_call,
2526
definable_tir_vars_in_struct_info,
2627
defined_symbolic_vars,

python/tvm/relax/analysis/analysis.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,3 +528,28 @@ def detect_recursion(mod: tvm.IRModule) -> List[List[GlobalVar]]:
528528
with any other, it will be a singleton in this list.
529529
"""
530530
return _ffi_api.detect_recursion(mod) # type: ignore
531+
532+
533+
def computable_at_compile_time(func: Function) -> List[Var]:
534+
"""Collect variables whose value can be computed at compile-time
535+
536+
If a function has the `kNumInput` attribute, then the first
537+
`kNumInput` parameters are provided at run-time, while all
538+
remaining parameters may be known at compile-time. This utility
539+
collects all variable bindings that only depend, directly or
540+
indirectly, on the parameters known at compile-time.
541+
542+
Parameters
543+
----------
544+
func: Function
545+
546+
The `relax.Function` to analyze
547+
548+
Returns
549+
-------
550+
ret: List[Var]
551+
552+
The set of variables that can be computed at compile-time, in
553+
order of their occurrence within the function.
554+
"""
555+
return _ffi_api.computable_at_compile_time(func) # type: ignore
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file computable_at_compile_time.cc
22+
*
23+
* \brief Utilities for identifying potentially compile-time variables
24+
*/
25+
26+
#include <tvm/relax/analysis.h>
27+
#include <tvm/relax/expr_functor.h>
28+
29+
#include "../../support/ordered_set.h"
30+
31+
namespace tvm {
32+
namespace relax {
33+
34+
namespace {
35+
class CompileTimeCollector : ExprVisitor {
36+
public:
37+
static Array<Var> Collect(const Function& func) {
38+
CompileTimeCollector visitor;
39+
visitor(func);
40+
return Array<Var>(visitor.known_relax_vars_.begin(), visitor.known_relax_vars_.end());
41+
}
42+
43+
private:
44+
void VisitExpr_(const FunctionNode* func) override {
45+
if (auto opt_num_input = func->attrs.GetAttr<Integer>(attr::kNumInput)) {
46+
size_t num_input = opt_num_input.value()->value;
47+
for (size_t i = num_input; i < func->params.size(); i++) {
48+
MarkAsKnown(func->params[i]);
49+
}
50+
}
51+
52+
ExprVisitor::VisitExpr_(func);
53+
}
54+
55+
void VisitBinding(const Binding& binding) override {
56+
Expr value = GetBoundValue(binding);
57+
bool can_compute_at_compile_time = [&]() {
58+
for (const auto& relax_var : FreeVars(value)) {
59+
if (!known_relax_vars_.count(relax_var)) {
60+
return false;
61+
}
62+
}
63+
for (const auto& tir_var : FreeSymbolicVars(value)) {
64+
if (!known_tir_vars_.count(tir_var)) {
65+
return false;
66+
}
67+
}
68+
69+
return true;
70+
}();
71+
72+
if (can_compute_at_compile_time) {
73+
MarkAsKnown(binding->var);
74+
}
75+
76+
ExprVisitor::VisitBinding(binding);
77+
}
78+
79+
void MarkAsKnown(const Var& var) {
80+
known_relax_vars_.insert(var);
81+
for (const auto& tir_var : DefinableTIRVarsInStructInfo(GetStructInfo(var))) {
82+
known_tir_vars_.insert(tir_var);
83+
}
84+
}
85+
86+
support::OrderedSet<Var> known_relax_vars_;
87+
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> known_tir_vars_;
88+
};
89+
} // namespace
90+
91+
Array<Var> ComputableAtCompileTime(const Function& func) {
92+
return CompileTimeCollector::Collect(func);
93+
}
94+
95+
TVM_REGISTER_GLOBAL("relax.analysis.computable_at_compile_time")
96+
.set_body_typed(ComputableAtCompileTime);
97+
98+
} // namespace relax
99+
} // namespace tvm

0 commit comments

Comments
 (0)