beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From al...@apache.org
Subject [beam] branch master updated: [Beam-6696] GroupIntoBatches transform for Python SDK (#8914)
Date Fri, 21 Jun 2019 20:55:44 GMT
This is an automated email from the ASF dual-hosted git repository.

altay pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new ad88aa8  [Beam-6696] GroupIntoBatches transform for Python SDK (#8914)
ad88aa8 is described below

commit ad88aa83977a99e2b1e602f944d62c966ac24e40
Author: Raheel Khan <raheelwp@gmail.com>
AuthorDate: Sat Jun 22 01:55:29 2019 +0500

    [Beam-6696] GroupIntoBatches transform for Python SDK (#8914)
    
    GroupIntoBatches transform in the Python SDK
---
 sdks/python/apache_beam/transforms/util.py      | 77 +++++++++++++++++++++++-
 sdks/python/apache_beam/transforms/util_test.py | 78 +++++++++++++++++++++++++
 2 files changed, 154 insertions(+), 1 deletion(-)

diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py
index 4388f6a..dd5817d 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -25,16 +25,19 @@ import collections
 import contextlib
 import random
 import time
+import warnings
 from builtins import object
 from builtins import range
 from builtins import zip
 
 from future.utils import itervalues
 
+from apache_beam import coders
 from apache_beam import typehints
 from apache_beam.metrics import Metrics
 from apache_beam.portability import common_urns
 from apache_beam.transforms import window
+from apache_beam.transforms.combiners import CountCombineFn
 from apache_beam.transforms.core import CombinePerKey
 from apache_beam.transforms.core import DoFn
 from apache_beam.transforms.core import FlatMap
@@ -45,13 +48,19 @@ from apache_beam.transforms.core import ParDo
 from apache_beam.transforms.core import Windowing
 from apache_beam.transforms.ptransform import PTransform
 from apache_beam.transforms.ptransform import ptransform_fn
+from apache_beam.transforms.timeutil import TimeDomain
 from apache_beam.transforms.trigger import AccumulationMode
 from apache_beam.transforms.trigger import AfterCount
+from apache_beam.transforms.userstate import BagStateSpec
+from apache_beam.transforms.userstate import CombiningValueStateSpec
+from apache_beam.transforms.userstate import TimerSpec
+from apache_beam.transforms.userstate import on_timer
 from apache_beam.transforms.window import NonMergingWindowFn
 from apache_beam.transforms.window import TimestampCombiner
 from apache_beam.transforms.window import TimestampedValue
 from apache_beam.utils import windowed_value
 from apache_beam.utils.annotations import deprecated
+from apache_beam.utils.annotations import experimental
 
 __all__ = [
     'BatchElements',
@@ -64,7 +73,8 @@ __all__ = [
     'Reshuffle',
     'ToString',
     'Values',
-    'WithKeys'
+    'WithKeys',
+    'GroupIntoBatches'
     ]
 
 K = typehints.TypeVariable('K')
@@ -671,6 +681,71 @@ def WithKeys(pcoll, k):
   return pcoll | Map(lambda v: (k, v))
 
 
+@experimental()
+@typehints.with_input_types(typehints.KV[K, V])
+class GroupIntoBatches(PTransform):
+  """PTransform that batches the input into desired batch size. Elements are
+  buffered until they are equal to batch size provided in the argument at which
+  point they are output to the output Pcollection.
+
+  Windows are preserved (batches will contain elements from the same window)
+
+  GroupIntoBatches is experimental. Its use case will depend on the runner if
+  it has support of States and Timers.
+  """
+
+  def __init__(self, batch_size):
+    """Create a new GroupIntoBatches with batch size.
+
+    Arguments:
+      batch_size: (required) How many elements should be in a batch
+    """
+    warnings.warn('Use of GroupIntoBatches transform requires State/Timer '
+                  'support from the runner')
+    self.batch_size = batch_size
+
+  def expand(self, pcoll):
+    input_coder = coders.registry.get_coder(pcoll)
+    return pcoll | ParDo(_pardo_group_into_batches(
+        self.batch_size, input_coder))
+
+
+def _pardo_group_into_batches(batch_size, input_coder):
+  ELEMENT_STATE = BagStateSpec('values', input_coder)
+  COUNT_STATE = CombiningValueStateSpec('count', input_coder, CountCombineFn())
+  EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK)
+
+  class _GroupIntoBatchesDoFn(DoFn):
+
+    def process(self, element,
+                window=DoFn.WindowParam,
+                element_state=DoFn.StateParam(ELEMENT_STATE),
+                count_state=DoFn.StateParam(COUNT_STATE),
+                expiry_timer=DoFn.TimerParam(EXPIRY_TIMER)):
+      # Allowed lateness not supported in Python SDK
+      # https://beam.apache.org/documentation/programming-guide/#watermarks-and-late-data
+      expiry_timer.set(window.end)
+      element_state.add(element)
+      count_state.add(1)
+      count = count_state.read()
+      if count >= batch_size:
+        batch = [element for element in element_state.read()]
+        yield batch
+        element_state.clear()
+        count_state.clear()
+
+    @on_timer(EXPIRY_TIMER)
+    def expiry(self, element_state=DoFn.StateParam(ELEMENT_STATE),
+               count_state=DoFn.StateParam(COUNT_STATE)):
+      batch = [element for element in element_state.read()]
+      if batch:
+        yield batch
+        element_state.clear()
+        count_state.clear()
+
+  return _GroupIntoBatchesDoFn()
+
+
 class ToString(object):
   """
   PTransform for converting a PCollection element, KV or PCollection Iterable
diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py
index b655f66..ae952f6 100644
--- a/sdks/python/apache_beam/transforms/util_test.py
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -20,7 +20,9 @@
 from __future__ import absolute_import
 from __future__ import division
 
+import itertools
 import logging
+import math
 import random
 import time
 import unittest
@@ -28,16 +30,19 @@ from builtins import object
 from builtins import range
 
 import apache_beam as beam
+from apache_beam import WindowInto
 from apache_beam.coders import coders
 from apache_beam.options.pipeline_options import PipelineOptions
 from apache_beam.options.pipeline_options import StandardOptions
 from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.test_stream import TestStream
 from apache_beam.testing.util import TestWindowedValue
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import contains_in_any_order
 from apache_beam.testing.util import equal_to
 from apache_beam.transforms import util
 from apache_beam.transforms import window
+from apache_beam.transforms.window import FixedWindows
 from apache_beam.transforms.window import GlobalWindow
 from apache_beam.transforms.window import GlobalWindows
 from apache_beam.transforms.window import IntervalWindow
@@ -432,6 +437,79 @@ class WithKeysTest(unittest.TestCase):
     assert_that(with_keys, equal_to([(1, 1), (4, 2), (9, 3)]))
 
 
+class GroupIntoBatchesTest(unittest.TestCase):
+  NUM_ELEMENTS = 10
+  BATCH_SIZE = 5
+
+  @staticmethod
+  def _create_test_data():
+    scientists = [
+        "Einstein",
+        "Darwin",
+        "Copernicus",
+        "Pasteur",
+        "Curie",
+        "Faraday",
+        "Newton",
+        "Bohr",
+        "Galilei",
+        "Maxwell"
+    ]
+
+    data = []
+    for i in range(GroupIntoBatchesTest.NUM_ELEMENTS):
+      index = i % len(scientists)
+      data.append(("key", scientists[index]))
+    return data
+
+  def test_in_global_window(self):
+    pipeline = TestPipeline()
+    collection = pipeline \
+                 | beam.Create(GroupIntoBatchesTest._create_test_data()) \
+                 | util.GroupIntoBatches(GroupIntoBatchesTest.BATCH_SIZE)
+    num_batches = collection | beam.combiners.Count.Globally()
+    assert_that(num_batches,
+                equal_to([int(math.ceil(GroupIntoBatchesTest.NUM_ELEMENTS /
+                                        GroupIntoBatchesTest.BATCH_SIZE))]))
+    pipeline.run()
+
+  def test_in_streaming_mode(self):
+    timestamp_interval = 1
+    offset = itertools.count(0)
+    start_time = timestamp.Timestamp(0)
+    window_duration = 6
+    test_stream = (TestStream()
+                   .advance_watermark_to(start_time)
+                   .add_elements(
+                       [TimestampedValue(x, next(offset) * timestamp_interval)
+                        for x in GroupIntoBatchesTest._create_test_data()])
+                   .advance_watermark_to(start_time + (window_duration - 1))
+                   .advance_watermark_to(start_time + (window_duration + 1))
+                   .advance_watermark_to(start_time +
+                                         GroupIntoBatchesTest.NUM_ELEMENTS)
+                   .advance_watermark_to_infinity())
+    pipeline = TestPipeline()
+    #  window duration is 6 and batch size is 5, so output batch size should be
+    #  5 (flush because of batchSize reached)
+    expected_0 = 5
+    # there is only one element left in the window so batch size should be 1
+    # (flush because of end of window reached)
+    expected_1 = 1
+    #  collection is 10 elements, there is only 4 left, so batch size should be
+    #  4 (flush because end of collection reached)
+    expected_2 = 4
+
+    collection = pipeline | test_stream \
+                 | WindowInto(FixedWindows(window_duration)) \
+                 | util.GroupIntoBatches(GroupIntoBatchesTest.BATCH_SIZE)
+    num_elements_in_batches = collection | beam.Map(len)
+
+    result = pipeline.run()
+    result.wait_until_finish()
+    assert_that(num_elements_in_batches,
+                equal_to([expected_0, expected_1, expected_2]))
+
+
 class ToStringTest(unittest.TestCase):
 
   def test_tostring_elements(self):


Mime
View raw message