Skip to content

[ARITH] Improve Canonical Simplification to Handle Fused Pattern  #1711

@ke1337

Description

@ke1337

I have following C++ code when testing TVM in CUDA:

    tvm::Array<tvm::Expr> a_shape1 = {5, 6};
    // tvm::Array<tvm::Expr> a_shape1 = {5, 2, 3};
    tvm::Tensor tvm_X = tvm::placeholder(a_shape1, tvm::Float(32), "A");
    tvm::Tensor tvm_Y = ::topi::where(less(0, tvm_X), 1 / (1 + exp(negative(tvm_X))), 1 - 1 / (1 + exp(tvm_X)));
    auto target1 = tvm::target::cuda();
    auto S1 = topi::cuda::schedule_injective(target1, {tvm_Y});

    auto args1 = tvm::Array<tvm::Tensor>({tvm_X, tvm_Y});
    std::unordered_map<tvm::Tensor, tvm::Buffer> binds1;
    auto config1 = tvm::build_config();
    config1->restricted_func = true;
    auto lowered1 = tvm::lower(S1, args1, "Sigmoid", binds1, config1);

    std::cout << lowered1[0]->body << std::endl;

When the input shape is 2D (5,6), the lowered function looks close to handwritten kernel:

  if ((threadIdx.x < 30)) {
    tensor[threadIdx.x] = tvm_if_then_else(((0.000000f < A[threadIdx.x]) == (uint1)0), (1.000000f - (1.000000f/(exp(A[threadIdx.x]) + 1.000000f))), (1.000000f/(exp((0.000000f - A[threadIdx.x])) + 1.000000f)))
  }

However for 3D input shape (5, 2, 3), the lowered function looks different:

  if ((threadIdx.x < 30)) {
    tensor[(((threadIdx.x/6)*6) + ((((threadIdx.x/3) % 2)*3) + (threadIdx.x % 3)))] = tvm_if_then_else(((0.000000f < A[(((threadIdx.x/6)*6) + ((((threadIdx.x/3) % 2)*3) + (threadIdx.x % 3)))]) == (uint1)0), (1.000000f - (1.000000f/(exp(A[(((threadIdx.x/6)*6) + ((((threadIdx.x/3) % 2)*3) + (threadIdx.x % 3)))]) + 1.000000f))), (1.000000f/(exp((0.000000f - A[(((threadIdx.x/6)*6) + ((((threadIdx.x/3) % 2)*3) + (threadIdx.x % 3)))])) + 1.000000f)))
  }

From my reading of injective schedule, it seems all input axes are fused before split, so the two cases above should have identical code gen. Is my understanding correct?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions