Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
76877fe
Initial take on adding QuRT thread support to TVM's thread pool. WIP;…
supersat Apr 8, 2022
2605634
Allocate QuRT thread stacks automatically
supersat Apr 8, 2022
82bad79
Remove duplicate stack in QuRTThread
supersat Apr 8, 2022
cb07bd3
Add more logging to QuRTThread
supersat Apr 8, 2022
f360a9f
Use QuRT mutexes and condition variables
supersat Apr 8, 2022
ffeaa90
Get QuRT thread pools working perhaps
supersat Apr 12, 2022
bc0ef77
Sleep for a little bit to let race condition bugs shine through
supersat Apr 12, 2022
e1fd5ee
ayeee it works!
supersat Apr 13, 2022
5948675
Remove custom hexagon implementations of std::mutex and std::conditio…
supersat Apr 13, 2022
deebe25
threading_backend.cc code cleanup
supersat Apr 13, 2022
b1e9265
Formatting changes
supersat Apr 13, 2022
d27746b
remove hexagon debugging
supersat Apr 13, 2022
c9ce982
Initial take on adding QuRT thread support to TVM's thread pool. WIP;…
supersat Apr 8, 2022
39097ab
Allocate QuRT thread stacks automatically
supersat Apr 8, 2022
832c125
Remove duplicate stack in QuRTThread
supersat Apr 8, 2022
a80713f
Add more logging to QuRTThread
supersat Apr 8, 2022
f99d32b
Use QuRT mutexes and condition variables
supersat Apr 8, 2022
85bc9b4
Get QuRT thread pools working perhaps
supersat Apr 12, 2022
91d2b23
Sleep for a little bit to let race condition bugs shine through
supersat Apr 12, 2022
cb2104a
ayeee it works!
supersat Apr 13, 2022
0d20028
Remove custom hexagon implementations of std::mutex and std::conditio…
supersat Apr 13, 2022
a0bf101
threading_backend.cc code cleanup
supersat Apr 13, 2022
95e99d9
Formatting changes
supersat Apr 13, 2022
912ecc6
remove hexagon debugging
supersat Apr 13, 2022
15fafc0
Merge branch 'main' into qurt-thread-pool
supersat Apr 14, 2022
e4d2b14
Add hexagon thread pool test
supersat Apr 18, 2022
590c1b3
style fixes for tests/python/contrib/test_hexagon/test_thread_pool.py
supersat Apr 19, 2022
5576fd3
Fix some style issues
supersat Apr 19, 2022
69d586e
Merge branch 'qurt-thread-pool' of github.com:supersat/tvm into qurt-…
supersat Apr 19, 2022
53ef297
Address some reviewer comments
supersat Apr 21, 2022
24f14f7
Add QuRT thread pool backend
supersat Apr 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion src/runtime/threading_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,72 @@
#endif
#if defined(__hexagon__)
#include <dlfcn.h>
#include <qurt.h>
#include <stdlib.h>
#define HEXAGON_STACK_SIZE 65536
#define HEXAGON_STACK_ALIGNMENT 32
#endif
#include <algorithm>
#include <thread>
#define CURRENT_THREAD_HANDLE (static_cast<std::thread::native_handle_type>(0))
namespace tvm {
namespace runtime {
namespace threading {
#ifdef __hexagon__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if it's possible to do this with templating instead of with #ifdef __hexagon__?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To a certain extent, although it's not clear it's worth the effort. As far as I can tell are a couple options:

  • Have multiple classes that implement the ThreadGroup::Impl interface -- one for pthreads, and one for QuRT threads. @csullivan was concerned that this might lead to duplicated code, making it more fragile to maintain.
  • Parameterizing ThreadPool::Impl on the underlying thread type. However, we'd still need #ifdefs to avoid calling SetAffinity, which isn't available on Hexagon.

Yield() also doesn't work as-is on Hexagon, so we'd need a fix for that as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One approach would be to introduce

template <typename ThreadType>
class ThreadGroup::Impl;

and then wrap both std::thread and QuRT thread into types with a common interface that ThreadGroup::Impl calls into.

The thing I got hung up on there is the runtime dispatch. Right now it should be doable from the device type -- when it is kDLHexagon dispatch to ThreadGroup::Impl<QuRTThreadInterface> when it's kDLCPU dispatch to TheadGroup::Impl<StdThreadInterface>. However there is impetus to move Hexagon fully over to kDLCPU -- wherein we could no longer do runtime dispatch based on the device type.

Copy link
Contributor

@csullivan csullivan Apr 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then the only other options for dispatch I see are

(1) TVM compile-time dispatch -- during codegen, do something specific based on the target
(2) Build time dispatch via either preprocessor or through changes to the build system to conditionally include one translation unit over another.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that abstracting thread control would be the best approach here. The runtime built for a particular target could then contain the implementation provided by that target (e.g. via template specialization), and there wouldn't be any run-time dependency on device kind.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've refactored this PR to split up pthread and qurt threading implementations. This introduces the following changes:

  • ThreadGroup::Impl is now an abstract class.
  • A ThreadGroupImplTemplate class is introduced with common code between pthread and qurt implementations. It is a subclass of ThreadGroup::Impl. (AFAICT, you can't have a pointer to an unspecialized template class. We need a pointer to a concrete type for ThreadGroup to call.)
  • ThreadGroupPosixImpl now contains the bulk of the code from ThreadGroup::Impl. It inherits from ThreadGroupImplTemplatestd::thread. This is in a new file src/runtime/posix/threading_posix.cc.
  • ThreadGroupHexagonImpl inherits from ThreadGroupImplTemplate, which are both defined in a new file, src/runtime/hexagon/threading_hexagon.cc.
  • There are now two different versions of Yield()
  • CMakeLists.txt has been modified to either include src/runtime/posix/threading_posix.cc or src/runtime/hexagon/threading_hexagon.cc, depending on whether we're building for Hexagon.

Honestly, this seems like a pretty ugly solution, especially when threading_backend.cc is already littered with #ifdefs for various platforms. Thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@areusch @kparzysz-quic PTAL. I'm inclined to take @supersat's recommendation if this approach is too ugly. But we'd like to keep forward progress on this one. Thanks!

// pthreads are broken on older versions of qurt, so
// we need to use native APIs instead of std::threads
class QuRTThread {
typedef std::function<void()> Callback;

public:
explicit QuRTThread(Callback worker_callback) : worker_callback_(worker_callback) {
static int id = 1;
qurt_thread_attr_t attr;
char name[32];
int ret = posix_memalign(&stack_, HEXAGON_STACK_ALIGNMENT, HEXAGON_STACK_SIZE);
CHECK_EQ(ret, 0);
// When a std::function<> is cast to bool,
// it indicates whether it stores a callable target
CHECK_EQ((bool)worker_callback_, true);
qurt_thread_attr_init(&attr);
qurt_thread_attr_set_stack_size(&attr, HEXAGON_STACK_SIZE);
qurt_thread_attr_set_stack_addr(&attr, stack_);
snprintf(name, sizeof(name), "worker %d", id++);
qurt_thread_attr_set_name(&attr, name);
ret = qurt_thread_create(&thread_, &attr, (void (*)(void*))RunFunction, this);
CHECK_EQ(ret, QURT_EOK);
}
QuRTThread(QuRTThread&& other)
: thread_(other.thread_),
worker_callback_(std::move(other.worker_callback_)),
stack_(other.stack_) {
other.thread_ = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

although line 73 already have the logic to avoid double free, after the construction from 'Rvalue' logically still need to clear member value(stack_, f_) of 'other', because these value already not owned by "other". doing this can help the future maintains.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

free(NULL) is defined to be a no-op, but I added an explicit check anyway.

other.stack_ = nullptr;
}
~QuRTThread() {
if (thread_) {
join();
}
if (stack_) {
free(stack_);
}
}
bool joinable() const { return qurt_thread_get_id() != thread_; }
void join() {
int status;
qurt_thread_join(thread_, &status);
}

private:
static void RunFunction(QuRTThread* qrt_thread) {
qrt_thread->worker_callback_();
qurt_thread_exit(QURT_EOK);
}
qurt_thread_t thread_;
Callback worker_callback_;
void* stack_ = nullptr;
};
#endif // __hexagon__
thread_local int max_concurrency = 0;
class ThreadGroup::Impl {
public:
Expand Down Expand Up @@ -116,6 +175,7 @@ class ThreadGroup::Impl {
// if worker 0 is offloaded to main, i.e. exclude_worker0 is true,
// the main thread is bound to core 0.
void SetAffinity(bool exclude_worker0, AffinityMode mode) {
#ifndef __hexagon__
const char* val = getenv("TVM_BIND_THREADS");
if (val != nullptr && atoi(val) != 1) {
return;
Expand Down Expand Up @@ -172,6 +232,7 @@ class ThreadGroup::Impl {
SetMasterThreadFullCpuAffinity(mode);
}
}
#endif // __hexagon__
}

void SetThreadFullCpuAffinity(std::thread::native_handle_type thread, AffinityMode mode) {
Expand All @@ -185,6 +246,7 @@ class ThreadGroup::Impl {
// Note: this works well on x86 too. Because x86 doesn't have BIG.LITTLE,
// our implementation will use kBig mode by default and will let main thread
// run on intended cores.
#ifndef __hexagon__
std::vector<unsigned> ids;
switch (mode) {
case kSpecifyOneCorePerThread:
Expand All @@ -206,6 +268,7 @@ class ThreadGroup::Impl {
break;
}
SetThreadAffinity(thread, ids);
#endif // __hexagon__
}

void SetMasterThreadFullCpuAffinity(AffinityMode mode) {
Expand Down Expand Up @@ -259,7 +322,11 @@ class ThreadGroup::Impl {
}

int num_workers_;
#if defined(__hexagon__)
std::vector<QuRTThread> threads_;
#else
std::vector<std::thread> threads_;
#endif
std::vector<unsigned int> sorted_order_;
int big_count_ = 0;
int little_count_ = 0;
Expand All @@ -276,7 +343,17 @@ int ThreadGroup::Configure(AffinityMode mode, int nthreads, bool exclude_worker0
return impl_->Configure(mode, nthreads, exclude_worker0, cpus);
}

void Yield() { std::this_thread::yield(); }
void Yield() {
#ifdef __hexagon__
// QuRT doesn't have a yield API, so instead we sleep for the minimum amount
// of time to let the OS schedule another thread. std::this_thread::yield()
// compiles down to an empty function.
qurt_sleep(1);
#else
std::this_thread::yield();
#endif
}

/*!
* \brief Set the maximum number of available cores.
*/
Expand Down
100 changes: 100 additions & 0 deletions tests/python/contrib/test_hexagon/test_thread_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import numpy as np
import pytest

import tvm
import tvm.contrib.hexagon
import tvm.script
import tvm.testing
from tvm import te

from .conftest import requires_hexagon_toolchain
from tvm.script import tir as T


@tvm.script.ir_module
class ElemwiseSumIRModule:
@T.prim_func
def elemwise_sum_serial(a: T.handle, b: T.handle, c: T.handle, n: T.int32):
T.func_attr({"global_symbol": "elemwise_sum_serial", "tir.noalias": True})
A = T.match_buffer(a, (n,), dtype="float32")
B = T.match_buffer(b, (n,), dtype="float32")
C = T.match_buffer(c, (n,), dtype="float32")
for i in T.serial(n):
with T.block("C"):
vi = T.axis.spatial(n, i)
C[vi] = A[vi] + B[vi]

@T.prim_func
def elemwise_sum_parallel(a: T.handle, b: T.handle, c: T.handle, n: T.int32):
T.func_attr({"global_symbol": "elemwise_sum_parallel", "tir.noalias": True})
A = T.match_buffer(a, (n,), dtype="float32")
B = T.match_buffer(b, (n,), dtype="float32")
C = T.match_buffer(c, (n,), dtype="float32")
for i in T.parallel(n):
with T.block("C"):
vi = T.axis.spatial(n, i)
C[vi] = A[vi] + B[vi]


def generate_add_test_data(hexagon_session, n=128 * 1024):
a = tvm.nd.array(np.random.uniform(size=n).astype("float32"), hexagon_session.device)
b = tvm.nd.array(np.random.uniform(size=n).astype("float32"), hexagon_session.device)
c = tvm.nd.array(np.zeros(n, dtype="float32"), hexagon_session.device)
return (a, b, c, n)


def benchmark_func(mod, name, args, hexagon_session):
(a, b, c, n) = args
evaluator = mod.time_evaluator(name, hexagon_session.device, number=100)
return evaluator(a, b, c, n).mean


@requires_hexagon_toolchain
def test_speedup(hexagon_session, capsys):
if hexagon_session is None:
pytest.skip(msg="Skip hardware test, ANDROID_SERIAL_NUMBER is not set.")

target_hexagon = tvm.target.hexagon("v68", link_params=True)
func = tvm.build(
ElemwiseSumIRModule, target=tvm.target.Target(target_hexagon, host=target_hexagon)
)
mod = hexagon_session.load_module(func)
args = generate_add_test_data(hexagon_session)
parallel_mean = benchmark_func(mod, "elemwise_sum_parallel", args, hexagon_session)
serial_mean = benchmark_func(mod, "elemwise_sum_serial", args, hexagon_session)

with capsys.disabled():
print("... speedup of {:.2f}".format(serial_mean / parallel_mean), end=" ")


@requires_hexagon_toolchain
def test_elemwise_sum_parallel(hexagon_session):
if hexagon_session is None:
pytest.skip(msg="Skip hardware test, ANDROID_SERIAL_NUMBER is not set.")

target_hexagon = tvm.target.hexagon("v68", link_params=True)
func = tvm.build(
ElemwiseSumIRModule, target=tvm.target.Target(target_hexagon, host=target_hexagon)
)
mod = hexagon_session.load_module(func)

(a, b, c, n) = generate_add_test_data(hexagon_session)
mod["elemwise_sum_parallel"](a, b, c, n)
tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy())