beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From "ASF GitHub Bot (JIRA)" <j...@apache.org>
Subject [jira] [Work logged] (BEAM-5264) Reference DirectRunner implementation of Python user state and timers API
Date Tue, 18 Sep 2018 16:36:00 GMT

     [ https://issues.apache.org/jira/browse/BEAM-5264?focusedWorklogId=145357&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-145357 ]

ASF GitHub Bot logged work on BEAM-5264:
----------------------------------------

                Author: ASF GitHub Bot
            Created on: 18/Sep/18 16:35
            Start Date: 18/Sep/18 16:35
    Worklog Time Spent: 10m 
      Work Description: charlesccychen closed pull request #6304: [BEAM-5264] Reference DirectRunner implementation of Python User State and Timers API
URL: https://github.com/apache/beam/pull/6304
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/sdks/python/apache_beam/runners/common.pxd b/sdks/python/apache_beam/runners/common.pxd
index a53f604e7e8..49f4c44aba1 100644
--- a/sdks/python/apache_beam/runners/common.pxd
+++ b/sdks/python/apache_beam/runners/common.pxd
@@ -34,6 +34,9 @@ cdef class MethodWrapper(object):
   cdef public object args
   cdef public object defaults
   cdef public object method_value
+  cdef bint has_userstate_arguments
+  cdef object state_args_to_replace
+  cdef object timer_args_to_replace
 
 
 cdef class DoFnSignature(object):
@@ -45,11 +48,14 @@ cdef class DoFnSignature(object):
   cdef public MethodWrapper create_tracker_method
   cdef public MethodWrapper split_method
   cdef public object do_fn
+  cdef public object timer_methods
+  cdef bint _is_stateful_dofn
 
 
 cdef class DoFnInvoker(object):
   cdef public DoFnSignature signature
   cdef OutputProcessor output_processor
+  cdef object user_state_context
 
   cpdef invoke_process(self, WindowedValue windowed_value,
                        restriction_tracker=*,
diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py
index a714eaba121..2186df4e635 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -38,8 +38,8 @@
 from apache_beam.pvalue import TaggedOutput
 from apache_beam.transforms import DoFn
 from apache_beam.transforms import core
+from apache_beam.transforms import userstate
 from apache_beam.transforms.core import RestrictionProvider
-from apache_beam.transforms.userstate import UserStateUtils
 from apache_beam.transforms.window import GlobalWindow
 from apache_beam.transforms.window import TimestampedValue
 from apache_beam.transforms.window import WindowFn
@@ -159,6 +159,29 @@ def __init__(self, obj_to_invoke, method_name):
     self.args = args
     self.defaults = defaults
 
+    self.has_userstate_arguments = False
+    self.state_args_to_replace = {}
+    self.timer_args_to_replace = {}
+    for kw, v in zip(args[-len(defaults):], defaults):
+      if isinstance(v, core.DoFn.StateParam):
+        self.state_args_to_replace[kw] = v.state_spec
+        self.has_userstate_arguments = True
+      elif isinstance(v, core.DoFn.TimerParam):
+        self.timer_args_to_replace[kw] = v.timer_spec
+        self.has_userstate_arguments = True
+
+  def invoke_timer_callback(self, user_state_context, key, window):
+    # TODO(ccy): support WindowParam, TimestampParam and side inputs.
+    if self.has_userstate_arguments:
+      kwargs = {}
+      for kw, state_spec in self.state_args_to_replace.items():
+        kwargs[kw] = user_state_context.get_state(state_spec, key, window)
+      for kw, timer_spec in self.timer_args_to_replace.items():
+        kwargs[kw] = user_state_context.get_timer(timer_spec, key, window)
+      return self.method_value(**kwargs)
+    else:
+      return self.method_value()
+
 
 class DoFnSignature(object):
   """Represents the signature of a given ``DoFn`` object.
@@ -198,6 +221,16 @@ def __init__(self, do_fn):
 
     self._validate()
 
+    # Handle stateful DoFns.
+    self._is_stateful_dofn = userstate.is_stateful_dofn(do_fn)
+    self.timer_methods = {}
+    if self._is_stateful_dofn:
+      # Populate timer firing methods, keyed by TimerSpec.
+      _, all_timer_specs = userstate.get_dofn_specs(do_fn)
+      for timer_spec in all_timer_specs:
+        method = timer_spec._attached_callback
+        self.timer_methods[timer_spec] = MethodWrapper(do_fn, method.__name__)
+
   def _get_restriction_provider(self, do_fn):
     result = _find_param_with_default(self.process_method,
                                       default_as_type=RestrictionProvider)
@@ -229,12 +262,15 @@ def _validate_bundle_method(self, method_wrapper):
             (param, method_wrapper))
 
   def _validate_stateful_dofn(self):
-    UserStateUtils.validate_stateful_dofn(self.do_fn)
+    userstate.validate_stateful_dofn(self.do_fn)
 
   def is_splittable_dofn(self):
     return any([isinstance(default, RestrictionProvider) for default in
                 self.process_method.defaults])
 
+  def is_stateful_dofn(self):
+    return self._is_stateful_dofn
+
 
 class DoFnInvoker(object):
   """An abstraction that can be used to execute DoFn methods.
@@ -245,13 +281,15 @@ class DoFnInvoker(object):
   def __init__(self, output_processor, signature):
     self.output_processor = output_processor
     self.signature = signature
+    self.user_state_context = None
 
   @staticmethod
   def create_invoker(
       signature,
       output_processor=None,
       context=None, side_inputs=None, input_args=None, input_kwargs=None,
-      process_invocation=True):
+      process_invocation=True,
+      user_state_context=None):
     """ Creates a new DoFnInvoker based on given arguments.
 
     Args:
@@ -271,18 +309,21 @@ def create_invoker(
         process_invocation: If True, this function may return an invoker that
                             performs extra optimizations for invoking process()
                             method efficiently.
+        user_state_context: The UserStateContext instance for the current
+                            Stateful DoFn.
     """
     side_inputs = side_inputs or []
     default_arg_values = signature.process_method.defaults
     use_simple_invoker = not process_invocation or (
         not side_inputs and not input_args and not input_kwargs and
-        not default_arg_values)
+        not default_arg_values and not signature.is_stateful_dofn())
     if use_simple_invoker:
       return SimpleInvoker(output_processor, signature)
     else:
       return PerWindowInvoker(
           output_processor,
-          signature, context, side_inputs, input_args, input_kwargs)
+          signature, context, side_inputs, input_args, input_kwargs,
+          user_state_context)
 
   def invoke_process(self, windowed_value, restriction_tracker=None,
                      output_processor=None,
@@ -313,6 +354,12 @@ def invoke_finish_bundle(self):
     self.output_processor.finish_bundle_outputs(
         self.signature.finish_bundle_method.method_value())
 
+  def invoke_user_timer(self, timer_spec, key, window, timestamp):
+    self.output_processor.process_outputs(
+        WindowedValue(None, timestamp, (window,)),
+        self.signature.timer_methods[timer_spec].invoke_timer_callback(
+            self.user_state_context, key, window))
+
   def invoke_split(self, element, restriction):
     return self.signature.split_method.method_value(element, restriction)
 
@@ -368,7 +415,7 @@ class PerWindowInvoker(DoFnInvoker):
   """An invoker that processes elements considering windowing information."""
 
   def __init__(self, output_processor, signature, context,
-               side_inputs, input_args, input_kwargs):
+               side_inputs, input_args, input_kwargs, user_state_context):
     super(PerWindowInvoker, self).__init__(output_processor, signature)
     self.side_inputs = side_inputs
     self.context = context
@@ -376,7 +423,9 @@ def __init__(self, output_processor, signature, context,
     default_arg_values = signature.process_method.defaults
     self.has_windowed_inputs = (
         not all(si.is_globally_windowed() for si in side_inputs) or
-        (core.DoFn.WindowParam in default_arg_values))
+        (core.DoFn.WindowParam in default_arg_values) or
+        signature.is_stateful_dofn())
+    self.user_state_context = user_state_context
 
     # Try to prepare all the arguments that can just be filled in
     # without any additional work. in the process function.
@@ -423,6 +472,10 @@ def __init__(self, placeholder):
         except StopIteration:
           if a not in input_kwargs:
             raise ValueError("Value for sideinput %s not provided" % a)
+      elif isinstance(d, core.DoFn.StateParam):
+        args_with_placeholders.append(ArgPlaceholder(d))
+      elif isinstance(d, core.DoFn.TimerParam):
+        args_with_placeholders.append(ArgPlaceholder(d))
       else:
         # If no more args are present then the value must be passed via kwarg
         try:
@@ -498,6 +551,19 @@ def _invoke_per_window(
     else:
       args_for_process, kwargs_for_process = (
           self.args_for_process, self.kwargs_for_process)
+
+    # Extract key in the case of a stateful DoFn. Note that in the case of a
+    # stateful DoFn, we set during __init__ self.has_windowed_inputs to be
+    # True. Therefore, windows will be exploded coming into this method, and
+    # we can rely on the window variable being set above.
+    if self.user_state_context:
+      try:
+        key, unused_value = windowed_value.value
+      except (TypeError, ValueError):
+        raise ValueError(
+            ('Input value to a stateful DoFn must be a KV tuple; instead, '
+             'got %s.') % (windowed_value.value,))
+
     # TODO(sourabhbajaj): Investigate why we can't use `is` instead of ==
     for i, p in self.placeholders:
       if p == core.DoFn.ElementParam:
@@ -506,6 +572,12 @@ def _invoke_per_window(
         args_for_process[i] = window
       elif p == core.DoFn.TimestampParam:
         args_for_process[i] = windowed_value.timestamp
+      elif isinstance(p, core.DoFn.StateParam):
+        args_for_process[i] = (
+            self.user_state_context.get_state(p.state_spec, key, window))
+      elif isinstance(p, core.DoFn.TimerParam):
+        args_for_process[i] = (
+            self.user_state_context.get_timer(p.timer_spec, key, window))
 
     if additional_kwargs:
       if kwargs_for_process is None:
@@ -540,7 +612,8 @@ def __init__(self,
                logging_context=None,
                state=None,
                scoped_metrics_container=None,
-               operation_name=None):
+               operation_name=None,
+               user_state_context=None):
     """Initializes a DoFnRunner.
 
     Args:
@@ -555,6 +628,8 @@ def __init__(self,
       state: handle for accessing DoFn state
       scoped_metrics_container: DEPRECATED
       operation_name: The system name assigned by the runner for this operation.
+      user_state_context: The UserStateContext instance for the current
+                          Stateful DoFn.
     """
     # Need to support multiple iterations.
     side_inputs = list(side_inputs)
@@ -581,9 +656,15 @@ def __init__(self,
         windowing.windowfn, main_receivers, tagged_receivers,
         per_element_output_counter)
 
+    if do_fn_signature.is_stateful_dofn() and not user_state_context:
+      raise Exception(
+          'Requested execution of a stateful DoFn, but no user state context '
+          'is available. This likely means that the current runner does not '
+          'support the execution of stateful DoFns.')
+
     self.do_fn_invoker = DoFnInvoker.create_invoker(
         do_fn_signature, output_processor, self.context, side_inputs, args,
-        kwargs)
+        kwargs, user_state_context=user_state_context)
 
   def receive(self, windowed_value):
     self.process(windowed_value)
@@ -594,6 +675,12 @@ def process(self, windowed_value):
     except BaseException as exn:
       self._reraise_augmented(exn)
 
+  def process_user_timer(self, timer_spec, key, window, timestamp):
+    try:
+      self.do_fn_invoker.invoke_user_timer(timer_spec, key, window, timestamp)
+    except BaseException as exn:
+      self._reraise_augmented(exn)
+
   def _invoke_bundle_method(self, bundle_method):
     try:
       self.context.set_element(None)
diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py
index 4131c9b0c65..a19b50f3a7b 100644
--- a/sdks/python/apache_beam/runners/direct/direct_runner.py
+++ b/sdks/python/apache_beam/runners/direct/direct_runner.py
@@ -105,6 +105,9 @@ def visit_transform(self, applied_ptransform):
           # The FnApiRunner does not support execution of SplittableDoFns.
           if DoFnSignature(dofn).is_splittable_dofn():
             self.supported_by_fnapi_runner = False
+          # The FnApiRunner does not support execution of Stateful DoFns.
+          if DoFnSignature(dofn).is_stateful_dofn():
+            self.supported_by_fnapi_runner = False
           # The FnApiRunner does not support execution of CombineFns with
           # deferred side inputs.
           if isinstance(dofn, CombineValuesDoFn):
diff --git a/sdks/python/apache_beam/runners/direct/direct_userstate.py b/sdks/python/apache_beam/runners/direct/direct_userstate.py
new file mode 100644
index 00000000000..f0fd9b8e91f
--- /dev/null
+++ b/sdks/python/apache_beam/runners/direct/direct_userstate.py
@@ -0,0 +1,110 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Support for user state in the BundleBasedDirectRunner."""
+from __future__ import absolute_import
+
+from apache_beam.transforms import userstate
+from apache_beam.transforms.trigger import _ListStateTag
+
+
+class DirectUserStateContext(userstate.UserStateContext):
+  """userstate.UserStateContext for the BundleBasedDirectRunner.
+
+  The DirectUserStateContext buffers up updates that are to be committed
+  by the TransformEvaluator after running a DoFn.
+  """
+
+  def __init__(self, step_context, dofn, key_coder):
+    self.step_context = step_context
+    self.dofn = dofn
+    self.key_coder = key_coder
+
+    self.all_state_specs, self.all_timer_specs = (
+        userstate.get_dofn_specs(dofn))
+    self.state_tags = {}
+    for state_spec in self.all_state_specs:
+      state_key = 'user/%s' % state_spec.name
+      if isinstance(state_spec, userstate.BagStateSpec):
+        state_tag = _ListStateTag(state_key)
+      elif isinstance(state_spec, userstate.CombiningValueStateSpec):
+        state_tag = _ListStateTag(state_key)
+      else:
+        raise ValueError('Invalid state spec: %s' % state_spec)
+      self.state_tags[state_spec] = state_tag
+
+    self.cached_states = {}
+    self.cached_timers = {}
+
+  def get_timer(self, timer_spec, key, window):
+    assert timer_spec in self.all_timer_specs
+    encoded_key = self.key_coder.encode(key)
+    cache_key = (encoded_key, window, timer_spec)
+    if cache_key not in self.cached_timers:
+      self.cached_timers[cache_key] = userstate.RuntimeTimer(timer_spec)
+    return self.cached_timers[cache_key]
+
+  def get_state(self, state_spec, key, window):
+    assert state_spec in self.all_state_specs
+    encoded_key = self.key_coder.encode(key)
+    cache_key = (encoded_key, window, state_spec)
+    if cache_key not in self.cached_states:
+      state_tag = self.state_tags[state_spec]
+      value_accessor = (
+          lambda: self._get_underlying_state(state_spec, key, window))
+      self.cached_states[cache_key] = userstate.RuntimeState.for_spec(
+          state_spec, state_tag, value_accessor)
+    return self.cached_states[cache_key]
+
+  def _get_underlying_state(self, state_spec, key, window):
+    state_tag = self.state_tags[state_spec]
+    encoded_key = self.key_coder.encode(key)
+    return (self.step_context.get_keyed_state(encoded_key)
+            .get_state(window, state_tag))
+
+  def commit(self):
+    # Commit state modifications.
+    for cache_key, runtime_state in self.cached_states.items():
+      encoded_key, window, state_spec = cache_key
+      state = self.step_context.get_keyed_state(encoded_key)
+      state_tag = self.state_tags[state_spec]
+      if isinstance(state_spec, userstate.BagStateSpec):
+        if runtime_state._cleared:
+          state.clear_state(window, state_tag)
+        for new_value in runtime_state._new_values:
+          state.add_state(window, state_tag, new_value)
+      elif isinstance(state_spec, userstate.CombiningValueStateSpec):
+        if runtime_state._modified:
+          state.clear_state(window, state_tag)
+          state.add_state(
+              window, state_tag,
+              state_spec.coder.encode(runtime_state._current_accumulator))
+      else:
+        raise ValueError('Invalid state spec: %s' % state_spec)
+
+    # Commit new timers.
+    for cache_key, runtime_timer in self.cached_timers.items():
+      encoded_key, window, timer_spec = cache_key
+      state = self.step_context.get_keyed_state(encoded_key)
+      timer_name = 'user/%s' % timer_spec.name
+      if runtime_timer._cleared:
+        state.clear_timer(window, timer_name, timer_spec.time_domain)
+      if runtime_timer._new_timestamp is not None:
+        # TODO(ccy): add corresponding watermark holds after the DirectRunner
+        # allows for keyed watermark holds.
+        state.set_timer(window, timer_name, timer_spec.time_domain,
+                        runtime_timer._new_timestamp)
diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py
index 22aedce2bb0..ef12e2cb02f 100644
--- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py
+++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py
@@ -20,6 +20,7 @@
 from __future__ import absolute_import
 
 import collections
+import logging
 import random
 import time
 from builtins import object
@@ -38,6 +39,7 @@
 from apache_beam.runners.direct.direct_runner import _DirectReadFromPubSub
 from apache_beam.runners.direct.direct_runner import _StreamingGroupAlsoByWindow
 from apache_beam.runners.direct.direct_runner import _StreamingGroupByKeyOnly
+from apache_beam.runners.direct.direct_userstate import DirectUserStateContext
 from apache_beam.runners.direct.sdf_direct_runner import ProcessElements
 from apache_beam.runners.direct.sdf_direct_runner import ProcessFn
 from apache_beam.runners.direct.sdf_direct_runner import SDFProcessElementInvoker
@@ -53,6 +55,8 @@
 from apache_beam.transforms.trigger import _CombiningValueStateTag
 from apache_beam.transforms.trigger import _ListStateTag
 from apache_beam.transforms.trigger import create_trigger_driver
+from apache_beam.transforms.userstate import get_dofn_specs
+from apache_beam.transforms.userstate import is_stateful_dofn
 from apache_beam.transforms.window import GlobalWindows
 from apache_beam.transforms.window import WindowedValue
 from apache_beam.typehints.typecheck import TypeCheckError
@@ -142,11 +146,16 @@ def should_execute_serially(self, applied_ptransform):
     Returns:
       True if executor should execute applied_ptransform serially.
     """
-    return isinstance(applied_ptransform.transform,
-                      (core._GroupByKeyOnly,
-                       _StreamingGroupByKeyOnly,
-                       _StreamingGroupAlsoByWindow,
-                       _NativeWrite))
+    if isinstance(applied_ptransform.transform,
+                  (core._GroupByKeyOnly,
+                   _StreamingGroupByKeyOnly,
+                   _StreamingGroupAlsoByWindow,
+                   _NativeWrite)):
+      return True
+    elif (isinstance(applied_ptransform.transform, core.ParDo) and
+          is_stateful_dofn(applied_ptransform.transform.dofn)):
+      return True
+    return False
 
 
 class RootBundleProvider(object):
@@ -199,6 +208,7 @@ def __init__(self, evaluation_context, applied_ptransform,
     self._expand_outputs()
     self._execution_context = evaluation_context.get_execution_context(
         applied_ptransform)
+    self._step_context = self._execution_context.get_step_context()
 
   def _expand_outputs(self):
     outputs = set()
@@ -252,7 +262,7 @@ def process_timer_wrapper(self, timer_firing):
     timer and passes it to process_element().  Evaluator subclasses which
     desire different timer delivery semantics can override process_timer().
     """
-    state = self.step_context.get_keyed_state(timer_firing.encoded_key)
+    state = self._step_context.get_keyed_state(timer_firing.encoded_key)
     state.clear_timer(
         timer_firing.window, timer_firing.name, timer_firing.time_domain)
     self.process_timer(timer_firing)
@@ -565,15 +575,40 @@ def start_bundle(self):
     args = transform.args if hasattr(transform, 'args') else []
     kwargs = transform.kwargs if hasattr(transform, 'kwargs') else {}
 
+    self.user_state_context = None
+    self.user_timer_map = {}
+    if is_stateful_dofn(dofn):
+      kv_type_hint = self._applied_ptransform.inputs[0].element_type
+      if kv_type_hint and kv_type_hint != typehints.Any:
+        coder = coders.registry.get_coder(kv_type_hint)
+        self.key_coder = coder.key_coder()
+      else:
+        self.key_coder = coders.registry.get_coder(typehints.Any)
+
+      self.user_state_context = DirectUserStateContext(
+          self._step_context, dofn, self.key_coder)
+      _, all_timer_specs = get_dofn_specs(dofn)
+      for timer_spec in all_timer_specs:
+        self.user_timer_map['user/%s' % timer_spec.name] = timer_spec
+
     self.runner = DoFnRunner(
         dofn, args, kwargs,
         self._side_inputs,
         self._applied_ptransform.inputs[0].windowing,
         tagged_receivers=self._tagged_receivers,
         step_name=self._applied_ptransform.full_label,
-        state=DoFnState(self._counter_factory))
+        state=DoFnState(self._counter_factory),
+        user_state_context=self.user_state_context)
     self.runner.start()
 
+  def process_timer(self, timer_firing):
+    if timer_firing.name not in self.user_timer_map:
+      logging.warning('Unknown timer fired: %s', timer_firing)
+    timer_spec = self.user_timer_map[timer_firing.name]
+    self.runner.process_user_timer(
+        timer_spec, self.key_coder.decode(timer_firing.encoded_key),
+        timer_firing.window, timer_firing.timestamp)
+
   def process_element(self, element):
     self.runner.process(element)
 
@@ -581,6 +616,8 @@ def finish_bundle(self):
     self.runner.finish()
     bundles = list(self._tagged_receivers.values())
     result_counters = self._counter_factory.get_counters()
+    if self.user_state_context:
+      self.user_state_context.commit()
     return TransformResult(
         self, bundles, [], result_counters, None)
 
@@ -604,8 +641,7 @@ def _is_final_bundle(self):
             == WatermarkManager.WATERMARK_POS_INF)
 
   def start_bundle(self):
-    self.step_context = self._execution_context.get_step_context()
-    self.global_state = self.step_context.get_keyed_state(None)
+    self.global_state = self._step_context.get_keyed_state(None)
 
     assert len(self._outputs) == 1
     self.output_pcollection = list(self._outputs)[0]
@@ -630,7 +666,7 @@ def process_element(self, element):
         and len(element.value) == 2):
       k, v = element.value
       encoded_k = self.key_coder.encode(k)
-      state = self.step_context.get_keyed_state(encoded_k)
+      state = self._step_context.get_keyed_state(encoded_k)
       state.add_state(None, _GroupByKeyOnlyEvaluator.ELEMENTS_TAG, v)
     else:
       raise TypeCheckError('Input to _GroupByKeyOnly must be a PCollection of '
@@ -648,12 +684,12 @@ def finish_bundle(self):
         gbk_result = []
         # TODO(ccy): perhaps we can clean this up to not use this
         # internal attribute of the DirectStepContext.
-        for encoded_k in self.step_context.existing_keyed_state:
+        for encoded_k in self._step_context.existing_keyed_state:
           # Ignore global state.
           if encoded_k is None:
             continue
           k = self.key_coder.decode(encoded_k)
-          state = self.step_context.get_keyed_state(encoded_k)
+          state = self._step_context.get_keyed_state(encoded_k)
           vs = state.get_state(None, _GroupByKeyOnlyEvaluator.ELEMENTS_TAG)
           gbk_result.append(GlobalWindows.windowed_value((k, vs)))
 
@@ -749,7 +785,6 @@ def __init__(self, evaluation_context, applied_ptransform,
   def start_bundle(self):
     assert len(self._outputs) == 1
     self.output_pcollection = list(self._outputs)[0]
-    self.step_context = self._execution_context.get_step_context()
     self.driver = create_trigger_driver(
         self._applied_ptransform.transform.windowing,
         clock=self._evaluation_context._watermark_manager._clock)
@@ -769,7 +804,7 @@ def process_element(self, element):
     encoded_k, timer_firings, vs = (
         kwi.encoded_key, kwi.timer_firings, kwi.elements)
     k = self.key_coder.decode(encoded_k)
-    state = self.step_context.get_keyed_state(encoded_k)
+    state = self._step_context.get_keyed_state(encoded_k)
 
     for timer_firing in timer_firings:
       for wvalue in self.driver.process_timer(
@@ -819,8 +854,7 @@ def _has_already_produced_output(self):
             == WatermarkManager.WATERMARK_POS_INF)
 
   def start_bundle(self):
-    self.step_context = self._execution_context.get_step_context()
-    self.global_state = self.step_context.get_keyed_state(None)
+    self.global_state = self._step_context.get_keyed_state(None)
 
   def process_timer(self, timer_firing):
     # We do not need to emit a KeyedWorkItem to process_element().
@@ -885,8 +919,7 @@ def __init__(self, evaluation_context, applied_ptransform,
 
     assert isinstance(self._process_fn, ProcessFn)
 
-    self.step_context = self._execution_context.get_step_context()
-    self._process_fn.step_context = self.step_context
+    self._process_fn.step_context = self._step_context
 
     process_element_invoker = (
         SDFProcessElementInvoker(
@@ -917,7 +950,7 @@ def process_element(self, element):
 
     self._par_do_evaluator.process_element(element)
 
-    state = self.step_context.get_keyed_state(key)
+    state = self._step_context.get_keyed_state(key)
     self.keyed_holds[key] = state.get_state(
         window, self._process_fn.watermark_hold_tag)
 
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index fa867e5231d..438eee687da 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -21,6 +21,7 @@
 
 import copy
 import inspect
+import logging
 import random
 import re
 import types
@@ -850,7 +851,7 @@ def __init__(self, fn, *args, **kwargs):
 
     # Validate the DoFn by creating a DoFnSignature
     from apache_beam.runners.common import DoFnSignature
-    DoFnSignature(self.fn)
+    self._signature = DoFnSignature(self.fn)
 
   def default_type_hints(self):
     return self.fn.get_type_hints()
@@ -873,6 +874,27 @@ def display_data(self):
             'fn_dd': self.fn}
 
   def expand(self, pcoll):
+    # In the case of a stateful DoFn, warn if the key coder is not
+    # deterministic.
+    if self._signature.is_stateful_dofn():
+      kv_type_hint = pcoll.element_type
+      if kv_type_hint and kv_type_hint != typehints.Any:
+        coder = coders.registry.get_coder(kv_type_hint)
+        if not coder.is_kv_coder():
+          raise ValueError(
+              'Input elements to the transform %s with stateful DoFn must be '
+              'key-value pairs.' % self)
+        key_coder = coder.key_coder()
+      else:
+        key_coder = coders.registry.get_coder(typehints.Any)
+
+      if not key_coder.is_deterministic():
+        logging.warning(
+            'Key coder %s for transform %s with stateful DoFn may not '
+            'be deterministic. This may cause incorrect behavior for complex '
+            'key types. Consider adding an input type hint for this transform.',
+            key_coder, self)
+
     return pvalue.PCollection(pcoll.pipeline)
 
   def with_outputs(self, *tags, **main_kw):
diff --git a/sdks/python/apache_beam/transforms/userstate.py b/sdks/python/apache_beam/transforms/userstate.py
index 0f99da246a5..6c2eabcd558 100644
--- a/sdks/python/apache_beam/transforms/userstate.py
+++ b/sdks/python/apache_beam/transforms/userstate.py
@@ -22,6 +22,7 @@
 
 from __future__ import absolute_import
 
+import itertools
 import types
 from builtins import object
 
@@ -52,16 +53,17 @@ def __init__(self, name, coder):
 class CombiningValueStateSpec(StateSpec):
   """Specification for a user DoFn combining value state cell."""
 
-  def __init__(self, name, coder, combiner):
+  def __init__(self, name, coder, combine_fn):
     # Avoid circular import.
     from apache_beam.transforms.core import CombineFn
 
     assert isinstance(name, str)
     assert isinstance(coder, Coder)
-    assert isinstance(combiner, CombineFn)
+    assert isinstance(combine_fn, CombineFn)
     self.name = name
+    # The coder here should be for the accumulator type of the given CombineFn.
     self.coder = coder
-    self.combiner = combiner
+    self.combine_fn = combine_fn
 
 
 class TimerSpec(object):
@@ -107,57 +109,197 @@ def _inner(method):
   return _inner
 
 
-class UserStateUtils(object):
+def get_dofn_specs(dofn):
+  """Gets the state and timer specs for a DoFn, if any."""
 
-  @staticmethod
-  def validate_stateful_dofn(dofn):
-    # Avoid circular import.
-    from apache_beam.runners.common import MethodWrapper
-    from apache_beam.transforms.core import _DoFnParam
-    from apache_beam.transforms.core import _StateDoFnParam
-    from apache_beam.transforms.core import _TimerDoFnParam
-
-    all_state_specs = set()
-    all_timer_specs = set()
-
-    # Validate params to process(), start_bundle(), finish_bundle() and to
-    # any on_timer callbacks.
-    for method_name in dir(dofn):
-      if not isinstance(getattr(dofn, method_name, None), types.MethodType):
-        continue
-      method = MethodWrapper(dofn, method_name)
-      param_ids = [d.param_id for d in method.defaults
-                   if isinstance(d, _DoFnParam)]
-      if len(param_ids) != len(set(param_ids)):
-        raise ValueError(
-            'DoFn %r has duplicate %s method parameters: %s.' % (
-                dofn, method_name, param_ids))
-      for d in method.defaults:
-        if isinstance(d, _StateDoFnParam):
-          all_state_specs.add(d.state_spec)
-        elif isinstance(d, _TimerDoFnParam):
-          all_timer_specs.add(d.timer_spec)
-
-    # Reject DoFns that have multiple state or timer specs with the same name.
-    if len(all_state_specs) != len(set(s.name for s in all_state_specs)):
+  # Avoid circular import.
+  from apache_beam.runners.common import MethodWrapper
+  from apache_beam.transforms.core import _DoFnParam
+  from apache_beam.transforms.core import _StateDoFnParam
+  from apache_beam.transforms.core import _TimerDoFnParam
+
+  all_state_specs = set()
+  all_timer_specs = set()
+
+  # Validate params to process(), start_bundle(), finish_bundle() and to
+  # any on_timer callbacks.
+  for method_name in dir(dofn):
+    if not isinstance(getattr(dofn, method_name, None), types.MethodType):
+      continue
+    method = MethodWrapper(dofn, method_name)
+    param_ids = [d.param_id for d in method.defaults
+                 if isinstance(d, _DoFnParam)]
+    if len(param_ids) != len(set(param_ids)):
+      raise ValueError(
+          'DoFn %r has duplicate %s method parameters: %s.' % (
+              dofn, method_name, param_ids))
+    for d in method.defaults:
+      if isinstance(d, _StateDoFnParam):
+        all_state_specs.add(d.state_spec)
+      elif isinstance(d, _TimerDoFnParam):
+        all_timer_specs.add(d.timer_spec)
+
+  return all_state_specs, all_timer_specs
+
+
+def is_stateful_dofn(dofn):
+  """Determines whether a given DoFn is a stateful DoFn."""
+
+  # A Stateful DoFn is a DoFn that uses user state or timers.
+  all_state_specs, all_timer_specs = get_dofn_specs(dofn)
+  return bool(all_state_specs or all_timer_specs)
+
+
+def validate_stateful_dofn(dofn):
+  """Validates the proper specification of a stateful DoFn."""
+
+  # Get state and timer specs.
+  all_state_specs, all_timer_specs = get_dofn_specs(dofn)
+
+  # Reject DoFns that have multiple state or timer specs with the same name.
+  if len(all_state_specs) != len(set(s.name for s in all_state_specs)):
+    raise ValueError(
+        'DoFn %r has multiple StateSpecs with the same name: %s.' % (
+            dofn, all_state_specs))
+  if len(all_timer_specs) != len(set(s.name for s in all_timer_specs)):
+    raise ValueError(
+        'DoFn %r has multiple TimerSpecs with the same name: %s.' % (
+            dofn, all_timer_specs))
+
+  # Reject DoFns that use timer specs without corresponding timer callbacks.
+  for timer_spec in all_timer_specs:
+    if not timer_spec._attached_callback:
       raise ValueError(
-          'DoFn %r has multiple StateSpecs with the same name: %s.' % (
-              dofn, all_state_specs))
-    if len(all_timer_specs) != len(set(s.name for s in all_timer_specs)):
+          ('DoFn %r has a TimerSpec without an associated on_timer '
+           'callback: %s.') % (dofn, timer_spec))
+    method_name = timer_spec._attached_callback.__name__
+    if (timer_spec._attached_callback !=
+        getattr(dofn, method_name, None).__func__):
       raise ValueError(
-          'DoFn %r has multiple TimerSpecs with the same name: %s.' % (
-              dofn, all_timer_specs))
-
-    # Reject DoFns that use timer specs without corresponding timer callbacks.
-    for timer_spec in all_timer_specs:
-      if not timer_spec._attached_callback:
-        raise ValueError(
-            ('DoFn %r has a TimerSpec without an associated on_timer '
-             'callback: %s.') % (dofn, timer_spec))
-      method_name = timer_spec._attached_callback.__name__
-      if (timer_spec._attached_callback !=
-          getattr(dofn, method_name, None).__func__):
-        raise ValueError(
-            ('The on_timer callback for %s is not the specified .%s method '
-             'for DoFn %r (perhaps it was overwritten?).') % (
-                 timer_spec, method_name, dofn))
+          ('The on_timer callback for %s is not the specified .%s method '
+           'for DoFn %r (perhaps it was overwritten?).') % (
+               timer_spec, method_name, dofn))
+
+
+class RuntimeTimer(object):
+  """Timer interface object passed to user code."""
+
+  def __init__(self, timer_spec):
+    self._cleared = False
+    self._new_timestamp = None
+
+  def clear(self):
+    self._cleared = True
+    self._new_timestamp = None
+
+  def set(self, timestamp):
+    self._new_timestamp = timestamp
+
+
+class RuntimeState(object):
+  """State interface object passed to user code."""
+
+  def __init__(self, state_spec, state_tag, current_value_accessor):
+    self._state_spec = state_spec
+    self._state_tag = state_tag
+    self._current_value_accessor = current_value_accessor
+
+  @staticmethod
+  def for_spec(state_spec, state_tag, current_value_accessor):
+    if isinstance(state_spec, BagStateSpec):
+      return BagRuntimeState(state_spec, state_tag, current_value_accessor)
+    elif isinstance(state_spec, CombiningValueStateSpec):
+      return CombiningValueRuntimeState(state_spec, state_tag,
+                                        current_value_accessor)
+    else:
+      raise ValueError('Invalid state spec: %s' % state_spec)
+
+  def _encode(self, value):
+    return self._state_spec.coder.encode(value)
+
+  def _decode(self, value):
+    return self._state_spec.coder.decode(value)
+
+  def prefetch(self):
+    # The default implementation here does nothing.
+    pass
+
+
+# Sentinel designating an unread value.
+UNREAD_VALUE = object()
+
+
+class BagRuntimeState(RuntimeState):
+  """Bag state interface object passed to user code."""
+
+  def __init__(self, state_spec, state_tag, current_value_accessor):
+    super(BagRuntimeState, self).__init__(
+        state_spec, state_tag, current_value_accessor)
+    self._cached_value = UNREAD_VALUE
+    self._cleared = False
+    self._new_values = []
+
+  def read(self):
+    if self._cached_value is UNREAD_VALUE:
+      self._cached_value = self._current_value_accessor()
+    if not self._cleared:
+      encoded_values = itertools.chain(self._cached_value, self._new_values)
+    else:
+      encoded_values = self._new_values
+    return (self._decode(v) for v in encoded_values)
+
+  def add(self, value):
+    self._new_values.append(self._encode(value))
+
+  def clear(self):
+    self._cleared = True
+    self._cached_value = []
+    self._new_values = []
+
+
+class CombiningValueRuntimeState(RuntimeState):
+  """Combining value state interface object passed to user code."""
+
+  def __init__(self, state_spec, state_tag, current_value_accessor):
+    super(CombiningValueRuntimeState, self).__init__(
+        state_spec, state_tag, current_value_accessor)
+    self._current_accumulator = UNREAD_VALUE
+    self._modified = False
+    self._combine_fn = state_spec.combine_fn
+
+  def _read_initial_value(self):
+    if self._current_accumulator is UNREAD_VALUE:
+      existing_accumulators = list(
+          self._decode(a) for a in self._current_value_accessor())
+      if existing_accumulators:
+        self._current_accumulator = self._combine_fn.merge_accumulators(
+            existing_accumulators)
+      else:
+        self._current_accumulator = self._combine_fn.create_accumulator()
+
+  def read(self):
+    self._read_initial_value()
+    return self._combine_fn.extract_output(self._current_accumulator)
+
+  def add(self, value):
+    self._read_initial_value()
+    self._modified = True
+    self._current_accumulator = self._combine_fn.add_input(
+        self._current_accumulator, value)
+
+  def clear(self):
+    self._modified = True
+    self._current_accumulator = self._combine_fn.create_accumulator()
+
+
+class UserStateContext(object):
+  """Wrapper allowing user state and timers to be accessed by a DoFnInvoker."""
+
+  def get_timer(self, timer_spec, key, window):
+    raise NotImplementedError()
+
+  def get_state(self, state_spec, key, window):
+    raise NotImplementedError()
+
+  def commit(self):
+    raise NotImplementedError()
diff --git a/sdks/python/apache_beam/transforms/userstate_test.py b/sdks/python/apache_beam/transforms/userstate_test.py
index b891e628178..479e66c138c 100644
--- a/sdks/python/apache_beam/transforms/userstate_test.py
+++ b/sdks/python/apache_beam/transforms/userstate_test.py
@@ -22,17 +22,25 @@
 
 import mock
 
+import apache_beam as beam
 from apache_beam.coders import BytesCoder
+from apache_beam.coders import IterableCoder
 from apache_beam.coders import VarIntCoder
 from apache_beam.runners.common import DoFnSignature
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.test_stream import TestStream
+from apache_beam.transforms import userstate
+from apache_beam.transforms.combiners import ToListCombineFn
 from apache_beam.transforms.combiners import TopCombineFn
 from apache_beam.transforms.core import DoFn
 from apache_beam.transforms.timeutil import TimeDomain
 from apache_beam.transforms.userstate import BagStateSpec
 from apache_beam.transforms.userstate import CombiningValueStateSpec
 from apache_beam.transforms.userstate import TimerSpec
-from apache_beam.transforms.userstate import UserStateUtils
+from apache_beam.transforms.userstate import get_dofn_specs
+from apache_beam.transforms.userstate import is_stateful_dofn
 from apache_beam.transforms.userstate import on_timer
+from apache_beam.transforms.userstate import validate_stateful_dofn
 
 
 class TestStatefulDoFn(DoFn):
@@ -80,14 +88,14 @@ def _validate_dofn(self, dofn):
     # Construction of DoFnSignature performs validation of the given DoFn.
     # In particular, it ends up calling userstate._validate_stateful_dofn.
     # That behavior is explicitly tested below in test_validate_dofn()
-    DoFnSignature(dofn)
+    return DoFnSignature(dofn)
 
   @mock.patch(
-      'apache_beam.transforms.userstate.UserStateUtils.validate_stateful_dofn')
+      'apache_beam.transforms.userstate.validate_stateful_dofn')
   def test_validate_dofn(self, unused_mock):
     dofn = TestStatefulDoFn()
     self._validate_dofn(dofn)
-    UserStateUtils.validate_stateful_dofn.assert_called_with(dofn)
+    userstate.validate_stateful_dofn.assert_called_with(dofn)
 
   def test_spec_construction(self):
     BagStateSpec('statename', VarIntCoder())
@@ -116,6 +124,10 @@ def test_param_construction(self):
     with self.assertRaises(ValueError):
       DoFn.TimerParam(BagStateSpec('elements', BytesCoder()))
 
+  def test_stateful_dofn_detection(self):
+    self.assertFalse(is_stateful_dofn(DoFn()))
+    self.assertTrue(is_stateful_dofn(TestStatefulDoFn()))
+
   def test_good_signatures(self):
     class BasicStatefulDoFn(DoFn):
       BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
@@ -129,8 +141,36 @@ def process(self, element, buffer=DoFn.StateParam(BUFFER_STATE),
       def expiry_callback(self, element, timer=DoFn.TimerParam(EXPIRY_TIMER)):
         yield element
 
-    self._validate_dofn(BasicStatefulDoFn())
-    self._validate_dofn(TestStatefulDoFn())
+    # Validate get_dofn_specs() and timer callbacks in
+    # DoFnSignature.
+    stateful_dofn = BasicStatefulDoFn()
+    signature = self._validate_dofn(stateful_dofn)
+    expected_specs = (set([BasicStatefulDoFn.BUFFER_STATE]),
+                      set([BasicStatefulDoFn.EXPIRY_TIMER]))
+    self.assertEqual(expected_specs,
+                     get_dofn_specs(stateful_dofn))
+    self.assertEqual(
+        stateful_dofn.expiry_callback,
+        signature.timer_methods[BasicStatefulDoFn.EXPIRY_TIMER].method_value)
+
+    stateful_dofn = TestStatefulDoFn()
+    signature = self._validate_dofn(stateful_dofn)
+    expected_specs = (set([TestStatefulDoFn.BUFFER_STATE_1,
+                           TestStatefulDoFn.BUFFER_STATE_2]),
+                      set([TestStatefulDoFn.EXPIRY_TIMER_1,
+                           TestStatefulDoFn.EXPIRY_TIMER_2,
+                           TestStatefulDoFn.EXPIRY_TIMER_3]))
+    self.assertEqual(expected_specs,
+                     get_dofn_specs(stateful_dofn))
+    self.assertEqual(
+        stateful_dofn.on_expiry_1,
+        signature.timer_methods[TestStatefulDoFn.EXPIRY_TIMER_1].method_value)
+    self.assertEqual(
+        stateful_dofn.on_expiry_2,
+        signature.timer_methods[TestStatefulDoFn.EXPIRY_TIMER_2].method_value)
+    self.assertEqual(
+        stateful_dofn.on_expiry_3,
+        signature.timer_methods[TestStatefulDoFn.EXPIRY_TIMER_3].method_value)
 
   def test_bad_signatures(self):
     # (1) The same state parameter is duplicated on the process method.
@@ -222,7 +262,7 @@ def process(self, element,
       def on_expiry_1(self, buffer_state=DoFn.StateParam(BUFFER_STATE)):
         yield 'expired1'
 
-      # Note that we mistakenly reuse the "on_expiry_2" name; this is valid
+      # Note that we mistakenly reuse the "on_expiry_1" name; this is valid
       # syntactically in Python.
       @on_timer(EXPIRY_TIMER_2)
       def on_expiry_1(self, buffer_state=DoFn.StateParam(BUFFER_STATE)):
@@ -238,7 +278,7 @@ def __repr__(self):
         (r'The on_timer callback for TimerSpec\(expiry1\) is not the '
          r'specified .on_expiry_1 method for DoFn '
          r'StatefulDoFnWithTimerWithTypo2 \(perhaps it was overwritten\?\).')):
-      UserStateUtils.validate_stateful_dofn(dofn)
+      validate_stateful_dofn(dofn)
 
     # (2) Here, the user forgot to add an on_timer decorator for 'expiry2'
     class StatefulDoFnWithTimerWithTypo3(DoFn):
@@ -267,7 +307,226 @@ def __repr__(self):
         ValueError,
         (r'DoFn StatefulDoFnWithTimerWithTypo3 has a TimerSpec without an '
          r'associated on_timer callback: TimerSpec\(expiry2\).')):
-      UserStateUtils.validate_stateful_dofn(dofn)
+      validate_stateful_dofn(dofn)
+
+
+class StatefulDoFnOnDirectRunnerTest(unittest.TestCase):
+  # pylint: disable=expression-not-assigned
+  all_records = None
+
+  def setUp(self):
+    # Use state on the TestCase class, since other references would be pickled
+    # into a closure and not have the desired side effects.
+    #
+    # TODO(BEAM-5295): Use assert_that after it works for the cases here in
+    # streaming mode.
+    StatefulDoFnOnDirectRunnerTest.all_records = []
+
+  def record_dofn(self):
+    class RecordDoFn(DoFn):
+      def process(self, element):
+        StatefulDoFnOnDirectRunnerTest.all_records.append(element)
+
+    return RecordDoFn()
+
+  def test_simple_stateful_dofn(self):
+    class SimpleTestStatefulDoFn(DoFn):
+      BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
+      EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK)
+
+      def process(self, element, buffer=DoFn.StateParam(BUFFER_STATE),
+                  timer1=DoFn.TimerParam(EXPIRY_TIMER)):
+        unused_key, value = element
+        buffer.add('A' + str(value))
+        timer1.set(20)
+
+      @on_timer(EXPIRY_TIMER)
+      def expiry_callback(self, buffer=DoFn.StateParam(BUFFER_STATE),
+                          timer=DoFn.TimerParam(EXPIRY_TIMER)):
+        yield ''.join(sorted(buffer.read()))
+
+    with TestPipeline() as p:
+      test_stream = (TestStream()
+                     .advance_watermark_to(10)
+                     .add_elements([1, 2])
+                     .add_elements([3])
+                     .advance_watermark_to(25)
+                     .add_elements([4]))
+      (p
+       | test_stream
+       | beam.Map(lambda x: ('mykey', x))
+       | beam.ParDo(SimpleTestStatefulDoFn())
+       | beam.ParDo(self.record_dofn()))
+
+    # Two firings should occur: once after element 3 since the timer should
+    # fire after the watermark passes time 20, and another time after element
+    # 4, since the timer issued at that point should fire immediately.
+    self.assertEqual(
+        ['A1A2A3', 'A1A2A3A4'],
+        StatefulDoFnOnDirectRunnerTest.all_records)
+
+  def test_stateful_dofn_nonkeyed_input(self):
+    p = TestPipeline()
+    values = p | beam.Create([1, 2, 3])
+    with self.assertRaisesRegexp(
+        ValueError,
+        ('Input elements to the transform .* with stateful DoFn must be '
+         'key-value pairs.')):
+      values | beam.ParDo(TestStatefulDoFn())
+
+  def test_simple_stateful_dofn_combining(self):
+    class SimpleTestStatefulDoFn(DoFn):
+      BUFFER_STATE = CombiningValueStateSpec(
+          'buffer',
+          IterableCoder(VarIntCoder()), ToListCombineFn())
+      EXPIRY_TIMER = TimerSpec('expiry1', TimeDomain.WATERMARK)
+
+      def process(self, element, buffer=DoFn.StateParam(BUFFER_STATE),
+                  timer1=DoFn.TimerParam(EXPIRY_TIMER)):
+        unused_key, value = element
+        buffer.add(value)
+        timer1.set(20)
+
+      @on_timer(EXPIRY_TIMER)
+      def expiry_callback(self, buffer=DoFn.StateParam(BUFFER_STATE),
+                          timer=DoFn.TimerParam(EXPIRY_TIMER)):
+        yield ''.join(str(x) for x in sorted(buffer.read()))
+
+    with TestPipeline() as p:
+      test_stream = (TestStream()
+                     .advance_watermark_to(10)
+                     .add_elements([1, 2])
+                     .add_elements([3])
+                     .advance_watermark_to(25)
+                     .add_elements([4]))
+      (p
+       | test_stream
+       | beam.Map(lambda x: ('mykey', x))
+       | beam.ParDo(SimpleTestStatefulDoFn())
+       | beam.ParDo(self.record_dofn()))
+
+    self.assertEqual(
+        ['123', '1234'],
+        StatefulDoFnOnDirectRunnerTest.all_records)
+
+  def test_timer_output_timestamp(self):
+    class TimerEmittingStatefulDoFn(DoFn):
+      EMIT_TIMER_1 = TimerSpec('emit1', TimeDomain.WATERMARK)
+      EMIT_TIMER_2 = TimerSpec('emit2', TimeDomain.WATERMARK)
+      EMIT_TIMER_3 = TimerSpec('emit3', TimeDomain.WATERMARK)
+
+      def process(self, element,
+                  timer1=DoFn.TimerParam(EMIT_TIMER_1),
+                  timer2=DoFn.TimerParam(EMIT_TIMER_2),
+                  timer3=DoFn.TimerParam(EMIT_TIMER_3)):
+        timer1.set(10)
+        timer2.set(20)
+        timer3.set(30)
+
+      @on_timer(EMIT_TIMER_1)
+      def emit_callback_1(self):
+        yield 'timer1'
+
+      @on_timer(EMIT_TIMER_2)
+      def emit_callback_2(self):
+        yield 'timer2'
+
+      @on_timer(EMIT_TIMER_3)
+      def emit_callback_3(self):
+        yield 'timer3'
+
+    class TimestampReifyingDoFn(DoFn):
+      def process(self, element, ts=DoFn.TimestampParam):
+        yield (element, int(ts))
+
+    with TestPipeline() as p:
+      test_stream = (TestStream()
+                     .advance_watermark_to(10)
+                     .add_elements([1]))
+      (p
+       | test_stream
+       | beam.Map(lambda x: ('mykey', x))
+       | beam.ParDo(TimerEmittingStatefulDoFn())
+       | beam.ParDo(TimestampReifyingDoFn())
+       | beam.ParDo(self.record_dofn()))
+
+    self.assertEqual(
+        [('timer1', 10), ('timer2', 20), ('timer3', 30)],
+        sorted(StatefulDoFnOnDirectRunnerTest.all_records))
+
+  def test_index_assignment(self):
+    class IndexAssigningStatefulDoFn(DoFn):
+      INDEX_STATE = BagStateSpec('index', VarIntCoder())
+
+      def process(self, element, state=DoFn.StateParam(INDEX_STATE)):
+        unused_key, value = element
+        next_index, = list(state.read()) or [0]
+        yield (value, next_index)
+        state.clear()
+        state.add(next_index + 1)
+
+    with TestPipeline() as p:
+      test_stream = (TestStream()
+                     .advance_watermark_to(10)
+                     .add_elements(['A', 'B'])
+                     .add_elements(['C'])
+                     .advance_watermark_to(25)
+                     .add_elements(['D']))
+      (p
+       | test_stream
+       | beam.Map(lambda x: ('mykey', x))
+       | beam.ParDo(IndexAssigningStatefulDoFn())
+       | beam.ParDo(self.record_dofn()))
+
+    self.assertEqual(
+        [('A', 0), ('B', 1), ('C', 2), ('D', 3)],
+        StatefulDoFnOnDirectRunnerTest.all_records)
+
+  def test_hash_join(self):
+    class HashJoinStatefulDoFn(DoFn):
+      BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
+      UNMATCHED_TIMER = TimerSpec('unmatched', TimeDomain.WATERMARK)
+
+      def process(self, element, state=DoFn.StateParam(BUFFER_STATE),
+                  timer=DoFn.TimerParam(UNMATCHED_TIMER)):
+        key, value = element
+        existing_values = list(state.read())
+        if not existing_values:
+          state.add(value)
+          timer.set(100)
+        else:
+          yield 'Record<%s,%s,%s>' % (key, existing_values[0], value)
+          state.clear()
+          timer.clear()
+
+      @on_timer(UNMATCHED_TIMER)
+      def expiry_callback(self, state=DoFn.StateParam(BUFFER_STATE)):
+        buffered = list(state.read())
+        assert len(buffered) == 1, buffered
+        state.clear()
+        yield 'Unmatched<%s>' % (buffered[0],)
+
+    with TestPipeline() as p:
+      test_stream = (TestStream()
+                     .advance_watermark_to(10)
+                     .add_elements([('A', 'a'), ('B', 'b')])
+                     .add_elements([('A', 'aa'), ('C', 'c')])
+                     .advance_watermark_to(25)
+                     .add_elements([('A', 'aaa'), ('B', 'bb')])
+                     .add_elements([('D', 'd'), ('D', 'dd'), ('D', 'ddd'),
+                                    ('D', 'dddd')])
+                     .advance_watermark_to(125)
+                     .add_elements([('C', 'cc')]))
+      (p
+       | test_stream
+       | beam.ParDo(HashJoinStatefulDoFn())
+       | beam.ParDo(self.record_dofn()))
+
+    self.assertEqual(
+        ['Record<A,a,aa>', 'Record<B,b,bb>', 'Record<D,d,dd>',
+         'Record<D,ddd,dddd>', 'Unmatched<aaa>', 'Unmatched<c>',
+         'Unmatched<cc>'],
+        sorted(StatefulDoFnOnDirectRunnerTest.all_records))
 
 
 if __name__ == '__main__':


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


Issue Time Tracking
-------------------

    Worklog Id:     (was: 145357)
    Time Spent: 8h 20m  (was: 8h 10m)

> Reference DirectRunner implementation of Python user state and timers API
> -------------------------------------------------------------------------
>
>                 Key: BEAM-5264
>                 URL: https://issues.apache.org/jira/browse/BEAM-5264
>             Project: Beam
>          Issue Type: Improvement
>          Components: sdk-py-core
>    Affects Versions: 2.6.0
>            Reporter: Charles Chen
>            Assignee: Charles Chen
>            Priority: Major
>          Time Spent: 8h 20m
>  Remaining Estimate: 0h
>
> This issue tracks the reference DirectRunner implementation of the Beam Python User State and Timer API, described here: [https://s.apache.org/beam-python-user-state-and-timers].



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Mime
View raw message