Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 35 additions & 10 deletions src/relax/transform/static_plan_block_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,16 +353,21 @@ class StorageAllocatorBaseVisitor : public ExprVisitor {
* the input function signature in the analyzer.
* \param func The function to be analyzed.
* \param ana The analyzer which contains the TIR var upper bounds.
* \param dom_map The domain map of the TIR variables.
*/
void SetTIRVarUpperBound(Function func, arith::Analyzer* ana) {
void SetTIRVarUpperBound(Function func, arith::Analyzer* ana,
Map<tir::Var, arith::IntSet>* dom_map) {
// Use the attribute-annotated TIR var upper bounds as the TIR var values for
// memory planning.
// NOTE: we only apply the annotated upper bounds to the TIR variables that
// appear in the **function signature**.
Map<ObjectRef, ObjectRef> var_upper_bound_attr_raw =
func->GetAttr<Map<ObjectRef, ObjectRef>>("tir_var_upper_bound")
.value_or(Map<ObjectRef, ObjectRef>());
Array<ObjectRef> non_negative_var_attr_raw =
func->GetAttr<Array<ObjectRef>>("tir_non_negative_var").value_or(Array<ObjectRef>());
std::unordered_map<String, IntImm> var_upper_bound_attr;
std::unordered_set<String> non_negative_var_attr;
// We manually check the value type to ensure the values are all positive IntImm.
for (auto it : var_upper_bound_attr_raw) {
const auto* key = it.first.as<StringObj>();
Expand All @@ -378,13 +383,23 @@ void SetTIRVarUpperBound(Function func, arith::Analyzer* ana) {
<< value->value << " is got.";
var_upper_bound_attr[GetRef<String>(key)] = GetRef<IntImm>(value);
}
for (ObjectRef var_name : non_negative_var_attr_raw) {
const auto* key = var_name.as<StringObj>();
CHECK(key != nullptr) << "The element of attr `tir_non_negative_var` should be string. However "
<< key->GetTypeKey() << " is got.";
non_negative_var_attr.insert(GetRef<String>(key));
}
Array<tir::Var> var_in_signature = TIRVarsInStructInfo(GetStructInfo(func));
for (const tir::Var& tir_var : var_in_signature) {
auto it = var_upper_bound_attr.find(tir_var->name_hint);
if (it != var_upper_bound_attr.end()) {
ana->Bind(tir_var,
tvm::Range::FromMinExtent(tvm::IntImm(DataType::Int(64), 0),
tvm::IntImm(DataType::Int(64), (*it).second->value + 1)));
tvm::Range range =
tvm::Range::FromMinExtent(tvm::IntImm(DataType::Int(64), 0),
tvm::IntImm(DataType::Int(64), (*it).second->value + 1));
ana->Bind(tir_var, range);
dom_map->Set(tir_var, arith::IntSet::FromRange(range));
} else if (non_negative_var_attr.count(tir_var->name_hint)) {
ana->MarkGlobalNonNegValue(tir_var);
}
}
}
Expand All @@ -398,14 +413,20 @@ void SetTIRVarUpperBound(Function func, arith::Analyzer* ana) {
* \return The upper-bounded shape. When a dimension's upper bound
* cannot be determined, we keep the dimension unchanged.
*/
Array<PrimExpr> GetUpperBoundShape(Array<PrimExpr> shape, arith::Analyzer* ana) {
Array<PrimExpr> GetUpperBoundShape(Array<PrimExpr> shape, arith::Analyzer* ana,
const Map<tir::Var, arith::IntSet>& dom_map) {
// Use the upper bounds of TIR vars as their values.
Array<PrimExpr> upper_bounded_shape;
upper_bounded_shape.reserve(shape.size());
for (const PrimExpr& dim_len : shape) {
int64_t max_bound = ana->const_int_bound(dim_len)->max_value;
if (max_bound == std::numeric_limits<int64_t>::max()) {
upper_bounded_shape.push_back(dim_len);
arith::IntSet int_set = ana->int_set(dim_len, dom_map);
if (int_set.HasUpperBound()) {
upper_bounded_shape.push_back(int_set.max());
} else {
upper_bounded_shape.push_back(dim_len);
}
} else {
upper_bounded_shape.push_back(tvm::IntImm(DataType::Int(64), max_bound));
}
Expand Down Expand Up @@ -462,7 +483,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {

void VisitExpr_(const FunctionNode* func) final {
// Set the upper bound of TIR variables in the analyzer.
SetTIRVarUpperBound(GetRef<Function>(func), analyzer_);
SetTIRVarUpperBound(GetRef<Function>(func), analyzer_, &dom_map_);
// Recurse into the function to get its tokens.
Tokens body_tokens = GetTokens(func->body);
// Discard the tokens used by the function return value, as they are external referenced.
Expand Down Expand Up @@ -565,7 +586,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {

// Use the upper bounds of TIR vars as their values. The upper bound shape can still be dynamic
// if the upper bounds of some variables are not provided.
Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values, analyzer_);
Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values, analyzer_, dom_map_);

// Create and set token.
StringImm storage_scope = Downcast<StringImm>(call->args[3]);
Expand Down Expand Up @@ -641,6 +662,8 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
const IRModule& ctx_mod_;
/*! \brief The arithmetic analyzer. */
arith::Analyzer* analyzer_;
/*! \brief The domain map of dynamic TIR variables for analysis. */
Map<tir::Var, arith::IntSet> dom_map_;
/*! \brief The mapping from each token to the binding block where it is created. */
std::unordered_map<const StorageTokenNode*, const BindingBlockNode*> token2block_;
/*! \brief The mapping from each token to the Exprs that are using this token. */
Expand Down Expand Up @@ -816,7 +839,7 @@ class StorageAllocationRewriter : public ExprMutator {
plan_dynamic_output_ = static_cast<bool>(
func_->GetAttr<IntImm>(plan_dyn_attr_).value_or(IntImm(DataType::Int(32), 0))->value);
if (plan_dynamic_output_) {
SetTIRVarUpperBound(GetRef<Function>(func_), &ana_);
SetTIRVarUpperBound(GetRef<Function>(func_), &ana_, &dom_map_);
}
token2storage_var_.clear();
Function func = Downcast<Function>(this->VisitExpr_(func_));
Expand Down Expand Up @@ -879,7 +902,7 @@ class StorageAllocationRewriter : public ExprMutator {
ICHECK_NOTNULL(sinfo);
const auto* shape = sinfo->shape.as<ShapeExprNode>();
ICHECK_NOTNULL(shape);
Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values, &ana_);
Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values, &ana_, dom_map_);
if (!IsStaticShape(shape->values)) {
ICHECK(!sinfo->IsUnknownDtype());
ICHECK_EQ(sinfo->dtype, Downcast<DataTypeImm>(call->args[1])->value);
Expand All @@ -906,6 +929,8 @@ class StorageAllocationRewriter : public ExprMutator {

/*! \brief The arithmetic analyzer. */
arith::Analyzer ana_;
/*! \brief The domain map of dynamic TIR variables for analysis. */
Map<tir::Var, arith::IntSet> dom_map_;
/*! \brief A boolean indicating whether to plan dynamic-shape function output tensors. */
bool plan_dynamic_output_;
/*!
Expand Down
102 changes: 102 additions & 0 deletions tests/python/relax/test_transform_static_plan_block_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,5 +1347,107 @@ def main(x: R.Tensor((2, "n"), dtype="float32")):
relax.transform.StaticPlanBlockMemory()(Module)


def test_add():
@I.ir_module
class Module:
@T.prim_func(private=True)
def cumsum(var_A: T.handle, var_A_1: T.handle, var_exclusive_scan_thrust: T.handle):
T.evaluate(0)

@R.function
def main(
probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32")
) -> R.Tensor(("batch_size", "vocab_size"), dtype="float32"):
batch_size = T.int64()
vocab_size = T.int64()
R.func_attr(
{
"relax.force_pure": 1,
"relax.memory_plan_dynamic_func_output": 1,
"tir_var_upper_bound": {"batch_size": 32},
"tir_non_negative_var": ["vocab_size"],
}
)
cls = Module
lv1: R.Tensor(
(2 * (batch_size * vocab_size * 4) + 4194304,),
dtype="uint8",
) = R.builtin.alloc_tensor(
R.shape([2 * (batch_size * vocab_size * 4) + 4194304]),
R.dtype("uint8"),
R.prim_value(0),
R.str("global"),
)
alloc1: R.Tensor((batch_size, vocab_size), dtype="float32") = R.builtin.alloc_tensor(
R.shape([batch_size, vocab_size]),
R.dtype("float32"),
R.prim_value(0),
R.str("global"),
)
cls.cumsum(probs, lv1, alloc1)
cumsum: R.Tensor((batch_size, vocab_size), dtype="float32") = alloc1
lv1_1: R.Tensor((batch_size, vocab_size), dtype="int32") = R.call_packed(
"vm.builtin.reshape",
cumsum,
R.shape([batch_size, vocab_size]),
sinfo_args=(R.Tensor((batch_size, vocab_size), dtype="float"),),
)
return lv1_1

@I.ir_module
class Expected:
@T.prim_func(private=True)
def cumsum(var_A: T.handle, var_A_1: T.handle, var_exclusive_scan_thrust: T.handle):
T.evaluate(0)

@R.function
def main(
probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32")
) -> R.Tensor(("batch_size", "vocab_size"), dtype="int32"):
batch_size = T.int64()
vocab_size = T.int64()
R.func_attr(
{
"relax.force_pure": 1,
"tir_non_negative_var": ["vocab_size"],
"tir_var_upper_bound": {"batch_size": 32},
}
)
cls = Expected
storage: R.Object = R.memory.alloc_storage(
R.shape([32 * vocab_size * 4 * 2 + 4194304]),
R.prim_value(0),
R.str("global"),
R.dtype("uint8"),
)
lv1: R.Tensor(
(2 * (batch_size * vocab_size * 4) + 4194304,),
dtype="uint8",
) = R.memory.alloc_tensor(
storage,
R.prim_value(0),
R.shape([2 * (batch_size * vocab_size * 4) + 4194304]),
R.dtype("uint8"),
)
storage1: R.Object = R.memory.alloc_storage(
R.shape([128 * vocab_size]), R.prim_value(0), R.str("global"), R.dtype("float32")
)
alloc1: R.Tensor((batch_size, vocab_size), dtype="float32") = R.memory.alloc_tensor(
storage1, R.prim_value(0), R.shape([batch_size, vocab_size]), R.dtype("float32")
)
cls.cumsum(probs, lv1, alloc1)
cumsum: R.Tensor((batch_size, vocab_size), dtype="float32") = alloc1
lv1_1: R.Tensor((batch_size, vocab_size), dtype="int32") = R.call_packed(
"vm.builtin.reshape",
cumsum,
R.shape([batch_size, vocab_size]),
sinfo_args=(R.Tensor((batch_size, vocab_size), dtype="float32"),),
)
return lv1_1

mod = relax.transform.StaticPlanBlockMemory()(Module)
tvm.ir.assert_structural_equal(mod, Expected)


if __name__ == "__main__":
tvm.testing.main()