Skip to content

Commit de18fd3

Browse files
committed
Use a flag instead of a separate read_unused function
1 parent 480f2c9 commit de18fd3

3 files changed

Lines changed: 58 additions & 64 deletions

File tree

src/asn1.py

Lines changed: 50 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ class Classes(Asn1Enum):
9797
Context = 0x80
9898
Private = 0xc0
9999

100+
class ReadFlags(IntEnum):
101+
OnlyValue = 0x00
102+
WithUnused = 0x01
103+
100104

101105
Tag = collections.namedtuple('Tag', 'nr typ cls')
102106
"""
@@ -620,12 +624,12 @@ class Decoder(object):
620624

621625
def __init__(self): # type: () -> None
622626
"""Constructor."""
623-
self._stream = None # type: Union[io.RawIOBase, None] # Input stream
624-
self._byte = bytes() # type: bytes # Cached byte (to be able to implement eof)
625-
self._position = 0 # type: int # Due to caching, tell does not give the right position
626-
self._tag = None # type: Union[Tag, None] # Cached Tag (to be able to implement peek)
627-
self._levels = 0 # type: int # Number of recursive calls
628-
self._enters = 0 # type int # Number of enter calls without leave
627+
self._stream = None # type: Union[io.RawIOBase, None] # Input stream
628+
self._byte = bytes() # type: bytes # Cached byte (to be able to implement eof)
629+
self._position = 0 # type: int # Due to caching, tell does not give the right position
630+
self._tag = None # type: Union[Tag, None] # Cached Tag (to be able to implement peek)
631+
self._levels = 0 # type: int # Number of recursive calls
632+
self._enters = 0 # type int # Number of enter calls without leave
629633

630634
def start(self, stream): # type: (Union[io.RawIOBase, bytes]) -> None
631635
"""
@@ -689,26 +693,7 @@ def peek(self): # type: () -> Union[Tag, None]
689693
self._tag = self._decode_tag()
690694
return self._tag
691695

692-
def read(self): # type: () -> Tuple[Union[Tag, None], Any]
693-
"""
694-
This method decodes one ASN.1 tag from the input and returns it as a
695-
``(tag, value, unused)`` tuple. ``tag`` is a 3-tuple ``(nr, typ, cls)``,
696-
while ``value`` is a Python object representing the ASN.1 value.
697-
The offset in the input is increased so that the next `Decoder.read()`
698-
or `Decoder.read_with_unused()` call will return the next tag. In case
699-
no more data is available from the input, this method returns ``None``
700-
to signal end-of-file.
701-
702-
Returns:
703-
`Tag`, value: The current ASN.1 tag and its value.
704-
705-
Raises:
706-
`Error`
707-
"""
708-
tag, value, _ = self.read_unused()
709-
return tag, value
710-
711-
def read_unused(self): # type: () -> Tuple[Union[Tag, None], Any, int]
696+
def read(self, flags=ReadFlags.OnlyValue): # type: (ReadFlags) -> Tuple[Union[Tag, None], Any]
712697
"""
713698
This method decodes one ASN.1 tag from the input and returns it as a
714699
``(tag, value, unused)`` tuple. ``tag`` is a 3-tuple ``(nr, typ, cls)``,
@@ -719,9 +704,12 @@ def read_unused(self): # type: () -> Tuple[Union[Tag, None], Any, int]
719704
no more data is available from the input, this method returns ``None``
720705
to signal end-of-file.
721706
707+
Args:
708+
flags (DecoderFlags): Return only the value or the value and the number of unused bits as a tuple.
709+
722710
Returns:
723-
`Tag`, value, unused: The current ASN.1 tag, its value and the
724-
number of unused bits.
711+
`Tag`, value: The current ASN.1 tag, its value and the
712+
number of unused bits if requested in `start`.
725713
726714
Raises:
727715
`Error`
@@ -731,10 +719,10 @@ def read_unused(self): # type: () -> Tuple[Union[Tag, None], Any, int]
731719
tag = self.peek()
732720
self._tag = None
733721
if tag is None:
734-
return None, None, 0
722+
return None, None
735723
length = self._decode_length(tag.typ)
736-
value, unused = self._decode_value(tag.cls, tag.typ, tag.nr, length)
737-
return tag, value, unused
724+
value = self._decode_value(tag.cls, tag.typ, tag.nr, length, flags)
725+
return tag, value
738726

739727
def eof(self): # type: () -> bool
740728
"""
@@ -889,7 +877,7 @@ def _decode_length(self, typ): # type: (int) -> int
889877
length = byte
890878
return length
891879

892-
def _decode_value(self, cls, typ, nr, length): # type: (int, int, int, int) -> Tuple[Any, int]
880+
def _decode_value(self, cls, typ, nr, length, flags): # type: (int, int, int, int, ReadFlags) -> Any
893881
"""Read a value from the input. length is -1 in case of an indefinite form."""
894882
# A primitive type shall have a definite form
895883
if typ == Types.Primitive and length == INDEFINITE_FORM:
@@ -909,35 +897,38 @@ def _decode_value(self, cls, typ, nr, length): # type: (int, int, int, int) ->
909897
raise Error('ASN1 decoding error: the Universal tag number {} shall have a constructed encoding'.format(nr))
910898

911899
if cls != Classes.Universal:
912-
return self._decode_bytes(typ, 0, length), 0
900+
return self._decode_bytes(typ, 0, length)
913901

914902
# Primitive encoding
915903
if nr == Numbers.Boolean:
916-
return self._decode_boolean(length), 0
904+
return self._decode_boolean(length)
917905
if nr in (Numbers.Integer, Numbers.Enumerated):
918-
return self._decode_integer(length), 0
906+
return self._decode_integer(length)
919907
if nr == Numbers.Null:
920-
return self._decode_null(length), 0
908+
return self._decode_null(length)
921909
if nr == Numbers.Real:
922-
return self._decode_real(length), 0
910+
return self._decode_real(length)
923911
if nr == Numbers.ObjectIdentifier:
924-
return self._decode_object_identifier(length), 0
912+
return self._decode_object_identifier(length)
925913

926914
# Primitive or Constructed encoding
927915
if nr == Numbers.OctetString:
928-
return self._decode_octet_string(typ, length), 0
916+
return self._decode_octet_string(typ, length)
929917
if nr in (Numbers.PrintableString, Numbers.IA5String,
930918
Numbers.UTF8String, Numbers.UTCTime,
931919
Numbers.GeneralizedTime):
932-
return self._decode_printable_string(typ, nr, length), 0
920+
return self._decode_printable_string(typ, nr, length)
933921
if nr == Numbers.BitString:
934-
return self._decode_bitstring(typ, length)
922+
value, unused = self._decode_bitstring(typ, length, flags)
923+
if flags & ReadFlags.WithUnused:
924+
return value, unused
925+
return value
935926

936927
# Constructed types
937928
if nr == Numbers.Sequence:
938-
return self._decode_sequence(length), 0
929+
return self._decode_sequence(length)
939930

940-
return self._decode_bytes(typ, 0, length), 0
931+
return self._decode_bytes(typ, 0, length)
941932

942933
@staticmethod
943934
def _check_length(actual_length, expected_length): # type: (int, int) -> None
@@ -1101,9 +1092,9 @@ def _is_eoc(self): # type: () -> bool
11011092
self._tag = None
11021093
return True
11031094

1104-
def _decode_bitstring(self, typ, length): # type: (int, int) -> Tuple[bytes, int]
1095+
def _decode_bitstring(self, typ, length, flags): # type: (int, int, ReadFlags) -> Tuple[bytes, int]
11051096
"""Decode a bitstring."""
1106-
data, unused = self._decode_bytes_constructed(Numbers.BitString, length) \
1097+
data, unused = self._decode_bytes_constructed(Numbers.BitString, length, flags) \
11071098
if typ == Types.Constructed else self._decode_bitstring_primitive(length)
11081099
if self._levels == 0:
11091100
data = shift_bits_right(data, unused)
@@ -1123,46 +1114,49 @@ def _decode_bitstring_primitive(self, length): # type: (int) -> Tuple[bytes, in
11231114

11241115
return content[1:], num_unused_bits
11251116

1126-
def _decode_bytes_constructed(self, nr, length): # type: (int, int) -> Tuple[bytes, int]
1117+
def _decode_bytes_constructed(self, nr, length, flags): # type: (int, int, ReadFlags) -> Tuple[bytes, int]
11271118
if length == INDEFINITE_FORM:
1128-
return self._decode_bytes_indefinite(nr)
1119+
return self._decode_bytes_indefinite(nr, flags)
11291120
else:
1130-
return self._decode_bytes_definite(nr, length)
1121+
return self._decode_bytes_definite(nr, length, flags)
11311122

1132-
def _decode_bytes_indefinite(self, nr): # type: (int) -> Tuple[bytes, int]
1123+
def _decode_bytes_indefinite(self, nr, flags): # type: (int, ReadFlags) -> Tuple[bytes, int]
11331124
self._levels += 1
11341125
value = b''
11351126
unused = 0
11361127
while not self._is_eoc():
11371128
if unused > 0:
11381129
raise Error('ASN1 decoding error: unused bits shall be 0 unless it is the last segment')
1139-
segment_value, segment_unused = self._decode_bytes_segment(nr)
1130+
segment_value, segment_unused = self._decode_bytes_segment(nr, flags)
11401131
value += segment_value
11411132
unused = segment_unused
11421133
self._levels -= 1
11431134
return value, unused
11441135

1145-
def _decode_bytes_segment(self, nr): # type: (int) -> Tuple[bytes, int]
1136+
def _decode_bytes_segment(self, nr, flags): # type: (int, ReadFlags) -> Tuple[bytes, int]
11461137
tag = self.peek()
11471138
if tag is None:
11481139
raise Error("ASN1 decoding error: premature end of input.")
11491140
if nr != 0 and (tag.nr != nr or tag.cls != Classes.Universal):
11501141
raise Error(
11511142
'ASN1 decoding error: invalid tag (cls={}, nr={}) in a constructed type.'.format(tag.cls, tag.nr))
1152-
tag, value, unused = self.read_unused()
1143+
tag, value = self.read(flags)
1144+
unused = 0
1145+
if flags & ReadFlags.WithUnused:
1146+
value, unused = value
11531147
if not isinstance(value, bytes):
11541148
raise Error('ASN1 decoding error: bytes were expected instead of the type {}'.format(type(value)))
11551149
return value, unused
11561150

1157-
def _decode_bytes_definite(self, nr, length): # type: (int, int) -> Tuple[bytes, int]
1151+
def _decode_bytes_definite(self, nr, length, flags): # type: (int, int, ReadFlags) -> Tuple[bytes, int]
11581152
self._levels += 1
11591153
value = b''
11601154
unused = 0
11611155
start_position = self._get_current_position()
11621156
while self._get_current_position() - start_position < length:
11631157
if unused > 0:
11641158
raise Error('ASN1 decoding error: unused bits shall be 0 unless it is the last segment')
1165-
segment_value, segment_unused = self._decode_bytes_segment(nr)
1159+
segment_value, segment_unused = self._decode_bytes_segment(nr, flags)
11661160
value += segment_value
11671161
unused = segment_unused
11681162

@@ -1176,15 +1170,15 @@ def _decode_octet_string(self, typ, length): # type: (int, int) -> bytes
11761170
"""Decode an octet string."""
11771171
if typ == Types.Primitive:
11781172
return self._read_bytes(length)
1179-
value, unused = self._decode_bytes_constructed(Numbers.OctetString, length)
1173+
value, unused = self._decode_bytes_constructed(Numbers.OctetString, length, ReadFlags.OnlyValue)
11801174
assert unused == 0
11811175
return value
11821176

11831177
def _decode_bytes(self, typ, nr, length): # type: (int, int, int) -> bytes
11841178
"""Decode a string."""
11851179
if typ == Types.Primitive:
11861180
return self._read_bytes(length)
1187-
value, unused = self._decode_bytes_constructed(nr, length)
1181+
value, unused = self._decode_bytes_constructed(nr, length, ReadFlags.OnlyValue)
11881182
assert unused == 0
11891183
return value
11901184

tests/test_asn1.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ def test_bitstring_primitive(self):
639639
dec.start(buf)
640640
tag = dec.peek()
641641
assert tag == (asn1.Numbers.BitString, asn1.Types.Primitive, asn1.Classes.Universal)
642-
tag, val, unused = dec.read_unused()
642+
tag, (val, unused) = dec.read(asn1.ReadFlags.WithUnused)
643643
assert val == b'\x12\x34\x56'
644644
assert unused == 0
645645

@@ -649,7 +649,7 @@ def test_bitstring_constructed(self):
649649
dec.start(buf)
650650
tag = dec.peek()
651651
assert tag == (asn1.Numbers.BitString, asn1.Types.Constructed, asn1.Classes.Universal)
652-
tag, val, unused = dec.read_unused()
652+
tag, (val, unused) = dec.read(asn1.ReadFlags.WithUnused)
653653
assert val == b'\x00\xB0\xB0'
654654
assert unused == 4
655655

@@ -659,7 +659,7 @@ def test_bitstring_unused_bits(self):
659659
dec.start(buf)
660660
tag = dec.peek()
661661
assert tag == (asn1.Numbers.BitString, asn1.Types.Primitive, asn1.Classes.Universal)
662-
ttag, val, unused = dec.read_unused()
662+
tag, (val, unused) = dec.read(asn1.ReadFlags.WithUnused)
663663
assert val == b'\x01\x23\x45'
664664
assert unused == 4
665665

tests/test_suite.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -385,15 +385,15 @@ def test_case_36(self, decode_ber):
385385
Using of "unused bits" in internal BIT STRINGs with constructive form of encoding.
386386
"""
387387
with pytest.raises(asn1.Error) as e:
388-
decode_ber.read()
388+
decode_ber.read(asn1.ReadFlags.WithUnused)
389389
assert "unused bits shall be 0" in str(e.value)
390390

391391
@pytest.mark.parametrize('filename', ['tc37.ber'])
392392
def test_case_37(self, decode_ber):
393393
"""
394394
Using of definite form of length block in case of constructive form of encoding.
395395
"""
396-
tag, value, unused = decode_ber.read_unused()
396+
tag, (value, unused) = decode_ber.read(asn1.ReadFlags.WithUnused)
397397
assert tag == (asn1.Numbers.BitString, asn1.Types.Constructed, asn1.Classes.Universal)
398398
assert isinstance(value, bytes)
399399
assert value == b'\x00\x10\x10'
@@ -404,7 +404,7 @@ def test_case_38(self, decode_ber):
404404
"""
405405
Using of indefinite form of length block in case of constructive form of encoding.
406406
"""
407-
tag, value, unused = decode_ber.read_unused()
407+
tag, (value, unused) = decode_ber.read(asn1.ReadFlags.WithUnused)
408408
assert tag == (asn1.Numbers.BitString, asn1.Types.Constructed, asn1.Classes.Universal)
409409
assert isinstance(value, bytes)
410410
assert value == b'\x00\xA3\xB5\xF2\x91\xCD'
@@ -415,7 +415,7 @@ def test_case_39(self, decode_ber):
415415
"""
416416
Using of constructive form of encoding for empty BIT STRING.
417417
"""
418-
tag, value, unused = decode_ber.read_unused()
418+
tag, (value, unused) = decode_ber.read(asn1.ReadFlags.WithUnused)
419419
assert tag == (asn1.Numbers.BitString, asn1.Types.Constructed, asn1.Classes.Universal)
420420
assert isinstance(value, bytes)
421421
assert value == b''
@@ -429,7 +429,7 @@ def test_case_40(self, decode_ber):
429429
So this case raises an exception
430430
"""
431431
with pytest.raises(asn1.Error) as e:
432-
decode_ber.read_unused()
432+
decode_ber.read(asn1.ReadFlags.WithUnused)
433433
assert "initial byte is missing" in str(e.value)
434434

435435
class TestSuiteOctetString:

0 commit comments

Comments
 (0)