From f4c3d43ff2344a127131205f31bb8c8a7e672086 Mon Sep 17 00:00:00 2001 From: Pragyan Poudyal Date: Tue, 8 Apr 2025 14:03:08 +0530 Subject: [PATCH 01/11] Impl `try_clone` for `Repository` Signed-off-by: Pragyan Poudyal --- src/repository.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/repository.rs b/src/repository.rs index d30fc9f3..a522b780 100644 --- a/src/repository.rs +++ b/src/repository.rs @@ -2,7 +2,7 @@ use std::{ collections::HashSet, ffi::CStr, fs::File, - io::{ErrorKind, Read, Write}, + io::{self, ErrorKind, Read, Write}, os::fd::{AsFd, OwnedFd}, path::{Path, PathBuf}, }; @@ -46,6 +46,12 @@ impl Repository { ) } + pub fn try_clone(&self) -> io::Result { + Ok(Self { + repository: self.repository.try_clone()?, + }) + } + pub fn open_path(dirfd: impl AsFd, path: impl AsRef) -> Result { let path = path.as_ref(); From 6e750385f099b94b3dcded24fc505c9ebdeca629 Mon Sep 17 00:00:00 2001 From: Pragyan Poudyal Date: Tue, 8 Apr 2025 14:03:33 +0530 Subject: [PATCH 02/11] Add `*.tar` to .gitignore Signed-off-by: Pragyan Poudyal --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 1b72444a..e5acaa84 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /Cargo.lock /target +*.tar From cf634762f7aab58cf478575a248ac4b871952aa1 Mon Sep 17 00:00:00 2001 From: Pragyan Poudyal Date: Mon, 7 Apr 2025 14:27:46 +0530 Subject: [PATCH 03/11] Separate out splitstream writer into its own struct Separating out the writer allows us to abstract the parallelism of writing the splitstream. Signed-off-by: Pragyan Poudyal --- src/lib.rs | 1 + src/repository.rs | 4 +- src/splitstream.rs | 78 ++++++++++---------------------------- src/zstd_encoder.rs | 92 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 114 insertions(+), 61 deletions(-) create mode 100644 src/zstd_encoder.rs diff --git a/src/lib.rs b/src/lib.rs index 9edd9a12..4e1d9345 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ pub mod repository; pub mod selabel; pub mod splitstream; pub mod util; +pub mod zstd_encoder; /// All files that contain 64 or fewer bytes (size <= INLINE_CONTENT_MAX) should be stored inline /// in the erofs image (and also in splitstreams). All files with 65 or more bytes (size > MAX) diff --git a/src/repository.rs b/src/repository.rs index a522b780..90303fdc 100644 --- a/src/repository.rs +++ b/src/repository.rs @@ -171,7 +171,7 @@ impl Repository { Ok(result) } - fn format_object_path(id: &Sha256HashValue) -> String { + pub fn format_object_path(id: &Sha256HashValue) -> String { format!("objects/{:02x}/{}", id[0], hex::encode(&id[1..])) } @@ -236,7 +236,7 @@ impl Repository { writer: SplitStreamWriter, reference: Option<&str>, ) -> Result { - let Some((.., ref sha256)) = writer.sha256 else { + let Some((.., ref sha256)) = writer.get_sha_builder() else { bail!("Writer doesn't have sha256 enabled"); }; let stream_path = format!("streams/{}", hex::encode(sha256)); diff --git a/src/splitstream.rs b/src/splitstream.rs index 319ffe09..7b8ff20d 100644 --- a/src/splitstream.rs +++ b/src/splitstream.rs @@ -6,13 +6,14 @@ use std::io::{BufReader, Read, Write}; use anyhow::{bail, Result}; -use sha2::{Digest, Sha256}; -use zstd::stream::{read::Decoder, write::Encoder}; +use sha2::Sha256; +use zstd::stream::read::Decoder; use crate::{ fsverity::{FsVerityHashValue, Sha256HashValue}, repository::Repository, util::read_exactish, + zstd_encoder::ZstdWriter, }; #[derive(Debug)] @@ -60,9 +61,8 @@ impl DigestMap { pub struct SplitStreamWriter<'a> { repo: &'a Repository, - inline_content: Vec, - writer: Encoder<'a, Vec>, - pub sha256: Option<(Sha256, Sha256HashValue)>, + pub(crate) inline_content: Vec, + writer: ZstdWriter, } impl std::fmt::Debug for SplitStreamWriter<'_> { @@ -71,7 +71,6 @@ impl std::fmt::Debug for SplitStreamWriter<'_> { f.debug_struct("SplitStreamWriter") .field("repo", &self.repo) .field("inline_content", &self.inline_content) - .field("sha256", &self.sha256) .finish() } } @@ -82,85 +81,46 @@ impl SplitStreamWriter<'_> { refs: Option, sha256: Option, ) -> SplitStreamWriter { - // SAFETY: we surely can't get an error writing the header to a Vec - let mut writer = Encoder::new(vec![], 0).unwrap(); - - match refs { - Some(DigestMap { map }) => { - writer.write_all(&(map.len() as u64).to_le_bytes()).unwrap(); - for ref entry in map { - writer.write_all(&entry.body).unwrap(); - writer.write_all(&entry.verity).unwrap(); - } - } - None => { - writer.write_all(&0u64.to_le_bytes()).unwrap(); - } - } - SplitStreamWriter { repo, inline_content: vec![], - writer, - sha256: sha256.map(|x| (Sha256::new(), x)), + writer: ZstdWriter::new(sha256, refs), } } - fn write_fragment(writer: &mut impl Write, size: usize, data: &[u8]) -> Result<()> { - writer.write_all(&(size as u64).to_le_bytes())?; - Ok(writer.write_all(data)?) + pub fn get_sha_builder(&self) -> &Option<(Sha256, Sha256HashValue)> { + &self.writer.sha256_builder } /// flush any buffered inline data, taking new_value as the new value of the buffer fn flush_inline(&mut self, new_value: Vec) -> Result<()> { - if !self.inline_content.is_empty() { - SplitStreamWriter::write_fragment( - &mut self.writer, - self.inline_content.len(), - &self.inline_content, - )?; - self.inline_content = new_value; - } + self.writer.flush_inline(&self.inline_content)?; + self.inline_content = new_value; Ok(()) } /// really, "add inline content to the buffer" /// you need to call .flush_inline() later pub fn write_inline(&mut self, data: &[u8]) { - if let Some((ref mut sha256, ..)) = self.sha256 { - sha256.update(data); - } + self.writer.update_sha(data); self.inline_content.extend(data); } - /// write a reference to external data to the stream. If the external data had padding in the - /// stream which is not stored in the object then pass it here as well and it will be stored - /// inline after the reference. - fn write_reference(&mut self, reference: Sha256HashValue, padding: Vec) -> Result<()> { - // Flush the inline data before we store the external reference. Any padding from the - // external data becomes the start of a new inline block. - self.flush_inline(padding)?; + pub fn write_external(&mut self, data: &[u8], padding: Vec) -> Result<()> { + let id = self.repo.ensure_object(&data)?; - SplitStreamWriter::write_fragment(&mut self.writer, 0, &reference) - } + self.writer.update_sha(data); + self.writer.update_sha(&padding); + self.writer.flush_inline(&padding)?; - pub fn write_external(&mut self, data: &[u8], padding: Vec) -> Result<()> { - if let Some((ref mut sha256, ..)) = self.sha256 { - sha256.update(data); - sha256.update(&padding); - } - let id = self.repo.ensure_object(data)?; - self.write_reference(id, padding) + self.writer.write_fragment(0, &id)?; + Ok(()) } pub fn done(mut self) -> Result { self.flush_inline(vec![])?; - if let Some((context, expected)) = self.sha256 { - if Into::::into(context.finalize()) != expected { - bail!("Content doesn't have expected SHA256 hash value!"); - } - } + self.writer.finalize_sha256_builder()?; self.repo.ensure_object(&self.writer.finish()?) } diff --git a/src/zstd_encoder.rs b/src/zstd_encoder.rs new file mode 100644 index 00000000..08e57685 --- /dev/null +++ b/src/zstd_encoder.rs @@ -0,0 +1,92 @@ +use std::io::{self, Write}; + +use sha2::{Digest, Sha256}; + +use anyhow::{bail, Result}; + +use crate::{ + fsverity::{FsVerityHashValue, Sha256HashValue}, + splitstream::DigestMap, +}; + +pub(crate) struct ZstdWriter { + writer: zstd::Encoder<'static, Vec>, + pub(crate) sha256_builder: Option<(Sha256, Sha256HashValue)>, +} + +impl ZstdWriter { + pub fn new(sha256: Option, refs: Option) -> Self { + Self { + writer: ZstdWriter::instantiate_writer(refs), + sha256_builder: sha256.map(|x| (Sha256::new(), x)), + } + } + + fn instantiate_writer(refs: Option) -> zstd::Encoder<'static, Vec> { + let mut writer = zstd::Encoder::new(vec![], 0).unwrap(); + + match refs { + Some(DigestMap { map }) => { + writer.write_all(&(map.len() as u64).to_le_bytes()).unwrap(); + + for ref entry in map { + writer.write_all(&entry.body).unwrap(); + writer.write_all(&entry.verity).unwrap(); + } + } + + None => { + writer.write_all(&0u64.to_le_bytes()).unwrap(); + } + } + + return writer; + } + + pub(crate) fn write_fragment(&mut self, size: usize, data: &[u8]) -> Result<()> { + self.writer.write_all(&(size as u64).to_le_bytes())?; + Ok(self.writer.write_all(data)?) + } + + pub(crate) fn update_sha(&mut self, data: &[u8]) { + if let Some((sha256, ..)) = &mut self.sha256_builder { + sha256.update(&data); + } + } + + pub(crate) fn flush_inline(&mut self, inline_content: &Vec) -> Result<()> { + if inline_content.is_empty() { + return Ok(()); + } + + self.write_fragment(inline_content.len(), &inline_content)?; + + Ok(()) + } + + pub(crate) fn finalize_sha256_builder(&mut self) -> Result { + let sha256_builder = std::mem::replace(&mut self.sha256_builder, None); + + let mut sha = Sha256HashValue::EMPTY; + + if let Some((context, expected)) = sha256_builder { + let final_sha = Into::::into(context.finalize()); + + if final_sha != expected { + bail!( + "Content doesn't have expected SHA256 hash value!\nExpected: {}, final: {}", + hex::encode(expected), + hex::encode(final_sha) + ); + } + + sha = final_sha; + } + + return Ok(sha); + } + + pub(crate) fn finish(self) -> io::Result> { + self.writer.finish() + } +} From 87b7796f767e7d81bf0ea73713075ca726669dfe Mon Sep 17 00:00:00 2001 From: Pragyan Poudyal Date: Mon, 7 Apr 2025 14:55:06 +0530 Subject: [PATCH 04/11] Parallelize writing of splitstream Add dependencies on Rayon and Crossbeam Have two modes, single and multi-threaded, for Zstd Encoder Spwan threads in splitstream writer for writing external objects and spawn separate threads for Zstd Encoder. We handle communication between these threads using channels Any image's layers will be handed off to one of the Encoder threads. For now we only have one channel for external object writing, but multiple for Encoders. The reasoning for this is Encoder threads are usually CPU bound while the object writer threads are more IO bound Signed-off-by: Pragyan Poudyal --- Cargo.toml | 2 + src/oci/mod.rs | 232 ++++++++++++++++++++++++++------- src/oci/tar.rs | 21 ++- src/repository.rs | 10 +- src/splitstream.rs | 148 +++++++++++++++++++-- src/zstd_encoder.rs | 311 ++++++++++++++++++++++++++++++++++++++++++-- 6 files changed, 647 insertions(+), 77 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d5680c2d..e2ede780 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,11 +21,13 @@ anyhow = { version = "1.0.97", default-features = false } async-compression = { version = "0.4.22", default-features = false, features = ["tokio", "zstd", "gzip"] } clap = { version = "4.5.32", default-features = false, features = ["std", "help", "usage", "derive"] } containers-image-proxy = "0.7.0" +crossbeam = "0.8.4" env_logger = "0.11.7" hex = "0.4.3" indicatif = { version = "0.17.11", features = ["tokio"] } log = "0.4.27" oci-spec = "0.7.1" +rayon = "1.10.0" regex-automata = { version = "0.4.9", default-features = false } rustix = { version = "1.0.3", features = ["fs", "mount", "process"] } serde = "1.0.219" diff --git a/src/oci/mod.rs b/src/oci/mod.rs index 8bc8fdc2..bc8edf21 100644 --- a/src/oci/mod.rs +++ b/src/oci/mod.rs @@ -18,8 +18,12 @@ use crate::{ fsverity::Sha256HashValue, oci::tar::{get_entry, split_async}, repository::Repository, - splitstream::DigestMap, + splitstream::{ + handle_external_object, DigestMap, EnsureObjectMessages, ResultChannelReceiver, + ResultChannelSender, WriterMessages, + }, util::parse_sha256, + zstd_encoder, }; pub fn import_layer( @@ -83,6 +87,7 @@ impl<'repo> ImageOp<'repo> { let proxy = containers_image_proxy::ImageProxy::new_with_config(config).await?; let img = proxy.open_image(imgref).await.context("Opening image")?; let progress = MultiProgress::new(); + Ok(ImageOp { repo, proxy, @@ -95,47 +100,49 @@ impl<'repo> ImageOp<'repo> { &self, layer_sha256: &Sha256HashValue, descriptor: &Descriptor, - ) -> Result { + layer_num: usize, + object_sender: crossbeam::channel::Sender, + ) -> Result<()> { // We need to use the per_manifest descriptor to download the compressed layer but it gets // stored in the repository via the per_config descriptor. Our return value is the // fsverity digest for the corresponding splitstream. - if let Some(layer_id) = self.repo.check_stream(layer_sha256)? { - self.progress - .println(format!("Already have layer {}", hex::encode(layer_sha256)))?; - Ok(layer_id) - } else { - // Otherwise, we need to fetch it... - let (blob_reader, driver) = self.proxy.get_descriptor(&self.img, descriptor).await?; - - // See https://github.com/containers/containers-image-proxy-rs/issues/71 - let blob_reader = blob_reader.take(descriptor.size()); - - let bar = self.progress.add(ProgressBar::new(descriptor.size())); - bar.set_style(ProgressStyle::with_template("[eta {eta}] {bar:40.cyan/blue} {decimal_bytes:>7}/{decimal_total_bytes:7} {msg}") - .unwrap() - .progress_chars("##-")); - let progress = bar.wrap_async_read(blob_reader); - self.progress - .println(format!("Fetching layer {}", hex::encode(layer_sha256)))?; + // Otherwise, we need to fetch it... + let (blob_reader, driver) = self.proxy.get_descriptor(&self.img, descriptor).await?; + + // See https://github.com/containers/containers-image-proxy-rs/issues/71 + let blob_reader = blob_reader.take(descriptor.size()); + + let bar = self.progress.add(ProgressBar::new(descriptor.size())); + bar.set_style( + ProgressStyle::with_template( + "[eta {eta}] {bar:40.cyan/blue} {decimal_bytes:>7}/{decimal_total_bytes:7} {msg}", + ) + .unwrap() + .progress_chars("##-"), + ); + let progress = bar.wrap_async_read(blob_reader); + self.progress + .println(format!("Fetching layer {}", hex::encode(layer_sha256)))?; + + let mut splitstream = + self.repo + .create_stream(Some(*layer_sha256), None, Some(object_sender)); + match descriptor.media_type() { + MediaType::ImageLayer => { + split_async(progress, &mut splitstream, layer_num).await?; + } + MediaType::ImageLayerGzip => { + split_async(GzipDecoder::new(progress), &mut splitstream, layer_num).await?; + } + MediaType::ImageLayerZstd => { + split_async(ZstdDecoder::new(progress), &mut splitstream, layer_num).await?; + } + other => bail!("Unsupported layer media type {:?}", other), + }; + driver.await?; - let mut splitstream = self.repo.create_stream(Some(*layer_sha256), None); - match descriptor.media_type() { - MediaType::ImageLayer => { - split_async(progress, &mut splitstream).await?; - } - MediaType::ImageLayerGzip => { - split_async(GzipDecoder::new(progress), &mut splitstream).await?; - } - MediaType::ImageLayerZstd => { - split_async(ZstdDecoder::new(progress), &mut splitstream).await?; - } - other => bail!("Unsupported layer media type {:?}", other), - }; - let layer_id = self.repo.write_stream(splitstream, None)?; - driver.await?; - Ok(layer_id) - } + Ok(()) } pub async fn ensure_config( @@ -154,7 +161,6 @@ impl<'repo> ImageOp<'repo> { } else { // We need to add the config to the repo. We need to parse the config and make sure we // have all of the layers first. - // self.progress .println(format!("Fetching config {}", hex::encode(config_sha256)))?; @@ -169,19 +175,35 @@ impl<'repo> ImageOp<'repo> { let raw_config = config?; let config = ImageConfiguration::from_reader(&raw_config[..])?; + let (done_chan_sender, done_chan_recver, object_sender) = self.spawn_threads(&config); + let mut config_maps = DigestMap::new(); - for (mld, cld) in zip(manifest_layers, config.rootfs().diff_ids()) { + + for (idx, (mld, cld)) in zip(manifest_layers, config.rootfs().diff_ids()).enumerate() { let layer_sha256 = sha256_from_digest(cld)?; - let layer_id = self - .ensure_layer(&layer_sha256, mld) - .await - .with_context(|| format!("Failed to fetch layer {cld} via {mld:?}"))?; + + if let Some(layer_id) = self.repo.check_stream(&layer_sha256)? { + self.progress + .println(format!("Already have layer {}", hex::encode(layer_sha256)))?; + + config_maps.insert(&layer_sha256, &layer_id); + } else { + self.ensure_layer(&layer_sha256, mld, idx, object_sender.clone()) + .await + .with_context(|| format!("Failed to fetch layer {cld} via {mld:?}"))?; + } + } + + drop(done_chan_sender); + + while let Ok(res) = done_chan_recver.recv() { + let (layer_sha256, layer_id) = res?; config_maps.insert(&layer_sha256, &layer_id); } - let mut splitstream = self - .repo - .create_stream(Some(config_sha256), Some(config_maps)); + let mut splitstream = + self.repo + .create_stream(Some(config_sha256), Some(config_maps), None); splitstream.write_inline(&raw_config); let config_id = self.repo.write_stream(splitstream, None)?; @@ -189,6 +211,121 @@ impl<'repo> ImageOp<'repo> { } } + fn spawn_threads( + &self, + config: &ImageConfiguration, + ) -> ( + ResultChannelSender, + ResultChannelReceiver, + crossbeam::channel::Sender, + ) { + use crossbeam::channel::{unbounded, Receiver, Sender}; + + let encoder_threads = 2; + let external_object_writer_threads = 4; + + let pool = rayon::ThreadPoolBuilder::new() + .num_threads(encoder_threads + external_object_writer_threads) + .build() + .unwrap(); + + // We need this as writers have internal state that can't be shared between threads + // + // We'll actually need as many writers (not writer threads, but writer instances) as there are layers. + let zstd_writer_channels: Vec<(Sender, Receiver)> = + (0..encoder_threads).map(|_| unbounded()).collect(); + + let (object_sender, object_receiver) = unbounded::(); + + // (layer_sha256, layer_id) + let (done_chan_sender, done_chan_recver) = + std::sync::mpsc::channel::>(); + + let chunk_len = (config.rootfs().diff_ids().len() + encoder_threads - 1) / encoder_threads; + + // Divide the layers into chunks of some specific size so each worker + // thread can work on multiple deterministic layers + let mut chunks: Vec> = config + .rootfs() + .diff_ids() + .iter() + .map(|x| sha256_from_digest(x).unwrap()) + .collect::>() + .chunks(chunk_len) + .map(|x| x.to_vec()) + .collect(); + + // Mapping from layer_id -> index in writer_channels + // This is to make sure that all messages relating to a particular layer + // always reach the same writer + let layers_to_chunks = chunks + .iter() + .enumerate() + .map(|(i, chunk)| std::iter::repeat(i).take(chunk.len()).collect::>()) + .flatten() + .collect::>(); + + let _ = (0..encoder_threads) + .map(|i| { + let repository = self.repo.try_clone().unwrap(); + let object_sender = object_sender.clone(); + let done_chan_sender = done_chan_sender.clone(); + let chunk = std::mem::take(&mut chunks[i]); + let receiver = zstd_writer_channels[i].1.clone(); + + pool.spawn({ + move || { + let start = i * (chunk_len); + let end = start + chunk_len; + + let enc = zstd_encoder::MultipleZstdWriters::new( + chunk, + repository, + object_sender, + done_chan_sender, + ); + + if let Err(e) = enc.recv_data(receiver, start, end) { + eprintln!("zstd_encoder returned with error: {}", e.to_string()); + return; + } + } + }); + }) + .collect::>(); + + let _ = (0..external_object_writer_threads) + .map(|_| { + pool.spawn({ + let repository = self.repo.try_clone().unwrap(); + let zstd_writer_channels = zstd_writer_channels + .iter() + .map(|(s, _)| s.clone()) + .collect::>(); + let layers_to_chunks = layers_to_chunks.clone(); + let external_object_receiver = object_receiver.clone(); + + move || { + if let Err(e) = handle_external_object( + repository, + external_object_receiver, + zstd_writer_channels, + layers_to_chunks, + ) { + eprintln!( + "handle_external_object returned with error: {}", + e.to_string() + ); + return; + } + } + }); + }) + .collect::>(); + + return (done_chan_sender, done_chan_recver, object_sender); + } + pub async fn pull(&self) -> Result<(Sha256HashValue, Sha256HashValue)> { let (_manifest_digest, raw_manifest) = self .proxy @@ -201,6 +338,7 @@ impl<'repo> ImageOp<'repo> { let manifest = ImageManifest::from_reader(raw_manifest.as_slice())?; let config_descriptor = manifest.config(); let layers = manifest.layers(); + self.ensure_config(layers, config_descriptor) .await .with_context(|| format!("Failed to pull config {config_descriptor:?}")) @@ -280,7 +418,7 @@ pub fn write_config( let json = config.to_string()?; let json_bytes = json.as_bytes(); let sha256 = hash(json_bytes); - let mut stream = repo.create_stream(Some(sha256), Some(refs)); + let mut stream = repo.create_stream(Some(sha256), Some(refs), None); stream.write_inline(json_bytes); let id = repo.write_stream(stream, None)?; Ok((sha256, id)) diff --git a/src/oci/tar.rs b/src/oci/tar.rs index 718a1670..88d6694b 100644 --- a/src/oci/tar.rs +++ b/src/oci/tar.rs @@ -16,7 +16,9 @@ use tokio::io::{AsyncRead, AsyncReadExt}; use crate::{ dumpfile, image::{LeafContent, RegularFile, Stat}, - splitstream::{SplitStreamData, SplitStreamReader, SplitStreamWriter}, + splitstream::{ + EnsureObjectMessages, FinishMessage, SplitStreamData, SplitStreamReader, SplitStreamWriter, + }, util::{read_exactish, read_exactish_async}, INLINE_CONTENT_MAX, }; @@ -60,7 +62,7 @@ pub fn split(tar_stream: &mut R, writer: &mut SplitStreamWriter) -> Res if header.entry_type() == EntryType::Regular && actual_size > INLINE_CONTENT_MAX { // non-empty regular file: store the data in the object store let padding = buffer.split_off(actual_size); - writer.write_external(&buffer, padding)?; + writer.write_external(buffer, padding, 0, 0)?; } else { // else: store the data inline in the split stream writer.write_inline(&buffer); @@ -72,7 +74,10 @@ pub fn split(tar_stream: &mut R, writer: &mut SplitStreamWriter) -> Res pub async fn split_async( mut tar_stream: impl AsyncRead + Unpin, writer: &mut SplitStreamWriter<'_>, + layer_num: usize, ) -> Result<()> { + let mut seq_num = 0; + while let Some(header) = read_header_async(&mut tar_stream).await? { // the header always gets stored as inline data writer.write_inline(header.as_bytes()); @@ -90,12 +95,22 @@ pub async fn split_async( if header.entry_type() == EntryType::Regular && actual_size > INLINE_CONTENT_MAX { // non-empty regular file: store the data in the object store let padding = buffer.split_off(actual_size); - writer.write_external(&buffer, padding)?; + writer.write_external(buffer, padding, seq_num, layer_num)?; + seq_num += 1; } else { // else: store the data inline in the split stream writer.write_inline(&buffer); } } + + if let Some(sender) = &writer.object_sender { + sender.send(EnsureObjectMessages::Finish(FinishMessage { + data: std::mem::take(&mut writer.inline_content), + total_msgs: seq_num, + layer_num, + }))?; + } + Ok(()) } diff --git a/src/repository.rs b/src/repository.rs index 90303fdc..0ea466eb 100644 --- a/src/repository.rs +++ b/src/repository.rs @@ -23,7 +23,7 @@ use crate::{ Sha256HashValue, }, mount::mount_composefs_at, - splitstream::{DigestMap, SplitStreamReader, SplitStreamWriter}, + splitstream::{DigestMap, EnsureObjectMessages, SplitStreamReader, SplitStreamWriter}, util::{parse_sha256, proc_self_fd}, }; @@ -143,12 +143,13 @@ impl Repository { /// Creates a SplitStreamWriter for writing a split stream. /// You should write the data to the returned object and then pass it to .store_stream() to /// store the result. - pub fn create_stream( + pub(crate) fn create_stream( &self, sha256: Option, maps: Option, + object_sender: Option>, ) -> SplitStreamWriter { - SplitStreamWriter::new(self, maps, sha256) + SplitStreamWriter::new(self, maps, sha256, object_sender) } fn parse_object_path(path: impl AsRef<[u8]>) -> Result { @@ -239,6 +240,7 @@ impl Repository { let Some((.., ref sha256)) = writer.get_sha_builder() else { bail!("Writer doesn't have sha256 enabled"); }; + let stream_path = format!("streams/{}", hex::encode(sha256)); let object_id = writer.done()?; let object_path = Repository::format_object_path(&object_id); @@ -286,7 +288,7 @@ impl Repository { let object_id = match self.has_stream(sha256)? { Some(id) => id, None => { - let mut writer = self.create_stream(Some(*sha256), None); + let mut writer = self.create_stream(Some(*sha256), None, None); callback(&mut writer)?; let object_id = writer.done()?; diff --git a/src/splitstream.rs b/src/splitstream.rs index 7b8ff20d..49c966e6 100644 --- a/src/splitstream.rs +++ b/src/splitstream.rs @@ -5,7 +5,9 @@ use std::io::{BufReader, Read, Write}; -use anyhow::{bail, Result}; +use crossbeam::channel::{Receiver as CrossbeamReceiver, Sender as CrossbeamSender}; + +use anyhow::{bail, Context, Result}; use sha2::Sha256; use zstd::stream::read::Decoder; @@ -63,6 +65,7 @@ pub struct SplitStreamWriter<'a> { repo: &'a Repository, pub(crate) inline_content: Vec, writer: ZstdWriter, + pub(crate) object_sender: Option>, } impl std::fmt::Debug for SplitStreamWriter<'_> { @@ -75,16 +78,116 @@ impl std::fmt::Debug for SplitStreamWriter<'_> { } } +#[derive(Debug)] +pub(crate) struct FinishMessage { + pub(crate) data: Vec, + pub(crate) total_msgs: usize, + pub(crate) layer_num: usize, +} + +#[derive(Eq, Debug)] +pub(crate) struct WriterMessagesData { + pub(crate) digest: Sha256HashValue, + pub(crate) object_data: SplitStreamWriterSenderData, +} + +#[derive(Debug)] +pub(crate) enum WriterMessages { + WriteData(WriterMessagesData), + Finish(FinishMessage), +} + +impl PartialEq for WriterMessagesData { + fn eq(&self, other: &Self) -> bool { + self.object_data.seq_num.eq(&other.object_data.seq_num) + } +} + +impl PartialOrd for WriterMessagesData { + fn partial_cmp(&self, other: &Self) -> Option { + self.object_data + .seq_num + .partial_cmp(&other.object_data.seq_num) + } +} + +impl Ord for WriterMessagesData { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.object_data.seq_num.cmp(&other.object_data.seq_num) + } +} + +#[derive(Debug, PartialEq, Eq)] +pub(crate) struct SplitStreamWriterSenderData { + pub(crate) external_data: Vec, + pub(crate) inline_content: Vec, + pub(crate) seq_num: usize, + pub(crate) layer_num: usize, +} +pub(crate) enum EnsureObjectMessages { + Data(SplitStreamWriterSenderData), + Finish(FinishMessage), +} + +pub(crate) type ResultChannelSender = + std::sync::mpsc::Sender>; +pub(crate) type ResultChannelReceiver = + std::sync::mpsc::Receiver>; + +pub(crate) fn handle_external_object( + repository: Repository, + external_object_receiver: CrossbeamReceiver, + zstd_writer_channels: Vec>, + layers_to_chunks: Vec, +) -> Result<()> { + while let Ok(data) = external_object_receiver.recv() { + match data { + EnsureObjectMessages::Data(data) => { + let digest = repository.ensure_object(&data.external_data)?; + let layer_num = data.layer_num; + let writer_chan_sender = &zstd_writer_channels[layers_to_chunks[layer_num]]; + + let msg = WriterMessagesData { + digest, + object_data: data, + }; + + // `send` only fails if all receivers are dropped + writer_chan_sender + .send(WriterMessages::WriteData(msg)) + .with_context(|| format!("Failed to send message for layer {layer_num}"))?; + } + + EnsureObjectMessages::Finish(final_msg) => { + let layer_num = final_msg.layer_num; + let writer_chan_sender = &zstd_writer_channels[layers_to_chunks[layer_num]]; + + writer_chan_sender + .send(WriterMessages::Finish(final_msg)) + .with_context(|| { + format!("Failed to send final message for layer {layer_num}") + })?; + } + } + } + + Ok(()) +} + impl SplitStreamWriter<'_> { - pub fn new( + pub(crate) fn new( repo: &Repository, refs: Option, sha256: Option, + object_sender: Option>, ) -> SplitStreamWriter { + let inline_content = vec![]; + SplitStreamWriter { repo, - inline_content: vec![], - writer: ZstdWriter::new(sha256, refs), + inline_content, + object_sender, + writer: ZstdWriter::new(sha256, refs, repo.try_clone().unwrap()), } } @@ -102,26 +205,45 @@ impl SplitStreamWriter<'_> { /// really, "add inline content to the buffer" /// you need to call .flush_inline() later pub fn write_inline(&mut self, data: &[u8]) { - self.writer.update_sha(data); self.inline_content.extend(data); } - pub fn write_external(&mut self, data: &[u8], padding: Vec) -> Result<()> { - let id = self.repo.ensure_object(&data)?; + pub fn write_external( + &mut self, + data: Vec, + padding: Vec, + seq_num: usize, + layer_num: usize, + ) -> Result<()> { + match &self.object_sender { + Some(sender) => { + let inline_content = std::mem::replace(&mut self.inline_content, padding); + + if let Err(e) = + sender.send(EnsureObjectMessages::Data(SplitStreamWriterSenderData { + external_data: data, + inline_content, + seq_num, + layer_num, + })) + { + println!("Falied to send message. Err: {}", e.to_string()); + } + } - self.writer.update_sha(data); - self.writer.update_sha(&padding); - self.writer.flush_inline(&padding)?; + None => { + let id = self.repo.ensure_object(&data)?; + self.writer.flush_inline(&padding)?; + self.writer.write_fragment(0, &id)?; + } + }; - self.writer.write_fragment(0, &id)?; Ok(()) } pub fn done(mut self) -> Result { self.flush_inline(vec![])?; - self.writer.finalize_sha256_builder()?; - self.repo.ensure_object(&self.writer.finish()?) } } diff --git a/src/zstd_encoder.rs b/src/zstd_encoder.rs index 08e57685..e79c6d26 100644 --- a/src/zstd_encoder.rs +++ b/src/zstd_encoder.rs @@ -1,27 +1,177 @@ -use std::io::{self, Write}; +use std::{ + cmp::Reverse, + collections::BinaryHeap, + io::{self, Write}, +}; use sha2::{Digest, Sha256}; -use anyhow::{bail, Result}; +use anyhow::{bail, Context, Result}; +use zstd::Encoder; use crate::{ fsverity::{FsVerityHashValue, Sha256HashValue}, - splitstream::DigestMap, + repository::Repository, + splitstream::{ + DigestMap, EnsureObjectMessages, FinishMessage, ResultChannelSender, + SplitStreamWriterSenderData, WriterMessages, WriterMessagesData, + }, }; pub(crate) struct ZstdWriter { writer: zstd::Encoder<'static, Vec>, + repository: Repository, pub(crate) sha256_builder: Option<(Sha256, Sha256HashValue)>, + mode: WriterMode, +} + +pub(crate) struct MultiThreadedState { + last: usize, + heap: BinaryHeap>, + final_sha: Option, + final_message: Option, + object_sender: crossbeam::channel::Sender, + final_result_sender: ResultChannelSender, +} + +pub(crate) enum WriterMode { + SingleThreaded, + MultiThreaded(MultiThreadedState), +} + +pub(crate) struct MultipleZstdWriters { + writers: Vec, + final_result_sender: ResultChannelSender, +} + +impl MultipleZstdWriters { + pub fn new( + sha256: Vec, + repository: Repository, + object_sender: crossbeam::channel::Sender, + final_result_sender: ResultChannelSender, + ) -> Self { + Self { + final_result_sender: final_result_sender.clone(), + + writers: sha256 + .iter() + .map(|sha| { + ZstdWriter::new_threaded( + Some(*sha), + None, + repository.try_clone().unwrap(), + object_sender.clone(), + final_result_sender.clone(), + ) + }) + .collect(), + } + } + + pub fn recv_data( + mut self, + enc_chan_recvr: crossbeam::channel::Receiver, + layer_num_start: usize, + layer_num_end: usize, + ) -> Result<()> { + assert!(layer_num_end >= layer_num_start); + + let total_writers = self.writers.len(); + + // layers_to_writers[layer_num] = writer_idx + // Faster than a hash map + let mut layers_to_writers: Vec = vec![0; layer_num_end]; + + for (idx, i) in (layer_num_start..layer_num_end).enumerate() { + layers_to_writers[i] = idx + } + + let mut finished_writers = 0; + + while let Ok(data) = enc_chan_recvr.recv() { + let layer_num = match &data { + WriterMessages::WriteData(d) => d.object_data.layer_num, + WriterMessages::Finish(d) => d.layer_num, + }; + + assert!(layer_num >= layer_num_start && layer_num <= layer_num_end); + + match self.writers[layers_to_writers[layer_num]].handle_received_data(data) { + Ok(t) => { + if t { + finished_writers += 1 + } + } + + Err(e) => self + .final_result_sender + .send(Err(e)) + .context("Failed to send result on channel")?, + } + + if finished_writers == total_writers { + break; + } + } + + Ok(()) + } } impl ZstdWriter { - pub fn new(sha256: Option, refs: Option) -> Self { + pub fn new_threaded( + sha256: Option, + refs: Option, + repository: Repository, + object_sender: crossbeam::channel::Sender, + final_result_sender: ResultChannelSender, + ) -> Self { Self { writer: ZstdWriter::instantiate_writer(refs), + repository, sha256_builder: sha256.map(|x| (Sha256::new(), x)), + + mode: WriterMode::MultiThreaded(MultiThreadedState { + final_sha: None, + last: 0, + heap: BinaryHeap::new(), + final_message: None, + object_sender, + final_result_sender, + }), } } + pub fn new( + sha256: Option, + refs: Option, + repository: Repository, + ) -> Self { + Self { + writer: ZstdWriter::instantiate_writer(refs), + repository, + sha256_builder: sha256.map(|x| (Sha256::new(), x)), + mode: WriterMode::SingleThreaded, + } + } + + fn get_state(&self) -> &MultiThreadedState { + let WriterMode::MultiThreaded(state) = &self.mode else { + panic!("`get_state` called on a single threaded writer") + }; + + return state; + } + + fn get_state_mut(&mut self) -> &mut MultiThreadedState { + let WriterMode::MultiThreaded(state) = &mut self.mode else { + panic!("`get_state_mut` called on a single threaded writer") + }; + + return state; + } + fn instantiate_writer(refs: Option) -> zstd::Encoder<'static, Vec> { let mut writer = zstd::Encoder::new(vec![], 0).unwrap(); @@ -48,22 +198,67 @@ impl ZstdWriter { Ok(self.writer.write_all(data)?) } - pub(crate) fn update_sha(&mut self, data: &[u8]) { - if let Some((sha256, ..)) = &mut self.sha256_builder { - sha256.update(&data); - } - } - pub(crate) fn flush_inline(&mut self, inline_content: &Vec) -> Result<()> { if inline_content.is_empty() { return Ok(()); } + if let Some((sha256, ..)) = &mut self.sha256_builder { + sha256.update(&inline_content); + } + self.write_fragment(inline_content.len(), &inline_content)?; Ok(()) } + fn write_message(&mut self) -> Result<()> { + loop { + // Gotta keep lifetime of the destructring inside the loop + let state = self.get_state_mut(); + + let Some(data) = state.heap.peek() else { + break; + }; + + if data.0.object_data.seq_num != state.last { + break; + } + + let data = state.heap.pop().unwrap(); + state.last += 1; + + self.flush_inline(&data.0.object_data.inline_content)?; + + if let Some((sha256, ..)) = &mut self.sha256_builder { + sha256.update(data.0.object_data.external_data); + } + + if let Err(e) = self.write_fragment(0, &data.0.digest) { + println!("write_fragment err while writing external content: {e:?}"); + } + } + + let final_msg = self.get_state_mut().final_message.take(); + + if let Some(final_msg) = final_msg { + // Haven't received all the messages so we reset the final_message field + if self.get_state().last < final_msg.total_msgs { + self.get_state_mut().final_message = Some(final_msg); + return Ok(()); + } + + let sha = self.handle_final_message(final_msg).unwrap(); + self.get_state_mut().final_sha = Some(sha); + } + + Ok(()) + } + + fn add_message_to_heap(&mut self, recv_data: WriterMessagesData) { + self.get_state_mut().heap.push(Reverse(recv_data)); + } + pub(crate) fn finalize_sha256_builder(&mut self) -> Result { let sha256_builder = std::mem::replace(&mut self.sha256_builder, None); @@ -89,4 +284,100 @@ impl ZstdWriter { pub(crate) fn finish(self) -> io::Result> { self.writer.finish() } + + fn handle_final_message(&mut self, final_message: FinishMessage) -> Result { + self.flush_inline(&final_message.data)?; + + let writer = std::mem::replace(&mut self.writer, Encoder::new(vec![], 0).unwrap()); + let finished = writer.finish()?; + + let sha = self.finalize_sha256_builder()?; + + if let Err(e) = self + .get_state() + .object_sender + .send(EnsureObjectMessages::Data(SplitStreamWriterSenderData { + external_data: finished, + inline_content: vec![], + seq_num: 0, + layer_num: final_message.layer_num, + })) + { + println!("Failed to finish writer. Err: {e}"); + }; + + Ok(sha) + } + + // Cannot `take` ownership of self, as we'll need it later + // returns whether finished or not + fn handle_received_data(&mut self, data: WriterMessages) -> Result { + match data { + WriterMessages::WriteData(recv_data) => { + if let Some(final_sha) = self.get_state().final_sha { + // We've already received the final messae + let stream_path = format!("streams/{}", hex::encode(final_sha)); + + let object_path = Repository::format_object_path(&recv_data.digest); + self.repository.ensure_symlink(&stream_path, &object_path)?; + + // if let Some(name) = reference { + // let reference_path = format!("streams/refs/{name}"); + // self.symlink(&reference_path, &stream_path)?; + // } + + if let Err(e) = self + .get_state() + .final_result_sender + .send(Ok((final_sha, recv_data.digest))) + { + println!("Failed to send final digest with err: {e:?}"); + } + + return Ok(true); + } + + let seq_num = recv_data.object_data.seq_num; + + self.add_message_to_heap(recv_data); + + if seq_num != self.get_state().last { + return Ok(false); + } + + self.write_message()?; + } + + WriterMessages::Finish(final_msg) => { + if self.get_state().final_message.is_some() { + panic!( + "Received two finalize messages for layer {}. Previous final message {:?}", + final_msg.layer_num, + self.get_state().final_message + ); + } + + // write all pending messages + if !self.get_state().heap.is_empty() { + self.write_message()?; + } + + let total_msgs = final_msg.total_msgs; + + if self.get_state().last >= total_msgs { + // We have received all the messages + // Finalize + let final_sha = self.handle_final_message(final_msg).unwrap(); + self.get_state_mut().final_sha = Some(final_sha); + } else { + // Haven't received all messages. Store the final message until we have + // received all + let state = self.get_state_mut(); + state.final_message = Some(final_msg); + } + } + } + + return Ok(false); + } } From 3135f6ac68e39373eebfe292fd9c9131f4a54358 Mon Sep 17 00:00:00 2001 From: Pragyan Poudyal Date: Mon, 7 Apr 2025 15:34:25 +0530 Subject: [PATCH 05/11] Add documentation for writer functions Signed-off-by: Pragyan Poudyal --- src/zstd_encoder.rs | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/zstd_encoder.rs b/src/zstd_encoder.rs index e79c6d26..7edfb6b1 100644 --- a/src/zstd_encoder.rs +++ b/src/zstd_encoder.rs @@ -198,6 +198,7 @@ impl ZstdWriter { Ok(self.writer.write_all(data)?) } + /// Writes all the data in `inline_content`, updating the internal SHA pub(crate) fn flush_inline(&mut self, inline_content: &Vec) -> Result<()> { if inline_content.is_empty() { return Ok(()); @@ -212,6 +213,8 @@ impl ZstdWriter { Ok(()) } + /// Keeps popping from the heap until it reaches the message with the largest seq_num, n, + /// given we have every message with seq_num < n fn write_message(&mut self) -> Result<()> { loop { // Gotta keep lifetime of the destructring inside the loop @@ -234,9 +237,7 @@ impl ZstdWriter { sha256.update(data.0.object_data.external_data); } - if let Err(e) = self.write_fragment(0, &data.0.digest) { - println!("write_fragment err while writing external content: {e:?}"); - } + self.write_fragment(0, &data.0.digest)?; } let final_msg = self.get_state_mut().final_message.take(); @@ -281,6 +282,7 @@ impl ZstdWriter { return Ok(sha); } + /// Calls `finish` on the internal writer pub(crate) fn finish(self) -> io::Result> { self.writer.finish() } @@ -310,7 +312,8 @@ impl ZstdWriter { } // Cannot `take` ownership of self, as we'll need it later - // returns whether finished or not + // + /// Returns whether we have finished writing all the data or not fn handle_received_data(&mut self, data: WriterMessages) -> Result { match data { WriterMessages::WriteData(recv_data) => { @@ -321,11 +324,6 @@ impl ZstdWriter { let object_path = Repository::format_object_path(&recv_data.digest); self.repository.ensure_symlink(&stream_path, &object_path)?; - // if let Some(name) = reference { - // let reference_path = format!("streams/refs/{name}"); - // self.symlink(&reference_path, &stream_path)?; - // } - if let Err(e) = self .get_state() .final_result_sender From 78e42496a4d83c78812c1bc07e46304799926967 Mon Sep 17 00:00:00 2001 From: Pragyan Poudyal Date: Mon, 7 Apr 2025 15:41:40 +0530 Subject: [PATCH 06/11] Wrap channel send errors with proper context Signed-off-by: Pragyan Poudyal --- src/oci/tar.rs | 14 ++++++++------ src/splitstream.rs | 10 +++++----- src/zstd_encoder.rs | 22 +++++++++------------- 3 files changed, 22 insertions(+), 24 deletions(-) diff --git a/src/oci/tar.rs b/src/oci/tar.rs index 88d6694b..ea118c6e 100644 --- a/src/oci/tar.rs +++ b/src/oci/tar.rs @@ -8,7 +8,7 @@ use std::{ path::PathBuf, }; -use anyhow::{bail, ensure, Result}; +use anyhow::{bail, ensure, Context, Result}; use rustix::fs::makedev; use tar::{EntryType, Header, PaxExtensions}; use tokio::io::{AsyncRead, AsyncReadExt}; @@ -104,11 +104,13 @@ pub async fn split_async( } if let Some(sender) = &writer.object_sender { - sender.send(EnsureObjectMessages::Finish(FinishMessage { - data: std::mem::take(&mut writer.inline_content), - total_msgs: seq_num, - layer_num, - }))?; + sender + .send(EnsureObjectMessages::Finish(FinishMessage { + data: std::mem::take(&mut writer.inline_content), + total_msgs: seq_num, + layer_num, + })) + .with_context(|| format!("Failed to send final message for layer {layer_num}"))?; } Ok(()) diff --git a/src/splitstream.rs b/src/splitstream.rs index 49c966e6..686bb400 100644 --- a/src/splitstream.rs +++ b/src/splitstream.rs @@ -219,16 +219,16 @@ impl SplitStreamWriter<'_> { Some(sender) => { let inline_content = std::mem::replace(&mut self.inline_content, padding); - if let Err(e) = - sender.send(EnsureObjectMessages::Data(SplitStreamWriterSenderData { + sender + .send(EnsureObjectMessages::Data(SplitStreamWriterSenderData { external_data: data, inline_content, seq_num, layer_num, })) - { - println!("Falied to send message. Err: {}", e.to_string()); - } + .with_context(|| { + format!("Failed to send message to writer for layer {layer_num}") + })?; } None => { diff --git a/src/zstd_encoder.rs b/src/zstd_encoder.rs index 7edfb6b1..36781837 100644 --- a/src/zstd_encoder.rs +++ b/src/zstd_encoder.rs @@ -98,8 +98,8 @@ impl MultipleZstdWriters { assert!(layer_num >= layer_num_start && layer_num <= layer_num_end); match self.writers[layers_to_writers[layer_num]].handle_received_data(data) { - Ok(t) => { - if t { + Ok(finished) => { + if finished { finished_writers += 1 } } @@ -295,18 +295,15 @@ impl ZstdWriter { let sha = self.finalize_sha256_builder()?; - if let Err(e) = self - .get_state() + self.get_state() .object_sender .send(EnsureObjectMessages::Data(SplitStreamWriterSenderData { external_data: finished, inline_content: vec![], - seq_num: 0, + seq_num: final_message.total_msgs, layer_num: final_message.layer_num, })) - { - println!("Failed to finish writer. Err: {e}"); - }; + .with_context(|| format!("Failed to send object finalize message"))?; Ok(sha) } @@ -324,13 +321,12 @@ impl ZstdWriter { let object_path = Repository::format_object_path(&recv_data.digest); self.repository.ensure_symlink(&stream_path, &object_path)?; - if let Err(e) = self - .get_state() + self.get_state() .final_result_sender .send(Ok((final_sha, recv_data.digest))) - { - println!("Failed to send final digest with err: {e:?}"); - } + .with_context(|| { + format!("Failed to send result for layer {final_sha:?}") + })?; return Ok(true); } From 0137bb9e872bc88ff1f45df30d57b4d6fdc01d1d Mon Sep 17 00:00:00 2001 From: Pragyan Poudyal Date: Mon, 7 Apr 2025 17:09:20 +0530 Subject: [PATCH 07/11] Fix bug in single threaded writer Expose an `update_sha` method from ZstdWriter so we can update the rolling SHA256 hash value from SplitStreamWriter Signed-off-by: Pragyan Poudyal --- src/splitstream.rs | 4 +++- src/zstd_encoder.rs | 18 ++++++++++-------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/splitstream.rs b/src/splitstream.rs index 686bb400..e88e0156 100644 --- a/src/splitstream.rs +++ b/src/splitstream.rs @@ -232,8 +232,10 @@ impl SplitStreamWriter<'_> { } None => { + self.flush_inline(padding)?; + self.writer.update_sha(&data); + let id = self.repo.ensure_object(&data)?; - self.writer.flush_inline(&padding)?; self.writer.write_fragment(0, &id)?; } }; diff --git a/src/zstd_encoder.rs b/src/zstd_encoder.rs index 36781837..634a62a5 100644 --- a/src/zstd_encoder.rs +++ b/src/zstd_encoder.rs @@ -10,7 +10,7 @@ use anyhow::{bail, Context, Result}; use zstd::Encoder; use crate::{ - fsverity::{FsVerityHashValue, Sha256HashValue}, + fsverity::Sha256HashValue, repository::Repository, splitstream::{ DigestMap, EnsureObjectMessages, FinishMessage, ResultChannelSender, @@ -198,15 +198,19 @@ impl ZstdWriter { Ok(self.writer.write_all(data)?) } + pub(crate) fn update_sha(&mut self, data: &[u8]) { + if let Some((sha256, ..)) = &mut self.sha256_builder { + sha256.update(&data); + } + } + /// Writes all the data in `inline_content`, updating the internal SHA pub(crate) fn flush_inline(&mut self, inline_content: &Vec) -> Result<()> { if inline_content.is_empty() { return Ok(()); } - if let Some((sha256, ..)) = &mut self.sha256_builder { - sha256.update(&inline_content); - } + self.update_sha(inline_content); self.write_fragment(inline_content.len(), &inline_content)?; @@ -263,8 +267,6 @@ impl ZstdWriter { pub(crate) fn finalize_sha256_builder(&mut self) -> Result { let sha256_builder = std::mem::replace(&mut self.sha256_builder, None); - let mut sha = Sha256HashValue::EMPTY; - if let Some((context, expected)) = sha256_builder { let final_sha = Into::::into(context.finalize()); @@ -276,10 +278,10 @@ impl ZstdWriter { ); } - sha = final_sha; + return Ok(final_sha); } - return Ok(sha); + bail!("SHA not enabled for writer"); } /// Calls `finish` on the internal writer From 50720e97936a392071e78c03f16ff73ac4ac2f8a Mon Sep 17 00:00:00 2001 From: Pragyan Poudyal Date: Mon, 7 Apr 2025 17:28:48 +0530 Subject: [PATCH 08/11] Fix clippy warnings Signed-off-by: Pragyan Poudyal --- src/oci/mod.rs | 16 +++++----------- src/splitstream.rs | 4 +--- src/zstd_encoder.rs | 18 +++++++++--------- 3 files changed, 15 insertions(+), 23 deletions(-) diff --git a/src/oci/mod.rs b/src/oci/mod.rs index bc8edf21..70d6c298 100644 --- a/src/oci/mod.rs +++ b/src/oci/mod.rs @@ -241,7 +241,7 @@ impl<'repo> ImageOp<'repo> { let (done_chan_sender, done_chan_recver) = std::sync::mpsc::channel::>(); - let chunk_len = (config.rootfs().diff_ids().len() + encoder_threads - 1) / encoder_threads; + let chunk_len = config.rootfs().diff_ids().len().div_ceil(encoder_threads); // Divide the layers into chunks of some specific size so each worker // thread can work on multiple deterministic layers @@ -261,8 +261,7 @@ impl<'repo> ImageOp<'repo> { let layers_to_chunks = chunks .iter() .enumerate() - .map(|(i, chunk)| std::iter::repeat(i).take(chunk.len()).collect::>()) - .flatten() + .flat_map(|(i, chunk)| std::iter::repeat(i).take(chunk.len()).collect::>()) .collect::>(); let _ = (0..encoder_threads) @@ -286,8 +285,7 @@ impl<'repo> ImageOp<'repo> { ); if let Err(e) = enc.recv_data(receiver, start, end) { - eprintln!("zstd_encoder returned with error: {}", e.to_string()); - return; + eprintln!("zstd_encoder returned with error: {}", e) } } }); @@ -312,18 +310,14 @@ impl<'repo> ImageOp<'repo> { zstd_writer_channels, layers_to_chunks, ) { - eprintln!( - "handle_external_object returned with error: {}", - e.to_string() - ); - return; + eprintln!("handle_external_object returned with error: {}", e); } } }); }) .collect::>(); - return (done_chan_sender, done_chan_recver, object_sender); + (done_chan_sender, done_chan_recver, object_sender) } pub async fn pull(&self) -> Result<(Sha256HashValue, Sha256HashValue)> { diff --git a/src/splitstream.rs b/src/splitstream.rs index e88e0156..833540ba 100644 --- a/src/splitstream.rs +++ b/src/splitstream.rs @@ -105,9 +105,7 @@ impl PartialEq for WriterMessagesData { impl PartialOrd for WriterMessagesData { fn partial_cmp(&self, other: &Self) -> Option { - self.object_data - .seq_num - .partial_cmp(&other.object_data.seq_num) + Some(self.cmp(other)) } } diff --git a/src/zstd_encoder.rs b/src/zstd_encoder.rs index 634a62a5..fd22c52f 100644 --- a/src/zstd_encoder.rs +++ b/src/zstd_encoder.rs @@ -161,7 +161,7 @@ impl ZstdWriter { panic!("`get_state` called on a single threaded writer") }; - return state; + state } fn get_state_mut(&mut self) -> &mut MultiThreadedState { @@ -169,7 +169,7 @@ impl ZstdWriter { panic!("`get_state_mut` called on a single threaded writer") }; - return state; + state } fn instantiate_writer(refs: Option) -> zstd::Encoder<'static, Vec> { @@ -190,7 +190,7 @@ impl ZstdWriter { } } - return writer; + writer } pub(crate) fn write_fragment(&mut self, size: usize, data: &[u8]) -> Result<()> { @@ -200,19 +200,19 @@ impl ZstdWriter { pub(crate) fn update_sha(&mut self, data: &[u8]) { if let Some((sha256, ..)) = &mut self.sha256_builder { - sha256.update(&data); + sha256.update(data); } } /// Writes all the data in `inline_content`, updating the internal SHA - pub(crate) fn flush_inline(&mut self, inline_content: &Vec) -> Result<()> { + pub(crate) fn flush_inline(&mut self, inline_content: &[u8]) -> Result<()> { if inline_content.is_empty() { return Ok(()); } self.update_sha(inline_content); - self.write_fragment(inline_content.len(), &inline_content)?; + self.write_fragment(inline_content.len(), inline_content)?; Ok(()) } @@ -265,7 +265,7 @@ impl ZstdWriter { } pub(crate) fn finalize_sha256_builder(&mut self) -> Result { - let sha256_builder = std::mem::replace(&mut self.sha256_builder, None); + let sha256_builder = self.sha256_builder.take(); if let Some((context, expected)) = sha256_builder { let final_sha = Into::::into(context.finalize()); @@ -305,7 +305,7 @@ impl ZstdWriter { seq_num: final_message.total_msgs, layer_num: final_message.layer_num, })) - .with_context(|| format!("Failed to send object finalize message"))?; + .context("Failed to send object finalize message")?; Ok(sha) } @@ -374,6 +374,6 @@ impl ZstdWriter { } } - return Ok(false); + Ok(false) } } From bc09a83651297f0e9c2276d5120942e85252c9d9 Mon Sep 17 00:00:00 2001 From: Pragyan Poudyal Date: Tue, 8 Apr 2025 13:21:57 +0530 Subject: [PATCH 09/11] Return Err from ZstdEncoder rather than sending it on the result channel If anything goes wrong at any point, short circuting and stopping the thread would be the better option Signed-off-by: Pragyan Poudyal --- src/zstd_encoder.rs | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/zstd_encoder.rs b/src/zstd_encoder.rs index fd22c52f..065b5ca5 100644 --- a/src/zstd_encoder.rs +++ b/src/zstd_encoder.rs @@ -41,7 +41,6 @@ pub(crate) enum WriterMode { pub(crate) struct MultipleZstdWriters { writers: Vec, - final_result_sender: ResultChannelSender, } impl MultipleZstdWriters { @@ -52,8 +51,6 @@ impl MultipleZstdWriters { final_result_sender: ResultChannelSender, ) -> Self { Self { - final_result_sender: final_result_sender.clone(), - writers: sha256 .iter() .map(|sha| { @@ -104,10 +101,9 @@ impl MultipleZstdWriters { } } - Err(e) => self - .final_result_sender - .send(Err(e)) - .context("Failed to send result on channel")?, + Err(e) => { + return Err(e); + } } if finished_writers == total_writers { From 44e180249da5f24040b8e61beecd713b706a9e95 Mon Sep 17 00:00:00 2001 From: Pragyan Poudyal Date: Tue, 8 Apr 2025 13:25:28 +0530 Subject: [PATCH 10/11] Fix deadlock when we have more threads than layers If we have more threads than we have unprocessed layers, some of the cloned senders aren't dropped and the main thread hangs on the result receiving loop. We make sure here to not spawn more threads than the number of unhandled layers Signed-off-by: Pragyan Poudyal --- src/oci/mod.rs | 168 +++++++++++++++++++++++++++---------------------- 1 file changed, 92 insertions(+), 76 deletions(-) diff --git a/src/oci/mod.rs b/src/oci/mod.rs index 70d6c298..91937d23 100644 --- a/src/oci/mod.rs +++ b/src/oci/mod.rs @@ -175,11 +175,14 @@ impl<'repo> ImageOp<'repo> { let raw_config = config?; let config = ImageConfiguration::from_reader(&raw_config[..])?; - let (done_chan_sender, done_chan_recver, object_sender) = self.spawn_threads(&config); + let (done_chan_sender, done_chan_recver, object_sender) = + self.spawn_threads(&config)?; let mut config_maps = DigestMap::new(); - for (idx, (mld, cld)) in zip(manifest_layers, config.rootfs().diff_ids()).enumerate() { + let mut idx = 0; + + for (mld, cld) in zip(manifest_layers, config.rootfs().diff_ids()) { let layer_sha256 = sha256_from_digest(cld)?; if let Some(layer_id) = self.repo.check_stream(&layer_sha256)? { @@ -191,6 +194,8 @@ impl<'repo> ImageOp<'repo> { self.ensure_layer(&layer_sha256, mld, idx, object_sender.clone()) .await .with_context(|| format!("Failed to fetch layer {cld} via {mld:?}"))?; + + idx += 1; } } @@ -214,43 +219,39 @@ impl<'repo> ImageOp<'repo> { fn spawn_threads( &self, config: &ImageConfiguration, - ) -> ( + ) -> Result<( ResultChannelSender, ResultChannelReceiver, crossbeam::channel::Sender, - ) { + )> { use crossbeam::channel::{unbounded, Receiver, Sender}; - let encoder_threads = 2; + let mut encoder_threads = 2; let external_object_writer_threads = 4; - let pool = rayon::ThreadPoolBuilder::new() - .num_threads(encoder_threads + external_object_writer_threads) - .build() - .unwrap(); - - // We need this as writers have internal state that can't be shared between threads - // - // We'll actually need as many writers (not writer threads, but writer instances) as there are layers. - let zstd_writer_channels: Vec<(Sender, Receiver)> = - (0..encoder_threads).map(|_| unbounded()).collect(); - - let (object_sender, object_receiver) = unbounded::(); - - // (layer_sha256, layer_id) - let (done_chan_sender, done_chan_recver) = - std::sync::mpsc::channel::>(); - let chunk_len = config.rootfs().diff_ids().len().div_ceil(encoder_threads); // Divide the layers into chunks of some specific size so each worker // thread can work on multiple deterministic layers - let mut chunks: Vec> = config + let diff_ids: Vec = config .rootfs() .diff_ids() .iter() - .map(|x| sha256_from_digest(x).unwrap()) - .collect::>() + .map(|x| sha256_from_digest(x)) + .collect::, _>>()?; + + let mut unhandled_layers = vec![]; + + // This becomes pretty unreadable with a filter,map chain + for id in diff_ids { + let layer_exists = self.repo.check_stream(&id)?; + + if layer_exists.is_none() { + unhandled_layers.push(id); + } + } + + let mut chunks: Vec> = unhandled_layers .chunks(chunk_len) .map(|x| x.to_vec()) .collect(); @@ -264,60 +265,75 @@ impl<'repo> ImageOp<'repo> { .flat_map(|(i, chunk)| std::iter::repeat(i).take(chunk.len()).collect::>()) .collect::>(); - let _ = (0..encoder_threads) - .map(|i| { - let repository = self.repo.try_clone().unwrap(); - let object_sender = object_sender.clone(); - let done_chan_sender = done_chan_sender.clone(); - let chunk = std::mem::take(&mut chunks[i]); - let receiver = zstd_writer_channels[i].1.clone(); - - pool.spawn({ - move || { - let start = i * (chunk_len); - let end = start + chunk_len; - - let enc = zstd_encoder::MultipleZstdWriters::new( - chunk, - repository, - object_sender, - done_chan_sender, - ); - - if let Err(e) = enc.recv_data(receiver, start, end) { - eprintln!("zstd_encoder returned with error: {}", e) - } + encoder_threads = encoder_threads.min(chunks.len()); + + let pool = rayon::ThreadPoolBuilder::new() + .num_threads(encoder_threads + external_object_writer_threads) + .build() + .unwrap(); + + // We need this as writers have internal state that can't be shared between threads + // + // We'll actually need as many writers (not writer threads, but writer instances) as there are layers. + let zstd_writer_channels: Vec<(Sender, Receiver)> = + (0..encoder_threads).map(|_| unbounded()).collect(); + + let (object_sender, object_receiver) = unbounded::(); + + // (layer_sha256, layer_id) + let (done_chan_sender, done_chan_recver) = + std::sync::mpsc::channel::>(); + + for i in 0..encoder_threads { + let repository = self.repo.try_clone().unwrap(); + let object_sender = object_sender.clone(); + let done_chan_sender = done_chan_sender.clone(); + let chunk = std::mem::take(&mut chunks[i]); + let receiver = zstd_writer_channels[i].1.clone(); + + pool.spawn({ + move || { + let start = i * (chunk_len); + let end = start + chunk_len; + + let enc = zstd_encoder::MultipleZstdWriters::new( + chunk, + repository, + object_sender, + done_chan_sender, + ); + + if let Err(e) = enc.recv_data(receiver, start, end) { + eprintln!("zstd_encoder returned with error: {}", e) } - }); - }) - .collect::>(); - - let _ = (0..external_object_writer_threads) - .map(|_| { - pool.spawn({ - let repository = self.repo.try_clone().unwrap(); - let zstd_writer_channels = zstd_writer_channels - .iter() - .map(|(s, _)| s.clone()) - .collect::>(); - let layers_to_chunks = layers_to_chunks.clone(); - let external_object_receiver = object_receiver.clone(); - - move || { - if let Err(e) = handle_external_object( - repository, - external_object_receiver, - zstd_writer_channels, - layers_to_chunks, - ) { - eprintln!("handle_external_object returned with error: {}", e); - } + } + }); + } + + for _ in 0..external_object_writer_threads { + pool.spawn({ + let repository = self.repo.try_clone().unwrap(); + let zstd_writer_channels = zstd_writer_channels + .iter() + .map(|(s, _)| s.clone()) + .collect::>(); + let layers_to_chunks = layers_to_chunks.clone(); + let external_object_receiver = object_receiver.clone(); + + move || { + if let Err(e) = handle_external_object( + repository, + external_object_receiver, + zstd_writer_channels, + layers_to_chunks, + ) { + eprintln!("handle_external_object returned with error: {}", e); } - }); - }) - .collect::>(); + } + }); + } - (done_chan_sender, done_chan_recver, object_sender) + Ok((done_chan_sender, done_chan_recver, object_sender)) } pub async fn pull(&self) -> Result<(Sha256HashValue, Sha256HashValue)> { From dc3dd52891301ff7c9dffb29052f28ddfc89bd60 Mon Sep 17 00:00:00 2001 From: Pragyan Poudyal Date: Tue, 8 Apr 2025 13:57:30 +0530 Subject: [PATCH 11/11] Fix more clippy warnings Signed-off-by: Pragyan Poudyal --- src/oci/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/oci/mod.rs b/src/oci/mod.rs index 91937d23..8d36d359 100644 --- a/src/oci/mod.rs +++ b/src/oci/mod.rs @@ -262,7 +262,7 @@ impl<'repo> ImageOp<'repo> { let layers_to_chunks = chunks .iter() .enumerate() - .flat_map(|(i, chunk)| std::iter::repeat(i).take(chunk.len()).collect::>()) + .flat_map(|(i, chunk)| std::iter::repeat_n(i, chunk.len()).collect::>()) .collect::>(); encoder_threads = encoder_threads.min(chunks.len());