Skip to content

Commit 80df07a

Browse files
committed
Fix with local args
1 parent c0a441a commit 80df07a

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

python/tvm/auto_scheduler/measure.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
import tempfile
3838
import multiprocessing
3939
import logging
40-
import copy
4140

4241
import tvm._ffi
4342
from tvm.runtime import Object, module, ndarray
@@ -910,18 +909,18 @@ def _timed_eval_func(
910909
random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True)
911910
assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake"
912911
assert len(args) == len(build_res.args)
913-
loc_args = copy.deepcopy(args)
912+
loc_args = []
914913
# pylint: disable=consider-using-enumerate
915-
for idx in range(len(loc_args)):
916-
if loc_args[idx] is None:
914+
for idx in range(len(args)):
915+
if args[idx] is None:
917916
build_res_arg = build_res.args[idx]
918917
empty_array = ndarray.empty(
919918
get_const_tuple(build_res_arg.shape), build_res_arg.dtype, dev
920919
)
921920
random_fill(empty_array)
922-
loc_args[idx] = empty_array
921+
loc_args.append(empty_array)
923922
else:
924-
loc_args[idx] = ndarray.array(loc_args[idx], dev)
923+
loc_args.append(ndarray.array(arg))
925924
dev.sync()
926925
costs = time_f(*loc_args).results
927926
# pylint: disable=broad-except
@@ -1114,18 +1113,18 @@ def _rpc_run(
11141113
), "Please make sure USE_RANDOM is ON in the config.cmake on the remote devices"
11151114

11161115
assert len(args) == len(build_res.args)
1117-
loc_args = copy.deepcopy(args)
1116+
loc_args = []
11181117
# pylint: disable=consider-using-enumerate
1119-
for idx in range(len(loc_args)):
1120-
if loc_args[idx] is None:
1118+
for idx in range(len(args)):
1119+
if args[idx] is None:
11211120
build_res_arg = build_res.args[idx]
11221121
empty_array = ndarray.empty(
11231122
get_const_tuple(build_res_arg.shape), build_res_arg.dtype, dev
11241123
)
11251124
random_fill(empty_array)
1126-
loc_args[idx] = empty_array
1125+
loc_args.append(empty_array)
11271126
else:
1128-
loc_args[idx] = ndarray.array(loc_args[idx], dev)
1127+
loc_args.append(ndarray.array(args[idx], dev))
11291128
dev.sync()
11301129

11311130
# First run for check that the kernel is correct

0 commit comments

Comments
 (0)