@@ -185,8 +185,8 @@ class CosineSimilarityCompressed {
185185namespace detail {
186186
187187struct MinMaxAccumulator {
188- float min = 0.0 ;
189- float max = 0.0 ;
188+ float min = std::numeric_limits< float >::max() ;
189+ float max = std::numeric_limits< float >::min() ;
190190
191191 void accumulate (float val) {
192192 min = std::min (min, val);
@@ -288,6 +288,51 @@ template <typename Element, typename Data> struct Compressor {
288288 float bias_;
289289};
290290
291+ template <typename Element, typename Data> struct Decompressor {
292+ using element_type = Element;
293+ using data_type = Data;
294+
295+ Decompressor (float scale, float bias)
296+ : scale_{scale}
297+ , bias_{bias} {}
298+
299+ template <
300+ data::ImmutableMemoryDataset Dataset,
301+ threads::ThreadPool Pool,
302+ typename Alloc>
303+ data_type
304+ operator ()(const Dataset& data, Pool& threadpool, const Alloc& allocator) const {
305+ static constexpr size_t batch_size = 512 ;
306+
307+ data_type decompressed{data.size (), data.dimensions (), allocator};
308+
309+ threads::parallel_for (
310+ threadpool,
311+ threads::DynamicPartition (data.size (), batch_size),
312+ [&](const auto & indices, uint64_t /* tid*/ ) {
313+ threads::UnitRange range{indices};
314+ // Allocate a buffer of given dimensionality, will be re-used for each datum
315+ std::vector<element_type> buffer (data.dimensions ());
316+ for (size_t i = range.start (); i < range.stop (); ++i) {
317+ // Compress datum
318+ auto datum = data.get_datum (i);
319+ std::transform (datum.begin (), datum.end (), buffer.begin (), [&](auto v) {
320+ return decompress<float >(v, scale_, bias_);
321+ });
322+ // Store to compressed dataset
323+ decompressed.set_datum (i, buffer);
324+ }
325+ }
326+ );
327+
328+ return decompressed;
329+ }
330+
331+ private:
332+ float scale_;
333+ float bias_;
334+ };
335+
291336// Map from baseline distance functors to the local versions.
292337template <typename T, typename ElementType> struct CompressedDistance ;
293338
@@ -337,6 +382,8 @@ class SQDataset {
337382 public:
338383 constexpr static size_t extent = Extent;
339384 constexpr static bool uses_compressed_data = true ;
385+ constexpr static T MIN = std::numeric_limits<T>::min();
386+ constexpr static T MAX = std::numeric_limits<T>::max();
340387
341388 using allocator_type = Alloc;
342389 using element_type = T;
@@ -347,14 +394,18 @@ class SQDataset {
347394 private:
348395 float scale_;
349396 float bias_;
397+ float min_;
398+ float max_;
350399 data_type data_;
351400
352401 public:
353402 SQDataset (size_t size, size_t dims)
354403 : data_{size, dims} {}
355- SQDataset (data_type data, float scale, float bias)
404+ SQDataset (data_type data, float scale, float bias, float min, float max )
356405 : scale_(scale)
357406 , bias_(bias)
407+ , min_(min)
408+ , max_(max)
358409 , data_{std::move (data)} {}
359410
360411 size_t size () const { return data_.size (); }
@@ -375,18 +426,66 @@ class SQDataset {
375426 }
376427
377428 template <typename QueryType, size_t N>
378- void set_datum (size_t i, std::span<QueryType, N> datum) {
379- auto dims = dimensions ();
380- assert (datum.size () == dims);
429+ void set_datum (
430+ size_t i, std::span<QueryType, N> datum, const allocator_type& allocator = {}
431+ ) {
432+ return set_datum (i, datum, 1 , allocator);
433+ }
381434
382- // Compress elements
383- std::vector<element_type> buffer (dims);
384- std::transform (datum.begin (), datum.end (), buffer.begin (), [&](QueryType v) {
385- return detail::compress<QueryType, element_type>(v, scale_, bias_);
386- });
435+ template <typename QueryType, size_t N>
436+ void set_datum (
437+ size_t i,
438+ std::span<QueryType, N> datum,
439+ size_t num_threads,
440+ const allocator_type& allocator = {}
441+ ) {
442+ auto pool = threads::DefaultThreadPool{num_threads};
443+ return set_datum (i, datum, pool, allocator);
444+ }
387445
388- data_.set_datum (i, buffer);
389- // TODO: Float16 truncation check? (see codec.h, line 1[14)
446+ template <typename QueryType, size_t N, threads::ThreadPool Pool>
447+ void set_datum (
448+ size_t i,
449+ std::span<QueryType, N> datum,
450+ Pool& threadpool,
451+ const allocator_type& allocator = {}
452+ ) {
453+ auto dims = dimensions ();
454+ if (datum.size () != dims) {
455+ throw ANNEXCEPTION (" Datum size mismatch!" );
456+ }
457+
458+ // Check if the datum is within the range
459+ float datum_min = *std::min_element (datum.begin (), datum.end ());
460+ float datum_max = *std::max_element (datum.begin (), datum.end ());
461+ bool in_range = (datum_min >= min_ && datum_max <= max_);
462+
463+ if (in_range) {
464+ // Compress elements
465+ std::vector<element_type> buffer (dims);
466+ std::transform (datum.begin (), datum.end (), buffer.begin (), [&](QueryType v) {
467+ return detail::compress<QueryType, element_type>(v, scale_, bias_);
468+ });
469+
470+ data_.set_datum (i, buffer);
471+ } else {
472+ // Need to re-compress entire dataset plus the new datum with new parameters
473+ auto decompressor =
474+ detail::Decompressor<element_type, data_type>{scale_, bias_};
475+ auto decompressed = decompressor (data_, threadpool, allocator);
476+ decompressed.set_datum (i, datum);
477+
478+ // Update min_ and max_ values, recalculate scale and bias
479+ min_ = std::min (min_, datum_min);
480+ max_ = std::max (max_, datum_max);
481+ scale_ = (max_ - min_) / (MAX - MIN);
482+ bias_ = min_ - MIN * scale_;
483+
484+ // Recompress with new parameters
485+ auto compressor = detail::Compressor<element_type, data_type>{scale_, bias_};
486+ auto compressed = compressor (decompressed, threadpool, allocator);
487+ data_ = std::move (compressed);
488+ }
390489 }
391490
392491 template <data::ImmutableMemoryDataset Dataset>
@@ -412,10 +511,11 @@ class SQDataset {
412511 // Get dataset extrema
413512 auto minmax = detail::MinMax{};
414513 auto global = minmax (data, threadpool);
514+ if (global.min == global.max ) {
515+ throw ANNEXCEPTION (" Trivial dataset can't be compressed" );
516+ }
415517
416518 // Compute scale and bias
417- constexpr float MIN = std::numeric_limits<element_type>::min ();
418- constexpr float MAX = std::numeric_limits<element_type>::max ();
419519 float scale = (global.max - global.min ) / (MAX - MIN);
420520 float bias = global.min - MIN * scale;
421521
@@ -424,7 +524,7 @@ class SQDataset {
424524 auto compressed = compressor (data, threadpool, allocator);
425525
426526 return SQDataset<element_type, extent, allocator_type>{
427- std::move (compressed), scale, bias};
527+ std::move (compressed), scale, bias, global. min , global. max };
428528 }
429529
430530 // / @brief Compact the dataset
@@ -452,7 +552,9 @@ class SQDataset {
452552 save_version,
453553 {SVS_LIST_SAVE_ (data, ctx),
454554 {" scale" , lib::save (scale_, ctx)},
455- {" bias" , lib::save (bias_, ctx)}}
555+ {" bias" , lib::save (bias_, ctx)},
556+ {" min" , lib::save (min_, ctx)},
557+ {" max" , lib::save (max_, ctx)}}
456558 );
457559 }
458560
@@ -462,7 +564,10 @@ class SQDataset {
462564 return SQDataset<element_type, extent, allocator_type>{
463565 SVS_LOAD_MEMBER_AT_ (table, data, allocator),
464566 lib::load_at<float >(table, " scale" ),
465- lib::load_at<float >(table, " bias" )};
567+ lib::load_at<float >(table, " bias" ),
568+ lib::load_at<float >(table, " min" ),
569+ lib::load_at<float >(table, " max" ),
570+ };
466571 }
467572
468573 // / @brief Prefetch data in the dataset.
0 commit comments