Skip to content

Commit 89f90fd

Browse files
Optimise SSZ encoding and decoding (#55)
* Optimise FixedVector Decode * Add encoding benchmarks * u8 encoding benchmark * Update benchmarks to include encoding, optimise VariableList * Remove path patch * Remove unnecessary pub * Fix trailing bytes bug and add more tests * Remove junk * More tests for oversize FixedVector * Remove temp ByteVector from benches * Fix tests * Update comments * Fix Clippy * Test bool on unsafe codepath * Add test demonstrating bool UB * Fix UB by using TypeId * Fix test name Co-authored-by: Paul Hauner <[email protected]> --------- Co-authored-by: Paul Hauner <[email protected]>
1 parent 8e5b6df commit 89f90fd

File tree

2 files changed

+191
-26
lines changed

2 files changed

+191
-26
lines changed

src/fixed_vector.rs

Lines changed: 86 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ use crate::tree_hash::vec_tree_hash_root;
22
use crate::Error;
33
use serde::Deserialize;
44
use serde_derive::Serialize;
5+
use std::any::TypeId;
56
use std::marker::PhantomData;
7+
use std::mem;
68
use std::ops::{Deref, DerefMut, Index, IndexMut};
79
use std::slice::SliceIndex;
810
use tree_hash::Hash256;
@@ -283,7 +285,7 @@ impl<T, N: Unsigned> ssz::TryFromIter<T> for FixedVector<T, N> {
283285

284286
impl<T, N: Unsigned> ssz::Decode for FixedVector<T, N>
285287
where
286-
T: ssz::Decode,
288+
T: ssz::Decode + 'static,
287289
{
288290
fn is_ssz_fixed_len() -> bool {
289291
T::is_ssz_fixed_len()
@@ -305,6 +307,24 @@ where
305307
len: 0,
306308
expected: 1,
307309
})
310+
} else if TypeId::of::<T>() == TypeId::of::<u8>() {
311+
if bytes.len() != fixed_len {
312+
return Err(ssz::DecodeError::BytesInvalid(format!(
313+
"FixedVector of {} items has {} items",
314+
fixed_len,
315+
bytes.len(),
316+
)));
317+
}
318+
319+
// Safety: We've verified T is u8, so Vec<T> *is* Vec<u8>.
320+
let vec_u8 = bytes.to_vec();
321+
let vec_t = unsafe { mem::transmute::<Vec<u8>, Vec<T>>(vec_u8) };
322+
Self::new(vec_t).map_err(|e| {
323+
ssz::DecodeError::BytesInvalid(format!(
324+
"Wrong number of FixedVector elements: {:?}",
325+
e
326+
))
327+
})
308328
} else if T::is_ssz_fixed_len() {
309329
let num_items = bytes
310330
.len()
@@ -314,17 +334,24 @@ where
314334
if num_items != fixed_len {
315335
return Err(ssz::DecodeError::BytesInvalid(format!(
316336
"FixedVector of {} items has {} items",
317-
num_items, fixed_len
337+
fixed_len, num_items
318338
)));
319339
}
320340

321-
let vec = bytes.chunks(T::ssz_fixed_len()).try_fold(
322-
Vec::with_capacity(num_items),
323-
|mut vec, chunk| {
324-
vec.push(T::from_ssz_bytes(chunk)?);
325-
Ok(vec)
326-
},
327-
)?;
341+
// Check that we have a whole number of items and that it is safe to use chunks_exact
342+
if !bytes.len().is_multiple_of(T::ssz_fixed_len()) {
343+
return Err(ssz::DecodeError::BytesInvalid(format!(
344+
"FixedVector of {} items has {} bytes",
345+
num_items,
346+
bytes.len()
347+
)));
348+
}
349+
350+
let mut vec = Vec::with_capacity(num_items);
351+
for chunk in bytes.chunks_exact(T::ssz_fixed_len()) {
352+
vec.push(T::from_ssz_bytes(chunk)?);
353+
}
354+
328355
Self::new(vec).map_err(|e| {
329356
ssz::DecodeError::BytesInvalid(format!(
330357
"Wrong number of FixedVector elements: {:?}",
@@ -479,6 +506,56 @@ mod test {
479506
ssz_round_trip::<FixedVector<u16, U8>>(vec![0; 8].try_into().unwrap());
480507
}
481508

509+
// Test byte decoding (we have a specialised code path with unsafe code that NEEDS coverage).
510+
#[test]
511+
fn ssz_round_trip_u8_len_1024() {
512+
ssz_round_trip::<FixedVector<u8, U1024>>(vec![42; 1024].try_into().unwrap());
513+
ssz_round_trip::<FixedVector<u8, U1024>>(vec![0; 1024].try_into().unwrap());
514+
}
515+
516+
// bool is layout equivalent to u8 but must not use the same unsafe codepath because not all u8
517+
// values are valid bools.
518+
#[test]
519+
fn ssz_round_trip_bool_len_1024() {
520+
assert_eq!(mem::size_of::<bool>(), 1);
521+
assert_eq!(mem::align_of::<bool>(), 1);
522+
ssz_round_trip::<FixedVector<bool, U1024>>(vec![true; 1024].try_into().unwrap());
523+
ssz_round_trip::<FixedVector<bool, U1024>>(vec![false; 1024].try_into().unwrap());
524+
}
525+
526+
// Decoding a u8 vector as a vector of bools must fail, if we aren't careful we could trigger UB.
527+
#[test]
528+
fn ssz_u8_to_bool_len_1024() {
529+
let list_u8 = FixedVector::<u8, U8>::new(vec![0, 1, 2, 3, 4, 5, 6, 7]).unwrap();
530+
FixedVector::<bool, U8>::from_ssz_bytes(&list_u8.as_ssz_bytes()).unwrap_err();
531+
}
532+
533+
#[test]
534+
fn ssz_u8_len_1024_too_long() {
535+
assert_eq!(
536+
FixedVector::<u8, U1024>::from_ssz_bytes(&vec![42; 1025]).unwrap_err(),
537+
ssz::DecodeError::BytesInvalid("FixedVector of 1024 items has 1025 items".into())
538+
);
539+
}
540+
541+
#[test]
542+
fn ssz_u64_len_1024_too_long() {
543+
assert_eq!(
544+
FixedVector::<u64, U1024>::from_ssz_bytes(&vec![42; 8 * 1025]).unwrap_err(),
545+
ssz::DecodeError::BytesInvalid("FixedVector of 1024 items has 1025 items".into())
546+
);
547+
}
548+
549+
// Decoding an input with invalid trailing bytes MUST fail.
550+
#[test]
551+
fn ssz_bytes_u32_trailing() {
552+
let bytes = [1, 0, 0, 0, 2, 0, 0, 0, 1];
553+
assert_eq!(
554+
FixedVector::<u32, U2>::from_ssz_bytes(&bytes).unwrap_err(),
555+
ssz::DecodeError::BytesInvalid("FixedVector of 2 items has 9 bytes".into())
556+
);
557+
}
558+
482559
#[test]
483560
fn tree_hash_u8() {
484561
let fixed: FixedVector<u8, U0> = FixedVector::try_from(vec![]).unwrap();

src/variable_list.rs

Lines changed: 105 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ use crate::tree_hash::vec_tree_hash_root;
22
use crate::Error;
33
use serde::Deserialize;
44
use serde_derive::Serialize;
5+
use std::any::TypeId;
56
use std::marker::PhantomData;
7+
use std::mem;
68
use std::ops::{Deref, DerefMut, Index, IndexMut};
79
use std::slice::SliceIndex;
810
use tree_hash::Hash256;
@@ -288,7 +290,7 @@ impl<T, N: Unsigned> ssz::TryFromIter<T> for VariableList<T, N> {
288290

289291
impl<T, N> ssz::Decode for VariableList<T, N>
290292
where
291-
T: ssz::Decode,
293+
T: ssz::Decode + 'static,
292294
N: Unsigned,
293295
{
294296
fn is_ssz_fixed_len() -> bool {
@@ -302,6 +304,26 @@ where
302304
return Ok(Self::default());
303305
}
304306

307+
if TypeId::of::<T>() == TypeId::of::<u8>() {
308+
if bytes.len() > max_len {
309+
return Err(ssz::DecodeError::BytesInvalid(format!(
310+
"VariableList of {} items exceeds maximum of {}",
311+
bytes.len(),
312+
max_len
313+
)));
314+
}
315+
316+
// Safety: We've verified T is u8, so Vec<T> *is* Vec<u8>.
317+
let vec_u8 = bytes.to_vec();
318+
let vec_t = unsafe { mem::transmute::<Vec<u8>, Vec<T>>(vec_u8) };
319+
return Self::new(vec_t).map_err(|e| {
320+
ssz::DecodeError::BytesInvalid(format!(
321+
"Wrong number of VariableList elements: {:?}",
322+
e
323+
))
324+
});
325+
}
326+
305327
if T::is_ssz_fixed_len() {
306328
let num_items = bytes
307329
.len()
@@ -315,20 +337,28 @@ where
315337
)));
316338
}
317339

318-
bytes.chunks(T::ssz_fixed_len()).try_fold(
319-
Vec::with_capacity(num_items),
320-
|mut vec, chunk| {
321-
vec.push(T::from_ssz_bytes(chunk)?);
322-
Ok(vec)
323-
},
324-
)
340+
// Check that we have a whole number of items and that it is safe to use chunks_exact
341+
if !bytes.len().is_multiple_of(T::ssz_fixed_len()) {
342+
return Err(ssz::DecodeError::BytesInvalid(format!(
343+
"VariableList of {} items has {} bytes",
344+
num_items,
345+
bytes.len()
346+
)));
347+
}
348+
349+
let mut vec = Vec::with_capacity(num_items);
350+
for chunk in bytes.chunks_exact(T::ssz_fixed_len()) {
351+
vec.push(T::from_ssz_bytes(chunk)?);
352+
}
353+
Self::new(vec).map_err(|e| {
354+
ssz::DecodeError::BytesInvalid(format!(
355+
"Wrong number of VariableList elements: {:?}",
356+
e
357+
))
358+
})
325359
} else {
326360
ssz::decode_list_of_variable_length_items(bytes, Some(max_len))
327-
}?
328-
.try_into()
329-
.map_err(|e| {
330-
ssz::DecodeError::BytesInvalid(format!("VariableList::try_from failed: {e:?}"))
331-
})
361+
}
332362
}
333363
}
334364

@@ -452,17 +482,60 @@ mod test {
452482
assert_eq!(<VariableList<u16, U2> as Encode>::ssz_fixed_len(), 4);
453483
}
454484

455-
fn round_trip<T: Encode + Decode + std::fmt::Debug + PartialEq>(item: T) {
485+
fn ssz_round_trip<T: Encode + Decode + std::fmt::Debug + PartialEq>(item: T) {
456486
let encoded = &item.as_ssz_bytes();
457487
assert_eq!(item.ssz_bytes_len(), encoded.len());
458488
assert_eq!(T::from_ssz_bytes(encoded), Ok(item));
459489
}
460490

461491
#[test]
462492
fn u16_len_8() {
463-
round_trip::<VariableList<u16, U8>>(vec![42; 8].try_into().unwrap());
464-
round_trip::<VariableList<u16, U8>>(vec![0; 8].try_into().unwrap());
465-
round_trip::<VariableList<u16, U8>>(vec![].try_into().unwrap());
493+
ssz_round_trip::<VariableList<u16, U8>>(vec![42; 8].try_into().unwrap());
494+
ssz_round_trip::<VariableList<u16, U8>>(vec![0; 8].try_into().unwrap());
495+
ssz_round_trip::<VariableList<u16, U8>>(vec![].try_into().unwrap());
496+
}
497+
498+
#[test]
499+
fn ssz_round_trip_u8_len_1024() {
500+
ssz_round_trip::<VariableList<u8, U1024>>(vec![42; 1024].try_into().unwrap());
501+
ssz_round_trip::<VariableList<u8, U1024>>(vec![0; 1024].try_into().unwrap());
502+
}
503+
504+
// bool is layout equivalent to u8 but must not use the same unsafe codepath because not all u8
505+
// values are valid bools.
506+
#[test]
507+
fn ssz_round_trip_bool_len_1024() {
508+
assert_eq!(mem::size_of::<bool>(), 1);
509+
assert_eq!(mem::align_of::<bool>(), 1);
510+
ssz_round_trip::<VariableList<bool, U1024>>(vec![true; 1024].try_into().unwrap());
511+
ssz_round_trip::<VariableList<bool, U1024>>(vec![false; 1024].try_into().unwrap());
512+
}
513+
514+
// Decoding a u8 list as a list of bools must fail, if we aren't careful we could trigger UB.
515+
#[test]
516+
fn ssz_u8_to_bool_len_1024() {
517+
let list_u8 = VariableList::<u8, U8>::new(vec![0, 1, 2, 3, 4, 5, 6, 7]).unwrap();
518+
VariableList::<bool, U8>::from_ssz_bytes(&list_u8.as_ssz_bytes()).unwrap_err();
519+
}
520+
521+
#[test]
522+
fn ssz_u8_len_1024_too_long() {
523+
assert_eq!(
524+
VariableList::<u8, U1024>::from_ssz_bytes(&vec![42; 1025]).unwrap_err(),
525+
ssz::DecodeError::BytesInvalid(
526+
"VariableList of 1025 items exceeds maximum of 1024".into()
527+
)
528+
);
529+
}
530+
531+
#[test]
532+
fn ssz_u64_len_1024_too_long() {
533+
assert_eq!(
534+
VariableList::<u64, U1024>::from_ssz_bytes(&vec![42; 8 * 1025]).unwrap_err(),
535+
ssz::DecodeError::BytesInvalid(
536+
"VariableList of 1025 items exceeds maximum of 1024".into()
537+
)
538+
);
466539
}
467540

468541
#[test]
@@ -473,6 +546,21 @@ mod test {
473546
assert_eq!(VariableList::from_ssz_bytes(&[]).unwrap(), empty_list);
474547
}
475548

549+
#[test]
550+
fn ssz_bytes_u32_trailing() {
551+
let bytes = [1, 0, 0, 0, 2, 0];
552+
assert_eq!(
553+
VariableList::<u32, U2>::from_ssz_bytes(&bytes).unwrap_err(),
554+
ssz::DecodeError::BytesInvalid("VariableList of 1 items has 6 bytes".into())
555+
);
556+
557+
let bytes = [1, 0, 0, 0, 2, 0, 0, 0, 3];
558+
assert_eq!(
559+
VariableList::<u32, U2>::from_ssz_bytes(&bytes).unwrap_err(),
560+
ssz::DecodeError::BytesInvalid("VariableList of 2 items has 9 bytes".into())
561+
);
562+
}
563+
476564
fn root_with_length(bytes: &[u8], len: usize) -> Hash256 {
477565
let root = merkle_root(bytes, 0);
478566
tree_hash::mix_in_length(&root, len)

0 commit comments

Comments
 (0)