@@ -2,7 +2,9 @@ use crate::tree_hash::vec_tree_hash_root;
22use crate :: Error ;
33use serde:: Deserialize ;
44use serde_derive:: Serialize ;
5+ use std:: any:: TypeId ;
56use std:: marker:: PhantomData ;
7+ use std:: mem;
68use std:: ops:: { Deref , DerefMut , Index , IndexMut } ;
79use std:: slice:: SliceIndex ;
810use tree_hash:: Hash256 ;
@@ -288,7 +290,7 @@ impl<T, N: Unsigned> ssz::TryFromIter<T> for VariableList<T, N> {
288290
289291impl < T , N > ssz:: Decode for VariableList < T , N >
290292where
291- T : ssz:: Decode ,
293+ T : ssz:: Decode + ' static ,
292294 N : Unsigned ,
293295{
294296 fn is_ssz_fixed_len ( ) -> bool {
@@ -302,6 +304,26 @@ where
302304 return Ok ( Self :: default ( ) ) ;
303305 }
304306
307+ if TypeId :: of :: < T > ( ) == TypeId :: of :: < u8 > ( ) {
308+ if bytes. len ( ) > max_len {
309+ return Err ( ssz:: DecodeError :: BytesInvalid ( format ! (
310+ "VariableList of {} items exceeds maximum of {}" ,
311+ bytes. len( ) ,
312+ max_len
313+ ) ) ) ;
314+ }
315+
316+ // Safety: We've verified T is u8, so Vec<T> *is* Vec<u8>.
317+ let vec_u8 = bytes. to_vec ( ) ;
318+ let vec_t = unsafe { mem:: transmute :: < Vec < u8 > , Vec < T > > ( vec_u8) } ;
319+ return Self :: new ( vec_t) . map_err ( |e| {
320+ ssz:: DecodeError :: BytesInvalid ( format ! (
321+ "Wrong number of VariableList elements: {:?}" ,
322+ e
323+ ) )
324+ } ) ;
325+ }
326+
305327 if T :: is_ssz_fixed_len ( ) {
306328 let num_items = bytes
307329 . len ( )
@@ -315,20 +337,28 @@ where
315337 ) ) ) ;
316338 }
317339
318- bytes. chunks ( T :: ssz_fixed_len ( ) ) . try_fold (
319- Vec :: with_capacity ( num_items) ,
320- |mut vec, chunk| {
321- vec. push ( T :: from_ssz_bytes ( chunk) ?) ;
322- Ok ( vec)
323- } ,
324- )
340+ // Check that we have a whole number of items and that it is safe to use chunks_exact
341+ if !bytes. len ( ) . is_multiple_of ( T :: ssz_fixed_len ( ) ) {
342+ return Err ( ssz:: DecodeError :: BytesInvalid ( format ! (
343+ "VariableList of {} items has {} bytes" ,
344+ num_items,
345+ bytes. len( )
346+ ) ) ) ;
347+ }
348+
349+ let mut vec = Vec :: with_capacity ( num_items) ;
350+ for chunk in bytes. chunks_exact ( T :: ssz_fixed_len ( ) ) {
351+ vec. push ( T :: from_ssz_bytes ( chunk) ?) ;
352+ }
353+ Self :: new ( vec) . map_err ( |e| {
354+ ssz:: DecodeError :: BytesInvalid ( format ! (
355+ "Wrong number of VariableList elements: {:?}" ,
356+ e
357+ ) )
358+ } )
325359 } else {
326360 ssz:: decode_list_of_variable_length_items ( bytes, Some ( max_len) )
327- } ?
328- . try_into ( )
329- . map_err ( |e| {
330- ssz:: DecodeError :: BytesInvalid ( format ! ( "VariableList::try_from failed: {e:?}" ) )
331- } )
361+ }
332362 }
333363}
334364
@@ -452,17 +482,60 @@ mod test {
452482 assert_eq ! ( <VariableList <u16 , U2 > as Encode >:: ssz_fixed_len( ) , 4 ) ;
453483 }
454484
455- fn round_trip < T : Encode + Decode + std:: fmt:: Debug + PartialEq > ( item : T ) {
485+ fn ssz_round_trip < T : Encode + Decode + std:: fmt:: Debug + PartialEq > ( item : T ) {
456486 let encoded = & item. as_ssz_bytes ( ) ;
457487 assert_eq ! ( item. ssz_bytes_len( ) , encoded. len( ) ) ;
458488 assert_eq ! ( T :: from_ssz_bytes( encoded) , Ok ( item) ) ;
459489 }
460490
461491 #[ test]
462492 fn u16_len_8 ( ) {
463- round_trip :: < VariableList < u16 , U8 > > ( vec ! [ 42 ; 8 ] . try_into ( ) . unwrap ( ) ) ;
464- round_trip :: < VariableList < u16 , U8 > > ( vec ! [ 0 ; 8 ] . try_into ( ) . unwrap ( ) ) ;
465- round_trip :: < VariableList < u16 , U8 > > ( vec ! [ ] . try_into ( ) . unwrap ( ) ) ;
493+ ssz_round_trip :: < VariableList < u16 , U8 > > ( vec ! [ 42 ; 8 ] . try_into ( ) . unwrap ( ) ) ;
494+ ssz_round_trip :: < VariableList < u16 , U8 > > ( vec ! [ 0 ; 8 ] . try_into ( ) . unwrap ( ) ) ;
495+ ssz_round_trip :: < VariableList < u16 , U8 > > ( vec ! [ ] . try_into ( ) . unwrap ( ) ) ;
496+ }
497+
498+ #[ test]
499+ fn ssz_round_trip_u8_len_1024 ( ) {
500+ ssz_round_trip :: < VariableList < u8 , U1024 > > ( vec ! [ 42 ; 1024 ] . try_into ( ) . unwrap ( ) ) ;
501+ ssz_round_trip :: < VariableList < u8 , U1024 > > ( vec ! [ 0 ; 1024 ] . try_into ( ) . unwrap ( ) ) ;
502+ }
503+
504+ // bool is layout equivalent to u8 but must not use the same unsafe codepath because not all u8
505+ // values are valid bools.
506+ #[ test]
507+ fn ssz_round_trip_bool_len_1024 ( ) {
508+ assert_eq ! ( mem:: size_of:: <bool >( ) , 1 ) ;
509+ assert_eq ! ( mem:: align_of:: <bool >( ) , 1 ) ;
510+ ssz_round_trip :: < VariableList < bool , U1024 > > ( vec ! [ true ; 1024 ] . try_into ( ) . unwrap ( ) ) ;
511+ ssz_round_trip :: < VariableList < bool , U1024 > > ( vec ! [ false ; 1024 ] . try_into ( ) . unwrap ( ) ) ;
512+ }
513+
514+ // Decoding a u8 list as a list of bools must fail, if we aren't careful we could trigger UB.
515+ #[ test]
516+ fn ssz_u8_to_bool_len_1024 ( ) {
517+ let list_u8 = VariableList :: < u8 , U8 > :: new ( vec ! [ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ] ) . unwrap ( ) ;
518+ VariableList :: < bool , U8 > :: from_ssz_bytes ( & list_u8. as_ssz_bytes ( ) ) . unwrap_err ( ) ;
519+ }
520+
521+ #[ test]
522+ fn ssz_u8_len_1024_too_long ( ) {
523+ assert_eq ! (
524+ VariableList :: <u8 , U1024 >:: from_ssz_bytes( & vec![ 42 ; 1025 ] ) . unwrap_err( ) ,
525+ ssz:: DecodeError :: BytesInvalid (
526+ "VariableList of 1025 items exceeds maximum of 1024" . into( )
527+ )
528+ ) ;
529+ }
530+
531+ #[ test]
532+ fn ssz_u64_len_1024_too_long ( ) {
533+ assert_eq ! (
534+ VariableList :: <u64 , U1024 >:: from_ssz_bytes( & vec![ 42 ; 8 * 1025 ] ) . unwrap_err( ) ,
535+ ssz:: DecodeError :: BytesInvalid (
536+ "VariableList of 1025 items exceeds maximum of 1024" . into( )
537+ )
538+ ) ;
466539 }
467540
468541 #[ test]
@@ -473,6 +546,21 @@ mod test {
473546 assert_eq ! ( VariableList :: from_ssz_bytes( & [ ] ) . unwrap( ) , empty_list) ;
474547 }
475548
549+ #[ test]
550+ fn ssz_bytes_u32_trailing ( ) {
551+ let bytes = [ 1 , 0 , 0 , 0 , 2 , 0 ] ;
552+ assert_eq ! (
553+ VariableList :: <u32 , U2 >:: from_ssz_bytes( & bytes) . unwrap_err( ) ,
554+ ssz:: DecodeError :: BytesInvalid ( "VariableList of 1 items has 6 bytes" . into( ) )
555+ ) ;
556+
557+ let bytes = [ 1 , 0 , 0 , 0 , 2 , 0 , 0 , 0 , 3 ] ;
558+ assert_eq ! (
559+ VariableList :: <u32 , U2 >:: from_ssz_bytes( & bytes) . unwrap_err( ) ,
560+ ssz:: DecodeError :: BytesInvalid ( "VariableList of 2 items has 9 bytes" . into( ) )
561+ ) ;
562+ }
563+
476564 fn root_with_length ( bytes : & [ u8 ] , len : usize ) -> Hash256 {
477565 let root = merkle_root ( bytes, 0 ) ;
478566 tree_hash:: mix_in_length ( & root, len)
0 commit comments