Skip to content

Commit 6f81479

Browse files
committed
Fix test
1 parent b179697 commit 6f81479

File tree

3 files changed

+20
-8
lines changed

3 files changed

+20
-8
lines changed

python/tvm/relay/expr.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .._ffi import base as _base
2828
from .. import nd as _nd
2929
from .. import convert
30+
from ..ndarray import NDArray
3031

3132
# will be registered afterwards
3233
_op_make = None
@@ -305,10 +306,15 @@ def __call__(self, *args):
305306
"""
306307
return Call(self, args, None, None)
307308

308-
def get_params(self, params):
309-
return _expr.FunctionGet(self, params)
309+
def get_params(self):
310+
return _expr.FunctionGetParams(self)
310311

311312
def set_params(self, params):
313+
for key in params:
314+
value = params[key]
315+
if isinstance(value, NDArray):
316+
params[key] = Constant(value)
317+
312318
return _expr.FunctionSetParams(self, params)
313319

314320

src/relay/ir/expr.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ Function FunctionNode::SetParams(const tvm::Map<Var, Constant>& parameters) cons
163163
return FunctionSetAttr(GetRef<Function>(this), "__params__", parameters);
164164
}
165165

166-
TVM_REGISTER_API("relay._expr.FunctionSetParms")
166+
TVM_REGISTER_API("relay._expr.FunctionSetParams")
167167
.set_body_typed<Function(const Function&, const tvm::Map<Var, Constant>&)>(
168168
[](const Function& func, const tvm::Map<Var, Constant>& parameters) {
169169
return func->SetParams(parameters);
@@ -174,7 +174,7 @@ tvm::Map<Var, Constant> FunctionNode::GetParams() const {
174174
return Downcast<tvm::Map<Var, Constant>>(node_ref);
175175
}
176176

177-
TVM_REGISTER_API("relay._expr.FunctionGetParms")
177+
TVM_REGISTER_API("relay._expr.FunctionGetParams")
178178
.set_body_typed<tvm::Map<Var, Constant>(const Function&)>([](const Function& func) {
179179
return func->GetParams();
180180
});

tests/python/relay/test_ir_nodes.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,6 @@ def test_global_var():
160160
str(gv)
161161
check_json_roundtrip(gv)
162162

163-
164163
def test_function():
165164
param_names = ['a', 'b', 'c', 'd']
166165
params = tvm.convert([relay.Var(n) for n in param_names])
@@ -184,7 +183,8 @@ def test_function_attrs():
184183
fn = relay.Function(params, body, ret_type, type_params)
185184
model_params = {}
186185
for param in params[:1]:
187-
tensor = np.random.rand(*param.shape).astype(param.dtype)
186+
cty = param.type_annotation
187+
tensor = np.random.rand(*[int(sh) for sh in cty.shape]).astype(cty.dtype)
188188
model_params[param] = tvm.nd.array(tensor)
189189
fn = fn.set_params(model_params)
190190
assert fn.params == params
@@ -196,8 +196,12 @@ def test_function_attrs():
196196
json_str = tvm.save_json(fn)
197197
fn_after = tvm.load_json(json_str)
198198
model_params_after = fn_after.get_params()
199-
for p1, p2 in zip(model_params, model_params_after):
200-
assert p1.asnumpy() == p2.asnumpy()
199+
after_keys = [item[0] for item in model_params_after.items()]
200+
for key1, key2 in zip(model_params, after_keys):
201+
assert key1.name_hint == key2.name_hint
202+
p1 = model_params[key1]
203+
p2 = model_params_after[key2]
204+
np.testing.assert_allclose(p1.data.asnumpy(), p2.data.asnumpy())
201205

202206
def test_call():
203207
op = relay.Var('f')
@@ -280,9 +284,11 @@ def test_conv2d_attrs():
280284
test_local_var()
281285
test_global_var()
282286
test_function()
287+
test_function_attrs()
283288
test_call()
284289
test_let()
285290
test_if()
286291
test_tuple_get_item()
287292
test_op()
288293
test_conv2d_attrs()
294+

0 commit comments

Comments
 (0)