@@ -97,6 +97,7 @@ class Classes(Asn1Enum):
9797 Context = 0x80
9898 Private = 0xc0
9999
100+
100101class ReadFlags (IntEnum ):
101102 OnlyValue = 0x00
102103 WithUnused = 0x01
@@ -177,6 +178,14 @@ def is_nan(x):
177178 return math .isnan (x )
178179
179180
181+ def is_iterable (value ):
182+ try :
183+ iter (value )
184+ return True
185+ except TypeError :
186+ return False
187+
188+
180189class Error (Exception ):
181190 """ASN.11 encoding or decoding error."""
182191
@@ -325,8 +334,13 @@ def write(self, value, nr=None, typ=None, cls=None): # type: (Any, Union[int, N
325334 if cls != Classes .Universal and nr is None :
326335 raise Error ('Specify a tag number (nr) when using classes Application, Context or Private' )
327336
337+ # Constructed
338+ if nr is None and not isinstance (value , str ) and not isinstance (value , bytes ) and is_iterable (value ):
339+ nr = Numbers .Sequence
340+ if typ is None :
341+ typ = Types .Constructed
328342 # Primitive
329- if nr is None :
343+ elif nr is None :
330344 if isinstance (value , bool ):
331345 nr = Numbers .Boolean
332346 elif isinstance (value , int ):
@@ -343,21 +357,19 @@ def write(self, value, nr=None, typ=None, cls=None): # type: (Any, Union[int, N
343357 if typ is None :
344358 typ = Types .Primitive
345359
346- # Constructed
347- if nr is None and isinstance (value , List ):
348- nr = Numbers .Sequence
349- if typ is None :
350- typ = Types .Constructed
351-
352360 if typ is None :
353361 typ = Types .Primitive
354362
355363 self ._check_type (nr , typ , cls )
356364
357- encoded = self ._encode_value (cls , nr , value )
358- self ._emit_tag (nr , typ , cls )
359- self ._emit_length (len (encoded ))
360- self ._write_bytes (encoded )
365+ if typ == Types .Primitive :
366+ encoded = self ._encode_value (cls , nr , value )
367+ self ._emit_tag (nr , typ , cls )
368+ self ._emit_length (len (encoded ))
369+ self ._write_bytes (encoded )
370+ else :
371+ self ._emit_tag (nr , typ , cls )
372+ self ._emit_sequence (value )
361373
362374 def output (self ): # type: () -> bytes
363375 """
@@ -426,7 +438,9 @@ def _emit_tag_long(self, nr, typ, cls): # type: (int, int, int) -> None
426438
427439 def _emit_length (self , length ): # type: (int) -> None
428440 """Emit length bytes."""
429- if length < 128 :
441+ if length == INDEFINITE_FORM :
442+ self ._emit_indefinite_length ()
443+ elif length < 128 :
430444 self ._emit_length_short (length )
431445 else :
432446 self ._emit_length_long (length )
@@ -484,6 +498,7 @@ def _encode_value(self, cls, nr, value): # type: (int, int, Any) -> bytes
484498 return self ._encode_real (value )
485499 if nr == Numbers .ObjectIdentifier :
486500 return self ._encode_object_identifier (value )
501+
487502 return value
488503
489504 @staticmethod
@@ -618,18 +633,27 @@ def _encode_object_identifier(self, oid): # type: (str) -> bytes
618633 result .reverse ()
619634 return bytes (result )
620635
636+ def _emit_sequence (self , value ): # type: (List) -> bytes
637+ if not is_iterable (value ):
638+ raise Error ('value must be an iterable' )
639+
640+ self ._emit_indefinite_length ()
641+ for item in iter (value ):
642+ self .write (item )
643+ self ._emit_eoc ()
644+
621645
622646class Decoder (object ):
623647 """ASN.1 decoder. Understands BER (and DER which is a subset)."""
624648
625649 def __init__ (self ): # type: () -> None
626650 """Constructor."""
627- self ._stream = None # type: Union[io.RawIOBase, None] # Input stream
628- self ._byte = bytes () # type: bytes # Cached byte (to be able to implement eof)
629- self ._position = 0 # type: int # Due to caching, tell does not give the right position
630- self ._tag = None # type: Union[Tag, None] # Cached Tag (to be able to implement peek)
631- self ._levels = 0 # type: int # Number of recursive calls
632- self ._enters = 0 # type int # Number of enter calls without leave
651+ self ._stream = None # type: Union[io.RawIOBase, io.BufferedIOBase , None] # Input stream
652+ self ._byte = bytes () # type: bytes # Cached byte (to be able to implement eof)
653+ self ._position = 0 # type: int # Due to caching, tell does not give the right position
654+ self ._tag = None # type: Union[Tag, None] # Cached Tag (to be able to implement peek)
655+ self ._levels = 0 # type: int # Number of recursive calls
656+ self ._ends = [] # type: List[int] # End of the current element (or INDEFINITE_FORM) for enter / leave
633657
634658 def start (self , stream ): # type: (Union[io.RawIOBase, bytes]) -> None
635659 """
@@ -651,15 +675,15 @@ def start(self, stream): # type: (Union[io.RawIOBase, bytes]) -> None
651675 Raises:
652676 `Error`
653677 """
654- if not isinstance (stream , bytes ) and not isinstance (stream , io .RawIOBase ):
678+ if not isinstance (stream , bytes ) and not isinstance (stream , io .RawIOBase ) and not isinstance ( stream , io . BufferedIOBase ) :
655679 raise Error ('Expecting bytes or a subclass of io.RawIOBase.' )
656680
657681 self ._stream = io .BytesIO (stream ) if isinstance (stream , bytes ) else stream # type: ignore
658682 self ._tag = None
659683 self ._byte = bytes ()
660684 self ._position = 0
661685 self ._levels = 0
662- self ._enters = 0
686+ self ._ends = []
663687
664688 def peek (self ): # type: () -> Union[Tag, None]
665689 """
@@ -690,7 +714,12 @@ def peek(self): # type: () -> Union[Tag, None]
690714 return self ._tag
691715 if self .eof ():
692716 return None
717+ end = self ._ends [- 1 ] if len (self ._ends ) > 0 else None
718+ if end is not None and end != INDEFINITE_FORM and self ._get_current_position () >= end :
719+ return None
693720 self ._tag = self ._decode_tag ()
721+ if end == INDEFINITE_FORM and self ._tag == (0 , 0 , 0 ):
722+ return None
694723 return self ._tag
695724
696725 def read (self , flags = ReadFlags .OnlyValue ): # type: (ReadFlags) -> Tuple[Union[Tag, None], Any]
@@ -762,9 +791,9 @@ def enter(self): # type: () -> None
762791 return
763792 if tag .typ != Types .Constructed :
764793 raise Error ('Cannot enter a primitive tag.' )
765- self ._decode_length (tag .typ )
794+ length = self ._decode_length (tag .typ )
766795 self ._tag = None
767- self ._enters += 1
796+ self ._ends . append ( self . _get_current_position () + length )
768797
769798 def leave (self ): # type: () -> None
770799 """
@@ -782,10 +811,10 @@ def leave(self): # type: () -> None
782811 """
783812 if self ._stream is None :
784813 raise Error ('No input selected. Call start() first.' )
785- if self ._enters <= 0 :
814+ if len ( self ._ends ) <= 0 :
786815 raise Error ('Call to leave() without a corresponding enter() call.' )
787816 self ._tag = None
788- self ._enters -= 1
817+ self ._ends . pop ()
789818
790819 def _get_current_position (self ): # type: () -> int
791820 return 0 if self ._stream is None else self ._position
0 commit comments