Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 67 additions & 24 deletions src/runtime/contrib/sort/sort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,25 @@ namespace contrib {

using namespace runtime;

template <typename DType>
template <typename DType, bool stable_comparison = false>
bool CompareAscend(const std::pair<int64_t, DType>& lhs, const std::pair<int64_t, DType>& rhs) {
if constexpr (stable_comparison) {
if (lhs.second == rhs.second) {
return lhs.first < rhs.first;
}
}

return lhs.second < rhs.second;
}

template <typename DType>
template <typename DType, bool stable_comparison = false>
bool CompareDescend(const std::pair<int64_t, DType>& lhs, const std::pair<int64_t, DType>& rhs) {
if constexpr (stable_comparison) {
if (lhs.second == rhs.second) {
return lhs.first < rhs.first;
}
}

return lhs.second > rhs.second;
}

Expand All @@ -49,18 +61,14 @@ struct float16 {
float to_float() const {
return __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(bits);
}
};

template <>
bool CompareAscend(const std::pair<int64_t, float16>& lhs, const std::pair<int64_t, float16>& rhs) {
return lhs.second.to_float() < rhs.second.to_float();
}

template <>
bool CompareDescend(const std::pair<int64_t, float16>& lhs,
const std::pair<int64_t, float16>& rhs) {
return lhs.second.to_float() > rhs.second.to_float();
}
inline bool operator==(const float16& rhs) const { return to_float() == rhs.to_float(); }
inline bool operator!=(const float16& rhs) const { return to_float() != rhs.to_float(); }
inline bool operator<(const float16& rhs) const { return to_float() < rhs.to_float(); }
inline bool operator>(const float16& rhs) const { return to_float() > rhs.to_float(); }
inline bool operator<=(const float16& rhs) const { return to_float() <= rhs.to_float(); }
inline bool operator>=(const float16& rhs) const { return to_float() >= rhs.to_float(); }
};

// Argsort implemented C library sort for nms.
// Return indices of sorted tensor.
Expand Down Expand Up @@ -346,7 +354,12 @@ void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, i
(out_values == nullptr) ? nullptr : static_cast<DataType*>(out_values->data);
IndicesType* indices_ptr =
(out_indices == nullptr) ? nullptr : static_cast<IndicesType*>(out_indices->data);
std::vector<std::pair<int64_t, DataType>> sorter;

// Maintain a min/max containing the top-k elements
std::vector<std::pair<int64_t, DataType>> running_heap;

// Need +1 when inserting new element before maintaining heap invariant
running_heap.reserve(k + 1);

int axis_mul_before = 1;
int axis_mul_after = 1;
Expand All @@ -363,26 +376,56 @@ void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, i

for (int i = 0; i < axis_mul_before; ++i) {
for (int j = 0; j < axis_mul_after; ++j) {
sorter.clear();
running_heap.clear();
int64_t src_base_idx = i * input->shape[axis] * axis_mul_after + j;
int64_t dst_base_idx = i * k * axis_mul_after + j;
for (int64_t kk = 0; kk < input->shape[axis]; ++kk) {
int64_t full_idx = src_base_idx + kk * axis_mul_after;
sorter.emplace_back(std::make_pair(kk, data_ptr[full_idx]));

// Start by creating min/max heap with fixed-k elements
int cur_axis_index = 0;
for (; cur_axis_index < k && cur_axis_index < input->shape[axis]; cur_axis_index++) {
int64_t full_idx = src_base_idx + cur_axis_index * axis_mul_after;
running_heap.emplace_back(std::make_pair(cur_axis_index, data_ptr[full_idx]));
}
if (!is_ascend) {
std::make_heap(running_heap.begin(), running_heap.end(), CompareDescend<DataType, true>);
} else {
std::make_heap(running_heap.begin(), running_heap.end(), CompareAscend<DataType, true>);
}

// Iterate through all elements, adding to heap along the way
for (; cur_axis_index < input->shape[axis]; cur_axis_index++) {
int64_t full_idx = src_base_idx + cur_axis_index * axis_mul_after;
std::pair<int64_t, DataType> cur_val = {cur_axis_index, data_ptr[full_idx]};

// Eq. to cur_val.second > running_heap.second
if (!is_ascend && CompareDescend<DataType, true>(cur_val, running_heap[0])) {
running_heap.push_back(cur_val);
std::push_heap(running_heap.begin(), running_heap.end(), CompareDescend<DataType, true>);
std::pop_heap(running_heap.begin(), running_heap.end(), CompareDescend<DataType, true>);
running_heap.pop_back();
} else if (is_ascend && CompareAscend<DataType, true>(cur_val, running_heap[0])) {
running_heap.push_back(cur_val);
std::push_heap(running_heap.begin(), running_heap.end(), CompareAscend<DataType, true>);
std::pop_heap(running_heap.begin(), running_heap.end(), CompareAscend<DataType, true>);
running_heap.pop_back();
}
}

// finally sort heap and deliver results
if (is_ascend) {
std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<DataType>);
std::stable_sort(running_heap.begin(), running_heap.end(), CompareAscend<DataType, true>);
} else {
std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<DataType>);
std::stable_sort(running_heap.begin(), running_heap.end(), CompareDescend<DataType, true>);
}
int64_t cnt = k > 0 ? k : input->shape[axis];
for (int64_t kk = 0; kk < cnt; ++kk) {

for (uint32_t kk = 0; kk < running_heap.size(); ++kk) {
if (indices_ptr != nullptr) {
indices_ptr[dst_base_idx + kk * axis_mul_after] =
static_cast<IndicesType>(sorter[kk].first);
static_cast<IndicesType>(running_heap[kk].first);
}
if (values_ptr != nullptr) {
values_ptr[dst_base_idx + kk * axis_mul_after] = static_cast<DataType>(sorter[kk].second);
values_ptr[dst_base_idx + kk * axis_mul_after] =
static_cast<DataType>(running_heap[kk].second);
}
}
}
Expand Down