beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From dhalp...@apache.org
Subject [1/2] incubator-beam git commit: Refactoring code in avroio.py to allow for re-use.
Date Tue, 13 Sep 2016 17:00:07 GMT
Repository: incubator-beam
Updated Branches:
  refs/heads/python-sdk bc32bc866 -> 4b584ca26


Refactoring code in avroio.py to allow for re-use.

* Making sure that _AvroUtils validates the sync_marker.
* This should detect corrupted or not-properly formatted AVRO files.
* Simplifying block reading.
* Running snappy tests only when snappy is installed in the system.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/f5557c00
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/f5557c00
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/f5557c00

Branch: refs/heads/python-sdk
Commit: f5557c00cd46c8560cad5027678603056a6235ff
Parents: bc32bc8
Author: Gus Katsiapis <katsiapis@katsiapis-linux.mtv.corp.google.com>
Authored: Mon Sep 12 10:11:44 2016 -0700
Committer: Dan Halperin <dhalperi@google.com>
Committed: Tue Sep 13 09:59:17 2016 -0700

----------------------------------------------------------------------
 sdks/python/apache_beam/io/avroio.py      | 252 +++++++++++++------------
 sdks/python/apache_beam/io/avroio_test.py |  88 +++++++--
 2 files changed, 202 insertions(+), 138 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/f5557c00/sdks/python/apache_beam/io/avroio.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/avroio.py b/sdks/python/apache_beam/io/avroio.py
index 7ad3842..196e760 100644
--- a/sdks/python/apache_beam/io/avroio.py
+++ b/sdks/python/apache_beam/io/avroio.py
@@ -14,7 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
 """Implements a source for reading Avro files."""
 
 import os
@@ -22,7 +21,7 @@ import StringIO
 import zlib
 
 from avro import datafile
-from avro import io as avro_io
+from avro import io as avroio
 from avro import schema
 
 from apache_beam.io import filebasedsource
@@ -75,159 +74,170 @@ class ReadFromAvro(PTransform):
     self._min_bundle_size = min_bundle_size
 
   def apply(self, pcoll):
-    return pcoll.pipeline | Read(_AvroSource(
-        file_pattern=self._file_pattern, min_bundle_size=self._min_bundle_size))
+    return pcoll.pipeline | Read(
+        _AvroSource(
+            file_pattern=self._file_pattern,
+            min_bundle_size=self._min_bundle_size))
 
 
-class _AvroSource(filebasedsource.FileBasedSource):
-  """A source for reading Avro files.
+class _AvroUtils(object):
 
-  ``_AvroSource`` is implemented using the file-based source framework available
-  in module 'filebasedsource'. Hence please refer to module 'filebasedsource'
-  to fully understand how this source implements operations common to all
-  file-based sources such as file-pattern expansion and splitting into bundles
-  for parallel processing.
-  """
+  @staticmethod
+  def read_meta_data_from_file(f):
+    """Reads metadata from a given Avro file.
 
-  def __init__(self, file_pattern, min_bundle_size=0):
-    super(_AvroSource, self).__init__(file_pattern, min_bundle_size)
-    self._avro_schema = None
-    self._codec = None
-    self._sync_marker = None
-
-  class AvroBlock(object):
-    """Represents a block of an Avro file."""
-
-    def __init__(self, block_bytes, num_records, avro_schema, avro_codec,
-                 offset):
-      self._block_bytes = block_bytes
-      self._num_records = num_records
-      self._avro_schema = avro_schema
-      self._avro_codec = avro_codec
-      self._offset = offset
-
-    def size(self):
-      return len(self._block_bytes)
-
-    def _decompress_bytes(self, encoding, data):
-      if encoding == 'null':
-        return data
-      elif encoding == 'deflate':
-        # zlib.MAX_WBITS is the window size. '-' sign indicates that this is
-        # raw data (without headers). See zlib and Avro documentations for more
-        # details.
-        return zlib.decompress(data, -zlib.MAX_WBITS)
-      else:
-        raise ValueError('Unsupported compression type: %r', encoding)
+    Args:
+      f: Avro file to read.
+    Returns:
+      a tuple containing the codec, schema, and the sync marker of the Avro
+      file.
 
-    def records(self):
-      decompressed_bytes = self._decompress_bytes(self._avro_codec,
-                                                  self._block_bytes)
-      decoder = avro_io.BinaryDecoder(StringIO.StringIO(decompressed_bytes))
-      reader = avro_io.DatumReader(
-          writers_schema=schema.parse(self._avro_schema),
-          readers_schema=schema.parse(self._avro_schema))
+    Raises:
+      ValueError: if the file does not start with the byte sequence defined in
+                  the specification.
+    """
+    if f.tell() > 0:
+      f.seek(0)
+    decoder = avroio.BinaryDecoder(f)
+    header = avroio.DatumReader().read_data(datafile.META_SCHEMA,
+                                            datafile.META_SCHEMA, decoder)
+    if header.get('magic') != datafile.MAGIC:
+      raise ValueError('Not an Avro file. File header should start with %s but'
+                       'started with %s instead.', datafile.MAGIC,
+                       header.get('magic'))
 
-      current_record = 0
-      while current_record < self._num_records:
-        yield reader.read(decoder)
-        current_record += 1
+    meta = header['meta']
 
-    def offset(self):
-      return self._offset
+    if datafile.CODEC_KEY in meta:
+      codec = meta[datafile.CODEC_KEY]
+    else:
+      codec = 'null'
 
-  def read_records(self, file_name, range_tracker):
-    start_offset = range_tracker.start_position()
-    if start_offset is None:
-      start_offset = 0
+    schema_string = meta[datafile.SCHEMA_KEY]
+    sync_marker = header['sync']
 
-    f = self.open_file(file_name)
-    try:
-      self._codec, self._avro_schema, self._sync_marker = (
-          AvroUtils.read_meta_data_from_file(f))
+    return codec, schema_string, sync_marker
 
-      # We have to start at current position if previous bundle ended at the
-      # end of a sync marker.
-      start_offset = max(0, start_offset - len(self._sync_marker))
+  @staticmethod
+  def read_block_from_file(f, codec, schema, expected_sync_marker):
+    """Reads a block from a given Avro file.
 
-      f.seek(start_offset)
-      while self.advance_pass_next_sync_marker(f):
-        if not range_tracker.try_claim(f.tell()):
-          return
-        next_block = self.read_next_block(f)
-        if next_block:
-          for record in next_block.records():
-            yield record
-        else:
-          return
-    finally:
-      f.close()
-
-  def advance_pass_next_sync_marker(self, f):
+    Args:
+      f: Avro file to read.
+    Returns:
+      A single _AvroBlock.
+
+    Raises:
+      ValueError: If the block cannot be read properly because the file doesn't
+        match the specification.
+    """
+    decoder = avroio.BinaryDecoder(f)
+    num_records = decoder.read_long()
+    block_size = decoder.read_long()
+    block_bytes = decoder.read(block_size)
+    sync_marker = decoder.read(len(expected_sync_marker))
+    if sync_marker != expected_sync_marker:
+      raise ValueError('Unexpected sync marker (actual "%s" vs expected "%s"). '
+                       'Maybe the underlying avro file is corrupted?',
+                       sync_marker, expected_sync_marker)
+    return _AvroBlock(block_bytes, num_records, codec, schema)
+
+  @staticmethod
+  def advance_file_past_next_sync_marker(f, sync_marker):
     buf_size = 10000
 
     data = f.read(buf_size)
     while data:
-      pos = data.find(self._sync_marker)
+      pos = data.find(sync_marker)
       if pos >= 0:
         # Adjusting the current position to the ending position of the sync
         # marker.
-        backtrack = len(data) - pos - len(self._sync_marker)
+        backtrack = len(data) - pos - len(sync_marker)
         f.seek(-1 * backtrack, os.SEEK_CUR)
         return True
       else:
-        if f.tell() >= len(self._sync_marker):
+        if f.tell() >= len(sync_marker):
           # Backtracking in case we partially read the sync marker during the
           # previous read. We only have to backtrack if there are at least
           # len(sync_marker) bytes before current position. We only have to
           # backtrack (len(sync_marker) - 1) bytes.
-          f.seek(-1 * (len(self._sync_marker) - 1), os.SEEK_CUR)
+          f.seek(-1 * (len(sync_marker) - 1), os.SEEK_CUR)
         data = f.read(buf_size)
 
-  def read_next_block(self, f):
-    decoder = avro_io.BinaryDecoder(f)
-    num_records = decoder.read_long()
-    block_size = decoder.read_long()
 
-    block_bytes = decoder.read(block_size)
-    return _AvroSource.AvroBlock(block_bytes, num_records,
-                                 self._avro_schema,
-                                 self._codec, f.tell()) if block_bytes else None
+class _AvroBlock(object):
+  """Represents a block of an Avro file."""
+
+  def __init__(self, block_bytes, num_records, codec, schema_string):
+    self._block_bytes = block_bytes
+    self._num_records = num_records
+    self._codec = codec
+    self._schema = schema.parse(schema_string)
+
+  def _decompress_bytes(self, data):
+    if self._codec == 'null':
+      return data
+    elif self._codec == 'deflate':
+      # zlib.MAX_WBITS is the window size. '-' sign indicates that this is
+      # raw data (without headers). See zlib and Avro documentations for more
+      # details.
+      return zlib.decompress(data, -zlib.MAX_WBITS)
+    elif self._codec == 'snappy':
+      # Snappy is an optional avro codec.
+      # See Snappy and Avro documentation for more details.
+      try:
+        import snappy
+      except ImportError:
+        raise ValueError('Snappy does not seem to be installed.')
+
+      # Compressed data includes a 4-byte CRC32 checksum which we verify.
+      result = snappy.decompress(data[:-4])
+      avroio.BinaryDecoder(StringIO.StringIO(data[-4:])).check_crc32(result)
+      return result
+    else:
+      raise ValueError('Unknown codec: %r', self._codec)
 
+  def num_records(self):
+    return self._num_records
 
-class AvroUtils(object):
+  def records(self):
+    decompressed_bytes = self._decompress_bytes(self._block_bytes)
+    decoder = avroio.BinaryDecoder(StringIO.StringIO(decompressed_bytes))
+    reader = avroio.DatumReader(
+        writers_schema=self._schema, readers_schema=self._schema)
 
-  @staticmethod
-  def read_meta_data_from_file(f):
-    """Reads metadata from a given Avro file.
+    current_record = 0
+    while current_record < self._num_records:
+      yield reader.read(decoder)
+      current_record += 1
 
-    Args:
-      f: Avro file to read.
-    Returns:
-      a tuple containing the codec, schema, and the sync marker of the Avro
-      file.
 
-    Raises:
-      ValueError: if the file does not start with the byte sequence defined in
-                  the specification.
-    """
-    f.seek(0, 0)
-    header = avro_io.DatumReader().read_data(datafile.META_SCHEMA,
-                                             datafile.META_SCHEMA,
-                                             avro_io.BinaryDecoder(f))
-    if header.get('magic') != datafile.MAGIC:
-      raise ValueError('Not an Avro file. File header should start with %s but'
-                       'started with %s instead.',
-                       datafile.MAGIC, header.get('magic'))
+class _AvroSource(filebasedsource.FileBasedSource):
+  """A source for reading Avro files.
 
-    meta = header['meta']
+  ``_AvroSource`` is implemented using the file-based source framework available
+  in module 'filebasedsource'. Hence please refer to module 'filebasedsource'
+  to fully understand how this source implements operations common to all
+  file-based sources such as file-pattern expansion and splitting into bundles
+  for parallel processing.
+  """
 
-    if datafile.CODEC_KEY in meta:
-      codec = meta[datafile.CODEC_KEY]
-    else:
-      codec = 'null'
+  def read_records(self, file_name, range_tracker):
+    start_offset = range_tracker.start_position()
+    if start_offset is None:
+      start_offset = 0
 
-    schema_string = meta[datafile.SCHEMA_KEY]
-    sync_marker = header['sync']
+    with self.open_file(file_name) as f:
+      codec, schema_string, sync_marker = _AvroUtils.read_meta_data_from_file(f)
 
-    return codec, schema_string, sync_marker
+      # We have to start at current position if previous bundle ended at the
+      # end of a sync marker.
+      start_offset = max(0, start_offset - len(sync_marker))
+      f.seek(start_offset)
+      _AvroUtils.advance_file_past_next_sync_marker(f, sync_marker)
+
+      while range_tracker.try_claim(f.tell()):
+        block = _AvroUtils.read_block_from_file(f, codec, schema_string,
+                                                sync_marker)
+        for record in block.records():
+          yield record

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/f5557c00/sdks/python/apache_beam/io/avroio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/avroio_test.py b/sdks/python/apache_beam/io/avroio_test.py
index 51bc375..29c4209 100644
--- a/sdks/python/apache_beam/io/avroio_test.py
+++ b/sdks/python/apache_beam/io/avroio_test.py
@@ -39,14 +39,24 @@ class TestAvro(unittest.TestCase):
     # environments with limited amount of resources.
     filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2
 
-  RECORDS = [{'name': 'Thomas', 'favorite_number': 1, 'favorite_color': 'blue'},
-             {'name': 'Henry', 'favorite_number': 3, 'favorite_color': 'green'},
-             {'name': 'Toby', 'favorite_number': 7, 'favorite_color': 'brown'},
-             {'name': 'Gordon', 'favorite_number': 4, 'favorite_color': 'blue'},
-             {'name': 'Emily', 'favorite_number': -1, 'favorite_color': 'Red'},
-             {'name': 'Percy', 'favorite_number': 6, 'favorite_color': 'Green'}]
-
-  def _write_data(self, directory=None,
+  RECORDS = [{'name': 'Thomas',
+              'favorite_number': 1,
+              'favorite_color': 'blue'}, {'name': 'Henry',
+                                          'favorite_number': 3,
+                                          'favorite_color': 'green'},
+             {'name': 'Toby',
+              'favorite_number': 7,
+              'favorite_color': 'brown'}, {'name': 'Gordon',
+                                           'favorite_number': 4,
+                                           'favorite_color': 'blue'},
+             {'name': 'Emily',
+              'favorite_number': -1,
+              'favorite_color': 'Red'}, {'name': 'Percy',
+                                         'favorite_number': 6,
+                                         'favorite_color': 'Green'}]
+
+  def _write_data(self,
+                  directory=None,
                   prefix=tempfile.template,
                   codec='null',
                   count=len(RECORDS)):
@@ -83,24 +93,27 @@ class TestAvro(unittest.TestCase):
     file_name_prefix = file_name[:file_name.rfind(os.path.sep)]
     return file_name_prefix + os.path.sep + 'mytemp*'
 
-  def _run_avro_test(
-      self, pattern, desired_bundle_size, perform_splitting, expected_result):
+  def _run_avro_test(self, pattern, desired_bundle_size, perform_splitting,
+                     expected_result):
     source = AvroSource(pattern)
 
     read_records = []
     if perform_splitting:
       assert desired_bundle_size
-      splits = [split for split in source.split(
-          desired_bundle_size=desired_bundle_size)]
+      splits = [
+          split
+          for split in source.split(desired_bundle_size=desired_bundle_size)
+      ]
       if len(splits) < 2:
         raise ValueError('Test is trivial. Please adjust it so that at least '
                          'two splits get generated')
 
       sources_info = [
           (split.source, split.start_position, split.stop_position)
-          for split in splits]
-      source_test_utils.assertSourcesEqualReferenceSource(
-          (source, None, None), sources_info)
+          for split in splits
+      ]
+      source_test_utils.assertSourcesEqualReferenceSource((source, None, None),
+                                                          sources_info)
     else:
       read_records = source_test_utils.readFromSource(source, None, None)
       self.assertItemsEqual(expected_result, read_records)
@@ -135,6 +148,28 @@ class TestAvro(unittest.TestCase):
     expected_result = self.RECORDS
     self._run_avro_test(file_name, 100, True, expected_result)
 
+  def test_read_without_splitting_compressed_snappy(self):
+    try:
+      import snappy  # pylint: disable=unused-variable
+      file_name = self._write_data(codec='snappy')
+      expected_result = self.RECORDS
+      self._run_avro_test(file_name, None, False, expected_result)
+    except ImportError:
+      logging.warning(
+          'Skipped test_read_without_splitting_compressed_snappy since snappy '
+          'appears to not be installed.')
+
+  def test_read_with_splitting_compressed_snappy(self):
+    try:
+      import snappy  # pylint: disable=unused-variable
+      file_name = self._write_data(codec='snappy')
+      expected_result = self.RECORDS
+      self._run_avro_test(file_name, 100, True, expected_result)
+    except ImportError:
+      logging.warning(
+          'Skipped test_read_with_splitting_compressed_snappy since snappy '
+          'appears to not be installed.')
+
   def test_read_without_splitting_pattern(self):
     pattern = self._write_pattern(3)
     expected_result = self.RECORDS * 3
@@ -153,13 +188,32 @@ class TestAvro(unittest.TestCase):
       avro.datafile.SYNC_INTERVAL = 5
       file_name = self._write_data(count=20)
       source = AvroSource(file_name)
-      splits = [split for split in source.split(
-          desired_bundle_size=float('inf'))]
+      splits = [split
+                for split in source.split(desired_bundle_size=float('inf'))]
       assert len(splits) == 1
       source_test_utils.assertSplitAtFractionExhaustive(splits[0].source)
     finally:
       avro.datafile.SYNC_INTERVAL = old_sync_interval
 
+  def test_corrupted_file(self):
+    file_name = self._write_data()
+    with open(file_name, 'r') as f:
+      data = bytearray(f.read())
+
+    # Corrupt the last character of the file which is also the last character of
+    # the last sync_marker.
+    with tempfile.NamedTemporaryFile(
+        delete=False, prefix=tempfile.template) as f:
+      last_char_index = len(data) - 1
+      data[last_char_index] = 'A' if data[last_char_index] == 'B' else 'A'
+      f.write(data)
+      corrupted_file_name = f.name
+
+    source = AvroSource(corrupted_file_name)
+    with self.assertRaises(ValueError) as exn:
+      source_test_utils.readFromSource(source, None, None)
+      self.assertEqual(0, exn.exception.message.find('Unexpected sync marker'))
+
 
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)


Mime
View raw message