diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu index d9ca255502fa..5c13104cda16 100644 --- a/docker/Dockerfile.ci_gpu +++ b/docker/Dockerfile.ci_gpu @@ -77,9 +77,6 @@ RUN bash /install/ubuntu_install_onnx.sh COPY install/ubuntu_install_tflite.sh /install/ubuntu_install_tflite.sh RUN bash /install/ubuntu_install_tflite.sh -COPY install/ubuntu_install_caffe2.sh /install/ubuntu_install_caffe2.sh -RUN bash /install/ubuntu_install_caffe2.sh - COPY install/ubuntu_install_dgl.sh /install/ubuntu_install_dgl.sh RUN bash /install/ubuntu_install_dgl.sh diff --git a/docker/install/ubuntu_install_onnx.sh b/docker/install/ubuntu_install_onnx.sh index 8fc8157e1b2c..f94df2d64a17 100755 --- a/docker/install/ubuntu_install_onnx.sh +++ b/docker/install/ubuntu_install_onnx.sh @@ -36,5 +36,5 @@ pip3 install \ pip3 install future pip3 install \ - torch==1.10.1 \ - torchvision==0.11.2 + torch==1.11.0 \ + torchvision==0.12.0 diff --git a/gallery/how_to/compile_models/from_caffe2.py b/gallery/how_to/compile_models/from_caffe2.py deleted file mode 100644 index 263f98c9454f..000000000000 --- a/gallery/how_to/compile_models/from_caffe2.py +++ /dev/null @@ -1,145 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Compile Caffe2 Models -===================== -**Author**: `Hiroyuki Makino `_ - -This article is an introductory tutorial to deploy Caffe2 models with Relay. - -For us to begin with, Caffe2 should be installed. - -A quick solution is to install via conda - -.. code-block:: bash - - # for cpu - conda install pytorch-nightly-cpu -c pytorch - # for gpu with CUDA 8 - conda install pytorch-nightly cuda80 -c pytorch - -or please refer to official site -https://caffe2.ai/docs/getting-started.html -""" - -###################################################################### -# Load pretrained Caffe2 model -# ---------------------------- -# We load a pretrained resnet50 classification model provided by Caffe2. -from caffe2.python.models.download import ModelDownloader - -mf = ModelDownloader() - - -class Model: - def __init__(self, model_name): - self.init_net, self.predict_net, self.value_info = mf.get_c2_model(model_name) - - -resnet50 = Model("resnet50") - -###################################################################### -# Load a test image -# ------------------ -# A single cat dominates the examples! -from tvm.contrib.download import download_testdata -from PIL import Image -from matplotlib import pyplot as plt -import numpy as np - -img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true" -img_path = download_testdata(img_url, "cat.png", module="data") -img = Image.open(img_path).resize((224, 224)) -plt.imshow(img) -plt.show() -# input preprocess -def transform_image(image): - image = np.array(image) - np.array([123.0, 117.0, 104.0]) - image /= np.array([58.395, 57.12, 57.375]) - image = image.transpose((2, 0, 1)) - image = image[np.newaxis, :].astype("float32") - return image - - -data = transform_image(img) - -###################################################################### -# Compile the model on Relay -# -------------------------- - -# Caffe2 input tensor name, shape and type -input_name = resnet50.predict_net.op[0].input[0] -shape_dict = {input_name: data.shape} -dtype_dict = {input_name: data.dtype} - -# parse Caffe2 model and convert into Relay computation graph -from tvm import relay, transform - -mod, params = relay.frontend.from_caffe2( - resnet50.init_net, resnet50.predict_net, shape_dict, dtype_dict -) - -# compile the model -# target x86 CPU -target = "llvm" -with transform.PassContext(opt_level=3): - lib = relay.build(mod, target, params=params) - -###################################################################### -# Execute on TVM -# --------------- -# The process is no different from other examples. -import tvm -from tvm import te -from tvm.contrib import graph_executor - -# context x86 CPU, use tvm.cuda(0) if you run on GPU -dev = tvm.cpu(0) -# create a runtime executor module -m = graph_executor.GraphModule(lib["default"](dev)) -# set inputs -m.set_input(input_name, tvm.nd.array(data.astype("float32"))) -# execute -m.run() -# get outputs -tvm_out = m.get_output(0) -top1_tvm = np.argmax(tvm_out.numpy()[0]) - -##################################################################### -# Look up synset name -# ------------------- -# Look up prediction top 1 index in 1000 class synset. -from caffe2.python import workspace - -synset_url = "".join( - [ - "https://gist.githubusercontent.com/zhreshold/", - "4d0b62f3d01426887599d4f7ede23ee5/raw/", - "596b27d23537e5a1b5751d2b0481ef172f58b539/", - "imagenet1000_clsid_to_human.txt", - ] -) -synset_name = "imagenet1000_clsid_to_human.txt" -synset_path = download_testdata(synset_url, synset_name, module="data") -with open(synset_path) as f: - synset = eval(f.read()) -print("Relay top-1 id: {}, class name: {}".format(top1_tvm, synset[top1_tvm])) -# confirm correctness with caffe2 output -p = workspace.Predictor(resnet50.init_net, resnet50.predict_net) -caffe2_out = p.run({input_name: data}) -top1_caffe2 = np.argmax(caffe2_out) -print("Caffe2 top-1 id: {}, class name: {}".format(top1_caffe2, synset[top1_caffe2])) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index e0bc9358cc9b..a5411ce4d0b6 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2039,6 +2039,14 @@ def stack(self, inputs, input_types): assert isinstance(ty, tvm.ir.TypeCall) and ty.func == list_ty, msg return self.tensor_array_stack(inputs, input_types) + def sub(self, inputs, input_types): + if len(inputs) == 3: + data0, data1, alpha = self.pytorch_promote_types(inputs, input_types) + return get_relay_op("subtract")(data0, alpha * data1) + else: + data0, data1= self.pytorch_promote_types(inputs, input_types) + return get_relay_op("subtract")(data0, data1) + def rsub(self, inputs, input_types): data0, data1, alpha = self.pytorch_promote_types(inputs, input_types) @@ -2859,7 +2867,10 @@ def all_any_common(self, op, inputs, input_types): inp = inputs[0] return op(inp, axis=dim, keepdims=keepdim) - def searchsorted_common(self, sorted_sequence, values, out_int32, right): + def searchsorted_common( + self, sorted_sequence, values, out_int32, right, side=None, out=None, sorter=None + ): + assert side is None and out is None and sorter is None, "unsupported parameters" dtype = "int32" if out_int32 else "int64" values_shape = _infer_shape(values) @@ -2959,7 +2970,7 @@ def create_convert_map(self): "aten::pixel_shuffle": self.pixel_shuffle, "aten::device": self.none, "prim::device": self.none, - "aten::sub": self.make_elemwise("subtract"), + "aten::sub": self.sub, "aten::max": self.max, "aten::min": self.min, "aten::mul": self.make_elemwise("multiply"), diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 8c9ce16acc1e..6e87b9ee4f6f 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -743,9 +743,10 @@ def __init__(self, inputsize=(128, 128)): self.backbone = Backbone() def fuse_model(self): + fuse_modules_qat = getattr(torch.ao.quantization, "fuse_modules_qat", fuse_modules) for idx, m in enumerate(self.modules()): if type(m) == ConvBnRelu: - torch.quantization.fuse_modules(m, ["conv", "bn", "relu"], inplace=True) + fuse_modules_qat(m, ["conv", "bn", "relu"], inplace=True) def forward(self, input): input = self.quant(input) diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh index af1c03bc2def..d7e1b5113f7c 100755 --- a/tests/scripts/task_python_frontend.sh +++ b/tests/scripts/task_python_frontend.sh @@ -56,9 +56,6 @@ function shard2 { i=$((i+1)) done - echo "Running relay caffe2 frontend test..." - run_pytest cython python-frontend-caffe2 tests/python/frontend/caffe2 - echo "Running relay DarkNet frontend test..." run_pytest cython python-frontend-darknet tests/python/frontend/darknet