avro-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From fo...@apache.org
Subject [avro] branch master updated: [AVRO-1816] Add support for logical decimal type for python AVRO (#82)
Date Tue, 12 Feb 2019 12:11:03 GMT
This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/avro.git


The following commit(s) were added to refs/heads/master by this push:
     new 114fee5  [AVRO-1816] Add support for logical decimal type for python AVRO (#82)
114fee5 is described below

commit 114fee5317ca2bb2b09864c66da99e86190faffb
Author: Prem Santosh <premsantosh@gmail.com>
AuthorDate: Tue Feb 12 04:10:57 2019 -0800

    [AVRO-1816] Add support for logical decimal type for python AVRO (#82)
    
    * Added decimal logical type support for python avro
    
    * Fixed issues with encoder and decoder [variable length]
    
    * Revamped logical schema design
    
    * Fixed bug with bytes packed in decimal encoder
    
    * Fixed exponent scale
---
 lang/py/src/avro/decimal_encoder.py |   0
 lang/py/src/avro/io.py              | 158 ++++++++++++++++++++++++++++++++++--
 lang/py/src/avro/schema.py          | 108 +++++++++++++++++++++++-
 lang/py/test/test_io.py             |  12 +++
 lang/py/test/test_schema.py         |  92 +++++++++++++++++++++
 5 files changed, 360 insertions(+), 10 deletions(-)

diff --git a/lang/py/src/avro/decimal_encoder.py b/lang/py/src/avro/decimal_encoder.py
new file mode 100644
index 0000000..e69de29
diff --git a/lang/py/src/avro/io.py b/lang/py/src/avro/io.py
index b2fd2f9..2901660 100644
--- a/lang/py/src/avro/io.py
+++ b/lang/py/src/avro/io.py
@@ -41,6 +41,9 @@ from avro import schema
 import sys
 from binascii import crc32
 
+from decimal import Decimal
+from decimal import getcontext
+
 try:
 	import json
 except ImportError:
@@ -68,11 +71,14 @@ else:
       return struct.unpack(self.format, *args)
   struct_class = SimpleStruct
 
-STRUCT_INT = struct_class('!I')     # big-endian unsigned int
-STRUCT_LONG = struct_class('!Q')    # big-endian unsigned long long
-STRUCT_FLOAT = struct_class('!f')   # big-endian float
-STRUCT_DOUBLE = struct_class('!d')  # big-endian double
-STRUCT_CRC32 = struct_class('>I')   # big-endian unsigned int
+STRUCT_INT = struct_class('!I')             # big-endian unsigned int
+STRUCT_LONG = struct_class('!Q')            # big-endian unsigned long long
+STRUCT_FLOAT = struct_class('!f')           # big-endian float
+STRUCT_DOUBLE = struct_class('!d')          # big-endian double
+STRUCT_CRC32 = struct_class('>I')           # big-endian unsigned int
+STRUCT_SIGNED_SHORT = struct_class('>h')    # big-endian signed short
+STRUCT_SIGNED_INT = struct_class('>i')      # big-endian signed int
+STRUCT_SIGNED_LONG = struct_class('>q')     # big-endian signed long
 
 #
 # Exceptions
@@ -108,6 +114,9 @@ def validate(expected_schema, datum):
   elif schema_type == 'string':
     return isinstance(datum, basestring)
   elif schema_type == 'bytes':
+    if (hasattr(expected_schema, 'logical_type') and
+            expected_schema.logical_type == 'decimal'):
+      return isinstance(datum, Decimal)
     return isinstance(datum, str)
   elif schema_type == 'int':
     return ((isinstance(datum, int) or isinstance(datum, long)) 
@@ -118,7 +127,11 @@ def validate(expected_schema, datum):
   elif schema_type in ['float', 'double']:
     return (isinstance(datum, int) or isinstance(datum, long)
             or isinstance(datum, float))
+  # Check for int, float, long and decimal
   elif schema_type == 'fixed':
+    if (hasattr(expected_schema, 'logical_type') and
+                    expected_schema.logical_type == 'decimal'):
+      return isinstance(datum, Decimal)
     return isinstance(datum, str) and len(datum) == expected_schema.size
   elif schema_type == 'enum':
     return datum in expected_schema.symbols
@@ -219,6 +232,41 @@ class BinaryDecoder(object):
       ((ord(self.read(1)) & 0xffL) << 56))
     return STRUCT_DOUBLE.unpack(STRUCT_LONG.pack(bits))[0]
 
+  def read_decimal_from_bytes(self, precision, scale):
+    """
+    Decimal bytes are decoded as signed short, int or long depending on the
+    size of bytes.
+    """
+    size = self.read_long()
+    return self.read_decimal_from_fixed(precision, scale, size)
+
+  def read_decimal_from_fixed(self, precision, scale, size):
+    """
+    Decimal is encoded as fixed. Fixed instances are encoded using the
+    number of bytes declared in the schema.
+    """
+    datum = self.read(size)
+    unscaled_datum = 0
+    msb = struct.unpack('!b', datum[0])[0]
+    leftmost_bit = (msb >> 7) & 1
+    if leftmost_bit == 1:
+      modified_first_byte = ord(datum[0]) ^ (1 << 7)
+      datum = chr(modified_first_byte) + datum[1:]
+      for offset in range(size):
+        unscaled_datum <<= 8
+        unscaled_datum += ord(datum[offset])
+      unscaled_datum += pow(-2, (size*8) - 1)
+    else:
+      for offset in range(size):
+        unscaled_datum <<= 8
+        unscaled_datum += ord(datum[offset])
+
+    original_prec = getcontext().prec
+    getcontext().prec = precision
+    scaled_datum = Decimal(unscaled_datum).scaleb(-scale)
+    getcontext().prec = original_prec
+    return scaled_datum
+
   def read_bytes(self):
     """
     Bytes are encoded as a long followed by that many bytes of data. 
@@ -341,6 +389,74 @@ class BinaryEncoder(object):
     self.write(chr((bits >> 48) & 0xFF))
     self.write(chr((bits >> 56) & 0xFF))
 
+  def write_decimal_bytes(self, datum, scale):
+    """
+    Decimal in bytes are encoded as long. Since size of packed value in bytes for
+    signed long is 8, 8 bytes are written.
+    """
+    sign, digits, exp = datum.as_tuple()
+    if exp > scale:
+      raise AvroTypeException('Scale provided in schema does not match the decimal')
+
+    unscaled_datum = 0
+    for digit in digits:
+      unscaled_datum = (unscaled_datum * 10) + digit
+
+    bits_req = unscaled_datum.bit_length() + 1
+    if sign:
+      unscaled_datum = (1 << bits_req) - unscaled_datum
+
+    bytes_req = bits_req // 8
+    padding_bits = ~((1 << bits_req) - 1) if sign else 0
+    packed_bits = padding_bits | unscaled_datum
+
+    bytes_req += 1 if (bytes_req << 3) < bits_req else 0
+    self.write_long(bytes_req)
+    for index in range(bytes_req-1, -1, -1):
+      bits_to_write = packed_bits >> (8 * index)
+      self.write(chr(bits_to_write & 0xff))
+
+  def write_decimal_fixed(self, datum, scale, size):
+    """
+    Decimal in fixed are encoded as size of fixed bytes.
+    """
+    sign, digits, exp = datum.as_tuple()
+    if exp > scale:
+      raise AvroTypeException('Scale provided in schema does not match the decimal')
+
+    unscaled_datum = 0
+    for digit in digits:
+      unscaled_datum = (unscaled_datum * 10) + digit
+
+    bits_req = unscaled_datum.bit_length() + 1
+    size_in_bits = size * 8
+    offset_bits = size_in_bits - bits_req
+
+    mask = 2 ** size_in_bits - 1
+    bit = 1
+    for i in range(bits_req):
+      mask ^= bit
+      bit <<= 1
+
+    if bits_req < 8:
+      bytes_req = 1
+    else:
+      bytes_req = bits_req // 8
+      if bits_req % 8 != 0:
+        bytes_req += 1
+    if sign:
+      unscaled_datum = (1 << bits_req) - unscaled_datum
+      unscaled_datum = mask | unscaled_datum
+      for index in range(size-1, -1, -1):
+        bits_to_write = unscaled_datum >> (8 * index)
+        self.write(chr(bits_to_write & 0xff))
+    else:
+      for i in range(offset_bits/8):
+        self.write(chr(0))
+      for index in range(bytes_req-1, -1, -1):
+        bits_to_write = unscaled_datum >> (8 * index)
+        self.write(chr(bits_to_write & 0xff))
+
   def write_bytes(self, datum):
     """
     Bytes are encoded as a long followed by that many bytes of data. 
@@ -475,8 +591,22 @@ class DatumReader(object):
     elif writers_schema.type == 'double':
       return decoder.read_double()
     elif writers_schema.type == 'bytes':
-      return decoder.read_bytes()
+      if (hasattr(writers_schema, 'logical_type') and
+                      writers_schema.logical_type == 'decimal'):
+        return decoder.read_decimal_from_bytes(
+          writers_schema.get_prop('precision'),
+          writers_schema.get_prop('scale')
+        )
+      else:
+        return decoder.read_bytes()
     elif writers_schema.type == 'fixed':
+      if (hasattr(writers_schema, 'logical_type') and
+                      writers_schema.logical_type == 'decimal'):
+        return decoder.read_decimal_from_fixed(
+          writers_schema.get_prop('precision'),
+          writers_schema.get_prop('scale'),
+          writers_schema.size
+        )
       return self.read_fixed(writers_schema, readers_schema, decoder)
     elif writers_schema.type == 'enum':
       return self.read_enum(writers_schema, readers_schema, decoder)
@@ -787,9 +917,21 @@ class DatumWriter(object):
     elif writers_schema.type == 'double':
       encoder.write_double(datum)
     elif writers_schema.type == 'bytes':
-      encoder.write_bytes(datum)
+      if (hasattr(writers_schema, 'logical_type') and
+                      writers_schema.logical_type == 'decimal'):
+        encoder.write_decimal_bytes(datum, writers_schema.get_prop('scale'))
+      else:
+        encoder.write_bytes(datum)
     elif writers_schema.type == 'fixed':
-      self.write_fixed(writers_schema, datum, encoder)
+      if (hasattr(writers_schema, 'logical_type') and
+                      writers_schema.logical_type == 'decimal'):
+        encoder.write_decimal_fixed(
+          datum,
+          writers_schema.get_prop('scale'),
+          writers_schema.get_prop('size')
+        )
+      else:
+        self.write_fixed(writers_schema, datum, encoder)
     elif writers_schema.type == 'enum':
       self.write_enum(writers_schema, datum, encoder)
     elif writers_schema.type == 'array':
diff --git a/lang/py/src/avro/schema.py b/lang/py/src/avro/schema.py
index 6a7fbbb..0737a94 100644
--- a/lang/py/src/avro/schema.py
+++ b/lang/py/src/avro/schema.py
@@ -33,6 +33,9 @@ A schema may be one of:
   A boolean; or
   Null.
 """
+from math import floor
+from math import log10
+
 try:
   import json
 except ImportError:
@@ -312,6 +315,42 @@ class NamedSchema(Schema):
   namespace = property(lambda self: self.get_prop('namespace'))
   fullname = property(lambda self: self._fullname)
 
+#
+# Logical type class
+#
+
+class LogicalSchema(object):
+  def __init__(self, logical_type):
+    self.logical_type = logical_type
+
+#
+# Decimal logical schema
+#
+
+class DecimalLogicalSchema(LogicalSchema):
+  def __init__(self, precision, scale=0):
+    max_precision = self._max_precision()
+    if not isinstance(precision, int) or precision <= 0:
+      raise SchemaParseException("""Precision is required for logical type
+                                DECIMAL and must be a positive integer but
+                                is %s.""" % precision)
+    elif precision > max_precision:
+      raise SchemaParseException("Cannot store precision digits. Max is %s"
+                                 %(max_precision))
+
+    if not isinstance(scale, int) or scale < 0:
+      raise SchemaParseException("Scale %s must be a positive Integer." % scale)
+
+    elif scale > precision:
+      raise SchemaParseException("Invalid DECIMAL scale %s. Cannot be greater than precision
%s"
+                                 %(scale, precision))
+
+    LogicalSchema.__init__(self, 'decimal')
+
+  def _max_precision(self):
+    raise NotImplementedError()
+
+
 class Field(object):
   def __init__(self, type, name, has_default, default=None,
                order=None,names=None, doc=None, other_props=None):
@@ -405,14 +444,40 @@ class PrimitiveSchema(Schema):
     return self.props == that.props
 
 #
+# Decimal Bytes Type
+#
+
+class BytesDecimalSchema(PrimitiveSchema, DecimalLogicalSchema):
+  def __init__(self, precision, scale=0, other_props=None):
+    DecimalLogicalSchema.__init__(self, precision, scale)
+    PrimitiveSchema.__init__(self, 'bytes', other_props)
+    self.set_prop('precision', precision)
+    self.set_prop('scale', scale)
+
+  # read-only properties
+  precision = property(lambda self: self.get_prop('precision'))
+  scale = property(lambda self: self.get_prop('scale'))
+
+  def _max_precision(self):
+    # Considering the max 32 bit integer value
+    return (1 << 31) - 1
+
+  def to_json(self, names=None):
+    return self.props
+
+  def __eq__(self, that):
+    return self.props == that.props
+
+
+#
 # Complex Types (non-recursive)
 #
 
 class FixedSchema(NamedSchema):
   def __init__(self, name, namespace, size, names=None, other_props=None):
     # Ensure valid ctor args
-    if not isinstance(size, int):
-      fail_msg = 'Fixed Schema requires a valid integer for size property.'
+    if not isinstance(size, int) or size < 0:
+      fail_msg = 'Fixed Schema requires a valid positive integer for size property.'
       raise AvroException(fail_msg)
 
     # Call parent ctor
@@ -436,6 +501,31 @@ class FixedSchema(NamedSchema):
   def __eq__(self, that):
     return self.props == that.props
 
+#
+# Decimal Fixed Type
+#
+
+class FixedDecimalSchema(FixedSchema, DecimalLogicalSchema):
+  def __init__(self, size, name, precision, scale=0, namespace=None, names=None, other_props=None):
+    FixedSchema.__init__(self, name, namespace, size, names, other_props)
+    DecimalLogicalSchema.__init__(self, precision, scale)
+    self.set_prop('precision', precision)
+    self.set_prop('scale', scale)
+
+  # read-only properties
+  precision = property(lambda self: self.get_prop('precision'))
+  scale = property(lambda self: self.get_prop('scale'))
+
+  def _max_precision(self):
+    return round(floor(log10(pow(2, (8 * self.size - 1)) - 1)))
+
+  def to_json(self, names=None):
+    return self.props
+
+  def __eq__(self, that):
+    return self.props == that.props
+
+
 class EnumSchema(NamedSchema):
   def __init__(self, name, namespace, symbols, names=None, doc=None, other_props=None):
     # Ensure valid ctor args
@@ -722,13 +812,27 @@ def make_avsc_object(json_data, names=None):
   if hasattr(json_data, 'get') and callable(json_data.get):
     type = json_data.get('type')
     other_props = get_other_props(json_data, SCHEMA_RESERVED_PROPS)
+    logical_type = None
+    if 'logicalType' in json_data:
+      logical_type = json_data.get('logicalType')
+      if logical_type != 'decimal':
+       raise SchemaParseException("Currently does not support %s logical type" % logical_type)
     if type in PRIMITIVE_TYPES:
+      if type == 'bytes':
+        if logical_type == 'decimal':
+          precision = json_data.get('precision')
+          scale = 0 if json_data.get('scale') is None else json_data.get('scale')
+          return BytesDecimalSchema(precision, scale, other_props)
       return PrimitiveSchema(type, other_props)
     elif type in NAMED_TYPES:
       name = json_data.get('name')
       namespace = json_data.get('namespace', names.default_namespace)
       if type == 'fixed':
         size = json_data.get('size')
+        if logical_type == 'decimal':
+          precision = json_data.get('precision')
+          scale = 0 if json_data.get('scale') is None else json_data.get('scale')
+          return FixedDecimalSchema(size, name, precision, scale, namespace, names, other_props)
         return FixedSchema(name, namespace, size, names, other_props)
       elif type == 'enum':
         symbols = json_data.get('symbols')
diff --git a/lang/py/test/test_io.py b/lang/py/test/test_io.py
index 1e79d3e..df8b180 100644
--- a/lang/py/test/test_io.py
+++ b/lang/py/test/test_io.py
@@ -14,6 +14,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import unittest
+
+from decimal import Decimal
+
 try:
   from cStringIO import StringIO
 except ImportError:
@@ -35,6 +38,12 @@ SCHEMAS_TO_VALIDATE = (
   ('"float"', 1234.0),
   ('"double"', 1234.0),
   ('{"type": "fixed", "name": "Test", "size": 1}', 'B'),
+  ('{"type": "fixed", "logicalType": "decimal", "name": "Test", "size": 8, "precision": 5,
"scale": 4}',
+   Decimal('3.1415')),
+  ('{"type": "fixed", "logicalType": "decimal", "name": "Test", "size": 8, "precision": 5,
"scale": 4}',
+   Decimal('-3.1415')),
+  ('{"type": "bytes", "logicalType": "decimal", "precision": 5, "scale": 4}', Decimal('3.1415')),
+  ('{"type": "bytes", "logicalType": "decimal", "precision": 5, "scale": 4}', Decimal('-3.1415')),
   ('{"type": "enum", "name": "Test", "symbols": ["A", "B"]}', 'B'),
   ('{"type": "array", "items": "long"}', [1, 3, 2]),
   ('{"type": "map", "values": "long"}', {'a': 1, 'b': 3, 'c': 2}),
@@ -199,6 +208,9 @@ class TestIO(unittest.TestCase):
       round_trip_datum = read_datum(writer, writers_schema)
 
       print 'Round Trip Datum: %s' % round_trip_datum
+      if isinstance(round_trip_datum, Decimal):
+        round_trip_datum = round_trip_datum.to_eng_string()
+        datum = str(datum)
       if datum == round_trip_datum: correct += 1
     self.assertEquals(correct, len(SCHEMAS_TO_VALIDATE))
 
diff --git a/lang/py/test/test_schema.py b/lang/py/test/test_schema.py
index 00e2a05..e556ae2 100644
--- a/lang/py/test/test_schema.py
+++ b/lang/py/test/test_schema.py
@@ -17,6 +17,9 @@
 Test the schema parsing logic.
 """
 import unittest
+
+from avro.schema import SchemaParseException, AvroException
+
 import set_avro_test_path
 
 from avro import schema
@@ -295,6 +298,21 @@ OTHER_PROP_EXAMPLES = [
     """, True)
 ]
 
+DECIMAL_LOGICAL_TYPE = [
+  ExampleSchema("""{
+  "type": "fixed",
+  "logicalType": "decimal",
+  "name": "TestDecimal",
+  "precision": 4,
+  "size": 10,
+  "scale": 2}""", True),
+  ExampleSchema("""{
+  "type": "bytes",
+  "logicalType": "decimal",
+  "precision": 4,
+  "scale": 2}""", True)
+]
+
 EXAMPLES = PRIMITIVE_EXAMPLES
 EXAMPLES += FIXED_EXAMPLES
 EXAMPLES += ENUM_EXAMPLES
@@ -303,6 +321,7 @@ EXAMPLES += MAP_EXAMPLES
 EXAMPLES += UNION_EXAMPLES
 EXAMPLES += RECORD_EXAMPLES
 EXAMPLES += DOC_EXAMPLES
+EXAMPLES += DECIMAL_LOGICAL_TYPE
 
 VALID_EXAMPLES = [e for e in EXAMPLES if e.valid]
 
@@ -491,5 +510,78 @@ class TestSchema(unittest.TestCase):
 
     self.assertTrue(caught_exception, 'Exception was not caught')
 
+  def test_decimal_invalid_schema(self):
+    invalid_schemas = [
+      ExampleSchema("""{
+      "type": "bytes",
+      "logicalType": "decimal",
+      "precision": 2,
+      "scale": -2}""", True),
+
+      ExampleSchema("""{
+      "type": "bytes",
+      "logicalType": "decimal",
+      "precision": -2,
+      "scale": 2}""", True),
+
+      ExampleSchema("""{
+      "type": "bytes",
+      "logicalType": "decimal",
+      "precision": 2,
+      "scale": 3}""", True),
+
+      ExampleSchema("""{
+      "type": "fixed",
+      "logicalType": "decimal",
+      "name": "TestDecimal",
+      "precision": -10,
+      "scale": 2,
+      "size": 5}""", True),
+
+
+      ExampleSchema("""{
+      "type": "fixed",
+      "logicalType": "decimal",
+      "name": "TestDecimal",
+      "precision": 2,
+      "scale": 3,
+      "size": 2}""", True)
+    ]
+
+    for invalid_schema in invalid_schemas:
+      self.assertRaises(SchemaParseException, schema.parse, invalid_schema.schema_string)
+
+    fixed_invalid_schema_size = ExampleSchema("""{
+                                "type": "fixed",
+                                "logicalType": "decimal",
+                                "name": "TestDecimal",
+                                "precision": 2,
+                                "scale": 2,
+                                "size": -2}""", True)
+    self.assertRaises(AvroException, schema.parse, fixed_invalid_schema_size.schema_string)
+
+  def test_decimal_valid_type(self):
+    fixed_decimal_schema = ExampleSchema("""{
+    "type": "fixed",
+    "logicalType": "decimal",
+    "name": "TestDecimal",
+    "precision": 4,
+    "scale": 2,
+    "size": 2}""", True)
+
+    bytes_decimal_schema = ExampleSchema("""{
+    "type": "bytes",
+    "logicalType": "decimal",
+    "precision": 4}""", True)
+
+    fixed_decimal = schema.parse(fixed_decimal_schema.schema_string)
+    self.assertEqual(4, fixed_decimal.get_prop('precision'))
+    self.assertEqual(2, fixed_decimal.get_prop('scale'))
+    self.assertEqual(2, fixed_decimal.get_prop('size'))
+
+    bytes_decimal = schema.parse(bytes_decimal_schema.schema_string)
+    self.assertEqual(4, bytes_decimal.get_prop('precision'))
+    self.assertEqual(0, bytes_decimal.get_prop('scale'))
+
 if __name__ == '__main__':
   unittest.main()


Mime
View raw message