Skip to content

Commit db920dd

Browse files
authored
[COLLAGE] Add more customization to support more targets (#13450)
* [COLLAGE] Add more customization to support more targets 1. Added custom cost module to provide a provision to incorporate custom cost estimator python function instead using default cost function. eg: cost_estimator = CustomCostEstimator(py_fn_estimator="tvm.relay.collage.opencl_cost_estimator") mod = CollagePartition(config, cost_estimator=cost_estimator)(mod) 2. Added provision to select BYOC fusion style for all compiler target. eg : config = { "relay.collage.byoc_fusion_style": ["compiler.NoFusion", "compiler.TVMFusion"]} ctxt = tvm.transform.PassContext(config=config) * Fix the lint errors * Fix the lint error whitespace * Fix the lint error tabs * Fix the lint error tabs * Fix the lint error tabs * move the clml collage test case to test_clml * Fix lint error whitespace * Fix the import error * Fix the envirnoment var and import * Add comments * Add clml preprocess module in cost estimator * Fix whitespace lint error * Fix whitespace lint error * Fix whitespace lint error * Fix the comments and removed unwanted code * Fix whitespace error * Removed Todo comments * Removed TODO comments * Updated naming convension * Fix typo error * Fixe the typo error * Corrected typo error * Corrected typo error * Removed unused and fix typo error * Removed redundent code and optimize the code * Fix the lint error * Fix whitespace lint error * Removed Prints in file * Fix lint error * Fix lint error * Removed runner template in test script * Fix the lint error * Fix lint error * Fix lint error * Fix the lint error * Fix the lint error Co-authored-by: kvegiraj <[email protected]>
1 parent 1265eb9 commit db920dd

File tree

11 files changed

+567
-5
lines changed

11 files changed

+567
-5
lines changed

python/tvm/relay/collage/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@
2121
WARMUP_MIN_REPEAT_MS,
2222
CostEstimator,
2323
MockCostEstimator,
24+
CustomCostEstimator,
2425
)

python/tvm/relay/collage/collage.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@ def __init__(self, target_costs, max_estimates=0):
5252
self.__init_handle_by_constructor__(_ffi_api.MockCostEstimator, target_costs, max_estimates)
5353

5454

55+
@register_object("relay.collage.CustomCostEstimator")
56+
class CustomCostEstimator(Object):
57+
"""CustomEstimator class"""
58+
59+
def __init__(self, py_fn_estimator="tvm.relay.collage.estimate_seconds_custom"):
60+
self.__init_handle_by_constructor__(_ffi_api.CustomCostEstimator, py_fn_estimator)
61+
62+
5563
def arg_for(arg_type, device):
5664
"""Returns a test argument of Relay arg_type on device"""
5765
assert isinstance(arg_type, tvm.ir.TensorType)

python/tvm/relay/op/contrib/clml.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from tvm._ffi import register_func
2424
from tvm.relay import transform
2525
from tvm.relay.build_module import bind_params_by_name
26+
from tvm.relay import function as _function
2627
from tvm.relay.expr_functor import ExprMutator
2728
from tvm.relay.expr import Call, TupleGetItem
2829

@@ -161,6 +162,25 @@ def alter_conv(attrs, inputs, tinfos, out_type):
161162
return preprocessed_mod
162163

163164

165+
def preprocess_for_clml(mod):
166+
"""Preprocessing pass to alter the layouts for CLML compiler target"""
167+
168+
for _var in mod.get_global_vars():
169+
if _var.name_hint == "main":
170+
continue
171+
fn = mod[_var.name_hint]
172+
if "Compiler" in fn.attrs.keys() and fn.attrs["Compiler"] == "clml":
173+
new_fn = fn.body
174+
clml_mod = tvm.IRModule.from_expr(new_fn)
175+
with tvm.transform.PassContext(opt_level=3):
176+
clml_mod = preprocess_module(clml_mod)
177+
new_body = clml_mod["main"].body
178+
mod[_var.name_hint] = _function.Function(
179+
fn.params, new_body, fn.ret_type, fn.type_params, fn.attrs
180+
)
181+
return mod
182+
183+
164184
@register_pattern_table("clml")
165185
def clml_pattern_table():
166186
"""Get the CLML pattern table."""

src/relay/collage/collage_partitioner.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ namespace {
5555

5656
TVM_REGISTER_PASS_CONFIG_OPTION("relay.collage.tvm_max_depth", Integer);
5757
TVM_REGISTER_PASS_CONFIG_OPTION("relay.collage.byoc_max_depth", Integer);
58-
58+
TVM_REGISTER_PASS_CONFIG_OPTION("relay.collage.byoc_fusion_style", Array<String>);
5959
/*!
6060
* \brief Represents the overall expression after some number of non-overlapping candidate
6161
* partitions have been applied.
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file src/relay/collage/custom_cost_estimator.cc
22+
* \brief A custom CostEstimator to support alternative cost functions.
23+
*/
24+
25+
#include "./custom_cost_estimator.h"
26+
27+
#include <tvm/relay/expr_functor.h>
28+
29+
namespace tvm {
30+
namespace relay {
31+
namespace collage {
32+
33+
TVM_REGISTER_OBJECT_TYPE(CustomCostEstimatorNode);
34+
35+
Cost CustomCostEstimatorNode::Estimate(const IRModule& mod, const Target& target) const {
36+
static const runtime::PackedFunc* estimate_seconds = runtime::Registry::Get(py_fn_estimator_);
37+
ICHECK(estimate_seconds);
38+
const double value = (*estimate_seconds)(mod, target);
39+
if (std::isinf(value)) {
40+
return Cost::Invalid();
41+
} else if (std::isnan(value)) {
42+
return Cost::Unknown();
43+
} else {
44+
return Cost::Value(value);
45+
}
46+
}
47+
48+
CustomCostEstimator::CustomCostEstimator(String py_fn_estimator) {
49+
auto node = make_object<CustomCostEstimatorNode>();
50+
node->py_fn_estimator_ = std::move(py_fn_estimator);
51+
data_ = std::move(node);
52+
}
53+
54+
TVM_REGISTER_GLOBAL("relay.collage.CustomCostEstimator").set_body_typed([](String py_fn_estimator) {
55+
return CustomCostEstimator(std::move(py_fn_estimator));
56+
});
57+
58+
} // namespace collage
59+
} // namespace relay
60+
} // namespace tvm
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file src/relay/collage/custom_cost_estimator.cc
22+
* \brief A custom CostEstimator to support target-specific cost functions.
23+
*/
24+
25+
#ifndef TVM_RELAY_COLLAGE_CUSTOM_COST_ESTIMATOR_H_
26+
#define TVM_RELAY_COLLAGE_CUSTOM_COST_ESTIMATOR_H_
27+
28+
#include <tvm/relay/function.h>
29+
30+
#include "./cost.h"
31+
#include "./cost_estimator.h"
32+
33+
namespace tvm {
34+
namespace relay {
35+
namespace collage {
36+
37+
/*!
38+
* \brief A cost estimator that uses a target-specific cost function.
39+
*/
40+
class CustomCostEstimatorNode : public CostEstimatorNode {
41+
public:
42+
Cost Estimate(const IRModule& mod, const Target& target) const override;
43+
44+
static constexpr const char* _type_key = "relay.collage.CustomCostEstimator";
45+
TVM_DECLARE_FINAL_OBJECT_INFO(CustomCostEstimatorNode, CostEstimatorNode);
46+
47+
protected:
48+
/*!
49+
* \brief Python implemented cost function name.
50+
*/
51+
String py_fn_estimator_;
52+
53+
friend class CustomCostEstimator;
54+
};
55+
56+
class CustomCostEstimator : public CostEstimator {
57+
public:
58+
explicit CustomCostEstimator(String py_fn_estimator);
59+
60+
TVM_DEFINE_OBJECT_REF_METHODS(CustomCostEstimator, CostEstimator, CustomCostEstimatorNode);
61+
};
62+
63+
} // namespace collage
64+
} // namespace relay
65+
} // namespace tvm
66+
67+
#endif // TVM_RELAY_COLLAGE_CUSTOM_COST_ESTIMATOR_H_

src/relay/collage/gather_partition_specs.cc

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,9 @@ PartitionRule MakeTVMPartitionRule() {
8989
}
9090

9191
/*!
92-
* \brief Returns the fusion style for \p compiler.
93-
*
94-
* TODO(mbs): Defer to per-BYOC integration definition.
92+
* \brief Returns the fusion style for default compiler.
9593
*/
96-
BYOCStyle BYOCFusionStyleForCompiler(const String& compiler) {
94+
BYOCStyle DefaultBYOCFusionStyleForCompiler(const String& compiler) {
9795
if (compiler == "cutlass" || compiler == "cublas" || compiler == "cudnn") {
9896
return kNoFusionBYOCStyle;
9997
} else if (compiler == "tensorrt") {
@@ -103,6 +101,35 @@ BYOCStyle BYOCFusionStyleForCompiler(const String& compiler) {
103101
}
104102
}
105103

104+
/*!
105+
* \brief Returns the fusion style for given compiler.
106+
*/
107+
BYOCStyle BYOCFusionStyleForCompiler(const String& compiler) {
108+
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
109+
std::string config_key = "relay.collage.byoc_fusion_style";
110+
Optional<Array<String>> byoc_configs = ctxt->GetConfig(config_key, Optional<Array<String>>());
111+
BYOCStyle byoc_fusion_style = DefaultBYOCFusionStyleForCompiler(compiler);
112+
if (!byoc_configs) {
113+
return byoc_fusion_style;
114+
}
115+
for (auto config_ : byoc_configs.value()) {
116+
std::vector<std::string> byoc_cfg = SplitString(config_, ".");
117+
if (byoc_cfg[0] == compiler) {
118+
if (byoc_cfg[1] == "NoFusion") {
119+
byoc_fusion_style = kNoFusionBYOCStyle;
120+
} else if (byoc_cfg[1] == "TVMFusion") {
121+
byoc_fusion_style = kTVMFusionBYOCStyle;
122+
} else if (byoc_cfg[1] == "ArbitraryFusion") {
123+
byoc_fusion_style = kArbitraryFusionBYOCStyle;
124+
} else {
125+
ICHECK(false) << "Invalid fusion name for compiler " << byoc_cfg[0] << " in pass context";
126+
}
127+
break;
128+
}
129+
}
130+
return byoc_fusion_style;
131+
}
132+
106133
/*!
107134
* \brief Returns the primitive combiner rules which allow for any touching candidates
108135
* to be fused provided they don't have kind \p kOpaque.

src/relay/collage/utils.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,19 @@ bool MustBeLowered(const Expr& expr) {
134134
return false;
135135
}
136136

137+
std::vector<std::string> SplitString(std::string stmt, const char* del) {
138+
std::vector<std::string> str_tokens;
139+
int start = 0;
140+
int end = stmt.find(del, 0);
141+
str_tokens.emplace_back(stmt.substr(start, end));
142+
while (end != -1) {
143+
stmt = stmt.substr(end + 1, stmt.size());
144+
end = stmt.find(del, 0);
145+
str_tokens.emplace_back(stmt.substr(start, end));
146+
}
147+
return str_tokens;
148+
}
149+
137150
} // namespace collage
138151
} // namespace relay
139152
} // namespace tvm

src/relay/collage/utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <tvm/runtime/container/string.h>
3232

3333
#include <string>
34+
#include <vector>
3435

3536
namespace tvm {
3637
namespace relay {
@@ -79,6 +80,11 @@ bool IsSpecialOp(const OpNode* op_node);
7980
*/
8081
bool MustBeLowered(const Expr& expr);
8182

83+
/*!
84+
* \brief Returns the list of split strings of given statement with delimiter.
85+
*/
86+
std::vector<std::string> SplitString(std::string stmt, const char* del);
87+
8288
} // namespace collage
8389
} // namespace relay
8490
} // namespace tvm

0 commit comments

Comments
 (0)