beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From boyu...@apache.org
Subject [beam] branch master updated: Add GroupIntoBatches to runner API; add Dataflow override in Python SDK
Date Wed, 09 Dec 2020 03:36:38 GMT
This is an automated email from the ASF dual-hosted git repository.

boyuanz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 05c8471  Add GroupIntoBatches to runner API; add Dataflow override in Python SDK
     new 2c96aeb  Merge pull request #13405 from [BEAM-10703, BEAM-10475] Add GroupIntoBatches
to runner API; add Dataflow override in Python SDK
05c8471 is described below

commit 05c8471b27e03e5611a2a13137c4a785f2d17fc9
Author: sychen <sychen@google.com>
AuthorDate: Mon Nov 9 21:16:50 2020 -0800

    Add GroupIntoBatches to runner API; add Dataflow override in Python SDK
---
 .../pipeline/src/main/proto/beam_runner_api.proto  |  18 ++++
 sdks/python/apache_beam/portability/common_urns.py |   1 +
 .../runners/dataflow/dataflow_runner.py            |  32 +++++-
 .../runners/dataflow/dataflow_runner_test.py       |  62 ++++++++++++
 .../apache_beam/runners/dataflow/internal/names.py |   5 +
 .../runners/dataflow/ptransform_overrides.py       |  45 +++++++++
 sdks/python/apache_beam/transforms/util.py         | 111 ++++++++++++++++-----
 sdks/python/apache_beam/transforms/util_test.py    |  41 ++++++++
 8 files changed, 288 insertions(+), 27 deletions(-)

diff --git a/model/pipeline/src/main/proto/beam_runner_api.proto b/model/pipeline/src/main/proto/beam_runner_api.proto
index cbfd817..ce561d3 100644
--- a/model/pipeline/src/main/proto/beam_runner_api.proto
+++ b/model/pipeline/src/main/proto/beam_runner_api.proto
@@ -344,6 +344,10 @@ message StandardPTransforms {
 
     // Payload: PubSubWritePayload.
     PUBSUB_WRITE = 5 [(beam_urn) = "beam:transform:pubsub_write:v1"];
+
+    // Represents the GroupIntoBatches.WithShardedKey operation.
+    // Payload: GroupIntoBatchesPayload
+    GROUP_INTO_BATCHES_WITH_SHARDED_KEY = 6 [(beam_urn) = "beam:transform:group_into_batches_with_sharded_key:v1"];
   }
   // Payload for all of these: CombinePayload
   enum CombineComponents {
@@ -414,6 +418,10 @@ message StandardPTransforms {
     // Output: KV(KV(element, restriction), size).
     TRUNCATE_SIZED_RESTRICTION = 3 [(beam_urn) = "beam:transform:sdf_truncate_sized_restrictions:v1"];
   }
+  // Payload for all of these: GroupIntoBatchesPayload
+  enum GroupIntoBatchesComponents {
+    GROUP_INTO_BATCHES = 0 [(beam_urn) = "beam:transform:group_into_batches:v1"];
+  }
 }
 
 message StandardSideInputTypes {
@@ -706,6 +714,16 @@ message PubSubWritePayload {
   string id_attribute = 3;
 }
 
+// Payload for GroupIntoBatches composite transform.
+message GroupIntoBatchesPayload {
+
+  // (Required) Max size of a batch.
+  int64 batch_size = 1;
+
+  // (Optional) Max duration a batch is allowed to be cached in states.
+  int64 max_buffering_duration_millis = 2;
+}
+
 // A coder, the binary format for serialization and deserialization of data in
 // a pipeline.
 message Coder {
diff --git a/sdks/python/apache_beam/portability/common_urns.py b/sdks/python/apache_beam/portability/common_urns.py
index 819264d..e356319 100644
--- a/sdks/python/apache_beam/portability/common_urns.py
+++ b/sdks/python/apache_beam/portability/common_urns.py
@@ -42,6 +42,7 @@ deprecated_primitives = StandardPTransforms.DeprecatedPrimitives
 composites = StandardPTransforms.Composites
 combine_components = StandardPTransforms.CombineComponents
 sdf_components = StandardPTransforms.SplittableParDoComponents
+group_into_batches_components = StandardPTransforms.GroupIntoBatchesComponents
 
 side_inputs = StandardSideInputTypes.Enum
 coders = StandardCoders.Enum
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
index 06853f4..92d9c19 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
@@ -505,7 +505,11 @@ class DataflowRunner(PipelineRunner):
     pipeline.replace_all(DataflowRunner._PTRANSFORM_OVERRIDES)
 
     from apache_beam.runners.dataflow.ptransform_overrides import WriteToBigQueryPTransformOverride
-    pipeline.replace_all([WriteToBigQueryPTransformOverride(pipeline, options)])
+    from apache_beam.runners.dataflow.ptransform_overrides import GroupIntoBatchesWithShardedKeyPTransformOverride
+    pipeline.replace_all([
+        WriteToBigQueryPTransformOverride(pipeline, options),
+        GroupIntoBatchesWithShardedKeyPTransformOverride(self, options)
+    ])
 
     if use_fnapi and not apiclient._use_unified_worker(options):
       pipeline.replace_all(DataflowRunner._JRH_PTRANSFORM_OVERRIDES)
@@ -727,6 +731,18 @@ class DataflowRunner(PipelineRunner):
       window_coder = None
     return self._get_typehint_based_encoding(element_type, window_coder)
 
+  def get_pcoll_with_auto_sharding(self):
+    if not hasattr(self, '_pcoll_with_auto_sharding'):
+      return set()
+    return self._pcoll_with_auto_sharding
+
+  def add_pcoll_with_auto_sharding(self, applied_ptransform):
+    if not hasattr(self, '_pcoll_with_auto_sharding'):
+      self.__setattr__('_pcoll_with_auto_sharding', set())
+    output = DataflowRunner._only_element(applied_ptransform.outputs.keys())
+    self._pcoll_with_auto_sharding.add(
+        applied_ptransform.outputs[output]._unique_name())
+
   def _add_step(self, step_kind, step_label, transform_node, side_tags=()):
     """Creates a Step object and adds it to the cache."""
     # Import here to avoid adding the dependency for local running scenarios.
@@ -1112,6 +1128,20 @@ class DataflowRunner(PipelineRunner):
       if is_stateful_dofn:
         step.add_property(PropertyNames.USES_KEYED_STATE, 'true')
 
+        # Also checks whether the step allows shardable keyed states.
+        # TODO(BEAM-11360): remove this when migrated to portable job
+        #  submission since we only consider supporting the property in runner
+        #  v2.
+        for pcoll in transform_node.outputs.values():
+          if pcoll._unique_name() in self.get_pcoll_with_auto_sharding():
+            step.add_property(PropertyNames.ALLOWS_SHARDABLE_STATE, 'true')
+            # Currently we only allow auto-sharding to be enabled through the
+            # GroupIntoBatches transform. So we also add the following property
+            # which GroupIntoBatchesDoFn has, to allow the backend to perform
+            # graph optimization.
+            step.add_property(PropertyNames.PRESERVES_KEYS, 'true')
+            break
+
   @staticmethod
   def _pardo_fn_data(transform_node, get_label):
     transform = transform_node.transform
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
index 8d64540..436d5d1 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
@@ -48,6 +48,7 @@ from apache_beam.runners import TestDataflowRunner
 from apache_beam.runners import create_runner
 from apache_beam.runners.dataflow.dataflow_runner import DataflowPipelineResult
 from apache_beam.runners.dataflow.dataflow_runner import DataflowRuntimeException
+from apache_beam.runners.dataflow.dataflow_runner import PropertyNames
 from apache_beam.runners.dataflow.internal.clients import dataflow as dataflow_api
 from apache_beam.runners.runner import PipelineState
 from apache_beam.testing.extra_assertions import ExtraAssertionsMixin
@@ -787,6 +788,67 @@ class DataflowRunnerTest(unittest.TestCase, ExtraAssertionsMixin):
     except ValueError:
       self.fail('ValueError raised unexpectedly')
 
+  def _run_group_into_batches_and_get_step_properties(
+      self, with_sharded_key, additional_properties):
+    self.default_properties.append('--streaming')
+    self.default_properties.append(
+        '--experiment=enable_streaming_auto_sharding')
+    for property in additional_properties:
+      self.default_properties.append(property)
+
+    runner = DataflowRunner()
+    with beam.Pipeline(runner=runner,
+                       options=PipelineOptions(self.default_properties)) as p:
+      # pylint: disable=expression-not-assigned
+      input = p | beam.Create([('a', 1), ('a', 1), ('b', 3), ('b', 4)])
+      if with_sharded_key:
+        (
+            input | beam.GroupIntoBatches.WithShardedKey(2)
+            | beam.Map(lambda key_values: (key_values[0].key, key_values[1])))
+        step_name = (
+            u'WithShardedKey/GroupIntoBatches/ParDo(_GroupIntoBatchesDoFn)')
+      else:
+        input | beam.GroupIntoBatches(2)
+        step_name = u'GroupIntoBatches/ParDo(_GroupIntoBatchesDoFn)'
+
+    return self._find_step(runner.job, step_name)['properties']
+
+  def test_group_into_batches_translation(self):
+    properties = self._run_group_into_batches_and_get_step_properties(
+        True, ['--enable_streaming_engine', '--experiment=use_runner_v2'])
+    self.assertEqual(properties[PropertyNames.USES_KEYED_STATE], u'true')
+    self.assertEqual(properties[PropertyNames.ALLOWS_SHARDABLE_STATE], u'true')
+    self.assertEqual(properties[PropertyNames.PRESERVES_KEYS], u'true')
+
+  def test_group_into_batches_translation_non_se(self):
+    properties = self._run_group_into_batches_and_get_step_properties(
+        True, ['--experiment=use_runner_v2'])
+    self.assertEqual(properties[PropertyNames.USES_KEYED_STATE], u'true')
+    self.assertFalse(PropertyNames.ALLOWS_SHARDABLE_STATE in properties)
+    self.assertFalse(PropertyNames.PRESERVES_KEYS in properties)
+
+  def test_group_into_batches_translation_non_sharded(self):
+    properties = self._run_group_into_batches_and_get_step_properties(
+        False, ['--enable_streaming_engine', '--experiment=use_runner_v2'])
+    self.assertEqual(properties[PropertyNames.USES_KEYED_STATE], u'true')
+    self.assertFalse(PropertyNames.ALLOWS_SHARDABLE_STATE in properties)
+    self.assertFalse(PropertyNames.PRESERVES_KEYS in properties)
+
+  def test_group_into_batches_translation_non_unified_worker(self):
+    # non-portable
+    properties = self._run_group_into_batches_and_get_step_properties(
+        True, ['--enable_streaming_engine'])
+    self.assertEqual(properties[PropertyNames.USES_KEYED_STATE], u'true')
+    self.assertFalse(PropertyNames.ALLOWS_SHARDABLE_STATE in properties)
+    self.assertFalse(PropertyNames.PRESERVES_KEYS in properties)
+
+    # JRH
+    properties = self._run_group_into_batches_and_get_step_properties(
+        True, ['--enable_streaming_engine', '--experiment=beam_fn_api'])
+    self.assertEqual(properties[PropertyNames.USES_KEYED_STATE], u'true')
+    self.assertFalse(PropertyNames.ALLOWS_SHARDABLE_STATE in properties)
+    self.assertFalse(PropertyNames.PRESERVES_KEYS in properties)
+
 
 class CustomMergingWindowFn(window.WindowFn):
   def assign(self, assign_context):
diff --git a/sdks/python/apache_beam/runners/dataflow/internal/names.py b/sdks/python/apache_beam/runners/dataflow/internal/names.py
index 79ef7c1..7a2ed09 100644
--- a/sdks/python/apache_beam/runners/dataflow/internal/names.py
+++ b/sdks/python/apache_beam/runners/dataflow/internal/names.py
@@ -69,6 +69,8 @@ class PropertyNames(object):
 
   Property strings as they are expected in the CloudWorkflow protos.
   """
+  # If uses_keyed_state, whether the state can be sharded.
+  ALLOWS_SHARDABLE_STATE = 'allows_shardable_state'
   BIGQUERY_CREATE_DISPOSITION = 'create_disposition'
   BIGQUERY_DATASET = 'dataset'
   BIGQUERY_EXPORT_FORMAT = 'bigquery_export_format'
@@ -98,6 +100,9 @@ class PropertyNames(object):
   OUTPUT_NAME = 'output_name'
   PARALLEL_INPUT = 'parallel_input'
   PIPELINE_PROTO_TRANSFORM_ID = 'pipeline_proto_transform_id'
+  # If the input element is a key/value pair, then the output element(s) all
+  # have the same key as the input.
+  PRESERVES_KEYS = 'preserves_keys'
   PUBSUB_ID_LABEL = 'pubsub_id_label'
   PUBSUB_SERIALIZED_ATTRIBUTES_FN = 'pubsub_serialized_attributes_fn'
   PUBSUB_SUBSCRIPTION = 'pubsub_subscription'
diff --git a/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py b/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py
index f91c016..402a4ed 100644
--- a/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py
+++ b/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py
@@ -22,6 +22,7 @@
 from __future__ import absolute_import
 
 from apache_beam.options.pipeline_options import DebugOptions
+from apache_beam.options.pipeline_options import GoogleCloudOptions
 from apache_beam.options.pipeline_options import StandardOptions
 from apache_beam.pipeline import PTransformOverride
 
@@ -329,3 +330,47 @@ class WriteToBigQueryPTransformOverride(PTransformOverride):
         return {key: out for key in self.outputs}
 
     return WriteToBigQuery(ptransform, self.outputs)
+
+
+class GroupIntoBatchesWithShardedKeyPTransformOverride(PTransformOverride):
+  """A ``PTransformOverride`` for ``GroupIntoBatches.WithShardedKey``.
+
+  This override simply returns the original transform but additionally records
+  the output PCollection in order to append required step properties during
+  graph translation.
+  """
+  def __init__(self, dataflow_runner, options):
+    self.dataflow_runner = dataflow_runner
+    self.options = options
+
+  def matches(self, applied_ptransform):
+    # Imported here to avoid circular dependencies.
+    # pylint: disable=wrong-import-order, wrong-import-position
+    from apache_beam import util
+
+    transform = applied_ptransform.transform
+
+    if not isinstance(transform, util.GroupIntoBatches.WithShardedKey):
+      return False
+
+    # The replacement is only valid for portable Streaming Engine jobs with
+    # runner v2.
+    standard_options = self.options.view_as(StandardOptions)
+    if not standard_options.streaming:
+      return False
+    google_cloud_options = self.options.view_as(GoogleCloudOptions)
+    if not google_cloud_options.enable_streaming_engine:
+      return False
+
+    from apache_beam.runners.dataflow.internal import apiclient
+    if not apiclient._use_unified_worker(self.options):
+      return False
+    experiments = self.options.view_as(DebugOptions).experiments or []
+    if 'enable_streaming_auto_sharding' not in experiments:
+      return False
+
+    self.dataflow_runner.add_pcoll_with_auto_sharding(applied_ptransform)
+    return True
+
+  def get_replacement_transform_for_applied_ptransform(self, ptransform):
+    return ptransform.transform
diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py
index 85681f4..2e8d7c9 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -50,6 +50,7 @@ from apache_beam import coders
 from apache_beam import typehints
 from apache_beam.metrics import Metrics
 from apache_beam.portability import common_urns
+from apache_beam.portability.api import beam_runner_api_pb2
 from apache_beam.transforms import window
 from apache_beam.transforms.combiners import CountCombineFn
 from apache_beam.transforms.core import CombinePerKey
@@ -763,15 +764,12 @@ class GroupIntoBatches(PTransform):
       max_buffering_duration_secs: (optional) How long in seconds at most an
         incomplete batch of elements is allowed to be buffered in the states.
         The duration must be a positive second duration and should be given as
-        an int or float.
+        an int or float. Setting this parameter to zero effectively means no
+        buffering limit.
       clock: (optional) an alternative to time.time (mostly for testing)
     """
-    self.batch_size = batch_size
-
-    if max_buffering_duration_secs is not None:
-      assert max_buffering_duration_secs > 0, (
-          'max buffering duration should be a positive value')
-    self.max_buffering_duration_secs = max_buffering_duration_secs
+    self.params = _GroupIntoBatchesParams(
+        batch_size, max_buffering_duration_secs)
     self.clock = clock
 
   def expand(self, pcoll):
@@ -779,11 +777,25 @@ class GroupIntoBatches(PTransform):
     return pcoll | ParDo(
         _pardo_group_into_batches(
             input_coder,
-            self.batch_size,
-            self.max_buffering_duration_secs,
+            self.params.batch_size,
+            self.params.max_buffering_duration_secs,
             self.clock))
 
-  @experimental()
+  def to_runner_api_parameter(
+      self,
+      unused_context  # type: PipelineContext
+  ):  # type: (...) -> Tuple[str, beam_runner_api_pb2.GroupIntoBatchesPayload]
+    return (
+        common_urns.group_into_batches_components.GROUP_INTO_BATCHES.urn,
+        self.params.get_payload())
+
+  @staticmethod
+  @PTransform.register_urn(
+      common_urns.group_into_batches_components.GROUP_INTO_BATCHES.urn,
+      beam_runner_api_pb2.GroupIntoBatchesPayload)
+  def from_runner_api_parameter(unused_ptransform, proto, unused_context):
+    return GroupIntoBatches(*_GroupIntoBatchesParams.parse_payload(proto))
+
   @typehints.with_input_types(Tuple[K, V])
   @typehints.with_output_types(Tuple[K, Iterable[V]])
   class WithShardedKey(PTransform):
@@ -796,21 +808,11 @@ class GroupIntoBatches(PTransform):
     execution time.
     """
     def __init__(self, batch_size, max_buffering_duration_secs=None):
-      """Create a new GroupIntoBatches.WithShardedKey.
-
-      Arguments:
-        batch_size: (required) How many elements should be in a batch
-        max_buffering_duration_secs: (optional) How long in seconds at most an
-          incomplete batch of elements is allowed to be buffered in the states.
-          The duration must be a positive second duration and should be given as
-          an int or float.
+      """Create a new GroupIntoBatches with sharded output.
+      See ``GroupIntoBatches`` transform for a description of input parameters.
       """
-      self.batch_size = batch_size
-
-      if max_buffering_duration_secs is not None:
-        assert max_buffering_duration_secs > 0, (
-            'max buffering duration should be a positive value')
-      self.max_buffering_duration_secs = max_buffering_duration_secs
+      self.params = _GroupIntoBatchesParams(
+          batch_size, max_buffering_duration_secs)
 
     _shard_id_prefix = uuid.uuid4().bytes
 
@@ -825,7 +827,64 @@ class GroupIntoBatches(PTransform):
               key_value[1]))
       return (
           sharded_pcoll
-          | GroupIntoBatches(self.batch_size, self.max_buffering_duration_secs))
+          | GroupIntoBatches(
+              self.params.batch_size, self.params.max_buffering_duration_secs))
+
+    def to_runner_api_parameter(
+        self,
+        unused_context  # type: PipelineContext
+    ):  # type: (...) -> Tuple[str, beam_runner_api_pb2.GroupIntoBatchesPayload]
+      return (
+          common_urns.composites.GROUP_INTO_BATCHES_WITH_SHARDED_KEY.urn,
+          self.params.get_payload())
+
+    @staticmethod
+    @PTransform.register_urn(
+        common_urns.composites.GROUP_INTO_BATCHES_WITH_SHARDED_KEY.urn,
+        beam_runner_api_pb2.GroupIntoBatchesPayload)
+    def from_runner_api_parameter(unused_ptransform, proto, unused_context):
+      return GroupIntoBatches.WithShardedKey(
+          *_GroupIntoBatchesParams.parse_payload(proto))
+
+
+class _GroupIntoBatchesParams:
+  """This class represents the parameters for
+  :class:`apache_beam.utils.GroupIntoBatches` transform, used to define how
+  elements should be batched.
+  """
+  def __init__(self, batch_size, max_buffering_duration_secs):
+    self.batch_size = batch_size
+    self.max_buffering_duration_secs = (
+        0
+        if max_buffering_duration_secs is None else max_buffering_duration_secs)
+    self._validate()
+
+  def __eq__(self, other):
+    if other is None or not isinstance(other, _GroupIntoBatchesParams):
+      return False
+    return (
+        self.batch_size == other.batch_size and
+        self.max_buffering_duration_secs == other.max_buffering_duration_secs)
+
+  def _validate(self):
+    assert self.batch_size is not None and self.batch_size > 0, (
+        'batch_size must be a positive value')
+    assert (
+        self.max_buffering_duration_secs is not None and
+        self.max_buffering_duration_secs >= 0), (
+            'max_buffering_duration must be a non-negative value')
+
+  def get_payload(self):
+    return beam_runner_api_pb2.GroupIntoBatchesPayload(
+        batch_size=self.batch_size,
+        max_buffering_duration_millis=int(
+            self.max_buffering_duration_secs * 1000))
+
+  @staticmethod
+  def parse_payload(
+      proto  # type: beam_runner_api_pb2.GroupIntoBatchesPayload
+  ):
+    return proto.batch_size, proto.max_buffering_duration_millis / 1000
 
 
 def _pardo_group_into_batches(
@@ -850,7 +909,7 @@ def _pardo_group_into_batches(
       element_state.add(element)
       count_state.add(1)
       count = count_state.read()
-      if count == 1 and max_buffering_duration_secs is not None:
+      if count == 1 and max_buffering_duration_secs > 0:
         # This is the first element in batch. Start counting buffering time if a
         # limit was set.
         buffering_timer.set(clock() + max_buffering_duration_secs)
diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py
index 26dcdb1..0022130 100644
--- a/sdks/python/apache_beam/transforms/util_test.py
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -43,6 +43,9 @@ from apache_beam import WindowInto
 from apache_beam.coders import coders
 from apache_beam.options.pipeline_options import PipelineOptions
 from apache_beam.options.pipeline_options import StandardOptions
+from apache_beam.portability import common_urns
+from apache_beam.portability.api import beam_runner_api_pb2
+from apache_beam.runners import pipeline_context
 from apache_beam.testing.test_pipeline import TestPipeline
 from apache_beam.testing.test_stream import TestStream
 from apache_beam.testing.util import TestWindowedValue
@@ -62,6 +65,7 @@ from apache_beam.transforms.window import IntervalWindow
 from apache_beam.transforms.window import Sessions
 from apache_beam.transforms.window import SlidingWindows
 from apache_beam.transforms.window import TimestampedValue
+from apache_beam.utils import proto_utils
 from apache_beam.utils import timestamp
 from apache_beam.utils.timestamp import MAX_TIMESTAMP
 from apache_beam.utils.timestamp import MIN_TIMESTAMP
@@ -773,6 +777,43 @@ class GroupIntoBatchesTest(unittest.TestCase):
       # the global window ends.
       assert_that(num_elements_per_batch, equal_to([9, 1]))
 
+  def _test_runner_api_round_trip(self, transform, urn):
+    context = pipeline_context.PipelineContext()
+    proto = transform.to_runner_api(context)
+    self.assertEqual(urn, proto.urn)
+    payload = (
+        proto_utils.parse_Bytes(
+            proto.payload, beam_runner_api_pb2.GroupIntoBatchesPayload))
+    self.assertEqual(transform.params.batch_size, payload.batch_size)
+    self.assertEqual(
+        transform.params.max_buffering_duration_secs * 1000,
+        payload.max_buffering_duration_millis)
+
+    transform_from_proto = (
+        transform.__class__.from_runner_api_parameter(None, payload, None))
+    self.assertTrue(isinstance(transform_from_proto, transform.__class__))
+    self.assertEqual(transform.params, transform_from_proto.params)
+
+  def test_runner_api(self):
+    batch_size = 10
+    max_buffering_duration_secs = [None, 0, 5]
+
+    for duration in max_buffering_duration_secs:
+      self._test_runner_api_round_trip(
+          util.GroupIntoBatches(batch_size, duration),
+          common_urns.group_into_batches_components.GROUP_INTO_BATCHES.urn)
+    self._test_runner_api_round_trip(
+        util.GroupIntoBatches(batch_size),
+        common_urns.group_into_batches_components.GROUP_INTO_BATCHES.urn)
+
+    for duration in max_buffering_duration_secs:
+      self._test_runner_api_round_trip(
+          util.GroupIntoBatches.WithShardedKey(batch_size, duration),
+          common_urns.composites.GROUP_INTO_BATCHES_WITH_SHARDED_KEY.urn)
+    self._test_runner_api_round_trip(
+        util.GroupIntoBatches.WithShardedKey(batch_size),
+        common_urns.composites.GROUP_INTO_BATCHES_WITH_SHARDED_KEY.urn)
+
 
 class ToStringTest(unittest.TestCase):
   def test_tostring_elements(self):


Mime
View raw message