Skip to content

Commit d1af75a

Browse files
committed
Fixes
Change-Id: I4f5f2a298bd3bb379c7c8d179150358923b0dd66
1 parent 31f9aa0 commit d1af75a

File tree

7 files changed

+203
-163
lines changed

7 files changed

+203
-163
lines changed

python/tvm/contrib/ethosu/cascader/pareto.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121

2222
from . import _ffi_api
2323
from .plan import Plan
24-
from .proposal import Proposal
25-
from .tensor_config import MemoryRegion
2624

2725

2826
def _get_pareto_frontier(costs: List[List[float]]) -> List[bool]:
@@ -39,9 +37,3 @@ def _thin_vector(vec: List[Object], max_size: int) -> List[Object]:
3937

4038
def _pareto_cull_plans(plans: List[Plan], max_plans: int) -> List[Plan]:
4139
return list(_ffi_api.ParetoCullPlans(plans, max_plans))
42-
43-
44-
def pareto_cull_proposals(
45-
proposals: List[Proposal], cascade_region: MemoryRegion, max_proposals: int
46-
) -> List[Proposal]:
47-
return list(_ffi_api.ParetoCullProposals(proposals, cascade_region, max_proposals))

python/tvm/contrib/ethosu/cascader/proposal.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,51 +22,85 @@
2222
from tvm.runtime import Object
2323

2424
from . import _ffi_api
25-
from .graph import Tensor, Part
25+
from .graph import Tensor, Part, CascaderGraph
2626
from .tensor_config import TensorConfig, MemoryRegion
2727

2828

2929
@tvm._ffi.register_object("contrib.ethosu.cascader.Proposal")
3030
class Proposal(Object):
31-
"""Proposal class"""
31+
"""A class which describes how to schedule a CascaderGraph as a series of disjoint Plans.
32+
33+
Attributes
34+
----------
35+
graph : CascaderGraph
36+
The CascaderGraph to which the Proposal applies.
37+
part_group : FrozenSet[Part]
38+
The Parts which are covered by the Proposal.
39+
plans : List[Plan]
40+
The Plans used in the Proposal.
41+
input_tensor_configs : Dict[Tensor, TensorConfig]
42+
The TensorConfigs indexed by Tensor in the Proposal which aren't produced by a Plan.
43+
cascade_region : MemoryRegion
44+
The MemoryRegion where cascading buffers should be homed.
45+
memory_usage : int
46+
The memory required to execute the Proposal in the cascading MemoryRegion.
47+
cycles : int
48+
The estimated cycles taken to execute the Proposal.
49+
50+
"""
3251

3352
def __init__(
3453
self,
54+
graph: CascaderGraph,
3555
part_group: FrozenSet[Part],
3656
plans: List[Plan],
3757
input_tensor_configs: Dict[Tensor, TensorConfig],
58+
cascade_region: MemoryRegion,
3859
memory_usage: Dict[MemoryRegion, int],
3960
cycles: int,
4061
):
4162
self.__init_handle_by_constructor__(
4263
_ffi_api.Proposal,
64+
graph,
4365
list(part_group),
4466
plans,
4567
input_tensor_configs,
68+
cascade_region,
4669
memory_usage,
4770
cycles,
4871
)
4972

5073
@property
51-
def graph(self):
74+
def graph(self) -> CascaderGraph:
75+
"""The CascaderGraph to which the Proposal applies."""
5276
return self._graph
5377

5478
@property
55-
def part_group(self):
79+
def part_group(self) -> FrozenSet[Part]:
80+
"""The Parts which are covered by the Proposal."""
5681
return frozenset(self._part_group)
5782

5883
@property
59-
def plans(self):
84+
def plans(self) -> List[Plan]:
85+
"""The Plans used in the Proposal."""
6086
return list(self._plans)
6187

6288
@property
63-
def input_tensor_configs(self):
89+
def input_tensor_configs(self) -> Dict[Tensor, TensorConfig]:
90+
"""The TensorConfigs indexed by Tensor in the Proposal which aren't produced by a Plan."""
6491
return dict(self._input_tensor_configs)
6592

6693
@property
67-
def memory_usage(self):
94+
def cascade_region(self) -> MemoryRegion:
95+
"""The MemoryRegion where cascading buffers should be homed."""
96+
return self._cascade_region
97+
98+
@property
99+
def memory_usage(self) -> int:
100+
"""The memory required to execute the Proposal in the cascading MemoryRegion."""
68101
return int(self._memory_usage)
69102

70103
@property
71-
def cycles(self):
104+
def cycles(self) -> int:
105+
"""The estimated cycles taken to execute the Proposal."""
72106
return int(self._cycles)

python/tvm/contrib/ethosu/cascader/proposal_generator.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,26 @@ def generate_proposals(
2929
home_map: Dict[FrozenSet[Part], List[Plan]],
3030
options: CascaderOptions,
3131
) -> List[Proposal]:
32+
"""Generate Pareto optimal Proposals for a CascaderGraph.
33+
34+
This algorithm takes a top-down dynamic programming approach to determining how
35+
to optimally combine Plans into Proposals.
36+
37+
Parameters
38+
----------
39+
graph : CascaderGraph
40+
The CascaderGraph to generate Proposals for.
41+
home_map : Dict[FrozenSet[Part], List[Plan]]
42+
The Tensor homing map defining valid memory homes for Tensors.
43+
options : CascaderOptions
44+
The configuration options with which to run the generator.
45+
46+
Returns
47+
------
48+
List[Proposal]
49+
A list of Pareto optimal Proposals.
50+
51+
"""
3252
return list(
3353
_ffi_api.GenerateProposals(
3454
graph,

src/contrib/ethosu/cascader/pareto.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,6 @@ TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.ParetoCullPlans")
161161
return Array<Plan>(ParetoCullPlans(vplans, max_size));
162162
});
163163

164-
TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.ParetoCullProposals")
165-
.set_body_typed([](Array<Proposal> proposals, int max_size) {
166-
std::vector<Proposal> vproposals(proposals.begin(), proposals.end());
167-
return Array<Proposal>(ParetoCullProposals(vproposals, max_size));
168-
});
169-
170164
} // namespace cascader
171165
} // namespace ethosu
172166
} // namespace contrib

src/contrib/ethosu/cascader/proposal.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <tvm/runtime/object.h>
2424
#include <tvm/runtime/registry.h>
2525

26+
#include <algorithm>
2627
#include <utility>
2728
#include <vector>
2829

src/contrib/ethosu/cascader/proposal_generator.cc

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,12 @@ std::unordered_set<TensorConfig> GetPlanBoundaryConfigs(const Plan& plan) {
5151
return boundary_configs;
5252
}
5353

54-
bool IsPlanCompatible(const Proposal& proposal,
55-
const std::vector<Part>& plan_part_group,
54+
bool IsPlanCompatible(const Proposal& proposal, const std::vector<Part>& plan_part_group,
5655
const std::unordered_set<TensorConfig>& plan_boundary_configs) {
5756
// Check the Plan Part group is disjoint with the Proposal Part group
58-
for(const auto& plan_part : plan_part_group) {
59-
for(const auto& proposal_part : proposal->GetPartGroup()) {
60-
if(plan_part == proposal_part) {
57+
for (const auto& plan_part : plan_part_group) {
58+
for (const auto& proposal_part : proposal->GetPartGroup()) {
59+
if (plan_part == proposal_part) {
6160
return false;
6261
}
6362
}
@@ -126,24 +125,25 @@ Proposal AddPlanToProposal(const Proposal& proposal, const Plan& plan,
126125
new_memory_usage = std::max(new_memory_usage, proposal->GetMemoryUsage());
127126
int new_cycles = proposal->GetCycles() + plan->GetCycles();
128127
std::vector<Part> new_part_group = proposal->GetPartGroup();
129-
new_part_group.insert(new_part_group.end(), plan->GetPartGroup().begin(), plan->GetPartGroup().end());
128+
new_part_group.insert(new_part_group.end(), plan->GetPartGroup().begin(),
129+
plan->GetPartGroup().end());
130130
std::sort(new_part_group.begin(), new_part_group.end());
131131
return Proposal(proposal->GetGraph(), new_part_group, new_plans, new_configs,
132132
proposal->GetCascadeRegion(), new_memory_usage, new_cycles);
133133
}
134134

135-
std::vector<Proposal> GeneratePartialProposals(const CascaderGraph& graph, const HomeMap& home_map,
136-
const CascaderOptions options,
137-
const std::unordered_map<Part, std::vector<Plan>, ObjectPtrHash, ObjectPtrEqual>& plans_by_part,
138-
const std::vector<Part>& partial_proposal_group,
139-
std::unordered_map<std::vector<Part>, std::vector<Proposal>>* proposals_by_group) {
135+
std::vector<Proposal> GeneratePartialProposals(
136+
const CascaderGraph& graph, const HomeMap& home_map, const CascaderOptions options,
137+
const std::unordered_map<Part, std::vector<Plan>, ObjectPtrHash, ObjectPtrEqual>& plans_by_part,
138+
const std::vector<Part>& partial_proposal_group,
139+
std::unordered_map<std::vector<Part>, std::vector<Proposal>>* proposals_by_group) {
140140
if (proposals_by_group->find(partial_proposal_group) != proposals_by_group->end()) {
141141
return proposals_by_group->at(partial_proposal_group);
142142
}
143143
if (partial_proposal_group.size() == 0) {
144144
(*proposals_by_group)[partial_proposal_group] =
145-
std::vector<Proposal>{Proposal(graph, std::vector<Part>(), std::vector<Plan>(),
146-
TensorConfigMap(), options->cascade_region, 0, 0)};
145+
std::vector<Proposal>{Proposal(graph, std::vector<Part>(), std::vector<Plan>(),
146+
TensorConfigMap(), options->cascade_region, 0, 0)};
147147
} else {
148148
Part part = partial_proposal_group.back();
149149
const auto& plans = plans_by_part.at(part);
@@ -158,26 +158,26 @@ std::vector<Proposal> GeneratePartialProposals(const CascaderGraph& graph, const
158158
// pick the current Plan.
159159
std::vector<Part> residual_proposal_group;
160160
std::copy_if(partial_proposal_group.begin(), partial_proposal_group.end(),
161-
std::back_inserter(residual_proposal_group), [&plan](Part value) {
162-
return std::find(plan->GetPartGroup().begin(),
163-
plan->GetPartGroup().end(),
161+
std::back_inserter(residual_proposal_group), [&plan](Part value) {
162+
return std::find(plan->GetPartGroup().begin(), plan->GetPartGroup().end(),
164163
value) == plan->GetPartGroup().end();
165-
});
164+
});
166165
// std::sort(residual_proposal_group.begin(), residual_proposal_group.end());
167-
const auto& residual_proposals = GeneratePartialProposals(graph, home_map, options, plans_by_part, residual_proposal_group, proposals_by_group);
166+
const auto& residual_proposals = GeneratePartialProposals(
167+
graph, home_map, options, plans_by_part, residual_proposal_group, proposals_by_group);
168168
auto plan_output_tensor = plan->GetOutputConfig()->GetTensor();
169169
ICHECK_LE(plan_output_tensor->GetProducers().size(), 1)
170170
<< "All tensors must have at most one producer.";
171171
for (const auto& residual_proposal : residual_proposals) {
172172
if (IsPlanCompatible(residual_proposal, plan->GetPartGroup(), plan_boundary_configs)) {
173-
(*proposals_by_group)[partial_proposal_group].push_back(AddPlanToProposal(
174-
residual_proposal, plan, plan_boundary_configs));
173+
(*proposals_by_group)[partial_proposal_group].push_back(
174+
AddPlanToProposal(residual_proposal, plan, plan_boundary_configs));
175175
}
176176
}
177177
}
178178
}
179-
(*proposals_by_group)[partial_proposal_group] = ParetoCullProposals(
180-
proposals_by_group->at(partial_proposal_group), options->max_proposals);
179+
(*proposals_by_group)[partial_proposal_group] =
180+
ParetoCullProposals(proposals_by_group->at(partial_proposal_group), options->max_proposals);
181181
}
182182
return proposals_by_group->at(partial_proposal_group);
183183
}
@@ -194,7 +194,8 @@ std::vector<Proposal> GenerateProposals(const CascaderGraph& graph, const HomeMa
194194
std::vector<Part> partial_proposal_group = graph->GetPartOrder();
195195
// A map of Proposals indexed by the Part group they cover
196196
std::unordered_map<std::vector<Part>, std::vector<Proposal>> proposals_by_group;
197-
return GeneratePartialProposals(graph, home_map, options, plans_by_part, partial_proposal_group, &proposals_by_group);
197+
return GeneratePartialProposals(graph, home_map, options, plans_by_part, partial_proposal_group,
198+
&proposals_by_group);
198199
}
199200

200201
TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.GenerateProposals")

0 commit comments

Comments
 (0)