Skip to content

Commit 4e860b9

Browse files
authored
Merge pull request #55 from ihsrobotics/thread-refactor
Thread refactor
2 parents 55fa8bc + d371a42 commit 4e860b9

File tree

4 files changed

+223
-34
lines changed

4 files changed

+223
-34
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
cmake_minimum_required(VERSION 3.0)
2-
project(ihs_boost VERSION 1.6.1)
2+
project(ihs_boost VERSION 1.7.0)
33

44
# options
55
option(build_tests "build_tests" OFF)

modules/threading/include/threadable.hpp

Lines changed: 134 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
#define IHSBOOST_THREADABLE_HPP
1515

1616
#include <thread>
17-
#include <functional>
17+
#include <type_traits>
18+
#include <tuple>
1819

1920
/**
2021
* @brief A thread that runs the given function with the given arguments
@@ -35,35 +36,33 @@ class Threadable
3536
* @details upon creation, a Threadable is considered not done and not started
3637
*
3738
* @tparam _MemberFunc the type of the member function to call
38-
* @tparam _Class_Ptr the type, as a pointer, of the instance
39-
* @tparam _Args Types of the arguments to pass to the thread
39+
* @tparam _Class the type the instance
40+
* @tparam _Args the types of the arguments to pass to the thread
4041
* @tparam std::enable_if<std::is_member_function_pointer<_MemberFunc>::value, bool>::type used to enforce template specialization
41-
* @tparam std::enable_if<std::is_pointer<_Class_Ptr>::value, bool>::type used to enforce template specialization
4242
* @param func the member function to call. In most circumstances, this is `&CLASS_NAME::METHOD_NAME`
4343
* where CLASS_NAME is the name of the class and METHOD_NAME is the name of the method
4444
* @param c a pointer to the instance from which to run the member function.
45-
* @param args the arguments with which to call the member function
45+
* @param args the arguments with which to call the member function. Note that these can be lvalues or rvalues
4646
*/
47-
template <typename _MemberFunc, typename _Class_Ptr, typename... _Args,
48-
typename std::enable_if<std::is_member_function_pointer<_MemberFunc>::value, bool>::type = true,
49-
typename std::enable_if<std::is_pointer<_Class_Ptr>::value, bool>::type = true>
50-
Threadable(_MemberFunc &&func, _Class_Ptr c, _Args &&...args) : _started(false),
51-
_done(false),
52-
_thread(), _func([&func, &c, &args...]() -> void
53-
{ (c->*func)(args...); }){};
47+
template <typename _MemberFunc, typename _Class, typename... _Args,
48+
typename std::enable_if<std::is_member_function_pointer<_MemberFunc>::value, bool>::type = true>
49+
Threadable(_MemberFunc &&func, _Class *c, _Args &&...args) : _started(false),
50+
_done(false),
51+
_thread(), _func(new MemberFunctionWrapper<typename std::decay<_MemberFunc>::type, typename std::decay<_Class>::type, _Args...>(func, c, args...)){};
5452

5553
/**
5654
* @brief Construct a new Threadable object to run the given function
5755
* with the given parameters in a separate thread
5856
* @details upon creation, a Threadable is considered not done and not started.
5957
*
58+
* @tparam _Callable the type of the function to call
59+
* @tparam _Args the types of the arguments to pass to the thread
6060
* @param func the function to run
61-
* @param args the arguments to pass to the function
61+
* @param args the arguments to pass to the function. Note that these can be lvalues or rvalues
6262
*/
6363
template <typename _Callable, typename... _Args>
6464
Threadable(_Callable &&func, _Args &&...args) : _started(false), _done(false),
65-
_thread(), _func([&func, &args...]() -> void
66-
{ func(args...); }){};
65+
_thread(), _func(new StaticFunctionWrapper<typename std::decay<_Callable>::type, _Args...>(func, args...)){};
6766

6867
/**
6968
* @brief Destroy the Threadable object
@@ -116,21 +115,133 @@ class Threadable
116115
bool started() const;
117116

118117
Threadable &operator=(const Threadable &other) = delete;
118+
/**
119+
* @brief Equals operator for setting a Threadable equal to
120+
* a Threadable rvalue
121+
*
122+
* @param other the Threadable to set this equal to
123+
* @return Threadable& this Threadable
124+
*/
119125
Threadable &operator=(Threadable &&other);
120126

121127
private:
122128
/**
123-
* @brief Wrapper function to allow the use of threads with member functions
129+
* @brief Wrapper function that sets done to true after finishing
130+
*
131+
*/
132+
void wrapper();
133+
134+
/**
135+
* @brief Function wrapper class to allow for runtime polymorphism
124136
*
125-
* @param threadable the Threadable object
126-
* @param args the arguments to pass to the Threadable object's function
127137
*/
128-
static void wrapper(Threadable *threadable);
138+
class FunctionWrapper
139+
{
140+
public:
141+
/**
142+
* @brief Call the function with any associated arguments
143+
*
144+
*/
145+
virtual void call() = 0;
146+
virtual ~FunctionWrapper(){};
147+
148+
protected:
149+
template <std::size_t... Ts>
150+
struct index
151+
{
152+
};
153+
154+
template <std::size_t N, std::size_t... Ts>
155+
struct gen_seq : gen_seq<N - 1, N - 1, Ts...>
156+
{
157+
};
158+
159+
/**
160+
* @brief Generate a sequence of indexes given the size of
161+
* a parameter pack
162+
*
163+
* @tparam Ts
164+
*/
165+
template <std::size_t... Ts>
166+
struct gen_seq<0, Ts...> : index<Ts...>
167+
{
168+
};
169+
};
170+
171+
/**
172+
* @brief Function wrapper class for static functions (functions
173+
* that aren't member functions)
174+
*
175+
* @tparam _StaticFunc the type of the static function to call
176+
* @tparam _Args the types of the arguments that will be passed
177+
*/
178+
template <typename _StaticFunc, typename... _Args>
179+
class StaticFunctionWrapper : public FunctionWrapper
180+
{
181+
private:
182+
std::tuple<_Args...> _args; ///< used for storing all the arguments in a tuple
183+
_StaticFunc _func; ///< the static function to call
184+
185+
/**
186+
* @brief Unpack the tuple by getting all the arguments by index
187+
*
188+
* @tparam Is all the indexes
189+
*/
190+
template <std::size_t... Is>
191+
void func_caller(FunctionWrapper::index<Is...>) { _func(std::get<Is>(_args)...); }
192+
193+
public:
194+
/**
195+
* @brief Construct a new Static Function Wrapper object
196+
*
197+
* @param func the function to call
198+
* @param args the arguments to call the function with
199+
*/
200+
StaticFunctionWrapper(_StaticFunc func, _Args... args) : _args(std::forward<_Args>(args)...), _func(func) {}
201+
virtual ~StaticFunctionWrapper() = default;
202+
virtual void call() { func_caller(gen_seq<sizeof...(_Args)>{}); }
203+
};
204+
205+
/**
206+
* @brief Function wrapper class for member functions
207+
*
208+
* @tparam _MemberFunc the type of the member function to call
209+
* @tparam _Class the type of the class that will call it
210+
* @tparam _Args the types of the arguments that will be passed
211+
*/
212+
template <typename _MemberFunc, typename _Class, typename... _Args>
213+
class MemberFunctionWrapper : public FunctionWrapper
214+
{
215+
private:
216+
std::tuple<_Args...> _args; ///< used for storing arguments in a tuple
217+
_Class *_ptr; ///< pointer to instance that calls the member function
218+
_MemberFunc _func; ///< pointer to the member function to call
219+
220+
/**
221+
* @brief Unpack the tuple by getting all the arguments by index
222+
*
223+
* @tparam Is all the indexes
224+
*/
225+
template <std::size_t... Is>
226+
void func_caller(FunctionWrapper::index<Is...>) { (_ptr->*_func)(std::get<Is>(_args)...); }
227+
228+
public:
229+
/**
230+
* @brief Construct a new Member Function Wrapper object
231+
*
232+
* @param func the function to call
233+
* @param ptr the instance to call the function from
234+
* @param args the arguments to call the function with
235+
*/
236+
MemberFunctionWrapper(_MemberFunc func, _Class *ptr, _Args... args) : _args(std::forward<_Args>(args)...), _ptr(ptr), _func(func) {}
237+
virtual ~MemberFunctionWrapper() = default;
238+
virtual void call() { func_caller(gen_seq<sizeof...(_Args)>{}); }
239+
};
129240

130-
bool _started; ///< whether or not the thread was started
131-
volatile bool _done; ///< whether or not the thread is done
132-
std::thread _thread; ///< the thread itself
133-
std::function<void()> _func; ///< the function to call
241+
bool _started; ///< whether or not the thread was started
242+
volatile bool _done; ///< whether or not the thread is done
243+
std::thread _thread; ///< the thread itself
244+
FunctionWrapper *_func; ///< the function to call
134245
};
135246

136247
#endif

modules/threading/src/threadable.cpp

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
11
#include "threadable.hpp"
22

3-
Threadable::Threadable() : _started(false), _done(false), _thread(), _func(){};
4-
Threadable::~Threadable() { join(); }
3+
Threadable::Threadable() : _started(false), _done(false), _thread(), _func(nullptr){};
4+
Threadable::~Threadable()
5+
{
6+
// stop thread
7+
join();
8+
9+
// delete function storage
10+
if (_func != nullptr)
11+
{
12+
delete _func;
13+
}
14+
}
515
Threadable &Threadable::operator=(Threadable &&other)
616
{
717
if (this == &other)
@@ -14,14 +24,17 @@ Threadable &Threadable::operator=(Threadable &&other)
1424
_func = std::move(other._func);
1525
_done = other._done;
1626
_started = other._started;
27+
28+
// steal other's _func and make sure it doesn't delete it
29+
other._func = nullptr;
1730
return *this;
1831
}
1932

2033
void Threadable::start()
2134
{
2235
if (!_started)
2336
{
24-
_thread = std::thread(wrapper, this);
37+
_thread = std::thread(&Threadable::wrapper, this);
2538
_done = false;
2639
_started = true;
2740
}
@@ -41,12 +54,15 @@ bool Threadable::done() const { return _done; }
4154
bool Threadable::operator()() const { return done(); }
4255
bool Threadable::started() const { return _started; }
4356

44-
void Threadable::wrapper(Threadable *threadable)
57+
void Threadable::wrapper()
4558
{
46-
// call function
47-
threadable->_func();
59+
// call function if it exists
60+
if (_func != nullptr)
61+
{
62+
_func->call();
63+
}
4864

4965
// cleanup variables
50-
threadable->_done = true;
51-
threadable->_started = false;
66+
_done = true;
67+
_started = false;
5268
}

tests/threading/test/thread_test.cpp

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include <chrono>
55
#include <thread>
66
#include <numeric>
7-
#include <mutex>
87
#include <vector>
98

109
using namespace std;
@@ -15,10 +14,10 @@ class Test
1514
public:
1615
Test(int val) : val(val){};
1716
void increment_val(int increment_amt) { val += increment_amt; }
17+
void add_vals(int a, int b) { val += a + b; }
1818
int get_val() { return val; }
1919

2020
private:
21-
mutex m;
2221
int val;
2322
};
2423

@@ -224,6 +223,65 @@ void test_dynamic(int val1, int val2)
224223
cout << "passed test dynamic" << endl;
225224
}
226225

226+
void test_member_func()
227+
{
228+
Test test(0);
229+
int amt1 = 10;
230+
int amt2 = 20;
231+
Threadable t1(&Test::increment_val, &test, amt1);
232+
Threadable t2(&Test::increment_val, &test, amt2);
233+
234+
t1.start();
235+
t2.start();
236+
237+
while (!t1.done() || !t2.done())
238+
;
239+
240+
assert_equals(amt1 + amt2, test.get_val(), "testing members");
241+
cout << "passed member funcs" << endl;
242+
}
243+
244+
void test_rvalue_member()
245+
{
246+
Test test(0);
247+
Threadable t1(&Test::increment_val, &test, 10);
248+
t1.start();
249+
250+
while (!t1.done())
251+
;
252+
253+
assert_equals(10, test.get_val(), "testing rvalues member");
254+
cout << "passed rvalue member" << endl;
255+
}
256+
257+
void test_rvalue_static()
258+
{
259+
Test test(0);
260+
Threadable t1(a, test, 35);
261+
t1.start();
262+
263+
while (!t1.done())
264+
;
265+
266+
assert_equals(35, test.get_val(), "testing rvalues static");
267+
cout << "passed rvalue static" << endl;
268+
}
269+
270+
void test_multiple_same_type()
271+
{
272+
Test test(0);
273+
Threadable t1(&Test::add_vals, &test, 30, 70);
274+
Threadable t2(&Test::add_vals, &test, 50, 10);
275+
t1.start();
276+
t2.start();
277+
278+
while (!t1.done() || !t2.done())
279+
;
280+
281+
assert_equals(160, test.get_val(), "testing multiple rvalues of same type");
282+
cout << "passed multiple rvalue same type" << endl;
283+
}
284+
227285
int main()
228286
{
229287
test_single_thread_ptr();
@@ -236,6 +294,10 @@ int main()
236294
test_multiple_threads();
237295

238296
test_dynamic(rand() % 10, rand() % 50);
297+
test_member_func();
298+
test_rvalue_member();
299+
test_rvalue_static();
300+
test_multiple_same_type();
239301

240302
return 0;
241303
}

0 commit comments

Comments
 (0)