Skip to content

Commit bc29367

Browse files
author
Krzysztof Parzyszek
authored
Move WrapTimeEvaluator from RPC to profiling, NFC (#11172)
1 parent 2160f73 commit bc29367

File tree

5 files changed

+81
-81
lines changed

5 files changed

+81
-81
lines changed

include/tvm/runtime/profiling.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,29 @@ String ShapeString(const std::vector<int64_t>& shape, DLDataType dtype);
511511
PackedFunc ProfileFunction(Module mod, std::string func_name, int device_type, int device_id,
512512
int warmup_iters, Array<MetricCollector> collectors);
513513

514+
/*!
515+
* \brief Wrap a timer function to measure the time cost of a given packed function.
516+
* \param f The function argument.
517+
* \param dev The device.
518+
* \param number The number of times to run this function for taking average.
519+
* We call these runs as one `repeat` of measurement.
520+
* \param repeat The number of times to repeat the measurement.
521+
* In total, the function will be invoked (1 + number x repeat) times,
522+
* where the first one is warm up and will be discarded.
523+
* The returned result contains `repeat` costs,
524+
* each of which is an average of `number` costs.
525+
* \param min_repeat_ms The minimum duration of one `repeat` in milliseconds.
526+
* By default, one `repeat` contains `number` runs. If this parameter is set,
527+
* the parameters `number` will be dynamically adjusted to meet the
528+
* minimum duration requirement of one `repeat`.
529+
* i.e., When the run time of one `repeat` falls below this time,
530+
* the `number` parameter will be automatically increased.
531+
* \param f_preproc The function to be executed before we excetute time evaluator.
532+
* \return f_timer A timer function.
533+
*/
534+
PackedFunc WrapTimeEvaluator(PackedFunc f, Device dev, int number, int repeat, int min_repeat_ms,
535+
PackedFunc f_preproc = nullptr);
536+
514537
} // namespace profiling
515538
} // namespace runtime
516539
} // namespace tvm

src/runtime/graph_executor/debug/graph_executor_debug.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class GraphExecutorDebug : public GraphExecutor {
113113

114114
// assume host runs things which is first device
115115
Device& d = devices_[0];
116-
PackedFunc time_evaluator = WrapTimeEvaluator(
116+
PackedFunc time_evaluator = profiling::WrapTimeEvaluator(
117117
TypedPackedFunc<void()>([this, node_index]() { this->RunOpHost(node_index); }), d, number,
118118
repeat, min_repeat_ms);
119119
std::string result = time_evaluator();

src/runtime/profiling.cc

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,61 @@ TVM_REGISTER_GLOBAL("runtime.profiling.ProfileFunction")
739739
}
740740
});
741741

742+
PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat, int min_repeat_ms,
743+
PackedFunc f_preproc) {
744+
ICHECK(pf != nullptr);
745+
746+
if (static_cast<int>(dev.device_type) == static_cast<int>(kDLMicroDev)) {
747+
auto get_micro_time_evaluator = runtime::Registry::Get("micro._GetMicroTimeEvaluator");
748+
ICHECK(get_micro_time_evaluator != nullptr) << "micro backend not enabled";
749+
return (*get_micro_time_evaluator)(pf, dev, number, repeat);
750+
}
751+
752+
auto ftimer = [pf, dev, number, repeat, min_repeat_ms, f_preproc](TVMArgs args,
753+
TVMRetValue* rv) mutable {
754+
TVMRetValue temp;
755+
std::ostringstream os;
756+
// skip first time call, to activate lazy compilation components.
757+
pf.CallPacked(args, &temp);
758+
759+
DeviceAPI::Get(dev)->StreamSync(dev, nullptr);
760+
761+
for (int i = 0; i < repeat; ++i) {
762+
if (f_preproc != nullptr) {
763+
f_preproc.CallPacked(args, &temp);
764+
}
765+
double duration_ms = 0.0;
766+
767+
do {
768+
if (duration_ms > 0.0) {
769+
number = static_cast<int>(std::max((min_repeat_ms / (duration_ms / number) + 1),
770+
number * 1.618)); // 1.618 is chosen by random
771+
}
772+
773+
Timer t = Timer::Start(dev);
774+
// start timing
775+
for (int i = 0; i < number; ++i) {
776+
pf.CallPacked(args, &temp);
777+
}
778+
t->Stop();
779+
int64_t t_nanos = t->SyncAndGetElapsedNanos();
780+
duration_ms = t_nanos / 1e6;
781+
} while (duration_ms < min_repeat_ms);
782+
783+
double speed = duration_ms / 1e3 / number;
784+
os.write(reinterpret_cast<char*>(&speed), sizeof(speed));
785+
}
786+
787+
std::string blob = os.str();
788+
TVMByteArray arr;
789+
arr.size = blob.length();
790+
arr.data = blob.data();
791+
// return the time.
792+
*rv = arr;
793+
};
794+
return PackedFunc(ftimer);
795+
}
796+
742797
} // namespace profiling
743798
} // namespace runtime
744799
} // namespace tvm

src/runtime/rpc/rpc_module.cc

Lines changed: 2 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -357,61 +357,6 @@ inline void CPUCacheFlush(int begin_index, const TVMArgs& args) {
357357
}
358358
}
359359

360-
PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat, int min_repeat_ms,
361-
PackedFunc f_preproc) {
362-
ICHECK(pf != nullptr);
363-
364-
if (static_cast<int>(dev.device_type) == static_cast<int>(kDLMicroDev)) {
365-
auto get_micro_time_evaluator = runtime::Registry::Get("micro._GetMicroTimeEvaluator");
366-
ICHECK(get_micro_time_evaluator != nullptr) << "micro backend not enabled";
367-
return (*get_micro_time_evaluator)(pf, dev, number, repeat);
368-
}
369-
370-
auto ftimer = [pf, dev, number, repeat, min_repeat_ms, f_preproc](TVMArgs args,
371-
TVMRetValue* rv) mutable {
372-
TVMRetValue temp;
373-
std::ostringstream os;
374-
// skip first time call, to activate lazy compilation components.
375-
pf.CallPacked(args, &temp);
376-
377-
DeviceAPI::Get(dev)->StreamSync(dev, nullptr);
378-
379-
for (int i = 0; i < repeat; ++i) {
380-
if (f_preproc != nullptr) {
381-
f_preproc.CallPacked(args, &temp);
382-
}
383-
double duration_ms = 0.0;
384-
385-
do {
386-
if (duration_ms > 0.0) {
387-
number = static_cast<int>(std::max((min_repeat_ms / (duration_ms / number) + 1),
388-
number * 1.618)); // 1.618 is chosen by random
389-
}
390-
391-
Timer t = Timer::Start(dev);
392-
// start timing
393-
for (int i = 0; i < number; ++i) {
394-
pf.CallPacked(args, &temp);
395-
}
396-
t->Stop();
397-
int64_t t_nanos = t->SyncAndGetElapsedNanos();
398-
duration_ms = t_nanos / 1e6;
399-
} while (duration_ms < min_repeat_ms);
400-
401-
double speed = duration_ms / 1e3 / number;
402-
os.write(reinterpret_cast<char*>(&speed), sizeof(speed));
403-
}
404-
405-
std::string blob = os.str();
406-
TVMByteArray arr;
407-
arr.size = blob.length();
408-
arr.data = blob.data();
409-
// return the time.
410-
*rv = arr;
411-
};
412-
return PackedFunc(ftimer);
413-
}
414-
415360
TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator")
416361
.set_body_typed([](Optional<Module> opt_mod, std::string name, int device_type, int device_id,
417362
int number, int repeat, int min_repeat_ms, std::string f_preproc_name) {
@@ -434,7 +379,7 @@ TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator")
434379
}
435380
PackedFunc pf = m.GetFunction(name, true);
436381
CHECK(pf != nullptr) << "Cannot find " << name << " in the global registry";
437-
return WrapTimeEvaluator(pf, dev, number, repeat, min_repeat_ms, f_preproc);
382+
return profiling::WrapTimeEvaluator(pf, dev, number, repeat, min_repeat_ms, f_preproc);
438383
}
439384
} else {
440385
auto* pf = runtime::Registry::Get(name);
@@ -446,7 +391,7 @@ TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator")
446391
<< "Cannot find " << f_preproc_name << " in the global function";
447392
f_preproc = *pf_preproc;
448393
}
449-
return WrapTimeEvaluator(*pf, dev, number, repeat, min_repeat_ms, f_preproc);
394+
return profiling::WrapTimeEvaluator(*pf, dev, number, repeat, min_repeat_ms, f_preproc);
450395
}
451396
});
452397

src/runtime/rpc/rpc_session.h

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -282,29 +282,6 @@ struct RemoteSpace {
282282
std::shared_ptr<RPCSession> sess;
283283
};
284284

285-
/*!
286-
* \brief Wrap a timer function to measure the time cost of a given packed function.
287-
* \param f The function argument.
288-
* \param dev The device.
289-
* \param number The number of times to run this function for taking average.
290-
* We call these runs as one `repeat` of measurement.
291-
* \param repeat The number of times to repeat the measurement.
292-
* In total, the function will be invoked (1 + number x repeat) times,
293-
* where the first one is warm up and will be discarded.
294-
* The returned result contains `repeat` costs,
295-
* each of which is an average of `number` costs.
296-
* \param min_repeat_ms The minimum duration of one `repeat` in milliseconds.
297-
* By default, one `repeat` contains `number` runs. If this parameter is set,
298-
* the parameters `number` will be dynamically adjusted to meet the
299-
* minimum duration requirement of one `repeat`.
300-
* i.e., When the run time of one `repeat` falls below this time,
301-
* the `number` parameter will be automatically increased.
302-
* \param f_preproc The function to be executed before we excetute time evaluator.
303-
* \return f_timer A timer function.
304-
*/
305-
PackedFunc WrapTimeEvaluator(PackedFunc f, Device dev, int number, int repeat, int min_repeat_ms,
306-
PackedFunc f_preproc = nullptr);
307-
308285
/*!
309286
* \brief Create a Global RPC module that refers to the session.
310287
* \param sess The RPC session of the global module.

0 commit comments

Comments
 (0)