From 0ff2e4bd393810b74bae75ae5d4623a7efc7ea40 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 12 Dec 2022 12:19:03 -0800 Subject: [PATCH 1/5] new topk --- src/runtime/contrib/sort/sort.cc | 48 ++++++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/src/runtime/contrib/sort/sort.cc b/src/runtime/contrib/sort/sort.cc index 8ea2f4b60cdf..0fbb7f02ca0a 100644 --- a/src/runtime/contrib/sort/sort.cc +++ b/src/runtime/contrib/sort/sort.cc @@ -346,7 +346,12 @@ void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, i (out_values == nullptr) ? nullptr : static_cast(out_values->data); IndicesType* indices_ptr = (out_indices == nullptr) ? nullptr : static_cast(out_indices->data); - std::vector> sorter; + + // Maintain a min_heap containing the top-k elements + std::vector> running_heap; + + // Need +1 when inserting new element before maintaining heap invariant + running_heap.reserve(k + 1); int axis_mul_before = 1; int axis_mul_after = 1; @@ -363,26 +368,47 @@ void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, i for (int i = 0; i < axis_mul_before; ++i) { for (int j = 0; j < axis_mul_after; ++j) { - sorter.clear(); + running_heap.clear(); int64_t src_base_idx = i * input->shape[axis] * axis_mul_after + j; int64_t dst_base_idx = i * k * axis_mul_after + j; - for (int64_t kk = 0; kk < input->shape[axis]; ++kk) { - int64_t full_idx = src_base_idx + kk * axis_mul_after; - sorter.emplace_back(std::make_pair(kk, data_ptr[full_idx])); + + // Start by creating min heap with fixed-k elements + int cur_axis_index = 0; + for (; cur_axis_index < k && cur_axis_index < input->shape[axis]; cur_axis_index++) { + int64_t full_idx = src_base_idx + cur_axis_index * axis_mul_after; + running_heap.emplace_back(std::make_pair(cur_axis_index, data_ptr[full_idx])); } + std::make_heap(running_heap.begin(), running_heap.end(), CompareDescend); + + // Iterate through all elements, adding to heap along the way + for (; cur_axis_index < input->shape[axis]; cur_axis_index++) { + int64_t full_idx = src_base_idx + cur_axis_index * axis_mul_after; + std::pair cur_val = {cur_axis_index, data_ptr[full_idx]}; + + // Eq. to cur_val.second > running_heap.second + if (CompareDescend(cur_val, running_heap[0])) { + running_heap.push_back(cur_val); + std::push_heap(running_heap.begin(), running_heap.end(), CompareDescend); + std::pop_heap(running_heap.begin(), running_heap.end(), CompareDescend); + running_heap.pop_back(); + } + } + + // finally sort heap and deliver results if (is_ascend) { - std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); + std::stable_sort(running_heap.begin(), running_heap.end(), CompareDescend); } else { - std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); + std::stable_sort(running_heap.begin(), running_heap.end(), CompareAscend); } - int64_t cnt = k > 0 ? k : input->shape[axis]; - for (int64_t kk = 0; kk < cnt; ++kk) { + + for (uint32_t kk = 0; kk < running_heap.size(); ++kk) { if (indices_ptr != nullptr) { indices_ptr[dst_base_idx + kk * axis_mul_after] = - static_cast(sorter[kk].first); + static_cast(running_heap[kk].first); } if (values_ptr != nullptr) { - values_ptr[dst_base_idx + kk * axis_mul_after] = static_cast(sorter[kk].second); + values_ptr[dst_base_idx + kk * axis_mul_after] = + static_cast(running_heap[kk].second); } } } From bf357c20003e0a594ee81efbe663b385db1f33c4 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 13 Dec 2022 12:46:30 -0800 Subject: [PATCH 2/5] reversed sorting --- src/runtime/contrib/sort/sort.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/runtime/contrib/sort/sort.cc b/src/runtime/contrib/sort/sort.cc index 0fbb7f02ca0a..88d39c13020a 100644 --- a/src/runtime/contrib/sort/sort.cc +++ b/src/runtime/contrib/sort/sort.cc @@ -396,9 +396,9 @@ void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, i // finally sort heap and deliver results if (is_ascend) { - std::stable_sort(running_heap.begin(), running_heap.end(), CompareDescend); - } else { std::stable_sort(running_heap.begin(), running_heap.end(), CompareAscend); + } else { + std::stable_sort(running_heap.begin(), running_heap.end(), CompareDescend); } for (uint32_t kk = 0; kk < running_heap.size(); ++kk) { From b6fe10ed02c50a0a3d0ea01c87bd713d0da282af Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Tue, 3 Jan 2023 11:10:07 -0800 Subject: [PATCH 3/5] stable sorting --- src/runtime/contrib/sort/sort.cc | 55 ++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/src/runtime/contrib/sort/sort.cc b/src/runtime/contrib/sort/sort.cc index 88d39c13020a..1fe3e94ede07 100644 --- a/src/runtime/contrib/sort/sort.cc +++ b/src/runtime/contrib/sort/sort.cc @@ -34,13 +34,25 @@ namespace contrib { using namespace runtime; -template +template bool CompareAscend(const std::pair& lhs, const std::pair& rhs) { + if constexpr(stable_comparison) { + if (lhs.second == rhs.second) { + return lhs.first < rhs.first; + } + } + return lhs.second < rhs.second; } -template +template bool CompareDescend(const std::pair& lhs, const std::pair& rhs) { + if constexpr(stable_comparison) { + if (lhs.second == rhs.second) { + return lhs.first < rhs.first; + } + } + return lhs.second > rhs.second; } @@ -49,18 +61,14 @@ struct float16 { float to_float() const { return __extendXfYf2__(bits); } -}; -template <> -bool CompareAscend(const std::pair& lhs, const std::pair& rhs) { - return lhs.second.to_float() < rhs.second.to_float(); -} - -template <> -bool CompareDescend(const std::pair& lhs, - const std::pair& rhs) { - return lhs.second.to_float() > rhs.second.to_float(); -} + inline bool operator==(const float16& rhs) const { return to_float() == rhs.to_float(); } + inline bool operator!=(const float16& rhs) const { return to_float() != rhs.to_float(); } + inline bool operator< (const float16& rhs) const { return to_float() < rhs.to_float(); } + inline bool operator> (const float16& rhs) const { return to_float() > rhs.to_float(); } + inline bool operator<=(const float16& rhs) const { return to_float() <= rhs.to_float(); } + inline bool operator>=(const float16& rhs) const { return to_float() >= rhs.to_float(); } +}; // Argsort implemented C library sort for nms. // Return indices of sorted tensor. @@ -347,7 +355,7 @@ void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, i IndicesType* indices_ptr = (out_indices == nullptr) ? nullptr : static_cast(out_indices->data); - // Maintain a min_heap containing the top-k elements + // Maintain a min/max containing the top-k elements std::vector> running_heap; // Need +1 when inserting new element before maintaining heap invariant @@ -372,13 +380,13 @@ void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, i int64_t src_base_idx = i * input->shape[axis] * axis_mul_after + j; int64_t dst_base_idx = i * k * axis_mul_after + j; - // Start by creating min heap with fixed-k elements + // Start by creating min/max heap with fixed-k elements int cur_axis_index = 0; for (; cur_axis_index < k && cur_axis_index < input->shape[axis]; cur_axis_index++) { int64_t full_idx = src_base_idx + cur_axis_index * axis_mul_after; running_heap.emplace_back(std::make_pair(cur_axis_index, data_ptr[full_idx])); } - std::make_heap(running_heap.begin(), running_heap.end(), CompareDescend); + std::make_heap(running_heap.begin(), running_heap.end(), CompareDescend); // Iterate through all elements, adding to heap along the way for (; cur_axis_index < input->shape[axis]; cur_axis_index++) { @@ -386,19 +394,24 @@ void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, i std::pair cur_val = {cur_axis_index, data_ptr[full_idx]}; // Eq. to cur_val.second > running_heap.second - if (CompareDescend(cur_val, running_heap[0])) { + if (!is_ascend && CompareDescend(cur_val, running_heap[0])) { + running_heap.push_back(cur_val); + std::push_heap(running_heap.begin(), running_heap.end(), CompareDescend); + std::pop_heap(running_heap.begin(), running_heap.end(), CompareDescend); + running_heap.pop_back(); + } else if (is_ascend && CompareAscend(cur_val, running_heap[0])) { running_heap.push_back(cur_val); - std::push_heap(running_heap.begin(), running_heap.end(), CompareDescend); - std::pop_heap(running_heap.begin(), running_heap.end(), CompareDescend); + std::push_heap(running_heap.begin(), running_heap.end(), CompareAscend); + std::pop_heap(running_heap.begin(), running_heap.end(), CompareAscend); running_heap.pop_back(); } } // finally sort heap and deliver results if (is_ascend) { - std::stable_sort(running_heap.begin(), running_heap.end(), CompareAscend); + std::stable_sort(running_heap.begin(), running_heap.end(), CompareAscend); } else { - std::stable_sort(running_heap.begin(), running_heap.end(), CompareDescend); + std::stable_sort(running_heap.begin(), running_heap.end(), CompareDescend); } for (uint32_t kk = 0; kk < running_heap.size(); ++kk) { From 16554efd642ebce7c7133a382f2233079532abbf Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Tue, 3 Jan 2023 12:13:04 -0800 Subject: [PATCH 4/5] fix heap construction oops --- src/runtime/contrib/sort/sort.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/runtime/contrib/sort/sort.cc b/src/runtime/contrib/sort/sort.cc index 1fe3e94ede07..a92874904d9c 100644 --- a/src/runtime/contrib/sort/sort.cc +++ b/src/runtime/contrib/sort/sort.cc @@ -386,8 +386,12 @@ void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, i int64_t full_idx = src_base_idx + cur_axis_index * axis_mul_after; running_heap.emplace_back(std::make_pair(cur_axis_index, data_ptr[full_idx])); } - std::make_heap(running_heap.begin(), running_heap.end(), CompareDescend); - + if (!is_ascend) { + std::make_heap(running_heap.begin(), running_heap.end(), CompareDescend); + } else { + std::make_heap(running_heap.begin(), running_heap.end(), CompareAscend); + } + // Iterate through all elements, adding to heap along the way for (; cur_axis_index < input->shape[axis]; cur_axis_index++) { int64_t full_idx = src_base_idx + cur_axis_index * axis_mul_after; From 323db1fe8a70859804af66637b14388d375bd14e Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 3 Jan 2023 13:34:12 -0800 Subject: [PATCH 5/5] lint --- src/runtime/contrib/sort/sort.cc | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/runtime/contrib/sort/sort.cc b/src/runtime/contrib/sort/sort.cc index a92874904d9c..bfb174a9206e 100644 --- a/src/runtime/contrib/sort/sort.cc +++ b/src/runtime/contrib/sort/sort.cc @@ -34,9 +34,9 @@ namespace contrib { using namespace runtime; -template +template bool CompareAscend(const std::pair& lhs, const std::pair& rhs) { - if constexpr(stable_comparison) { + if constexpr (stable_comparison) { if (lhs.second == rhs.second) { return lhs.first < rhs.first; } @@ -45,9 +45,9 @@ bool CompareAscend(const std::pair& lhs, const std::pair +template bool CompareDescend(const std::pair& lhs, const std::pair& rhs) { - if constexpr(stable_comparison) { + if constexpr (stable_comparison) { if (lhs.second == rhs.second) { return lhs.first < rhs.first; } @@ -62,12 +62,12 @@ struct float16 { return __extendXfYf2__(bits); } - inline bool operator==(const float16& rhs) const { return to_float() == rhs.to_float(); } - inline bool operator!=(const float16& rhs) const { return to_float() != rhs.to_float(); } - inline bool operator< (const float16& rhs) const { return to_float() < rhs.to_float(); } - inline bool operator> (const float16& rhs) const { return to_float() > rhs.to_float(); } - inline bool operator<=(const float16& rhs) const { return to_float() <= rhs.to_float(); } - inline bool operator>=(const float16& rhs) const { return to_float() >= rhs.to_float(); } + inline bool operator==(const float16& rhs) const { return to_float() == rhs.to_float(); } + inline bool operator!=(const float16& rhs) const { return to_float() != rhs.to_float(); } + inline bool operator<(const float16& rhs) const { return to_float() < rhs.to_float(); } + inline bool operator>(const float16& rhs) const { return to_float() > rhs.to_float(); } + inline bool operator<=(const float16& rhs) const { return to_float() <= rhs.to_float(); } + inline bool operator>=(const float16& rhs) const { return to_float() >= rhs.to_float(); } }; // Argsort implemented C library sort for nms. @@ -391,7 +391,7 @@ void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, i } else { std::make_heap(running_heap.begin(), running_heap.end(), CompareAscend); } - + // Iterate through all elements, adding to heap along the way for (; cur_axis_index < input->shape[axis]; cur_axis_index++) { int64_t full_idx = src_base_idx + cur_axis_index * axis_mul_after;