beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From al...@apache.org
Subject [beam] branch master updated: Several performance improvements to Beam's Combiners. (#7838)
Date Thu, 14 Feb 2019 17:33:52 GMT
This is an automated email from the ASF dual-hosted git repository.

altay pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 30dfa0b  Several performance improvements to Beam's Combiners. (#7838)
30dfa0b is described below

commit 30dfa0b03b783328a6f654236fa974df097e8034
Author: Ahmet Altay <aaltay@gmail.com>
AuthorDate: Thu Feb 14 09:33:40 2019 -0800

    Several performance improvements to Beam's Combiners. (#7838)
    
    Several performance improvements to Beam's Combiners.
    
    * Updating the efficient implementation of a Global Top to also work with keys.
    * Updating the TopCombineFn implementation to use a heap for improved performance.
    * Making use of some additional caching in TopCombineFn.
    * Documenting use of CombineFn.compact in PartialGroupByKeyCombiningValues.
    * Making ComparableValue pure-python Cython optimized.
---
 .../runners/direct/helper_transforms.py            |   2 +
 sdks/python/apache_beam/transforms/combiners.py    | 213 +++++++++++----------
 .../apache_beam/transforms/combiners_test.py       |   4 -
 .../python/apache_beam/transforms/cy_combiners.pxd |  12 ++
 sdks/python/apache_beam/transforms/cy_combiners.py |  31 +++
 5 files changed, 152 insertions(+), 110 deletions(-)

diff --git a/sdks/python/apache_beam/runners/direct/helper_transforms.py b/sdks/python/apache_beam/runners/direct/helper_transforms.py
index 60b3ad3..51377ba 100644
--- a/sdks/python/apache_beam/runners/direct/helper_transforms.py
+++ b/sdks/python/apache_beam/runners/direct/helper_transforms.py
@@ -69,6 +69,8 @@ class PartialGroupByKeyCombiningValues(beam.DoFn):
 
   def finish_bundle(self):
     for (k, w), va in self._cache.items():
+      # We compact the accumulator since a GBK (which necessitates encoding)
+      # will follow.
       yield WindowedValue((k, self._combine_fn.compact(va)), w.end, (w,))
 
   def default_type_hints(self):
diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py
index 65e098e..3976a8e 100644
--- a/sdks/python/apache_beam/transforms/combiners.py
+++ b/sdks/python/apache_beam/transforms/combiners.py
@@ -25,7 +25,6 @@ import operator
 import random
 from builtins import object
 from builtins import zip
-from functools import cmp_to_key
 
 from past.builtins import long
 
@@ -197,7 +196,7 @@ class Top(object):
     def expand(self, pcoll):
       compare = self._compare
       if (not self._args and not self._kwargs and
-          not self._key and pcoll.windowing.is_default()):
+          pcoll.windowing.is_default()):
         if self._reverse:
           if compare is None or compare is operator.lt:
             compare = operator.gt
@@ -205,7 +204,8 @@ class Top(object):
             original_compare = compare
             compare = lambda a, b: original_compare(b, a)
         # This is a more efficient global algorithm.
-        top_per_bundle = pcoll | core.ParDo(_TopPerBundle(self._n, compare))
+        top_per_bundle = pcoll | core.ParDo(
+            _TopPerBundle(self._n, compare, self._key))
         # If pcoll is empty, we can't guerentee that top_per_bundle
         # won't be empty, so inject at least one empty accumulator
         # so that downstream is guerenteed to produce non-empty output.
@@ -213,7 +213,7 @@ class Top(object):
         return (
             (top_per_bundle, empty_bundle) | core.Flatten()
             | core.GroupByKey()
-            | core.ParDo(_MergeTopPerBundle(self._n, compare)))
+            | core.ParDo(_MergeTopPerBundle(self._n, compare, self._key)))
       else:
         return pcoll | core.CombineGlobally(
             TopCombineFn(self._n, compare, self._key, self._reverse),
@@ -295,34 +295,21 @@ class Top(object):
     return pcoll | Top.PerKey(n, reverse=True)
 
 
-class _ComparableValue(object):
-
-  __slots__ = ('value', 'less_than')
-
-  def __init__(self, value, less_than):
-    self.value = value
-    self.less_than = less_than
-
-  def __lt__(self, other):
-    return self.less_than(self.value, other.value)
-
-  def __repr__(self):
-    return "_ComparableValue[%s]" % self.value
-
-
 @with_input_types(T)
 @with_output_types(KV[None, List[T]])
 class _TopPerBundle(core.DoFn):
-  def __init__(self, n, less_than):
+  def __init__(self, n, less_than, key):
     self._n = n
     self._less_than = None if less_than is operator.le else less_than
+    self._key = key
 
   def start_bundle(self):
     self._heap = []
 
   def process(self, element):
-    if self._less_than is not None:
-      element = _ComparableValue(element, self._less_than)
+    if self._less_than or self._key:
+      element = cy_combiners.ComparableValue(
+          element, self._less_than, self._key)
     if len(self._heap) < self._n:
       heapq.heappush(self._heap, element)
     else:
@@ -336,7 +323,7 @@ class _TopPerBundle(core.DoFn):
     self._heap.sort()
 
     # Unwrap to avoid serialization via pickle.
-    if self._less_than:
+    if self._less_than or self._key:
       yield window.GlobalWindows.windowed_value(
           (None, [wrapper.value for wrapper in self._heap]))
     else:
@@ -347,27 +334,30 @@ class _TopPerBundle(core.DoFn):
 @with_input_types(KV[None, Iterable[List[T]]])
 @with_output_types(List[T])
 class _MergeTopPerBundle(core.DoFn):
-  def __init__(self, n, less_than):
+  def __init__(self, n, less_than, key):
     self._n = n
-    self._less_than = None if less_than is operator.le else less_than
+    self._less_than = None if less_than is operator.lt else less_than
+    self._key = key
 
   def process(self, key_and_bundles):
     _, bundles = key_and_bundles
     heap = []
     for bundle in bundles:
       if not heap:
-        if self._less_than:
+        if self._less_than or self._key:
           heap = [
-              _ComparableValue(element, self._less_than) for element in bundle]
+              cy_combiners.ComparableValue(element, self._less_than, self._key)
+              for element in bundle]
         else:
           heap = bundle
         continue
       for element in reversed(bundle):
-        if self._less_than is not None:
-          element = _ComparableValue(element, self._less_than)
+        if self._less_than or self._key:
+          element = cy_combiners.ComparableValue(
+              element, self._less_than, self._key)
         if len(heap) < self._n:
           heapq.heappush(heap, element)
-        elif element <= heap[0]:
+        elif element < heap[0]:
           # Because _TopPerBundle returns sorted lists, all other elements
           # will also be smaller.
           break
@@ -375,7 +365,7 @@ class _MergeTopPerBundle(core.DoFn):
           heapq.heappushpop(heap, element)
 
     heap.sort()
-    if self._less_than:
+    if self._less_than or self._key:
       yield [wrapper.value for wrapper in reversed(heap)]
     else:
       yield heap[::-1]
@@ -398,15 +388,9 @@ class TopCombineFn(core.CombineFn):
         than largest to smallest
   """
 
-  _MIN_BUFFER_OVERSIZE = 100
-  _MAX_BUFFER_OVERSIZE = 1000
-
-  # TODO(robertwb): Allow taking a key rather than a compare.
+  # TODO(robertwb): For Python 3, remove compare and only keep key.
   def __init__(self, n, compare=None, key=None, reverse=False):
     self._n = n
-    self._buffer_size = max(
-        min(2 * n, n + TopCombineFn._MAX_BUFFER_OVERSIZE),
-        n + TopCombineFn._MIN_BUFFER_OVERSIZE)
 
     if compare is operator.lt:
       compare = None
@@ -422,18 +406,8 @@ class TopCombineFn(core.CombineFn):
     else:
       self._compare = operator.gt if reverse else operator.lt
 
-    self._key_fn = key
-    self._reverse = reverse
-
-  def _sort_buffer(self, buffer, lt):
-    if lt in (operator.gt, operator.lt):
-      buffer.sort(key=self._key_fn, reverse=self._reverse)
-    elif self._key_fn:
-      buffer.sort(key=cmp_to_key(
-          (lambda a, b: (not lt(self._key_fn(a), self._key_fn(b)))
-           - (not lt(self._key_fn(b), self._key_fn(a))))))
-    else:
-      buffer.sort(key=cmp_to_key(lambda a, b: (not lt(a, b)) - (not lt(b, a))))
+    self._less_than = None
+    self._key = key
 
   def display_data(self):
     return {'n': self._n,
@@ -442,83 +416,110 @@ class TopCombineFn(core.CombineFn):
                                        else self._compare.__class__.__name__)
                        .drop_if_none()}
 
-  # The accumulator type is a tuple (threshold, buffer), where threshold
-  # is the smallest element [key] that could possibly be in the top n based
-  # on the elements observed so far, and buffer is a (periodically sorted)
-  # list of candidates of bounded size.
-
+  # The accumulator type is a tuple
+  # (bool, Union[List[T], List[ComparableValue[T]])
+  # where the boolean indicates whether the second slot contains a List of T
+  # (False) or List of ComparableValue[T] (True). In either case, the List
+  # maintains heap invariance.
+  # This accumulator representation allows us to minimize the data encoding
+  # overheads. Creation of ComparableValues is also elided when there is no need
+  # for complicated comparison functions.
   def create_accumulator(self, *args, **kwargs):
-    return None, []
+    return (False, [])
 
   def add_input(self, accumulator, element, *args, **kwargs):
-    if args or kwargs:
-      lt = lambda a, b: self._compare(a, b, *args, **kwargs)
+    # Caching to avoid paying the price of variadic expansion of args / kwargs
+    # when it's not needed (for the 'if' case below).
+    if self._less_than is None:
+      if args or kwargs:
+        self._less_than = lambda a, b: self._compare(a, b, *args, **kwargs)
+      else:
+        self._less_than = self._compare
+
+    holds_comparables, heap = accumulator
+    if self._less_than is not operator.lt or self._key:
+      if not holds_comparables:
+        heap = [
+            cy_combiners.ComparableValue(value, self._less_than, self._key)
+            for value in heap]
+        holds_comparables = True
     else:
-      lt = self._compare
+      assert not holds_comparables
 
-    threshold, buffer = accumulator
-    element_key = self._key_fn(element) if self._key_fn else element
+    comparable = (
+        cy_combiners.ComparableValue(element, self._less_than, self._key)
+        if holds_comparables else element)
 
-    if len(buffer) < self._n:
-      if not buffer:
-        return element_key, [element]
-      buffer.append(element)
-      if lt(element_key, threshold):  # element_key < threshold
-        return element_key, buffer
-      else:
-        return accumulator  # with mutated buffer
-    elif lt(threshold, element_key):  # threshold < element_key
-      buffer.append(element)
-      if len(buffer) < self._buffer_size:
-        return accumulator  # with mutated buffer
-      else:
-        self._sort_buffer(buffer, lt)
-        min_element = buffer[-self._n]
-        threshold = self._key_fn(min_element) if self._key_fn else min_element
-        return threshold, buffer[-self._n:]
+    if len(heap) < self._n:
+      heapq.heappush(heap, comparable)
     else:
-      return accumulator
+      heapq.heappushpop(heap, comparable)
+    return (holds_comparables, heap)
 
   def merge_accumulators(self, accumulators, *args, **kwargs):
     if args or kwargs:
+      self._less_than = lambda a, b: self._compare(a, b, *args, **kwargs)
       add_input = lambda accumulator, element: self.add_input(
           accumulator, element, *args, **kwargs)
     else:
+      self._less_than = self._compare
       add_input = self.add_input
 
-    total_accumulator = None
+    result_heap = None
+    holds_comparables = None
     for accumulator in accumulators:
-      if total_accumulator is None:
-        total_accumulator = accumulator
+      holds_comparables, heap = accumulator
+      if self._less_than is not operator.lt or self._key:
+        if not holds_comparables:
+          heap = [
+              cy_combiners.ComparableValue(value, self._less_than, self._key)
+              for value in heap]
+          holds_comparables = True
       else:
-        for element in accumulator[1]:
-          total_accumulator = add_input(total_accumulator, element)
-    return total_accumulator
+        assert not holds_comparables
 
-  def compact(self, accumulator, *args, **kwargs):
-    if args or kwargs:
-      lt = lambda a, b: self._compare(a, b, *args, **kwargs)
-    else:
-      lt = self._compare
+      if result_heap is None:
+        result_heap = heap
+      else:
+        for comparable in heap:
+          _, result_heap = add_input(
+              (holds_comparables, result_heap),
+              comparable.value if holds_comparables else comparable)
 
-    _, buffer = accumulator
-    if len(buffer) <= self._n:
-      return accumulator  # No compaction needed.
+    assert result_heap is not None and holds_comparables is not None
+    return (holds_comparables, result_heap)
+
+  def compact(self, accumulator, *args, **kwargs):
+    holds_comparables, heap = accumulator
+    # Unwrap to avoid serialization via pickle.
+    if holds_comparables:
+      return (False, [comparable.value for comparable in heap])
     else:
-      self._sort_buffer(buffer, lt)
-      min_element = buffer[-self._n]
-      threshold = self._key_fn(min_element) if self._key_fn else min_element
-      return threshold, buffer[-self._n:]
+      return accumulator
 
   def extract_output(self, accumulator, *args, **kwargs):
     if args or kwargs:
-      lt = lambda a, b: self._compare(a, b, *args, **kwargs)
+      self._less_than = lambda a, b: self._compare(a, b, *args, **kwargs)
+    else:
+      self._less_than = self._compare
+
+    holds_comparables, heap = accumulator
+    if self._less_than is not operator.lt or self._key:
+      if not holds_comparables:
+        heap = [
+            cy_combiners.ComparableValue(value, self._less_than, self._key)
+            for value in heap
+        ]
+        holds_comparables = True
     else:
-      lt = self._compare
+      assert not holds_comparables
 
-    _, buffer = accumulator
-    self._sort_buffer(buffer, lt)
-    return buffer[:-self._n-1:-1]  # tail, reversed
+    assert len(heap) <= self._n
+    heap.sort(reverse=True)
+    return [
+        comparable.value if holds_comparables else comparable
+        for comparable in heap
+    ]
 
 
 class Largest(TopCombineFn):
@@ -777,15 +778,15 @@ class PhasedCombineFnExecutor(object):
     else:
       raise ValueError('Unexpected phase: %s' % phase)
 
-  def full_combine(self, elements):  # pylint: disable=invalid-name
+  def full_combine(self, elements):
     return self.combine_fn.apply(elements)
 
-  def add_only(self, elements):  # pylint: disable=invalid-name
+  def add_only(self, elements):
     return self.combine_fn.add_inputs(
         self.combine_fn.create_accumulator(), elements)
 
-  def merge_only(self, accumulators):  # pylint: disable=invalid-name
+  def merge_only(self, accumulators):
     return self.combine_fn.merge_accumulators(accumulators)
 
-  def extract_only(self, accumulator):  # pylint: disable=invalid-name
+  def extract_only(self, accumulator):
     return self.combine_fn.extract_output(accumulator)
diff --git a/sdks/python/apache_beam/transforms/combiners_test.py b/sdks/python/apache_beam/transforms/combiners_test.py
index 3db019a..f070c32 100644
--- a/sdks/python/apache_beam/transforms/combiners_test.py
+++ b/sdks/python/apache_beam/transforms/combiners_test.py
@@ -41,10 +41,6 @@ from apache_beam.transforms.ptransform import PTransform
 
 class CombineTest(unittest.TestCase):
 
-  def setUp(self):
-    # Sort more often for more rigorous testing on small data sets.
-    combine.TopCombineFn._MIN_BUFFER_OVERSIZE = 1
-
   def test_builtin_combines(self):
     pipeline = TestPipeline()
 
diff --git a/sdks/python/apache_beam/transforms/cy_combiners.pxd b/sdks/python/apache_beam/transforms/cy_combiners.pxd
index 4fc03a7..bfbaa2c 100644
--- a/sdks/python/apache_beam/transforms/cy_combiners.pxd
+++ b/sdks/python/apache_beam/transforms/cy_combiners.pxd
@@ -27,6 +27,7 @@ cdef class CountAccumulator(object):
   @cython.locals(accumulator=CountAccumulator)
   cpdef merge(self, accumulators)
 
+
 cdef class SumInt64Accumulator(object):
   cdef readonly int64_t value
   cpdef add_input(self, int64_t element)
@@ -39,12 +40,14 @@ cdef class MinInt64Accumulator(object):
   @cython.locals(accumulator=MinInt64Accumulator)
   cpdef merge(self, accumulators)
 
+
 cdef class MaxInt64Accumulator(object):
   cdef readonly int64_t value
   cpdef add_input(self, int64_t element)
   @cython.locals(accumulator=MaxInt64Accumulator)
   cpdef merge(self, accumulators)
 
+
 cdef class MeanInt64Accumulator(object):
   cdef readonly int64_t sum
   cdef readonly int64_t count
@@ -59,18 +62,21 @@ cdef class SumDoubleAccumulator(object):
   @cython.locals(accumulator=SumDoubleAccumulator)
   cpdef merge(self, accumulators)
 
+
 cdef class MinDoubleAccumulator(object):
   cdef readonly double value
   cpdef add_input(self, double element)
   @cython.locals(accumulator=MinDoubleAccumulator)
   cpdef merge(self, accumulators)
 
+
 cdef class MaxDoubleAccumulator(object):
   cdef readonly double value
   cpdef add_input(self, double element)
   @cython.locals(accumulator=MaxDoubleAccumulator)
   cpdef merge(self, accumulators)
 
+
 cdef class MeanDoubleAccumulator(object):
   cdef readonly double sum
   cdef readonly int64_t count
@@ -85,8 +91,14 @@ cdef class AllAccumulator(object):
   @cython.locals(accumulator=AllAccumulator)
   cpdef merge(self, accumulators)
 
+
 cdef class AnyAccumulator(object):
   cdef readonly bint value
   cpdef add_input(self, bint element)
   @cython.locals(accumulator=AnyAccumulator)
   cpdef merge(self, accumulators)
+
+
+cdef class ComparableValue(object):
+  cdef readonly object value, _less_than_fn, _comparable_value
+
diff --git a/sdks/python/apache_beam/transforms/cy_combiners.py b/sdks/python/apache_beam/transforms/cy_combiners.py
index 6695c73..5f5fbe4 100644
--- a/sdks/python/apache_beam/transforms/cy_combiners.py
+++ b/sdks/python/apache_beam/transforms/cy_combiners.py
@@ -25,6 +25,7 @@ For internal use only; no backwards-compatibility guarantees.
 from __future__ import absolute_import
 from __future__ import division
 
+import operator
 from builtins import object
 
 from apache_beam.transforms import core
@@ -335,3 +336,33 @@ class DataflowDistributionCounterFn(AccumulatorCombineFn):
   version.
   """
   _accumulator_type = DataflowDistributionCounter
+
+
+class ComparableValue(object):
+  """A way to allow comparing elements in a rich fashion."""
+
+  __slots__ = ('value', '_less_than_fn', '_comparable_value')
+
+  def __init__(self, value, less_than_fn, key_fn, _from_pickle=False):
+    self.value = value
+    self._less_than_fn = less_than_fn if less_than_fn else operator.lt
+    self._comparable_value = key_fn(value) if key_fn else value
+
+    # TODO(b/123368592): Remove this limitation by making ComparableValue
+    # hydratable with less_than_fn and perhaps key_fn post construction, and
+    # updating TopCombineFn appropriately.
+    assert not _from_pickle  # See comments in __reduce__ below.
+
+  def __lt__(self, other):
+    return self._less_than_fn(self._comparable_value, other._comparable_value)
+
+  def __repr__(self):
+    return 'ComparableValue[%s]' % str(self.value)
+
+  def __reduce__(self):
+    # ComparableValues might need to be encoded for sizing estimation, but
+    # should otherwise never be instantiated from their encoded representation
+    # and compared with each other (since we are not always able to
+    # serialize self._less_than_fn and/or self._key_fn) and this is verified in
+    # __init__ by asserting on _from_pickle.
+    return ComparableValue, (self.value, None, None, True)


Mime
View raw message