diff --git a/Makefile b/Makefile index 7daddbd955af..e7dcebc3c586 100644 --- a/Makefile +++ b/Makefile @@ -37,7 +37,7 @@ LIBHALIDEIR: + cd HalideIR; make lib/libHalideIR.a ; cd $(ROOTDIR) lint: - python2 dmlc-core/scripts/lint.py tvm cpp include src + python2 dmlc-core/scripts/lint.py tvm all include src python doc: doxygen docs/Doxyfile diff --git a/dmlc-core b/dmlc-core index f294fc2271b2..749e570c1942 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit f294fc2271b27b0b6e2b117003ed2dc3d3ba8fda +Subproject commit 749e570c19423fe679a5f496e2394ba3bed75a16 diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h new file mode 100644 index 000000000000..beed7e9d1281 --- /dev/null +++ b/include/tvm/buffer.h @@ -0,0 +1,98 @@ + +/*! + * Copyright (c) 2016 by Contributors + * \file buffer.h + * \brief Symbolic n-dimensional array, to represent a memory buffer. + */ +#ifndef TVM_BUFFER_H_ +#define TVM_BUFFER_H_ + +#include +#include + +#include "./base.h" +#include "./expr.h" + +namespace tvm { + +// Internal node container Buffer +class BufferNode; +/*! + * \brief Buffer is a symbolic n-darray structure. + * It is a composition of primitive symbolic types, + * used to specify input/output strcuture of the program. + */ +class Buffer : public NodeRef { + public: + Buffer() {} + explicit Buffer(std::shared_ptr n) : NodeRef(n) {} + /*! + * \brief construct a new buffer based on shape and strides. + */ + explicit Buffer(Array shape, + Type dtype = Float(32), + std::string name = "buffer"); + /*! + * \brief Generate a load expression loading the index location of buffer. + * \param index The index to the buffer. + * \return The load expression. + */ + Expr MakeLoad(Array index) const; + /*! + * \brief Generate a store statement. + * \param index The index to the buffer. + * \param value The value to be stored. + * \return The load expression. + */ + Stmt MakeStore(Array index, Expr value) const; + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const BufferNode* operator->() const; +}; + +/*! \brief Node to represent a buffer */ +class BufferNode : public Node { + public: + /*! \brief optional name of the buffer */ + std::string name; + /*! \brief The pointer to the head of the data */ + Var ptr; + /*! \brief The shape of the buffer */ + Array shape; + /*! + * \brief The strides of each dimension + * This can be an empty array, indicating array is contiguous + */ + Array strides; + /*! \brief data type in the content of the tensor */ + Type dtype; + // Maybe need more information(alignment) later + /*! \brief constructor */ + BufferNode() {} + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("name", &name); + v->Visit("ptr", &ptr); + v->Visit("shape", &shape); + v->Visit("strides", &strides); + v->Visit("dtype", &dtype); + } + + static Buffer make(std::string name, + Var ptr, + Array shape, + Array strides, + Type dtype); + + static constexpr const char* _type_key = "Buffer"; + TVM_DECLARE_NODE_TYPE_INFO(BufferNode); +}; + +inline const BufferNode* Buffer::operator->() const { + return static_cast(node_.get()); +} + +} // namespace tvm +#endif // TVM_BUFFER_H_ diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index c93d748a1856..d4456ed74cd4 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -13,6 +13,7 @@ #include #include #include "./expr.h" +#include "./buffer.h" #include "./schedule.h" namespace tvm { @@ -56,10 +57,22 @@ Stmt ConvertSSA(Stmt stmt); * * \note All the passes in this file uses SSA form and outputs SSA form. */ -Stmt Inline(FunctionRef f, +Stmt Inline(Stmt stmt, + FunctionRef f, Array args, - Expr body, - Stmt stmt); + Expr body); + + +/*! + * \brief Flatten the multi-dimensional read/write + * to single dimensional Load/Store + * + * \param stmt The stmt to be trasnformed. + * \param extern_buffer Map specifies external + * buffer assignment of input and outputs. + */ +Stmt StorageFlatten(Stmt stmt, + Map extern_buffer); } // namespace ir } // namespace tvm diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 07fca6ecab39..3725cd62c15d 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -1,3 +1,4 @@ +# pylint: disable=redefined-builtin, wildcard-import """C++ backend related python scripts""" from __future__ import absolute_import as _abs from ._ctypes._api import register_node diff --git a/python/tvm/_base.py b/python/tvm/_base.py index 88ac476ecc6f..b67275cb4030 100644 --- a/python/tvm/_base.py +++ b/python/tvm/_base.py @@ -1,10 +1,9 @@ # coding: utf-8 -# pylint: disable=invalid-name +# pylint: disable=invalid-name, no-member """ ctypes library of nnvm and helper functions """ from __future__ import absolute_import import sys -import os import ctypes import numpy as np from . import libinfo diff --git a/python/tvm/_ctypes/_api.py b/python/tvm/_ctypes/_api.py index c9d980928967..4de64f9db4b6 100644 --- a/python/tvm/_ctypes/_api.py +++ b/python/tvm/_ctypes/_api.py @@ -1,5 +1,6 @@ # coding: utf-8 # pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines +# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring """Symbolic configuration API.""" from __future__ import absolute_import as _abs @@ -14,6 +15,7 @@ from .. import _function_internal class ArgVariant(ctypes.Union): + """ArgVariant in C API""" _fields_ = [("v_long", ctypes.c_long), ("v_double", ctypes.c_double), ("v_str", ctypes.c_char_p), @@ -30,8 +32,8 @@ class ArgVariant(ctypes.Union): def _return_node(x): handle = x.v_handle - if not isinstance(handle, ctypes.c_void_p): - handle = ctypes.c_void_p(handle) + if not isinstance(handle, NodeHandle): + handle = NodeHandle(handle) ret_val = ArgVariant() ret_typeid = ctypes.c_int() ret_success = ctypes.c_int() @@ -47,7 +49,7 @@ def _return_node(x): kLong: lambda x: x.v_long, kDouble: lambda x: x.v_double, kStr: lambda x: py_str(x.v_str), - kNodeHandle: lambda x: _return_node(x) + kNodeHandle: _return_node } class SliceBase(object): @@ -251,6 +253,7 @@ def register_node(type_key=None): """ if isinstance(type_key, str): def register(cls): + """internal register function""" NODE_TYPE[type_key] = cls return cls return register @@ -273,9 +276,9 @@ def _init_function_module(root_namespace): module_obj = sys.modules["%s.function" % root_namespace] module_internal = sys.modules["%s._function_internal" % root_namespace] namespace_match = { - "_make_" : sys.modules["%s.make" % root_namespace], - "_pass_" : sys.modules["%s.ir_pass" % root_namespace], - "_schedule_" : sys.modules["%s.schedule" % root_namespace] + "_make_": sys.modules["%s.make" % root_namespace], + "_pass_": sys.modules["%s.ir_pass" % root_namespace], + "_schedule_": sys.modules["%s.schedule" % root_namespace] } for name in op_names: diff --git a/python/tvm/collections.py b/python/tvm/collections.py index eb988930408d..85e629cc96da 100644 --- a/python/tvm/collections.py +++ b/python/tvm/collections.py @@ -1,3 +1,4 @@ +# pylint: disable=protected-access, no-member """Collection structure in the high level DSL.""" from __future__ import absolute_import as _abs from ._ctypes._api import NodeBase, register_node @@ -6,6 +7,7 @@ @register_node class Array(NodeBase): + """Array container of TVM""" def __getitem__(self, i): if i >= len(self): raise IndexError("array index out ot range") @@ -19,6 +21,7 @@ def __repr__(self): @register_node class Map(NodeBase): + """Map container of TVM""" def __getitem__(self, k): return _function_internal._MapGetItem(self, k) @@ -26,6 +29,7 @@ def __contains__(self, k): return _function_internal._MapCount(self, k) != 0 def items(self): + """Get the items from the map""" akvs = _function_internal._MapItems(self) return [(akvs[i], akvs[i+1]) for i in range(0, len(akvs), 2)] @@ -38,9 +42,17 @@ def __repr__(self): @register_node class Range(NodeBase): + """Represent range in TVM""" pass @register_node class IterVar(NodeBase, _expr.ExprOp): + """Represent iteration variable.""" + pass + + +@register_node +class Buffer(NodeBase): + """Represent a Buffer in TVM.""" pass diff --git a/python/tvm/expr.py b/python/tvm/expr.py index a6de17f8cb55..3d75fa4c7d3f 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -1,3 +1,4 @@ +# pylint: disable=protected-access, no-member, missing-docstring from __future__ import absolute_import as _abs from ._ctypes._api import NodeBase, register_node from . import make as _make @@ -174,7 +175,7 @@ class Call(Expr): Halide = 3 Intrinsic = 4 PureIntrinsic = 5 - pass + @register_node class Let(Expr): diff --git a/python/tvm/function.py b/python/tvm/function.py index 43e688276362..78491404d7b1 100644 --- a/python/tvm/function.py +++ b/python/tvm/function.py @@ -1,5 +1,8 @@ +# pylint: disable=protected-access, no-member, invalid-name +# pylint: disable=redefined-builtin, undefined-variable +"""Functions defined in TVM.""" from __future__ import absolute_import as _abs -from numbers import Number as _Number, Integral as _Integral +from numbers import Integral as _Integral from ._ctypes._api import _init_function_module, convert from . import _function_internal from . import make as _make @@ -8,6 +11,7 @@ int32 = "int32" float32 = "float32" +handle = "handle" def const(value, dtype=None): """construct a constant""" @@ -65,7 +69,7 @@ def Var(name="tindex", dtype=int32): return _function_internal._Var(name, dtype) -def placeholder(shape, dtype = None, name="placeholder"): +def placeholder(shape, dtype=None, name="placeholder"): """Construct an empty tensor object. Parameters @@ -84,6 +88,7 @@ def placeholder(shape, dtype = None, name="placeholder"): tensor: tensor.Tensor The created tensor """ + shape = (shape,) if isinstance(shape, _expr.Expr) else shape dtype = float32 if dtype is None else dtype return _function_internal._Placeholder( shape, dtype, name) @@ -111,8 +116,7 @@ def compute(shape, fcompute, name="compute"): tensor: tensor.Tensor The created tensor """ - if isinstance(shape, _expr.Expr): - shape = (shape, ) + shape = (shape,) if isinstance(shape, _expr.Expr) else shape ndim = len(shape) arg_names = fcompute.__code__.co_varnames @@ -125,7 +129,44 @@ def compute(shape, fcompute, name="compute"): op_node = _function_internal._ComputeOp( name, dim_var, body) return _function_internal._Tensor( - shape, name, body.dtype, op_node, 0) + shape, body.dtype, op_node, 0) + + +def Buffer(shape, dtype=None, + name="buffer", ptr=None, + strides=None): + """Create a new buffer + + Parameters + ---------- + shape : tuple of Expr + The shape of the buffer. + + dtype : str, optional + The data type of the buffer. + + name : str, optional + The name of the buffer. + + ptr : Var, optional + The data pointer in the buffer. + + strides: array of Expr + The stride of the buffer. + + Returns + ------- + buffer : Buffer + The created buffer + """ + shape = (shape,) if isinstance(shape, _expr.Expr) else shape + dtype = float32 if dtype is None else dtype + strides = () if strides is None else strides + if ptr is None: + ptr = Var(name, "handle") + + return _function_internal._Buffer( + name, ptr, shape, strides, dtype) def IterVar(dom, name='iter', thread_tag=''): @@ -170,7 +211,7 @@ def sum(expr, rdom): The reduction domainx """ rdom = rdom if isinstance(rdom, list) else [rdom] - x = _make.Reduce("Add", expr, rdom) + x = _make.Reduce("Add", expr, rdom) return x @@ -186,7 +227,7 @@ def min(expr, rdom): The reduction domainx """ rdom = rdom if isinstance(rdom, list) else [rdom] - x = _make.Reduce("Min", expr, rdom) + x = _make.Reduce("Min", expr, rdom) return x @@ -202,7 +243,7 @@ def max(expr, rdom): The reduction domainx """ rdom = rdom if isinstance(rdom, list) else [rdom] - x = _make.Reduce("Max", expr, rdom) + x = _make.Reduce("Max", expr, rdom) return x diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index 7a5282b1219f..a8ecb97bf27b 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -1,3 +1,4 @@ +# pylint: disable=protected-access, no-member """Collection structure in the high level DSL.""" from __future__ import absolute_import as _abs from ._ctypes._api import NodeBase, register_node @@ -6,15 +7,18 @@ @register_node class Split(NodeBase): + """Split operation on axis.""" pass @register_node class Fuse(NodeBase): + """Fuse operation on axis.""" pass @register_node class Schedule(NodeBase): + """Schedule for all the stages.""" def __getitem__(self, k): if isinstance(k, _tensor.Tensor): k = k.op @@ -26,6 +30,7 @@ def __getitem__(self, k): @register_node class Stage(NodeBase): + """A Stage represents schedule for one operation.""" def split(self, parent, factor=None, outer=None): """Split the stage either by factor providing outer scope, or both @@ -132,6 +137,32 @@ def reorder(self, *args): _function_internal._StageReorder(self, args) def tile(self, x_parent, y_parent, x_factor, y_factor): + """ Perform tiling on two dimensions + + The final loop order from outmost to inner most are + [x_outer, y_outer, x_inner, y_inner] + + Parameters + ---------- + x_parent : IterVar + The original x dimension + y_parent : IterVar + The original y dimension + x_factor : Expr + The stride factor on x axis + y_factor : Expr The stride factor on y axis + + Returns + ------- + x_outer : IterVar + Outer axis of x dimension + y_outer : IterVar + Outer axis of y dimension + x_inner : IterVar + Inner axis of x dimension + p_y_inner : IterVar + Inner axis of y dimension + """ x_outer, y_outer, x_inner, y_inner = _function_internal._StageTile( self, x_parent, y_parent, x_factor, y_factor) return x_outer, y_outer, x_inner, y_inner diff --git a/python/tvm/stmt.py b/python/tvm/stmt.py index cc22e168782a..97ef3b6b3c8c 100644 --- a/python/tvm/stmt.py +++ b/python/tvm/stmt.py @@ -1,6 +1,6 @@ +# pylint: disable=protected-access, no-member, missing-docstring from __future__ import absolute_import as _abs from ._ctypes._api import NodeBase, register_node -from . import make as _make class Stmt(NodeBase): pass @@ -23,7 +23,6 @@ class For(Stmt): Parallel = 1 Vectorized = 2 Unrolled = 3 - pass @register_node class Store(Stmt): diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index 99e14180aa7b..fdaec1d33c25 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -1,3 +1,5 @@ +# pylint: disable=protected-access, no-member, invalid-name +"""Tensor related abstractions""" from __future__ import absolute_import as _abs from ._ctypes._api import NodeBase, SliceBase, register_node, convert from . import collections as _collections @@ -51,10 +53,12 @@ def __eq__(self, other): @property def ndim(self): + """Dimension of the tensor.""" return len(self.shape) class Operation(NodeBase): + """Represent an operation that generate a tensor""" def output(self, index): """Get the index-th output of the operation @@ -72,8 +76,10 @@ def output(self, index): @register_node class ComputeOp(Operation): + """Compute operation.""" pass @register_node class PlaceholderOp(Operation): + """Placeholder operation.""" pass diff --git a/src/base/common.h b/src/base/common.h index 66ffffef5cc2..0485bdfc4af0 100644 --- a/src/base/common.h +++ b/src/base/common.h @@ -12,6 +12,7 @@ namespace tvm { inline std::string Type2String(const Type& t) { + if (t.code() ==Type::Handle) return "handle"; std::ostringstream os; os << t; return os.str(); @@ -28,6 +29,8 @@ inline Type String2Type(std::string s) { code = Type::Float; s = s.substr(5); } else if (s.substr(0, 5) == "float") { code = Type::Float; s = s.substr(5); + } else if (s == "handle") { + return Type(Type::Handle, 0, 0); } else { LOG(FATAL) << "unknown type " << s; } diff --git a/src/c_api/c_api_ir.cc b/src/c_api/c_api_ir.cc index af66e9db6690..c2110c76d9f6 100644 --- a/src/c_api/c_api_ir.cc +++ b/src/c_api/c_api_ir.cc @@ -123,6 +123,7 @@ REGISTER_MAKE3(Let); REGISTER_MAKE3(LetStmt); REGISTER_MAKE2(AssertStmt); REGISTER_MAKE3(ProducerConsumer); +REGISTER_MAKE3(Load); REGISTER_MAKE3(Store); REGISTER_MAKE4(Provide); REGISTER_MAKE1(Free); diff --git a/src/c_api/c_api_lang.cc b/src/c_api/c_api_lang.cc index 3f2b4e2a0abd..948783966dfd 100644 --- a/src/c_api/c_api_lang.cc +++ b/src/c_api/c_api_lang.cc @@ -5,6 +5,7 @@ */ #include #include +#include #include #include "./c_api_registry.h" @@ -140,14 +141,23 @@ TVM_REGISTER_API(Range) .add_argument("begin", "Expr", "beginning of the range.") .add_argument("end", "Expr", "extent of the range"); -TVM_REGISTER_API(_Tensor) +TVM_REGISTER_API(_Buffer) .set_body([](const ArgStack& args, RetValue *ret) { - *ret = TensorNode::make(args.at(0), + *ret = BufferNode::make(args.at(0), + args.at(1), args.at(2), args.at(3), args.at(4)); }); +TVM_REGISTER_API(_Tensor) +.set_body([](const ArgStack& args, RetValue *ret) { + *ret = TensorNode::make(args.at(0), + args.at(1), + args.at(2), + args.at(3)); + }); + TVM_REGISTER_API(_TensorEqual) .set_body([](const ArgStack& args, RetValue *ret) { *ret = args.at(0).operator Tensor() == args.at(1).operator Tensor(); diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc new file mode 100644 index 000000000000..02cd05224e53 --- /dev/null +++ b/src/lang/buffer.cc @@ -0,0 +1,78 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file buffer.cc + */ +#include +#include + +namespace tvm { + +Array GetStrides(Array shape) { + CHECK_NE(shape.size(), 0U); + std::vector vec{make_const(shape[0].type(), 1)}; + for (size_t i = shape.size() - 1; i != 0; --i) { + vec.push_back(shape[i - 1] * vec.back()); + } + return Array(vec.rbegin(), vec.rend()); +} + +Buffer::Buffer(Array shape, + Type dtype, + std::string name) + : Buffer(BufferNode::make( + name, + Var(name, Type(Type::Handle, 0, 0)), + shape, Array(), dtype)) { +} + +inline Expr BufferOffset(const BufferNode* n, Array index) { + Expr base; + if (n->strides.size() == 0) { + CHECK_EQ(n->shape.size(), index.size()); + base = index[0]; + for (size_t i = 1; i < index.size(); ++i) { + base = base * n->shape[i] + index[i]; + } + } else { + CHECK_EQ(n->strides.size(), index.size()); + base = index[0] * n->strides[0]; + for (size_t i = 1; i < index.size(); ++i) { + base = base + index[i] * n->strides[i]; + } + } + return base; +} + +Expr Buffer::MakeLoad(Array index) const { + const BufferNode* n = operator->(); + return ir::Load::make(n->dtype, n->ptr, BufferOffset(n, index)); +} + +Stmt Buffer::MakeStore(Array index, Expr value) const { + const BufferNode* n = operator->(); + CHECK_EQ(value.type(), n->dtype); + return ir::Store::make(n->ptr, BufferOffset(n, index), value); +} + +Buffer BufferNode::make(std::string name, + Var ptr, + Array shape, + Array strides, + Type dtype) { + auto n = std::make_shared(); + n->name = name; + n->ptr = ptr; + n->shape = shape; + n->strides = strides; + n->dtype = dtype; + return Buffer(n); +} + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const BufferNode *op, IRPrinter *p) { + p->stream << "buffer(" << op->name << ", " << op << ")"; +}); + +TVM_REGISTER_NODE_TYPE(BufferNode); + +} // namespace tvm diff --git a/src/lang/operation.cc b/src/lang/operation.cc index 1883a5eacff3..ce26e65da8fe 100644 --- a/src/lang/operation.cc +++ b/src/lang/operation.cc @@ -1,4 +1,3 @@ - /*! * Copyright (c) 2016 by Contributors * \file operation.cc diff --git a/src/pass/inline.cc b/src/pass/inline.cc index b912e30897db..085fe738eaeb 100644 --- a/src/pass/inline.cc +++ b/src/pass/inline.cc @@ -8,7 +8,6 @@ namespace tvm { namespace ir { -namespace { // inliner to inline a function // the result may not be SSA, @@ -50,12 +49,10 @@ class IRInline : public IRMutator { } }; -} // namespace - -Stmt Inline(FunctionRef f, +Stmt Inline(Stmt stmt, + FunctionRef f, Array args, - Expr body, - Stmt stmt) { + Expr body) { CHECK_EQ(f->num_outputs(), 1) << "can only inline output single value operation"; return ConvertSSA(IRInline(f, args, body).Mutate(stmt)); diff --git a/src/pass/schedule_ops.cc b/src/pass/schedule_ops.cc index 3cdb8f171a7f..a62cf678b8cf 100644 --- a/src/pass/schedule_ops.cc +++ b/src/pass/schedule_ops.cc @@ -13,7 +13,6 @@ namespace tvm { namespace ir { -namespace { /*! * \brief use message passing to calculate the assignment of each Var inside the loop body. @@ -256,7 +255,7 @@ Stmt MakePipeline(const Stage& sch, if (sch->op.as()) { provide = MakeProvide(sch->op.as(), tensors); } else { - LOG(FATAL) << "not supported op"; + LOG(FATAL) << "not supported op " << sch->op->type_key(); } std::vector > nest = MakeLoopNest(sch, dom_map); Stmt producer = MergeNest(nest, provide); @@ -317,10 +316,9 @@ Stmt InjectInline(const Operation op, Stmt body) { for (auto iv : compute->axis) { args.push_back(iv->var); } - return Inline(op, args, compute->body, body); + return Inline(body, op, args, compute->body); } -} // namespace Stmt ScheduleOps( Schedule sch, Map dom_map) { @@ -328,6 +326,8 @@ Stmt ScheduleOps( // reverse the post DFS order. for (size_t i = sch->stages.size(); i != 0; --i) { Stage s = sch->stages[i - 1]; + // no need to specify place holder op. + if (s->op.as()) continue; if (s->attach_type == kInline) { body = InjectInline(s->op, body); } else if (s->attach_type == kRoot || s-> attach_type == kNone) { diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc index 9ac720305590..6a23f48d8c90 100644 --- a/src/schedule/bound.cc +++ b/src/schedule/bound.cc @@ -151,8 +151,10 @@ BoundProp(const Array& post_order, } }; ir::PostOrderVisit(op.as()->body, fvisit); + } else if (op.as()) { + // do nothing } else { - LOG(FATAL) << "unknown operation mode"; + LOG(FATAL) << "unknown operation mode " << op->type_key(); } } return result; diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc index 5b13c1569078..530eecaac971 100644 --- a/src/schedule/graph.cc +++ b/src/schedule/graph.cc @@ -42,10 +42,11 @@ ReadGraph CreateReadGraph(const Array& roots) { }; ir::PostOrderVisit(op.as()->body, fvisit); rmap.Set(op, deps); + } else if (op.as()) { + // empty set of deps + rmap.Set(op, deps); } else { - if (!op.as()) { - LOG(FATAL) << "unknown Operation" << op->type_key(); - } + LOG(FATAL) << "unknown Operation" << op->type_key(); } } return rmap; @@ -56,7 +57,7 @@ void PostDFSOrder(const Operation& op, const ReadGraph& g, std::unordered_set* visited, Array* post_order) { - if (op.as() || visited->count(op)) return; + if (visited->count(op)) return; visited->insert(op); for (const auto& t : g.at(op)) { PostDFSOrder(t->op, g, visited, post_order); diff --git a/tests/python/test_lang_buffer.py b/tests/python/test_lang_buffer.py new file mode 100644 index 000000000000..b1b6b7d9b8fb --- /dev/null +++ b/tests/python/test_lang_buffer.py @@ -0,0 +1,16 @@ +import tvm + +def test_buffer(): + m = tvm.Var('m') + n = tvm.Var('n') + l = tvm.Var('l') + Ab = tvm.Buffer((m, n), tvm.float32) + Bb = tvm.Buffer((n, l), tvm.float32) + + assert isinstance(Ab, tvm.collections.Buffer) + assert Ab.dtype == tvm.float32 + assert tuple(Ab.shape) == (m, n) + + +if __name__ == "__main__": + test_buffer() diff --git a/tests/python/test_lang_tensor.py b/tests/python/test_lang_tensor.py index af0632866404..01ab5109f628 100644 --- a/tests/python/test_lang_tensor.py +++ b/tests/python/test_lang_tensor.py @@ -33,6 +33,7 @@ def test_tensor_reduce(): assert(isinstance(C_loaded, tvm.tensor.Tensor)) assert(str(C_loaded) == str(C)) + if __name__ == "__main__": test_tensor() test_tensor_reduce() diff --git a/tests/python/test_pass_inline.py b/tests/python/test_pass_inline.py index 858864c60b75..43149fd3d966 100644 --- a/tests/python/test_pass_inline.py +++ b/tests/python/test_pass_inline.py @@ -6,7 +6,7 @@ def test_inline(): T = tvm.compute((m,), lambda i,: A[i] + 10, name='T') stmt = tvm.make.Evaluate(T[10] + 11 * T[100]) stmt = tvm.ir_pass.Inline( - T.op, [x.var for x in T.op.axis], T.op.body, stmt) + stmt, T.op, [x.var for x in T.op.axis], T.op.body) print(stmt) assert(tvm.ir_pass.VerifySSA(stmt)) diff --git a/tests/python/test_schedule_bound_inference.py b/tests/python/test_schedule_bound_inference.py index 9e8c70cac66b..e80fb275c561 100644 --- a/tests/python/test_schedule_bound_inference.py +++ b/tests/python/test_schedule_bound_inference.py @@ -63,8 +63,8 @@ def test_create_read_graph(): assert g[A2.op][0] == A1 assert g[A1.op][0] == A post_order = tvm.schedule.PostDFSOrder([A2.op], g) - assert(post_order[0] == A1.op) - assert(post_order[1] == A2.op) + assert(post_order[0] == A.op) + assert(post_order[1] == A1.op) if __name__ == "__main__":