|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +# pylint: disable=invalid-name |
| 18 | +"""Scheduler for cascader which converts Proposals into Schedules.""" |
| 19 | +from typing import Tuple, List, Dict, DefaultDict |
| 20 | +from collections import defaultdict |
| 21 | +import numpy as np |
| 22 | + |
| 23 | +from tvm import te |
| 24 | +from tvm import tir |
| 25 | +from .cascader_options import CascaderOptions |
| 26 | +from .graph import CascaderGraph, Part, Tensor, TESubgraph |
| 27 | +from .tensor_config import MemoryRegion |
| 28 | +from .proposal import Proposal |
| 29 | +from .proposal_generator import generate_proposals |
| 30 | +from .graph import create_cascader_graph |
| 31 | +from .device_config import EthosuDeviceConfig |
| 32 | + |
| 33 | + |
| 34 | +def tile_nd( |
| 35 | + sch: te.Schedule, tensor: te.Tensor, tile: Tuple[int, ...] |
| 36 | +) -> Tuple[List[tir.IterVar], List[tir.IterVar]]: |
| 37 | + """Scheduling utility to perform N-dimensional tiling. |
| 38 | +
|
| 39 | + Parameters |
| 40 | + ---------- |
| 41 | + sch : te.Schedule |
| 42 | + The schedule to apply the tiling to. |
| 43 | + tensor : te.Tensor |
| 44 | + The tensor to apply the tiling to. |
| 45 | + tile : Tuple[int, ...] |
| 46 | + The N-dimensional tile size. |
| 47 | +
|
| 48 | + Returns |
| 49 | + ------- |
| 50 | + outer_indices : List[tir.IterVar] |
| 51 | + The outer iteration variables. |
| 52 | + inner_indices : List[tir.IterVar] |
| 53 | + The inner iteration variables. |
| 54 | +
|
| 55 | + """ |
| 56 | + outer_indices = [] |
| 57 | + inner_indices = [] |
| 58 | + for i, size in enumerate(tile): |
| 59 | + outer, inner = sch[tensor].split(tensor.op.axis[i], size) |
| 60 | + outer_indices.append(outer) |
| 61 | + inner_indices.append(inner) |
| 62 | + |
| 63 | + sch[tensor].reorder(*outer_indices, *inner_indices) |
| 64 | + return outer_indices, inner_indices |
| 65 | + |
| 66 | + |
| 67 | +def stripe_part( |
| 68 | + part: Part, stripe_shape: Tuple[int, ...], sch: te.Schedule |
| 69 | +) -> Tuple[te.Stage, tir.IterVar]: |
| 70 | + """Apply a striping schedule to the TE subgraph represented by a Part.""" |
| 71 | + te_subgraph = part.subgraph |
| 72 | + te_output_tensor = te_subgraph.output_tensor |
| 73 | + outer_indices, _ = tile_nd(sch, te_output_tensor, stripe_shape) |
| 74 | + g = sch.create_group( |
| 75 | + outputs=te_output_tensor.op.input_tensors, |
| 76 | + inputs=te_subgraph.input_tensors, |
| 77 | + include_inputs=False, |
| 78 | + ) |
| 79 | + g.compute_at(sch[te_output_tensor], outer_indices[-1]) |
| 80 | + for ax in outer_indices: |
| 81 | + sch[te_output_tensor].unroll(ax) |
| 82 | + |
| 83 | + return sch[te_output_tensor], outer_indices[-1] |
| 84 | + |
| 85 | + |
| 86 | +def cascade_part( |
| 87 | + part: Part, stripe_stage: te.Stage, stripe_axis: tir.IterVar, sch: te.Schedule |
| 88 | +) -> None: |
| 89 | + """Schedule a Part into a cascade indicated by a stripe Stage.""" |
| 90 | + te_subgraph = part.subgraph |
| 91 | + g = sch.create_group( |
| 92 | + outputs=te_subgraph.output_tensor, inputs=te_subgraph.input_tensors, include_inputs=False |
| 93 | + ) |
| 94 | + g.compute_at(stripe_stage, stripe_axis) |
| 95 | + |
| 96 | + |
| 97 | +def update_readers(part: Part, readers: DefaultDict[te.Tensor, List[te.Tensor]]) -> None: |
| 98 | + """ |
| 99 | + Update a dictionary which stores the te.Tensors that need to be read in |
| 100 | + order to produce a given te.Tensor. |
| 101 | + """ |
| 102 | + visited = set() |
| 103 | + |
| 104 | + def _visit(tensor): |
| 105 | + if tensor not in visited and tensor not in part.subgraph.input_tensors: |
| 106 | + visited.add(tensor) |
| 107 | + for input_tensor in tensor.op.input_tensors: |
| 108 | + readers[input_tensor].append(tensor) |
| 109 | + _visit(input_tensor) |
| 110 | + |
| 111 | + _visit(part.subgraph.output_tensor) |
| 112 | + |
| 113 | + |
| 114 | +def apply_proposal(proposal: Proposal, sch: te.Schedule) -> None: |
| 115 | + """Apply a Proposal to a Schedule, converting all the Plans into TE scheduling instructions. |
| 116 | +
|
| 117 | + Note that the Schedule is mutated in-place. |
| 118 | +
|
| 119 | + Parameters |
| 120 | + ---------- |
| 121 | + proposal : Proposal |
| 122 | + The Proposal to apply to the Schedule. |
| 123 | + sch : te.Schedule |
| 124 | + The Schedule to apply to Proposal to. |
| 125 | +
|
| 126 | + """ |
| 127 | + for plan in proposal.plans: |
| 128 | + output_tensor_config = plan.output_config |
| 129 | + output_tensor = output_tensor_config.tensor |
| 130 | + output_part = output_tensor.producers[0] |
| 131 | + if output_part.in_line: |
| 132 | + continue |
| 133 | + stripe_config = output_tensor_config.stripe_configs[0] |
| 134 | + stripe_shape = [int(x) for x in stripe_config.shape] |
| 135 | + stripe_stage, stripe_axis = stripe_part(output_part, stripe_shape, sch) |
| 136 | + copy_te_tensors = [] |
| 137 | + readers = defaultdict(list) |
| 138 | + for part in plan.part_group: |
| 139 | + if part != output_part: |
| 140 | + cascade_part(part, stripe_stage, stripe_axis, sch) |
| 141 | + |
| 142 | + update_readers(part, readers) |
| 143 | + for i, input_tensor in enumerate(part.input_tensors): |
| 144 | + tensor_config = plan.tensor_configs[input_tensor] |
| 145 | + if tensor_config.home_region != tensor_config.copy_region: |
| 146 | + copy_te_tensors.append(part.subgraph.input_tensors[i]) |
| 147 | + |
| 148 | + for te_tensor in copy_te_tensors: |
| 149 | + copy_stage = sch.cache_read(te_tensor, "global", readers[te_tensor]) |
| 150 | + sch[copy_stage].compute_at(stripe_stage, stripe_axis) |
| 151 | + |
| 152 | + |
| 153 | +def create_home_map( |
| 154 | + graph: CascaderGraph, |
| 155 | + io_region: MemoryRegion, |
| 156 | + constant_region: MemoryRegion, |
| 157 | + working_regions: List[MemoryRegion], |
| 158 | +) -> Dict[Tensor, List[MemoryRegion]]: |
| 159 | + """Create a map between Tensors and the MemoryRegions they can be homed in.""" |
| 160 | + home_map = {} |
| 161 | + for tensor in graph.tensor_order: |
| 162 | + if tensor.is_constant: |
| 163 | + home_map[tensor] = [constant_region] |
| 164 | + elif tensor in graph.input_tensors or tensor in graph.output_tensors: |
| 165 | + home_map[tensor] = [io_region] |
| 166 | + else: |
| 167 | + home_map[tensor] = working_regions |
| 168 | + |
| 169 | + return home_map |
| 170 | + |
| 171 | + |
| 172 | +def choose_proposal(proposals: List[Proposal], cascade_region: MemoryRegion): |
| 173 | + """Choose the best performing Proposal that doesn't overflow the cascade region.""" |
| 174 | + proposal_choice = proposals[0] |
| 175 | + for proposal in reversed(proposals): |
| 176 | + if proposal.memory_usage < cascade_region.size: |
| 177 | + proposal_choice = proposal |
| 178 | + break |
| 179 | + |
| 180 | + return proposal_choice |
| 181 | + |
| 182 | + |
| 183 | +def cascade( |
| 184 | + sch: te.Schedule, |
| 185 | + te_graph: TESubgraph, |
| 186 | + const_dict: Dict[int, np.ndarray], |
| 187 | + options: CascaderOptions, |
| 188 | + io_region: MemoryRegion, |
| 189 | + constant_region: MemoryRegion, |
| 190 | + working_regions: List[MemoryRegion], |
| 191 | + device_config: EthosuDeviceConfig, |
| 192 | +) -> None: |
| 193 | + """Schedule a Tensor Expression graph using the technique of 'cascading'. |
| 194 | +
|
| 195 | + 'Cascading' is a technique whereby operations are split into smaller |
| 196 | + dependent tiles ('stripes') which can then execute in an interleaved |
| 197 | + fashion. This allows for operations to execute together rather than |
| 198 | + sequentially which can reduce intermediate memory requirements and in |
| 199 | + certain cases improve performance. |
| 200 | +
|
| 201 | + For more detail on 'cascading' as well as how it is implemented, refer to |
| 202 | + the RFC here: https://github.com/apache/tvm-rfcs/pull/37. |
| 203 | +
|
| 204 | + Parameters |
| 205 | + ---------- |
| 206 | + sch : te.Schedule |
| 207 | + The Schedule to apply the cascading to. |
| 208 | + te_graph : TESubgraph |
| 209 | + The Tensor Expression graph from which the Schedule was created. |
| 210 | + const_dict : Dict[int, np.ndarray] |
| 211 | + A dictionary mapping input index to constant data if that input is |
| 212 | + to be a constant. |
| 213 | + options : CascaderOptions |
| 214 | + Configuration options for the cascading scheduler. |
| 215 | + io_region : MemoryRegion |
| 216 | + The MemoryRegion in which input/output tensors should reside. |
| 217 | + constant_region : MemoryRegion |
| 218 | + The MemoryRegion in which constants should reside. |
| 219 | + working_regions : List[MemoryRegion] |
| 220 | + The MemoryRegions in which intermediate working tensors can reside. The |
| 221 | + cascading scheduler will select which MemoryRegion to per tensor. |
| 222 | + device_config : EthosuDeviceConfig |
| 223 | + Target device configuration. |
| 224 | +
|
| 225 | + """ |
| 226 | + assert options.cascade_region in working_regions |
| 227 | + # First convert the Tensor Expression graph into a CascaderGraph |
| 228 | + casc_graph = create_cascader_graph(te_graph, const_dict, device_config) |
| 229 | + # Then create a mapping between Tensors and their possible memory homes |
| 230 | + home_map = create_home_map(casc_graph, io_region, constant_region, working_regions) |
| 231 | + # Generate Proposals for Pareto-optimal ways to cascade the CascaderGraph |
| 232 | + proposals = generate_proposals(casc_graph, home_map, options) |
| 233 | + # Select the best Proposal subject to the memory constraints |
| 234 | + proposal_choice = choose_proposal(proposals, options.cascade_region) |
| 235 | + # Apply the selected Proposal to the Tensor Expression Schedule |
| 236 | + apply_proposal(proposal_choice, sch) |
0 commit comments