diff --git a/gallery/how_to/work_with_msc/_resnet.py b/gallery/how_to/work_with_msc/_resnet.py new file mode 100644 index 000000000000..d05172337638 --- /dev/null +++ b/gallery/how_to/work_with_msc/_resnet.py @@ -0,0 +1,350 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# build resnet for cifar10, debug use only +# from https://github.com/huyvnphan/PyTorch_CIFAR10/blob/master/cifar10_models/resnet.py + +import os +import requests +from tqdm import tqdm +import zipfile + +import torch +import torch.nn as nn + +__all__ = [ + "ResNet", + "resnet18", + "resnet34", + "resnet50", +] + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None, + ): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None, + ): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.0)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__( + self, + block, + layers, + num_classes=10, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + norm_layer=None, + ): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + "replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation) + ) + self.groups = groups + self.base_width = width_per_group + + # CIFAR10: kernel_size 7 -> 3, stride 2 -> 1, padding 3->1 + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + # END + + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer( + block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] + ) + self.layer3 = self._make_layer( + block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] + ) + self.layer4 = self._make_layer( + block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] + ) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + ) + ) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.reshape(x.size(0), -1) + x = self.fc(x) + + return x + + +def _resnet(arch, block, layers, pretrained, progress, device, **kwargs): + model = ResNet(block, layers, **kwargs) + if pretrained: + if os.path.isdir(pretrained): + state_dict = torch.load(pretrained + "/" + arch + ".pt", map_location=device) + else: + script_dir = os.path.dirname(__file__) + state_dict = torch.load( + script_dir + "/state_dicts/" + arch + ".pt", map_location=device + ) + model.load_state_dict(state_dict) + return model + + +def resnet18(pretrained=False, progress=True, device="cpu", **kwargs): + """Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, device, **kwargs) + + +def resnet34(pretrained=False, progress=True, device="cpu", **kwargs): + """Constructs a ResNet-34 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, device, **kwargs) + + +def resnet50(pretrained=False, progress=True, device="cpu", **kwargs): + """Constructs a ResNet-50 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, device, **kwargs) + + +def download_weights(): + url = "https://rutgers.box.com/shared/static/gkw08ecs797j2et1ksmbg1w5t3idf5r5.zip" + + # Streaming, so we can iterate over the response. + r = requests.get(url, stream=True) + + # Total size in Mebibyte + total_size = int(r.headers.get("content-length", 0)) + block_size = 2**20 # Mebibyte + t = tqdm(total=total_size, unit="MiB", unit_scale=True) + + with open("state_dicts.zip", "wb") as f: + for data in r.iter_content(block_size): + t.update(len(data)) + f.write(data) + t.close() + + if total_size != 0 and t.n != total_size: + raise Exception("Error, something went wrong") + + print("Download successful. Unzipping file...") + path_to_zip_file = os.path.join(os.getcwd(), "state_dicts.zip") + directory_to_extract_to = os.path.join(os.getcwd(), "cifar10_models") + with zipfile.ZipFile(path_to_zip_file, "r") as zip_ref: + zip_ref.extractall(directory_to_extract_to) + print("Unzip file successful!") diff --git a/gallery/how_to/work_with_msc/using_tools.py b/gallery/how_to/work_with_msc/using_tools.py new file mode 100644 index 000000000000..3c3f528d959d --- /dev/null +++ b/gallery/how_to/work_with_msc/using_tools.py @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Wrap pytorch model with quantizer. +This example shows how to run PTQ, QAT, PTQ with distill... +Reference for MSC: +https://discuss.tvm.apache.org/t/rfc-unity-msc-introduction-to-multi-system-compiler/15251/5 + +This example use resnet50 from https://github.com/huyvnphan/PyTorch_CIFAR10/tree/master, +please download pt file and copy to args.checkpoint before run example +""" + +import argparse +import torch +import torch.optim as optim + +from tvm.contrib.msc.pipeline import TorchWrapper +from tvm.contrib.msc.core.tools import ToolType +from tvm.contrib.msc.core.utils.message import MSCStage +from _resnet import resnet50 +from utils import * + +parser = argparse.ArgumentParser(description="MSC train && eval example") +parser.add_argument( + "--dataset", + type=str, + default="/tmp/msc_dataset", + help="The folder saving training and testing datas", +) +parser.add_argument( + "--checkpoint", + type=str, + default="/tmp/msc_models", + help="The folder saving training and testing datas", +) +parser.add_argument("--compile_type", type=str, default="tvm", help="The compile type of model") +parser.add_argument("--prune", action="store_true", help="Whether to use pruner") +parser.add_argument("--quantize", action="store_true", help="Whether to use quantizer") +parser.add_argument("--distill", action="store_true", help="Whether to use distiller for tool") +parser.add_argument("--gym", action="store_true", help="Whether to use gym for tool") +parser.add_argument("--test_batch", type=int, default=1, help="The batch size for test") +parser.add_argument("--test_iter", type=int, default=100, help="The iter for test") +parser.add_argument("--calibrate_iter", type=int, default=100, help="The iter for calibration") +parser.add_argument("--train_batch", type=int, default=32, help="The batch size for train") +parser.add_argument("--train_iter", type=int, default=200, help="The iter for train") +parser.add_argument("--train_epoch", type=int, default=5, help="The epoch for train") +args = parser.parse_args() + + +def get_config(calib_loader, train_loader): + tools, dataset = [], {MSCStage.PREPARE: {"loader": calib_loader}} + if args.prune: + config = {"gym_configs": ["default"]} if args.gym else "default" + tools.append((ToolType.PRUNER, config)) + if args.quantize: + config = {"gym_configs": ["default"]} if args.gym else "default" + tools.append((ToolType.QUANTIZER, config)) + if args.distill: + config = { + "options": { + "optimizer": "adam", + "opt_config": {"lr": 0.00000001, "weight_decay": 0.08}, + } + } + tools.append((ToolType.DISTILLER, config)) + dataset[MSCStage.DISTILL] = {"loader": train_loader} + return TorchWrapper.create_config( + inputs=[("input", [args.test_batch, 3, 32, 32], "float32")], + outputs=["output"], + compile_type=args.compile_type, + dataset=dataset, + tools=tools, + skip_config={"all": "check"}, + verbose="info", + ) + + +if __name__ == "__main__": + trainloader, testloader = get_dataloaders(args.dataset, args.train_batch, args.test_batch) + + def _get_calib_datas(): + for i, (inputs, _) in enumerate(testloader, 0): + if i >= args.calibrate_iter > 0: + break + yield {"input": inputs} + + def _get_train_datas(): + for i, (inputs, _) in enumerate(trainloader, 0): + if i >= args.train_iter > 0: + break + yield {"input": inputs} + + model = resnet50(pretrained=args.checkpoint) + if torch.cuda.is_available(): + model = model.to(torch.device("cuda:0")) + + acc = eval_model(model, testloader, max_iter=args.test_iter) + print("Baseline acc: " + str(acc)) + + model = TorchWrapper(model, get_config(_get_calib_datas, _get_train_datas)) + + # optimize the model with tool + model.optimize() + acc = eval_model(model, testloader, max_iter=args.test_iter) + print("Optimized acc: " + str(acc)) + + # train the model with tool + optimizer = optim.Adam(model.parameters(), lr=0.0000001, weight_decay=0.08) + for ep in range(args.train_epoch): + train_model(model, trainloader, optimizer, max_iter=args.train_iter) + acc = eval_model(model, testloader, max_iter=args.test_iter) + print("Train[{}] acc: {}".format(ep, acc)) + + # compile the model + model.compile() + acc = eval_model(model, testloader, max_iter=args.test_iter) + print("Compiled acc: " + str(acc)) diff --git a/gallery/how_to/work_with_msc/utils.py b/gallery/how_to/work_with_msc/utils.py new file mode 100644 index 000000000000..3ff20afec6d3 --- /dev/null +++ b/gallery/how_to/work_with_msc/utils.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Utils of using msc examples """ + +import numpy as np + +import torch +from torch import nn +import torchvision +import torchvision.transforms as transforms + + +def get_dataloaders(path, train_batch=32, test_batch=1, dataset="cifar10"): + """Get the data loaders for torch process""" + + if dataset == "cifar10": + mean = (0.4914, 0.4822, 0.4465) + std = (0.2471, 0.2435, 0.2616) + train_transform = transforms.Compose( + [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean, std), + ] + ) + trainset = torchvision.datasets.CIFAR10( + root=path, train=True, download=True, transform=train_transform + ) + test_transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean, std), + ] + ) + testset = torchvision.datasets.CIFAR10( + root=path, train=False, download=True, transform=test_transform + ) + trainloader = torch.utils.data.DataLoader( + trainset, batch_size=train_batch, shuffle=True, num_workers=2 + ) + testloader = torch.utils.data.DataLoader( + testset, batch_size=test_batch, shuffle=False, num_workers=2 + ) + return trainloader, testloader + raise Exception("Unexpected dataset " + str(dataset)) + + +def eval_model(model, dataloader, max_iter=-1, log_step=100): + """Evaluate the model""" + + model.eval() + device = next(model.parameters()).device + num_correct, num_datas = 0, 0 + for i, (inputs, labels) in enumerate(dataloader, 0): + with torch.no_grad(): + outputs = model(inputs.to(device)) + cls_idices = torch.argmax(outputs, axis=1) + labels = labels.to(device) + num_datas += len(cls_idices) + num_correct += torch.where(cls_idices == labels, 1, 0).sum() + if num_datas > 0 and num_datas % log_step == 0: + print("[{}/{}] Torch eval acc: {}".format(i, len(dataloader), num_correct / num_datas)) + if max_iter > 0 and num_datas >= max_iter: + break + acc = num_correct / num_datas + return acc.detach().cpu().numpy().tolist() + + +def train_model(model, dataloader, optimizer, max_iter=-1, log_step=100): + """Train the model""" + + model.train() + device = next(model.parameters()).device + num_correct, num_datas = 0, 0 + criterion = nn.CrossEntropyLoss() + running_loss = 0.0 + for i, (inputs, labels) in enumerate(dataloader, 0): + optimizer.zero_grad() + outputs = model(inputs.to(device)) + cls_idices = torch.argmax(outputs, axis=1) + labels = labels.to(device) + num_datas += len(cls_idices) + num_correct += torch.where(cls_idices == labels, 1, 0).sum() + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + # gather loss + running_loss += loss.item() + if num_datas > 0 and num_datas % log_step == 0: + print( + "[{}/{}] Torch train loss: {}, acc {}".format( + i, len(dataloader), running_loss / (i + 1), num_correct / num_datas + ) + ) + if max_iter > 0 and num_datas >= max_iter: + break diff --git a/python/tvm/contrib/msc/core/codegen/codegen.py b/python/tvm/contrib/msc/core/codegen/codegen.py index 8ffaf9dd5fa1..c2711231f400 100644 --- a/python/tvm/contrib/msc/core/codegen/codegen.py +++ b/python/tvm/contrib/msc/core/codegen/codegen.py @@ -21,8 +21,10 @@ from typing import Dict, List, Optional, Any, Callable import tvm -from tvm.relax.transform import BindParams -from tvm.contrib.msc.core.ir import MSCGraph +from tvm import relax +from tvm.relax import PyExprVisitor +from tvm.contrib.msc.core import transform as msc_transform +from tvm.contrib.msc.core.ir import MSCGraph, MSCTensor from tvm.contrib.msc.core.frontend import from_relay from tvm.contrib.msc.core import utils as msc_utils @@ -126,6 +128,95 @@ def load( return obj +def to_relax( + graph: MSCGraph, + weights: Optional[Dict[str, tvm.nd.array]] = None, + codegen_config: Optional[Dict[str, str]] = None, + print_config: Optional[Dict[str, str]] = None, + build_folder: msc_utils.MSCDirectory = None, + plugin: Any = None, + use_alias: bool = True, +) -> tvm.IRModule: + """Change MSCGraph to IRModule. + + Parameters + ---------- + graph: tvm.contrib.msc.core.ir.MSCGraph + The translated graph. + weights: dict of + The parameters of the IRModule. + codegen_config: dict + The config for codegen. + print_config: dict + The config for print. + build_folder: MSCDirectory + The folder for saving scripts and datas. + plugin: PluginManager + The plugin manager. + use_alias: bool + Whether to use alias for input. + + Returns + ------- + mod: IRModule + The IRModule of relax. + """ + + @relax.expr_functor.visitor + class NamesGetter(PyExprVisitor): + """Visitor for get attributes in span""" + + def get_names(self, expr: relax.Expr) -> dict: + self._names = {} + if isinstance(expr, relax.Expr): + self.visit_expr(expr) + elif isinstance(expr, relax.BindingBlock): + self.visit_binding_block(expr) + return self._names + + def visit_var_binding_(self, binding: relax.VarBinding) -> None: + super().visit_var_binding_(binding) + self._names[binding.var.name_hint] = binding.var.name_hint + + def _to_var(tensor: MSCTensor): + v_name = tensor.alias if use_alias else graph.find_producer(tensor).name + return tvm.relax.Var( + v_name, tvm.relax.TensorStructInfo(tensor.get_shape(), tensor.dtype_name) + ) + + def _save_weights(folder: msc_utils.MSCDirectory): + if weights: + with open(folder.relpath(graph.name + "_params.bin"), "wb") as f_params: + f_params.write(tvm.runtime.save_param_dict(weights)) + + # pylint: disable=unused-argument + def _post_proc(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRModule: + passes, var_names = [], NamesGetter().get_names(mod["main"]) + if weights: + passes.append(msc_transform.BindNamedParams("main", weights)) + # The canonicalization of relax variable bindings is not required + # for correctness. It does, however, remove trivial `x = y` + # bindings, preventing test cases from depending on their + # presence. + passes.extend( + [ + msc_transform.SetExprName(var_names=var_names), + tvm.relax.transform.CanonicalizeBindings(), + tvm.relax.transform.ConvertToDataflow(min_size=1), + ] + ) + return tvm.ir.transform.Sequential( + passes, name="tvm.contrib.msc.core.codegen.to_relax_postproc" + )(mod) + + source_getter = tvm.get_global_func("msc.framework.tvm.GetRelaxSources") + codegen = CodeGen(graph, source_getter, codegen_config, print_config, build_folder) + model_args = [_to_var(i) for i in graph.get_inputs()] + if plugin: + model_args = model_args + [plugin] + return codegen.load(model_args, pre_load=_save_weights, post_load=_post_proc) + + def relay_to_relax( relay_mod: tvm.IRModule, params: Optional[Dict[str, tvm.nd.array]] = None, @@ -133,7 +224,7 @@ def relay_to_relax( build_config: Optional[Dict[str, str]] = None, opt_config: Optional[Dict[str, str]] = None, ) -> tvm.IRModule: - """Change IRModule to MSCGraph. + """Change relay IRModule to relax MSCGraph. Parameters ---------- @@ -161,26 +252,5 @@ def relay_to_relax( build_config=build_config, opt_config=opt_config, ) - source_getter = tvm.get_global_func("msc.framework.tvm.GetRelaxSources") - codegen = CodeGen(graph, source_getter, codegen_config={"from_relay": True}) - inputs = [ - tvm.relax.Var(i.alias, tvm.relax.TensorStructInfo(i.get_shape(), i.dtype_name)) - for i in graph.get_inputs() - ] - - # pylint: disable=unused-argument - def _post_proc(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRModule: - mod = BindParams("main", weights)(mod) - return tvm.ir.transform.Sequential( - [ - # The canonicalization of relax variable bindings is not required - # for correctness. It does, however, remove trivial `x = y` - # bindings, preventing test cases from depending on their - # presence. - tvm.relax.transform.CanonicalizeBindings(), - tvm.relax.transform.ConvertToDataflow(min_size=1), - ], - name="tvm.contrib.msc.core.codegen.relay_to_relax_postproc", - )(mod) - return codegen.load(inputs, post_load=_post_proc) + return to_relax(graph, weights, codegen_config={"from_relay": True}) diff --git a/python/tvm/contrib/msc/core/ir/graph.py b/python/tvm/contrib/msc/core/ir/graph.py index 5bfe1cec2a6f..19a16a375b7a 100644 --- a/python/tvm/contrib/msc/core/ir/graph.py +++ b/python/tvm/contrib/msc/core/ir/graph.py @@ -278,6 +278,25 @@ def weight_at(self, wtype: str) -> MSCTensor: return _ffi_api.MSCJointWeightAt(self, wtype) + def weight_type(self, name: str) -> str: + """Get the weight type of weight + + Parameters + ---------- + name: str + The name of weight. + + Returns + ------- + wtype: str + The type of weight. + """ + + for w_type, weight in self.get_weights().items(): + if weight.name == name: + return w_type + raise Exception("Can not find weight type for " + name) + def get_inputs(self) -> List[MSCTensor]: """Get all the inputs. @@ -727,6 +746,23 @@ def get_outputs(self) -> List[MSCTensor]: return _ffi_api.MSCGraphGetOutputs(self) + def get_tensors(self) -> List[MSCTensor]: + """Get all the tensors. + + Returns + ------- + tensors: list + The Tensors. + """ + + for node in self.get_nodes(): + for t_input in node.get_inputs(): + yield t_input + for weight in node.get_weights().values(): + yield weight + for t_output in self.get_outputs(): + yield t_output + def to_json(self) -> str: """Dump the graph to json. diff --git a/python/tvm/contrib/msc/core/runtime/runner.py b/python/tvm/contrib/msc/core/runtime/runner.py index 6d3a364e90ec..c4f4016d148f 100644 --- a/python/tvm/contrib/msc/core/runtime/runner.py +++ b/python/tvm/contrib/msc/core/runtime/runner.py @@ -26,6 +26,7 @@ import tvm from tvm.contrib.msc.core.ir import MSCGraph from tvm.contrib.msc.core.frontend import from_relax +from tvm.contrib.msc.core.codegen import to_relax from tvm.contrib.msc.core.tools import BaseTool, ToolType, ToolScope, create_tool, remove_tools from tvm.contrib.msc.core.utils.namespace import MSCFramework from tvm.contrib.msc.core.utils.message import MSCStage @@ -43,7 +44,7 @@ class BaseRunner(object): The IRModule of relax. params: dict of The parameters of the IRModule. - tools_config: dict + tools_config: list The config of MSC Tools. translate_config: dict The config for translate IRModule to MSCGraph. @@ -70,7 +71,7 @@ class BaseRunner(object): def __init__( self, mod: tvm.IRModule, - tools_config: Optional[Dict[str, Any]] = None, + tools_config: Optional[List[dict]] = None, translate_config: Optional[Dict[str, str]] = None, generate_config: Optional[Dict[str, str]] = None, build_config: Optional[Dict[str, str]] = None, @@ -83,7 +84,13 @@ def __init__( logger: logging.Logger = None, ): self._mod = mod - self._tools_config = msc_utils.copy_dict(tools_config) + if tools_config: + self._tools_type = [t["tool_type"] for t in tools_config] + self._tools_config = { + t["tool_type"]: msc_utils.copy_dict(t["tool_config"]) for t in tools_config + } + else: + self._tools_type, self._tools_config = [], {} self._translate_config = msc_utils.copy_dict(translate_config) self._generate_config = msc_utils.copy_dict(generate_config) self._build_config = msc_utils.copy_dict(build_config) @@ -94,11 +101,8 @@ def __init__( self._debug_level = debug_level self._training, self._trained = training, training self._logger = logger or msc_utils.get_global_logger() - self._logger.info( - msc_utils.msg_block( - "RUNNER.SETUP({} @ {})".format(self._stage, self.framework), self.setup() - ) - ) + self._logger.info(msc_utils.msg_block(self.runner_mark("SETUP"), self.setup())) + self._tools = self.setup_tools() def setup(self) -> dict: """Setup the runner @@ -114,23 +118,10 @@ def setup(self) -> dict: self._graphs, self._weights = [], {} self._model, self._model_info = None, {} self._runnable = None - # Setup tools - self._tools = {} - if self._tools_config: - self._update_codegen({"use_tools": True, "tools_tag": self._name}) - for t_type, config in self._tools_config.items(): - self._tools[t_type] = create_tool( - self.framework, - t_type, - self._name, - training=self._training, - stage=self._stage, - **config, - ) if self._plugin: self._update_codegen({"use_plugin": True}) return { - "tools": {k: v.tool_style() for k, v in self._tools.items()}, + "tools": {k: v.get("tool_style", "default") for k, v in self._tools_config.items()}, "plugin": self._plugin, "translate_config": self._translate_config, "generate_config": self._generate_config, @@ -140,6 +131,29 @@ def setup(self) -> dict: "debug_level": self._debug_level, } + def setup_tools(self) -> Dict[str, BaseTool]: + """Setup tools + + Returns + ------- + tools: dict + The tools. + """ + + tools = {} + if self._tools_type: + self._update_codegen({"use_tools": True, "tools_tag": self._name}) + for t_type in self._tools_type: + tools[t_type] = create_tool( + self.framework, + t_type, + self._name, + training=self._training, + stage=self._stage, + **self._tools_config[t_type], + ) + return tools + def change_stage(self, stage: str): """Change the stage of runner and tools""" @@ -154,7 +168,12 @@ def change_logger(self, logger: logging.Logger): for tool in self._tools.values(): tool.change_logger(logger) - def build(self, cache_dir: msc_utils.MSCDirectory = None, force_build: bool = False) -> Any: + def build( + self, + cache_dir: msc_utils.MSCDirectory = None, + force_build: bool = False, + disable_tools: List[str] = None, + ) -> Any: """Build the runnable object Parameters @@ -163,6 +182,8 @@ def build(self, cache_dir: msc_utils.MSCDirectory = None, force_build: bool = Fa cache path for save/load info force_build: bool Whether to force build the runner. + disable_tools: list + The tool types to be disabled. Returns ------- @@ -179,31 +200,32 @@ def build(self, cache_dir: msc_utils.MSCDirectory = None, force_build: bool = Fa else: cache_info = {} + # set tools to reset + if disable_tools: + tools = [t for t in self.get_tools() if t.tool_type not in disable_tools] + else: + tools = None + + build_msg = "" # Load graphs from cache if not self._graphs and cache_info.get("graphs"): self._graphs = self._load_graphs(cache_dir, cache_info["graphs"]) assert "weights" in cache_info, "Missing weights in cache_info" with open(cache_dir.relpath(cache_info["weights"]), "rb") as f: self._weights = tvm.runtime.load_param_dict(f.read()) - self._logger.info( - "Load %d graphs %d weights from %s", - len(self._graphs), - len(self._weights), - cache_dir, - ) + build_msg += "Load " # Translate graphs from module if not self._graphs: self._graphs, self._weights = self.translate() - self._logger.info( - "Translate %d graphs %d weights from module", len(self._graphs), len(self._weights) - ) + build_msg += "Translate " + build_msg += "{} graphs {} weights -> ".format(len(self._graphs), len(self._weights)) # Load model from cache if not self._model and cache_info.get("model"): - self._graphs, self._weights = self.reset_tools(cache_dir=cache_dir) + self._graphs, self._weights = self.reset_tools(tools=tools, cache_dir=cache_dir) self._model = self._load_model(cache_dir, cache_info["model"]) - self._logger.info("Load model(%s) from %s", self.framework, cache_dir) + build_msg += "Load " # Generate model if not self._model: @@ -218,37 +240,41 @@ def _build_scope_model(scope: str, apply_hooks: bool): # Generate distill model teacher_model = _build_scope_model(ToolScope.TEACHER, False) - self._graphs, self._weights = self.reset_tools(cache_dir=cache_dir) + self._graphs, self._weights = self.reset_tools(tools=tools, cache_dir=cache_dir) student_model = _build_scope_model(ToolScope.STUDENT, True) self._model = distiller.build_model(teacher_model, student_model) else: # Generate normal model - self._graphs, self._weights = self.reset_tools(cache_dir=cache_dir) + self._graphs, self._weights = self.reset_tools(tools=tools, cache_dir=cache_dir) self._model = self.generate_model() + build_msg += "Generate " - generate_msg = "Generate model({})".format(self.framework) - if self._tools: - self._logger.info("%s with tools: %s", generate_msg, ",".join(self._tools.keys())) - else: - self._logger.info("%s without tools", generate_msg) + # Add tool message + if self._tools: + build_msg += "model with tools " + str(",".join(self._tools.keys())) + " -> " + else: + build_msg += "model without tools -> " # Inspect model self._model_info = self._inspect_model() if self._debug_level >= 2: - self._logger.debug(msc_utils.msg_block("RUNNER.MODEL_INFO", self._model_info)) + self._logger.debug( + msc_utils.msg_block(self.runner_mark("MODEL_INFO"), self._model_info) + ) - runnable_msg = "runnable({}, {}) @ {}".format( - self.framework, "train" if self._training else "eval", self._device - ) # Load runnable from cache if not self._runnable and cache_info.get("runnable"): self._runnable = self._load_runnable(cache_dir, cache_info["runnable"]) - self._logger.info("Load %s from %s", runnable_msg, cache_dir) + build_msg += "Load " # Build runnable if not self._runnable: self._runnable = self.build_runnable() - self._logger.info("Build %s", runnable_msg) + build_msg += "Build " + build_msg += "runnable({}, {}) on {}".format( + self.framework, "train" if self._training else "eval", self._device + ) + self._logger.info(build_msg) return self._runnable def run( @@ -280,14 +306,14 @@ def run( inputs, type(inputs) ) assert all( - isinstance(data, np.ndarray) for data in inputs.values() - ), "Expected all inputs as np.ndarray" + msc_utils.is_array(data) for data in inputs.values() + ), "Expected all inputs as array like" inputs = {i["name"]: inputs[i["name"]] for i in model_inputs} outputs = self._call_runnable(self._runnable, inputs, self._device) if ret_type == "native": return outputs if ret_type == "dict": - if isinstance(outputs, (list, tuple)): + if isinstance(outputs, (list, tuple, tvm.ir.container.Array)): assert len(outputs) == len( model_outputs ), "outputs({}) mismatch with model outputs {}".format(len(outputs), model_outputs) @@ -297,8 +323,8 @@ def run( model_outputs ) outputs = {model_outputs[0]["name"]: outputs} - outputs = {name: msc_utils.cast_array(data) for name, data in outputs.items()} - elif ret_type == "list": + return {name: msc_utils.cast_array(data) for name, data in outputs.items()} + if ret_type == "list": if isinstance(outputs, dict): assert len(outputs) == len( model_outputs @@ -306,7 +332,7 @@ def run( outputs = [outputs[o["name"]] for o in model_outputs] if not isinstance(outputs, (list, tuple)): outputs = [outputs] - outputs = [msc_utils.cast_array(data) for data in outputs] + return [msc_utils.cast_array(data) for data in outputs] return outputs def save_cache( @@ -343,9 +369,8 @@ def save_cache( cache_info[t_type] = tool.save_cache(cache_dir) with open(cache_dir.relpath("cache_info.json"), "w") as f: f.write(json.dumps(cache_info, indent=2)) - self._logger.debug( - msc_utils.msg_block("RUNNER.SAVE_CACHE", {"folder": cache_dir, "info": cache_info}) - ) + title = self.runner_mark("SAVE_CACHE") + self._logger.debug(msc_utils.msg_block(title, {"folder": cache_dir, "info": cache_info})) def translate(self, apply_hooks: bool = True) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: """Translate IRModule to MSCgraphs @@ -421,7 +446,8 @@ def reset_tools( graphs = graphs or self._graphs weights = weights or self._weights - tools = tools or self._tools.values() + if tools is None: + tools = list(self.get_tools()) for tool in tools: graphs, weights = tool.reset(graphs, weights, cache_dir) return graphs, weights @@ -508,6 +534,22 @@ def _build_runnable(self, model: Any) -> Any: raise NotImplementedError("_build_runnable is not implemented for " + str(self.__class__)) + def export_module(self, folder: msc_utils.MSCDirectory) -> tvm.IRModule: + """Export the module from graphs + + Parameters + ---------- + folder: MSCDirectory + The export folder. + + Returns + ------- + module: IRModule + The exported module + """ + + raise NotImplementedError("export_module is not supported in BaseRunner") + def train(self): """Change status to train""" @@ -583,7 +625,7 @@ def get_tools(self) -> Iterable[BaseTool]: if tool: yield tool - def apply_tool(self, tool_type: str, data_loader: Any = None) -> str: + def make_plan(self, tool_type: str, data_loader: Any = None) -> str: """Execute tool and get plan Parameters @@ -591,7 +633,7 @@ def apply_tool(self, tool_type: str, data_loader: Any = None) -> str: tool_type: str The tool type, should be in ToolType data_loader: - The data loader + The data loader. Returns ------- @@ -602,7 +644,7 @@ def apply_tool(self, tool_type: str, data_loader: Any = None) -> str: assert tool_type in self._tools, "Can not find tool " + str(tool_type) if tool_type == ToolType.PRUNER: pruner = self.get_tool(ToolType.PRUNER) - if not pruner.finalize(): + if not pruner.pruned: assert data_loader, "data_loader should be given to plan prune" for inputs in data_loader(): self.run(inputs, ret_type="native") @@ -625,13 +667,21 @@ def apply_tool(self, tool_type: str, data_loader: Any = None) -> str: distiller.learn(loss) distiller.distill() plan = distiller.finalize() + elif tool_type == ToolType.TRACKER: + tracker = self.get_tool(ToolType.TRACKER) + if not tracker.tracked: + assert data_loader, "data_loader should be given to plan prune" + for inputs in data_loader(): + self.run(inputs, ret_type="native") + if tracker.tracked: + break + plan = tracker.finalize() else: plan = self.get_tool(tool_type).finalize() - assert plan, "Failed to create plan for {}".format(tool_type) + self._logger.debug("Made %d plan for %s", len(plan), tool_type) plan_file = self._tools_config[tool_type]["plan_file"] with open(plan_file, "w") as f: f.write(json.dumps(plan, indent=2)) - self._logger.info("Save %d plan(%s) -> %s", len(plan), tool_type, plan_file) return plan_file def _apply_hook(self, desc: str, hook_def: dict, *args, **kwargs) -> Any: @@ -738,6 +788,30 @@ def get_weights(self, framework: str = None, device: str = None) -> Iterable[tvm data = msc_utils.cast_array(data, framework, device) yield data + def get_runtime_params(self) -> Dict[str, tvm.nd.array]: + """Get the runtime parameters + + Returns + ------- + params: dict + The parameters from runtime. + """ + + return self._get_runtime_params() + + def _get_runtime_params(self) -> Dict[str, tvm.nd.array]: + """Get the runtime parameters + + Returns + ------- + params: dict + The parameters from runtime. + """ + + raise NotImplementedError( + "_get_runtime_params is not implemented for " + str(self.__class__) + ) + def destory(self): """Destory runner""" @@ -897,6 +971,22 @@ def _device_enabled(self, device: str) -> bool: return True + def runner_mark(self, msg: Any) -> str: + """Mark the message with runner info + + Parameters + ------- + msg: str + The message + + Returns + ------- + msg: str + The message with mark. + """ + + return "RUNNER({} @ {}) {}".format(self.framework, self._stage, msg) + @property def stage(self): return self._stage @@ -930,21 +1020,27 @@ def framework(self): return MSCFramework.MSC @classmethod - def load_native(cls, model: Any) -> Any: + def load_native(cls, model: Any, config: dict) -> Tuple[Any, str, bool]: """Load the native model Parameters ------- model: The native model. + config: dict + The config for pipeline. Returns ------- model: The loaded native model. + device: str + The device of the model. + training: + Whether the model is for training. """ - return model, "cpu" + return model, "cpu", False @classmethod def update_config(cls, stage: str, config: dict, model: Any = None) -> dict: @@ -982,10 +1078,6 @@ def update_config(cls, stage: str, config: dict, model: Any = None) -> dict: config[stage]["run_config"] = run_config return config - @classmethod - def support_tool(cls, tool_type: str) -> bool: - return True - class ModelRunner(BaseRunner): """Model runner of MSC""" @@ -1090,6 +1182,26 @@ def _inspect_model(self) -> dict: return self._graphs[0].inspect() + def export_module(self, folder: msc_utils.MSCDirectory) -> tvm.IRModule: + """Export the module from graphs + + Parameters + ---------- + folder: MSCDirectory + The export folder. + + Returns + ------- + module: IRModule + The exported module + """ + + build_folder = folder.create_dir("export_build", keep_history=False, cleanup=True) + module = to_relax( + self._graphs[0], self.get_runtime_params(), build_folder=build_folder, use_alias=False + ) + return module + class BYOCRunner(BaseRunner): """BYOC runner of MSC""" @@ -1189,7 +1301,7 @@ def _save_graphs(self, cache_dir: msc_utils.MSCDirectory) -> dict: The cache info. """ - sub_graphs = [g.name + "_graph.info" for g in self._graphs] + sub_graphs = [g.name + "_graph.json" for g in self._graphs] with cache_dir: for graph, g_file in zip(self._graphs, sub_graphs): with open(g_file, "w") as f_graph: @@ -1288,16 +1400,10 @@ def _call_runnable( The outputs in list. """ - model_inputs = self.get_inputs() - if device == "cpu": - tvm_inputs = [tvm.nd.array(inputs[i["name"]]) for i in model_inputs] - elif device.startswith("cuda"): - dev_id = int(device.split(":")[1]) if ":" in device else 0 - tvm_inputs = [ - tvm.nd.array(inputs[i["name"]], device=tvm.cuda(dev_id)) for i in model_inputs - ] - else: - raise NotImplementedError("Unsupported device " + str(device)) + input_names = [i["name"] for i in self.get_inputs()] + tvm_inputs = [ + msc_utils.cast_array(inputs[i], MSCFramework.TVM, device) for i in input_names + ] return runnable["main"](*tvm_inputs) def _inspect_model(self) -> dict: @@ -1310,10 +1416,9 @@ def _inspect_model(self) -> dict: """ if self._debug_level >= 2: - for idx, graph in enumerate(self._graphs): - self._logger.debug( - msc_utils.msg_block("GRAPH[{}].INFO".format(idx), graph.inspect()) - ) + sub_graphs = {g.name: g.inspect for g in self._graphs} + title = self.runner_mark("SUBGRAPHS({})".format(len(sub_graphs))) + self._logger.debug(msc_utils.msg_block(title, sub_graphs)) return self._byoc_graph.inspect() def _device_enabled(self, device: str) -> bool: diff --git a/python/tvm/contrib/msc/core/tools/configer.py b/python/tvm/contrib/msc/core/tools/configer.py new file mode 100644 index 000000000000..c9ac6dd876b2 --- /dev/null +++ b/python/tvm/contrib/msc/core/tools/configer.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.tools.configer""" + +from typing import Union +from tvm.contrib.msc.core import utils as msc_utils +from .tool import ToolType + + +class ToolConfiger(object): + """Base configer for tool""" + + def config(self, raw_config: dict = None) -> dict: + """Get the config + + Parameters + ---------- + raw_config: dict + The raw config. + + Returns + ------- + config: dict + The update config. + """ + + config = {} + if isinstance(raw_config, dict) and "gym_configs" in raw_config: + config["gym_configs"] = [self.config_gym(g) for g in raw_config.pop("gym_configs")] + if raw_config: + config["tool_config"] = self.update_tool(raw_config) + else: + config["tool_config"] = self.config_tool() + if self.run_type: + config["run_type"] = self.run_type + if self.apply_once: + config["apply_once"] = self.apply_once + return config + + def config_tool(self) -> dict: + """Get the default config of tool + + Returns + ------- + config: dict + The default config. + """ + + raise NotImplementedError("config_tool is not implemented in ToolConfiger") + + def update_tool(self, raw_config: dict) -> dict: + """Update tool config from raw_config + + Parameters + ---------- + raw_config: dict + The raw config. + + Returns + ------- + config: dict + The update config. + """ + + config = self.config_tool() + return msc_utils.update_dict(config, raw_config) + + def config_gym(self, gym_config: Union[dict, str]) -> dict: + """Config the gym + + Parameters + ---------- + gym_config: dict + The raw config. + + Returns + ------- + gym_config: dict + The update config. + """ + + raise NotImplementedError("config_gym is not implemented in ToolConfiger") + + @property + def run_type(self): + return "" + + @property + def apply_once(self): + return False + + @classmethod + def tool_type(cls): + return ToolType.BASE diff --git a/python/tvm/contrib/msc/core/tools/distill/__init__.py b/python/tvm/contrib/msc/core/tools/distill/__init__.py index 8714eae4e4da..a3478d7b9682 100644 --- a/python/tvm/contrib/msc/core/tools/distill/__init__.py +++ b/python/tvm/contrib/msc/core/tools/distill/__init__.py @@ -18,3 +18,4 @@ from .distiller import * from .method import * +from .configer import * diff --git a/python/tvm/contrib/msc/core/tools/distill/configer.py b/python/tvm/contrib/msc/core/tools/distill/configer.py new file mode 100644 index 000000000000..b531bc3a88fa --- /dev/null +++ b/python/tvm/contrib/msc/core/tools/distill/configer.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.tools.distill.configer""" + +from tvm.contrib.msc.core.tools.tool import ToolType +from tvm.contrib.msc.core.tools.configer import ToolConfiger +from tvm.contrib.msc.core import utils as msc_utils + + +class DistillConfiger(ToolConfiger): + """Configer for distill""" + + @classmethod + def tool_type(cls): + return ToolType.DISTILLER + + +@msc_utils.register_tool_configer +class DefaultDistillConfiger(DistillConfiger): + """Default configer for distill""" + + def config_tool(self) -> dict: + """Get the default config of tool + + Returns + ------- + config: dict + The default config. + """ + + return { + "plan_file": "msc_distiller.json", + "strategys": [ + { + "methods": {"mark": "loss_lp_norm"}, + "marks": ["loss"], + }, + ], + } + + @classmethod + def config_style(cls): + return "default" diff --git a/python/tvm/contrib/msc/core/tools/distill/distiller.py b/python/tvm/contrib/msc/core/tools/distill/distiller.py index 58cf3fd2d953..7eee93cbc9e6 100644 --- a/python/tvm/contrib/msc/core/tools/distill/distiller.py +++ b/python/tvm/contrib/msc/core/tools/distill/distiller.py @@ -39,7 +39,10 @@ def setup(self) -> dict: self._max_iter = self._options.get("max_iter", 5) self._save_step = self._options.get("save_step", 50) - self._weights_folder = msc_utils.get_weights_dir().create_dir("Distill") + if "weights_folder" in self._options: + self._weights_folder = msc_utils.msc_dir(self._options["weights_folder"]) + else: + self._weights_folder = msc_utils.get_weights_dir().create_dir("Distill") self._weights_path = self._weights_folder.relpath("distill_{}.bin".format(self._max_iter)) self._distilled = os.path.isfile(self._weights_path) return super().setup() @@ -64,8 +67,7 @@ def _reset( The weights. """ - self._current_iter = 0 - self._total_loss = 0 + self._current_iter, self._total_loss = 0, 0 if self._distilled: with open(self._weights_path, "rb") as f: distilled_weights = tvm.runtime.load_param_dict(f.read()) @@ -100,8 +102,8 @@ def learn(self, loss: Any): The loss after forward """ - if self.on_debug(3): - self._logger.debug("%sStart Learn", self.msg_mark()) + if self.on_debug(3, in_forward=False): + self._logger.debug("%s start learn[%d]", self.tool_type(), self._current_iter) self._total_loss += float(self._learn(loss)) def _learn(self, loss: Any): @@ -242,6 +244,24 @@ def _distill_tensor( self._plan[name][scope] = plan return tensor + def export_config(self, config: dict, folder: msc_utils.MSCDirectory) -> dict: + """Export the config for tool + + Parameters + ------- + config: dict + The source config. + folder: MSCDirectory + The export folder. + + Returns + ------- + config: dict + The exported config. + """ + + return {} + @property def distilled(self): return self._distilled diff --git a/python/tvm/contrib/msc/core/tools/prune/__init__.py b/python/tvm/contrib/msc/core/tools/prune/__init__.py index 8317d52ac12b..8954cd6b90a1 100644 --- a/python/tvm/contrib/msc/core/tools/prune/__init__.py +++ b/python/tvm/contrib/msc/core/tools/prune/__init__.py @@ -18,3 +18,4 @@ from .pruner import * from .method import * +from .configer import * diff --git a/python/tvm/contrib/msc/core/tools/prune/configer.py b/python/tvm/contrib/msc/core/tools/prune/configer.py new file mode 100644 index 000000000000..74a4f598862f --- /dev/null +++ b/python/tvm/contrib/msc/core/tools/prune/configer.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.tools.prune.configer""" + +from typing import Union +from tvm.contrib.msc.core.tools.tool import ToolType +from tvm.contrib.msc.core.tools.configer import ToolConfiger +from tvm.contrib.msc.core import utils as msc_utils + + +class PruneConfiger(ToolConfiger): + """Configer for prune""" + + def config_gym(self, raw_config: Union[dict, str]) -> dict: + """Config the gym + + Parameters + ---------- + gym_config: dict + The raw config. + + Returns + ------- + gym_config: dict + The update config. + """ + + if isinstance(raw_config, dict): + return raw_config + if raw_config == "default": + return { + "env": { + "executors": { + "action_space": { + "method": "action_prune_density", + "start": 0.2, + "end": 0.8, + "step": 0.1, + } + }, + }, + "agent": {"role_type": "search.grid", "executors": {}}, + } + else: + raise TypeError("Unexpected gym config " + str(raw_config)) + + @classmethod + def tool_type(cls): + return ToolType.PRUNER + + +@msc_utils.register_tool_configer +class DefaultPruneConfiger(PruneConfiger): + """Default configer for prune""" + + def config_tool(self) -> dict: + """Get the default config of tool + + Returns + ------- + config: dict + The default config. + """ + + return { + "plan_file": "msc_pruner.json", + "strategys": [ + { + "methods": { + "weights": {"method_name": "per_channel", "density": 0.8}, + "output": {"method_name": "per_channel", "density": 0.8}, + } + } + ], + } + + @classmethod + def config_style(cls): + return "default" diff --git a/python/tvm/contrib/msc/core/tools/prune/pruner.py b/python/tvm/contrib/msc/core/tools/prune/pruner.py index 7eb4434a62f3..515ea09e0145 100644 --- a/python/tvm/contrib/msc/core/tools/prune/pruner.py +++ b/python/tvm/contrib/msc/core/tools/prune/pruner.py @@ -22,6 +22,7 @@ import tvm from tvm.contrib.msc.core.ir import MSCGraph, WeightJoint, MSCTensor from tvm.contrib.msc.core.tools.tool import ToolType, WeightTool, ToolStrategy +from tvm.contrib.msc.core.utils.message import MSCStage from tvm.contrib.msc.core import _ffi_api from tvm.contrib.msc.core import utils as msc_utils from .method import PruneMethod @@ -30,6 +31,19 @@ class BasePruner(WeightTool): """Base pruner for all""" + def setup(self) -> dict: + """Setup the tool + + Returns + ------- + info: dict + The setup info. + """ + + if not self._plan: + self.change_stage(MSCStage.PRUNE) + return super().setup() + def _get_wtypes(self) -> Tuple[Dict[str, List[str]], Dict[str, str]]: """Get the weight types from options @@ -65,13 +79,13 @@ def _get_wtypes(self) -> Tuple[Dict[str, List[str]], Dict[str, str]]: } return main_wtypes, relation_wtypes - def _parse_strategys(self, strategy_list: dict) -> Dict[str, ToolStrategy]: + def _parse_strategys(self, strategy_list: List[dict]) -> Dict[str, ToolStrategy]: """Parse the strategy to get valid strategy Parameters ------- - strategy_list: dict - The given strategy + strategy_list: list + The given strategys. Returns ------- @@ -79,10 +93,12 @@ def _parse_strategys(self, strategy_list: dict) -> Dict[str, ToolStrategy]: The parsed strategy. """ + if self._stage != MSCStage.PRUNE: + return {} + def _update_stages(strategy): if "stages" not in strategy: - strategy["stages"] = [msc_utils.MSCStage.PRUNE] - strategy["tensor_types"] = ["weight", "output"] + strategy["stages"] = [MSCStage.PRUNE] return strategy return super()._parse_strategys([_update_stages(s) for s in strategy_list]) @@ -203,11 +219,8 @@ def _process_tensor( strategys = self._get_tensor_strategys(lazy_name, info["consumer"]) self._prune_tensor(lazy_name, info["consumer"], strategys) t_mark = ".".join([s.get_executor().name for s in strategys]) - self.debug_tensor( - self.find_tensor(lazy_name), - lazy_name, - consumer, - "lazy processed({})".format(t_mark), + self.debug_tensors( + lazy_name, consumer, t_mark, {"lazy": self.find_tensor(lazy_name)} ) lazy_pruned.add(lazy_name) if lazy_pruned: @@ -476,40 +489,24 @@ def create_tasks(self, **kwargs) -> List[dict]: if w_node.get_attr("weight_strategy") != "main": continue consumer = self.find_producer(w_node.name).name - strategy = self._get_tensor_strategy(w_node.name, consumer) + executor = self._get_tensor_strategy(w_node.name, consumer).get_executor(MSCStage.PRUNE) tasks.append( - { - "tensor_names": [self.to_tensor_id(w_node.name, consumer)], - **strategy.meta, - } + {"methods": {"tensor": executor.method_def}, "tensor_names": [w_node.name]} ) return tasks - def plan_by_strategys(self, strategys: List[dict]) -> dict: - """Plan the pruning with startegys and get plan + def change_strategys(self, strategy_list: List[dict]): + """Change the strategys Parameters ------- - strategys: list - The given strategys - - Returns - ------- - plan: dict - The plan after new strategy applied. + strategy_list: list + The given strategys. """ - self._tensor_cache, self._processed_tensor = {}, {} self._plan = {} - self._strategys = self._parse_strategys(msc_utils.copy_dict(strategys)) - info = {k: v.inspect() for k, v in self._strategys.items()} - title = "{}.PRUNE_STRATEGYS".format(self.tool_type().upper()) - self._logger.debug(msc_utils.msg_block(title, info, width=0)) - for w_node in self.get_w_nodes(): - consumer = self.find_consumers(w_node.name)[0] - self.process_tensor(w_node.weight, w_node.name, consumer.name, "") - self._plan = {n: c for n, c in self._plan.items() if c["in_indices"] or c["out_indices"]} - return self._plan + self.change_stage(MSCStage.PRUNE) + super().change_strategys(strategy_list) def finalize(self) -> dict: """Get the plan""" @@ -517,6 +514,28 @@ def finalize(self) -> dict: self._plan = {n: c for n, c in self._plan.items() if c["in_indices"] or c["out_indices"]} return super().finalize() + def export_config(self, config: dict, folder: msc_utils.MSCDirectory) -> dict: + """Export the config for tool + + Parameters + ------- + config: dict + The source config. + folder: MSCDirectory + The export folder. + + Returns + ------- + config: dict + The exported config. + """ + + return {} + + @property + def pruned(self): + return len(self._plan) > 0 + @classmethod def tool_type(cls): return ToolType.PRUNER diff --git a/python/tvm/contrib/msc/core/tools/quantize/__init__.py b/python/tvm/contrib/msc/core/tools/quantize/__init__.py index 1aad17c0553c..ed7942a7c330 100644 --- a/python/tvm/contrib/msc/core/tools/quantize/__init__.py +++ b/python/tvm/contrib/msc/core/tools/quantize/__init__.py @@ -18,3 +18,4 @@ from .quantizer import * from .method import * +from .configer import * diff --git a/python/tvm/contrib/msc/core/tools/quantize/configer.py b/python/tvm/contrib/msc/core/tools/quantize/configer.py new file mode 100644 index 000000000000..81a6149806d2 --- /dev/null +++ b/python/tvm/contrib/msc/core/tools/quantize/configer.py @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.tools.quantize.configer""" + +from typing import Union + +from tvm.contrib.msc.core.tools.tool import ToolType +from tvm.contrib.msc.core.tools.configer import ToolConfiger +from tvm.contrib.msc.core import utils as msc_utils +from .quantizer import QuantizeStage + + +class QuantizeConfiger(ToolConfiger): + """Configer for quantize""" + + def config_gym(self, gym_config: Union[dict, str]) -> dict: + """Config the gym + + Parameters + ---------- + gym_config: dict + The raw config. + + Returns + ------- + gym_config: dict + The update config. + """ + + if isinstance(gym_config, dict): + return gym_config + if gym_config == "default": + return { + "env": { + "executors": { + "action_space": { + "method": "action_quantize_scale", + "start": 0.8, + "end": 1.2, + "step": 0.1, + } + }, + }, + "agent": {"agent_type": "search.grid", "executors": {}}, + } + else: + raise TypeError("Unexpected gym config " + str(gym_config)) + + @classmethod + def tool_type(cls): + return ToolType.QUANTIZER + + +@msc_utils.register_tool_configer +class DefaultQuantizeConfiger(QuantizeConfiger): + """Default configer for quantize""" + + def config_tool(self) -> dict: + """Get the default config of tool + + Returns + ------- + config: dict + The default config. + """ + + op_types = [ + "nn.conv1d", + "msc.conv1d_bias", + "nn.conv2d", + "msc.conv2d_bias", + "nn.conv3d", + "msc.conv3d_bias", + "msc.linear", + "msc.linear_bias", + "nn.avg_pool1d", + "nn.avg_pool2d", + "nn.avg_pool3d", + ] + + return { + "plan_file": "msc_quantizer.json", + "strategys": [ + { + "methods": { + "input": "gather_maxmin", + "output": "gather_maxmin", + "weights": "gather_max_per_channel", + }, + "op_types": op_types, + "stages": [QuantizeStage.GATHER], + }, + { + "methods": {"input": "calibrate_maxmin", "output": "calibrate_maxmin"}, + "op_types": op_types, + "stages": [QuantizeStage.CALIBRATE], + }, + { + "methods": { + "input": "quantize_normal", + "weights": "quantize_normal", + "output": "dequantize_normal", + }, + "op_types": op_types, + }, + ], + } + + @classmethod + def config_style(cls): + return "default" diff --git a/python/tvm/contrib/msc/core/tools/quantize/quantizer.py b/python/tvm/contrib/msc/core/tools/quantize/quantizer.py index 3b0f3267df85..8bf8242bb4b2 100644 --- a/python/tvm/contrib/msc/core/tools/quantize/quantizer.py +++ b/python/tvm/contrib/msc/core/tools/quantize/quantizer.py @@ -19,6 +19,7 @@ from typing import List, Dict, Any from tvm.contrib.msc.core.tools.tool import ToolType, BaseTool, ToolStrategy +from tvm.contrib.msc.core.utils.message import MSCStage from tvm.contrib.msc.core import utils as msc_utils @@ -41,7 +42,7 @@ def setup(self) -> dict: if self._plan: self._calibrated = True - self.change_stage(msc_utils.MSCStage.QUANTIZE) + self.change_stage(MSCStage.QUANTIZE) else: self._calibrated = False self._calibrate_plan = {} @@ -73,17 +74,21 @@ def calibrate(self) -> dict: self._calibrated = True for name, plan in new_plan.items(): self._plan[name] = {k: v for k, v in plan.items() if k not in ("calibrated")} - self.change_stage(msc_utils.MSCStage.QUANTIZE) + self.change_stage(MSCStage.QUANTIZE) + calib_type = "calibrate" if self._calibrated else "gather" + self._logger.info( + "Quantizer %s %d plan after %d batch", calib_type, len(new_plan), self._forward_cnt + ) self._forward_cnt = 0 return new_plan - def _parse_strategys(self, strategy_list: dict) -> Dict[str, ToolStrategy]: + def _parse_strategys(self, strategy_list: List[dict]) -> Dict[str, ToolStrategy]: """Parse the strategy to get valid strategy Parameters ------- - strategy_list: dict - The given strategy + strategy_list: list + The given strategys Returns ------- @@ -93,7 +98,7 @@ def _parse_strategys(self, strategy_list: dict) -> Dict[str, ToolStrategy]: def _update_stages(strategy): if "stages" not in strategy: - strategy["stages"] = [msc_utils.MSCStage.QUANTIZE] + strategy["stages"] = [MSCStage.QUANTIZE] return strategy return super()._parse_strategys([_update_stages(s) for s in strategy_list]) @@ -115,10 +120,7 @@ def _check_tensor(self, name: str, consumer: str) -> bool: """ if self._calibrated: - tensor_id = self.to_tensor_id(name, consumer) - if tensor_id not in self._plan: - return False - return self._plan.get(tensor_id, {}).get("nbits", 8) != -1 + return self.to_tensor_id(name, consumer) in self._plan strategys = self._get_tensor_strategys(name, consumer) if not strategys: return False @@ -226,14 +228,21 @@ def create_tasks(self, **kwargs) -> List[dict]: """ tasks, recorded = [], set() - for tensor_id, plan in self._plan.items(): - name, _ = self.from_tensor_id(tensor_id) + for tensor_id in self._plan: + name, consumer = self.from_tensor_id(tensor_id) if self.is_weight(name) and not kwargs.get("quantize_weights", False): continue if name not in recorded: - tasks.append({"name": tensor_id, **plan}) + executor = self._get_tensor_strategy(name, consumer).get_executor(MSCStage.QUANTIZE) + task = {"methods": {"tensor": executor.method_def}} if self._cache_processed: + task["tensor_ids"] = [ + self.to_tensor_id(name, c.name) for c in self.find_consumers(name) + ] recorded.add(name) + else: + task["tensor_ids"] = [tensor_id] + tasks.append(task) return tasks @property diff --git a/python/tvm/contrib/msc/core/tools/tool.py b/python/tvm/contrib/msc/core/tools/tool.py index fec391339f20..7cd0742c0753 100644 --- a/python/tvm/contrib/msc/core/tools/tool.py +++ b/python/tvm/contrib/msc/core/tools/tool.py @@ -91,7 +91,7 @@ def execute(self, *args, **kwargs) -> Any: The plan generated by method or processed tensor. """ - kwargs.update({k: v for k, v in self._config.items() if k not in kwargs}) + kwargs.update(self._config) return self._method(*args, **kwargs) def copy(self, name: str = None, method: callable = None, config: dict = None): @@ -116,6 +116,10 @@ def copy(self, name: str = None, method: callable = None, config: dict = None): new_config.update({k: v for k, v in self._config.items() if k not in new_config}) return ToolExecutor(name or self._name, method or self._method, new_config) + @property + def method_def(self): + return {"method_name": self._name, **self._config} + @property def name(self): return self._name @@ -140,12 +144,11 @@ class ToolStrategy(object): The meta strategy config. """ - def __init__(self, name: str, tensor_type: str, stage: str = "default", meta: dict = None): + def __init__(self, name: str, tensor_type: str, stage: str = "default"): self._name = name self._tensor_type = tensor_type self._stage = stage self._executors = {} - self._meta = meta def __str__(self): return "{}({} @ {}) ".format(self._name, self._tensor_type, self._stage) + "; ".join( @@ -161,7 +164,7 @@ def inspect(self) -> dict: The inspect of the strategy. """ - return {"{}({})".format(s, self._tensor_type): str(e) for s, e in self._executors.items()} + return {s: str(e) for s, e in self._executors.items()} def __call__(self, *args, **kwargs) -> Any: return self.apply(*args, **kwargs) @@ -204,17 +207,23 @@ def add_executor(self, stage: str, executor: ToolExecutor): if not self._stage: self._stage = stage - def get_executor(self) -> Tuple[callable, dict]: + def get_executor(self, stage: str = None) -> Tuple[callable, dict]: """Get executor of current stage + Parameters + ---------- + stage: str + The mark of the executor. + Returns ------- executor: tuple The method and config to execute strategy """ - if self._stage in self._executors: - return self._executors[self._stage] + stage = stage or self._stage + if stage in self._executors: + return self._executors[stage] return self._executors["default"] def get_config(self) -> dict: @@ -273,10 +282,6 @@ def copy( strategy.add_executor(st_name, new_executor) return strategy - @property - def meta(self): - return self._meta - class BaseTool(object): """Basic tool of MSC @@ -316,22 +321,20 @@ def __init__( logger: logging.Logger = None, ): self._stage = stage + self._plan_file = plan_file if os.path.isfile(plan_file): self._plan = msc_utils.load_dict(plan_file) else: self._plan = {} - self._strategys = self._parse_strategys(msc_utils.copy_dict(strategys)) + self._meta_strategys, self._strategys = msc_utils.copy_dict(strategys), {} self._training = training self._cache_processed = cache_processed self._options = options or {} self._debug_level = debug_level self._verbose_step = verbose_step self._logger = logger or msc_utils.get_global_logger() - title = "{}.SETUP({} @ {})".format(self.tool_type().upper(), self._stage, self.framework()) + title = self.tool_mark("APPLY_PLAN" if self._plan else "MAKE_PLAN") self._logger.info(msc_utils.msg_block(title, self.setup(), width=0)) - if self._debug_level >= 3 and self._plan: - title = "{}.PLAN".format(self.tool_type().upper()) - self._logger.debug(msc_utils.msg_block(title, self._plan)) def setup(self) -> dict: """Setup the tool @@ -347,79 +350,15 @@ def setup(self) -> dict: self._graphs, self._weights = [], {} self._graph_id, self._forward_cnt = 0, 0 self._processed_tensor = {} + plan_info = self._plan if self._plan and self._debug_level >= 2 else self._plan_file return { "style": self.tool_style(), - "strategys": {k: v.inspect() for k, v in self._strategys.items()}, "cache_processed": self._cache_processed, "options": self._options, - "planed_num": len(self._plan), - "verbose_step": self._verbose_step, - "debug_level": self._debug_level, + "debug_step({})".format(self._debug_level): self._verbose_step, + "plan({})".format(len(self._plan)): plan_info, } - def _parse_strategys(self, strategy_list: List[dict]) -> Dict[str, ToolStrategy]: - """Parse the strategy to get valid strategy - - Parameters - ------- - strategy_list: list - The given strategys - - Returns - ------- - strategys: dict - The parsed strategy. - """ - - strategys = {} - assert isinstance(strategy_list, list) and all( - isinstance(s, dict) for s in strategy_list - ), "ToolStrategy should be given as list of dict" - for strategy in strategy_list: - meta_strategy = msc_utils.copy_dict(strategy) - method_cls_name = strategy.pop("method_cls") if "method_cls" in strategy else "default" - method_cls = msc_utils.get_registered_tool_method( - self.framework(), self.tool_type(), method_cls_name - ) - method_name = strategy.pop("method") if "method" in strategy else "default" - method = None - if hasattr(method_cls, method_name): - method = getattr(method_cls, method_name) - if not method: - default_cls = msc_utils.get_registered_tool_method( - MSCFramework.MSC, self.tool_type(), method_cls_name - ) - if hasattr(default_cls, method_name): - method = getattr(default_cls, method_name) - if not method: - method = msc_utils.get_registered_func(method_name) - assert method, "Can not find method with " + str(method_name) - tensor_types = ( - strategy.pop("tensor_types") - if "tensor_types" in strategy - else ["input", "output", "weight"] - ) - if "op_types" in strategy: - op_types = strategy.pop("op_types") - marks = [("{}.{}".format(s, t), t) for s, t in product(op_types, tensor_types)] - elif "op_names" in strategy: - op_names = strategy.pop("op_names") - marks = [("{}.{}".format(s, t), t) for s, t in product(op_names, tensor_types)] - elif "tensor_names" in strategy: - tensor_names = strategy.pop("tensor_names") - marks = [(n, "tensor") for n in tensor_names] - else: - marks = [("default." + str(t), t) for t in tensor_types] - stages = strategy.pop("stages") if "stages" in strategy else ["default"] - for mark, t_type in marks: - if mark not in strategys: - strategys[mark] = ToolStrategy(mark, t_type, self._stage, meta_strategy) - for stage in stages: - strategys[mark].add_executor( - stage, ToolExecutor(method_name, method, copy.deepcopy(strategy)) - ) - return strategys - def reset( self, graphs: List[MSCGraph], @@ -454,12 +393,11 @@ def reset( if self.tool_type() in cache_info: self.load_cache(cache_dir, cache_info[self.tool_type()]) self._graphs, self._weights = self._reset(graphs, weights) - self._logger.debug( - "%s reset %d graphs, %d weights", - self.tool_type(), - len(self._graphs), - len(self._weights), - ) + self._strategys = self._parse_strategys(self._meta_strategys) + if self._strategys: + title = self.tool_mark("STRATEGYS({})".format(len(self._strategys))) + strategys_info = {k: v.inspect() for k, v in self._strategys.items()} + self._logger.info(msc_utils.msg_block(title, strategys_info, width=0)) return self._graphs, self._weights def _reset( @@ -484,6 +422,105 @@ def _reset( return graphs, weights + def _parse_strategys(self, strategy_list: List[dict]) -> Dict[str, ToolStrategy]: + """Parse the strategy to get valid strategy + + Parameters + ------- + strategy_list: list + The given strategys. + + Returns + ------- + strategys: dict + The parsed strategy. + """ + + assert isinstance(strategy_list, list) and all( + isinstance(s, dict) for s in strategy_list + ), "ToolStrategy should be given as list of dict" + assert self._graphs, "graphs are needed to parse strategys" + all_tensor_names = set(t.name for t in self.get_tensors()) + all_tensor_ids = set(self.get_tensor_ids()) + all_op_types = set(n.optype for n in self.get_nodes()) + all_op_names = set(n.name for n in self.get_nodes()) + strategys = {} + + def _get_method(method_name): + if "." in method_name: + method_cls_name, method_name = method_name.split(".") + else: + method_cls_name = "default" + method_cls = msc_utils.get_registered_tool_method( + self.framework(), self.tool_type(), method_cls_name + ) + if hasattr(method_cls, method_name): + return getattr(method_cls, method_name) + default_cls = msc_utils.get_registered_tool_method( + MSCFramework.MSC, self.tool_type(), method_cls_name + ) + if hasattr(default_cls, method_name): + return getattr(default_cls, method_name) + method = msc_utils.get_registered_func(method_name) + assert method, "Can not find method with " + str(method_name) + return method + + for strategy in strategy_list: + meta_strategy = msc_utils.copy_dict(strategy) + for t_type, method_def in meta_strategy["methods"].items(): + if isinstance(method_def, str): + method_name, method_kwargs = method_def, {} + elif isinstance(method_def, dict): + assert "method_name" in method_def, "Can not find method_name" + method_name = method_def["method_name"] + method_kwargs = {k: v for k, v in method_def.items() if k != "method_name"} + else: + raise TypeError( + "Only support string and dict as method define, get " + str(method_def) + ) + method = _get_method(method_name) + if "marks" in strategy: + assert t_type == "mark", "mark strategy only support mark method, get " + str( + meta_strategy + ) + marks = strategy["marks"] + elif "tensor_names" in strategy: + assert ( + t_type == "tensor" + ), "tensor strategy only support tensor method, get " + str(meta_strategy) + marks = [t for t in strategy["tensor_names"] if t in all_tensor_names] + elif "tensor_ids" in strategy: + assert ( + t_type == "tensor" + ), "tensor strategy only support tensor method, get " + str(meta_strategy) + marks = [t for t in strategy["tensor_ids"] if t in all_tensor_ids] + elif "op_types" in strategy: + op_types = [t for t in strategy["op_types"] if t in all_op_types] + marks = ["{}.{}".format(t, t_type) for t in op_types] + elif "op_names" in strategy: + op_names = [t for t in strategy["op_names"] if t in all_op_names] + marks = ["{}.{}".format(t, t_type) for t in op_names] + else: + marks = ["default." + str(t_type)] + for mark, stage in product(marks, strategy.get("stages", ["default"])): + if mark not in strategys: + strategys[mark] = ToolStrategy(mark, t_type, self._stage) + strategys[mark].add_executor( + stage, ToolExecutor(method_name, method, copy.deepcopy(method_kwargs)) + ) + return strategys + + def change_strategys(self, strategy_list: List[dict]): + """Change the strategys + + Parameters + ------- + strategy_list: list + The given strategys. + """ + + self._meta_strategys = strategy_list + def change_stage(self, stage: str): """Change the stage of tool and strategy""" @@ -501,6 +538,28 @@ def destory(self): self._graphs, self._weights = [], {} + def export_config(self, config: dict, folder: msc_utils.MSCDirectory) -> dict: + """Export the config for tool + + Parameters + ------- + config: dict + The source config. + folder: MSCDirectory + The export folder. + + Returns + ------- + config: dict + The exported config. + """ + + config = msc_utils.copy_dict(config) + plan_file = msc_utils.to_abs_path(config["plan_file"], msc_utils.get_config_dir()) + if os.path.isfile(plan_file): + config["plan_file"] = folder.create_dir("tools").copy(plan_file) + return config + def load_cache(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict): """Save runner to cache @@ -545,7 +604,7 @@ def execute_before_build(self, *args, **kwargs): self._graph_id = self._infer_graph_id(kwargs) self._processed_tensor = {} if self.on_debug(3, in_forward=False): - self._logger.debug("%sStart Build", self.msg_mark(in_forward=False)) + self._logger.debug(self.msg_mark("Start Build", in_forward=False)) self._execute_before_build(*args, **kwargs) def _execute_before_build(self, *args, **kwargs): @@ -578,7 +637,7 @@ def execute_after_build(self, output: Any) -> Any: if self._enabled: output = self._execute_after_build(output) if self.on_debug(3, in_forward=False): - self._logger.debug("%sEnd Build", self.msg_mark(in_forward=False)) + self._logger.debug(self.msg_mark("End Build", in_forward=False)) return output def _execute_after_build(self, output: Any) -> Any: @@ -612,7 +671,7 @@ def execute_before_forward(self, *args, **kwargs): self._graph_id = self._infer_graph_id(kwargs) self._processed_tensor = {} if self.on_debug(3): - self._logger.debug("%sStart Forward", self.msg_mark()) + self._logger.debug(self.msg_mark("Start Forward")) self._execute_before_forward(*args, **kwargs) def _execute_before_forward(self, *args, **kwargs): @@ -645,11 +704,8 @@ def execute_after_forward(self, output: Any) -> Any: if self._enabled: output = self._execute_after_forward(output) if self.on_debug(3): - self._logger.debug( - "%sEnd Forward, process %d tensors", - self.msg_mark(), - len(self._processed_tensor), - ) + msg = "End Forward, process {} tensors".format(len(self._processed_tensor)) + self._logger.debug(self.msg_mark(msg)) self._forward_cnt += 1 return output @@ -699,20 +755,21 @@ def process_tensor(self, tensor: Any, name: str, consumer: str, scope: str) -> A t_mark += "." + scope cached_tensor = self._get_processed(name, consumer, t_mark) if cached_tensor is not None: - self.debug_tensor(cached_tensor, name, consumer, "cached({})".format(t_mark)) + if msc_utils.is_array(cached_tensor): + self.debug_tensors(name, consumer, t_mark, {"cached": cached_tensor}) return cached_tensor process = self._get_tensor_cache(name, consumer, "process") if process is None: process = self._check_tensor(name, consumer) self._save_tensor_cache(name, consumer, "process", process) - if process and self.on_debug(3): - self._logger.debug("%sprocess tensor %s-%s", self.msg_mark(), name, consumer) if not process: return tensor - tensor = self._process_tensor(tensor, name, consumer, scope, strategys) - self._save_processed(name, consumer, tensor, t_mark) - self.debug_tensor(tensor, name, consumer, "processed({})".format(t_mark)) - return tensor + new_tensor = self._process_tensor(tensor, name, consumer, scope, strategys) + self._save_processed(name, consumer, new_tensor, t_mark) + if msc_utils.is_array(tensor) and id(new_tensor) != id(tensor): + tensors = {"pre": tensor, "post": new_tensor, "diff": tensor - new_tensor} + self.debug_tensors(name, consumer, t_mark, tensors) + return new_tensor def _support_scope(self, scope: str) -> bool: """Check if the scope si supported @@ -862,20 +919,6 @@ def visualize(self, visual_dir: msc_utils.MSCDirectory): return None - def set_plan(self, plan: dict): - """Set the plan - - Parameters - ---------- - plan: dict - The new plan. - """ - - if self._plan: - self._plan = msc_utils.update_dict(self._plan, plan) - else: - self._plan = plan - def finalize(self) -> dict: """Get the plan""" @@ -973,54 +1016,69 @@ def on_debug(self, debug_level: int = 1, in_forward: bool = True) -> bool: return False return self._debug_level >= debug_level - def msg_mark(self, in_forward: bool = True) -> str: - """Get the debug title + def tool_mark(self, msg: Any) -> dict: + """Mark the message with tool info Parameters ------- + msg: str + The message + + Returns + ------- + msg: str + The message with mark. + """ + + return "{}({} @ {}) {}".format(self.tool_type().upper(), self.framework(), self._stage, msg) + + def msg_mark(self, msg: Any, in_forward: bool = True) -> str: + """Mark the message with debug info + + Parameters + ------- + msg: + The message in_forward: bool Whether to add forward mark. Returns ------- - msg_mark: str - Get the debug title. + msg: str + The message with mark. """ - title = "{}.G[{}]".format(self.tool_type().upper(), self._graph_id) + mark = "{}.G[{}]".format(self.tool_type().upper(), self._graph_id) if in_forward: - title += ".F[{}]".format(self._forward_cnt) - title += "({}) ".format(self._stage) - return title + mark += ".F[{}]".format(self._forward_cnt) + mark += "({}) ".format(self._stage) + return mark + str(msg) - def debug_tensor( - self, tensor: Any, name: str, consumer: str, t_mark: str, debug_level: int = 3 + def debug_tensors( + self, name: str, consumer: str, t_mark: str, tensors: Dict[str, Any], debug_level: int = 3 ) -> str: """Get the debug tensor info Parameters ------- - tensor: array_like - The tensor name: str The name of tensor. consumer: str The name of consumer. t_mark: str The mark of tensor. + tensors: dict + The tensors. debug_level: int The given debug_level. """ if self.on_debug(debug_level): - self._logger.debug( - "%s%s %s-%s: %s", - self.msg_mark(), - t_mark, - name, - consumer, - msc_utils.inspect_array(tensor), + msg = "{}-{}({})".format(name, consumer, t_mark) + tensor_des = "\n ".join( + ["{:6s}:{}".format(k, msc_utils.inspect_array(v)) for k, v in tensors.items()] ) + self._logger.debug("%s\n %s", self.msg_mark(msg), tensor_des) def _infer_graph_id(self, kwargs: dict) -> int: """Infer graph id from kwargs @@ -1072,6 +1130,35 @@ def find_node(self, name: str) -> MSCJoint: return g.find_node(name) raise Exception("Can not find node {} from {} graphs".format(name, len(self._graphs))) + def get_tensors(self) -> Iterable[MSCTensor]: + """Get all the tensors in the graphs. + + Returns + ------- + tensors: generator + The generator of nodes. + """ + + for graph in self._graphs: + for tensor in graph.get_tensors(): + yield tensor + + def get_tensor_ids(self) -> Iterable[MSCTensor]: + """Get all the tensor ids in the graphs. + + Returns + ------- + tensors: generator + The generator of nodes. + """ + + for graph in self._graphs: + for node in graph.get_nodes(): + for tensor in node.get_inputs(): + yield self.to_tensor_id(tensor.name, node.name) + for weight in node.get_weights().values(): + yield self.to_tensor_id(weight.name, node.name) + def find_tensor(self, name: str) -> MSCTensor: """Find tensor by name. @@ -1151,7 +1238,7 @@ def get_data(self, name: str) -> np.ndarray: return msc_utils.cast_array(self._weights[name]) raise Exception("Can not find data {} from {} weights".format(name, len(self._weights))) - def _save_tensor_cache(self, name: str, consumer: str, key: str, value: Any): + def _save_tensor_cache(self, name: str, consumer: str, key: str, value: Any) -> Any: """Save the data to tensor cache Parameters @@ -1164,12 +1251,18 @@ def _save_tensor_cache(self, name: str, consumer: str, key: str, value: Any): The data key. value: any The value to cache. + + Returns + ------- + value: any + The saved value. """ tensor_id = self.to_tensor_id(name, consumer) if tensor_id not in self._tensor_cache: self._tensor_cache[tensor_id] = {} self._tensor_cache[tensor_id][key] = value + return value def _get_tensor_cache(self, name: str, consumer: str, key: str) -> Any: """Get the cached tensor data @@ -1212,37 +1305,37 @@ def _get_tensor_strategys(self, name: str, consumer: str) -> List[ToolStrategy]: tensor_id = self.to_tensor_id(name, consumer) mark = "strategy.{}".format(self._stage) - - def _check_strategy(s_ref): - return s_ref in self._strategys and self._strategys[s_ref].support_stage(self._stage) - if mark not in self._tensor_cache.get(tensor_id, {}): strategys = [] - tensor_strategy = self._strategys.get(tensor_id) + + def _add_strategy(ref): + if ref in self._strategys and self._strategys[ref].support_stage(self._stage): + strategys.append(self._strategys[ref]) + return True + return False + + tensor_strategy = self._strategys.get(tensor_id) or self._strategys.get(name) if tensor_strategy and tensor_strategy.support_stage(self._stage): strategys.append(tensor_strategy) elif self.is_weight(name): consumer = self.find_node(consumer) - for ref in [consumer.name, consumer.optype, "default"]: - if _check_strategy(ref + ".weight"): - strategys.append(self._strategys[ref + ".weight"]) - break + for w_type in [consumer.weight_type(name), "weights"]: + for ref in [consumer.name, consumer.optype, "default"]: + if not strategys and _add_strategy(ref + "." + w_type): + break elif consumer == "exit": producer = self.find_producer(name) for ref in [producer.name, producer.optype, "exit", "default"]: - if _check_strategy(ref + ".output"): - strategys.append(self._strategys[ref + ".output"]) + if _add_strategy(ref + ".output"): break else: - consumer = self.find_node(consumer) - for ref in [consumer.name, consumer.optype, "default"]: - if _check_strategy(ref + ".input"): - strategys.append(self._strategys[ref + ".input"]) - break producer = self.find_producer(name) for ref in [producer.name, producer.optype, "default"]: - if _check_strategy(ref + ".output"): - strategys.append(self._strategys[ref + ".output"]) + if _add_strategy(ref + ".output"): + break + consumer = self.find_node(consumer) + for ref in [consumer.name, consumer.optype, "default"]: + if _add_strategy(ref + ".input"): break self._save_tensor_cache(name, consumer, mark, strategys) return self._get_tensor_cache(name, consumer, mark) @@ -1274,6 +1367,10 @@ def _get_tensor_strategy(self, name: str, consumer: str) -> ToolStrategy: def get_graph(self): return self._graphs[self._graph_id] + @property + def plan(self): + return self._plan + @classmethod def tool_type(cls): return ToolType.BASE @@ -1337,13 +1434,12 @@ def _reset( for graph in graphs ] self._logger.debug( - "%s reset %d weight graphs", self.tool_type(), len(self._weight_graphs) + "%s build %d weight graphs", self.tool_type(), len(self._weight_graphs) ) if self.on_debug(2, in_forward=False): - for idx, graph in enumerate(self._weight_graphs): - self._logger.debug( - msc_utils.msg_block("WEIGHT_GRAPH[{}].INFO".format(idx), graph.inspect()) - ) + weight_graphs = {g.name: g.inspect() for g in self._weight_graphs} + title = self.tool_mark("WEIGHT_GRAPHS({})".format(len(weight_graphs))) + self._logger.debug(msc_utils.msg_block(title, weight_graphs)) return graphs, weights def _get_wtypes(self) -> Tuple[Dict[str, List[str]], Dict[str, str]]: diff --git a/python/tvm/contrib/msc/core/tools/track/__init__.py b/python/tvm/contrib/msc/core/tools/track/__init__.py index 2c82a6d48627..cdcf16fad3af 100644 --- a/python/tvm/contrib/msc/core/tools/track/__init__.py +++ b/python/tvm/contrib/msc/core/tools/track/__init__.py @@ -18,3 +18,4 @@ from .tracker import * from .method import * +from .configer import * diff --git a/python/tvm/contrib/msc/core/tools/track/configer.py b/python/tvm/contrib/msc/core/tools/track/configer.py new file mode 100644 index 000000000000..fafb30d4842c --- /dev/null +++ b/python/tvm/contrib/msc/core/tools/track/configer.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.tools.track.configer""" + +from tvm.contrib.msc.core.tools.tool import ToolType +from tvm.contrib.msc.core.tools.configer import ToolConfiger +from tvm.contrib.msc.core.utils import MSCStage +from tvm.contrib.msc.core import utils as msc_utils + + +class TrackConfiger(ToolConfiger): + """Configer for track""" + + @property + def apply_once(self): + return False + + @classmethod + def tool_type(cls): + return ToolType.TRACKER + + +@msc_utils.register_tool_configer +class DefaultTrackConfiger(TrackConfiger): + """Default configer for track""" + + def config_tool(self) -> dict: + """Get the default config of tool + + Returns + ------- + config: dict + The default config. + """ + + return { + "plan_file": "msc_tracker.json", + "strategys": [ + { + "methods": { + "output": { + "method_name": "save_compared", + "compare_to": { + MSCStage.OPTIMIZE: [MSCStage.BASELINE], + MSCStage.COMPILE: [MSCStage.OPTIMIZE, MSCStage.BASELINE], + }, + } + }, + "op_types": ["nn.relu"], + } + ], + } + + @classmethod + def config_style(cls): + return "default" diff --git a/python/tvm/contrib/msc/core/tools/track/method.py b/python/tvm/contrib/msc/core/tools/track/method.py index a86a6af881f3..7d02456f4359 100644 --- a/python/tvm/contrib/msc/core/tools/track/method.py +++ b/python/tvm/contrib/msc/core/tools/track/method.py @@ -62,7 +62,7 @@ def save_compared( config = {"info": msc_utils.inspect_array(data)} # save the data tracker._saver.save_datas({name: data}, tracker._forward_cnt) - tracker.debug_tensor(data, name, consumer, "save") + tracker.debug_tensors(name, consumer, "save_compares", {"save": data}) # compare datas if tracker._stage in compare_to: diffs = {} @@ -72,13 +72,11 @@ def save_compared( continue golden = tracker._loaders[stage].load_data(name, tracker._forward_cnt) report = msc_utils.compare_arrays({name: golden}, {name: data}) - diff_msg = "{}{} to {} -> {}".format( - tracker.msg_mark(), name, stage, report["info"][name] - ) + diff_msg = "{} to {} -> {}".format(name, stage, report["info"][name]) if report["passed"] == 0: - tracker._logger.info(diff_msg) + tracker._logger.info(tracker.msg_mark(diff_msg)) elif tracker.on_debug(): - tracker._logger.debug(diff_msg) + tracker._logger.debug(tracker.msg_mark(diff_msg)) diffs[stage] = { "pass": report["passed"] == 1, "info": msc_utils.inspect_array(np.abs(golden - data)), @@ -94,5 +92,9 @@ def framework(cls): def tool_type(cls): return ToolType.TRACKER + @classmethod + def method_style(cls): + return "default" + msc_utils.register_tool_method(TrackMethod) diff --git a/python/tvm/contrib/msc/core/tools/track/tracker.py b/python/tvm/contrib/msc/core/tools/track/tracker.py index e43a390e850f..bb60b9fe8b2d 100644 --- a/python/tvm/contrib/msc/core/tools/track/tracker.py +++ b/python/tvm/contrib/msc/core/tools/track/tracker.py @@ -33,11 +33,10 @@ def setup(self) -> dict: The setup info. """ - # filter plan - def _filter_info(info: dict) -> dict: - return {k: v for k, v in info.items() if k != self._stage} + suffix = "." + msc_utils.MSCStage.TRACK + if self._stage.endswith(suffix): + self.change_stage(self._stage[: -len(suffix)]) - self._plan = {k: _filter_info(v) for k, v in self._plan.items()} data_folder = msc_utils.get_dataset_dir().create_dir("Track") self._loaders = {} for folder in data_folder.listdir(): @@ -46,7 +45,7 @@ def _filter_info(info: dict) -> dict: if msc_utils.is_simple_dataset(data_folder.relpath(folder)): self._loaders[folder] = msc_utils.SimpleDataLoader(data_folder.relpath(folder)) self._saver = msc_utils.SimpleDataSaver(data_folder.relpath(self._stage)) - self._max_iter = self._options.get("max_iter", 1) + self._max_iter, self._tracked = self._options.get("max_iter", 1), False info = super().setup() info.update({"saver": self._saver, "loaders": self._loaders}) return info @@ -55,7 +54,7 @@ def finalize(self) -> dict: """Get the plan""" self._saver.finalize() - return super().finalize() + return {} def _execute_after_forward(self, output: Any) -> Any: """Execute after model forward @@ -89,6 +88,8 @@ def _execute_after_forward(self, output: Any) -> Any: ["{}: {}/{}".format(s, i["passed"], i["total"]) for s, i in passed.items()] ) self._logger.info(msg) + else: + self._tracked = True return output def _check_tensor(self, name: str, consumer: str) -> bool: @@ -175,6 +176,10 @@ def _track_tensor( plan.update(strategy(self, tensor, name, consumer)) return tensor + @property + def tracked(self): + return self._tracked + @classmethod def tool_type(cls): return ToolType.TRACKER diff --git a/python/tvm/contrib/msc/core/transform/transform.py b/python/tvm/contrib/msc/core/transform/transform.py index ddcfffc210fa..fe8882f7f296 100644 --- a/python/tvm/contrib/msc/core/transform/transform.py +++ b/python/tvm/contrib/msc/core/transform/transform.py @@ -17,13 +17,18 @@ # pylint: disable=invalid-name """tvm.contrib.msc.core.transform.transform""" +from typing import Dict + import tvm from tvm.relax.transform import _ffi_api as relax_api from tvm.relay.transform import _ffi_api as relay_api def SetExprName( - as_relax: bool = True, entry_name: str = "main", target: str = "" + as_relax: bool = True, + entry_name: str = "main", + target: str = "", + var_names: Dict[str, str] = None, ) -> tvm.ir.transform.Pass: """Set name for the call and constant in IRModule. @@ -35,6 +40,8 @@ def SetExprName( The entry name target: str The target prefix for target functions + var_names: dict + The var names. Returns ------- @@ -42,7 +49,13 @@ def SetExprName( """ if as_relax: - return relax_api.SetRelaxExprName(entry_name, target) # type: ignore + + def _get_name(name): + return name.replace("/", "_").replace(".", "_").strip("_") + + var_names = var_names or {} + var_names = {k: _get_name(v) for k, v in var_names.items()} + return relax_api.SetRelaxExprName(entry_name, target, var_names) # type: ignore return relay_api.SetRelayExprName(entry_name) # type: ignore @@ -136,3 +149,25 @@ def SetBYOCAttrs(target, entry_name: str = "main") -> tvm.ir.transform.Pass: """ return relax_api.SetBYOCAttrs(target, entry_name) # type: ignore + + +def BindNamedParams( + func_name: str, + params: Dict[str, tvm.runtime.NDArray], +) -> tvm.ir.transform.Pass: + """Bind params of function of the module to constant tensors with span names. + + Parameters + ---------- + func_name: str + The function name to be bound + params: dict + The map from parameter or parameter name to constant + tensors. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + + return relax_api.BindNamedParams(func_name, params) # type: ignore diff --git a/python/tvm/contrib/msc/core/utils/arguments.py b/python/tvm/contrib/msc/core/utils/arguments.py index dba54da3a4e8..a1b8e918e8ac 100644 --- a/python/tvm/contrib/msc/core/utils/arguments.py +++ b/python/tvm/contrib/msc/core/utils/arguments.py @@ -19,7 +19,7 @@ import os import json import copy -import numpy as np +from typing import Any from .info import MSCArray @@ -39,6 +39,8 @@ def load_dict(str_dict: str, flavor: str = "json") -> dict: The loaded dict. """ + if not str_dict: + return {} if isinstance(str_dict, str) and os.path.isfile(str_dict): with open(str_dict, "r") as f: dict_obj = json.load(f) @@ -52,6 +54,29 @@ def load_dict(str_dict: str, flavor: str = "json") -> dict: return dict_obj +def save_dict(dict_obj: Any, path: str, indent: int = 2) -> str: + """Save dict object + + Parameters + ---------- + dict_obj: + The object that can be load as dict. + path: str + The output path. + indent: int + The indent + + Returns + ------- + path: str + The output path. + """ + + with open(path, "w") as f: + f.write(json.dumps(load_dict(dict_obj), indent=indent)) + return path + + def update_dict(src_dict: dict, new_dict: dict, soft_update: bool = True) -> dict: """Update src_dict with new_dict. @@ -116,20 +141,22 @@ def _get_lines(value, indent=2): lines.append("{}{}:".format(indent * " ", k)) lines.extend(_get_lines(v, indent + 2)) elif isinstance(v, (tuple, list)) and len(str(k) + str(v)) > max_size: - if all(isinstance(e, (int, float)) for e in v): + if MSCArray.is_array(v): lines.append("{}{}: {}".format(indent * " ", k, MSCArray(v).abstract())) else: lines.append("{}{}:".format(indent * " ", k)) - lines.extend( - [ - "{}<{}>{}".format((indent + 2) * " ", idx, ele) - for idx, ele in enumerate(v) - ] - ) + for idx, ele in enumerate(v): + if isinstance(ele, dict) and len(str(ele)) > max_size: + lines.append("{}[{}.{}]:".format((indent + 2) * " ", k, idx)) + lines.extend(_get_lines(ele, indent + 4)) + else: + lines.append("{}<{}>{}".format((indent + 2) * " ", idx, ele)) elif isinstance(v, bool): lines.append("{}{}: {}".format(indent * " ", k, "true" if v else "false")) - elif isinstance(v, np.ndarray): + elif MSCArray.is_array(v): lines.append("{}{}: {}".format(indent * " ", k, MSCArray(v).abstract())) + elif hasattr(v, "__name__"): + lines.append("{}{}: {}({})".format(indent * " ", k, v.__name__, type(v))) else: lines.append("{}{}: {}".format(indent * " ", k, v)) return lines @@ -220,9 +247,11 @@ def map_dict(dict_obj: dict, mapper: callable) -> dict: new_dict = {} for k, v in dict_obj.items(): if isinstance(v, (tuple, list)): - new_dict[k] = [map_dict(e, mapper) if isinstance(e, dict) else e for e in v] + new_dict[k] = [ + map_dict(mapper(e), mapper) if isinstance(e, dict) else mapper(e) for e in v + ] elif isinstance(v, dict): - new_dict[k] = map_dict(v, mapper) + new_dict[k] = map_dict(mapper(v), mapper) else: new_dict[k] = mapper(v) return new_dict diff --git a/python/tvm/contrib/msc/core/utils/dataset.py b/python/tvm/contrib/msc/core/utils/dataset.py index 8ca8d8ae1a0d..3da57abb4384 100644 --- a/python/tvm/contrib/msc/core/utils/dataset.py +++ b/python/tvm/contrib/msc/core/utils/dataset.py @@ -24,6 +24,7 @@ import numpy as np from .arguments import load_dict +from .info import cast_array class BaseDataLoader(object): @@ -344,6 +345,7 @@ def _save_data(self, index: int, name: str, data: np.ndarray, collect: str) -> s The folder that data saved to. """ + data = cast_array(data) save_name = name.replace("/", "_").replace(":", "_") sub_folder = f_path = os.path.join(self._folder, save_name) if not os.path.isdir(sub_folder): @@ -428,6 +430,8 @@ def finalize(self): """Finalize the saver""" super().finalize() + if "inputs" not in self._info: + return with open(os.path.join(self._folder, "datas_info.txt"), "w") as f: for name in self._input_names: info = self._info["inputs"][name] diff --git a/python/tvm/contrib/msc/core/utils/info.py b/python/tvm/contrib/msc/core/utils/info.py index 49d2bdd96a9b..26afedfa282d 100644 --- a/python/tvm/contrib/msc/core/utils/info.py +++ b/python/tvm/contrib/msc/core/utils/info.py @@ -35,24 +35,24 @@ class MSCArray(object): """ def __init__(self, data: Any): - self._type, self._device, self._data = self._analysis(data) + self._meta_data = data + self._framework, self._type, self._device = self._analysis(data) def __str__(self): - return "<{}>{}".format(self._type, self.abstract()) + return "<{} @{}>{}".format(self._framework, self._device, self.abstract()) def _analysis(self, data: Any) -> Tuple[str, str, np.ndarray]: if isinstance(data, (list, tuple)) and all(isinstance(d, (int, float)) for d in data): - return "list", "cpu", np.array(data) + return MSCFramework.MSC, "list", "cpu" if isinstance(data, np.ndarray): - return "np", "cpu", data + return MSCFramework.MSC, "tensor", "cpu" if isinstance(data, tvm.runtime.NDArray): device = tvm.runtime.Device.MASK2STR[data.device.device_type] if data.device.device_id: device += ":{}".format(data.device.device_id) - return "tvm", device, data.asnumpy() + return MSCFramework.TVM, "tensor", device if isinstance(data, tvm.relax.Var): - shape = [int(s) for s in data.struct_info.shape] - return "var", "cpu", np.zeros(shape, dtype=data.struct_info.dtype) + return MSCFramework.TVM, "var", "cpu" try: import torch # pylint: disable=import-outside-toplevel @@ -62,7 +62,7 @@ def _analysis(self, data: Any) -> Tuple[str, str, np.ndarray]: device = "{}:{}".format(ref_dev.type, ref_dev.index) else: device = ref_dev.type - return "torch", device, data.detach().cpu().numpy() + return MSCFramework.TORCH, "tensor", device except: # pylint: disable=bare-except pass @@ -71,16 +71,63 @@ def _analysis(self, data: Any) -> Tuple[str, str, np.ndarray]: def abstract(self) -> str: """Get abstract describe of the data""" - return "[S:{},D:{}] Max {:g}, Min {:g}, Avg {:g}".format( - ";".join([str(s) for s in self._data.shape]), - self._data.dtype.name, - self._data.max(), - self._data.min(), - self._data.sum() / self._data.size, + data = self._to_ndarray() + if data.size < 10: + return ",".join([str(i) for i in data.flatten()]) + return "[{},{}] Max {:g}, Min {:g}, Avg {:g}".format( + ";".join([str(s) for s in data.shape]), + data.dtype.name, + data.max(), + data.min(), + data.sum() / data.size, ) - def cast(self, framework: str, device: str = None) -> Any: - """Cast np.ndarray to array like object + def _to_ndarray(self) -> np.ndarray: + """Cast array like object to np.ndarray + + Returns + ------- + data: np.ndarray + The data as np.ndarray. + """ + + if self._framework == MSCFramework.MSC: + if self._type == "list": + return np.array(self._meta_data) + return self._meta_data + if self._framework == MSCFramework.TVM: + if self._type == "var": + shape = [int(s) for s in self._meta_data.struct_info.shape] + return np.zeros(shape, dtype=self._meta_data.struct_info.dtype) + return self._meta_data.asnumpy() + if self._framework == MSCFramework.TORCH: + return self._meta_data.detach().cpu().numpy() + return self._meta_data + + def _to_device(self, device: str) -> Any: + """Cast array like object to array like object + + Parameters + ---------- + device: str + The device for tensor. + + Returns + ------- + output: + The output as framework tensor. + """ + + if self._device == device: + return self._meta_data + if self._framework == MSCFramework.TORCH: + return self._meta_data.to(self.get_device(device)) + if self._framework == MSCFramework.TVM: + return tvm.nd.array(self._cast_data(), device=self.get_device(device)) + return self._meta_data + + def cast(self, framework: str, device: str = "cpu") -> Any: + """Cast array like object to array like object Parameters ---------- @@ -96,20 +143,48 @@ def cast(self, framework: str, device: str = None) -> Any: """ device = device or self._device + if framework == self._framework and device == self._device and self._type == "tensor": + return self._meta_data + if framework == self._framework: + return self._to_device(device) + data = self._to_ndarray() if framework == MSCFramework.TORCH: import torch # pylint: disable=import-outside-toplevel - return torch.from_numpy(self._data).to(torch.device(device)) + return torch.from_numpy(data).to(self.get_device(device, framework)) + if framework == MSCFramework.TVM: + return tvm.nd.array(data, device=self.get_device(device, framework)) + return data + + def get_device(self, device: str, framework: str = None) -> Any: + """Change device from name to device obj + + Parameters + ---------- + device: str + The device for tensor. + framework: str + The target framework. + + Returns + ------- + device: any + The device object. + """ + + framework = framework or self._framework if framework == MSCFramework.TVM: if device.startswith("cpu"): - t_device = tvm.cpu() - elif device.startswith("cuda"): + return tvm.cpu() + if device.startswith("cuda"): dev_id = int(device.split(":")[1]) if ":" in device else 0 - t_device = tvm.cuda(dev_id) - else: - raise NotImplementedError("device {} is not supported for tvm") - return tvm.nd.array(self._data, device=t_device) - return self._data + return tvm.cuda(dev_id) + raise TypeError("Unexpected tvm device " + str(device)) + if framework == MSCFramework.TORCH: + import torch # pylint: disable=import-outside-toplevel + + return torch.device(device) + return device @classmethod def is_array(cls, data: Any) -> bool: @@ -142,19 +217,36 @@ def is_array(cls, data: Any) -> bool: return False @property - def type(self): - return self._type + def framework(self): + return self._framework @property def device(self): return self._device @property - def data(self): - return self._data + def type(self): + return self._type -def cast_array(data: Any, framework: str = None, device: str = None) -> Any: +def is_array(data: Any) -> bool: + """Check if the data is array + + Parameters + ---------- + data: array_like: np.ndarray| torch.Tensor| tvm.ndarray| ... + The data object. + + Returns + ------- + is_array: bool + Whether the data is array. + """ + + return MSCArray.is_array(data) + + +def cast_array(data: Any, framework: str = MSCFramework.MSC, device: str = "cpu") -> Any: """Cast array like object to np.ndarray Parameters @@ -173,8 +265,6 @@ def cast_array(data: Any, framework: str = None, device: str = None) -> Any: """ assert MSCArray.is_array(data), "{} is not array like".format(data) - if not framework: - return MSCArray(data).data return MSCArray(data).cast(framework, device) @@ -293,7 +383,7 @@ def get_version(framework: str) -> List[int]: raw_version = "1.0.0" except: # pylint: disable=bare-except raw_version = "1.0.0" - + raw_version = raw_version or "1.0.0" return LooseVersion(raw_version).version diff --git a/python/tvm/contrib/msc/core/utils/message.py b/python/tvm/contrib/msc/core/utils/message.py index 7ff0e187b05b..1479a99dd5db 100644 --- a/python/tvm/contrib/msc/core/utils/message.py +++ b/python/tvm/contrib/msc/core/utils/message.py @@ -18,9 +18,9 @@ import datetime import logging -from typing import List +from typing import List, Tuple -from .arguments import dump_dict +from .arguments import dump_dict, map_dict from .log import get_global_logger from .namespace import MSCMap, MSCKey @@ -31,14 +31,27 @@ class MSCStage(object): SETUP = "setup" PREPARE = "prepare" PARSE = "parse" - BASELINE = "baseline" PRUNE = "prune" QUANTIZE = "quantize" DISTILL = "distill" + TRACK = "track" + BASELINE = "baseline" OPTIMIZE = "optimize" COMPILE = "compile" SUMMARY = "summary" - ALL = [SETUP, PREPARE, PARSE, BASELINE, PRUNE, QUANTIZE, DISTILL, OPTIMIZE, COMPILE, SUMMARY] + ALL = [ + SETUP, + PREPARE, + PARSE, + PRUNE, + QUANTIZE, + DISTILL, + TRACK, + BASELINE, + OPTIMIZE, + COMPILE, + SUMMARY, + ] @classmethod def all_stages(cls) -> List[str]: @@ -73,7 +86,8 @@ def time_stamp(stage: str, log_stage: bool = True, logger: logging.Logger = None logger.info("\n{0} {1} {0}".format("#" * 20, start_msg.center(40))) MSCMap.set(MSCKey.MSC_STAGE, stage.upper()) elif log_stage: - logger.debug("Start {}".format(stage)) + start_msg = "Start {}".format(stage) + logger.debug("\n{0} {1} {0}".format("+" * 20, start_msg.center(40))) def get_duration() -> dict: @@ -89,65 +103,43 @@ def get_duration() -> dict: if not time_stamps: return {} - def _get_duration(start_idx, end_idx): - return (time_stamps[end_idx][1] - time_stamps[start_idx][1]).total_seconds() - - total = _get_duration(0, -1) - duration = {"total": total} - for idx in range(len(time_stamps) - 1): - duration[time_stamps[idx][0]] = _get_duration(idx, idx + 1) - sub_durations = {} - for stage, _ in time_stamps: - if stage not in duration: - continue - if "." in stage: - main_stage = stage.split(".")[0] - if main_stage not in sub_durations: - sub_durations[main_stage] = {"total": 0} - if main_stage in duration and "init" not in sub_durations[main_stage]: - sub_durations[main_stage]["init"] = duration[main_stage] - sub_durations[main_stage]["total"] += duration[main_stage] - sub_duration = duration.pop(stage) - sub_durations[main_stage][stage.replace(main_stage + ".", "")] = sub_duration - sub_durations[main_stage]["total"] += sub_duration - - # change to report format - def _to_str(dur): - return "{:.2f} s({:.2f}%)".format(dur, dur * 100 / total) - - for sub_dur in sub_durations.values(): - for stage in sub_dur: - sub_dur[stage] = _to_str(sub_dur[stage]) - for stage in duration: - duration[stage] = _to_str(duration[stage]) - duration.update(sub_durations) - return duration - + def _get_duration(idx): + return (time_stamps[idx + 1][1] - time_stamps[idx][1]).total_seconds() -def msg_table(title: str, msg: str, width: int = 100): - """Log message in table format - - Parameters - ---------- - title: str - The title of the block - msg: str - The message to log. - width: int - The max width of block message + def _set_stage(stage: str, info: Tuple[float, dict], collect: dict): + if "." in stage: + main_stage, sub_stage = stage.split(".", 1) + _set_stage(sub_stage, info, collect.setdefault(main_stage, {})) + else: + collect[stage] = info + + def _set_total(collect: dict): + collect["total"] = 0 + for dur in collect.values(): + collect["total"] += _set_total(dur) if isinstance(dur, dict) else dur + return collect["total"] + + duration, depth = {}, 1 + left_durs = {time_stamps[i][0]: _get_duration(i) for i in range(len(time_stamps) - 1)} + while left_durs: + current_durs = {s: dur for s, dur in left_durs.items() if len(s.split(".")) == depth} + left_durs = {k: v for k, v in left_durs.items() if k not in current_durs} + for stage, dur in current_durs.items(): + info = {"init": dur} if any(s.startswith(stage + ".") for s in left_durs) else dur + _set_stage(stage, info, duration) + depth += 1 + + _set_total(duration) - Returns - ------- - msg: str - The block message. - """ + def _to_str(dur): + if not isinstance(dur, float): + return dur + return "{:.2f} s({:.2f}%)".format(dur, dur * 100 / duration["total"]) - if isinstance(msg, dict): - msg = dump_dict(msg, "table:" + str(width)) - return "\n{0} {1} {0}\n{2}\n".format("-" * 20, title.center(40), msg) + return map_dict(duration, _to_str) -def msg_block(title: str, msg: str, width: int = 100): +def msg_block(title: str, msg: str, width: int = 100, symbol: str = "-"): """Log message in block format Parameters @@ -158,6 +150,8 @@ def msg_block(title: str, msg: str, width: int = 100): The message to log. width: int The max width of block message + symbol: str + The split symbol. Returns ------- @@ -167,7 +161,7 @@ def msg_block(title: str, msg: str, width: int = 100): if isinstance(msg, dict): msg = dump_dict(msg, "table:" + str(width)) - return "\n{0} {1} {0}\n{2}\n{3} {1} {3}".format(">" * 20, title.center(40), msg, "<" * 20) + return "\n{0} {1} {0}\n{2}".format(symbol * 20, title.center(40), msg) def current_stage(): diff --git a/python/tvm/contrib/msc/core/utils/namespace.py b/python/tvm/contrib/msc/core/utils/namespace.py index 6744548ddfc4..330499764159 100644 --- a/python/tvm/contrib/msc/core/utils/namespace.py +++ b/python/tvm/contrib/msc/core/utils/namespace.py @@ -67,6 +67,7 @@ class MSCKey: TRACKERS = "trackers" FUSED_CNT = "fused_cnt" + ROOT_MARK = "$" class MSCFramework: diff --git a/python/tvm/contrib/msc/core/utils/register.py b/python/tvm/contrib/msc/core/utils/register.py index 855c28f8b4b2..ae7c8eac03b3 100644 --- a/python/tvm/contrib/msc/core/utils/register.py +++ b/python/tvm/contrib/msc/core/utils/register.py @@ -27,6 +27,7 @@ class MSCRegistery: MSC_FUNCS = "msc_funcs" MSC_TOOLS_CLS = "msc_tools_cls" MSC_TOOLS_METHOD = "msc_tools_method" + TOOL_CONFIGERS = "tool_configers" GYM_CONFIGERS = "gym_configers" GYM_CONTROLLERS = "gym_controllers" GYM_AGENTS = "gym_agents" @@ -192,6 +193,44 @@ def get_registered_tool_method( return tools_method.get(framework, {}).get(register_name) +def register_tool_configer(configer: Any): + """Register a tool configer. + + Parameters + ---------- + configer: class + The configer class. + """ + + for key in ["tool_type", "config_style"]: + assert hasattr(configer, key), "{} should be given to register tool configer".format(key) + tool_configers = MSCRegistery.get(MSCRegistery.TOOL_CONFIGERS, {}) + col = tool_configers.setdefault(configer.tool_type(), {}) + col[configer.config_style()] = configer + MSCRegistery.register(MSCRegistery.TOOL_CONFIGERS, tool_configers) + return configer + + +def get_registered_tool_configer(tool_type: str, config_style: str) -> Any: + """Get the registered configer. + + Parameters + ---------- + tool_type: string + The type of tool. + config_style: string + The style of tool. + + Returns + ------- + configer: class + The configer class. + """ + + tool_configers = MSCRegistery.get(MSCRegistery.TOOL_CONFIGERS, {}) + return tool_configers.get(tool_type, {}).get(config_style) + + def register_gym_configer(configer: Any): """Register a gym configer. diff --git a/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py b/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py index c33fc89fa790..2fff6d1c75dc 100644 --- a/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py @@ -29,6 +29,7 @@ from tvm.contrib.msc.core.runtime import ModelRunner from tvm.contrib.msc.core.utils.message import MSCStage from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils from tvm.contrib.msc.framework.tensorflow.frontend import from_tensorflow from tvm.contrib.msc.framework.tensorflow.codegen import to_tensorflow from tvm.contrib.msc.framework.tensorflow import tf_v1 @@ -154,7 +155,8 @@ def _call_runnable( The outputs in list or dict. """ - feed_dict = {i["name"] + ":0": inputs[i["name"]] for i in self.get_inputs()} + input_names = [i["name"] for i in self.get_inputs()] + feed_dict = {i + ":0": msc_utils.cast_array(inputs[i]) for i in input_names} return runnable.run(self._tf_outputs, feed_dict) def _device_enabled(self, device: str) -> bool: @@ -182,13 +184,15 @@ def framework(self): return MSCFramework.TENSORFLOW @classmethod - def load_native(cls, model: Any) -> Tuple[tf_v1.GraphDef, str, bool]: + def load_native(cls, model: Any, config: dict) -> Tuple[tf_v1.GraphDef, str, bool]: """Load the native model Parameters ------- model: The native model. + config: dict + The config for pipeline. Returns ------- diff --git a/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py b/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py index 43e85b601579..d74a6a42461c 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py @@ -56,7 +56,7 @@ def train(self): raise Exception("TensorRT only support eval") - def apply_tool(self, tool_type: str, data_loader: Any = None) -> dict: + def make_plan(self, tool_type: str, data_loader: Any = None) -> dict: """Execute tool and get plan Parameters @@ -76,7 +76,7 @@ def apply_tool(self, tool_type: str, data_loader: Any = None) -> dict: self._generate_model(self._graphs, self._weights) quantizer.calibrate() assert quantizer.calibrated, "Failed to calibrate the tenosrrt quantizer" - return super().apply_tool(tool_type, data_loader) + return super().make_plan(tool_type, data_loader) def _generate_model(self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.array]) -> Any: """Codegen the model according to framework diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py index f97118619603..e2402e2dfa62 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py +++ b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py @@ -188,7 +188,7 @@ def _execute_before_forward(self, step_context: dict) -> dict: {name: data.asnumpy() for name, data in step_context["datas"].items()} ) for name, data in step_context["datas"].items(): - self.debug_tensor(data, name, "any", "ctx_gathered") + self.debug_tensors(name, "any", "ctx_gather", {"gather": data}) super()._execute_before_forward(step_context) def _quantize_tensor( @@ -261,12 +261,8 @@ def config_generate(self, generate_config: Dict[str, Any]) -> Dict[str, Any]: generate_config["codegen"], self._calibrate_savers, self._range_files ): saver.finalize() - self._logger.debug( - "%ssave %d datas to %s", - self.msg_mark(in_forward=False), - self._forward_cnt, - saver.folder, - ) + msg = "Save {} batch to {}".format(self._forward_cnt, saver.folder) + self._logger.debug(self.msg_mark(msg, in_forward=False)) config.update( {"dataset": saver.folder, "range_file": r_file, "precision": "int8"} ) diff --git a/python/tvm/contrib/msc/framework/torch/runtime/runner.py b/python/tvm/contrib/msc/framework/torch/runtime/runner.py index 97dbdebcb3a9..67812e7e5219 100644 --- a/python/tvm/contrib/msc/framework/torch/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/torch/runtime/runner.py @@ -102,19 +102,34 @@ def _call_runnable( The outputs in list. """ - model_inputs = self.get_inputs() - parameters = list(runnable.parameters()) - if parameters: - in_dev = parameters[0].device - elif device == "cpu": - in_dev = torch.device(device) - elif device.startswith("cuda"): - in_dev = torch.device(device) - else: - raise NotImplementedError("Unsupported device " + str(device)) - torch_inputs = [torch.from_numpy(inputs[i["name"]]).to(in_dev) for i in model_inputs] + input_names = [i["name"] for i in self.get_inputs()] + torch_inputs = [ + msc_utils.cast_array(inputs[i], MSCFramework.TORCH, device) for i in input_names + ] return runnable(*torch_inputs) + def _get_runtime_params(self) -> Dict[str, tvm.nd.array]: + """Get the runtime parameters + + Returns + ------- + params: dict + The parameters from runtime. + """ + + assert self._runnable, "runnable is needed to get params" + state_dict = self._runnable.state_dict() + params = {} + for graph in self._graphs: + for weight in graph.get_weights(): + assert weight.alias in state_dict, "Missing weight {} in state_dict".format( + weight.alias + ) + params[weight.name] = msc_utils.cast_array( + state_dict[weight.alias], MSCFramework.TVM, "cpu" + ) + return params + def _device_enabled(self, device: str) -> bool: """Check if the device is enabled @@ -139,13 +154,15 @@ def framework(self): return MSCFramework.TORCH @classmethod - def load_native(cls, model: Any) -> Tuple[torch.nn.Module, str, bool]: + def load_native(cls, model: Any, config: dict) -> Tuple[torch.nn.Module, str, bool]: """Load the native model Parameters ------- model: The native model. + config: dict + The config for pipeline. Returns ------- @@ -249,10 +266,16 @@ def run_native( parameters = list(model.parameters()) if parameters: - device = parameters[0].device + ref_dev = parameters[0].device + if ref_dev.index: + device = "{}:{}".format(ref_dev.type, ref_dev.index) + else: + device = ref_dev.type else: - device = torch.device("cpu") - torch_inputs = [torch.from_numpy(inputs[i_name]).to(device) for i_name in input_names] + device = "cpu" + torch_inputs = [ + msc_utils.cast_array(inputs[i], MSCFramework.TORCH, device) for i in input_names + ] def _run_once(): return model(*torch_inputs) diff --git a/python/tvm/contrib/msc/framework/torch/tools/distill/distiller.py b/python/tvm/contrib/msc/framework/torch/tools/distill/distiller.py index ee5c895603e4..688cfd8b30b9 100644 --- a/python/tvm/contrib/msc/framework/torch/tools/distill/distiller.py +++ b/python/tvm/contrib/msc/framework/torch/tools/distill/distiller.py @@ -79,8 +79,8 @@ def build_model(self, teacher: Any, student: Any) -> Any: raise NotImplementedError("optimizer {} is not supported".format(optimizer)) # Get loss function - loss_strategy = self._strategys.get("loss.output") - assert loss_strategy, "Can not find loss.output in strategys" + loss_strategy = self._strategys.get("loss") + assert loss_strategy, "Can not find loss in strategys" def get_loss(teacher_outputs, student_outputs): return loss_strategy(self, teacher_outputs, student_outputs) diff --git a/python/tvm/contrib/msc/framework/torch/tools/quantize/method.py b/python/tvm/contrib/msc/framework/torch/tools/quantize/method.py index 6f82a796e167..9b36d89b7b93 100644 --- a/python/tvm/contrib/msc/framework/torch/tools/quantize/method.py +++ b/python/tvm/contrib/msc/framework/torch/tools/quantize/method.py @@ -14,16 +14,47 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=unused-argument +# pylint: disable=unused-argument, arguments-differ """tvm.contrib.msc.framework.torch.tools.quantize.method""" +from functools import wraps import numpy as np + import torch +from torch.autograd import Function from tvm.contrib.msc.core.tools.quantize import QuantizeMethod, BaseQuantizer from tvm.contrib.msc.core.utils.namespace import MSCFramework from tvm.contrib.msc.core import utils as msc_utils +def fake_quantize(func): + """Fake quantize without backward""" + + @wraps(func) + def wrapper( + cls, quantizer: BaseQuantizer, data: torch.Tensor, name: str, consumer: str, *args, **kwargs + ): + func_name = "quantize_func." + func.__name__ + quantize_func = quantizer._get_tensor_cache(name, consumer, func_name) + if quantize_func is None: + + class FakeQuantize(Function): + """Fake quantize func for torch""" + + @staticmethod + def forward(ctx, data): + return func(cls, quantizer, data, name, consumer, *args, **kwargs) + + @staticmethod + def backward(ctx, grad_outputs): + return grad_outputs + + quantize_func = quantizer._save_tensor_cache(name, consumer, func_name, FakeQuantize) + return quantize_func.apply(data) + + return wrapper + + class TorchQuantizeMethod(QuantizeMethod): """Default quantize method for torch""" @@ -174,6 +205,7 @@ def gather_max_per_channel( return {"scale": scale, "sign": sign, "axis": axis, "calibrated": True} @classmethod + @fake_quantize def quantize_normal( cls, quantizer: BaseQuantizer, diff --git a/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py b/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py index 4038b74b7ea2..3c964464043a 100644 --- a/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py +++ b/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py @@ -19,11 +19,9 @@ from typing import Dict, Optional, Any import tvm -from tvm.relax.transform import BindParams from tvm.contrib.msc.core.ir import MSCGraph -from tvm.contrib.msc.core.codegen import CodeGen +from tvm.contrib.msc.core import codegen as msc_codegen from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.framework.tvm import _ffi_api def to_relax( @@ -57,34 +55,4 @@ def to_relax( The IRModule of relax. """ - inputs = [ - tvm.relax.Var(i.alias, tvm.relax.TensorStructInfo(i.get_shape(), i.dtype_name)) - for i in graph.get_inputs() - ] - - def _save_weights(folder: msc_utils.MSCDirectory): - if weights: - with open(folder.relpath(graph.name + "_params.bin"), "wb") as f_params: - f_params.write(tvm.runtime.save_param_dict(weights)) - - # pylint: disable=unused-argument - def _post_proc(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRModule: - if weights: - mod = BindParams("main", weights)(mod) - return tvm.ir.transform.Sequential( - [ - # The canonicalization of relax variable bindings is not required - # for correctness. It does, however, remove trivial `x = y` - # bindings, preventing test cases from depending on their - # presence. - tvm.relax.transform.CanonicalizeBindings(), - tvm.relax.transform.ConvertToDataflow(min_size=1), - ], - name="tvm.contrib.msc.framework.tvm.codegen.to_relax_postproc", - )(mod) - - codegen = CodeGen(graph, _ffi_api.GetRelaxSources, codegen_config, print_config, build_folder) - model_args = inputs - if plugin: - model_args = model_args + [plugin] - return codegen.load(model_args, pre_load=_save_weights, post_load=_post_proc) + return msc_codegen.to_relax(graph, weights, codegen_config, print_config, build_folder, plugin) diff --git a/python/tvm/contrib/msc/framework/tvm/runtime/runner.py b/python/tvm/contrib/msc/framework/tvm/runtime/runner.py index 690e146becfd..ab52b8de99d2 100644 --- a/python/tvm/contrib/msc/framework/tvm/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/tvm/runtime/runner.py @@ -121,16 +121,10 @@ def _call_runnable( The outputs in list. """ - model_inputs = self.get_inputs() - if device == "cpu": - tvm_inputs = [tvm.nd.array(inputs[i["name"]]) for i in model_inputs] - elif device.startswith("cuda"): - dev_id = int(device.split(":")[1]) if ":" in device else 0 - tvm_inputs = [ - tvm.nd.array(inputs[i["name"]], device=tvm.cuda(dev_id)) for i in model_inputs - ] - else: - raise NotImplementedError("Unsupported device " + str(device)) + input_names = [i["name"] for i in self.get_inputs()] + tvm_inputs = [ + msc_utils.cast_array(inputs[i], MSCFramework.TVM, device) for i in input_names + ] return runnable(*tvm_inputs) def _device_enabled(self, device: str) -> bool: @@ -158,18 +152,24 @@ def framework(self): return MSCFramework.TVM @classmethod - def load_native(cls, model: Any) -> tvm.IRModule: + def load_native(cls, model: Any, config: dict) -> Tuple[tvm.IRModule, str, bool]: """Load the native model Parameters ------- model: The native model. + config: dict + The config for pipeline. Returns ------- model: tvm.IRModule The loaded native model. + device: str + The device of the model. + training: bool + Whether the model is for training. """ if isinstance(model, dict) and "model" in model: diff --git a/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py b/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py index 9966e9c1af5d..5a534991b93f 100644 --- a/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py +++ b/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py @@ -74,6 +74,7 @@ def get_quantize_cache( zero_point = quantizer._get_tensor_cache(name, consumer, "zero_point") if scale_tensor is None: scale_tensor = cls.get_scale_tensor(data, scale, axis, epsilon, expand_dims=False) + scale_tensor = 1 / scale_tensor if isinstance(scale_tensor, float): scale_tensor = np.array(scale_tensor) scale_tensor = scale_tensor.astype(quantizer.find_tensor(name).dtype_name) diff --git a/python/tvm/contrib/msc/pipeline/__init__.py b/python/tvm/contrib/msc/pipeline/__init__.py index 99a8699ad9ab..b27b09d5d764 100644 --- a/python/tvm/contrib/msc/pipeline/__init__.py +++ b/python/tvm/contrib/msc/pipeline/__init__.py @@ -17,3 +17,4 @@ """tvm.contrib.msc.pipeline""" from .manager import * +from .wrapper import * diff --git a/python/tvm/contrib/msc/pipeline/config.py b/python/tvm/contrib/msc/pipeline/config.py new file mode 100644 index 000000000000..16ff34f2eca6 --- /dev/null +++ b/python/tvm/contrib/msc/pipeline/config.py @@ -0,0 +1,170 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.pipeline.config""" + +from typing import List, Union, Dict, Tuple + +from tvm.contrib.msc.core.tools import ToolType +from tvm.contrib.msc.core.utils.message import MSCStage +from tvm.contrib.msc.core import utils as msc_utils + + +def support_tool(tool: dict, stage: str, run_type: str) -> bool: + """Check if the tool is supported + + Parameters + ---------- + tool: dict + The tool config, + stage: str + The compile stage. + run_type: str + The runtime type. + + Returns + ------- + supported: bool + Whether the tool is supported. + """ + + run_type = tool.get("run_type", run_type) + if stage == MSCStage.BASELINE: + return tool["tool_type"] == ToolType.TRACKER + return True + + +def config_tool(tool_type: str, raw_config: Union[dict, str]) -> dict: + """Config the tool + + Parameters + ---------- + tool_type: str + The tool type, + raw_config: str| dict + The tool config or style. + + Returns + ------- + config: dict + The config for tool. + """ + + if isinstance(raw_config, dict): + if "config_style" in raw_config: + config_style = raw_config.pop("config_style") + else: + config_style = "default" + else: + config_style, raw_config = raw_config, None + configer_cls = msc_utils.get_registered_tool_configer(tool_type, config_style) + assert configer_cls, "Can not find configer for {}:{}".format(tool_type, config_style) + return {"tool_type": tool_type, **configer_cls().config(raw_config)} + + +def create_config( + inputs: List[dict], + outputs: List[str], + model_type: str, + baseline_type: str = None, + optimize_type: str = None, + compile_type: str = None, + dataset: Dict[str, dict] = None, + tools: List[Tuple[str, Union[dict, str]]] = None, + skip_config: Dict[str, str] = None, + **extra_config, +) -> dict: + """Create config for msc pipeline + + Parameters + ---------- + inputs: list + The inputs info, + outputs: list + The output names. + model_type: str + The model type. + baseline_type: str + The baseline type. + compile_type: str + The compile type. + optimize_type: str + The optimize type. + dataset: dict + The datasets for compile pipeline. + tools: list + The tools config. + skip_config: dict + The skip config for compile. + extra_config: dict + The extra config. + """ + + baseline_type = baseline_type or model_type + optimize_type = optimize_type or baseline_type + compile_type = compile_type or optimize_type + if tools: + tools = [config_tool(t_type, t_config) for t_type, t_config in tools] + # basic config + config = { + "model_type": model_type, + "inputs": inputs, + "outputs": outputs, + "dataset": dataset, + "tools": tools, + MSCStage.PREPARE: {"profile": {"benchmark": {"repeat": -1}}}, + MSCStage.BASELINE: { + "run_type": baseline_type, + "profile": {"check": {"atol": 1e-3, "rtol": 1e-3}, "benchmark": {"repeat": -1}}, + }, + } + + # config optimize + if tools: + config[MSCStage.OPTIMIZE] = { + "run_type": optimize_type, + "profile": {"check": {"atol": 1e-3, "rtol": 1e-3}, "benchmark": {"repeat": -1}}, + } + + # config compile + config[MSCStage.COMPILE] = { + "run_type": compile_type, + "profile": {"check": {"atol": 1e-3, "rtol": 1e-3}, "benchmark": {"repeat": -1}}, + } + + # skip stages + skip_config = skip_config or {} + for stage in [MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE]: + if stage not in config: + continue + for key in ["all", stage]: + if key not in skip_config: + continue + if skip_config[key] == "stage": + config.pop(stage) + elif skip_config[key] == "profile": + config[stage].pop("profile") + elif skip_config[key] == "check": + config[stage]["profile"].pop("check") + elif skip_config[key] == "benchmark": + config[stage]["profile"].pop("benchmark") + else: + raise TypeError("Unexpected skip type " + str(skip_config[key])) + + # update config + if extra_config: + config = msc_utils.update_dict(config, extra_config) + return config diff --git a/python/tvm/contrib/msc/pipeline/manager.py b/python/tvm/contrib/msc/pipeline/manager.py index 42ef227b551b..c0b93569c843 100644 --- a/python/tvm/contrib/msc/pipeline/manager.py +++ b/python/tvm/contrib/msc/pipeline/manager.py @@ -20,12 +20,11 @@ import os import time import json -from typing import Dict, Any +from typing import Dict, Any, Union, List import traceback import numpy as np import tvm -from tvm.contrib.msc.core import transform as msc_transform from tvm.contrib.msc.core.runtime import BaseRunner from tvm.contrib.msc.core.tools import ToolType from tvm.contrib.msc.core.utils.namespace import MSCFramework, MSCMap, MSCKey @@ -33,7 +32,8 @@ from tvm.contrib.msc.core import utils as msc_utils from tvm.contrib.msc.core.gym.control import create_controller from tvm.contrib.msc.core import _ffi_api -from tvm.contrib.msc.plugin.utils import load_plugins +from tvm.contrib.msc.plugin.utils import export_plugins, load_plugins +from .config import support_tool class BaseManager(object): @@ -49,9 +49,21 @@ class BaseManager(object): The plugins for pipeline. root: str The root path for files. + run_optimize: bool + Whether to run optimize. + run_compile: bool + Whether to run compile. """ - def __init__(self, model: Any, config: dict, plugins: dict = None, root: str = None): + def __init__( + self, + model: Any, + config: dict, + plugins: dict = None, + root: str = None, + run_optimize: bool = True, + run_compile: bool = True, + ): # change path to root path if root: @@ -66,19 +78,15 @@ def _from_root_mark(val): # check stage for stage in ["inputs", "outputs", "dataset", MSCStage.PREPARE, MSCStage.COMPILE]: - assert stage in config, "{} should be given to run the pipeline".format(stage) + config.setdefault(stage, {}) MSCMap.reset() - self._model_type = config["model_type"] - self._model, self._device, self._training = self._get_runner_cls( - self._model_type - ).load_native(model) - if plugins: - self._plugins = load_plugins(plugins) - else: - self._plugins = {} use_cache = config.get("use_cache", True) self._workspace = msc_utils.set_workspace(config.get("workspace"), use_cache) + self._model_type = config["model_type"] + runner_cls = self._get_runner_cls(self._model_type) + self._model, self._device, self._training = runner_cls.load_native(model, config) + self._plugins = load_plugins(plugins) if plugins else {} self._verbose = config.get("verbose", "info") if "logger" in config: self._logger = config["logger"] @@ -90,15 +98,21 @@ def _from_root_mark(val): self._logger = msc_utils.set_global_logger(self._verbose, log_path) self._optimized, self._compiled = False, False msc_utils.time_stamp(MSCStage.SETUP) - self._logger.info(msc_utils.msg_block("SETUP", self.setup(config))) + self._logger.info( + msc_utils.msg_block("SETUP", self.setup(config, run_optimize, run_compile)) + ) - def setup(self, config: dict) -> dict: + def setup(self, config: dict, run_optimize: bool = True, run_compile: bool = True) -> dict: """Setup the manager Parameters ---------- config: dict The config for manager. + run_optimize: bool + Whether to run optimize. + run_compile: bool + Whether to run compile. Returns ------- @@ -116,7 +130,11 @@ def setup(self, config: dict) -> dict: for name, plugin in self._plugins[self._model_type].get_ops_info().items(): _ffi_api.RegisterPlugin(name, msc_utils.dump_dict(plugin)) self._config, self._debug_levels = self.update_config(config) - self._tools_config = {} + if not run_optimize and MSCStage.OPTIMIZE in self._config: + self._config.pop(MSCStage.OPTIMIZE) + if not run_compile and MSCStage.COMPILE in self._config: + self._config.pop(MSCStage.COMPILE) + self._tools_config = [] self._relax_mod, self._runner = None, None self._sample_inputs = None self._report = { @@ -128,7 +146,7 @@ def setup(self, config: dict) -> dict: "duration": {}, "profile": {}, } - return {"workspace": self._workspace.path, "plugins": self._plugins, "config": config} + return {"workspace": self._workspace.path, "plugins": self._plugins, "config": self._config} def update_config(self, config: dict) -> dict: """Update config @@ -154,23 +172,26 @@ def update_config(self, config: dict) -> dict: config = self._get_runner_cls(self._model_type).update_config( MSCStage.PARSE, config, self._model ) + + # update runner config for stage in [MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE]: if stage not in config: continue if "run_type" not in config[stage]: config[stage]["run_type"] = self._model_type - config = self._get_runner_cls(config[stage]["run_type"]).update_config( - stage, config, self._model - ) - if MSCStage.OPTIMIZE in config: - config[MSCStage.OPTIMIZE] = self._update_tool_config(config[MSCStage.OPTIMIZE]) + runner_cls = self._get_runner_cls(config[stage]["run_type"]) + config = runner_cls.update_config(stage, config, self._model) - def _set_debug_level(stage: str, stage_config: dict, default: int = None) -> dict: - if "debug_level" in stage_config: - debug_levels[stage] = stage_config["debug_level"] + # update tool config + if config.get("tools"): + config["tools"] = self._update_tools_config(config["tools"]) + + def _set_debug_level(stage: str, sub_config: dict, default: int = None) -> dict: + if "debug_level" in sub_config: + debug_levels[stage] = sub_config["debug_level"] elif default is not None: debug_levels[stage] = default - stage_config["debug_level"] = default + sub_config["debug_level"] = default return debug_levels if self._verbose.startswith("debug:"): @@ -181,18 +202,17 @@ def _set_debug_level(stage: str, stage_config: dict, default: int = None) -> dic if stage not in config: continue debug_levels = _set_debug_level(stage, config[stage]["run_config"], debug_level) - if MSCStage.OPTIMIZE in config: - for t_type in ToolType.all_types(): - if t_type not in config[MSCStage.OPTIMIZE]: + for t_config in config.get("tools", []): + if not support_tool(t_config, stage, config[stage]["run_type"]): continue - debug_levels = _set_debug_level( - self._get_tool_stage(t_type), config[MSCStage.OPTIMIZE][t_type], debug_level - ) + t_stage = stage + "." + self._get_tool_stage(t_config["tool_type"]) + debug_levels = _set_debug_level(t_stage, t_config["tool_config"], debug_level) ordered_keys = [ "model_type", "inputs", "outputs", "dataset", + "tools", MSCStage.PREPARE, MSCStage.PARSE, MSCStage.BASELINE, @@ -201,16 +221,9 @@ def _set_debug_level(stage: str, stage_config: dict, default: int = None) -> dic ] return {k: config[k] for k in ordered_keys if k in config}, debug_levels - def run_pipe(self, run_optimize: bool = True, run_compile: bool = True) -> dict: + def run_pipe(self) -> dict: """Run the pipeline and return object. - Parameters - ---------- - run_optimize: bool - Whether to run the optimize. - run_compile: bool - Whether to run the compile. - Returns ------- report: @@ -223,9 +236,9 @@ def run_pipe(self, run_optimize: bool = True, run_compile: bool = True) -> dict: self.parse() if MSCStage.BASELINE in self._config: self.baseline() - if run_optimize and MSCStage.OPTIMIZE in self._config: + if MSCStage.OPTIMIZE in self._config: self.optimize() - if run_compile: + if MSCStage.COMPILE in self._config: self.compile() except Exception as exc: # pylint: disable=broad-exception-caught err_msg = "Pipeline failed:{}\nTrace: {}".format(exc, traceback.format_exc()) @@ -271,7 +284,9 @@ def prepare(self) -> Dict[str, np.ndarray]: if cnt >= max_golden > 0: break if not self._sample_inputs: - self._sample_inputs = inputs + self._sample_inputs = { + k: msc_utils.cast_array(v) for k, v in inputs.items() + } outputs, _ = run_func(self._model, inputs, input_names, self._config["outputs"]) cnt = saver.save_batch(inputs, outputs) report["datas_info"] = saver.info @@ -298,7 +313,7 @@ def _to_tensor_str(info): if "profile" in stage_config and run_func: benchmark = stage_config["profile"].get("benchmark", {}) benchmark["repeat"] = self._get_repeat(benchmark) - self._logger.debug("Prepare profile with %s(%s)", run_func, benchmark) + self._logger.debug("Prepare profile with %s(%s)", run_func.__name__, benchmark) _, avg_time = run_func( self._model, self._sample_inputs, input_names, self._config["outputs"], **benchmark ) @@ -335,24 +350,31 @@ def parse(self) -> tvm.IRModule: plugin = self._plugins[self._model_type] parse_config["custom_convert_map"] = plugin.get_convert_map() self._relax_mod, _ = stage_config["parser"](self._model, **parse_config) + transformed = set() for stage in [MSCStage.OPTIMIZE, MSCStage.COMPILE]: if stage not in self._config: continue - runner_cls = self._get_runner_cls(self._config[stage]["run_type"]) + run_type = self._config[stage]["run_type"] + if run_type in transformed: + continue + transformed.add(run_type) + runner_cls = self._get_runner_cls(run_type) if hasattr(runner_cls, "target_transform"): - self._logger.info( - "Transform for stage %s: %s", stage, runner_cls.target_transform - ) + self._logger.info("Transform for %s(%s)", run_type, stage) self._relax_mod = runner_cls.target_transform(self._relax_mod) - self._relax_mod = msc_transform.SetExprName()(self._relax_mod) if cache_path: with open(cache_path, "w") as f: f.write(tvm.ir.save_json(self._relax_mod)) self._logger.debug("Save parsed mod to %s", cache_path) return self._relax_mod - def baseline(self) -> BaseRunner: - """Run the baseline. + def _run_stage(self, stage: str) -> BaseRunner: + """Run the stage. + + Parameters + ---------- + stage: str + The compile stage. Returns ------- @@ -360,14 +382,26 @@ def baseline(self) -> BaseRunner: The runner. """ - msc_utils.time_stamp(MSCStage.BASELINE) + msc_utils.time_stamp(stage) + self.apply_tools(stage) self._runner = self._create_runner( - MSCStage.BASELINE, - self._config[MSCStage.BASELINE], + stage, + self._config[stage], use_cache=self._config.get("use_cache", True), ) return self._runner + def baseline(self) -> BaseRunner: + """Run the baseline. + + Returns + ------- + runner: BaseRunner + The runner. + """ + + return self._run_stage(MSCStage.BASELINE) + def optimize(self) -> BaseRunner: """Run the optimize and return object. @@ -377,17 +411,9 @@ def optimize(self) -> BaseRunner: The runner. """ - stage_config = self._config[MSCStage.OPTIMIZE] - self.apply_tools(stage_config) - msc_utils.time_stamp(MSCStage.OPTIMIZE) - self._runner = self._create_runner( - MSCStage.OPTIMIZE, - stage_config, - tools_config=self._tools_config, - use_cache=self._config.get("use_cache", True), - ) + runner = self._run_stage(MSCStage.OPTIMIZE) self._optimized = True - return self._runner + return runner def compile(self) -> BaseRunner: """Run the compile and return object. @@ -398,43 +424,28 @@ def compile(self) -> BaseRunner: The runner. """ - stage_config = self._config[MSCStage.COMPILE] - self.apply_tools(stage_config) - msc_utils.time_stamp(MSCStage.COMPILE) - self._runner = self._create_runner( - MSCStage.COMPILE, - stage_config, - tools_config=self._tools_config, - use_cache=self._config.get("use_cache", True), - ) + runner = self._run_stage(MSCStage.COMPILE) self._compiled = True - return self._runner + return runner - def apply_tools(self, stage_config: dict): + def apply_tools(self, stage: str): """Apply tools for a stage. Parameters ---------- - stage_config: dict - The config of this stage. + stage: str + The compile stage. """ - runner_cls = self._get_runner_cls(stage_config["run_type"]) - - def _tool_enabled(tool_type: str) -> bool: - return tool_type in stage_config and runner_cls.support_tool(tool_type) - - # run prune - if _tool_enabled(ToolType.PRUNER): - self._apply_tool(ToolType.PRUNER, stage_config) - - # run quantize - if _tool_enabled(ToolType.QUANTIZER): - self._apply_tool(ToolType.QUANTIZER, stage_config) - - # run distill - if _tool_enabled(ToolType.DISTILLER): - self._apply_tool(ToolType.DISTILLER, stage_config) + self._tools_config = [] + for tool in self._config.get("tools", []): + run_type = tool.get("run_type", self._config[stage]["run_type"]) + if not support_tool(tool, stage, run_type): + continue + self._apply_tool(tool, stage) + if tool.get("apply_once", False): + self._logger.debug("Remove apply once tool %s", tool["tool_type"]) + self._tools_config = self._tools_config[:-1] def summary(self, err_msg=None): """Summary the pipeline. @@ -458,6 +469,155 @@ def summary(self, err_msg=None): self._report["duration"] = msc_utils.get_duration() return self._report + def export(self, path: str = None, dump: bool = True) -> Union[str, dict]: + """Export the pipeline + + Parameters + ---------- + path: str + The export path. + dump: bool + Whether to dump the info. + + Returns + ------- + export_path/pipeline: str/dict + The exported path/pipeline info. + """ + + path = path or "msc_export" + if path.endswith(".tar.gz"): + folder, dump = msc_utils.msc_dir(path.replace(".tar.gz", ""), keep_history=False), True + else: + folder = msc_utils.msc_dir(path, keep_history=False) + if dump: + plugins = export_plugins(self._plugins, folder.create_dir("plugin")) + else: + plugins = self._plugins + + def _to_root_mark(val): + if isinstance(val, str) and folder.path != val and folder.path in val: + return val.replace(folder.path, MSCKey.ROOT_MARK) + return val + + pipeline = { + "model": self.export_model(folder.create_dir("model"), dump), + "config": self.export_config(folder, dump), + "plugins": plugins, + "root": folder.path, + } + pipeline = msc_utils.map_dict(pipeline, _to_root_mark) + if not dump: + return pipeline + with open(folder.relpath("pipeline.json"), "w") as f: + f.write(json.dumps(pipeline, indent=2)) + if path.endswith(".tar.gz"): + msc_utils.pack_folder(path.replace(".tar.gz", ""), "tar") + return path + + def export_model(self, folder: msc_utils.MSCDirectory, dump: bool = True) -> Any: + """Export the model + + Parameters + ---------- + folder: MSCDirectory + The export folder. + dump: bool + Whether to dump info. + + Returns + ------- + exported: + The exported model. + """ + + if self._compiled: + return self._runner._save_runnable(folder) if dump else self._runner.runnable + if self._optimized: + module = self._runner.export_module(folder) + if not dump: + return module + path = folder.relpath("model.json") + with open(path, "w") as f: + f.write(tvm.ir.save_json(module)) + return {"model": path} + if not dump: + return self._model + return self._get_runner_cls(self._model_type).dump_nativate(self._model, folder) + + def export_config(self, folder: msc_utils.MSCDirectory, dump: bool = True) -> dict: + """Export the config + + Parameters + ---------- + folder: MSCDirectory + The export folder. + dump: bool + Whether to dump info. + + Returns + ------- + config: dict + The updated config. + """ + + if self._compiled: + return {"model_info": self.runner.model_info} + + # dump the dataloader + def _save_dataset(name, info, dump: bool): + loader, max_batch = info["loader"], info.get("max_batch", -1) + data_folder = folder.create_dir("dataset") + if isinstance(loader, str) and msc_utils.is_callable(loader): + path, func_name = loader.split(":") + exp_loader = data_folder.copy(path) + ":" + func_name + elif msc_utils.is_io_dataset(loader): + exp_loader = data_folder.copy(loader, name) + elif callable(loader) and dump: + saver_options = { + "input_names": [i[0] for i in self._config["inputs"]], + "output_names": self._config["outputs"], + } + batch_cnt = 0 + exp_loader = data_folder.create_dir(name).path + with msc_utils.IODataSaver(exp_loader, saver_options) as saver: + for inputs in loader(): + if batch_cnt >= max_batch > 0: + break + batch_cnt = saver.save_batch(inputs) + else: + exp_loader = loader + return {"loader": exp_loader, "max_batch": max_batch} + + config = msc_utils.copy_dict(self._meta_config) + config["dataset"] = { + k: _save_dataset(k, v, dump) for k, v in self._config["dataset"].items() + } + if self._optimized: + config["model_type"] = MSCFramework.TVM + for stage in [MSCStage.BASELINE, MSCStage.OPTIMIZE]: + if stage in config: + config.pop(stage) + if "profile" in config[MSCStage.COMPILE]: + config[MSCStage.COMPILE]["profile"].setdefault("check", {})["err_rate"] = -1 + config["tools"] = [] + for tool in self._config.get("tools", []): + if not support_tool(tool, MSCStage.COMPILE, self._compile_type): + continue + run_tool = self.runner.get_tool(tool["tool_type"]) + tool["tool_config"] = run_tool.export_config(tool["tool_config"], folder) + if tool["tool_config"]: + config["tools"].append(tool) + else: + self._logger.info( + "Skip compile with tool %s as no config exported", tool["tool_type"] + ) + # remove not serializable items + if dump: + remove_keys = {"workspace", "logger"} + config = {k: v for k, v in config.items() if k not in remove_keys} + return config + def destory(self, keep_workspace: bool = False): """Destroy the manager @@ -476,7 +636,6 @@ def _create_runner( self, stage: str, stage_config: dict, - tools_config: dict = None, visualize: bool = True, profile: bool = True, use_cache: bool = True, @@ -489,8 +648,6 @@ def _create_runner( The stage name stage_config: dict The config of this stage. - tools_config: dict - The config of the tools visualize: bool Whether to visualize the runner profile: bool @@ -507,7 +664,6 @@ def _create_runner( if self._runner: self._runner.destory() cache_dir = msc_utils.get_cache_dir().create_dir(stage) if use_cache else None - tools_config = tools_config or {} msc_utils.time_stamp(stage + ".build", False) runner_cls = self._get_runner_cls(stage_config["run_type"]) run_config = msc_utils.copy_dict(stage_config.get("run_config")) @@ -521,41 +677,34 @@ def _create_runner( run_config["device"] = self._device if "training" not in run_config: run_config["training"] = self._training - opt_config = self._config.get(MSCStage.OPTIMIZE, {}) - if ToolType.TRACKER in opt_config and runner_cls.support_tool(ToolType.TRACKER): - tools_config = {**tools_config, ToolType.TRACKER: opt_config[ToolType.TRACKER]} # Build runner runner = runner_cls( self._relax_mod, - tools_config=tools_config, + tools_config=self._tools_config, plugin=self._plugins.get(stage_config["run_type"]), stage=stage, logger=self._logger, **run_config, ) runner.build(cache_dir=cache_dir) - self._report["info"][stage + "_by"] = "{}({})".format(runner.framework, runner.device) + self._report["info"][stage + "_type"] = "{}({})".format(runner.framework, runner.device) if visualize: runner.visualize(msc_utils.get_visual_dir().create_dir(stage)) if profile and "profile" in stage_config: self._report["profile"][stage] = self._profile_runner(runner, stage_config) if use_cache: runner.save_cache(cache_dir) - if runner.get_tool(ToolType.TRACKER): - runner.apply_tool(ToolType.TRACKER) return runner - def _apply_tool(self, tool_type: str, stage_config: dict, add_tool: bool = True) -> str: + def _apply_tool(self, tool: dict, stage: str) -> str: """Apply tool with runner Parameters ---------- - tool_type: str - The tool type. - stage_config: dict - The config of this stage. - add_tool: bool - Whether to add tool in self._tools. + tool: dict + The tool config. + stage: str + The compile stage. Returns ------- @@ -563,51 +712,51 @@ def _apply_tool(self, tool_type: str, stage_config: dict, add_tool: bool = True) The plan_file path. """ - assert tool_type in stage_config, "Can not find config for tool " + str(tool_type) - tool_stage, tool_config = self._get_tool_stage(tool_type), stage_config[tool_type] - if "run_type" in tool_config: - run_type = tool_config.pop("run_type") - else: - run_type = stage_config["run_type"] + self._tools_config.append(tool) + tool_type, tool_config = tool["tool_type"], tool["tool_config"] + tool_stage = self._get_tool_stage(tool_type) plan_file = tool_config["plan_file"] - if "gym_configs" in tool_config: - gym_configs = tool_config.pop("gym_configs") - else: - gym_configs = None - if add_tool: - self._tools_config[tool_type] = tool_config - tools_config = self._tools_config - else: - tools_config = {**self._tools_config, tool_type: tool_config} if os.path.isfile(plan_file): self._logger.info("Skip %s with plan %s", tool_type, plan_file) return plan_file - msc_utils.time_stamp(tool_stage) - t_stage_config = {"run_type": run_type, "run_config": stage_config["run_config"]} - runner = self._create_runner( - tool_stage, t_stage_config, tools_config=tools_config, profile=False, use_cache=False - ) - if gym_configs: + t_stage = stage + "." + tool_stage + msc_utils.time_stamp(t_stage) + stage_config = { + "run_type": tool.get("run_type", self._config[stage]["run_type"]), + "run_config": self._config[stage]["run_config"], + } + runner = self._create_runner(t_stage, stage_config, profile=False, use_cache=False) + if "gym_configs" in tool: knowledge = None - for idx, config in enumerate(gym_configs): - self._logger.info("GYM[%d/%d].CREATE(%s)", idx, len(gym_configs), tool_stage) - extra_config = { - "env": { - "runner": runner, - "data_loader": self._get_loader(tool_stage), - "knowledge": knowledge, - }, - "verbose": self._verbose, - } - controller = create_controller(runner.stage, config, extra_config) - knowledge = controller.run() - with open(plan_file, "w") as f: - f.write(json.dumps(knowledge, indent=2)) - self._logger.info( - "Gym save %d knowledge(%s) -> %s", len(knowledge), tool_type, plan_file - ) - return plan_file - return runner.apply_tool(tool_type, self._get_loader(tool_stage)) + for idx, config in enumerate(tool["gym_configs"]): + knowledge_file = msc_utils.get_config_dir().relpath( + "gym_knowledge_{}.json".format(idx) + ) + gym_mark = "GYM[{}/{}]({} @ {}) ".format( + idx, len(tool["gym_configs"]), runner.framework, t_stage + ) + if os.path.isfile(knowledge_file): + knowledge = knowledge_file + self._logger.info("%sLoad from %d", gym_mark, knowledge) + else: + msc_utils.time_stamp(t_stage + ".gym_{}".format(idx)) + self._logger.info("%sStart search", gym_mark) + extra_config = { + "env": { + "runner": runner, + "data_loader": self._get_loader(tool_stage), + "knowledge": knowledge, + }, + "verbose": self._verbose, + } + controller = create_controller(tool_stage, config, extra_config) + knowledge = controller.run() + msc_utils.save_dict(knowledge, knowledge_file) + plan = msc_utils.load_dict(knowledge) + self._logger.info("%sFound %d plan", gym_mark, len(plan)) + return msc_utils.save_dict(plan, plan_file) + msc_utils.time_stamp(t_stage + ".make_plan", False) + return runner.make_plan(tool_type, self._get_loader(tool_stage)) def _profile_runner(self, runner: BaseRunner, stage_config: str) -> dict: """Profile the runner. @@ -682,30 +831,28 @@ def _profile_runner(self, runner: BaseRunner, stage_config: str) -> dict: self._logger.info(msg) return report - def _update_tool_config(self, opt_config: dict) -> dict: + def _update_tools_config(self, tools: List[dict]) -> List[dict]: """Update tool in stage config. Parameters ---------- - opt_config: dict - The config of optimize. + tools: list + The config of tools. Returns ------- - config: dict - The updated config of optimize. + tools: list + The updated config of tools. """ - for tool_type in ToolType.all_types(): - if tool_type not in opt_config: - continue - tool_config = opt_config[tool_type] + for tool in tools: + tool_config = tool["tool_config"] if "plan_file" not in tool_config: - tool_config["plan_file"] = "msc_{}.json".format(tool_type) + tool_config["plan_file"] = "msc_{}.json".format(tool["tool_type"]) tool_config["plan_file"] = msc_utils.to_abs_path( tool_config["plan_file"], msc_utils.get_config_dir() ) - return opt_config + return tools def _get_tool_stage(self, tool_type: str) -> str: """Map the stage according to tool_type @@ -727,6 +874,8 @@ def _get_tool_stage(self, tool_type: str) -> str: return MSCStage.QUANTIZE if tool_type == ToolType.DISTILLER: return MSCStage.DISTILL + if tool_type == ToolType.TRACKER: + return MSCStage.TRACK return tool_type def get_runnable(self, ret_type: str = "runner") -> Any: @@ -743,6 +892,7 @@ def get_runnable(self, ret_type: str = "runner") -> Any: The runner or model. """ + assert self._runner, "Failed to create runner, call run_pipe first" if ret_type == "runner": return self._runner elif ret_type == "runnable": @@ -772,10 +922,9 @@ def _get_loader(self, name: str = MSCStage.PREPARE) -> Any: config = self._config["dataset"].get(name, self._config["dataset"][MSCStage.PREPARE]) source_loader = config.get("loader") - max_batch = config.get("max_batch", 5) assert source_loader, "Dataset loader should be given for msc pipeline" if source_loader == "from_random": - max_batch = max(max_batch, 5) + max_batch = config.get("max_batch", 5) def get_random(): for _ in range(max_batch): @@ -783,6 +932,7 @@ def get_random(): loader, source_type = get_random, "Random" elif msc_utils.is_io_dataset(source_loader): + max_batch = config.get("max_batch", -1) def load_datas(): for inputs, _ in msc_utils.IODataLoader(source_loader, end=max_batch): @@ -790,9 +940,11 @@ def load_datas(): loader, source_type = load_datas, "IOData" elif callable(source_loader): + max_batch = config.get("max_batch", -1) + load_kwargs = config.get("load_kwargs", {}) def get_source(): - for idx, inputs in enumerate(source_loader()): + for idx, inputs in enumerate(source_loader(**load_kwargs)): if idx >= max_batch > 0: break yield inputs @@ -802,7 +954,7 @@ def get_source(): raise TypeError( "Unexpected source loader {}({})".format(source_loader, type(source_loader)) ) - self._logger.debug("Create data loader(%s) %s(%s)", name, loader, source_type) + self._logger.debug("Create data loader(%s) %s(%s)", name, loader.__name__, source_type) return loader def _get_repeat(self, benchmark: dict, device: str = None) -> int: diff --git a/python/tvm/contrib/msc/pipeline/wrapper.py b/python/tvm/contrib/msc/pipeline/wrapper.py new file mode 100644 index 000000000000..c790b5ef27be --- /dev/null +++ b/python/tvm/contrib/msc/pipeline/wrapper.py @@ -0,0 +1,302 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.pipeline.wrapper""" + +import shutil +from typing import Any, Union, List + +from tvm.contrib.msc.core.tools.tool import BaseTool, ToolType +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils +from .manager import MSCManager +from .config import create_config + + +class BaseWrapper(object): + """Base Wrapper of models + + Parameters + ---------- + model: Any + The raw model in framwork. + config: dict + The config for pipeline + plugins: dict + The plugins for pipeline. + debug: bool + Whether to use debug mode. + """ + + def __init__( + self, + model: Any, + config: dict, + workspace: str = "msc_workspace", + plugins: dict = None, + debug: bool = False, + ): + self._meta_model = model + self._optimized_model, self._compiled_model = None, None + self._config = config + self._plugins = plugins + verbose = config.get("verbose", "info") + self._debug = True if verbose.startswith("debug") else debug + self._workspace = msc_utils.msc_dir(workspace, keep_history=self._debug) + log_path = self._workspace.relpath("MSC_LOG", keep_history=False) + self._config["logger"] = msc_utils.create_file_logger(verbose, log_path) + self._manager = None + self.setup() + + def __str__(self): + if self.compiled: + phase = "compiled" + elif self.optimized: + phase = "optimized" + else: + phase = "meta" + return "({}) {}".format(phase, self._get_model().__str__()) + + def __getattr__(self, name): + if hasattr(self._get_model(), name): + return getattr(self._get_model(), name) + return self._get_model().__getattr__(name) + + def setup(self): + """Setup the wrapper""" + + return + + def optimize(self, workspace: str = "Optimize"): + """Optimize the model + + Parameters + ---------- + workspace: str + The workspace. + """ + + self.logger.info("[Wrapper] Start optimize model") + config = msc_utils.copy_dict(self._config) + config["workspace"] = self._workspace.create_dir(workspace) + self._manager = MSCManager(self._meta_model, config, self._plugins, run_compile=False) + self._manager.run_pipe() + self._optimized_model = self._manager.get_runnable("runnable") + return self + + def compile( + self, workspace: str = "Compile", ckpt_path: str = "Checkpoint", dump: bool = False + ): + """Compile the model + + Parameters + ---------- + workspace: str + The workspace. + ckpt_path: str + The path to export checkpoint. + dump: bool + Whether to dump the info. + """ + + if self._optimized_model: + self.logger.info("[Wrapper] Start compile checkpoint") + ckpt_path = self._workspace.create_dir(ckpt_path).path + pipeline = self.export(ckpt_path, dump=dump) + pipeline["config"]["workspace"] = self._workspace.create_dir(workspace) + self._manager = MSCManager(**pipeline) + self._manager.run_pipe() + self._compiled_model = self._manager.get_runnable("runnable") + if not self._debug: + shutil.rmtree(ckpt_path) + else: + self.logger.info("[Wrapper] Start compile model") + config = msc_utils.copy_dict(self._config) + config["workspace"] = self._workspace.create_dir(workspace) + self._manager = MSCManager(self._meta_model, config, self._plugins) + self._manager.run_pipe() + self._compiled_model = self._manager.get_runnable("runnable") + return self + + def export(self, path: str = "msc_export", dump: bool = True) -> Union[str, dict]: + """Export compile pipeline + + Parameters + ---------- + path: str + The export path. + dump: bool + Whether to dump the info. + + Returns + ------- + export_path/pipeline: str/dict + The exported path/pipeline info. + """ + + if not self._manager: + self._manager = MSCManager(self._meta_model, self._config, self._plugins) + exported = self._manager.export(path, dump=dump) + if not self._debug: + self._manager.destory() + return exported + + def get_tools(self, tool_types: List[str]) -> List[BaseTool]: + """Get the tools from manager + + Parameters + ---------- + tool_types: list + The tool types. + + Returns + ------- + tools: list + The tools. + """ + + if not self._manager: + return [] + tool_types = tool_types or ToolType.all_types() + tools = [] + for t in tool_types: + tool = self._manager.runner.get_tool(t) + if tool: + tools.append(tool) + return tools + + def disable_tools(self, tool_types: List[str]): + """Disable the tools + + Parameters + ---------- + tool_types: list + The tool types. + """ + + for tool in self.get_tools(tool_types): + tool.disable() + + def enable_tools(self, tool_types: List[str]): + """Enable the tools + + Parameters + ---------- + tool_types: list + The tool types. + """ + + for tool in self.get_tools(tool_types): + tool.enable() + + def _get_model(self) -> Any: + return self._compiled_model or self._optimized_model or self._meta_model + + def _get_framework(self) -> str: + return self._manager.runner.framework if self._manager else self.model_type() + + @property + def optimized(self): + return self._optimized_model is not None + + @property + def compiled(self): + return self._compiled_model is not None + + @property + def device(self): + if self._manager: + return self._manager.runner.device + return "cpu" + + @property + def logger(self): + return self._config["logger"] + + @classmethod + def create_config( + cls, + inputs: List[dict], + outputs: List[str], + baseline_type: str = None, + optimize_type: str = None, + compile_type: str = None, + **kwargs, + ) -> dict: + """Create config for msc pipeline + + Parameters + ---------- + inputs: list + The inputs info, + outputs: list + The output names. + baseline_type: str + The baseline type. + compile_type: str + The compile type. + optimize_type: str + The optimize type. + kwargs: dict + The config kwargs. + """ + + return create_config( + inputs, outputs, cls.model_type(), baseline_type, optimize_type, compile_type, **kwargs + ) + + @classmethod + def model_type(cls): + return MSCFramework.MSC + + +class TorchWrapper(BaseWrapper): + """Wrapper of torch models""" + + def __call__(self, *inputs): + framework = self._get_framework() + if framework != MSCFramework.TORCH: + inputs = [msc_utils.cast_array(i, framework, self.device) for i in inputs] + outputs = self._get_model()(*inputs) + if framework == MSCFramework.TORCH: + return outputs + if isinstance(outputs, (tuple, list)): + return [msc_utils.cast_array(o, MSCFramework.TORCH, self.device) for o in outputs] + return msc_utils.cast_array(outputs, MSCFramework.TORCH) + + def parameters(self): + framework = self._get_framework() + if framework == MSCFramework.TORCH: + return self._get_model().parameters() + return self._manager.runner.get_weights(MSCFramework.TORCH) + + def train(self): + if self._manager: + self._manager.runner.train() + if self._get_framework() == MSCFramework.TORCH: + return self._get_model().train() + return self._get_model() + + def eval(self): + if self._manager: + self._manager.runner.eval() + if self._get_framework() == MSCFramework.TORCH: + return self._get_model().eval() + return self._get_model() + + @classmethod + def model_type(cls): + return MSCFramework.TORCH diff --git a/src/contrib/msc/core/ir/graph.cc b/src/contrib/msc/core/ir/graph.cc index 71f3208db94d..ca1bff09725f 100644 --- a/src/contrib/msc/core/ir/graph.cc +++ b/src/contrib/msc/core/ir/graph.cc @@ -1067,8 +1067,7 @@ void WeightGraphNode::FromJson(const JsonWeightGraph& j_graph) { } // set friends for (const auto& j_joint : j_graph.nodes) { - name = j_joint.name; - const auto& node = Downcast(nodes[name]); + const auto& node = Downcast(nodes[j_joint.name]); for (const auto& f_name : j_joint.friends) { ICHECK(nodes.count(f_name)) << "Can not find friend " << f_name; node->friends.push_back(nodes[f_name]); diff --git a/src/contrib/msc/core/transform/bind_named_params.cc b/src/contrib/msc/core/transform/bind_named_params.cc new file mode 100644 index 000000000000..5ba1ca30eb1c --- /dev/null +++ b/src/contrib/msc/core/transform/bind_named_params.cc @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace relax { +using namespace tvm::contrib::msc; + +std::tuple, Map> NormalizeNamedBindings( + const Function& func, const Map& untyped_params) { + ICHECK(func.defined()); + ICHECK(untyped_params.defined()); + + // Map from string to the variable(s) with that name. + std::unordered_map> string_lookup; + std::unordered_set var_set; + for (const auto& param : func->params) { + string_lookup[param->name_hint()].push_back(param); + var_set.insert(param.get()); + } + + Map relax_var_remap; + + auto normalize_key = [&](ObjectRef obj) -> relax::Var { + if (auto opt_str = obj.as()) { + std::string str = opt_str.value(); + auto it = string_lookup.find(str); + CHECK(it != string_lookup.end()) + << "Function does not have parameter with name \"" << str << "\". " + << "Function parameters are named " + << func->params.Map([](const auto& param) { return param->name_hint(); }); + CHECK_EQ(it->second.size(), 1) + << "Function contains multiple parameters with name \"" << str << "\". " + << "The Relax variables " << it->second << " are all named \"" << str << "\""; + auto var = it->second[0]; + CHECK(!relax_var_remap.count(var)) + << "Remap of variable " << var << " was defined multiple times"; + + return var; + } else if (auto opt_var = obj.as()) { + auto var = opt_var.value(); + CHECK(!relax_var_remap.count(var)) + << "Remap of variable " << var << " was defined multiple times"; + CHECK(var_set.count(var.get())) + << "Function does not use Relax variable " << var << " as a parameter. " + << "Function parameters are " << func->params; + return var; + } else { + LOG(FATAL) + << "Expected bound parameter to be a relax::Var, " + << " or a string that uniquely identifies a relax::Var param within the function. " + << "However, received object " << obj << " of type " << obj->GetTypeKey(); + } + }; + auto normalize_value = [&](Var key, ObjectRef obj) -> relax::Expr { + if (auto opt = obj.as()) { + return opt.value(); + } else if (auto opt = obj.as()) { + const auto& span = SpanUtils::SetAttr(Span(), msc_attr::kName, key->name_hint()); + return Constant(opt.value(), StructInfo(), span); + } else { + LOG(FATAL) << "Cannot coerce object of type " << obj->GetTypeKey() + << " into relax expression"; + } + }; + + for (const auto& [key, value] : untyped_params) { + relax_var_remap.Set(normalize_key(key), normalize_value(normalize_key(key), value)); + } + + arith::Analyzer analyzer; + Map symbolic_var_map = InferSymbolicVarMap(relax_var_remap, &analyzer); + + return {relax_var_remap, symbolic_var_map}; +} + +/*! + * \brief Bind params to function by using name with span name + * \param func Relax function + * \param params params dict + * \return Function + */ +Function FunctionBindNamedParams(Function func, const Map& untyped_params) { + auto [bind_dict, symbolic_var_map] = NormalizeNamedBindings(func, untyped_params); + + Expr bound_expr = Bind(func, bind_dict, symbolic_var_map); + return Downcast(bound_expr); +} + +/*! + * \brief Bind params to a specific function in a module with span name + * \param m The module + * \param func_name The name of the specific function + * \param param The param dict + * \return The module after binding params. + */ +IRModule BindNamedParam(IRModule m, String func_name, Map bind_params) { + IRModuleNode* new_module = m.CopyOnWrite(); + Map functions = m->functions; + for (const auto& func_pr : functions) { + if (const auto* relax_f = func_pr.second.as()) { + if (relax_f->GetLinkageType() == LinkageType::kExternal) { + // Use global_symbol if it's external linkage + Optional gsymbol = relax_f->GetAttr(tvm::attr::kGlobalSymbol); + if (gsymbol.defined() && gsymbol.value() == func_name) { + Function f_after_bind = FunctionBindNamedParams(GetRef(relax_f), bind_params); + new_module->Update(func_pr.first, f_after_bind); + } + } else { + // Use global var's name_hint if it's internal linkage + if (func_pr.first->name_hint == func_name) { + Function f_after_bind = FunctionBindNamedParams(GetRef(relax_f), bind_params); + new_module->Update(func_pr.first, f_after_bind); + } + } + } + } + return GetRef(new_module); +} + +namespace transform { + +Pass BindNamedParams(String func_name, Map params) { + runtime::TypedPackedFunc pass_func = [=](IRModule mod, + PassContext pc) { + return BindNamedParam(std::move(mod), func_name, params); + }; + return CreateModulePass(pass_func, 0, "BindNamedParams", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.BindNamedParams").set_body_typed(BindNamedParams); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/contrib/msc/core/transform/set_expr_name.cc b/src/contrib/msc/core/transform/set_expr_name.cc index dfed1a242a50..163d86833593 100644 --- a/src/contrib/msc/core/transform/set_expr_name.cc +++ b/src/contrib/msc/core/transform/set_expr_name.cc @@ -84,8 +84,9 @@ class FuncNameGetter : public ExprVisitor { */ class RelaxExprNameSetter : public ExprVisitor { public: - explicit RelaxExprNameSetter(const IRModule& ref_module, const String& target) - : ref_module_(ref_module), target_{target} {} + explicit RelaxExprNameSetter(const IRModule& ref_module, const String& target, + const Map& var_names) + : ref_module_(ref_module), target_{target}, var_names_{var_names} {} void VisitBindingBlock(const BindingBlock& block) final { String block_name = SpanUtils::GetAttr(block->span, msc_attr::kName); @@ -170,7 +171,9 @@ class RelaxExprNameSetter : public ExprVisitor { ExprVisitor::VisitBinding_(binding, val); String name_hint, optype; bool use_unique = true; - if (const auto* op_node = val->op.as()) { + if (var_names_.count(binding->var->name_hint())) { + name_hint = var_names_[binding->var->name_hint()]; + } else if (const auto* op_node = val->op.as()) { const std::string& op_name = op_node->name; if (op_name == "relax.call_dps_packed" && val->args[0]->IsInstance()) { const auto& func = Downcast(val->args[0]); @@ -306,18 +309,21 @@ class RelaxExprNameSetter : public ExprVisitor { Map local_funcs_; IRModule ref_module_; String target_; + Map var_names_; }; // class ExprNameSetter -void SetRelaxExprName(const IRModule& ref_module, const Expr& e, const String& target) { - RelaxExprNameSetter(ref_module, target).VisitExpr(e); +void SetRelaxExprName(const IRModule& ref_module, const Expr& e, const String& target, + const Map& var_names) { + RelaxExprNameSetter(ref_module, target, var_names).VisitExpr(e); } namespace transform { -Pass SetRelaxExprName(const String& entry_name, const String& target) { +Pass SetRelaxExprName(const String& entry_name, const String& target, + const Map& var_names) { runtime::TypedPackedFunc pass_func = [=](IRModule m, PassContext pc) { - relax::SetRelaxExprName(m, m->Lookup(entry_name), target); + relax::SetRelaxExprName(m, m->Lookup(entry_name), target, var_names); return m; }; return CreateModulePass(pass_func, 0, "SetRelaxExprName", {}); diff --git a/src/contrib/msc/framework/torch/torch_opcode.cc b/src/contrib/msc/framework/torch/torch_opcode.cc index 59d30e774000..e355626f859f 100644 --- a/src/contrib/msc/framework/torch/torch_opcode.cc +++ b/src/contrib/msc/framework/torch/torch_opcode.cc @@ -580,18 +580,22 @@ class TorchStridedSliceCodeGen : public TorchOpCode { void CodeGenForward() final { const auto& begin = node()->GetTypeArrayAttr("begin"); const auto& end = node()->GetTypeArrayAttr("end"); - const auto& strides = node()->GetTypeArrayAttr("strides"); + std::vector strides; + if (!node()->GetAttr("strides", &strides)) { + strides = std::vector(begin.size(), 1); + } const auto& axes = CommonUtils::GetIndices(node()->GetTypeArrayAttr("axes"), node()->InputAt(0)->Ndim()); - std::set axes_set; - for (const auto& a : axes) { - axes_set.insert(a); + std::unordered_map axes_map; + for (size_t i = 0; i < axes.size(); i++) { + axes_map[axes[i]] = i; } Array slice; for (size_t i = 0; i < node()->InputAt(0)->Ndim(); i++) { - if (axes_set.count(i)) { - slice.push_back(std::to_string(begin[i]) + ":" + std::to_string(end[i]) + ":" + - std::to_string(strides[i])); + if (axes_map.count(i)) { + size_t idx = axes_map[i]; + slice.push_back(std::to_string(begin[idx]) + ":" + std::to_string(end[idx]) + ":" + + std::to_string(strides[idx])); } else { slice.push_back(":"); } diff --git a/tests/python/contrib/test_msc/test_tools.py b/tests/python/contrib/test_msc/test_tools.py index 7161b4b42f40..3a56b255efdb 100644 --- a/tests/python/contrib/test_msc/test_tools.py +++ b/tests/python/contrib/test_msc/test_tools.py @@ -36,7 +36,7 @@ def _get_config( model_type, compile_type, - tools_config, + tools, inputs, outputs, atol=1e-2, @@ -45,7 +45,7 @@ def _get_config( ): """Get msc config""" - path = "_".join(["test_tools", model_type, compile_type] + list(tools_config.keys())) + path = "_".join(["test_tools", model_type, compile_type] + [t["tool_type"] for t in tools]) return { "workspace": msc_utils.msc_dir(path), "verbose": "critical", @@ -53,6 +53,7 @@ def _get_config( "inputs": inputs, "outputs": outputs, "dataset": {"prepare": {"loader": "from_random", "max_iter": 5}}, + "tools": tools, "prepare": {"profile": {"benchmark": {"repeat": 10}}}, "baseline": { "run_type": model_type, @@ -61,7 +62,6 @@ def _get_config( "optimize": { "run_type": optimize_type or model_type, "profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}}, - **tools_config, }, "compile": { "run_type": compile_type, @@ -70,79 +70,93 @@ def _get_config( } -def get_tool_config(tool_type, use_distill=False): +def get_tools(tool_type, use_distill=False, run_type=MSCFramework.MSC): """Get config for the tool""" - config = {} + tools = [] if tool_type == ToolType.PRUNER: config = { "plan_file": "msc_pruner.json", - "strategys": [{"method": "per_channel", "density": 0.8}], + "strategys": [ + { + "methods": { + "weights": {"method_name": "per_channel", "density": 0.8}, + "output": {"method_name": "per_channel", "density": 0.8}, + } + } + ], } + tools.append({"tool_type": ToolType.PRUNER, "tool_config": config}) elif tool_type == ToolType.QUANTIZER: # pylint: disable=import-outside-toplevel from tvm.contrib.msc.core.tools.quantize import QuantizeStage - config = { - "plan_file": "msc_quantizer.json", - "strategys": [ - { - "method": "gather_maxmin", - "op_types": ["nn.conv2d", "msc.linear"], - "tensor_types": ["input", "output"], - "stages": [QuantizeStage.GATHER], - }, - { - "method": "gather_max_per_channel", - "op_types": ["nn.conv2d", "msc.linear"], - "tensor_types": ["weight"], - "stages": [QuantizeStage.GATHER], - }, - { - "method": "calibrate_maxmin", - "op_types": ["nn.conv2d", "msc.linear"], - "tensor_types": ["input", "output"], - "stages": [QuantizeStage.CALIBRATE], - }, - { - "method": "quantize_normal", - "op_types": ["nn.conv2d", "msc.linear"], - "tensor_types": ["input", "weight"], - }, - { - "method": "dequantize_normal", - "op_types": ["nn.conv2d", "msc.linear"], - "tensor_types": ["output"], - }, - ], - } + if run_type == MSCFramework.TENSORRT: + config = {"plan_file": "msc_quantizer.json", "strategys": []} + else: + op_types = ["nn.conv2d", "msc.conv2d_bias", "msc.linear", "msc.linear_bias"] + config = { + "plan_file": "msc_quantizer.json", + "strategys": [ + { + "methods": { + "input": "gather_maxmin", + "output": "gather_maxmin", + "weights": "gather_max_per_channel", + }, + "op_types": op_types, + "stages": [QuantizeStage.GATHER], + }, + { + "methods": {"input": "calibrate_maxmin", "output": "calibrate_maxmin"}, + "op_types": op_types, + "stages": [QuantizeStage.CALIBRATE], + }, + { + "methods": { + "input": "quantize_normal", + "weights": "quantize_normal", + "output": "dequantize_normal", + }, + "op_types": op_types, + }, + ], + } + tools.append({"tool_type": ToolType.QUANTIZER, "tool_config": config}) elif tool_type == ToolType.TRACKER: + # pylint: disable=import-outside-toplevel + from tvm.contrib.msc.core.utils import MSCStage + config = { "plan_file": "msc_tracker.json", "strategys": [ { - "method": "save_compared", - "compare_to": { - "optimize": ["baseline"], - "compile": ["optimize", "baseline"], + "methods": { + "output": { + "method_name": "save_compared", + "compare_to": { + MSCStage.OPTIMIZE: [MSCStage.BASELINE], + MSCStage.COMPILE: [MSCStage.OPTIMIZE, MSCStage.BASELINE], + }, + } }, "op_types": ["nn.relu"], - "tensor_types": ["output"], } ], } + tools.append({"tool_type": ToolType.TRACKER, "tool_config": config, "apply_once": True}) if use_distill: - distill_config = { + config = { "plan_file": "msc_distiller.json", "strategys": [ { - "method": "loss_lp_norm", - "op_types": ["loss"], + "methods": {"mark": "loss_lp_norm"}, + "marks": ["loss"], }, ], } - return {tool_type: config, ToolType.DISTILLER: distill_config} - return {tool_type: config} + tools.append({"tool_type": ToolType.DISTILLER, "tool_config": config}) + return tools def _get_torch_model(name, training=False): @@ -181,7 +195,7 @@ def _check_manager(manager, expected_info): def _test_from_torch( compile_type, - tools_config, + tools, expected_info, training=False, atol=1e-1, @@ -195,7 +209,7 @@ def _test_from_torch( config = _get_config( MSCFramework.TORCH, compile_type, - tools_config, + tools, inputs=[["input_0", [1, 3, 224, 224], "float32"]], outputs=["output"], atol=atol, @@ -245,21 +259,16 @@ def get_model_info(compile_type): def test_tvm_tool(tool_type): """Test tools for tvm""" - tool_config = get_tool_config(tool_type) - _test_from_torch( - MSCFramework.TVM, tool_config, get_model_info(MSCFramework.TVM), training=False - ) + tools = get_tools(tool_type) + _test_from_torch(MSCFramework.TVM, tools, get_model_info(MSCFramework.TVM), training=False) -@tvm.testing.requires_cuda @pytest.mark.parametrize("tool_type", [ToolType.PRUNER, ToolType.QUANTIZER]) def test_tvm_distill(tool_type): """Test tools for tvm with distiller""" - tool_config = get_tool_config(tool_type, use_distill=True) - _test_from_torch( - MSCFramework.TVM, tool_config, get_model_info(MSCFramework.TVM), training=False - ) + tools = get_tools(tool_type, use_distill=True) + _test_from_torch(MSCFramework.TVM, tools, get_model_info(MSCFramework.TVM), training=False) @requires_tensorrt @@ -270,15 +279,14 @@ def test_tvm_distill(tool_type): def test_tensorrt_tool(tool_type): """Test tools for tensorrt""" - tool_config = get_tool_config(tool_type) + tools = get_tools(tool_type, run_type=MSCFramework.TENSORRT) if tool_type == ToolType.QUANTIZER: - tool_config[ToolType.QUANTIZER]["strategys"] = [] optimize_type = MSCFramework.TENSORRT else: optimize_type = None _test_from_torch( MSCFramework.TENSORRT, - tool_config, + tools, get_model_info(MSCFramework.TENSORRT), training=False, atol=1e-1, @@ -292,9 +300,9 @@ def test_tensorrt_tool(tool_type): def test_tensorrt_distill(tool_type): """Test tools for tensorrt with distiller""" - tool_config = get_tool_config(tool_type, use_distill=True) + tools = get_tools(tool_type, use_distill=True) _test_from_torch( - MSCFramework.TENSORRT, tool_config, get_model_info(MSCFramework.TENSORRT), training=False + MSCFramework.TENSORRT, tools, get_model_info(MSCFramework.TENSORRT), training=False )