Skip to content

[RFC][TVM] Extend TensorComputeOp to allow scalar inputs #2606

@jdavies-huawei

Description

@jdavies-huawei

Motivation

TensorComputeOp allows a TensorIntrin to be called by the user directly, rather than relying on schedule.tensorize to match a pattern and perform a replacement. The original motivation for TensorComputeOpt was to generalize TVM's compute to tensor regions, as described in issue #1485.

Currently, all arguments passed to the tensor intrinsic must be tensor regions. We believe this is too restrictive. For example, it is common across multiple architectures for hardware intrinsics to take scalar arguments.

In a regular compute, it is already possible to use arbitrary expressions over the iterator variables, e.g.:

n = 10
A = tvm.placeholder((n, n))
B = tvm.compute((n, n), lambda i, j : A[i, j] + i*i)

However, it isn't possible to do something similar with TensorComputeOp, for example:

tfunc = intrin_tfunc(n)
C1 = tvm.compute((n, n), lambda i : tfunc(A[i, 0:n], i*i))

In the above, passing i*i to the TensorIntrin tfunc fails, because i*i is a scalar expression, not a tensor region.

One current workaround is to store i*i in another tensor:

S = tvm.compute((n, ), lambda i : i*i)
C2 = tvm.compute((n, n), lambda i : tfunc(A[i, 0:n], S[i])

However, this workaround introduces extra tensors that do not need to exist and will add overhead. Therefore, we propose to extend TensorComputeOp so scalar expressions can be passed to the TensorIntrin call.

Proposed Syntax

A list of scalar expressions is passed to the TensorIntrin call as a keyword argument 'scalar_inputs':

C = tvm.compute((n, n), lambda i: tfunc(A[i, 0:n], scalar_inputs=(i*i)))

When declaring the TensorIntrin, the expected scalar parameters must be listed in a keyword argument "scalar_params":

tfunc = tvm.decl_tensor_intrin(D.op, intrin_func, binds={a: Ab, c: Cb}, scalar_params=[s])

where the scalar parameters must be a variables used in D's compute:

s = tvm.var("s")
a = tvm.placeholder((n, ))
D = tvm.compute((n,), lambda i: a[i] + s)

Finally, the intrin_func lambda function passed to decl_tensor_intrin must take a third argument containing the list of scalar inputs ('sp' below). The scalar inputs can then be used in the emitted call:

# sp will be the list of scalar inputs passed to the TensorIntrin call
def intrin_func(ins, outs, sp):
  aa = ins[0]
  cc = outs[0]
  def _body():
    ib = tvm.ir_builder.create()
    ib.emit(tvm.call_extern("int32", "test_intrin",
                          cc.access_ptr("w"),
                          aa.access_ptr("r"),
                          sp[0]))
    return ib.get()
  return _body()

Example

import tvm

def intrin_test(n):
  s = tvm.var("s")
  a = tvm.placeholder((n,), name='a')
  d = tvm.compute((n,), lambda i: a[i] + s, name='d')

  def intrin_func(ins, outs, sp):
    aa = ins[0]
    cc = outs[0]
    def _body():
      ib = tvm.ir_builder.create()
      ib.emit(tvm.call_extern("int32", "test_intrin",
                            cc.access_ptr("w"),
                            aa.access_ptr("r"),
                            sp[0]))
      return ib.get()
    return _body()

  with tvm.build_config(offset_factor=1):
    return tvm.decl_tensor_intrin(d.op, intrin_func, scalar_params=[s])

if __name__ == '__main__':

    n = 10
    A = tvm.placeholder((n, n), name='A')
    tfunc = intrin_test(n)
    C = tvm.compute((n, n), lambda i: tfunc(A[i, 0:n], scalar_inputs=(i*i)), name='C')
    s = tvm.create_schedule(C.op)
    print(tvm.lower(s, [A, C], simple_mode=True))

The above example program produces the following output:

produce C {
  for (i, 0, 10) {
    test_intrin(tvm_address_of(C[(i*10)]), tvm_address_of(A[(i*10)]), (i*i))
  }
}

Implementation

The extension is implemented in the accompanying pull request.

Testing

The pull request includes one new unit test for this feature, but additional stronger tests are still required. We appreciate advice about how best to test this feature.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions