beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From "ASF GitHub Bot (JIRA)" <j...@apache.org>
Subject [jira] [Work logged] (BEAM-2687) Python SDK support for Stateful Processing
Date Thu, 20 Sep 2018 08:28:00 GMT

     [ https://issues.apache.org/jira/browse/BEAM-2687?focusedWorklogId=145950&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-145950
]

ASF GitHub Bot logged work on BEAM-2687:
----------------------------------------

                Author: ASF GitHub Bot
            Created on: 20/Sep/18 08:27
            Start Date: 20/Sep/18 08:27
    Worklog Time Spent: 10m 
      Work Description: robertwb closed pull request #6349: [BEAM-2687] Implement State over
the Fn API
URL: https://github.com/apache/beam/pull/6349
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py
index 2186df4e635..b09142a5ae4 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -271,6 +271,10 @@ def is_splittable_dofn(self):
   def is_stateful_dofn(self):
     return self._is_stateful_dofn
 
+  def has_timers(self):
+    _, all_timer_specs = userstate.get_dofn_specs(self.do_fn)
+    return bool(all_timer_specs)
+
 
 class DoFnInvoker(object):
   """An abstraction that can be used to execute DoFn methods.
diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py
index a19b50f3a7b..00e37f3c9fd 100644
--- a/sdks/python/apache_beam/runners/direct/direct_runner.py
+++ b/sdks/python/apache_beam/runners/direct/direct_runner.py
@@ -105,8 +105,8 @@ def visit_transform(self, applied_ptransform):
           # The FnApiRunner does not support execution of SplittableDoFns.
           if DoFnSignature(dofn).is_splittable_dofn():
             self.supported_by_fnapi_runner = False
-          # The FnApiRunner does not support execution of Stateful DoFns.
-          if DoFnSignature(dofn).is_stateful_dofn():
+          # The FnApiRunner does not support execution of DoFns with timers.
+          if DoFnSignature(dofn).has_timers():
             self.supported_by_fnapi_runner = False
           # The FnApiRunner does not support execution of CombineFns with
           # deferred side inputs.
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index 9e39ca82c84..da6d79d9dc9 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -1010,21 +1010,24 @@ def State(self, request_stream, context=None):
       # Note that this eagerly mutates state, assuming any failures are fatal.
       # Thus it is safe to ignore instruction_reference.
       for request in request_stream:
-        if request.get:
+        request_type = request.WhichOneof('request')
+        if request_type == 'get':
           yield beam_fn_api_pb2.StateResponse(
               id=request.id,
               get=beam_fn_api_pb2.StateGetResponse(
                   data=self.blocking_get(request.state_key)))
-        elif request.append:
+        elif request_type == 'append':
           self.blocking_append(request.state_key, request.append.data)
           yield beam_fn_api_pb2.StateResponse(
               id=request.id,
-              append=beam_fn_api_pb2.AppendResponse())
-        elif request.clear:
+              append=beam_fn_api_pb2.StateAppendResponse())
+        elif request_type == 'clear':
           self.blocking_clear(request.state_key)
           yield beam_fn_api_pb2.StateResponse(
               id=request.id,
-              clear=beam_fn_api_pb2.ClearResponse())
+              clear=beam_fn_api_pb2.StateClearResponse())
+        else:
+          raise NotImplementedError('Unknown state request: %s' % request_type)
 
   class SingletonStateHandlerFactory(sdk_worker.StateHandlerFactory):
     """A singleton cache for a StateServicer."""
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index 1064f626d66..57e270baa0d 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -36,6 +36,7 @@
 from apache_beam.runners.worker import statesampler
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
+from apache_beam.transforms import userstate
 from apache_beam.transforms import window
 
 if statesampler.FAST_SAMPLER:
@@ -228,6 +229,29 @@ def cross_product(elem, sides):
           pcoll | beam.FlatMap(cross_product, beam.pvalue.AsList(derived)),
           equal_to([('a', 'a'), ('a', 'b'), ('b', 'a'), ('b', 'b')]))
 
+  def test_pardo_state_only(self):
+
+    index_state_spec = userstate.CombiningValueStateSpec(
+        'index', beam.coders.VarIntCoder(), sum)
+
+    # TODO(ccy): State isn't detected with Map/FlatMap.
+    class AddIndex(beam.DoFn):
+      def process(self, kv, index=beam.DoFn.StateParam(index_state_spec)):
+        k, v = kv
+        index.add(1)
+        yield k, v, index.read()
+
+    inputs = [('A', 'a')] * 2 + [('B', 'b')] * 3
+    expected = [('A', 'a', 1),
+                ('A', 'a', 2),
+                ('B', 'b', 1),
+                ('B', 'b', 2),
+                ('B', 'b', 3)]
+
+    with self.create_pipeline() as p:
+      assert_that(p | beam.Create(inputs) | beam.ParDo(AddIndex()),
+                  equal_to(expected))
+
   def test_group_by_key(self):
     with self.create_pipeline() as p:
       res = (p
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index 4b7e9cda1bf..d119d3d1f68 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -25,6 +25,7 @@
 import collections
 import json
 import logging
+import random
 import re
 from builtins import next
 from builtins import object
@@ -46,6 +47,7 @@
 from apache_beam.runners.worker import operations
 from apache_beam.runners.worker import statesampler
 from apache_beam.transforms import sideinputs
+from apache_beam.transforms import userstate
 from apache_beam.utils import counters
 from apache_beam.utils import proto_utils
 
@@ -121,6 +123,23 @@ def process_encoded(self, encoded_windowed_values):
       self.output(decoded_value)
 
 
+class _StateBackedIterable(object):
+  def __init__(self, state_handler, state_key, coder):
+    self._state_handler = state_handler
+    self._state_key = state_key
+    self._coder_impl = coder.get_impl()
+
+  def __iter__(self):
+    # TODO(robertwb): Support pagination.
+    input_stream = coder_impl.create_InputStream(
+        self._state_handler.blocking_get(self._state_key))
+    while input_stream.size() > 0:
+      yield self._coder_impl.decode_from_stream(input_stream, True)
+
+  def __reduce__(self):
+    return list, (list(self),)
+
+
 class StateBackedSideInputMap(object):
   def __init__(self, state_handler, transform_id, tag, side_input_data, coder):
     self._state_handler = state_handler
@@ -145,23 +164,9 @@ def __getitem__(self, window):
       state_handler = self._state_handler
       access_pattern = self._side_input_data.access_pattern
 
-      class AllElements(object):
-        def __init__(self, state_key, coder):
-          self._state_key = state_key
-          self._coder_impl = coder.get_impl()
-
-        def __iter__(self):
-          # TODO(robertwb): Support pagination.
-          input_stream = coder_impl.create_InputStream(
-              state_handler.blocking_get(self._state_key))
-          while input_stream.size() > 0:
-            yield self._coder_impl.decode_from_stream(input_stream, True)
-
-        def __reduce__(self):
-          return list, (list(self),)
-
       if access_pattern == common_urns.side_inputs.ITERABLE.urn:
-        raw_view = AllElements(state_key, self._element_coder)
+        raw_view = _StateBackedIterable(
+            state_handler, state_key, self._element_coder)
 
       elif (access_pattern == common_urns.side_inputs.MULTIMAP.urn or
             access_pattern ==
@@ -177,7 +182,8 @@ def __getitem__(self, key):
               keyed_state_key.CopyFrom(state_key)
               keyed_state_key.multimap_side_input.key = (
                   key_coder_impl.encode_nested(key))
-              cache[key] = AllElements(keyed_state_key, value_coder)
+              cache[key] = _StateBackedIterable(
+                  state_handler, keyed_state_key, value_coder)
             return cache[key]
 
           def __reduce__(self):
@@ -198,6 +204,87 @@ def is_globally_windowed(self):
             == sideinputs._global_window_mapping_fn)
 
 
+class CombiningValueRuntimeState(userstate.RuntimeState):
+  def __init__(self, underlying_bag_state, combinefn):
+    self._combinefn = combinefn
+    self._underlying_bag_state = underlying_bag_state
+
+  def _read_accumulator(self, rewrite=True):
+    merged_accumulator = self._combinefn.merge_accumulators(
+        self._underlying_bag_state.read())
+    if rewrite:
+      self._underlying_bag_state.clear()
+      self._underlying_bag_state.add(merged_accumulator)
+    return merged_accumulator
+
+  def read(self):
+    return self._combinefn.extract_output(self._read_accumulator())
+
+  def add(self, value):
+    # Prefer blind writes, but don't let them grow unboundedly.
+    # This should be tuned to be much lower, but for now exercise
+    # both paths well.
+    if random.random() < 0.5:
+      accumulator = self._read_accumulator(False)
+      self._underlying_bag_state.clear()
+    else:
+      accumulator = self._combinefn.create_accumulator()
+    self._underlying_bag_state.add(
+        self._combinefn.add_input(accumulator, value))
+
+  def clear(self):
+    self._underlying_bag_state.clear()
+
+
+# TODO(BEAM-5428): Implement cross-bundle state caching.
+class SynchronousBagRuntimeState(userstate.RuntimeState):
+  def __init__(self, state_handler, state_key, value_coder):
+    self._state_handler = state_handler
+    self._state_key = state_key
+    self._value_coder = value_coder
+
+  def read(self):
+    return _StateBackedIterable(
+        self._state_handler, self._state_key, self._value_coder)
+
+  def add(self, value):
+    self._state_handler.blocking_append(
+        self._state_key, self._value_coder.encode(value))
+
+  def clear(self):
+    self._state_handler.blocking_clear(self._state_key)
+
+
+class FnApiUserStateContext(userstate.UserStateContext):
+  def __init__(self, state_handler, transform_id, key_coder, window_coder):
+    self._state_handler = state_handler
+    self._transform_id = transform_id
+    self._key_coder = key_coder
+    self._window_coder = window_coder
+
+  def get_timer(self, timer_spec, key, window):
+    raise NotImplementedError
+
+  def get_state(self, state_spec, key, window):
+    if isinstance(state_spec,
+                  (userstate.BagStateSpec, userstate.CombiningValueStateSpec)):
+      bag_state = SynchronousBagRuntimeState(
+          self._state_handler,
+          state_key=beam_fn_api_pb2.StateKey(
+              bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
+                  ptransform_id=self._transform_id,
+                  user_state_id=state_spec.name,
+                  window=self._window_coder.encode(window),
+                  key=self._key_coder.encode(key))),
+          value_coder=state_spec.coder)
+      if isinstance(state_spec, userstate.BagStateSpec):
+        return bag_state
+      else:
+        return CombiningValueRuntimeState(bag_state, state_spec.combine_fn)
+    else:
+      raise NotImplementedError(state_spec)
+
+
 def memoize(func):
   cache = {}
   missing = object()
@@ -545,6 +632,16 @@ def mutate_tag(tag):
         factory.descriptor.pcollections[pcoll_id].windowing_strategy_id)
     serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing,))
 
+  if userstate.is_stateful_dofn(dofn_data[0]):
+    input_coder = factory.get_only_input_coder(transform_proto)
+    user_state_context = FnApiUserStateContext(
+        factory.state_handler,
+        transform_id,
+        input_coder.key_coder(),
+        input_coder.window_coder)
+  else:
+    user_state_context = None
+
   output_coders = factory.get_output_coders(transform_proto)
   spec = operation_specs.WorkerDoFn(
       serialized_fn=serialized_fn,
@@ -552,13 +649,15 @@ def mutate_tag(tag):
       input=None,
       side_inputs=None,  # Fn API uses proto definitions and the Fn State API
       output_coders=[output_coders[tag] for tag in output_tags])
+
   return factory.augment_oldstyle_op(
       operations.DoOperation(
           transform_proto.unique_name,
           spec,
           factory.counter_factory,
           factory.state_sampler,
-          side_input_maps),
+          side_input_maps,
+          user_state_context),
       transform_proto.unique_name,
       consumers,
       output_tags)
diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd b/sdks/python/apache_beam/runners/worker/operations.pxd
index f27de8d0170..848be52bf26 100644
--- a/sdks/python/apache_beam/runners/worker/operations.pxd
+++ b/sdks/python/apache_beam/runners/worker/operations.pxd
@@ -75,6 +75,7 @@ cdef class DoOperation(Operation):
   cdef Receiver dofn_receiver
   cdef object tagged_receivers
   cdef object side_input_maps
+  cdef object user_state_context
 
 
 cdef class CombineOperation(Operation):
diff --git a/sdks/python/apache_beam/runners/worker/operations.py b/sdks/python/apache_beam/runners/worker/operations.py
index 0488fe928d3..efb9450e4e7 100644
--- a/sdks/python/apache_beam/runners/worker/operations.py
+++ b/sdks/python/apache_beam/runners/worker/operations.py
@@ -272,9 +272,11 @@ class DoOperation(Operation):
   """A Do operation that will execute a custom DoFn for each input element."""
 
   def __init__(
-      self, name, spec, counter_factory, sampler, side_input_maps=None):
+      self, name, spec, counter_factory, sampler, side_input_maps=None,
+      user_state_context=None):
     super(DoOperation, self).__init__(name, spec, counter_factory, sampler)
     self.side_input_maps = side_input_maps
+    self.user_state_context = user_state_context
     self.tagged_receivers = None
 
   def _read_side_inputs(self, tags_and_types):
@@ -375,6 +377,7 @@ def start(self):
           tagged_receivers=self.tagged_receivers,
           step_name=self.name_context.logging_name(),
           state=state,
+          user_state_context=self.user_state_context,
           operation_name=self.name_context.metrics_name())
 
       self.dofn_receiver = (self.dofn_runner
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index 438eee687da..1785408b1af 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -43,6 +43,7 @@
 from apache_beam.portability import python_urns
 from apache_beam.portability.api import beam_runner_api_pb2
 from apache_beam.transforms import ptransform
+from apache_beam.transforms import userstate
 from apache_beam.transforms.display import DisplayDataItem
 from apache_beam.transforms.display import HasDisplayData
 from apache_beam.transforms.ptransform import PTransform
@@ -637,7 +638,12 @@ def from_callable(fn):
 
   @staticmethod
   def maybe_from_callable(fn):
-    return fn if isinstance(fn, CombineFn) else CallableWrapperCombineFn(fn)
+    if isinstance(fn, CombineFn):
+      return fn
+    elif callable(fn):
+      return CallableWrapperCombineFn(fn)
+    else:
+      raise TypeError('Expected a CombineFn or callable, got %r' % fn)
 
   def get_accumulator_coder(self):
     return coders.registry.get_coder(object)
@@ -942,6 +948,7 @@ def to_runner_api_parameter(self, context):
     assert isinstance(self, ParDo), \
         "expected instance of ParDo, but got %s" % self.__class__
     picked_pardo_fn_data = pickler.dumps(self._pardo_fn_data())
+    state_specs, timer_specs = userstate.get_dofn_specs(self.fn)
     return (
         common_urns.primitives.PAR_DO.urn,
         beam_runner_api_pb2.ParDoPayload(
@@ -950,6 +957,10 @@ def to_runner_api_parameter(self, context):
                 spec=beam_runner_api_pb2.FunctionSpec(
                     urn=python_urns.PICKLED_DOFN_INFO,
                     payload=picked_pardo_fn_data)),
+            state_specs={spec.name: spec.to_runner_api(context)
+                         for spec in state_specs},
+            timer_specs={spec.name: spec.to_runner_api(context)
+                         for spec in timer_specs},
             # It'd be nice to name these according to their actual
             # names/positions in the orignal argument list, but such a
             # transformation is currently irreversible given how
diff --git a/sdks/python/apache_beam/transforms/timeutil.py b/sdks/python/apache_beam/transforms/timeutil.py
index bf30a131392..55c7921cbe8 100644
--- a/sdks/python/apache_beam/transforms/timeutil.py
+++ b/sdks/python/apache_beam/transforms/timeutil.py
@@ -25,6 +25,8 @@
 
 from future.utils import with_metaclass
 
+from apache_beam.portability.api import beam_runner_api_pb2
+
 __all__ = [
     'TimeDomain',
     ]
@@ -37,6 +39,13 @@ class TimeDomain(object):
   REAL_TIME = 'REAL_TIME'
   DEPENDENT_REAL_TIME = 'DEPENDENT_REAL_TIME'
 
+  _RUNNER_API_MAPPING = {
+      WATERMARK: beam_runner_api_pb2.TimeDomain.EVENT_TIME,
+      REAL_TIME: beam_runner_api_pb2.TimeDomain.PROCESSING_TIME,
+      DEPENDENT_REAL_TIME:
+      beam_runner_api_pb2.TimeDomain.SYNCHRONIZED_PROCESSING_TIME,
+  }
+
   @staticmethod
   def from_string(domain):
     if domain in (TimeDomain.WATERMARK,
@@ -45,6 +54,10 @@ def from_string(domain):
       return domain
     raise ValueError('Unknown time domain: %s' % domain)
 
+  @staticmethod
+  def to_runner_api(domain):
+    return TimeDomain._RUNNER_API_MAPPING[domain]
+
 
 class TimestampCombinerImpl(with_metaclass(ABCMeta, object)):
   """Implementation of TimestampCombiner."""
diff --git a/sdks/python/apache_beam/transforms/userstate.py b/sdks/python/apache_beam/transforms/userstate.py
index 6c2eabcd558..c7fc96b1364 100644
--- a/sdks/python/apache_beam/transforms/userstate.py
+++ b/sdks/python/apache_beam/transforms/userstate.py
@@ -27,6 +27,7 @@
 from builtins import object
 
 from apache_beam.coders import Coder
+from apache_beam.portability.api import beam_runner_api_pb2
 from apache_beam.transforms.timeutil import TimeDomain
 
 
@@ -39,6 +40,9 @@ def __init__(self):
   def __repr__(self):
     return '%s(%s)' % (self.__class__.__name__, self.name)
 
+  def to_runner_api(self, context):
+    raise NotImplementedError
+
 
 class BagStateSpec(StateSpec):
   """Specification for a user DoFn bag state cell."""
@@ -49,6 +53,11 @@ def __init__(self, name, coder):
     self.name = name
     self.coder = coder
 
+  def to_runner_api(self, context):
+    return beam_runner_api_pb2.StateSpec(
+        bag_spec=beam_runner_api_pb2.BagStateSpec(
+            element_coder_id=context.coders.get_id(self.coder)))
+
 
 class CombiningValueStateSpec(StateSpec):
   """Specification for a user DoFn combining value state cell."""
@@ -59,11 +68,16 @@ def __init__(self, name, coder, combine_fn):
 
     assert isinstance(name, str)
     assert isinstance(coder, Coder)
-    assert isinstance(combine_fn, CombineFn)
     self.name = name
     # The coder here should be for the accumulator type of the given CombineFn.
     self.coder = coder
-    self.combine_fn = combine_fn
+    self.combine_fn = CombineFn.maybe_from_callable(combine_fn)
+
+  def to_runner_api(self, context):
+    return beam_runner_api_pb2.StateSpec(
+        combining_spec=beam_runner_api_pb2.CombiningStateSpec(
+            combine_fn=self.combine_fn.to_runner_api(context),
+            accumulator_coder_id=context.coders.get_id(self.coder)))
 
 
 class TimerSpec(object):
@@ -79,6 +93,10 @@ def __init__(self, name, time_domain):
   def __repr__(self):
     return '%s(%s)' % (self.__class__.__name__, self.name)
 
+  def to_runner_api(self, context):
+    return beam_runner_api_pb2.TimerSpec(
+        time_domain=TimeDomain.to_runner_api(self.time_domain))
+
 
 def on_timer(timer_spec):
   """Decorator for timer firing DoFn method.
diff --git a/sdks/python/apache_beam/transforms/userstate_test.py b/sdks/python/apache_beam/transforms/userstate_test.py
index 479e66c138c..d3d592f7c6f 100644
--- a/sdks/python/apache_beam/transforms/userstate_test.py
+++ b/sdks/python/apache_beam/transforms/userstate_test.py
@@ -104,7 +104,7 @@ def test_spec_construction(self):
     CombiningValueStateSpec('statename', VarIntCoder(), TopCombineFn(10))
     with self.assertRaises(AssertionError):
       CombiningValueStateSpec(123, VarIntCoder(), TopCombineFn(10))
-    with self.assertRaises(AssertionError):
+    with self.assertRaises(TypeError):
       CombiningValueStateSpec('statename', VarIntCoder(), object())
     # BagStateSpec('bag', )
     # TODO: add more spec tests


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


Issue Time Tracking
-------------------

    Worklog Id:     (was: 145950)
    Time Spent: 3h 50m  (was: 3h 40m)

> Python SDK support for Stateful Processing
> ------------------------------------------
>
>                 Key: BEAM-2687
>                 URL: https://issues.apache.org/jira/browse/BEAM-2687
>             Project: Beam
>          Issue Type: New Feature
>          Components: sdk-py-core
>            Reporter: Ahmet Altay
>            Assignee: Charles Chen
>            Priority: Major
>          Time Spent: 3h 50m
>  Remaining Estimate: 0h
>
> Python SDK should support stateful processing (https://beam.apache.org/blog/2017/02/13/stateful-processing.html)
> In the meantime, runner capability matrix should be updated to show the lack of this
feature (https://beam.apache.org/documentation/runners/capability-matrix/)
> Use this as an umbrella issue for all related issues.



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Mime
View raw message