diff --git a/include/svs/index/flat/flat.h b/include/svs/index/flat/flat.h index fdf3f4b82..187fc7440 100644 --- a/include/svs/index/flat/flat.h +++ b/include/svs/index/flat/flat.h @@ -158,6 +158,11 @@ template < typename Ownership = OwnsMembers> class FlatIndex { public: + static constexpr bool supports_insertions = false; + static constexpr bool supports_deletions = false; + static constexpr bool supports_saving = true; + static constexpr bool needs_id_translation = false; + using const_value_type = data::const_value_type_t; /// The type of the distance functor. diff --git a/include/svs/lib/file.h b/include/svs/lib/file.h index aaab0f5b0..937bc1502 100644 --- a/include/svs/lib/file.h +++ b/include/svs/lib/file.h @@ -18,10 +18,12 @@ // svs #include "svs/lib/exception.h" +#include "svs/lib/uuid.h" // stl #include #include +#include namespace svs::lib { @@ -110,4 +112,232 @@ inline std::ifstream open_read( return std::ifstream(path, mode); } +inline std::filesystem::path unique_temp_directory_path(const std::string& prefix) { + namespace fs = std::filesystem; + auto temp_dir = fs::temp_directory_path(); + // Try up to 10 times to create a unique directory. + for (int i = 0; i < 10; ++i) { + auto dir = temp_dir / (prefix + "-" + svs::lib::UUID().str()); + if (!fs::exists(dir)) { + return dir; + } + return dir; + } + throw ANNEXCEPTION("Could not create a unique temporary directory!"); +} + +// RAII helper to create and delete a temporary directory. +struct UniqueTempDirectory { + std::filesystem::path path; + + UniqueTempDirectory(const std::string& prefix) + : path{unique_temp_directory_path(prefix)} { + std::filesystem::create_directories(path); + } + + ~UniqueTempDirectory() { + try { + std::filesystem::remove_all(path); + } catch (...) { + // Ignore errors. + } + } + + std::filesystem::path get() const { return path; } + operator const std::filesystem::path&() const { return path; } +}; + +// Simple directory archiver to pack/unpack a directory to/from a stream. +// Uses a simple custom binary format. +// Not meant to be super efficient, just a simple way to serialize a directory +// structure to a stream. +struct DirectoryArchiver { + using size_type = uint64_t; + + // TODO: Define CACHELINE_BYTES in a common place + // rather than duplicating it here and in prefetch.h + static constexpr auto CACHELINE_BYTES = 64; + static constexpr size_type magic_number = 0x5e2d58d9f3b4a6c1; + + static size_type write_size(std::ostream& os, size_type size) { + os.write(reinterpret_cast(&size), sizeof(size)); + if (!os) { + throw ANNEXCEPTION("Error writing to stream!"); + } + return sizeof(size); + } + + static size_type read_size(std::istream& is, size_type& size) { + is.read(reinterpret_cast(&size), sizeof(size)); + if (!is) { + throw ANNEXCEPTION("Error reading from stream!"); + } + return sizeof(size); + } + + static size_type write_name(std::ostream& os, const std::string& name) { + auto bytes = write_size(os, name.size()); + os.write(name.data(), name.size()); + if (!os) { + throw ANNEXCEPTION("Error writing to stream!"); + } + return bytes + name.size(); + } + + static size_type read_name(std::istream& is, std::string& name) { + size_type size = 0; + auto bytes = read_size(is, size); + name.resize(size); + is.read(name.data(), size); + if (!is) { + throw ANNEXCEPTION("Error reading from stream!"); + } + return bytes + size; + } + + static size_type write_file( + std::ostream& stream, + const std::filesystem::path& path, + const std::filesystem::path& root + ) { + namespace fs = std::filesystem; + check_file(path, std::ios_base::in | std::ios_base::binary); + + // Write the filename as a string. + std::string filename = fs::relative(path, root).string(); + auto header_bytes = write_name(stream, filename); + if (!stream) { + throw ANNEXCEPTION("Error writing to stream!"); + } + + // Write the size of the file. + size_type filesize = fs::file_size(path); + header_bytes += write_size(stream, filesize); + if (!stream) { + throw ANNEXCEPTION("Error writing to stream!"); + } + + // Now write the actual file contents. + std::ifstream in(path, std::ios_base::in | std::ios_base::binary); + if (!in) { + throw ANNEXCEPTION("Error opening file {} for reading!", path); + } + stream << in.rdbuf(); + if (!stream) { + throw ANNEXCEPTION("Error writing to stream!"); + } + + return header_bytes + filesize; + } + + static size_type read_file(std::istream& stream, const std::filesystem::path& root) { + namespace fs = std::filesystem; + + // Read the filename as a string. + std::string filename; + auto header_bytes = read_name(stream, filename); + if (!stream) { + throw ANNEXCEPTION("Error reading from stream!"); + } + + auto path = root / filename; + auto parent_dir = path.parent_path(); + if (!fs::exists(parent_dir)) { + fs::create_directories(parent_dir); + } else if (!fs::is_directory(parent_dir)) { + throw ANNEXCEPTION("Path {} exists and is not a directory!", root); + } + check_file(path, std::ios_base::out | std::ios_base::binary); + + // Read the size of the file. + std::uint64_t filesize = 0; + header_bytes += read_size(stream, filesize); + if (!stream) { + throw ANNEXCEPTION("Error reading from stream!"); + } + + // Now write the actual file contents. + std::ofstream out(path, std::ios_base::out | std::ios_base::binary); + if (!out) { + throw ANNEXCEPTION("Error opening file {} for writing!", path); + } + + // Copy the data in chunks. + constexpr size_t buffer_size = 1 << 13; // 8KB buffer + alignas(CACHELINE_BYTES) char buffer[buffer_size]; + + size_t bytes_remaining = filesize; + while (bytes_remaining > 0) { + size_t to_read = std::min(buffer_size, bytes_remaining); + stream.read(buffer, to_read); + if (!stream) { + throw ANNEXCEPTION("Error reading from stream!"); + } + out.write(buffer, to_read); + if (!out) { + throw ANNEXCEPTION("Error writing to file {}!", path); + } + bytes_remaining -= to_read; + } + + return header_bytes + filesize; + } + + static size_t pack(const std::filesystem::path& dir, std::ostream& stream) { + namespace fs = std::filesystem; + if (!fs::is_directory(dir)) { + throw ANNEXCEPTION("Path {} is not a directory!", dir); + } + + auto total_bytes = write_size(stream, magic_number); + + // Calculate the number of files in the directory. + uint64_t filesnum = std::count_if( + fs::recursive_directory_iterator{dir}, + fs::recursive_directory_iterator{}, + [&](const auto& entry) { return entry.is_regular_file(); } + ); + total_bytes += write_size(stream, filesnum); + + // Now serialize each file in the directory recursively. + for (const auto& entry : fs::recursive_directory_iterator{dir}) { + if (entry.is_regular_file()) { + total_bytes += write_file(stream, entry.path(), dir); + } + // Ignore other types of entries. + } + + return total_bytes; + } + + static size_t unpack(std::istream& stream, const std::filesystem::path& root) { + namespace fs = std::filesystem; + + // Read and verify the magic number. + size_type magic = 0; + auto total_bytes = read_size(stream, magic); + if (magic != magic_number) { + throw ANNEXCEPTION("Invalid magic number in directory unpacking!"); + } + + size_type num_files = 0; + total_bytes += read_size(stream, num_files); + if (!stream) { + throw ANNEXCEPTION("Error reading from stream!"); + } + + if (!fs::exists(root)) { + fs::create_directories(root); + } else if (!fs::is_directory(root)) { + throw ANNEXCEPTION("Path {} exists and is not a directory!", root); + } + + // Now deserialize each file in the directory. + for (size_type i = 0; i < num_files; ++i) { + total_bytes += read_file(stream, root); + } + + return total_bytes; + } +}; } // namespace svs::lib diff --git a/include/svs/orchestrators/dynamic_flat.h b/include/svs/orchestrators/dynamic_flat.h index b658a9ed0..e06efb451 100644 --- a/include/svs/orchestrators/dynamic_flat.h +++ b/include/svs/orchestrators/dynamic_flat.h @@ -58,6 +58,7 @@ class DynamicFlatInterface { const std::filesystem::path& config_directory, const std::filesystem::path& data_directory ) = 0; + virtual void save(std::ostream& stream) = 0; }; template @@ -118,6 +119,21 @@ class DynamicFlatImpl ) override { impl().save(config_directory, data_directory); } + + // Stream-based save implementation + void save(std::ostream& stream) override { + if constexpr (Impl::supports_saving) { + lib::UniqueTempDirectory tempdir{"svs_dynflat_save"}; + const auto config_dir = tempdir.get() / "config"; + const auto data_dir = tempdir.get() / "data"; + std::filesystem::create_directories(config_dir); + std::filesystem::create_directories(data_dir); + save(config_dir, data_dir); + lib::DirectoryArchiver::pack(tempdir, stream); + } else { + throw ANNEXCEPTION("The current DynamicFlat backend doesn't support saving!"); + } + } }; // Forward Declarations. @@ -253,6 +269,44 @@ class DynamicFlat : public manager::IndexManager { ); } + // Assembly from stream + template < + manager::QueryTypeDefinition QueryTypes, + typename Data, + typename Distance, + typename ThreadPoolProto, + typename... DataLoaderArgs> + static DynamicFlat assemble( + std::istream& stream, + const Distance& distance, + ThreadPoolProto threadpool_proto, + DataLoaderArgs&&... data_args + ) { + namespace fs = std::filesystem; + lib::UniqueTempDirectory tempdir{"svs_dynflat_load"}; + lib::DirectoryArchiver::unpack(stream, tempdir); + + const auto config_path = tempdir.get() / "config"; + if (!fs::is_directory(config_path)) { + throw ANNEXCEPTION( + "Invalid Dynamic Flat index archive: missing config directory!" + ); + } + + const auto data_path = tempdir.get() / "data"; + if (!fs::is_directory(data_path)) { + throw ANNEXCEPTION("Invalid Dynamic Flat index archive: missing data directory!" + ); + } + + return assemble( + config_path, + lib::load_from_disk(data_path, SVS_FWD(data_args)...), + distance, + threads::as_threadpool(std::move(threadpool_proto)) + ); + } + ///// Distance /// @brief Get the distance between a vector in the index and a query vector /// @tparam Query The query vector type diff --git a/include/svs/orchestrators/dynamic_vamana.h b/include/svs/orchestrators/dynamic_vamana.h index 465b85cc2..5d27a0d31 100644 --- a/include/svs/orchestrators/dynamic_vamana.h +++ b/include/svs/orchestrators/dynamic_vamana.h @@ -255,6 +255,8 @@ class DynamicVamana : public manager::IndexManager { impl_->save(config_dir, graph_dir, data_dir); } + void save(std::ostream& stream) { impl_->save(stream); } + /// Reconstruction void reconstruct_at(data::SimpleDataView data, std::span ids) { impl_->reconstruct_at(data, ids); @@ -340,6 +342,48 @@ class DynamicVamana : public manager::IndexManager { ); } + // Assembly from stream + template < + manager::QueryTypeDefinition QueryTypes, + typename Data, + typename Distance, + typename ThreadPoolProto, + typename... DataLoaderArgs> + static DynamicVamana assemble( + std::istream& stream, + const Distance& distance, + ThreadPoolProto threadpool_proto, + DataLoaderArgs&&... data_args + ) { + namespace fs = std::filesystem; + lib::UniqueTempDirectory tempdir{"svs_vamana_load"}; + lib::DirectoryArchiver::unpack(stream, tempdir); + + const auto config_path = tempdir.get() / "config"; + if (!fs::is_directory(config_path)) { + throw ANNEXCEPTION("Invalid Vamana index archive: missing config directory!"); + } + + const auto graph_path = tempdir.get() / "graph"; + if (!fs::is_directory(graph_path)) { + throw ANNEXCEPTION("Invalid Vamana index archive: missing graph directory!"); + } + + const auto data_path = tempdir.get() / "data"; + if (!fs::is_directory(data_path)) { + throw ANNEXCEPTION("Invalid Vamana index archive: missing data directory!"); + } + + return assemble( + config_path, + svs::GraphLoader{graph_path}, + lib::load_from_disk(data_path, SVS_FWD(data_args)...), + distance, + threads::as_threadpool(std::move(threadpool_proto)), + false + ); + } + /// @copydoc svs::Vamana::batch_iterator template svs::VamanaIterator batch_iterator( diff --git a/include/svs/orchestrators/exhaustive.h b/include/svs/orchestrators/exhaustive.h index d82f85133..b33b6cc4a 100644 --- a/include/svs/orchestrators/exhaustive.h +++ b/include/svs/orchestrators/exhaustive.h @@ -45,6 +45,7 @@ class FlatInterface { // Non-templated virtual method for distance calculation virtual double get_distance(size_t id, const AnonymousArray<1>& query) const = 0; virtual void save(const std::filesystem::path& data_directory) const = 0; + virtual void save(std::ostream& stream) const = 0; }; template @@ -83,6 +84,16 @@ class FlatImpl : public manager::ManagerImpl { void save(const std::filesystem::path& data_directory) const override { this->impl().save(data_directory); } + + void save(std::ostream& stream) const override { + if constexpr (Impl::supports_saving) { + lib::UniqueTempDirectory tempdir{"svs_flat_save"}; + save(tempdir); + lib::DirectoryArchiver::pack(tempdir, stream); + } else { + throw ANNEXCEPTION("The current Vamana backend doesn't support saving!"); + } + } }; // Forward Declarations @@ -107,6 +118,8 @@ class Flat : public manager::IndexManager { impl_->save(data_directory); } + void save(std::ostream& stream) const { impl_->save(stream); } + ///// Loading /// @@ -154,6 +167,45 @@ class Flat : public manager::IndexManager { } } + /// + /// @brief Load a Flat Index from a stream. + /// + /// @tparam QueryType The element type of the vectors that will be used for querying. + /// + /// @param distance A distance functor to use or a ``svs::DistanceType`` enum. + /// @param threadpool_proto Precursor for the thread pool to use. Can either be an + /// acceptable thread pool + /// instance or an integer specifying the number of threads to use. In the latter + /// case, a new default thread pool will be constructed using ``threadpool_proto`` + /// as the number of threads to create. + /// @param data_args Arguments to be passed to data loader. + /// + /// @copydoc hidden_flat_auto_assemble + /// + /// @copydoc threadpool_requirements + /// + template < + manager::QueryTypeDefinition QueryTypes, + typename Data, + typename Distance, + typename ThreadPoolProto, + typename... DataLoaderArgs> + static Flat assemble( + std::istream& stream, + Distance distance, + ThreadPoolProto threadpool_proto, + DataLoaderArgs&&... data_args + ) { + namespace fs = std::filesystem; + lib::UniqueTempDirectory tempdir{"svs_flat_load"}; + lib::DirectoryArchiver::unpack(stream, tempdir); + return assemble( + lib::load_from_disk(tempdir, SVS_FWD(data_args)...), + distance, + threads::as_threadpool(std::move(threadpool_proto)) + ); + } + ///// Distance /// @brief Get the distance between a vector in the index and a query vector /// @tparam Query The query vector type diff --git a/include/svs/orchestrators/vamana.h b/include/svs/orchestrators/vamana.h index d41bd0cf8..6b698c4f9 100644 --- a/include/svs/orchestrators/vamana.h +++ b/include/svs/orchestrators/vamana.h @@ -35,6 +35,7 @@ // stdlib #include +#include #include #include @@ -71,6 +72,8 @@ class VamanaInterface { const std::filesystem::path& data_dir ) = 0; + virtual void save(std::ostream& stream) = 0; + ///// Reconstruction // TODO: Allow threadpools to be const-invocable. virtual void @@ -172,6 +175,22 @@ class VamanaImpl : public manager::ManagerImpl { } } + void save(std::ostream& stream) override { + if constexpr (Impl::supports_saving) { + lib::UniqueTempDirectory tempdir{"svs_vamana_save"}; + const auto config_dir = tempdir.get() / "config"; + const auto graph_dir = tempdir.get() / "graph"; + const auto data_dir = tempdir.get() / "data"; + std::filesystem::create_directories(config_dir); + std::filesystem::create_directories(graph_dir); + std::filesystem::create_directories(data_dir); + save(config_dir, graph_dir, data_dir); + lib::DirectoryArchiver::pack(tempdir, stream); + } else { + throw ANNEXCEPTION("The current Vamana backend doesn't support saving!"); + } + } + ///// Reconstruction void reconstruct_at(data::SimpleDataView data, std::span ids) override { @@ -363,6 +382,8 @@ class Vamana : public manager::IndexManager { impl_->save(config_directory, graph_directory, data_directory); } + void save(std::ostream& stream) const { impl_->save(stream); } + void reconstruct_at(data::SimpleDataView data, std::span ids) { impl_->reconstruct_at(data, ids); } @@ -440,6 +461,47 @@ class Vamana : public manager::IndexManager { } } + // Assembly from stream + template < + manager::QueryTypeDefinition QueryTypes, + typename Data, + typename Distance, + typename ThreadPoolProto, + typename... DataLoaderArgs> + static Vamana assemble( + std::istream& stream, + const Distance& distance, + ThreadPoolProto threadpool_proto, + DataLoaderArgs&&... data_args + ) { + namespace fs = std::filesystem; + lib::UniqueTempDirectory tempdir{"svs_vamana_load"}; + lib::DirectoryArchiver::unpack(stream, tempdir); + + const auto config_path = tempdir.get() / "config"; + if (!fs::is_directory(config_path)) { + throw ANNEXCEPTION("Invalid Vamana index archive: missing config directory!"); + } + + const auto graph_path = tempdir.get() / "graph"; + if (!fs::is_directory(graph_path)) { + throw ANNEXCEPTION("Invalid Vamana index archive: missing graph directory!"); + } + + const auto data_path = tempdir.get() / "data"; + if (!fs::is_directory(data_path)) { + throw ANNEXCEPTION("Invalid Vamana index archive: missing data directory!"); + } + + return assemble( + config_path, + svs::GraphLoader{graph_path}, + lib::load_from_disk(data_path, SVS_FWD(data_args)...), + distance, + threads::as_threadpool(std::move(threadpool_proto)) + ); + } + /// /// @brief Construct the a Vamana Index for the given dataset. /// diff --git a/tests/integration/exhaustive.cpp b/tests/integration/exhaustive.cpp index 6bff5544e..578cf3f66 100644 --- a/tests/integration/exhaustive.cpp +++ b/tests/integration/exhaustive.cpp @@ -144,6 +144,25 @@ void test_flat( ); batch_size_search_test(index); } + + // Test save and load to a stream + if constexpr (std::is_same_v, svs::Flat>) { + svs_test::prepare_temp_directory(); + auto temp_dir = svs_test::temp_directory(); + auto file = temp_dir / "flat_index.bin"; + std::ofstream file_ostream(file, std::ios::binary); + CATCH_REQUIRE(file_ostream.good()); + index.save(file_ostream); + file_ostream.close(); + std::ifstream file_istream(file, std::ios::binary); + CATCH_REQUIRE(file_istream.good()); + index = svs::Flat:: + assemble, svs::data::SimpleData>( + file_istream, distance_type, index.get_num_threads() + ); + CATCH_REQUIRE(index.get_num_threads() == 2); + batch_size_search_test(index); + } } } // namespace diff --git a/tests/integration/vamana/index_search.cpp b/tests/integration/vamana/index_search.cpp index 3b96d5998..c050872e0 100644 --- a/tests/integration/vamana/index_search.cpp +++ b/tests/integration/vamana/index_search.cpp @@ -344,5 +344,43 @@ CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration][search][vamana]") { run_tests( index, queries, groundtruth, expected_results.config_and_recall_ ); + + // Save/load to/from a single file. + svs_test::prepare_temp_directory(); + + // Set variables to ensure they are saved and reloaded properly. + index.set_search_window_size(123); + index.set_alpha(1.2); + index.set_construction_window_size(456); + index.set_max_candidates(1001); + + max_degree = index.get_graph_max_degree(); + index.set_prune_to(max_degree - 2); + index.set_full_search_history(false); + + auto file = temp_dir / "vamana_index.bin"; + std::ofstream file_ostream(file, std::ios::binary); + CATCH_REQUIRE(file_ostream.good()); + index.save(file_ostream); + file_ostream.close(); + + std::ifstream file_istream(file, std::ios::binary); + CATCH_REQUIRE(file_istream.good()); + index = svs::Vamana::assemble< + svs::lib::Types, + svs::data::SimpleData>(file_istream, distance_type, 2); + + // Data Properties + CATCH_REQUIRE(index.size() == test_dataset::VECTORS_IN_DATA_SET); + CATCH_REQUIRE(index.dimensions() == test_dataset::NUM_DIMENSIONS); + // Index Properties + CATCH_REQUIRE(index.get_search_window_size() == 123); + CATCH_REQUIRE(index.get_alpha() == 1.2f); + CATCH_REQUIRE(index.get_construction_window_size() == 456); + CATCH_REQUIRE(index.get_max_candidates() == 1001); + CATCH_REQUIRE(index.get_graph_max_degree() == max_degree); + CATCH_REQUIRE(index.get_prune_to() == max_degree - 2); + CATCH_REQUIRE(index.get_full_search_history() == false); + run_tests(index, queries, groundtruth, expected_results.config_and_recall_); } } diff --git a/tests/integration/vamana/scalar_search.cpp b/tests/integration/vamana/scalar_search.cpp index 363cc6ae8..e0a247331 100644 --- a/tests/integration/vamana/scalar_search.cpp +++ b/tests/integration/vamana/scalar_search.cpp @@ -120,6 +120,25 @@ void test_search( CATCH_REQUIRE(reloaded.dimensions() == test_dataset::NUM_DIMENSIONS); run_search(index, queries, groundtruth, expected_results.config_and_recall_); } + + // Reload via single file + { + svs_test::prepare_temp_directory(); + auto file = svs_test::temp_directory() / "vamana_index.bin"; + std::ofstream file_ostream(file, std::ios::binary); + CATCH_REQUIRE(file_ostream.good()); + index.save(file_ostream); + file_ostream.close(); + + std::ifstream file_istream(file, std::ios::binary); + CATCH_REQUIRE(file_istream.good()); + auto reloaded = svs::Vamana::assemble(file_istream, distance, num_threads); + + CATCH_REQUIRE(reloaded.get_num_threads() == num_threads); + CATCH_REQUIRE(reloaded.size() == test_dataset::VECTORS_IN_DATA_SET); + CATCH_REQUIRE(reloaded.dimensions() == test_dataset::NUM_DIMENSIONS); + run_search(reloaded, queries, groundtruth, expected_results.config_and_recall_); + } } } // namespace diff --git a/tests/svs/lib/file.cpp b/tests/svs/lib/file.cpp index df188e797..b7ec3adff 100644 --- a/tests/svs/lib/file.cpp +++ b/tests/svs/lib/file.cpp @@ -54,3 +54,45 @@ CATCH_TEST_CASE("Filesystem Handling", "[lib][files]") { ); } } + +CATCH_TEST_CASE("DirectoryArchiver", "[lib][files]") { + namespace fs = std::filesystem; + using namespace svs::lib; + + auto tempdir = svs_test::prepare_temp_directory_v2(); + auto srcdir = tempdir / "src"; + auto dstdir = tempdir / "dst"; + + // Create a source directory with some files in it. + fs::create_directories(srcdir); + std::ofstream(srcdir / "file1.txt") << "Hello, World!" << std::endl; + fs::create_directories(srcdir / "subdir"); + std::ofstream(srcdir / "subdir/file2.txt") << "This is a test." << std::endl; + + CATCH_SECTION("Pack and Unpack") { + // Pack the directory. + std::stringstream ss; + auto bytes_written = DirectoryArchiver::pack(srcdir, ss); + CATCH_REQUIRE(bytes_written > 0); + + // Unpack the directory. + fs::create_directories(dstdir); + auto bytes_read = DirectoryArchiver::unpack(ss, dstdir); + CATCH_REQUIRE(bytes_read == bytes_written); + + // Check that the files exist in the destination directory. + CATCH_REQUIRE(fs::exists(dstdir / "file1.txt")); + CATCH_REQUIRE(fs::exists(dstdir / "subdir/file2.txt")); + + // Check that the contents are correct. + std::ifstream in1(dstdir / "file1.txt"); + std::string line1; + std::getline(in1, line1); + CATCH_REQUIRE(line1 == "Hello, World!"); + + std::ifstream in2(dstdir / "subdir/file2.txt"); + std::string line2; + std::getline(in2, line2); + CATCH_REQUIRE(line2 == "This is a test."); + } +}