diff --git a/src/runtime/threading_backend.cc b/src/runtime/threading_backend.cc index 748b0b035094..b067be4752a3 100644 --- a/src/runtime/threading_backend.cc +++ b/src/runtime/threading_backend.cc @@ -34,6 +34,10 @@ #endif #if defined(__hexagon__) #include +#include +#include +#define HEXAGON_STACK_SIZE 65536 +#define HEXAGON_STACK_ALIGNMENT 32 #endif #include #include @@ -41,6 +45,61 @@ namespace tvm { namespace runtime { namespace threading { +#ifdef __hexagon__ +// pthreads are broken on older versions of qurt, so +// we need to use native APIs instead of std::threads +class QuRTThread { + typedef std::function Callback; + + public: + explicit QuRTThread(Callback worker_callback) : worker_callback_(worker_callback) { + static int id = 1; + qurt_thread_attr_t attr; + char name[32]; + int ret = posix_memalign(&stack_, HEXAGON_STACK_ALIGNMENT, HEXAGON_STACK_SIZE); + CHECK_EQ(ret, 0); + // When a std::function<> is cast to bool, + // it indicates whether it stores a callable target + CHECK_EQ((bool)worker_callback_, true); + qurt_thread_attr_init(&attr); + qurt_thread_attr_set_stack_size(&attr, HEXAGON_STACK_SIZE); + qurt_thread_attr_set_stack_addr(&attr, stack_); + snprintf(name, sizeof(name), "worker %d", id++); + qurt_thread_attr_set_name(&attr, name); + ret = qurt_thread_create(&thread_, &attr, (void (*)(void*))RunFunction, this); + CHECK_EQ(ret, QURT_EOK); + } + QuRTThread(QuRTThread&& other) + : thread_(other.thread_), + worker_callback_(std::move(other.worker_callback_)), + stack_(other.stack_) { + other.thread_ = 0; + other.stack_ = nullptr; + } + ~QuRTThread() { + if (thread_) { + join(); + } + if (stack_) { + free(stack_); + } + } + bool joinable() const { return qurt_thread_get_id() != thread_; } + void join() { + int status; + qurt_thread_join(thread_, &status); + } + + private: + static void RunFunction(QuRTThread* qrt_thread) { + qrt_thread->worker_callback_(); + qurt_thread_exit(QURT_EOK); + } + qurt_thread_t thread_; + Callback worker_callback_; + void* stack_ = nullptr; +}; +#endif // __hexagon__ thread_local int max_concurrency = 0; class ThreadGroup::Impl { public: @@ -116,6 +175,7 @@ class ThreadGroup::Impl { // if worker 0 is offloaded to main, i.e. exclude_worker0 is true, // the main thread is bound to core 0. void SetAffinity(bool exclude_worker0, AffinityMode mode) { +#ifndef __hexagon__ const char* val = getenv("TVM_BIND_THREADS"); if (val != nullptr && atoi(val) != 1) { return; @@ -172,6 +232,7 @@ class ThreadGroup::Impl { SetMasterThreadFullCpuAffinity(mode); } } +#endif // __hexagon__ } void SetThreadFullCpuAffinity(std::thread::native_handle_type thread, AffinityMode mode) { @@ -185,6 +246,7 @@ class ThreadGroup::Impl { // Note: this works well on x86 too. Because x86 doesn't have BIG.LITTLE, // our implementation will use kBig mode by default and will let main thread // run on intended cores. +#ifndef __hexagon__ std::vector ids; switch (mode) { case kSpecifyOneCorePerThread: @@ -206,6 +268,7 @@ class ThreadGroup::Impl { break; } SetThreadAffinity(thread, ids); +#endif // __hexagon__ } void SetMasterThreadFullCpuAffinity(AffinityMode mode) { @@ -259,7 +322,11 @@ class ThreadGroup::Impl { } int num_workers_; +#if defined(__hexagon__) + std::vector threads_; +#else std::vector threads_; +#endif std::vector sorted_order_; int big_count_ = 0; int little_count_ = 0; @@ -276,7 +343,17 @@ int ThreadGroup::Configure(AffinityMode mode, int nthreads, bool exclude_worker0 return impl_->Configure(mode, nthreads, exclude_worker0, cpus); } -void Yield() { std::this_thread::yield(); } +void Yield() { +#ifdef __hexagon__ + // QuRT doesn't have a yield API, so instead we sleep for the minimum amount + // of time to let the OS schedule another thread. std::this_thread::yield() + // compiles down to an empty function. + qurt_sleep(1); +#else + std::this_thread::yield(); +#endif +} + /*! * \brief Set the maximum number of available cores. */ diff --git a/tests/python/contrib/test_hexagon/test_thread_pool.py b/tests/python/contrib/test_hexagon/test_thread_pool.py new file mode 100644 index 000000000000..a05404914607 --- /dev/null +++ b/tests/python/contrib/test_hexagon/test_thread_pool.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import pytest + +import tvm +import tvm.contrib.hexagon +import tvm.script +import tvm.testing +from tvm import te + +from .conftest import requires_hexagon_toolchain +from tvm.script import tir as T + + +@tvm.script.ir_module +class ElemwiseSumIRModule: + @T.prim_func + def elemwise_sum_serial(a: T.handle, b: T.handle, c: T.handle, n: T.int32): + T.func_attr({"global_symbol": "elemwise_sum_serial", "tir.noalias": True}) + A = T.match_buffer(a, (n,), dtype="float32") + B = T.match_buffer(b, (n,), dtype="float32") + C = T.match_buffer(c, (n,), dtype="float32") + for i in T.serial(n): + with T.block("C"): + vi = T.axis.spatial(n, i) + C[vi] = A[vi] + B[vi] + + @T.prim_func + def elemwise_sum_parallel(a: T.handle, b: T.handle, c: T.handle, n: T.int32): + T.func_attr({"global_symbol": "elemwise_sum_parallel", "tir.noalias": True}) + A = T.match_buffer(a, (n,), dtype="float32") + B = T.match_buffer(b, (n,), dtype="float32") + C = T.match_buffer(c, (n,), dtype="float32") + for i in T.parallel(n): + with T.block("C"): + vi = T.axis.spatial(n, i) + C[vi] = A[vi] + B[vi] + + +def generate_add_test_data(hexagon_session, n=128 * 1024): + a = tvm.nd.array(np.random.uniform(size=n).astype("float32"), hexagon_session.device) + b = tvm.nd.array(np.random.uniform(size=n).astype("float32"), hexagon_session.device) + c = tvm.nd.array(np.zeros(n, dtype="float32"), hexagon_session.device) + return (a, b, c, n) + + +def benchmark_func(mod, name, args, hexagon_session): + (a, b, c, n) = args + evaluator = mod.time_evaluator(name, hexagon_session.device, number=100) + return evaluator(a, b, c, n).mean + + +@requires_hexagon_toolchain +def test_speedup(hexagon_session, capsys): + if hexagon_session is None: + pytest.skip(msg="Skip hardware test, ANDROID_SERIAL_NUMBER is not set.") + + target_hexagon = tvm.target.hexagon("v68", link_params=True) + func = tvm.build( + ElemwiseSumIRModule, target=tvm.target.Target(target_hexagon, host=target_hexagon) + ) + mod = hexagon_session.load_module(func) + args = generate_add_test_data(hexagon_session) + parallel_mean = benchmark_func(mod, "elemwise_sum_parallel", args, hexagon_session) + serial_mean = benchmark_func(mod, "elemwise_sum_serial", args, hexagon_session) + + with capsys.disabled(): + print("... speedup of {:.2f}".format(serial_mean / parallel_mean), end=" ") + + +@requires_hexagon_toolchain +def test_elemwise_sum_parallel(hexagon_session): + if hexagon_session is None: + pytest.skip(msg="Skip hardware test, ANDROID_SERIAL_NUMBER is not set.") + + target_hexagon = tvm.target.hexagon("v68", link_params=True) + func = tvm.build( + ElemwiseSumIRModule, target=tvm.target.Target(target_hexagon, host=target_hexagon) + ) + mod = hexagon_session.load_module(func) + + (a, b, c, n) = generate_add_test_data(hexagon_session) + mod["elemwise_sum_parallel"](a, b, c, n) + tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy())