Skip to content

Commit 364e2db

Browse files
authored
[microNPU] Add support for scalar values (#9794)
* [microNPU] Add support for scalar values PR #9515 enabled support for scalar constants, but didn't consider the case of a scalar value where the underlying constant data does not have a shape i.e. `constant.shape == []`. See the test case for a visual differece when the scalar value is 1. Change-Id: Id7a238cb5bf999dd5a8428c097202f9fb940a5f0 * Fix failing test by removing constant Before this PR scalar constants were handled differently so this test was able to pass. Now that scalar constants are handled in the same manner as tensor constants, the test fails since unexpected tir is produced in the compilation pipeline. Since the relay used in this test case is not expected to be produced by higher levels of the compiler, removing this constant for now. Change-Id: I4ea5155778809041339e6faac05af3f72c3e3ea5 * clean up finding tensor from inputs Change-Id: Ideccf84f8c9149148ff23e2406229cf637c982a3
1 parent 211291f commit 364e2db

File tree

5 files changed

+15
-13
lines changed

5 files changed

+15
-13
lines changed

python/tvm/relay/backend/contrib/ethosu/legalize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -652,8 +652,8 @@ def callback(
652652
ifm2_zero_point=int(params.ifm2.q_params.zero_point),
653653
ofm_scale=float(params.ofm.q_params.scale_f32),
654654
ofm_zero_point=int(params.ofm.q_params.zero_point),
655-
ifm_channels=params.ifm.shape[-1],
656-
ifm2_channels=params.ifm2.shape[-1],
655+
ifm_channels=params.ifm.shape[-1] if params.ifm.shape else 1,
656+
ifm2_channels=params.ifm2.shape[-1] if params.ifm2.shape else 1,
657657
reversed_operands=params.reversed_operands,
658658
ofm_dtype=params.ofm.dtype,
659659
activation=activation,

python/tvm/relay/backend/contrib/ethosu/tir/compiler.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,11 @@ def __init__(self):
123123

124124
def visit_constant(self, const):
125125
if isinstance(const.checked_type, relay.ty.TensorType):
126-
if const.checked_type.concrete_shape != ():
127-
self.constants.append(const.data.asnumpy())
128-
name = "p" + str(len(self.constants))
129-
var = relay.var(type_annotation=const.checked_type, name_hint=name)
130-
self.const_vars.append(var)
131-
return var
126+
self.constants.append(const.data.asnumpy())
127+
name = "p" + str(len(self.constants))
128+
var = relay.var(type_annotation=const.checked_type, name_hint=name)
129+
self.const_vars.append(var)
130+
return var
132131

133132
return const
134133

python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,10 @@ def _visit(tensor, reader, lut):
136136
if tensor not in planned:
137137
planned.add(tensor)
138138
if isinstance(tensor.op, tvm.te.PlaceholderOp) and tensor != lut:
139-
index = list(cached_func.inputs).index(tensor)
139+
# Find index of input using 'same_as' check to prevent equality
140+
# ambiguity when encountering a scalar.
141+
is_same = [var.same_as(tensor) for var in cached_func.inputs]
142+
index = is_same.index(True)
140143
if index in const_dict:
141144
sch.cache_read(tensor, "global", [reader])
142145

tests/python/contrib/test_ethosu/test_codegen.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -629,12 +629,13 @@ def create_mod_from_relay():
629629

630630
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
631631
@pytest.mark.parametrize("dtype", ["int8", "uint8"])
632-
def test_elementwise_add_from_constant_scalar(accel_type, dtype):
632+
@pytest.mark.parametrize("constant", [np.ones((1, 1, 1, 1)), np.array(1)])
633+
def test_elementwise_add_from_constant_scalar(accel_type, dtype, constant):
633634
ifm_shape = (1, 4, 4, 8)
634635

635636
def create_relay_graph():
636637
inp = relay.var("input", shape=ifm_shape, dtype=dtype)
637-
scalar = relay.const(np.ones((1, 1, 1, 1), dtype=dtype), dtype=dtype)
638+
scalar = relay.const(constant, dtype=dtype)
638639
add = relay.qnn.op.add(
639640
inp,
640641
scalar,

tests/python/contrib/test_ethosu/test_compiler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ def test_lower_to_tir():
3434
kernel_layout="HWIO",
3535
out_dtype="int32",
3636
)
37-
multiply = relay.multiply(relay.const(-22, dtype="int32"), p2)
38-
tile = relay.tile(multiply, reps=(1, 1, 1, 1001))
37+
tile = relay.tile(p2, reps=(1, 1, 1, 1001))
3938
subtract = relay.subtract(conv, tile)
4039
func = subtract
4140
expr = relay.Function(relay.analysis.free_vars(func), func)

0 commit comments

Comments
 (0)