Skip to content

Commit 4e84c70

Browse files
PSeitzPSeitz-dd
andauthored
Fix TopNComputer for reverse order (#2672)
Co-authored-by: Pascal Seitz <[email protected]>
1 parent f2c77f0 commit 4e84c70

File tree

1 file changed

+73
-5
lines changed

1 file changed

+73
-5
lines changed

src/collector/top_score_collector.rs

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -970,7 +970,7 @@ impl<Score, D, const R: bool> From<TopNComputerDeser<Score, D, R>> for TopNCompu
970970
}
971971
}
972972

973-
impl<Score, D, const R: bool> TopNComputer<Score, D, R>
973+
impl<Score, D, const REVERSE_ORDER: bool> TopNComputer<Score, D, REVERSE_ORDER>
974974
where
975975
Score: PartialOrd + Clone,
976976
D: Ord,
@@ -991,7 +991,10 @@ where
991991
#[inline]
992992
pub fn push(&mut self, feature: Score, doc: D) {
993993
if let Some(last_median) = self.threshold.clone() {
994-
if feature < last_median {
994+
if !REVERSE_ORDER && feature > last_median {
995+
return;
996+
}
997+
if REVERSE_ORDER && feature < last_median {
995998
return;
996999
}
9971000
}
@@ -1026,7 +1029,7 @@ where
10261029
}
10271030

10281031
/// Returns the top n elements in sorted order.
1029-
pub fn into_sorted_vec(mut self) -> Vec<ComparableDoc<Score, D, R>> {
1032+
pub fn into_sorted_vec(mut self) -> Vec<ComparableDoc<Score, D, REVERSE_ORDER>> {
10301033
if self.buffer.len() > self.top_n {
10311034
self.truncate_top_n();
10321035
}
@@ -1037,7 +1040,7 @@ where
10371040
/// Returns the top n elements in stored order.
10381041
/// Useful if you do not need the elements in sorted order,
10391042
/// for example when merging the results of multiple segments.
1040-
pub fn into_vec(mut self) -> Vec<ComparableDoc<Score, D, R>> {
1043+
pub fn into_vec(mut self) -> Vec<ComparableDoc<Score, D, REVERSE_ORDER>> {
10411044
if self.buffer.len() > self.top_n {
10421045
self.truncate_top_n();
10431046
}
@@ -1047,9 +1050,11 @@ where
10471050

10481051
#[cfg(test)]
10491052
mod tests {
1053+
use proptest::prelude::*;
1054+
10501055
use super::{TopDocs, TopNComputer};
10511056
use crate::collector::top_collector::ComparableDoc;
1052-
use crate::collector::Collector;
1057+
use crate::collector::{Collector, DocSetCollector};
10531058
use crate::query::{AllQuery, Query, QueryParser};
10541059
use crate::schema::{Field, Schema, FAST, STORED, TEXT};
10551060
use crate::time::format_description::well_known::Rfc3339;
@@ -1144,6 +1149,44 @@ mod tests {
11441149
}
11451150
}
11461151

1152+
proptest! {
1153+
#[test]
1154+
fn test_topn_computer_asc_prop(
1155+
limit in 0..10_usize,
1156+
docs in proptest::collection::vec((0..100_u64, 0..100_u64), 0..100_usize),
1157+
) {
1158+
let mut computer: TopNComputer<_, _, false> = TopNComputer::new(limit);
1159+
for (feature, doc) in &docs {
1160+
computer.push(*feature, *doc);
1161+
}
1162+
let mut comparable_docs = docs.into_iter().map(|(feature, doc)| ComparableDoc { feature, doc }).collect::<Vec<_>>();
1163+
comparable_docs.sort();
1164+
comparable_docs.truncate(limit);
1165+
prop_assert_eq!(
1166+
computer.into_sorted_vec(),
1167+
comparable_docs,
1168+
);
1169+
}
1170+
1171+
#[test]
1172+
fn test_topn_computer_desc_prop(
1173+
limit in 0..10_usize,
1174+
docs in proptest::collection::vec((0..100_u64, 0..100_u64), 0..100_usize),
1175+
) {
1176+
let mut computer: TopNComputer<_, _, true> = TopNComputer::new(limit);
1177+
for (feature, doc) in &docs {
1178+
computer.push(*feature, *doc);
1179+
}
1180+
let mut comparable_docs = docs.into_iter().map(|(feature, doc)| ComparableDoc { feature, doc }).collect::<Vec<_>>();
1181+
comparable_docs.sort();
1182+
comparable_docs.truncate(limit);
1183+
prop_assert_eq!(
1184+
computer.into_sorted_vec(),
1185+
comparable_docs,
1186+
);
1187+
}
1188+
}
1189+
11471190
#[test]
11481191
fn test_top_collector_not_at_capacity_without_offset() -> crate::Result<()> {
11491192
let index = make_index()?;
@@ -1645,4 +1688,29 @@ mod tests {
16451688
);
16461689
Ok(())
16471690
}
1691+
1692+
#[test]
1693+
fn test_topn_computer_asc() {
1694+
let mut computer: TopNComputer<u32, u32, false> = TopNComputer::new(2);
1695+
1696+
computer.push(1u32, 1u32);
1697+
computer.push(2u32, 2u32);
1698+
computer.push(3u32, 3u32);
1699+
computer.push(2u32, 4u32);
1700+
computer.push(4u32, 5u32);
1701+
computer.push(1u32, 6u32);
1702+
assert_eq!(
1703+
computer.into_sorted_vec(),
1704+
&[
1705+
ComparableDoc {
1706+
feature: 1u32,
1707+
doc: 1u32,
1708+
},
1709+
ComparableDoc {
1710+
feature: 1u32,
1711+
doc: 6u32,
1712+
}
1713+
]
1714+
);
1715+
}
16481716
}

0 commit comments

Comments
 (0)