Skip to content

Commit d0f3981

Browse files
committed
Add support for Sequence (and iterrable in general)
1 parent de18fd3 commit d0f3981

File tree

1 file changed

+53
-24
lines changed

1 file changed

+53
-24
lines changed

src/asn1.py

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class Classes(Asn1Enum):
9797
Context = 0x80
9898
Private = 0xc0
9999

100+
100101
class ReadFlags(IntEnum):
101102
OnlyValue = 0x00
102103
WithUnused = 0x01
@@ -177,6 +178,14 @@ def is_nan(x):
177178
return math.isnan(x)
178179

179180

181+
def is_iterable(value):
182+
try:
183+
iter(value)
184+
return True
185+
except TypeError:
186+
return False
187+
188+
180189
class Error(Exception):
181190
"""ASN.11 encoding or decoding error."""
182191

@@ -325,8 +334,13 @@ def write(self, value, nr=None, typ=None, cls=None): # type: (Any, Union[int, N
325334
if cls != Classes.Universal and nr is None:
326335
raise Error('Specify a tag number (nr) when using classes Application, Context or Private')
327336

337+
# Constructed
338+
if nr is None and not isinstance(value, str) and not isinstance(value, bytes) and is_iterable(value):
339+
nr = Numbers.Sequence
340+
if typ is None:
341+
typ = Types.Constructed
328342
# Primitive
329-
if nr is None:
343+
elif nr is None:
330344
if isinstance(value, bool):
331345
nr = Numbers.Boolean
332346
elif isinstance(value, int):
@@ -343,21 +357,19 @@ def write(self, value, nr=None, typ=None, cls=None): # type: (Any, Union[int, N
343357
if typ is None:
344358
typ = Types.Primitive
345359

346-
# Constructed
347-
if nr is None and isinstance(value, List):
348-
nr = Numbers.Sequence
349-
if typ is None:
350-
typ = Types.Constructed
351-
352360
if typ is None:
353361
typ = Types.Primitive
354362

355363
self._check_type(nr, typ, cls)
356364

357-
encoded = self._encode_value(cls, nr, value)
358-
self._emit_tag(nr, typ, cls)
359-
self._emit_length(len(encoded))
360-
self._write_bytes(encoded)
365+
if typ == Types.Primitive:
366+
encoded = self._encode_value(cls, nr, value)
367+
self._emit_tag(nr, typ, cls)
368+
self._emit_length(len(encoded))
369+
self._write_bytes(encoded)
370+
else:
371+
self._emit_tag(nr, typ, cls)
372+
self._emit_sequence(value)
361373

362374
def output(self): # type: () -> bytes
363375
"""
@@ -426,7 +438,9 @@ def _emit_tag_long(self, nr, typ, cls): # type: (int, int, int) -> None
426438

427439
def _emit_length(self, length): # type: (int) -> None
428440
"""Emit length bytes."""
429-
if length < 128:
441+
if length == INDEFINITE_FORM:
442+
self._emit_indefinite_length()
443+
elif length < 128:
430444
self._emit_length_short(length)
431445
else:
432446
self._emit_length_long(length)
@@ -484,6 +498,7 @@ def _encode_value(self, cls, nr, value): # type: (int, int, Any) -> bytes
484498
return self._encode_real(value)
485499
if nr == Numbers.ObjectIdentifier:
486500
return self._encode_object_identifier(value)
501+
487502
return value
488503

489504
@staticmethod
@@ -618,18 +633,27 @@ def _encode_object_identifier(self, oid): # type: (str) -> bytes
618633
result.reverse()
619634
return bytes(result)
620635

636+
def _emit_sequence(self, value): # type: (List) -> bytes
637+
if not is_iterable(value):
638+
raise Error('value must be an iterable')
639+
640+
self._emit_indefinite_length()
641+
for item in iter(value):
642+
self.write(item)
643+
self._emit_eoc()
644+
621645

622646
class Decoder(object):
623647
"""ASN.1 decoder. Understands BER (and DER which is a subset)."""
624648

625649
def __init__(self): # type: () -> None
626650
"""Constructor."""
627-
self._stream = None # type: Union[io.RawIOBase, None] # Input stream
628-
self._byte = bytes() # type: bytes # Cached byte (to be able to implement eof)
629-
self._position = 0 # type: int # Due to caching, tell does not give the right position
630-
self._tag = None # type: Union[Tag, None] # Cached Tag (to be able to implement peek)
631-
self._levels = 0 # type: int # Number of recursive calls
632-
self._enters = 0 # type int # Number of enter calls without leave
651+
self._stream = None # type: Union[io.RawIOBase, io.BufferedIOBase, None] # Input stream
652+
self._byte = bytes() # type: bytes # Cached byte (to be able to implement eof)
653+
self._position = 0 # type: int # Due to caching, tell does not give the right position
654+
self._tag = None # type: Union[Tag, None] # Cached Tag (to be able to implement peek)
655+
self._levels = 0 # type: int # Number of recursive calls
656+
self._ends = [] # type: List[int] # End of the current element (or INDEFINITE_FORM) for enter / leave
633657

634658
def start(self, stream): # type: (Union[io.RawIOBase, bytes]) -> None
635659
"""
@@ -651,15 +675,15 @@ def start(self, stream): # type: (Union[io.RawIOBase, bytes]) -> None
651675
Raises:
652676
`Error`
653677
"""
654-
if not isinstance(stream, bytes) and not isinstance(stream, io.RawIOBase):
678+
if not isinstance(stream, bytes) and not isinstance(stream, io.RawIOBase) and not isinstance(stream, io.BufferedIOBase):
655679
raise Error('Expecting bytes or a subclass of io.RawIOBase.')
656680

657681
self._stream = io.BytesIO(stream) if isinstance(stream, bytes) else stream # type: ignore
658682
self._tag = None
659683
self._byte = bytes()
660684
self._position = 0
661685
self._levels = 0
662-
self._enters = 0
686+
self._ends = []
663687

664688
def peek(self): # type: () -> Union[Tag, None]
665689
"""
@@ -690,7 +714,12 @@ def peek(self): # type: () -> Union[Tag, None]
690714
return self._tag
691715
if self.eof():
692716
return None
717+
end = self._ends[-1] if len(self._ends) > 0 else None
718+
if end is not None and end != INDEFINITE_FORM and self._get_current_position() >= end:
719+
return None
693720
self._tag = self._decode_tag()
721+
if end == INDEFINITE_FORM and self._tag == (0, 0, 0):
722+
return None
694723
return self._tag
695724

696725
def read(self, flags=ReadFlags.OnlyValue): # type: (ReadFlags) -> Tuple[Union[Tag, None], Any]
@@ -762,9 +791,9 @@ def enter(self): # type: () -> None
762791
return
763792
if tag.typ != Types.Constructed:
764793
raise Error('Cannot enter a primitive tag.')
765-
self._decode_length(tag.typ)
794+
length = self._decode_length(tag.typ)
766795
self._tag = None
767-
self._enters += 1
796+
self._ends.append(self._get_current_position() + length)
768797

769798
def leave(self): # type: () -> None
770799
"""
@@ -782,10 +811,10 @@ def leave(self): # type: () -> None
782811
"""
783812
if self._stream is None:
784813
raise Error('No input selected. Call start() first.')
785-
if self._enters <= 0:
814+
if len(self._ends) <= 0:
786815
raise Error('Call to leave() without a corresponding enter() call.')
787816
self._tag = None
788-
self._enters -= 1
817+
self._ends.pop()
789818

790819
def _get_current_position(self): # type: () -> int
791820
return 0 if self._stream is None else self._position

0 commit comments

Comments
 (0)