Skip to content

Commit 817675a

Browse files
committed
Remake format guess functions
1 parent 3a355aa commit 817675a

2 files changed

Lines changed: 115 additions & 51 deletions

File tree

src/snappy/snappy.py

Lines changed: 58 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -149,23 +149,16 @@ def __init__(self):
149149
self.remains = None
150150

151151
@staticmethod
152-
def check_format(data):
152+
def check_format(fin):
153153
"""Checks that the given data starts with snappy framing format
154154
stream identifier.
155155
Raises UncompressError if it doesn't start with the identifier.
156156
:return: None
157157
"""
158-
if len(data) < 6:
159-
raise UncompressError("Too short data length")
160-
chunk_type = struct.unpack("<L", data[:4])[0]
161-
size = (chunk_type >> 8)
162-
chunk_type &= 0xff
163-
if (chunk_type != _IDENTIFIER_CHUNK or
164-
size != len(_STREAM_IDENTIFIER)):
165-
raise UncompressError("stream missing snappy identifier")
166-
chunk = data[4:4 + size]
167-
if chunk != _STREAM_IDENTIFIER:
168-
raise UncompressError("stream has invalid snappy identifier")
158+
try:
159+
return fin.read(len(_STREAM_HEADER_BLOCK)) == _STREAM_HEADER_BLOCK
160+
except:
161+
return False
169162

170163
def decompress(self, data: bytes):
171164
"""Decompress 'data', returning a string containing the uncompressed
@@ -233,14 +226,23 @@ def __init__(self):
233226
self.remains = b""
234227

235228
@staticmethod
236-
def check_format(data):
229+
def check_format(fin):
237230
"""Checks that there are enough bytes for a hadoop header
238231
239232
We cannot actually determine if the data is really hadoop-snappy
240233
"""
241-
if len(data) < 8:
242-
raise UncompressError("Too short data length")
243-
chunk_length = int.from_bytes(data[4:8], "big")
234+
try:
235+
from snappy.snappy_formats import check_unframed_format
236+
size = fin.seek(0, 2)
237+
fin.seek(0)
238+
assert size >= 8
239+
240+
chunk_length = int.from_bytes(fin.read(4), "big")
241+
assert chunk_length < size
242+
fin.read(4)
243+
return check_unframed_format(fin)
244+
except:
245+
return False
244246

245247
def decompress(self, data: bytes):
246248
"""Decompress 'data', returning a string containing the uncompressed
@@ -319,16 +321,43 @@ def stream_decompress(src,
319321
decompressor.flush() # makes sure the stream ended well
320322

321323

322-
def check_format(fin=None, chunk=None,
323-
blocksize=_STREAM_TO_STREAM_BLOCK_SIZE,
324-
decompressor_cls=StreamDecompressor):
325-
ok = True
326-
if chunk is None:
327-
chunk = fin.read(blocksize)
328-
if not chunk:
329-
raise UncompressError("Empty input stream")
330-
try:
331-
decompressor_cls.check_format(chunk)
332-
except UncompressError as err:
333-
ok = False
334-
return ok, chunk
324+
def hadoop_stream_decompress(
325+
src,
326+
dst,
327+
blocksize=_STREAM_TO_STREAM_BLOCK_SIZE,
328+
):
329+
c = HadoopStreamDecompressor()
330+
while True:
331+
data = src.read(blocksize)
332+
if not data:
333+
break
334+
buf = c.decompress(data)
335+
if buf:
336+
dst.write()
337+
dst.flush()
338+
339+
340+
def hadoop_stream_compress(
341+
src,
342+
dst,
343+
blocksize=_STREAM_TO_STREAM_BLOCK_SIZE,
344+
):
345+
c = HadoopStreamCompressor()
346+
while True:
347+
data = src.read(blocksize)
348+
if not data:
349+
break
350+
buf = c.compress(data)
351+
if buf:
352+
dst.write()
353+
dst.flush()
354+
355+
356+
def raw_stream_decompress(src, dst):
357+
data = src.read()
358+
dst.write(decompress(data))
359+
360+
361+
def raw_stream_compress(src, dst):
362+
data = src.read()
363+
dst.write(compress(data))

src/snappy/snappy_formats.py

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,40 +8,73 @@
88
from __future__ import absolute_import
99

1010
from .snappy import (
11-
stream_compress, stream_decompress, check_format, UncompressError)
12-
11+
HadoopStreamDecompressor, StreamDecompressor,
12+
hadoop_stream_compress, hadoop_stream_decompress, raw_stream_compress,
13+
raw_stream_decompress, stream_compress, stream_decompress,
14+
UncompressError
15+
)
1316

14-
FRAMING_FORMAT = 'framing'
1517

1618
# Means format auto detection.
1719
# For compression will be used framing format.
1820
# In case of decompression will try to detect a format from the input stream
1921
# header.
20-
FORMAT_AUTO = 'auto'
21-
22-
DEFAULT_FORMAT = FORMAT_AUTO
22+
DEFAULT_FORMAT = "auto"
2323

24-
ALL_SUPPORTED_FORMATS = [FRAMING_FORMAT, FORMAT_AUTO]
24+
ALL_SUPPORTED_FORMATS = ["framing", "auto"]
2525

2626
_COMPRESS_METHODS = {
27-
FRAMING_FORMAT: stream_compress,
27+
"framing": stream_compress,
28+
"hadoop": hadoop_stream_compress,
29+
"raw": raw_stream_compress
2830
}
2931

3032
_DECOMPRESS_METHODS = {
31-
FRAMING_FORMAT: stream_decompress,
33+
"framing": stream_decompress,
34+
"hadoop": hadoop_stream_decompress,
35+
"raw": raw_stream_decompress
3236
}
3337

3438
# We will use framing format as the default to compression.
3539
# And for decompression, if it's not defined explicitly, we will try to
3640
# guess the format from the file header.
37-
_DEFAULT_COMPRESS_FORMAT = FRAMING_FORMAT
41+
_DEFAULT_COMPRESS_FORMAT = "framing"
42+
43+
44+
def uvarint(fin):
45+
result = 0
46+
shift = 0
47+
while True:
48+
byte = fin.read(1)[0]
49+
result |= (byte & 0x7F) << shift
50+
if (byte & 0x80) == 0:
51+
break
52+
shift += 7
53+
return result
54+
55+
56+
def check_unframed_format(fin):
57+
fin.seek(0)
58+
try:
59+
size = uvarint(fin)
60+
assert size < 2**32 - 1
61+
next_byte = fin.read(1)[0]
62+
end = fin.seek(0, 2)
63+
assert size < end
64+
assert next_byte & 0b11 == 0 # must start with literal block
65+
return True
66+
except:
67+
return False
68+
3869

3970
# The tuple contains an ordered sequence of a format checking function and
4071
# a format-specific decompression function.
4172
# Framing format has it's header, that may be recognized.
42-
_DECOMPRESS_FORMAT_FUNCS = (
43-
(check_format, stream_decompress),
44-
)
73+
_DECOMPRESS_FORMAT_FUNCS = {
74+
"framed": stream_decompress,
75+
"hadoop": hadoop_stream_decompress,
76+
"raw": raw_stream_decompress
77+
}
4578

4679

4780
def guess_format_by_header(fin):
@@ -50,23 +83,25 @@ def guess_format_by_header(fin):
5083
:return: tuple of decompression method and a chunk that was taken from the
5184
input for format detection.
5285
"""
53-
chunk = None
54-
for check_method, decompress_func in _DECOMPRESS_FORMAT_FUNCS:
55-
ok, chunk = check_method(fin=fin, chunk=chunk)
56-
if not ok:
57-
continue
58-
return decompress_func, chunk
59-
raise UncompressError("Can't detect archive format")
86+
if StreamDecompressor.check_format(fin):
87+
form = "framed"
88+
elif HadoopStreamDecompressor.check_format(fin):
89+
form = "hadoop"
90+
elif check_unframed_format(fin):
91+
form = "raw"
92+
else:
93+
raise UncompressError("Can't detect format")
94+
return form, _DECOMPRESS_FORMAT_FUNCS[form]
6095

6196

6297
def get_decompress_function(specified_format, fin):
63-
if specified_format == FORMAT_AUTO:
98+
if specified_format == "auto":
6499
decompress_func, read_chunk = guess_format_by_header(fin)
65100
return decompress_func, read_chunk
66101
return _DECOMPRESS_METHODS[specified_format], None
67102

68103

69104
def get_compress_function(specified_format):
70-
if specified_format == FORMAT_AUTO:
105+
if specified_format == "auto":
71106
return _COMPRESS_METHODS[_DEFAULT_COMPRESS_FORMAT]
72107
return _COMPRESS_METHODS[specified_format]

0 commit comments

Comments
 (0)