avro-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From cutt...@apache.org
Subject svn commit: r896176 - in /hadoop/avro/trunk: CHANGES.txt src/py/avro/__init__.py src/py/avro/ipc.py src/py/avro/protocol.py src/test/py/sample_ipc_client.py src/test/py/sample_ipc_server.py src/test/py/test_protocol.py
Date Tue, 05 Jan 2010 18:48:04 GMT
Author: cutting
Date: Tue Jan  5 18:48:04 2010
New Revision: 896176

URL: http://svn.apache.org/viewvc?rev=896176&view=rev
Log:
Rework Python RPC.  Contributed by Jeff Hammerbacher.

Added:
    hadoop/avro/trunk/src/py/avro/ipc.py
    hadoop/avro/trunk/src/py/avro/protocol.py
    hadoop/avro/trunk/src/test/py/sample_ipc_client.py
    hadoop/avro/trunk/src/test/py/sample_ipc_server.py
    hadoop/avro/trunk/src/test/py/test_protocol.py
Modified:
    hadoop/avro/trunk/CHANGES.txt
    hadoop/avro/trunk/src/py/avro/__init__.py

Modified: hadoop/avro/trunk/CHANGES.txt
URL: http://svn.apache.org/viewvc/hadoop/avro/trunk/CHANGES.txt?rev=896176&r1=896175&r2=896176&view=diff
==============================================================================
--- hadoop/avro/trunk/CHANGES.txt (original)
+++ hadoop/avro/trunk/CHANGES.txt Tue Jan  5 18:48:04 2010
@@ -166,6 +166,8 @@
 
     AVRO-219. Rework Python API.  (Jeff Hammerbacher via cutting)
 
+    AVRO-264. Rework Python RPC.  (Jeff Hammerbacher via cutting)
+
   OPTIMIZATIONS
 
     AVRO-172. More efficient schema processing (massie)

Modified: hadoop/avro/trunk/src/py/avro/__init__.py
URL: http://svn.apache.org/viewvc/hadoop/avro/trunk/src/py/avro/__init__.py?rev=896176&r1=896175&r2=896176&view=diff
==============================================================================
--- hadoop/avro/trunk/src/py/avro/__init__.py (original)
+++ hadoop/avro/trunk/src/py/avro/__init__.py Tue Jan  5 18:48:04 2010
@@ -14,5 +14,5 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-__all__ = ['schema', 'io', 'datafile']
+__all__ = ['schema', 'io', 'datafile', 'protocol', 'ipc']
 

Added: hadoop/avro/trunk/src/py/avro/ipc.py
URL: http://svn.apache.org/viewvc/hadoop/avro/trunk/src/py/avro/ipc.py?rev=896176&view=auto
==============================================================================
--- hadoop/avro/trunk/src/py/avro/ipc.py (added)
+++ hadoop/avro/trunk/src/py/avro/ipc.py Tue Jan  5 18:48:04 2010
@@ -0,0 +1,461 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+# http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Support for inter-process calls.
+"""
+import cStringIO
+import struct
+import socket
+from avro import io
+from avro import protocol
+from avro import schema
+
+#
+# Constants
+#
+
+HANDSHAKE_REQUEST_SCHEMA = schema.parse("""\
+{
+  "type": "record",
+  "name": "HandshakeRequest", "namespace":"org.apache.avro.ipc",
+  "fields": [
+    {"name": "clientHash",
+     "type": {"type": "fixed", "name": "MD5", "size": 16}},
+    {"name": "clientProtocol", "type": ["null", "string"]},
+    {"name": "serverHash", "type": ["null", "MD5"]},
+    {"name": "meta", "type": ["null", {"type": "map", "values": "bytes"}]}
+  ]
+}""")
+
+HANDSHAKE_RESPONSE_SCHEMA = schema.parse("""\
+{
+  "type": "record",
+  "name": "HandshakeResponse", "namespace": "org.apache.avro.ipc",
+  "fields": [
+    {"name": "match",
+     "type": {"type": "enum", "name": "HandshakeMatch",
+              "symbols": ["BOTH", "CLIENT", "NONE"]}},
+    {"name": "serverProtocol", "type": ["null", "string"]},
+    {"name": "serverHash",
+     "type": ["null", {"type": "fixed", "name": "MD5", "size": 16}]},
+    {"name": "meta",
+     "type": ["null", {"type": "map", "values": "bytes"}]}
+  ]
+}
+""")
+
+HANDSHAKE_REQUESTOR_WRITER = io.DatumWriter(HANDSHAKE_REQUEST_SCHEMA)
+HANDSHAKE_REQUESTOR_READER = io.DatumReader(HANDSHAKE_RESPONSE_SCHEMA)
+HANDSHAKE_RESPONDER_WRITER = io.DatumWriter(HANDSHAKE_RESPONSE_SCHEMA)
+HANDSHAKE_RESPONDER_READER = io.DatumReader(HANDSHAKE_REQUEST_SCHEMA)
+
+META_SCHEMA = schema.parse('{"type": "map", "values": "bytes"}')
+META_WRITER = io.DatumWriter(META_SCHEMA)
+META_READER = io.DatumReader(META_SCHEMA)
+
+SYSTEM_ERROR_SCHEMA = schema.parse('["string"]')
+
+# protocol cache
+REMOTE_HASHES = {}
+REMOTE_PROTOCOLS = {}
+
+BIG_ENDIAN_INT_STRUCT = struct.Struct('!I')
+BUFFER_HEADER_LENGTH = 4
+BUFFER_SIZE = 8192
+
+#
+# Exceptions
+#
+
+class AvroRemoteException(schema.AvroException):
+  """
+  Raised when an error message is sent by an Avro requestor or responder.
+  """
+  def __init__(self, fail_msg=None):
+    schema.AvroException.__init__(self, fail_msg)
+
+class ConnectionClosedException(schema.AvroException):
+  pass
+
+#
+# Base IPC Classes (Requestor/Responder)
+#
+
+class Requestor(object):
+  """Base class for the client side of a protocol interaction."""
+  def __init__(self, local_protocol, transport):
+    self._local_protocol = local_protocol
+    self._transport = transport
+    self._remote_protocol = None
+    self._remote_hash = None
+    self._send_protocol = None
+
+  # read-only properties
+  local_protocol = property(lambda self: self._local_protocol)
+  transport = property(lambda self: self._transport)
+
+  # read/write properties
+  def set_remote_protocol(self, new_remote_protocol):
+    self._remote_protocol = new_remote_protocol
+    REMOTE_PROTOCOLS[self.transport.remote_name] = self.remote_protocol
+  remote_protocol = property(lambda self: self._remote_protocol,
+                             set_remote_protocol)
+  def set_remote_hash(self, new_remote_hash):
+    self._remote_hash = new_remote_hash
+    REMOTE_HASHES[self.transport.remote_name] = self.remote_hash
+  remote_hash = property(lambda self: self._remote_hash, set_remote_hash)
+  def set_send_protocol(self, new_send_protocol):
+    self._send_protocol = new_send_protocol
+  send_protocol = property(lambda self: self._send_protocol, set_send_protocol)
+
+  def request(self, message_name, request_datum):
+    """
+    Writes a request message and reads a response or error message.
+    """
+    # build handshake and call request
+    buffer_writer = cStringIO.StringIO()
+    buffer_encoder = io.BinaryEncoder(buffer_writer)
+    self.write_handshake_request(buffer_encoder)
+    self.write_call_request(message_name, request_datum, buffer_encoder)
+
+    # send the handshake and call request;  block until call response
+    call_request = buffer_writer.getvalue()
+    call_response = self.transport.transceive(call_request)
+
+    # process the handshake and call response
+    buffer_decoder = io.BinaryDecoder(cStringIO.StringIO(call_response))
+    call_response_exists = self.read_handshake_response(buffer_decoder)
+    if call_response_exists:
+      return self.read_call_response(message_name, buffer_decoder)
+    else:
+      self.request(message_name, request_datum)
+
+  def write_handshake_request(self, encoder):
+    local_hash = self.local_protocol.md5
+    remote_name = self.transport.remote_name
+    remote_hash = REMOTE_HASHES.get(remote_name)
+    if remote_hash is None:
+      remote_hash = local_hash
+      self.remote_protocol = self.local_protocol
+    request_datum = {}
+    request_datum['clientHash'] = local_hash
+    request_datum['serverHash'] = remote_hash
+    if self.send_protocol:
+      request_datum['clientProtocol'] = str(self.local_protocol)
+    HANDSHAKE_REQUESTOR_WRITER.write(request_datum, encoder)
+
+  def write_call_request(self, message_name, request_datum, encoder):
+    """
+    The format of a call request is:
+      * request metadata, a map with values of type bytes
+      * the message name, an Avro string, followed by
+      * the message parameters. Parameters are serialized according to
+        the message's request declaration.
+    """
+    # request metadata (not yet implemented)
+    request_metadata = {}
+    META_WRITER.write(request_metadata, encoder)
+
+    # message name
+    message = self.local_protocol.messages.get(message_name)
+    if message is None:
+      raise schema.AvroException('Unknown message: %s' % message_name)
+    encoder.write_utf8(message.name)
+
+    # message parameters
+    self.write_request(message.request, request_datum, encoder)
+
+  def write_request(self, request_fields, request_datum, encoder):
+    """
+    Looks an awful lot like new_io.write_record, eh?
+    """
+    for field in request_fields:
+      datum_writer = io.DatumWriter(field.type)
+      datum_writer.write(request_datum.get(field.name), encoder)
+
+  def read_handshake_response(self, decoder):
+    handshake_response = HANDSHAKE_REQUESTOR_READER.read(decoder)
+    match = handshake_response.get('match')
+    if match == 'BOTH':
+      self.send_protocol = False
+      return True
+    elif match == 'CLIENT':
+      if self.send_protocol:
+        raise schema.AvroException('Handshake failure.')
+      self.remote_protocol = handshake_response.get('serverProtocol')
+      self.remote_hash = handshake_response.get('serverHash')
+      self.send_protocol = False
+      return False
+    elif match == 'NONE':
+      if self.send_protocol:
+        raise schema.AvroException('Handshake failure.')
+      self.remote_protocol = handshake_response.get('serverProtocol')
+      self.remote_hash = handshake_response.get('serverHash')
+      self.send_protocol = True
+      return False
+    else:
+      raise schema.AvroException('Unexpected match: %s' % match)
+
+  def read_call_response(self, message_name, decoder):
+    """
+    The format of a call response is:
+      * response metadata, a map with values of type bytes
+      * a one-byte error flag boolean, followed by either:
+        o if the error flag is false,
+          the message response, serialized per the message's response schema.
+        o if the error flag is true, 
+          the error, serialized per the message's error union schema.
+    """
+    # response metadata
+    response_metadata = META_READER.read(decoder)
+
+    # remote response schema
+    remote_message_schema = self.remote_protocol.messages.get(message_name)
+    if remote_message_schema is None:
+      raise schema.AvroException('Unknown remote message: %s' % message_name)
+
+    # local response schema
+    local_message_schema = self.local_protocol.messages.get(message_name)
+    if local_message_schema is None:
+      raise schema.AvroException('Unknown local message: %s' % message_name)
+
+    # error flag
+    if not decoder.read_boolean():
+      writers_schema = remote_message_schema.response
+      readers_schema = local_message_schema.response
+      return self.read_response(writers_schema, readers_schema, decoder)
+    else:
+      writers_schema = remote_message_schema.errors or SYSTEM_ERROR_SCHEMA
+      readers_schema = local_message_schema.errors or SYSTEM_ERROR_SCHEMA
+      raise self.read_error(writers_schema, readers_schema, decoder)
+
+  def read_response(self, writers_schema, readers_schema, decoder):
+    datum_reader = io.DatumReader(writers_schema, readers_schema)
+    return datum_reader.read(decoder)
+
+  def read_error(self, writers_schema, readers_schema, decoder):
+    datum_reader = io.DatumReader(writers_schema, readers_schema)
+    return AvroRemoteException(datum_reader.read(decoder))
+
+class Responder(object):
+  """Base class for the server side of a protocol interaction."""
+  def __init__(self, local_protocol):
+    self._local_protocol = local_protocol
+    self._local_hash = self.local_protocol.md5
+    self._protocol_cache = {}
+    self.set_protocol_cache(self.local_hash, self.local_protocol)
+
+  # read-only properties
+  local_protocol = property(lambda self: self._local_protocol)
+  local_hash = property(lambda self: self._local_hash)
+  protocol_cache = property(lambda self: self._protocol_cache)
+
+  # utility functions to manipulate protocol cache
+  def get_protocol_cache(self, hash):
+    return self.protocol_cache.get(hash)
+  def set_protocol_cache(self, hash, protocol):
+    self.protocol_cache[hash] = protocol
+
+  def respond(self, transport):
+    """
+    Called by a server to deserialize a request, compute and serialize
+    a response or error. Compare to 'handle()' in Thrift.
+    """
+    call_request = transport.read_framed_message()
+    buffer_decoder = io.BinaryDecoder(cStringIO.StringIO(call_request))
+    buffer_writer = cStringIO.StringIO()
+    buffer_encoder = io.BinaryEncoder(buffer_writer)
+    error = None
+    response_metadata = {}
+    
+    try:
+      remote_protocol = self.process_handshake(transport, buffer_decoder,
+                                               buffer_encoder)
+      # handshake failure
+      if remote_protocol is None:  
+        return buffer_writer.getvalue()
+      
+      # read request using remote protocol
+      request_metadata = META_READER.read(buffer_decoder)
+      remote_message_name = buffer_decoder.read_utf8()
+
+      # get remote and local request schemas so we can do
+      # schema resolution (one fine day)
+      remote_message = remote_protocol.messages.get(remote_message_name)
+      if remote_message is None:
+        fail_msg = 'Unknown remote message: %s' % remote_message_name
+        raise schema.AvroException(fail_msg)
+      local_message = self.local_protocol.messages.get(remote_message_name)
+      if local_message is None:
+        fail_msg = 'Unknown local message: %s' % remote_message_name
+        raise schema.AvroException(fail_msg)
+      writers_fields = remote_message.request
+      # TODO(hammer) pass reader schema
+      request = self.read_request(writers_fields, buffer_decoder)
+      # perform server logic
+      try:
+        response = self.invoke(local_message, request)
+      except AvroRemoteException, e:
+        error = e
+      except Exception, e:
+        error = AvroRemoteException(str(e))
+
+      # write response using local protocol
+      META_WRITER.write(response_metadata, buffer_encoder)
+      buffer_encoder.write_boolean(error is not None)
+      if error is None:
+        writers_schema = local_message.response
+        self.write_response(writers_schema, response, buffer_encoder)
+      else:
+        writers_schema = local_message.errors or SYSTEM_ERROR_SCHEMA
+        self.write_error(writers_schema, error, buffer_encoder)
+    except schema.AvroException, e:
+      error = AvroRemoteException(str(e))
+      buffer_encoder = io.BinaryEncoder(cStringIO.StringIO())
+      META_WRITER.write(response_metadata, buffer_encoder)
+      buffer_encoder.write_boolean(True)
+      self.write_error(SYSTEM_ERROR_SCHEMA, error, buffer_encoder)
+    return buffer_writer.getvalue()
+
+  def process_handshake(self, transport, decoder, encoder):
+    handshake_request = HANDSHAKE_RESPONDER_READER.read(decoder)
+    handshake_response = {}
+
+    # determine the remote protocol
+    client_hash = handshake_request.get('clientHash')
+    client_protocol = handshake_request.get('clientProtocol')
+    remote_protocol = self.get_protocol_cache(client_hash)
+    if remote_protocol is None and client_protocol is not None:
+      remote_protocol = protocol.parse(client_protocol)
+      self.set_protocol_cache(client_hash, remote_protocol)
+
+    # evaluate remote's guess of the local protocol
+    server_hash = handshake_request.get('serverHash')
+    if self.local_hash == server_hash:
+      if remote_protocol is None:
+        handshake_response['match'] = 'NONE'
+      else:
+        handshake_response['match'] = 'BOTH'
+    else:
+      if remote_protocol is None:
+        handshake_response['match'] = 'NONE'
+      else:
+        handshake_response['match'] = 'CLIENT'
+
+    if handshake_response['match'] != 'BOTH':
+      handshake_response['serverProtocol'] = str(self.local_protocol)
+      handshake_response['serverHash'] = self.local_hash
+
+    HANDSHAKE_RESPONDER_WRITER.write(handshake_response, encoder)
+    return remote_protocol
+
+  def invoke(self, local_message, request):
+    """
+    Aactual work done by server: cf. handler in thrift.
+    """
+    pass
+
+  def read_request(self, writers_fields, decoder):
+    """
+    Need to handle schema resolution here. Half-assing it now.
+    """
+    request_data = []
+    for field in writers_fields:
+      datum_reader = io.DatumReader(field.type)
+      request_data.append(datum_reader.read(decoder))
+    return request_data
+
+  def write_response(self, writers_schema, response_datum, encoder):
+    datum_writer = io.DatumWriter(writers_schema)
+    datum_writer.write(response_datum, encoder)
+
+  def write_error(self, writers_schema, error_exception, encoder):
+    datum_writer = io.DatumWriter(writers_schema)
+    datum_writer.write(str(error_exception), encoder)
+
+#
+# Transport Implementations
+#
+
+class SocketTransport(object):
+  """A simple socket-based Transport implementation."""
+  def __init__(self, sock):
+    self._sock = sock
+
+  # read-only properties
+  sock = property(lambda self: self._sock)
+  remote_name = property(lambda self: self.sock.getsockname())
+
+  def transceive(self, request):
+    self.write_framed_message(request)
+    return self.read_framed_message()
+
+  def read_framed_message(self):
+    message = []
+    while True:
+      buffer = cStringIO.StringIO()
+      buffer_length = self.read_buffer_length()
+      if buffer_length == 0:
+        return ''.join(message)
+      while buffer.tell() < buffer_length:
+        chunk = self.sock.recv(buffer_length - buffer.tell())
+        if chunk == '':
+          raise ConnectionClosedException("Socket read 0 bytes.")
+        buffer.write(chunk)
+      message.append(buffer.getvalue())
+
+  def write_framed_message(self, message):
+    message_length = len(message)
+    total_bytes_sent = 0
+    while message_length - total_bytes_sent > 0:
+      if message_length - total_bytes_sent > BUFFER_SIZE:
+        buffer_length = BUFFER_SIZE
+      else:
+        buffer_length = message_length - total_bytes_sent
+      self.write_buffer(message[total_bytes_sent:
+                                (total_bytes_sent + buffer_length)])
+      total_bytes_sent += buffer_length
+    # A message is always terminated by a zero-length buffer.
+    self.write_buffer_length(0)
+
+  def write_buffer(self, chunk):
+    buffer_length = len(chunk)
+    self.write_buffer_length(buffer_length)
+    total_bytes_sent = 0
+    while total_bytes_sent < buffer_length:
+      bytes_sent = self.sock.send(chunk[total_bytes_sent:])
+      if bytes_sent == 0:
+        raise ConnectionClosedException("Socket sent 0 bytes.")
+      total_bytes_sent += bytes_sent
+
+  def write_buffer_length(self, n):
+    bytes_sent = self.sock.sendall(BIG_ENDIAN_INT_STRUCT.pack(n))
+    if bytes_sent == 0:
+      raise ConnectionClosedException("socket sent 0 bytes")
+
+  def read_buffer_length(self):
+    read = self.sock.recv(BUFFER_HEADER_LENGTH)
+    if read == '':
+      raise ConnectionClosedException("Socket read 0 bytes.")
+    return BIG_ENDIAN_INT_STRUCT.unpack(read)[0]
+
+  def close(self):
+    self.sock.close()
+
+#
+# Server Implementations (none yet)
+#
+

Added: hadoop/avro/trunk/src/py/avro/protocol.py
URL: http://svn.apache.org/viewvc/hadoop/avro/trunk/src/py/avro/protocol.py?rev=896176&view=auto
==============================================================================
--- hadoop/avro/trunk/src/py/avro/protocol.py (added)
+++ hadoop/avro/trunk/src/py/avro/protocol.py Tue Jan  5 18:48:04 2010
@@ -0,0 +1,219 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+# http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Protocol implementation.
+"""
+import cStringIO
+import md5
+try:
+  import simplejson as json
+except ImportError:
+  import json
+from avro import schema
+
+#
+# Constants
+#
+
+# TODO(hammer): confirmed 'fixed' with Doug
+VALID_TYPE_SCHEMA_TYPES = ('enum', 'record', 'error', 'fixed')
+
+#
+# Exceptions
+#
+
+class ProtocolParseException(schema.AvroException):
+  pass
+
+#
+# Base Classes
+#
+
+class Protocol(object):
+  """An application protocol."""
+  def _parse_types(self, types, type_names):
+    type_objects = []
+    for type in types:
+      type_object = schema.make_avsc_object(type, type_names)
+      if type_object.type not in VALID_TYPE_SCHEMA_TYPES:
+        fail_msg = 'Type %s not an enum, record, or error.' % type
+        raise ProtocolParseException(fail_msg)
+      type_objects.append(type_object)
+    return type_objects
+
+  def _parse_messages(self, messages, names):
+    message_objects = {}
+    for name, body in messages.iteritems():
+      if message_objects.has_key(name):
+        fail_msg = 'Message name "%s" repeated.' % name
+        raise ProtocolParseException(fail_msg)
+      elif not(hasattr(body, 'get') and callable(body.get)):
+        fail_msg = 'Message name "%s" has non-object body %s.' % (name, body)
+        raise ProtocolParseException(fail_msg)
+
+      request = body.get('request')
+      response = body.get('response')
+      errors = body.get('errors')
+      message_objects[name] = Message(name, request, response, errors, names)
+    return message_objects
+
+  def __init__(self, name, namespace=None, types=None, messages=None):
+    # Ensure valid ctor args
+    if not name:
+      fail_msg = 'Protocols must have a non-empty name.'
+      raise ProtocolParseException(fail_msg)
+    elif not isinstance(name, basestring):
+      fail_msg = 'The name property must be a string.'
+      raise ProtocolParseException(fail_msg)
+    elif namespace is not None and not isinstance(namespace, basestring):
+      fail_msg = 'The namespace property must be a string.'
+      raise ProtocolParseException(fail_msg)
+    elif types is not None and not isinstance(types, list):
+      fail_msg = 'The types property must be a list.'
+      raise ProtocolParseException(fail_msg)
+    elif (messages is not None and 
+          not(hasattr(messages, 'get') and callable(messages.get))):
+      fail_msg = 'The messages property must be a JSON object.'
+      raise ProtocolParseException(fail_msg)
+
+    self._props = {}
+    self.set_prop('name', name)
+    if namespace is not None: self.set_prop('namespace', namespace)
+    type_names = {}
+    if types is not None:
+      self.set_prop('types', self._parse_types(types, type_names))
+    if messages is not None:
+      self.set_prop('messages', self._parse_messages(messages, type_names))
+    self._md5 = md5.new(str(self)).digest()
+
+  # read-only properties
+  name = property(lambda self: self.get_prop('name'))
+  namespace = property(lambda self: self.get_prop('namespace'))
+  fullname = property(lambda self: 
+                      schema.Name.make_fullname(self.name, self.namespace))
+  types = property(lambda self: self.get_prop('types'))
+  types_dict = property(lambda self: dict([(type.name, type)
+                                           for type in self.types]))
+  messages = property(lambda self: self.get_prop('messages'))
+  md5 = property(lambda self: self._md5)
+  props = property(lambda self: self._props)
+
+  # utility functions to manipulate properties dict
+  def get_prop(self, key):
+    return self.props.get(key)
+  def set_prop(self, key, value):
+    self.props[key] = value  
+
+  def __str__(self):
+    # until we implement a JSON encoder for Schema and Message objects,
+    # we'll have to go through and call str() by hand.
+    to_dump = {}
+    to_dump['protocol'] = self.name
+    if self.namespace: to_dump['namespace'] = self.namespace
+    if self.types:
+      to_dump['types'] = [json.loads(str(t)) for t in self.types]
+    if self.messages:
+      messages_dict = {}
+      for name, body in self.messages.iteritems():
+        messages_dict[name] = json.loads(str(body))
+      to_dump['messages'] = messages_dict
+    return json.dumps(to_dump)
+
+  def __eq__(self, that):
+    to_cmp = json.loads(str(self))
+    return to_cmp == json.loads(str(that))
+
+class Message(object):
+  """A Protocol message."""
+  def _parse_request(self, request, names):
+    if not isinstance(request, list):
+      fail_msg = 'Request property not a list: %s' % request
+      raise ProtocolParseException(fail_msg)
+    return schema.RecordSchema.make_field_objects(request, names)
+  
+  def _parse_response(self, response, names):
+    if isinstance(response, basestring) and names.has_key(response):
+      self._response_from_names = True
+      return names.get(response)
+    else:
+      return schema.make_avsc_object(response, names)
+
+  def _parse_errors(self, errors, names):
+    if not isinstance(errors, list):
+      fail_msg = 'Errors property not a list: %s' % errors
+      raise ProtocolParseException(fail_msg)
+    return schema.make_avsc_object(errors, names)
+
+  def __init__(self,  name, request, response, errors=None, names=None):
+    self._name = name
+    self._response_from_names = False
+
+    self._props = {}
+    self.set_prop('request', self._parse_request(request, names))
+    self.set_prop('response', self._parse_response(response, names))
+    if errors is not None:
+      self.set_prop('errors', self._parse_errors(errors, names))
+
+  # read-only properties
+  name = property(lambda self: self._name)
+  response_from_names = property(lambda self: self._response_from_names)
+  request = property(lambda self: self.get_prop('request'))
+  response = property(lambda self: self.get_prop('response'))
+  errors = property(lambda self: self.get_prop('errors'))
+  props = property(lambda self: self._props)
+
+  # utility functions to manipulate properties dict
+  def get_prop(self, key):
+    return self.props.get(key)
+  def set_prop(self, key, value):
+    self.props[key] = value  
+
+  # TODO(hammer): allow schemas and fields to be JSON Encoded!
+  def __str__(self):
+    to_dump = {}
+    to_dump['request'] = [json.loads(str(r)) for r in self.request]
+    if self.response_from_names:
+      to_dump['response'] = self.response.fullname
+    else:
+      to_dump['response'] = json.loads(str(self.response))
+    if self.errors:
+      to_dump['errors'] = json.loads(str(self.errors))
+    return json.dumps(to_dump)
+
+  def __eq__(self, that):
+    return self.name == that.name and self.props == that.props
+      
+def make_avpr_object(json_data):
+  """Build Avro Protocol from data parsed out of JSON string."""
+  if hasattr(json_data, 'get') and callable(json_data.get):
+    name = json_data.get('protocol')
+    namespace = json_data.get('namespace')
+    types = json_data.get('types')
+    messages = json_data.get('messages')
+    return Protocol(name, namespace, types, messages)
+  else:
+    raise ProtocolParseException('Not a JSON object: %s' % json_data)
+
+def parse(json_string):
+  """Constructs the Protocol from the JSON text."""
+  try:
+    json_data = json.loads(json_string)
+  except:
+    raise ProtocolParseException('Error parsing JSON: %s' % json_string)
+
+  # construct the Avro Protocol object
+  return make_avpr_object(json_data)
+

Added: hadoop/avro/trunk/src/test/py/sample_ipc_client.py
URL: http://svn.apache.org/viewvc/hadoop/avro/trunk/src/test/py/sample_ipc_client.py?rev=896176&view=auto
==============================================================================
--- hadoop/avro/trunk/src/test/py/sample_ipc_client.py (added)
+++ hadoop/avro/trunk/src/test/py/sample_ipc_client.py Tue Jan  5 18:48:04 2010
@@ -0,0 +1,95 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import socket
+import sys
+
+from avro import ipc
+from avro import protocol
+from avro import schema
+
+MAIL_PROTOCOL_JSON = """\
+{"namespace": "example.proto",
+ "protocol": "Mail",
+
+ "types": [
+     {"name": "Message", "type": "record",
+      "fields": [
+          {"name": "to",   "type": "string"},
+          {"name": "from", "type": "string"},
+          {"name": "body", "type": "string"}
+      ]
+     }
+ ],
+
+ "messages": {
+     "send": {
+         "request": [{"name": "message", "type": "Message"}],
+         "response": "string"
+     },
+     "replay": {
+         "request": [],
+         "response": "string"
+     }
+ }
+}
+"""
+MAIL_PROTOCOL = protocol.parse(MAIL_PROTOCOL_JSON)
+SERVER_ADDRESS = ('localhost', 9090)
+
+class UsageError(Exception):
+  def __init__(self, value):
+    self.value = value
+  def __str__(self):
+    return repr(self.value)
+
+def make_requestor(server_address, protocol):
+  sock = socket.socket()
+  sock.connect(server_address)
+  client = ipc.SocketTransport(sock)
+  return ipc.Requestor(protocol, client)
+
+if __name__ == '__main__':
+  if len(sys.argv) not in [4, 5]:
+    raise UsageError("Usage: <to> <from> <body> [<count>]")
+
+  # client code - attach to the server and send a message
+  # fill in the Message record
+  message = dict()
+  message['to'] = sys.argv[1]
+  message['from'] = sys.argv[2]
+  message['body'] = sys.argv[3]
+
+  try:
+    num_messages = int(sys.argv[4])
+  except:
+    num_messages = 1
+
+  # build the parameters for the request
+  params = {}
+  params['message'] = message
+   
+  # send the requests and print the result
+  for msg_count in range(num_messages):
+    requestor = make_requestor(SERVER_ADDRESS, MAIL_PROTOCOL)
+    result = requestor.request('send', params)
+    print("Result: " + result)
+
+  # try out a replay message
+  requestor = make_requestor(SERVER_ADDRESS, MAIL_PROTOCOL)
+  result = requestor.request('replay', dict())
+  print("Replay Result: " + result)

Added: hadoop/avro/trunk/src/test/py/sample_ipc_server.py
URL: http://svn.apache.org/viewvc/hadoop/avro/trunk/src/test/py/sample_ipc_server.py?rev=896176&view=auto
==============================================================================
--- hadoop/avro/trunk/src/test/py/sample_ipc_server.py (added)
+++ hadoop/avro/trunk/src/test/py/sample_ipc_server.py Tue Jan  5 18:48:04 2010
@@ -0,0 +1,73 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from SocketServer import BaseRequestHandler, TCPServer
+from avro import ipc
+from avro import protocol
+from avro import schema
+
+MAIL_PROTOCOL_JSON = """\
+{"namespace": "example.proto",
+ "protocol": "Mail",
+
+ "types": [
+     {"name": "Message", "type": "record",
+      "fields": [
+          {"name": "to",   "type": "string"},
+          {"name": "from", "type": "string"},
+          {"name": "body", "type": "string"}
+      ]
+     }
+ ],
+
+ "messages": {
+     "send": {
+         "request": [{"name": "message", "type": "Message"}],
+         "response": "string"
+     },
+     "replay": {
+         "request": [],
+         "response": "string"
+     }
+ }
+}
+"""
+MAIL_PROTOCOL = protocol.parse(MAIL_PROTOCOL_JSON)
+SERVER_ADDRESS = ('localhost', 9090)
+
+class MailResponder(ipc.Responder):
+  def __init__(self):
+    ipc.Responder.__init__(self, MAIL_PROTOCOL)
+
+  def invoke(self, message, request):
+    if message.name == 'send':
+      request_content = request[0]
+      response = "Sent message to %(to)s from %(from)s with body %(body)s" % \
+                 request_content
+      return response
+    elif message.name == 'replay':
+      return 'replay'
+
+class MailHandler(BaseRequestHandler):
+  def handle(self):
+    self.responder = MailResponder()
+    self.transport = ipc.SocketTransport(self.request)
+    self.transport.write_framed_message(self.responder.respond(self.transport))
+
+if __name__ == '__main__':
+  mail_server = TCPServer(SERVER_ADDRESS, MailHandler)
+  mail_server.serve_forever()

Added: hadoop/avro/trunk/src/test/py/test_protocol.py
URL: http://svn.apache.org/viewvc/hadoop/avro/trunk/src/test/py/test_protocol.py?rev=896176&view=auto
==============================================================================
--- hadoop/avro/trunk/src/test/py/test_protocol.py (added)
+++ hadoop/avro/trunk/src/test/py/test_protocol.py Tue Jan  5 18:48:04 2010
@@ -0,0 +1,257 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+# http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Test the protocol parsing logic.
+"""
+import unittest
+from avro import protocol
+
+class ExampleProtocol(object):
+  def __init__(self, protocol_string, valid, name='', comment=''):
+    self._protocol_string = protocol_string
+    self._valid = valid
+    self._name = name or protocol_string # default to schema_string for name
+    self._comment = comment
+
+  # read-only properties
+  protocol_string = property(lambda self: self._protocol_string)
+  valid = property(lambda self: self._valid)
+  name = property(lambda self: self._name)
+
+  # read/write properties
+  def set_comment(self, new_comment): self._comment = new_comment
+  comment = property(lambda self: self._comment, set_comment)
+
+#
+# Example Protocols
+#
+
+EXAMPLES = [
+  ExampleProtocol("""\
+{
+  "namespace": "com.acme",
+  "protocol": "HelloWorld",
+
+  "types": [
+    {"name": "Greeting", "type": "record", "fields": [
+      {"name": "message", "type": "string"}]},
+    {"name": "Curse", "type": "error", "fields": [
+      {"name": "message", "type": "string"}]}
+  ],
+
+  "messages": {
+    "hello": {
+      "request": [{"name": "greeting", "type": "Greeting" }],
+      "response": "Greeting",
+      "errors": ["Curse"]
+    }
+  }
+}
+    """, True),
+  ExampleProtocol("""\
+{"namespace": "org.apache.avro.test",
+ "protocol": "Simple",
+
+ "types": [
+     {"name": "Kind", "type": "enum", "symbols": ["FOO","BAR","BAZ"]},
+
+     {"name": "MD5", "type": "fixed", "size": 16},
+
+     {"name": "TestRecord", "type": "record",
+      "fields": [
+          {"name": "name", "type": "string", "order": "ignore"},
+          {"name": "kind", "type": "Kind", "order": "descending"},
+          {"name": "hash", "type": "MD5"}
+      ]
+     },
+
+     {"name": "TestError", "type": "error", "fields": [
+         {"name": "message", "type": "string"}
+      ]
+     }
+
+ ],
+
+ "messages": {
+
+     "hello": {
+         "request": [{"name": "greeting", "type": "string"}],
+         "response": "string"
+     },
+
+     "echo": {
+         "request": [{"name": "record", "type": "TestRecord"}],
+         "response": "TestRecord"
+     },
+
+     "add": {
+         "request": [{"name": "arg1", "type": "int"}, {"name": "arg2", "type": "int"}],
+         "response": "int"
+     },
+
+     "echoBytes": {
+         "request": [{"name": "data", "type": "bytes"}],
+         "response": "bytes"
+     },
+
+     "error": {
+         "request": [],
+         "response": "null",
+         "errors": ["TestError"]
+     }
+ }
+
+}
+    """, True),
+  ExampleProtocol("""\
+{"namespace": "org.apache.avro.test.namespace",
+ "protocol": "TestNamespace",
+
+ "types": [
+     {"name": "org.apache.avro.test.util.MD5", "type": "fixed", "size": 16},
+     {"name": "TestRecord", "type": "record",
+      "fields": [ {"name": "hash", "type": "org.apache.avro.test.util.MD5"} ]
+     },
+     {"name": "TestError", "namespace": "org.apache.avro.test.errors",
+      "type": "error", "fields": [ {"name": "message", "type": "string"} ]
+     }
+ ],
+
+ "messages": {
+     "echo": {
+         "request": [{"name": "record", "type": "TestRecord"}],
+         "response": "TestRecord"
+     },
+
+     "error": {
+         "request": [],
+         "response": "null",
+         "errors": ["org.apache.avro.test.errors.TestError"]
+     }
+
+ }
+
+}
+    """, True),
+  ExampleProtocol("""\
+{"namespace": "org.apache.avro.test",
+ "protocol": "BulkData",
+
+ "types": [],
+
+ "messages": {
+
+     "read": {
+         "request": [],
+         "response": "bytes"
+     },
+
+     "write": {
+         "request": [ {"name": "data", "type": "bytes"} ],
+         "response": "null"
+     }
+
+ }
+
+}
+    """, True),
+]
+
+VALID_EXAMPLES = [e for e in EXAMPLES if e.valid]
+
+class TestProtocol(unittest.TestCase):
+  def test_parse(self):
+    print ''
+    print 'TEST PARSE'
+    print '=========='
+    print ''
+
+    num_correct = 0
+    for example in EXAMPLES:
+      try:
+        protocol.parse(example.protocol_string)
+        if example.valid: num_correct += 1
+        debug_msg = "%s: PARSE SUCCESS" % example.name
+      except:
+        if not example.valid: num_correct += 1
+        debug_msg = "%s: PARSE FAILURE" % example.name
+      finally:
+        print debug_msg
+
+    fail_msg = "Parse behavior correct on %d out of %d protocols." % \
+      (num_correct, len(EXAMPLES))
+    self.assertEqual(num_correct, len(EXAMPLES), fail_msg)
+
+  def test_valid_cast_to_string_after_parse(self):
+    """
+    Test that the string generated by an Avro Protocol object
+    is, in fact, a valid Avro protocol.
+    """
+    print ''
+    print 'TEST CAST TO STRING'
+    print '==================='
+    print ''
+
+    num_correct = 0
+    for example in VALID_EXAMPLES:
+      protocol_data = protocol.parse(example.protocol_string)
+      try:
+        protocol.parse(str(protocol_data))
+        debug_msg = "%s: STRING CAST SUCCESS" % example.name
+        num_correct += 1
+      except:
+        debug_msg = "%s: STRING CAST FAILURE" % example.name
+      finally:
+        print debug_msg
+
+    fail_msg = "Cast to string success on %d out of %d protocols" % \
+      (num_correct, len(VALID_EXAMPLES))
+    self.assertEqual(num_correct, len(VALID_EXAMPLES), fail_msg)
+
+  def test_equivalence_after_round_trip(self):
+    """
+    1. Given a string, parse it to get Avro protocol "original".
+    2. Serialize "original" to a string and parse that string
+         to generate Avro protocol "round trip".
+    3. Ensure "original" and "round trip" protocols are equivalent.
+    """
+    print ''
+    print 'TEST ROUND TRIP'
+    print '==============='
+    print ''
+
+    num_correct = 0
+    for example in VALID_EXAMPLES:
+      try:
+        original_protocol = protocol.parse(example.protocol_string)
+        round_trip_protocol = protocol.parse(str(original_protocol))
+
+        if original_protocol == round_trip_protocol:
+          num_correct += 1
+          debug_msg = "%s: ROUND TRIP SUCCESS" % example.name
+        else:       
+          debug_msg = "%s: ROUND TRIP FAILURE" % example.name
+      except:
+        debug_msg = "%s: ROUND TRIP FAILURE" % example.name
+      finally:
+        print debug_msg
+
+    fail_msg = "Round trip success on %d out of %d protocols" % \
+      (num_correct, len(VALID_EXAMPLES))
+    self.assertEqual(num_correct, len(VALID_EXAMPLES), fail_msg)
+
+if __name__ == '__main__':
+  unittest.main()



Mime
View raw message