@@ -69,22 +69,22 @@ struct UnaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
6969 }
7070};
7171
72- template <typename T> using BinaryOp = T(T, T);
72+ template <typename OutType, typename InType = OutType>
73+ using BinaryOp = OutType(InType, InType);
7374
74- template <typename T, mpfr::Operation Op, BinaryOp<T> Func>
75+ template <typename OutType, typename InType, mpfr::Operation Op,
76+ BinaryOp<OutType, InType> Func>
7577struct BinaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
76- using FloatType = T ;
78+ using FloatType = InType ;
7779 using FPBits = LIBC_NAMESPACE::fputil::FPBits<FloatType>;
7880 using StorageType = typename FPBits::StorageType;
7981
80- static constexpr BinaryOp<FloatType> *FUNC = Func;
81-
8282 // Check in a range, return the number of failures.
8383 uint64_t check (StorageType x_start, StorageType x_stop, StorageType y_start,
8484 StorageType y_stop, mpfr::RoundingMode rounding) {
8585 mpfr::ForceRoundingMode r (rounding);
8686 if (!r.success )
87- return ( x_stop > x_start || y_stop > y_start) ;
87+ return x_stop > x_start || y_stop > y_start;
8888 StorageType xbits = x_start;
8989 uint64_t failed = 0 ;
9090 do {
@@ -93,12 +93,12 @@ struct BinaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
9393 do {
9494 FloatType y = FPBits (ybits).get_val ();
9595 mpfr::BinaryInput<FloatType> input{x, y};
96- bool correct = TEST_MPFR_MATCH_ROUNDING_SILENTLY (Op, input, FUNC (x, y),
96+ bool correct = TEST_MPFR_MATCH_ROUNDING_SILENTLY (Op, input, Func (x, y),
9797 0.5 , rounding);
9898 failed += (!correct);
9999 // Uncomment to print out failed values.
100100 // if (!correct) {
101- // TEST_MPFR_MATCH (Op::Operation, x, Op::func (x, y), 0.5, rounding);
101+ // EXPECT_MPFR_MATCH_ROUNDING (Op, input, Func (x, y), 0.5, rounding);
102102 // }
103103 } while (ybits++ < y_stop);
104104 } while (xbits++ < x_stop);
@@ -108,20 +108,45 @@ struct BinaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
108108
109109// Checker class needs inherit from LIBC_NAMESPACE::testing::Test and provide
110110// StorageType and check method.
111- template <typename Checker>
111+ template <typename Checker, size_t Increment = 1 << 20 >
112112struct LlvmLibcExhaustiveMathTest
113113 : public virtual LIBC_NAMESPACE::testing::Test,
114114 public Checker {
115115 using FloatType = typename Checker::FloatType;
116116 using FPBits = typename Checker::FPBits;
117117 using StorageType = typename Checker::StorageType;
118118
119- static constexpr StorageType INCREMENT = (1 << 20 );
119+ static constexpr StorageType INCREMENT = Increment;
120+
121+ void explain_failed_range (std::stringstream &msg, StorageType x_begin,
122+ StorageType x_end) {
123+ #ifdef LIBC_TYPES_HAS_FLOAT16
124+ using T = LIBC_NAMESPACE::cpp::conditional_t <
125+ LIBC_NAMESPACE::cpp::is_same_v<FloatType, float16>, float , FloatType>;
126+ #else
127+ using T = FloatType;
128+ #endif
129+
130+ msg << x_begin << " to " << x_end << " [0x" << std::hex << x_begin << " , 0x"
131+ << x_end << " ), [" << std::hexfloat
132+ << static_cast <T>(FPBits (x_begin).get_val ()) << " , "
133+ << static_cast <T>(FPBits (x_end).get_val ()) << " )" ;
134+ }
135+
136+ void explain_failed_range (std::stringstream &msg, StorageType x_begin,
137+ StorageType x_end, StorageType y_begin,
138+ StorageType y_end) {
139+ msg << " x " ;
140+ explain_failed_range (msg, x_begin, x_end);
141+ msg << " , y " ;
142+ explain_failed_range (msg, y_begin, y_end);
143+ }
120144
121145 // Break [start, stop) into `nthreads` subintervals and apply *check to each
122146 // subinterval in parallel.
123- void test_full_range (StorageType start, StorageType stop,
124- mpfr::RoundingMode rounding) {
147+ template <typename ... T>
148+ void test_full_range (mpfr::RoundingMode rounding, StorageType start,
149+ StorageType stop, T... extra_range_bounds) {
125150 int n_threads = std::thread::hardware_concurrency ();
126151 std::vector<std::thread> thread_list;
127152 std::mutex mx_cur_val;
@@ -158,15 +183,14 @@ struct LlvmLibcExhaustiveMathTest
158183 std::cout << msg.str () << std::flush;
159184 }
160185
161- uint64_t failed_in_range =
162- Checker::check ( range_begin, range_end, rounding);
186+ uint64_t failed_in_range = Checker::check (
187+ range_begin, range_end, extra_range_bounds... , rounding);
163188 if (failed_in_range > 0 ) {
164189 std::stringstream msg;
165190 msg << " Test failed for " << std::dec << failed_in_range
166- << " inputs in range: " << range_begin << " to " << range_end
167- << " [0x" << std::hex << range_begin << " , 0x" << range_end
168- << " ), [" << std::hexfloat << FPBits (range_begin).get_val ()
169- << " , " << FPBits (range_end).get_val () << " )\n " ;
191+ << " inputs in range: " ;
192+ explain_failed_range (msg, start, stop, extra_range_bounds...);
193+ msg << " \n " ;
170194 std::cerr << msg.str () << std::flush;
171195
172196 failed.fetch_add (failed_in_range);
@@ -189,127 +213,46 @@ struct LlvmLibcExhaustiveMathTest
189213 void test_full_range_all_roundings (StorageType start, StorageType stop) {
190214 std::cout << " -- Testing for FE_TONEAREST in range [0x" << std::hex << start
191215 << " , 0x" << stop << " ) --" << std::dec << std::endl;
192- test_full_range (start, stop, mpfr::RoundingMode::Nearest);
216+ test_full_range (mpfr::RoundingMode::Nearest, start, stop );
193217
194218 std::cout << " -- Testing for FE_UPWARD in range [0x" << std::hex << start
195219 << " , 0x" << stop << " ) --" << std::dec << std::endl;
196- test_full_range (start, stop, mpfr::RoundingMode::Upward);
220+ test_full_range (mpfr::RoundingMode::Upward, start, stop );
197221
198222 std::cout << " -- Testing for FE_DOWNWARD in range [0x" << std::hex << start
199223 << " , 0x" << stop << " ) --" << std::dec << std::endl;
200- test_full_range (start, stop, mpfr::RoundingMode::Downward);
224+ test_full_range (mpfr::RoundingMode::Downward, start, stop );
201225
202226 std::cout << " -- Testing for FE_TOWARDZERO in range [0x" << std::hex
203227 << start << " , 0x" << stop << " ) --" << std::dec << std::endl;
204- test_full_range (start, stop, mpfr::RoundingMode::TowardZero);
228+ test_full_range (mpfr::RoundingMode::TowardZero, start, stop );
205229 };
206- };
207-
208- template <typename Checker>
209- struct LlvmLibcBinaryInputExhaustiveMathTest
210- : public virtual LIBC_NAMESPACE::testing::Test,
211- public Checker {
212- using FloatType = typename Checker::FloatType;
213- using FPBits = typename Checker::FPBits;
214- using StorageType = typename Checker::StorageType;
215-
216- static constexpr StorageType Increment = (1 << 2 );
217-
218- // Break [start, stop) into `nthreads` subintervals and apply *check to each
219- // subinterval in parallel.
220- void test_full_range (StorageType x_start, StorageType x_stop,
221- StorageType y_start, StorageType y_stop,
222- mpfr::RoundingMode rounding) {
223- int n_threads = std::thread::hardware_concurrency ();
224- std::vector<std::thread> thread_list;
225- std::mutex mx_cur_val;
226- int current_percent = -1 ;
227- StorageType current_value = x_start;
228- std::atomic<uint64_t > failed (0 );
229-
230- for (int i = 0 ; i < n_threads; ++i) {
231- thread_list.emplace_back ([&, this ]() {
232- while (true ) {
233- StorageType range_begin, range_end;
234- int new_percent = -1 ;
235- {
236- std::lock_guard<std::mutex> lock (mx_cur_val);
237- if (current_value == x_stop)
238- return ;
239-
240- range_begin = current_value;
241- if (x_stop >= Increment && x_stop - Increment >= current_value) {
242- range_end = current_value + Increment;
243- } else {
244- range_end = x_stop;
245- }
246- current_value = range_end;
247- int pc = 100.0 * (range_end - x_start) / (x_stop - x_start);
248- if (current_percent != pc) {
249- new_percent = pc;
250- current_percent = pc;
251- }
252- }
253- if (new_percent >= 0 ) {
254- std::stringstream msg;
255- msg << new_percent << " % is in process \r " ;
256- std::cout << msg.str () << std::flush;
257- }
258-
259- uint64_t failed_in_range =
260- Checker::check (range_begin, range_end, y_start, y_stop, rounding);
261- if (failed_in_range > 0 ) {
262- using T = LIBC_NAMESPACE::cpp::conditional_t <
263- LIBC_NAMESPACE::cpp::is_same_v<FloatType, float16>, float ,
264- FloatType>;
265- std::stringstream msg;
266- msg << " Test failed for " << std::dec << failed_in_range
267- << " inputs in range: " << range_begin << " to " << range_end
268- << " [0x" << std::hex << range_begin << " , 0x" << range_end
269- << " ), [" << std::hexfloat
270- << static_cast <T>(FPBits (range_begin).get_val ()) << " , "
271- << static_cast <T>(FPBits (range_end).get_val ()) << " )\n " ;
272- std::cerr << msg.str () << std::flush;
273-
274- failed.fetch_add (failed_in_range);
275- }
276- }
277- });
278- }
279-
280- for (auto &thread : thread_list) {
281- if (thread.joinable ()) {
282- thread.join ();
283- }
284- }
285-
286- std::cout << std::endl;
287- std::cout << " Test " << ((failed > 0 ) ? " FAILED" : " PASSED" ) << std::endl;
288- ASSERT_EQ (failed.load (), uint64_t (0 ));
289- }
290230
291231 void test_full_range_all_roundings (StorageType x_start, StorageType x_stop,
292232 StorageType y_start, StorageType y_stop) {
293- test_full_range (x_start, x_stop, y_start, y_stop,
294- mpfr::RoundingMode::Nearest);
233+ std::cout << " -- Testing for FE_TONEAREST in x range [0x" << std::hex
234+ << x_start << " , 0x" << x_stop << " ), y range [0x" << y_start
235+ << " , 0x" << y_stop << " ) --" << std::dec << std::endl;
236+ test_full_range (mpfr::RoundingMode::Nearest, x_start, x_stop, y_start,
237+ y_stop);
295238
296239 std::cout << " -- Testing for FE_UPWARD in x range [0x" << std::hex
297- << x_start << " , 0x" << x_stop << " ) y range [0x" << std::hex
298- << y_start << " , 0x" << y_stop << " ) --" << std::dec << std::endl;
299- test_full_range (x_start, x_stop, y_start, y_stop ,
300- mpfr::RoundingMode::Upward );
240+ << x_start << " , 0x" << x_stop << " ), y range [0x" << y_start
241+ << " , 0x" << y_stop << " ) --" << std::dec << std::endl;
242+ test_full_range (mpfr::RoundingMode::Upward, x_start, x_stop, y_start,
243+ y_stop );
301244
302245 std::cout << " -- Testing for FE_DOWNWARD in x range [0x" << std::hex
303- << x_start << " , 0x" << x_stop << " ) y range [0x" << std::hex
304- << y_start << " , 0x" << y_stop << " ) --" << std::dec << std::endl;
305- test_full_range (x_start, x_stop, y_start, y_stop ,
306- mpfr::RoundingMode::Downward );
246+ << x_start << " , 0x" << x_stop << " ), y range [0x" << y_start
247+ << " , 0x" << y_stop << " ) --" << std::dec << std::endl;
248+ test_full_range (mpfr::RoundingMode::Downward, x_start, x_stop, y_start,
249+ y_stop );
307250
308251 std::cout << " -- Testing for FE_TOWARDZERO in x range [0x" << std::hex
309- << x_start << " , 0x" << x_stop << " ) y range [0x" << std::hex
310- << y_start << " , 0x" << y_stop << " ) --" << std::dec << std::endl;
311- test_full_range (x_start, x_stop, y_start, y_stop ,
312- mpfr::RoundingMode::TowardZero );
252+ << x_start << " , 0x" << x_stop << " ), y range [0x" << y_start
253+ << " , 0x" << y_stop << " ) --" << std::dec << std::endl;
254+ test_full_range (mpfr::RoundingMode::TowardZero, x_start, x_stop, y_start,
255+ y_stop );
313256 };
314257};
315258
@@ -324,4 +267,5 @@ using LlvmLibcUnaryNarrowingOpExhaustiveMathTest =
324267
325268template <typename FloatType, mpfr::Operation Op, BinaryOp<FloatType> Func>
326269using LlvmLibcBinaryOpExhaustiveMathTest =
327- LlvmLibcBinaryInputExhaustiveMathTest<BinaryOpChecker<FloatType, Op, Func>>;
270+ LlvmLibcExhaustiveMathTest<BinaryOpChecker<FloatType, FloatType, Op, Func>,
271+ 1 << 2 >;
0 commit comments