beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From rober...@apache.org
Subject [beam] branch master updated: [BEAM -7741] Implement SetState for Python SDK (#9090)
Date Thu, 08 Aug 2019 14:17:57 GMT
This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new c9e5ea8  [BEAM -7741] Implement SetState for Python SDK (#9090)
c9e5ea8 is described below

commit c9e5ea843841ac4898d0104e536bd4b2fc297d33
Author: Rakesh Kumar <rakeshkumar@lyft.com>
AuthorDate: Thu Aug 8 07:17:26 2019 -0700

    [BEAM -7741] Implement SetState for Python SDK (#9090)
---
 .../apache_beam/runners/direct/direct_userstate.py |   9 ++
 .../apache_beam/runners/worker/bundle_processor.py |  70 +++++++++
 sdks/python/apache_beam/transforms/trigger.py      |  16 +++
 sdks/python/apache_beam/transforms/userstate.py    |  51 +++++++
 .../apache_beam/transforms/userstate_test.py       | 156 ++++++++++++++++++++-
 5 files changed, 301 insertions(+), 1 deletion(-)

diff --git a/sdks/python/apache_beam/runners/direct/direct_userstate.py b/sdks/python/apache_beam/runners/direct/direct_userstate.py
index f0fd9b8..b764ea4 100644
--- a/sdks/python/apache_beam/runners/direct/direct_userstate.py
+++ b/sdks/python/apache_beam/runners/direct/direct_userstate.py
@@ -20,6 +20,7 @@ from __future__ import absolute_import
 
 from apache_beam.transforms import userstate
 from apache_beam.transforms.trigger import _ListStateTag
+from apache_beam.transforms.trigger import _SetStateTag
 
 
 class DirectUserStateContext(userstate.UserStateContext):
@@ -43,6 +44,8 @@ class DirectUserStateContext(userstate.UserStateContext):
         state_tag = _ListStateTag(state_key)
       elif isinstance(state_spec, userstate.CombiningValueStateSpec):
         state_tag = _ListStateTag(state_key)
+      elif isinstance(state_spec, userstate.SetStateSpec):
+        state_tag = _SetStateTag(state_key)
       else:
         raise ValueError('Invalid state spec: %s' % state_spec)
       self.state_tags[state_spec] = state_tag
@@ -93,6 +96,12 @@ class DirectUserStateContext(userstate.UserStateContext):
           state.add_state(
               window, state_tag,
               state_spec.coder.encode(runtime_state._current_accumulator))
+      elif isinstance(state_spec, userstate.SetStateSpec):
+        if runtime_state.is_modified():
+          state.clear_state(window, state_tag)
+          for new_value in runtime_state._current_accumulator:
+            state.add_state(
+                window, state_tag, state_spec.coder.encode(new_value))
       else:
         raise ValueError('Invalid state spec: %s' % state_spec)
 
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index a85e4da..72a41b7 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -208,6 +208,7 @@ class _StateBackedIterable(object):
       self._coder_impl = coder_or_impl
 
   def __iter__(self):
+    # This is the continuation token this might be useful
     data, continuation_token = self._state_handler.blocking_get(self._state_key)
     while True:
       input_stream = coder_impl.create_InputStream(data)
@@ -379,6 +380,65 @@ class SynchronousBagRuntimeState(userstate.RuntimeState):
       self._state_handler.blocking_append(self._state_key, out.get())
 
 
+# TODO(BEAM-5428): Implement cross-bundle state caching.
+class SynchronousSetRuntimeState(userstate.RuntimeState):
+
+  def __init__(self, state_handler, state_key, value_coder):
+    self._state_handler = state_handler
+    self._state_key = state_key
+    self._value_coder = value_coder
+    self._cleared = False
+    self._added_elements = set()
+
+  def _compact_data(self, rewrite=True):
+    accumulator = set(_ConcatIterable(
+        set() if self._cleared else _StateBackedIterable(
+            self._state_handler, self._state_key, self._value_coder),
+        self._added_elements))
+
+    if rewrite and accumulator:
+      self._state_handler.blocking_clear(self._state_key)
+
+      value_coder_impl = self._value_coder.get_impl()
+      out = coder_impl.create_OutputStream()
+      for element in accumulator:
+        value_coder_impl.encode_to_stream(element, out, True)
+      self._state_handler.blocking_append(self._state_key, out.get())
+
+      # Since everthing is already committed so we can safely reinitialize
+      # added_elements here.
+      self._added_elements = set()
+
+    return accumulator
+
+  def read(self):
+    return self._compact_data(rewrite=False)
+
+  def add(self, value):
+    if self._cleared:
+      # This is a good time explicitly clear.
+      self._state_handler.blocking_clear(self._state_key)
+      self._cleared = False
+
+    self._added_elements.add(value)
+    if random.random() > 0.5:
+      self._compact_data()
+
+  def clear(self):
+    self._cleared = True
+    self._added_elements = set()
+
+  def _commit(self):
+    if self._cleared:
+      self._state_handler.blocking_clear(self._state_key)
+    if self._added_elements:
+      value_coder_impl = self._value_coder.get_impl()
+      out = coder_impl.create_OutputStream()
+      for element in self._added_elements:
+        value_coder_impl.encode_to_stream(element, out, True)
+      self._state_handler.blocking_append(self._state_key, out.get())
+
+
 class OutputTimer(object):
   def __init__(self, key, window, receiver):
     self._key = key
@@ -454,6 +514,16 @@ class FnApiUserStateContext(userstate.UserStateContext):
         return bag_state
       else:
         return CombiningValueRuntimeState(bag_state, state_spec.combine_fn)
+    elif isinstance(state_spec, userstate.SetStateSpec):
+      return SynchronousSetRuntimeState(
+          self._state_handler,
+          state_key=beam_fn_api_pb2.StateKey(
+              bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
+                  ptransform_id=self._transform_id,
+                  user_state_id=state_spec.name,
+                  window=self._window_coder.encode(window),
+                  key=self._key_coder.encode(key))),
+          value_coder=state_spec.coder)
     else:
       raise NotImplementedError(state_spec)
 
diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py
index f0d2a1c..dbda3cc 100644
--- a/sdks/python/apache_beam/transforms/trigger.py
+++ b/sdks/python/apache_beam/transforms/trigger.py
@@ -93,6 +93,16 @@ class _ValueStateTag(_StateTag):
     return _ValueStateTag(prefix + self.tag)
 
 
+class _SetStateTag(_StateTag):
+  """StateTag pointing to an element."""
+
+  def __repr__(self):
+    return 'SetStateTag({tag})'.format(tag=self.tag)
+
+  def with_prefix(self, prefix):
+    return _SetStateTag(prefix + self.tag)
+
+
 class _CombiningValueStateTag(_StateTag):
   """StateTag pointing to an element, accumulated with a combiner.
 
@@ -865,6 +875,8 @@ class MergeableStateAdapter(SimpleState):
           original_tag.combine_fn.merge_accumulators(values))
     elif isinstance(tag, _ListStateTag):
       return [v for vs in values for v in vs]
+    elif isinstance(tag, _SetStateTag):
+      return {v for vs in values for v in vs}
     elif isinstance(tag, _WatermarkHoldStateTag):
       return tag.timestamp_combiner_impl.combine_all(values)
     else:
@@ -1226,6 +1238,8 @@ class InMemoryUnmergedState(UnmergedState):
       self.state[window][tag.tag].append(value)
     elif isinstance(tag, _ListStateTag):
       self.state[window][tag.tag].append(value)
+    elif isinstance(tag, _SetStateTag):
+      self.state[window][tag.tag].append(value)
     elif isinstance(tag, _WatermarkHoldStateTag):
       self.state[window][tag.tag].append(value)
     else:
@@ -1239,6 +1253,8 @@ class InMemoryUnmergedState(UnmergedState):
       return tag.combine_fn.apply(values)
     elif isinstance(tag, _ListStateTag):
       return values
+    elif isinstance(tag, _SetStateTag):
+      return values
     elif isinstance(tag, _WatermarkHoldStateTag):
       return tag.timestamp_combiner_impl.combine_all(values)
     else:
diff --git a/sdks/python/apache_beam/transforms/userstate.py b/sdks/python/apache_beam/transforms/userstate.py
index aa4e866..4662d13 100644
--- a/sdks/python/apache_beam/transforms/userstate.py
+++ b/sdks/python/apache_beam/transforms/userstate.py
@@ -60,6 +60,23 @@ class BagStateSpec(StateSpec):
             element_coder_id=context.coders.get_id(self.coder)))
 
 
+class SetStateSpec(StateSpec):
+  """Specification for a user DoFn Set State cell"""
+
+  def __init__(self, name, coder):
+    if not isinstance(name, str):
+      raise TypeError("SetState name is not a string")
+    if not isinstance(coder, Coder):
+      raise TypeError("SetState coder is not of type Coder")
+    self.name = name
+    self.coder = coder
+
+  def to_runner_api(self, context):
+    return beam_runner_api_pb2.StateSpec(
+        set_spec=beam_runner_api_pb2.SetStateSpec(
+            element_coder_id=context.coders.get_id(self.coder)))
+
+
 class CombiningValueStateSpec(StateSpec):
   """Specification for a user DoFn combining value state cell."""
 
@@ -264,6 +281,8 @@ class RuntimeState(object):
     elif isinstance(state_spec, CombiningValueStateSpec):
       return CombiningValueRuntimeState(state_spec, state_tag,
                                         current_value_accessor)
+    elif isinstance(state_spec, SetStateSpec):
+      return SetRuntimeState(state_spec, state_tag, current_value_accessor)
     else:
       raise ValueError('Invalid state spec: %s' % state_spec)
 
@@ -310,6 +329,38 @@ class BagRuntimeState(RuntimeState):
     self._new_values = []
 
 
+class SetRuntimeState(RuntimeState):
+  """Set state interface object passed to user code."""
+
+  def __init__(self, state_spec, state_tag, current_value_accessor):
+    super(SetRuntimeState, self).__init__(
+        state_spec, state_tag, current_value_accessor)
+    self._current_accumulator = UNREAD_VALUE
+    self._modified = False
+
+  def _read_initial_value(self):
+    if self._current_accumulator is UNREAD_VALUE:
+      self._current_accumulator = {
+          self._decode(a) for a in self._current_value_accessor()
+      }
+
+  def read(self):
+    self._read_initial_value()
+    return self._current_accumulator
+
+  def add(self, value):
+    self._read_initial_value()
+    self._modified = True
+    self._current_accumulator.add(value)
+
+  def clear(self):
+    self._current_accumulator = set()
+    self._modified = True
+
+  def is_modified(self):
+    return self._modified
+
+
 class CombiningValueRuntimeState(RuntimeState):
   """Combining value state interface object passed to user code."""
 
diff --git a/sdks/python/apache_beam/transforms/userstate_test.py b/sdks/python/apache_beam/transforms/userstate_test.py
index 0d98337..8e55cee 100644
--- a/sdks/python/apache_beam/transforms/userstate_test.py
+++ b/sdks/python/apache_beam/transforms/userstate_test.py
@@ -31,6 +31,7 @@ from apache_beam.options.pipeline_options import PipelineOptions
 from apache_beam.runners.common import DoFnSignature
 from apache_beam.testing.test_pipeline import TestPipeline
 from apache_beam.testing.test_stream import TestStream
+from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
 from apache_beam.transforms import trigger
 from apache_beam.transforms import userstate
@@ -41,6 +42,7 @@ from apache_beam.transforms.core import DoFn
 from apache_beam.transforms.timeutil import TimeDomain
 from apache_beam.transforms.userstate import BagStateSpec
 from apache_beam.transforms.userstate import CombiningValueStateSpec
+from apache_beam.transforms.userstate import SetStateSpec
 from apache_beam.transforms.userstate import TimerSpec
 from apache_beam.transforms.userstate import get_dofn_specs
 from apache_beam.transforms.userstate import is_stateful_dofn
@@ -114,7 +116,13 @@ class InterfaceTest(unittest.TestCase):
       CombiningValueStateSpec(123, VarIntCoder(), TopCombineFn(10))
     with self.assertRaises(TypeError):
       CombiningValueStateSpec('statename', VarIntCoder(), object())
-    # BagStateSpec('bag', )
+    SetStateSpec('setstatename', VarIntCoder())
+
+    with self.assertRaises(TypeError):
+      SetStateSpec(123, VarIntCoder())
+    with self.assertRaises(TypeError):
+      SetStateSpec('setstatename', object())
+
     # TODO: add more spec tests
     with self.assertRaises(ValueError):
       DoFn.TimerParam(BagStateSpec('elements', BytesCoder()))
@@ -415,6 +423,152 @@ class StatefulDoFnOnDirectRunnerTest(unittest.TestCase):
         ['extra'],
         StatefulDoFnOnDirectRunnerTest.all_records)
 
+  def test_simple_set_stateful_dofn(self):
+    class SimpleTestSetStatefulDoFn(DoFn):
+      BUFFER_STATE = SetStateSpec('buffer', VarIntCoder())
+      EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK)
+
+      def process(self, element, buffer=DoFn.StateParam(BUFFER_STATE),
+                  timer1=DoFn.TimerParam(EXPIRY_TIMER)):
+        unused_key, value = element
+        buffer.add(value)
+        timer1.set(20)
+
+      @on_timer(EXPIRY_TIMER)
+      def expiry_callback(self, buffer=DoFn.StateParam(BUFFER_STATE)):
+        yield sorted(buffer.read())
+
+    with TestPipeline() as p:
+      test_stream = (TestStream()
+                     .advance_watermark_to(10)
+                     .add_elements([1, 2, 3])
+                     .add_elements([2])
+                     .advance_watermark_to(24))
+      (p
+       | test_stream
+       | beam.Map(lambda x: ('mykey', x))
+       | beam.ParDo(SimpleTestSetStatefulDoFn())
+       | beam.ParDo(self.record_dofn()))
+
+    # Two firings should occur: once after element 3 since the timer should
+    # fire after the watermark passes time 20, and another time after element
+    # 4, since the timer issued at that point should fire immediately.
+    self.assertEqual(
+        [[1, 2, 3]],
+        StatefulDoFnOnDirectRunnerTest.all_records)
+
+  def test_clearing_set_state(self):
+    class SetStateClearingStatefulDoFn(beam.DoFn):
+
+      SET_STATE = SetStateSpec('buffer', StrUtf8Coder())
+      EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK)
+      CLEAR_TIMER = TimerSpec('clear_timer', TimeDomain.WATERMARK)
+
+      def process(self,
+                  element,
+                  set_state=beam.DoFn.StateParam(SET_STATE),
+                  emit_timer=beam.DoFn.TimerParam(EMIT_TIMER),
+                  clear_timer=beam.DoFn.TimerParam(CLEAR_TIMER)):
+        value = element[1]
+        set_state.add(value)
+        clear_timer.set(100)
+        emit_timer.set(1000)
+
+      @on_timer(EMIT_TIMER)
+      def emit_values(self, set_state=beam.DoFn.StateParam(SET_STATE)):
+        for value in set_state.read():
+          yield value
+
+      @on_timer(CLEAR_TIMER)
+      def clear_values(self, set_state=beam.DoFn.StateParam(SET_STATE)):
+        set_state.clear()
+        set_state.add('different-value')
+
+    with TestPipeline() as p:
+      test_stream = (TestStream()
+                     .advance_watermark_to(0)
+                     .add_elements([('key1', 'value1')])
+                     .advance_watermark_to(100))
+
+      _ = (p
+           | test_stream
+           | beam.ParDo(SetStateClearingStatefulDoFn())
+           | beam.ParDo(self.record_dofn()))
+
+    self.assertEqual(['different-value'],
+                     StatefulDoFnOnDirectRunnerTest.all_records)
+
+  def test_stateful_set_state_portably(self):
+
+    class SetStatefulDoFn(beam.DoFn):
+
+      SET_STATE = SetStateSpec('buffer', VarIntCoder())
+
+      def process(self,
+                  element,
+                  set_state=beam.DoFn.StateParam(SET_STATE)):
+        _, value = element
+        aggregated_value = 0
+        set_state.add(value)
+        for saved_value in set_state.read():
+          aggregated_value += saved_value
+        yield aggregated_value
+
+    p = TestPipeline()
+    values = p | beam.Create([('key', 1),
+                              ('key', 2),
+                              ('key', 3),
+                              ('key', 4),
+                              ('key', 3)])
+    actual_values = (values
+                     | beam.ParDo(SetStatefulDoFn()))
+
+    assert_that(actual_values, equal_to([1, 3, 6, 10, 10]))
+
+    result = p.run()
+    result.wait_until_finish()
+
+  def test_stateful_set_state_clean_portably(self):
+
+    class SetStateClearingStatefulDoFn(beam.DoFn):
+
+      SET_STATE = SetStateSpec('buffer', VarIntCoder())
+      EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK)
+
+      def process(self,
+                  element,
+                  set_state=beam.DoFn.StateParam(SET_STATE),
+                  emit_timer=beam.DoFn.TimerParam(EMIT_TIMER)):
+        _, value = element
+        set_state.add(value)
+
+        all_elements = [element for element in set_state.read()]
+
+        if len(all_elements) == 5:
+          set_state.clear()
+          set_state.add(100)
+          emit_timer.set(1)
+
+      @on_timer(EMIT_TIMER)
+      def emit_values(self, set_state=beam.DoFn.StateParam(SET_STATE)):
+        yield sorted(set_state.read())
+
+    p = TestPipeline()
+    values = p | beam.Create([('key', 1),
+                              ('key', 2),
+                              ('key', 3),
+                              ('key', 4),
+                              ('key', 5)])
+    actual_values = (values
+                     | beam.Map(lambda t: window.TimestampedValue(t, 1))
+                     | beam.WindowInto(window.FixedWindows(1))
+                     | beam.ParDo(SetStateClearingStatefulDoFn()))
+
+    assert_that(actual_values, equal_to([[100]]))
+
+    result = p.run()
+    result.wait_until_finish()
+
   def test_stateful_dofn_nonkeyed_input(self):
     p = TestPipeline()
     values = p | beam.Create([1, 2, 3])


Mime
View raw message