-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Hexagon][Runtime] Add QuRT thread pool backend #11018
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
76877fe
2605634
82bad79
cb07bd3
f360a9f
ffeaa90
bc0ef77
e1fd5ee
5948675
deebe25
b1e9265
d27746b
c9ce982
39097ab
832c125
a80713f
f99d32b
85bc9b4
91d2b23
cb2104a
0d20028
a0bf101
95e99d9
912ecc6
15fafc0
e4d2b14
590c1b3
5576fd3
69d586e
53ef297
24f14f7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,13 +34,72 @@ | |
| #endif | ||
| #if defined(__hexagon__) | ||
| #include <dlfcn.h> | ||
| #include <qurt.h> | ||
| #include <stdlib.h> | ||
| #define HEXAGON_STACK_SIZE 65536 | ||
| #define HEXAGON_STACK_ALIGNMENT 32 | ||
| #endif | ||
| #include <algorithm> | ||
| #include <thread> | ||
| #define CURRENT_THREAD_HANDLE (static_cast<std::thread::native_handle_type>(0)) | ||
| namespace tvm { | ||
| namespace runtime { | ||
| namespace threading { | ||
| #ifdef __hexagon__ | ||
| // pthreads are broken on older versions of qurt, so | ||
| // we need to use native APIs instead of std::threads | ||
| class QuRTThread { | ||
| typedef std::function<void()> Callback; | ||
|
|
||
| public: | ||
| explicit QuRTThread(Callback worker_callback) : worker_callback_(worker_callback) { | ||
| static int id = 1; | ||
| qurt_thread_attr_t attr; | ||
| char name[32]; | ||
| int ret = posix_memalign(&stack_, HEXAGON_STACK_ALIGNMENT, HEXAGON_STACK_SIZE); | ||
| CHECK_EQ(ret, 0); | ||
| // When a std::function<> is cast to bool, | ||
| // it indicates whether it stores a callable target | ||
| CHECK_EQ((bool)worker_callback_, true); | ||
| qurt_thread_attr_init(&attr); | ||
| qurt_thread_attr_set_stack_size(&attr, HEXAGON_STACK_SIZE); | ||
| qurt_thread_attr_set_stack_addr(&attr, stack_); | ||
| snprintf(name, sizeof(name), "worker %d", id++); | ||
| qurt_thread_attr_set_name(&attr, name); | ||
| ret = qurt_thread_create(&thread_, &attr, (void (*)(void*))RunFunction, this); | ||
| CHECK_EQ(ret, QURT_EOK); | ||
| } | ||
| QuRTThread(QuRTThread&& other) | ||
| : thread_(other.thread_), | ||
| worker_callback_(std::move(other.worker_callback_)), | ||
| stack_(other.stack_) { | ||
| other.thread_ = 0; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. although line 73 already have the logic to avoid double free, after the construction from 'Rvalue' logically still need to clear member value(stack_, f_) of 'other', because these value already not owned by "other". doing this can help the future maintains.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. free(NULL) is defined to be a no-op, but I added an explicit check anyway. |
||
| other.stack_ = nullptr; | ||
| } | ||
| ~QuRTThread() { | ||
| if (thread_) { | ||
| join(); | ||
| } | ||
| if (stack_) { | ||
| free(stack_); | ||
supersat marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
| } | ||
| bool joinable() const { return qurt_thread_get_id() != thread_; } | ||
| void join() { | ||
supersat marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| int status; | ||
| qurt_thread_join(thread_, &status); | ||
| } | ||
|
|
||
| private: | ||
| static void RunFunction(QuRTThread* qrt_thread) { | ||
| qrt_thread->worker_callback_(); | ||
| qurt_thread_exit(QURT_EOK); | ||
| } | ||
| qurt_thread_t thread_; | ||
| Callback worker_callback_; | ||
| void* stack_ = nullptr; | ||
| }; | ||
| #endif // __hexagon__ | ||
| thread_local int max_concurrency = 0; | ||
| class ThreadGroup::Impl { | ||
| public: | ||
|
|
@@ -116,6 +175,7 @@ class ThreadGroup::Impl { | |
| // if worker 0 is offloaded to main, i.e. exclude_worker0 is true, | ||
| // the main thread is bound to core 0. | ||
| void SetAffinity(bool exclude_worker0, AffinityMode mode) { | ||
| #ifndef __hexagon__ | ||
| const char* val = getenv("TVM_BIND_THREADS"); | ||
| if (val != nullptr && atoi(val) != 1) { | ||
| return; | ||
|
|
@@ -172,6 +232,7 @@ class ThreadGroup::Impl { | |
| SetMasterThreadFullCpuAffinity(mode); | ||
| } | ||
| } | ||
| #endif // __hexagon__ | ||
| } | ||
|
|
||
| void SetThreadFullCpuAffinity(std::thread::native_handle_type thread, AffinityMode mode) { | ||
|
|
@@ -185,6 +246,7 @@ class ThreadGroup::Impl { | |
| // Note: this works well on x86 too. Because x86 doesn't have BIG.LITTLE, | ||
| // our implementation will use kBig mode by default and will let main thread | ||
| // run on intended cores. | ||
| #ifndef __hexagon__ | ||
| std::vector<unsigned> ids; | ||
| switch (mode) { | ||
| case kSpecifyOneCorePerThread: | ||
|
|
@@ -206,6 +268,7 @@ class ThreadGroup::Impl { | |
| break; | ||
| } | ||
| SetThreadAffinity(thread, ids); | ||
| #endif // __hexagon__ | ||
| } | ||
|
|
||
| void SetMasterThreadFullCpuAffinity(AffinityMode mode) { | ||
|
|
@@ -259,7 +322,11 @@ class ThreadGroup::Impl { | |
| } | ||
|
|
||
| int num_workers_; | ||
| #if defined(__hexagon__) | ||
| std::vector<QuRTThread> threads_; | ||
| #else | ||
| std::vector<std::thread> threads_; | ||
| #endif | ||
| std::vector<unsigned int> sorted_order_; | ||
| int big_count_ = 0; | ||
| int little_count_ = 0; | ||
|
|
@@ -276,7 +343,17 @@ int ThreadGroup::Configure(AffinityMode mode, int nthreads, bool exclude_worker0 | |
| return impl_->Configure(mode, nthreads, exclude_worker0, cpus); | ||
| } | ||
|
|
||
| void Yield() { std::this_thread::yield(); } | ||
| void Yield() { | ||
| #ifdef __hexagon__ | ||
| // QuRT doesn't have a yield API, so instead we sleep for the minimum amount | ||
| // of time to let the OS schedule another thread. std::this_thread::yield() | ||
| // compiles down to an empty function. | ||
| qurt_sleep(1); | ||
supersat marked this conversation as resolved.
Show resolved
Hide resolved
supersat marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| #else | ||
| std::this_thread::yield(); | ||
| #endif | ||
| } | ||
|
|
||
| /*! | ||
| * \brief Set the maximum number of available cores. | ||
| */ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
|
|
||
| import tvm | ||
| import tvm.contrib.hexagon | ||
| import tvm.script | ||
| import tvm.testing | ||
| from tvm import te | ||
|
|
||
| from .conftest import requires_hexagon_toolchain | ||
| from tvm.script import tir as T | ||
|
|
||
|
|
||
| @tvm.script.ir_module | ||
| class ElemwiseSumIRModule: | ||
| @T.prim_func | ||
| def elemwise_sum_serial(a: T.handle, b: T.handle, c: T.handle, n: T.int32): | ||
| T.func_attr({"global_symbol": "elemwise_sum_serial", "tir.noalias": True}) | ||
| A = T.match_buffer(a, (n,), dtype="float32") | ||
| B = T.match_buffer(b, (n,), dtype="float32") | ||
| C = T.match_buffer(c, (n,), dtype="float32") | ||
| for i in T.serial(n): | ||
| with T.block("C"): | ||
| vi = T.axis.spatial(n, i) | ||
| C[vi] = A[vi] + B[vi] | ||
|
|
||
| @T.prim_func | ||
| def elemwise_sum_parallel(a: T.handle, b: T.handle, c: T.handle, n: T.int32): | ||
| T.func_attr({"global_symbol": "elemwise_sum_parallel", "tir.noalias": True}) | ||
| A = T.match_buffer(a, (n,), dtype="float32") | ||
| B = T.match_buffer(b, (n,), dtype="float32") | ||
| C = T.match_buffer(c, (n,), dtype="float32") | ||
| for i in T.parallel(n): | ||
| with T.block("C"): | ||
| vi = T.axis.spatial(n, i) | ||
| C[vi] = A[vi] + B[vi] | ||
|
|
||
|
|
||
| def generate_add_test_data(hexagon_session, n=128 * 1024): | ||
| a = tvm.nd.array(np.random.uniform(size=n).astype("float32"), hexagon_session.device) | ||
| b = tvm.nd.array(np.random.uniform(size=n).astype("float32"), hexagon_session.device) | ||
| c = tvm.nd.array(np.zeros(n, dtype="float32"), hexagon_session.device) | ||
| return (a, b, c, n) | ||
|
|
||
|
|
||
| def benchmark_func(mod, name, args, hexagon_session): | ||
| (a, b, c, n) = args | ||
| evaluator = mod.time_evaluator(name, hexagon_session.device, number=100) | ||
| return evaluator(a, b, c, n).mean | ||
|
|
||
|
|
||
| @requires_hexagon_toolchain | ||
| def test_speedup(hexagon_session, capsys): | ||
| if hexagon_session is None: | ||
| pytest.skip(msg="Skip hardware test, ANDROID_SERIAL_NUMBER is not set.") | ||
|
|
||
| target_hexagon = tvm.target.hexagon("v68", link_params=True) | ||
| func = tvm.build( | ||
| ElemwiseSumIRModule, target=tvm.target.Target(target_hexagon, host=target_hexagon) | ||
| ) | ||
| mod = hexagon_session.load_module(func) | ||
| args = generate_add_test_data(hexagon_session) | ||
| parallel_mean = benchmark_func(mod, "elemwise_sum_parallel", args, hexagon_session) | ||
| serial_mean = benchmark_func(mod, "elemwise_sum_serial", args, hexagon_session) | ||
|
|
||
| with capsys.disabled(): | ||
| print("... speedup of {:.2f}".format(serial_mean / parallel_mean), end=" ") | ||
|
|
||
|
|
||
| @requires_hexagon_toolchain | ||
| def test_elemwise_sum_parallel(hexagon_session): | ||
| if hexagon_session is None: | ||
| pytest.skip(msg="Skip hardware test, ANDROID_SERIAL_NUMBER is not set.") | ||
|
|
||
| target_hexagon = tvm.target.hexagon("v68", link_params=True) | ||
| func = tvm.build( | ||
| ElemwiseSumIRModule, target=tvm.target.Target(target_hexagon, host=target_hexagon) | ||
| ) | ||
| mod = hexagon_session.load_module(func) | ||
|
|
||
| (a, b, c, n) = generate_add_test_data(hexagon_session) | ||
| mod["elemwise_sum_parallel"](a, b, c, n) | ||
| tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering if it's possible to do this with templating instead of with
#ifdef __hexagon__?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To a certain extent, although it's not clear it's worth the effort. As far as I can tell are a couple options:
Yield() also doesn't work as-is on Hexagon, so we'd need a fix for that as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One approach would be to introduce
and then wrap both std::thread and QuRT thread into types with a common interface that
ThreadGroup::Implcalls into.The thing I got hung up on there is the runtime dispatch. Right now it should be doable from the device type -- when it is
kDLHexagondispatch toThreadGroup::Impl<QuRTThreadInterface>when it'skDLCPUdispatch toTheadGroup::Impl<StdThreadInterface>. However there is impetus to move Hexagon fully over to kDLCPU -- wherein we could no longer do runtime dispatch based on the device type.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then the only other options for dispatch I see are
(1) TVM compile-time dispatch -- during codegen, do something specific based on the target
(2) Build time dispatch via either preprocessor or through changes to the build system to conditionally include one translation unit over another.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that abstracting thread control would be the best approach here. The runtime built for a particular target could then contain the implementation provided by that target (e.g. via template specialization), and there wouldn't be any run-time dependency on device kind.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've refactored this PR to split up pthread and qurt threading implementations. This introduces the following changes:
Honestly, this seems like a pretty ugly solution, especially when threading_backend.cc is already littered with #ifdefs for various platforms. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@areusch @kparzysz-quic PTAL. I'm inclined to take @supersat's recommendation if this approach is too ugly. But we'd like to keep forward progress on this one. Thanks!