Skip to content

Commit ac4ddd3

Browse files
authored
Merge pull request #144 from martindurant/guesses
Remake format guess functions
2 parents 3a355aa + 0906e06 commit ac4ddd3

3 files changed

Lines changed: 165 additions & 74 deletions

File tree

src/snappy/snappy.py

Lines changed: 62 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -149,23 +149,15 @@ def __init__(self):
149149
self.remains = None
150150

151151
@staticmethod
152-
def check_format(data):
153-
"""Checks that the given data starts with snappy framing format
154-
stream identifier.
155-
Raises UncompressError if it doesn't start with the identifier.
156-
:return: None
152+
def check_format(fin):
153+
"""Does this stream start with a stream header block?
154+
155+
True indicates that the stream can likely be decoded using this class.
157156
"""
158-
if len(data) < 6:
159-
raise UncompressError("Too short data length")
160-
chunk_type = struct.unpack("<L", data[:4])[0]
161-
size = (chunk_type >> 8)
162-
chunk_type &= 0xff
163-
if (chunk_type != _IDENTIFIER_CHUNK or
164-
size != len(_STREAM_IDENTIFIER)):
165-
raise UncompressError("stream missing snappy identifier")
166-
chunk = data[4:4 + size]
167-
if chunk != _STREAM_IDENTIFIER:
168-
raise UncompressError("stream has invalid snappy identifier")
157+
try:
158+
return fin.read(len(_STREAM_HEADER_BLOCK)) == _STREAM_HEADER_BLOCK
159+
except:
160+
return False
169161

170162
def decompress(self, data: bytes):
171163
"""Decompress 'data', returning a string containing the uncompressed
@@ -233,14 +225,21 @@ def __init__(self):
233225
self.remains = b""
234226

235227
@staticmethod
236-
def check_format(data):
237-
"""Checks that there are enough bytes for a hadoop header
238-
239-
We cannot actually determine if the data is really hadoop-snappy
228+
def check_format(fin):
229+
"""Does this look like a hadoop snappy stream?
240230
"""
241-
if len(data) < 8:
242-
raise UncompressError("Too short data length")
243-
chunk_length = int.from_bytes(data[4:8], "big")
231+
try:
232+
from snappy.snappy_formats import check_unframed_format
233+
size = fin.seek(0, 2)
234+
fin.seek(0)
235+
assert size >= 8
236+
237+
chunk_length = int.from_bytes(fin.read(4), "big")
238+
assert chunk_length < size
239+
fin.read(4)
240+
return check_unframed_format(fin)
241+
except:
242+
return False
244243

245244
def decompress(self, data: bytes):
246245
"""Decompress 'data', returning a string containing the uncompressed
@@ -319,16 +318,43 @@ def stream_decompress(src,
319318
decompressor.flush() # makes sure the stream ended well
320319

321320

322-
def check_format(fin=None, chunk=None,
323-
blocksize=_STREAM_TO_STREAM_BLOCK_SIZE,
324-
decompressor_cls=StreamDecompressor):
325-
ok = True
326-
if chunk is None:
327-
chunk = fin.read(blocksize)
328-
if not chunk:
329-
raise UncompressError("Empty input stream")
330-
try:
331-
decompressor_cls.check_format(chunk)
332-
except UncompressError as err:
333-
ok = False
334-
return ok, chunk
321+
def hadoop_stream_decompress(
322+
src,
323+
dst,
324+
blocksize=_STREAM_TO_STREAM_BLOCK_SIZE,
325+
):
326+
c = HadoopStreamDecompressor()
327+
while True:
328+
data = src.read(blocksize)
329+
if not data:
330+
break
331+
buf = c.decompress(data)
332+
if buf:
333+
dst.write(buf)
334+
dst.flush()
335+
336+
337+
def hadoop_stream_compress(
338+
src,
339+
dst,
340+
blocksize=_STREAM_TO_STREAM_BLOCK_SIZE,
341+
):
342+
c = HadoopStreamCompressor()
343+
while True:
344+
data = src.read(blocksize)
345+
if not data:
346+
break
347+
buf = c.compress(data)
348+
if buf:
349+
dst.write(buf)
350+
dst.flush()
351+
352+
353+
def raw_stream_decompress(src, dst):
354+
data = src.read()
355+
dst.write(decompress(data))
356+
357+
358+
def raw_stream_compress(src, dst):
359+
data = src.read()
360+
dst.write(compress(data))

src/snappy/snappy_formats.py

Lines changed: 69 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,65 +8,107 @@
88
from __future__ import absolute_import
99

1010
from .snappy import (
11-
stream_compress, stream_decompress, check_format, UncompressError)
12-
11+
HadoopStreamDecompressor, StreamDecompressor,
12+
hadoop_stream_compress, hadoop_stream_decompress, raw_stream_compress,
13+
raw_stream_decompress, stream_compress, stream_decompress,
14+
UncompressError
15+
)
1316

14-
FRAMING_FORMAT = 'framing'
1517

1618
# Means format auto detection.
1719
# For compression will be used framing format.
1820
# In case of decompression will try to detect a format from the input stream
1921
# header.
20-
FORMAT_AUTO = 'auto'
22+
DEFAULT_FORMAT = "auto"
2123

22-
DEFAULT_FORMAT = FORMAT_AUTO
23-
24-
ALL_SUPPORTED_FORMATS = [FRAMING_FORMAT, FORMAT_AUTO]
24+
ALL_SUPPORTED_FORMATS = ["framing", "auto"]
2525

2626
_COMPRESS_METHODS = {
27-
FRAMING_FORMAT: stream_compress,
27+
"framing": stream_compress,
28+
"hadoop": hadoop_stream_compress,
29+
"raw": raw_stream_compress
2830
}
2931

3032
_DECOMPRESS_METHODS = {
31-
FRAMING_FORMAT: stream_decompress,
33+
"framing": stream_decompress,
34+
"hadoop": hadoop_stream_decompress,
35+
"raw": raw_stream_decompress
3236
}
3337

3438
# We will use framing format as the default to compression.
3539
# And for decompression, if it's not defined explicitly, we will try to
3640
# guess the format from the file header.
37-
_DEFAULT_COMPRESS_FORMAT = FRAMING_FORMAT
41+
_DEFAULT_COMPRESS_FORMAT = "framing"
42+
43+
44+
def uvarint(fin):
45+
"""Read uint64 nbumber from varint encoding in a stream"""
46+
result = 0
47+
shift = 0
48+
while True:
49+
byte = fin.read(1)[0]
50+
result |= (byte & 0x7F) << shift
51+
if (byte & 0x80) == 0:
52+
break
53+
shift += 7
54+
return result
55+
56+
57+
def check_unframed_format(fin, reset=False):
58+
"""Can this be read using the raw codec
59+
60+
This function wil return True for all snappy raw streams, but
61+
True does not mean that we can necessarily decode the stream.
62+
"""
63+
if reset:
64+
fin.seek(0)
65+
try:
66+
size = uvarint(fin)
67+
assert size < 2**32 - 1
68+
next_byte = fin.read(1)[0]
69+
end = fin.seek(0, 2)
70+
assert size < end
71+
assert next_byte & 0b11 == 0 # must start with literal block
72+
return True
73+
except:
74+
return False
75+
3876

3977
# The tuple contains an ordered sequence of a format checking function and
4078
# a format-specific decompression function.
4179
# Framing format has it's header, that may be recognized.
42-
_DECOMPRESS_FORMAT_FUNCS = (
43-
(check_format, stream_decompress),
44-
)
80+
_DECOMPRESS_FORMAT_FUNCS = {
81+
"framed": stream_decompress,
82+
"hadoop": hadoop_stream_decompress,
83+
"raw": raw_stream_decompress
84+
}
4585

4686

4787
def guess_format_by_header(fin):
4888
"""Tries to guess a compression format for the given input file by it's
4989
header.
50-
:return: tuple of decompression method and a chunk that was taken from the
51-
input for format detection.
90+
91+
:return: format name (str), stream decompress function (callable)
5292
"""
53-
chunk = None
54-
for check_method, decompress_func in _DECOMPRESS_FORMAT_FUNCS:
55-
ok, chunk = check_method(fin=fin, chunk=chunk)
56-
if not ok:
57-
continue
58-
return decompress_func, chunk
59-
raise UncompressError("Can't detect archive format")
93+
if StreamDecompressor.check_format(fin):
94+
form = "framed"
95+
elif HadoopStreamDecompressor.check_format(fin):
96+
form = "hadoop"
97+
elif check_unframed_format(fin, reset=True):
98+
form = "raw"
99+
else:
100+
raise UncompressError("Can't detect format")
101+
return form, _DECOMPRESS_FORMAT_FUNCS[form]
60102

61103

62104
def get_decompress_function(specified_format, fin):
63-
if specified_format == FORMAT_AUTO:
64-
decompress_func, read_chunk = guess_format_by_header(fin)
65-
return decompress_func, read_chunk
66-
return _DECOMPRESS_METHODS[specified_format], None
105+
if specified_format == "auto":
106+
format, decompress_func = guess_format_by_header(fin)
107+
return decompress_func
108+
return _DECOMPRESS_METHODS[specified_format]
67109

68110

69111
def get_compress_function(specified_format):
70-
if specified_format == FORMAT_AUTO:
112+
if specified_format == "auto":
71113
return _COMPRESS_METHODS[_DEFAULT_COMPRESS_FORMAT]
72114
return _COMPRESS_METHODS[specified_format]

test_formats.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
from unittest import TestCase
44

55
from snappy import snappy_formats as formats
6-
from snappy.snappy import _CHUNK_MAX, UncompressError
76

87

98
class TestFormatBase(TestCase):
10-
compress_format = formats.FORMAT_AUTO
11-
decompress_format = formats.FORMAT_AUTO
9+
compress_format = "auto"
10+
decompress_format = "auto"
1211
success = True
1312

1413
def runTest(self):
@@ -18,34 +17,58 @@ def runTest(self):
1817
compressed_stream = io.BytesIO()
1918
compress_func(instream, compressed_stream)
2019
compressed_stream.seek(0)
21-
decompress_func, read_chunk = formats.get_decompress_function(
20+
decompress_func = formats.get_decompress_function(
2221
self.decompress_format, compressed_stream
2322
)
23+
compressed_stream.seek(0)
2424
decompressed_stream = io.BytesIO()
2525
decompress_func(
2626
compressed_stream,
2727
decompressed_stream,
28-
start_chunk=read_chunk
2928
)
3029
decompressed_stream.seek(0)
3130
self.assertEqual(data, decompressed_stream.read())
3231

3332

3433
class TestFormatFramingFraming(TestFormatBase):
35-
compress_format = formats.FRAMING_FORMAT
36-
decompress_format = formats.FRAMING_FORMAT
34+
compress_format = "framing"
35+
decompress_format = "framing"
3736
success = True
3837

3938

4039
class TestFormatFramingAuto(TestFormatBase):
41-
compress_format = formats.FRAMING_FORMAT
42-
decompress_format = formats.FORMAT_AUTO
40+
compress_format = "framing"
41+
decompress_format = "auto"
4342
success = True
4443

4544

4645
class TestFormatAutoFraming(TestFormatBase):
47-
compress_format = formats.FORMAT_AUTO
48-
decompress_format = formats.FRAMING_FORMAT
46+
compress_format = "auto"
47+
decompress_format = "framing"
48+
success = True
49+
50+
51+
class TestFormatHadoop(TestFormatBase):
52+
compress_format = "hadoop"
53+
decompress_format = "hadoop"
54+
success = True
55+
56+
57+
class TestFormatRaw(TestFormatBase):
58+
compress_format = "raw"
59+
decompress_format = "raw"
60+
success = True
61+
62+
63+
class TestFormatHadoopAuto(TestFormatBase):
64+
compress_format = "hadoop"
65+
decompress_format = "auto"
66+
success = True
67+
68+
69+
class TestFormatRawAuto(TestFormatBase):
70+
compress_format = "raw"
71+
decompress_format = "auto"
4972
success = True
5073

5174

0 commit comments

Comments
 (0)