Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions python/tvm/meta_schedule/testing/te_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,74 @@ def softmax_mn(m, n) -> Tuple[te.Tensor, te.Tensor]: # pylint: disable=invalid-
return (a, b)


def conv2d_nhwc_f16( # pylint: disable=invalid-name,missing-docstring
N: int,
H: int,
W: int,
CI: int,
CO: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
):
inputs = te.placeholder((N, H, W, CI), name="inputs", dtype="float16")
weight = te.placeholder(
(kernel_size, kernel_size, CI // groups, CO), name="weight", dtype="float16"
)
batch_size, in_h, in_w, _ = inputs.shape
k_h, k_w, channel_per_group, out_channel = weight.shape
out_channel_per_group = out_channel // groups

out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1
out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1
rh = te.reduce_axis((0, k_h), name="rh")
rw = te.reduce_axis((0, k_w), name="rw")
rc = te.reduce_axis((0, channel_per_group), name="rc")

padded = topi.nn.pad(inputs, [0, padding, padding, 0])
output = te.compute(
(batch_size, out_h, out_w, out_channel),
lambda n, h, w, co: te.sum(
(
tir.Cast(
value=padded[
n,
h * stride + rh * dilation,
w * stride + rw * dilation,
co // out_channel_per_group * channel_per_group + rc,
],
dtype="float32",
)
* tir.Cast(value=weight[rh, rw, rc, co], dtype="float32")
),
axis=[rh, rw, rc],
),
name="conv2d_nhwc",
)
return (inputs, weight, output)


def batch_matmul_nkkm_f16( # pylint: disable=invalid-name,missing-docstring
B: int,
N: int,
M: int,
K: int,
) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
x = te.placeholder((B, N, K), name="X", dtype="float16")
y = te.placeholder((B, K, M), name="Y", dtype="float16")
k = te.reduce_axis((0, K), name="k")
z = te.compute( # pylint: disable=invalid-name
(B, N, M),
lambda b, i, j: te.sum(
tir.Cast("float32", x[b][i][k]) * tir.Cast("float32", y[b][k][j]), axis=[k]
),
name="Z",
)
return (x, y, z)


def create_te_workload(name: str, idx: int) -> tir.PrimFunc:
workload_func, params = CONFIGS[name]
return te.create_prim_func(workload_func(*params[idx])) # type: ignore
Expand Down
34 changes: 34 additions & 0 deletions python/tvm/tir/schedule/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,37 @@ def get_tensorize_loop_mapping(
TensorizeInfo structure if a valid mapping is found, None otherwise
"""
return _ffi_api.GetTensorizeLoopMapping(sch, block, desc_func) # type: ignore


@tvm._ffi.register_object("tir.schedule.AutoTensorizeMappingInfo")
class AutoTensorizeMappingInfo(Object):
"""Necessary information used to perform transformations for tensorization."""


def get_auto_tensorize_mapping_info(
sch: Schedule, block: BlockRV, desc_func: PrimFunc
) -> Optional[AutoTensorizeMappingInfo]:
"""Get mapping info between a target block and an intrinsic description including layout
transformations to apply.

Parameters
----------
sch : Schedule
The schedule to be tensorized
block : BlockRV
The compute block for auto tensorization
desc_func : PrimFunc
The prim func describing the computation to be tensorized

Returns
-------
auto_tensorize_mapping_info : Optional[AutoTensorizeMappingInfo]
AutoTensorizeMappingInfo structure if potential mappings found, None otherwise.

Note
----
Returning a valid AutoTensorizeMappingInfo doesn't guarantee the block can be tensorized.
We will need to apply the suggested layout transformations and then match against the tensor
intrinsics.
"""
return _ffi_api.GetAutoTensorizeMappingInfo(sch, block, desc_func) # type: ignore
50 changes: 50 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,56 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
const tir::StmtSRef& block_sref,
const tir::PrimFunc& desc_func);

/*!\brief Necessary information used to perform transformations for tensorization */
class AutoTensorizeMappingInfoNode : public Object {
public:
/*! \brief Possible mappings to apply to block iters */
Array<IndexMap> mappings;

/* Additional information from AutoTensorizeComparator */

/*! \brief Mapping from LHS buffer to RHS buffer */
Map<Buffer, Buffer> lhs_buffer_map;
/*! \brief Buffer indices on RHS */
Map<Buffer, Array<PrimExpr>> rhs_buffer_indices;
/*! \brief Block iters on LHS */
Array<IterVar> lhs_iters;
/*! \brief Block iters on RHS */
Array<IterVar> rhs_iters;

void VisitAttrs(AttrVisitor* v) {
v->Visit("mappings", &mappings);
v->Visit("lhs_buffer_map", &lhs_buffer_map);
v->Visit("rhs_buffer_indices", &rhs_buffer_indices);
v->Visit("lhs_iters", &lhs_iters);
v->Visit("rhs_iters", &rhs_iters);
}

static constexpr const char* _type_key = "tir.schedule.AutoTensorizeMappingInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(AutoTensorizeMappingInfoNode, Object);
};

class AutoTensorizeMappingInfo : public ObjectRef {
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AutoTensorizeMappingInfo, ObjectRef,
AutoTensorizeMappingInfoNode);
};

/*!
* \brief Get mapping info between a target block and an intrinsic description including layout
* transformations to apply.
* \param self The schedule state
* \param block_sref The compute block for auto tensorization
* \param desc_func The prim func describing the computation to be tensorized
* \return AutoTensorizeMappingInfo structure if a potential mapping is found, NullOpt otherwise.
* \note Returning a valid AutoTensorizeMappingInfo doesn't guarantee the block can be tensorized.
* We will need to apply the suggested layout transformations and then match against the tensor
* intrinsics.
*/
Optional<AutoTensorizeMappingInfo> GetAutoTensorizeMappingInfo(const ScheduleState& self,
const StmtSRef& block_sref,
const PrimFunc& desc_func);

} // namespace tir
} // namespace tvm

Expand Down
Loading