Skip to content

Commit d18182c

Browse files
committed
set_datum: recalculate scale and bias if necessary
1 parent d716062 commit d18182c

File tree

2 files changed

+163
-24
lines changed

2 files changed

+163
-24
lines changed

include/svs/quantization/scalar/scalar.h

Lines changed: 123 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,8 @@ class CosineSimilarityCompressed {
185185
namespace detail {
186186

187187
struct MinMaxAccumulator {
188-
float min = 0.0;
189-
float max = 0.0;
188+
float min = std::numeric_limits<float>::max();
189+
float max = std::numeric_limits<float>::min();
190190

191191
void accumulate(float val) {
192192
min = std::min(min, val);
@@ -288,6 +288,51 @@ template <typename Element, typename Data> struct Compressor {
288288
float bias_;
289289
};
290290

291+
template <typename Element, typename Data> struct Decompressor {
292+
using element_type = Element;
293+
using data_type = Data;
294+
295+
Decompressor(float scale, float bias)
296+
: scale_{scale}
297+
, bias_{bias} {}
298+
299+
template <
300+
data::ImmutableMemoryDataset Dataset,
301+
threads::ThreadPool Pool,
302+
typename Alloc>
303+
data_type
304+
operator()(const Dataset& data, Pool& threadpool, const Alloc& allocator) const {
305+
static constexpr size_t batch_size = 512;
306+
307+
data_type decompressed{data.size(), data.dimensions(), allocator};
308+
309+
threads::parallel_for(
310+
threadpool,
311+
threads::DynamicPartition(data.size(), batch_size),
312+
[&](const auto& indices, uint64_t /*tid*/) {
313+
threads::UnitRange range{indices};
314+
// Allocate a buffer of given dimensionality, will be re-used for each datum
315+
std::vector<element_type> buffer(data.dimensions());
316+
for (size_t i = range.start(); i < range.stop(); ++i) {
317+
// Compress datum
318+
auto datum = data.get_datum(i);
319+
std::transform(datum.begin(), datum.end(), buffer.begin(), [&](auto v) {
320+
return decompress<float>(v, scale_, bias_);
321+
});
322+
// Store to compressed dataset
323+
decompressed.set_datum(i, buffer);
324+
}
325+
}
326+
);
327+
328+
return decompressed;
329+
}
330+
331+
private:
332+
float scale_;
333+
float bias_;
334+
};
335+
291336
// Map from baseline distance functors to the local versions.
292337
template <typename T, typename ElementType> struct CompressedDistance;
293338

@@ -337,6 +382,8 @@ class SQDataset {
337382
public:
338383
constexpr static size_t extent = Extent;
339384
constexpr static bool uses_compressed_data = true;
385+
constexpr static T MIN = std::numeric_limits<T>::min();
386+
constexpr static T MAX = std::numeric_limits<T>::max();
340387

341388
using allocator_type = Alloc;
342389
using element_type = T;
@@ -347,14 +394,18 @@ class SQDataset {
347394
private:
348395
float scale_;
349396
float bias_;
397+
float min_;
398+
float max_;
350399
data_type data_;
351400

352401
public:
353402
SQDataset(size_t size, size_t dims)
354403
: data_{size, dims} {}
355-
SQDataset(data_type data, float scale, float bias)
404+
SQDataset(data_type data, float scale, float bias, float min, float max)
356405
: scale_(scale)
357406
, bias_(bias)
407+
, min_(min)
408+
, max_(max)
358409
, data_{std::move(data)} {}
359410

360411
size_t size() const { return data_.size(); }
@@ -375,18 +426,66 @@ class SQDataset {
375426
}
376427

377428
template <typename QueryType, size_t N>
378-
void set_datum(size_t i, std::span<QueryType, N> datum) {
379-
auto dims = dimensions();
380-
assert(datum.size() == dims);
429+
void set_datum(
430+
size_t i, std::span<QueryType, N> datum, const allocator_type& allocator = {}
431+
) {
432+
return set_datum(i, datum, 1, allocator);
433+
}
381434

382-
// Compress elements
383-
std::vector<element_type> buffer(dims);
384-
std::transform(datum.begin(), datum.end(), buffer.begin(), [&](QueryType v) {
385-
return detail::compress<QueryType, element_type>(v, scale_, bias_);
386-
});
435+
template <typename QueryType, size_t N>
436+
void set_datum(
437+
size_t i,
438+
std::span<QueryType, N> datum,
439+
size_t num_threads,
440+
const allocator_type& allocator = {}
441+
) {
442+
auto pool = threads::DefaultThreadPool{num_threads};
443+
return set_datum(i, datum, pool, allocator);
444+
}
387445

388-
data_.set_datum(i, buffer);
389-
// TODO: Float16 truncation check? (see codec.h, line 1[14)
446+
template <typename QueryType, size_t N, threads::ThreadPool Pool>
447+
void set_datum(
448+
size_t i,
449+
std::span<QueryType, N> datum,
450+
Pool& threadpool,
451+
const allocator_type& allocator = {}
452+
) {
453+
auto dims = dimensions();
454+
if (datum.size() != dims) {
455+
throw ANNEXCEPTION("Datum size mismatch!");
456+
}
457+
458+
// Check if the datum is within the range
459+
float datum_min = *std::min_element(datum.begin(), datum.end());
460+
float datum_max = *std::max_element(datum.begin(), datum.end());
461+
bool in_range = (datum_min >= min_ && datum_max <= max_);
462+
463+
if (in_range) {
464+
// Compress elements
465+
std::vector<element_type> buffer(dims);
466+
std::transform(datum.begin(), datum.end(), buffer.begin(), [&](QueryType v) {
467+
return detail::compress<QueryType, element_type>(v, scale_, bias_);
468+
});
469+
470+
data_.set_datum(i, buffer);
471+
} else {
472+
// Need to re-compress entire dataset plus the new datum with new parameters
473+
auto decompressor =
474+
detail::Decompressor<element_type, data_type>{scale_, bias_};
475+
auto decompressed = decompressor(data_, threadpool, allocator);
476+
decompressed.set_datum(i, datum);
477+
478+
// Update min_ and max_ values, recalculate scale and bias
479+
min_ = std::min(min_, datum_min);
480+
max_ = std::max(max_, datum_max);
481+
scale_ = (max_ - min_) / (MAX - MIN);
482+
bias_ = min_ - MIN * scale_;
483+
484+
// Recompress with new parameters
485+
auto compressor = detail::Compressor<element_type, data_type>{scale_, bias_};
486+
auto compressed = compressor(decompressed, threadpool, allocator);
487+
data_ = std::move(compressed);
488+
}
390489
}
391490

392491
template <data::ImmutableMemoryDataset Dataset>
@@ -412,10 +511,11 @@ class SQDataset {
412511
// Get dataset extrema
413512
auto minmax = detail::MinMax{};
414513
auto global = minmax(data, threadpool);
514+
if (global.min == global.max) {
515+
throw ANNEXCEPTION("Trivial dataset can't be compressed");
516+
}
415517

416518
// Compute scale and bias
417-
constexpr float MIN = std::numeric_limits<element_type>::min();
418-
constexpr float MAX = std::numeric_limits<element_type>::max();
419519
float scale = (global.max - global.min) / (MAX - MIN);
420520
float bias = global.min - MIN * scale;
421521

@@ -424,7 +524,7 @@ class SQDataset {
424524
auto compressed = compressor(data, threadpool, allocator);
425525

426526
return SQDataset<element_type, extent, allocator_type>{
427-
std::move(compressed), scale, bias};
527+
std::move(compressed), scale, bias, global.min, global.max};
428528
}
429529

430530
/// @brief Compact the dataset
@@ -452,7 +552,9 @@ class SQDataset {
452552
save_version,
453553
{SVS_LIST_SAVE_(data, ctx),
454554
{"scale", lib::save(scale_, ctx)},
455-
{"bias", lib::save(bias_, ctx)}}
555+
{"bias", lib::save(bias_, ctx)},
556+
{"min", lib::save(min_, ctx)},
557+
{"max", lib::save(max_, ctx)}}
456558
);
457559
}
458560

@@ -462,7 +564,10 @@ class SQDataset {
462564
return SQDataset<element_type, extent, allocator_type>{
463565
SVS_LOAD_MEMBER_AT_(table, data, allocator),
464566
lib::load_at<float>(table, "scale"),
465-
lib::load_at<float>(table, "bias")};
567+
lib::load_at<float>(table, "bias"),
568+
lib::load_at<float>(table, "min"),
569+
lib::load_at<float>(table, "max"),
570+
};
466571
}
467572

468573
/// @brief Prefetch data in the dataset.

tests/svs/quantization/scalar/scalar.cpp

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
// catch2
2727
#include "catch2/catch_test_macros.hpp"
28+
#include "catch2/matchers/catch_matchers_string.hpp"
2829

2930
namespace scalar = svs::quantization::scalar;
3031

@@ -294,8 +295,6 @@ template <typename T, typename Distance> void test_distance_compressed() {
294295
}
295296

296297
CATCH_TEST_CASE("Testing SQDataset", "[quantization][scalar]") {
297-
CATCH_SECTION("Default SQDataset") {}
298-
299298
CATCH_SECTION("SQDataset dynamic extent") {
300299
auto x = scalar::SQDataset<std::int8_t>(10, 100);
301300

@@ -306,12 +305,10 @@ CATCH_TEST_CASE("Testing SQDataset", "[quantization][scalar]") {
306305

307306
CATCH_SECTION("SQDataset fixed extent") {
308307
constexpr size_t dims = 128;
309-
auto x = scalar::SQDataset<std::int8_t, dims>({}, 1.0F, 0.0F);
308+
auto x = scalar::SQDataset<std::int8_t, dims>(0, 128);
310309

311310
CATCH_REQUIRE(x.size() == 0);
312311
CATCH_REQUIRE(x.dimensions() == dims);
313-
CATCH_REQUIRE(x.get_scale() == 1.0F);
314-
CATCH_REQUIRE(x.get_bias() == 0.0F);
315312
CATCH_REQUIRE(x.extent == dims);
316313
}
317314

@@ -320,9 +317,46 @@ CATCH_TEST_CASE("Testing SQDataset", "[quantization][scalar]") {
320317
test_sq_top<std::int16_t, 128>();
321318
}
322319

323-
CATCH_SECTION("SQDataset compress and resize") {
320+
CATCH_SECTION("SQDataset compact and resize") {
324321
// TODO
325322
}
323+
324+
CATCH_SECTION("SQDataset trivial compression is not allowed") {
325+
// Compress single-value data, would result in 0 scale
326+
svs::data::SimpleData<float> simple_data(1, 4);
327+
std::vector<float> initial_data = {1, 1, 1, 1};
328+
simple_data.set_datum(0, initial_data);
329+
// the next line must throw
330+
CATCH_REQUIRE_THROWS_MATCHES(
331+
scalar::SQDataset<std::int8_t>::compress(simple_data),
332+
svs::ANNException,
333+
svs_test::ExceptionMatcher(
334+
Catch::Matchers::ContainsSubstring("Trivial dataset can't be compressed")
335+
)
336+
);
337+
}
338+
339+
CATCH_SECTION("SQDataset update scale and bias") {
340+
using A = svs::lib::Allocator<std::int8_t>;
341+
using blocked_type = svs::data::Blocked<A>;
342+
using compressed_type = scalar::SQDataset<std::int8_t, 4, blocked_type>;
343+
344+
// Create SQDataset from initial set of values
345+
auto initial_data = std::vector<float>{1, 2, 3, 4};
346+
auto simple_data = svs::data::SimpleData<float>(1, 4);
347+
simple_data.set_datum(0, initial_data);
348+
349+
auto data = compressed_type::compress(simple_data);
350+
auto initial_scale = data.get_scale();
351+
CATCH_REQUIRE(initial_scale != 0.0F);
352+
353+
// Add another value that's outside the range of the initial values
354+
data.resize(2);
355+
std::vector<float> new_data = {5, 6, 7, 8};
356+
data.set_datum(1, std::span(new_data));
357+
// Assert the scale was updated accordingly
358+
CATCH_REQUIRE(data.get_scale() != initial_scale);
359+
}
326360
}
327361

328362
CATCH_TEST_CASE(

0 commit comments

Comments
 (0)