@@ -51,13 +51,12 @@ std::unordered_set<TensorConfig> GetPlanBoundaryConfigs(const Plan& plan) {
5151 return boundary_configs;
5252}
5353
54- bool IsPlanCompatible (const Proposal& proposal,
55- const std::vector<Part>& plan_part_group,
54+ bool IsPlanCompatible (const Proposal& proposal, const std::vector<Part>& plan_part_group,
5655 const std::unordered_set<TensorConfig>& plan_boundary_configs) {
5756 // Check the Plan Part group is disjoint with the Proposal Part group
58- for (const auto & plan_part : plan_part_group) {
59- for (const auto & proposal_part : proposal->GetPartGroup ()) {
60- if (plan_part == proposal_part) {
57+ for (const auto & plan_part : plan_part_group) {
58+ for (const auto & proposal_part : proposal->GetPartGroup ()) {
59+ if (plan_part == proposal_part) {
6160 return false ;
6261 }
6362 }
@@ -126,24 +125,25 @@ Proposal AddPlanToProposal(const Proposal& proposal, const Plan& plan,
126125 new_memory_usage = std::max (new_memory_usage, proposal->GetMemoryUsage ());
127126 int new_cycles = proposal->GetCycles () + plan->GetCycles ();
128127 std::vector<Part> new_part_group = proposal->GetPartGroup ();
129- new_part_group.insert (new_part_group.end (), plan->GetPartGroup ().begin (), plan->GetPartGroup ().end ());
128+ new_part_group.insert (new_part_group.end (), plan->GetPartGroup ().begin (),
129+ plan->GetPartGroup ().end ());
130130 std::sort (new_part_group.begin (), new_part_group.end ());
131131 return Proposal (proposal->GetGraph (), new_part_group, new_plans, new_configs,
132132 proposal->GetCascadeRegion (), new_memory_usage, new_cycles);
133133}
134134
135- std::vector<Proposal> GeneratePartialProposals (const CascaderGraph& graph, const HomeMap& home_map,
136- const CascaderOptions options,
137- const std::unordered_map<Part, std::vector<Plan>, ObjectPtrHash, ObjectPtrEqual>& plans_by_part,
138- const std::vector<Part>& partial_proposal_group,
139- std::unordered_map<std::vector<Part>, std::vector<Proposal>>* proposals_by_group) {
135+ std::vector<Proposal> GeneratePartialProposals (
136+ const CascaderGraph& graph, const HomeMap& home_map, const CascaderOptions options,
137+ const std::unordered_map<Part, std::vector<Plan>, ObjectPtrHash, ObjectPtrEqual>& plans_by_part,
138+ const std::vector<Part>& partial_proposal_group,
139+ std::unordered_map<std::vector<Part>, std::vector<Proposal>>* proposals_by_group) {
140140 if (proposals_by_group->find (partial_proposal_group) != proposals_by_group->end ()) {
141141 return proposals_by_group->at (partial_proposal_group);
142142 }
143143 if (partial_proposal_group.size () == 0 ) {
144144 (*proposals_by_group)[partial_proposal_group] =
145- std::vector<Proposal>{Proposal (graph, std::vector<Part>(), std::vector<Plan>(),
146- TensorConfigMap (), options->cascade_region , 0 , 0 )};
145+ std::vector<Proposal>{Proposal (graph, std::vector<Part>(), std::vector<Plan>(),
146+ TensorConfigMap (), options->cascade_region , 0 , 0 )};
147147 } else {
148148 Part part = partial_proposal_group.back ();
149149 const auto & plans = plans_by_part.at (part);
@@ -158,26 +158,26 @@ std::vector<Proposal> GeneratePartialProposals(const CascaderGraph& graph, const
158158 // pick the current Plan.
159159 std::vector<Part> residual_proposal_group;
160160 std::copy_if (partial_proposal_group.begin (), partial_proposal_group.end (),
161- std::back_inserter (residual_proposal_group), [&plan](Part value) {
162- return std::find (plan->GetPartGroup ().begin (),
163- plan->GetPartGroup ().end (),
161+ std::back_inserter (residual_proposal_group), [&plan](Part value) {
162+ return std::find (plan->GetPartGroup ().begin (), plan->GetPartGroup ().end (),
164163 value) == plan->GetPartGroup ().end ();
165- });
164+ });
166165 // std::sort(residual_proposal_group.begin(), residual_proposal_group.end());
167- const auto & residual_proposals = GeneratePartialProposals (graph, home_map, options, plans_by_part, residual_proposal_group, proposals_by_group);
166+ const auto & residual_proposals = GeneratePartialProposals (
167+ graph, home_map, options, plans_by_part, residual_proposal_group, proposals_by_group);
168168 auto plan_output_tensor = plan->GetOutputConfig ()->GetTensor ();
169169 ICHECK_LE (plan_output_tensor->GetProducers ().size (), 1 )
170170 << " All tensors must have at most one producer." ;
171171 for (const auto & residual_proposal : residual_proposals) {
172172 if (IsPlanCompatible (residual_proposal, plan->GetPartGroup (), plan_boundary_configs)) {
173- (*proposals_by_group)[partial_proposal_group].push_back (AddPlanToProposal (
174- residual_proposal, plan, plan_boundary_configs));
173+ (*proposals_by_group)[partial_proposal_group].push_back (
174+ AddPlanToProposal ( residual_proposal, plan, plan_boundary_configs));
175175 }
176176 }
177177 }
178178 }
179- (*proposals_by_group)[partial_proposal_group] = ParetoCullProposals (
180- proposals_by_group->at (partial_proposal_group), options->max_proposals );
179+ (*proposals_by_group)[partial_proposal_group] =
180+ ParetoCullProposals ( proposals_by_group->at (partial_proposal_group), options->max_proposals );
181181 }
182182 return proposals_by_group->at (partial_proposal_group);
183183}
@@ -194,7 +194,8 @@ std::vector<Proposal> GenerateProposals(const CascaderGraph& graph, const HomeMa
194194 std::vector<Part> partial_proposal_group = graph->GetPartOrder ();
195195 // A map of Proposals indexed by the Part group they cover
196196 std::unordered_map<std::vector<Part>, std::vector<Proposal>> proposals_by_group;
197- return GeneratePartialProposals (graph, home_map, options, plans_by_part, partial_proposal_group, &proposals_by_group);
197+ return GeneratePartialProposals (graph, home_map, options, plans_by_part, partial_proposal_group,
198+ &proposals_by_group);
198199}
199200
200201TVM_REGISTER_GLOBAL (" contrib.ethosu.cascader.GenerateProposals" )
0 commit comments