beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From lc...@apache.org
Subject [beam] branch master updated: Move TestStream implementation to replacement transform
Date Wed, 26 Feb 2020 19:13:33 GMT
This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 437c88f  Move TestStream implementation to replacement transform
     new 7a4cdec  Merge pull request #10892 from rohdesamuel/teststream_merge
437c88f is described below

commit 437c88f1c9754fdac9769d734d08f281318fb343
Author: Sam Rohde <rohde.samuel@gmail.com>
AuthorDate: Thu Feb 13 16:38:43 2020 -0800

    Move TestStream implementation to replacement transform
    
    * This also moves the DirectRunner's TestStream implementation to a
    replacement transform. This is because the TestStream relies on getting
    the output_tags from the PTransform.
    
    Change-Id: Ibd80b0d25cd8cc5ff5c28e127f7313638e6664da
---
 sdks/python/apache_beam/io/iobase.py               |  2 +-
 sdks/python/apache_beam/pipeline.py                |  2 +-
 .../apache_beam/runners/direct/direct_runner.py    | 70 ++++++----------------
 .../apache_beam/runners/direct/test_stream_impl.py | 61 ++++++++++++++++++-
 .../runners/portability/expansion_service.py       |  3 +-
 .../runners/portability/expansion_service_test.py  | 29 +++++----
 sdks/python/apache_beam/testing/test_stream.py     | 53 ++++++++++++----
 .../python/apache_beam/testing/test_stream_test.py | 51 ++++++++++++++++
 sdks/python/apache_beam/transforms/core.py         | 17 +++---
 .../apache_beam/transforms/external_it_test.py     |  2 +-
 sdks/python/apache_beam/transforms/ptransform.py   | 16 ++---
 sdks/python/apache_beam/transforms/util.py         |  3 +-
 12 files changed, 214 insertions(+), 95 deletions(-)

diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py
index e320c4a..1edcac9 100644
--- a/sdks/python/apache_beam/io/iobase.py
+++ b/sdks/python/apache_beam/io/iobase.py
@@ -925,7 +925,7 @@ class Read(ptransform.PTransform):
             beam_runner_api_pb2.IsBounded.UNBOUNDED))
 
   @staticmethod
-  def from_runner_api_parameter(parameter, context):
+  def from_runner_api_parameter(unused_ptransform, parameter, context):
     # type: (beam_runner_api_pb2.ReadPayload, PipelineContext) -> Read
     return Read(SourceBase.from_runner_api(parameter.source, context))
 
diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py
index 50c141a..22aa2a3 100644
--- a/sdks/python/apache_beam/pipeline.py
+++ b/sdks/python/apache_beam/pipeline.py
@@ -1099,7 +1099,7 @@ class AppliedPTransform(object):
         id in proto.inputs.items() if is_side_input(tag)
     ]
     side_inputs = [si for _, si in sorted(indexed_side_inputs)]
-    transform = ptransform.PTransform.from_runner_api(proto.spec, context)
+    transform = ptransform.PTransform.from_runner_api(proto, context)
     result = AppliedPTransform(
         parent=None,
         transform=transform,
diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py
index 1ddda51..f207779 100644
--- a/sdks/python/apache_beam/runners/direct/direct_runner.py
+++ b/sdks/python/apache_beam/runners/direct/direct_runner.py
@@ -73,60 +73,12 @@ class SwitchingDirectRunner(PipelineRunner):
   def is_fnapi_compatible(self):
     return BundleBasedDirectRunner.is_fnapi_compatible()
 
-  def apply_TestStream(self, transform, pbegin, options):
-    """Expands the TestStream into the DirectRunner implementation.
-
-    Takes the TestStream transform and creates a _TestStream -> multiplexer ->
-    _WatermarkController.
-    """
-
-    from apache_beam.runners.direct.test_stream_impl import _TestStream
-    from apache_beam.runners.direct.test_stream_impl import _WatermarkController
-    from apache_beam import pvalue
-    assert isinstance(pbegin, pvalue.PBegin)
-
-    # If there is only one tag there is no need to add the multiplexer.
-    if len(transform.output_tags) == 1:
-      return (
-          pbegin
-          | _TestStream(transform.output_tags, events=transform._events)
-          | _WatermarkController())
-
-    # This multiplexing the  multiple output PCollections.
-    def mux(event):
-      if event.tag:
-        yield pvalue.TaggedOutput(event.tag, event)
-      else:
-        yield event
-
-    mux_output = (
-        pbegin
-        | _TestStream(transform.output_tags, events=transform._events)
-        | 'TestStream Multiplexer' >> beam.ParDo(mux).with_outputs())
-
-    # Apply a way to control the watermark per output. It is necessary to
-    # have an individual _WatermarkController per PCollection because the
-    # calculation of the input watermark of a transform is based on the event
-    # timestamp of the elements flowing through it. Meaning, it is impossible
-    # to control the output watermarks of the individual PCollections solely
-    # on the event timestamps.
-    outputs = {}
-    for tag in transform.output_tags:
-      label = '_WatermarkController[{}]'.format(tag)
-      outputs[tag] = (mux_output[tag] | label >> _WatermarkController())
-
-    return outputs
-
-  # We must mark this method as not a test or else its name is a matcher for
-  # nosetest tests.
-  apply_TestStream.__test__ = False
-
   def run_pipeline(self, pipeline, options):
 
     from apache_beam.pipeline import PipelineVisitor
     from apache_beam.runners.dataflow.native_io.iobase import NativeSource
     from apache_beam.runners.dataflow.native_io.iobase import _NativeWrite
-    from apache_beam.runners.direct.test_stream_impl import _TestStream
+    from apache_beam.testing.test_stream import TestStream
 
     class _FnApiRunnerSupportVisitor(PipelineVisitor):
       """Visitor determining if a Pipeline can be run on the FnApiRunner."""
@@ -138,7 +90,7 @@ class SwitchingDirectRunner(PipelineRunner):
       def visit_transform(self, applied_ptransform):
         transform = applied_ptransform.transform
         # The FnApiRunner does not support streaming execution.
-        if isinstance(transform, _TestStream):
+        if isinstance(transform, TestStream):
           self.supported_by_fnapi_runner = False
         # The FnApiRunner does not support reads from NativeSources.
         if (isinstance(transform, beam.io.Read) and
@@ -195,7 +147,8 @@ class _StreamingGroupByKeyOnly(_GroupByKeyOnly):
 
   @staticmethod
   @PTransform.register_urn(urn, None)
-  def from_runner_api_parameter(unused_payload, unused_context):
+  def from_runner_api_parameter(
+      unused_ptransform, unused_payload, unused_context):
     return _StreamingGroupByKeyOnly()
 
 
@@ -214,7 +167,7 @@ class _StreamingGroupAlsoByWindow(_GroupAlsoByWindow):
 
   @staticmethod
   @PTransform.register_urn(urn, wrappers_pb2.BytesValue)
-  def from_runner_api_parameter(payload, context):
+  def from_runner_api_parameter(unused_ptransform, payload, context):
     return _StreamingGroupAlsoByWindow(
         context.windowing_strategies.get_by_id(payload.value))
 
@@ -271,10 +224,21 @@ def _get_transform_overrides(pipeline_options):
       transform = _StreamingGroupAlsoByWindow(transform.dofn.windowing)
       return transform
 
+  class TestStreamOverride(PTransformOverride):
+    def matches(self, applied_ptransform):
+      from apache_beam.testing.test_stream import TestStream
+      self.applied_ptransform = applied_ptransform
+      return isinstance(applied_ptransform.transform, TestStream)
+
+    def get_replacement_transform(self, transform):
+      from apache_beam.runners.direct.test_stream_impl import _ExpandableTestStream
+      return _ExpandableTestStream(transform)
+
   overrides = [
       SplittableParDoOverride(),
       ProcessKeyedElementsViaKeyedWorkItemsOverride(),
-      CombinePerKeyOverride()
+      CombinePerKeyOverride(),
+      TestStreamOverride(),
   ]
 
   # Add streaming overrides, if necessary.
diff --git a/sdks/python/apache_beam/runners/direct/test_stream_impl.py b/sdks/python/apache_beam/runners/direct/test_stream_impl.py
index 9b91154..8cee1fc 100644
--- a/sdks/python/apache_beam/runners/direct/test_stream_impl.py
+++ b/sdks/python/apache_beam/runners/direct/test_stream_impl.py
@@ -27,6 +27,7 @@ tagged PCollection.
 
 from __future__ import absolute_import
 
+from apache_beam import ParDo
 from apache_beam import coders
 from apache_beam import pvalue
 from apache_beam.testing.test_stream import WatermarkEvent
@@ -45,11 +46,69 @@ class _WatermarkController(PTransform):
    - If the instance receives an ElementEvent, it emits all specified elements
      to the Global Window with the event time set to the element's timestamp.
   """
+  def __init__(self, output_tag):
+    self.output_tag = output_tag
+
   def get_windowing(self, _):
     return core.Windowing(window.GlobalWindows())
 
   def expand(self, pcoll):
-    return pvalue.PCollection.from_(pcoll)
+    ret = pvalue.PCollection.from_(pcoll)
+    ret.tag = self.output_tag
+    return ret
+
+
+class _ExpandableTestStream(PTransform):
+  def __init__(self, test_stream):
+    self.test_stream = test_stream
+
+  def expand(self, pbegin):
+    """Expands the TestStream into the DirectRunner implementation.
+
+
+    Takes the TestStream transform and creates a _TestStream -> multiplexer ->
+    _WatermarkController.
+    """
+
+    assert isinstance(pbegin, pvalue.PBegin)
+
+    # If there is only one tag there is no need to add the multiplexer.
+    if len(self.test_stream.output_tags) == 1:
+      return (
+          pbegin
+          | _TestStream(
+              self.test_stream.output_tags,
+              events=self.test_stream._events,
+              coder=self.test_stream.coder)
+          | _WatermarkController(list(self.test_stream.output_tags)[0]))
+
+    # Multiplex to the correct PCollection based upon the event tag.
+    def mux(event):
+      if event.tag:
+        yield pvalue.TaggedOutput(event.tag, event)
+      else:
+        yield event
+
+    mux_output = (
+        pbegin
+        | _TestStream(
+            self.test_stream.output_tags,
+            events=self.test_stream._events,
+            coder=self.test_stream.coder)
+        | 'TestStream Multiplexer' >> ParDo(mux).with_outputs())
+
+    # Apply a way to control the watermark per output. It is necessary to
+    # have an individual _WatermarkController per PCollection because the
+    # calculation of the input watermark of a transform is based on the event
+    # timestamp of the elements flowing through it. Meaning, it is impossible
+    # to control the output watermarks of the individual PCollections solely
+    # on the event timestamps.
+    outputs = {}
+    for tag in self.test_stream.output_tags:
+      label = '_WatermarkController[{}]'.format(tag)
+      outputs[tag] = (mux_output[tag] | label >> _WatermarkController(tag))
+
+    return outputs
 
 
 class _TestStream(PTransform):
diff --git a/sdks/python/apache_beam/runners/portability/expansion_service.py b/sdks/python/apache_beam/runners/portability/expansion_service.py
index 5a9601d..d2be037 100644
--- a/sdks/python/apache_beam/runners/portability/expansion_service.py
+++ b/sdks/python/apache_beam/runners/portability/expansion_service.py
@@ -64,8 +64,7 @@ class ExpansionServiceServicer(
           pcoll_id in t_proto.outputs.items()
       }
       transform = with_pipeline(
-          ptransform.PTransform.from_runner_api(
-              request.transform.spec, context))
+          ptransform.PTransform.from_runner_api(request.transform, context))
       inputs = transform._pvaluish_from_dict({
           tag:
           with_pipeline(context.pcollections.get_by_id(pcoll_id), pcoll_id)
diff --git a/sdks/python/apache_beam/runners/portability/expansion_service_test.py b/sdks/python/apache_beam/runners/portability/expansion_service_test.py
index 809a2c4..34acac1 100644
--- a/sdks/python/apache_beam/runners/portability/expansion_service_test.py
+++ b/sdks/python/apache_beam/runners/portability/expansion_service_test.py
@@ -62,7 +62,8 @@ class CountPerElementTransform(ptransform.PTransform):
     return 'beam:transforms:xlang:count', None
 
   @staticmethod
-  def from_runner_api_parameter(unused_parameter, unused_context):
+  def from_runner_api_parameter(
+      unused_ptransform, unused_parameter, unused_context):
     return CountPerElementTransform()
 
 
@@ -82,7 +83,7 @@ class FilterLessThanTransform(ptransform.PTransform):
         'beam:transforms:xlang:filter_less_than', self._payload.encode('utf8'))
 
   @staticmethod
-  def from_runner_api_parameter(payload, unused_context):
+  def from_runner_api_parameter(unused_ptransform, payload, unused_context):
     return FilterLessThanTransform(payload.decode('utf8'))
 
 
@@ -101,7 +102,7 @@ class PrefixTransform(ptransform.PTransform):
         {'data': self._payload}).payload()
 
   @staticmethod
-  def from_runner_api_parameter(payload, unused_context):
+  def from_runner_api_parameter(unused_ptransform, payload, unused_context):
     return PrefixTransform(parse_string_payload(payload)['data'])
 
 
@@ -134,7 +135,8 @@ class GBKTransform(ptransform.PTransform):
     return TEST_GBK_URN, None
 
   @staticmethod
-  def from_runner_api_parameter(unused_parameter, unused_context):
+  def from_runner_api_parameter(
+      unused_ptransform, unused_parameter, unused_context):
     return GBKTransform()
 
 
@@ -155,7 +157,8 @@ class CoGBKTransform(ptransform.PTransform):
     return TEST_CGBK_URN, None
 
   @staticmethod
-  def from_runner_api_parameter(unused_parameter, unused_context):
+  def from_runner_api_parameter(
+      unused_ptransform, unused_parameter, unused_context):
     return CoGBKTransform()
 
 
@@ -169,7 +172,8 @@ class CombineGloballyTransform(ptransform.PTransform):
     return TEST_COMGL_URN, None
 
   @staticmethod
-  def from_runner_api_parameter(unused_parameter, unused_context):
+  def from_runner_api_parameter(
+      unused_ptransform, unused_parameter, unused_context):
     return CombineGloballyTransform()
 
 
@@ -184,7 +188,8 @@ class CombinePerKeyTransform(ptransform.PTransform):
     return TEST_COMPK_URN, None
 
   @staticmethod
-  def from_runner_api_parameter(unused_parameter, unused_context):
+  def from_runner_api_parameter(
+      unused_ptransform, unused_parameter, unused_context):
     return CombinePerKeyTransform()
 
 
@@ -197,7 +202,8 @@ class FlattenTransform(ptransform.PTransform):
     return TEST_FLATTEN_URN, None
 
   @staticmethod
-  def from_runner_api_parameter(unused_parameter, unused_context):
+  def from_runner_api_parameter(
+      unused_ptransform, unused_parameter, unused_context):
     return FlattenTransform()
 
 
@@ -214,7 +220,8 @@ class PartitionTransform(ptransform.PTransform):
     return TEST_PARTITION_URN, None
 
   @staticmethod
-  def from_runner_api_parameter(unused_parameter, unused_context):
+  def from_runner_api_parameter(
+      unused_ptransform, unused_parameter, unused_context):
     return PartitionTransform()
 
 
@@ -230,7 +237,7 @@ class PayloadTransform(ptransform.PTransform):
     return b'payload', self._payload.encode('ascii')
 
   @staticmethod
-  def from_runner_api_parameter(payload, unused_context):
+  def from_runner_api_parameter(unused_ptransform, payload, unused_context):
     return PayloadTransform(payload.decode('ascii'))
 
 
@@ -259,7 +266,7 @@ class FibTransform(ptransform.PTransform):
     return 'fib', str(self._level).encode('ascii')
 
   @staticmethod
-  def from_runner_api_parameter(level, unused_context):
+  def from_runner_api_parameter(unused_ptransform, level, unused_context):
     return FibTransform(int(level.decode('ascii')))
 
 
diff --git a/sdks/python/apache_beam/testing/test_stream.py b/sdks/python/apache_beam/testing/test_stream.py
index 406f4da..967cbcf 100644
--- a/sdks/python/apache_beam/testing/test_stream.py
+++ b/sdks/python/apache_beam/testing/test_stream.py
@@ -76,16 +76,21 @@ class Event(with_metaclass(ABCMeta, object)):  # type: ignore[misc]
   @staticmethod
   def from_runner_api(proto, element_coder):
     if proto.HasField('element_event'):
+      event = proto.element_event
+      tag = None if event.tag == 'None' else event.tag
       return ElementEvent([
           TimestampedValue(
               element_coder.decode(tv.encoded_element),
               timestamp.Timestamp(micros=1000 * tv.timestamp))
           for tv in proto.element_event.elements
-      ])
+      ], tag=tag) # yapf: disable
     elif proto.HasField('watermark_event'):
+      event = proto.watermark_event
+      tag = None if event.tag == 'None' else event.tag
       return WatermarkEvent(
           timestamp.Timestamp(
-              micros=1000 * proto.watermark_event.new_watermark))
+              micros=1000 * proto.watermark_event.new_watermark),
+          tag=tag)
     elif proto.HasField('processing_time_event'):
       return ProcessingTimeEvent(
           timestamp.Duration(
@@ -113,6 +118,7 @@ class ElementEvent(Event):
     return self.timestamped_values < other.timestamped_values
 
   def to_runner_api(self, element_coder):
+    tag = 'None' if self.tag is None else self.tag
     return beam_runner_api_pb2.TestStreamPayload.Event(
         element_event=beam_runner_api_pb2.TestStreamPayload.Event.AddElements(
             elements=[
@@ -120,7 +126,8 @@ class ElementEvent(Event):
                     encoded_element=element_coder.encode(tv.value),
                     timestamp=tv.timestamp.micros // 1000)
                 for tv in self.timestamped_values
-            ]))
+            ],
+            tag=tag))
 
 
 class WatermarkEvent(Event):
@@ -133,15 +140,21 @@ class WatermarkEvent(Event):
     return self.new_watermark == other.new_watermark and self.tag == other.tag
 
   def __hash__(self):
-    return hash(self.new_watermark)
+    return hash(str(self.new_watermark) + str(self.tag))
 
   def __lt__(self, other):
     return self.new_watermark < other.new_watermark
 
   def to_runner_api(self, unused_element_coder):
+    tag = 'None' if self.tag is None else self.tag
+
+    # Assert that no prevision is lost.
+    assert 1000 * (
+        self.new_watermark.micros // 1000) == self.new_watermark.micros
     return beam_runner_api_pb2.TestStreamPayload.Event(
         watermark_event=beam_runner_api_pb2.TestStreamPayload.Event.
-        AdvanceWatermark(new_watermark=self.new_watermark.micros // 1000))
+        AdvanceWatermark(
+            new_watermark=self.new_watermark.micros // 1000, tag=tag))
 
 
 class ProcessingTimeEvent(Event):
@@ -171,13 +184,20 @@ class TestStream(PTransform):
   time. After all of the specified elements are emitted, ceases to produce
   output.
   """
-  def __init__(self, coder=coders.FastPrimitivesCoder(), events=None):
+  def __init__(
+      self, coder=coders.FastPrimitivesCoder(), events=None, output_tags=None):
     super(TestStream, self).__init__()
     assert coder is not None
+
     self.coder = coder
     self.watermarks = {None: timestamp.MIN_TIMESTAMP}
     self._events = [] if events is None else list(events)
-    self.output_tags = set()
+    self.output_tags = set(output_tags) if output_tags else set()
+
+    event_tags = set(
+        e.tag for e in self._events
+        if isinstance(e, (WatermarkEvent, ElementEvent)))
+    assert event_tags.issubset(self.output_tags)
 
   def get_windowing(self, unused_inputs):
     return core.Windowing(window.GlobalWindows())
@@ -188,7 +208,17 @@ class TestStream(PTransform):
   def expand(self, pbegin):
     assert isinstance(pbegin, pvalue.PBegin)
     self.pipeline = pbegin.pipeline
-    return pvalue.PCollection(self.pipeline, is_bounded=False)
+    if not self.output_tags:
+      self.output_tags = set([None])
+
+    # For backwards compatibility return a single PCollection.
+    if len(self.output_tags) == 1:
+      return pvalue.PCollection(
+          self.pipeline, is_bounded=False, tag=list(self.output_tags)[0])
+    return {
+        tag: pvalue.PCollection(self.pipeline, is_bounded=False, tag=tag)
+        for tag in self.output_tags
+    }
 
   def _add(self, event):
     if isinstance(event, ElementEvent):
@@ -276,8 +306,11 @@ class TestStream(PTransform):
   @PTransform.register_urn(
       common_urns.primitives.TEST_STREAM.urn,
       beam_runner_api_pb2.TestStreamPayload)
-  def from_runner_api_parameter(payload, context):
+  def from_runner_api_parameter(ptransform, payload, context):
     coder = context.coders.get_by_id(payload.coder_id)
+    output_tags = set(
+        None if k == 'None' else k for k in ptransform.outputs.keys())
     return TestStream(
         coder=coder,
-        events=[Event.from_runner_api(e, coder) for e in payload.events])
+        events=[Event.from_runner_api(e, coder) for e in payload.events],
+        output_tags=output_tags)
diff --git a/sdks/python/apache_beam/testing/test_stream_test.py b/sdks/python/apache_beam/testing/test_stream_test.py
index 05b1d59..cb98b14 100644
--- a/sdks/python/apache_beam/testing/test_stream_test.py
+++ b/sdks/python/apache_beam/testing/test_stream_test.py
@@ -26,6 +26,7 @@ import unittest
 import apache_beam as beam
 from apache_beam.options.pipeline_options import PipelineOptions
 from apache_beam.options.pipeline_options import StandardOptions
+from apache_beam.portability import common_urns
 from apache_beam.testing.test_pipeline import TestPipeline
 from apache_beam.testing.test_stream import ElementEvent
 from apache_beam.testing.test_stream import ProcessingTimeEvent
@@ -528,6 +529,56 @@ class TestStreamTest(unittest.TestCase):
 
     p.run()
 
+  def test_roundtrip_proto(self):
+    test_stream = (TestStream()
+                   .advance_processing_time(1)
+                   .advance_watermark_to(2)
+                   .add_elements([1, 2, 3])) # yapf: disable
+
+    p = TestPipeline(options=StandardOptions(streaming=True))
+    p | test_stream
+
+    pipeline_proto, context = p.to_runner_api(return_context=True)
+
+    for t in pipeline_proto.components.transforms.values():
+      if t.spec.urn == common_urns.primitives.TEST_STREAM.urn:
+        test_stream_proto = t
+
+    self.assertTrue(test_stream_proto)
+    roundtrip_test_stream = TestStream().from_runner_api(
+        test_stream_proto, context)
+
+    self.assertListEqual(test_stream._events, roundtrip_test_stream._events)
+    self.assertSetEqual(
+        test_stream.output_tags, roundtrip_test_stream.output_tags)
+    self.assertEqual(test_stream.coder, roundtrip_test_stream.coder)
+
+  def test_roundtrip_proto_multi(self):
+    test_stream = (TestStream()
+                   .advance_processing_time(1)
+                   .advance_watermark_to(2, tag='a')
+                   .advance_watermark_to(3, tag='b')
+                   .add_elements([1, 2, 3], tag='a')
+                   .add_elements([4, 5, 6], tag='b')) # yapf: disable
+
+    p = TestPipeline(options=StandardOptions(streaming=True))
+    p | test_stream
+
+    pipeline_proto, context = p.to_runner_api(return_context=True)
+
+    for t in pipeline_proto.components.transforms.values():
+      if t.spec.urn == common_urns.primitives.TEST_STREAM.urn:
+        test_stream_proto = t
+
+    self.assertTrue(test_stream_proto)
+    roundtrip_test_stream = TestStream().from_runner_api(
+        test_stream_proto, context)
+
+    self.assertListEqual(test_stream._events, roundtrip_test_stream._events)
+    self.assertSetEqual(
+        test_stream.output_tags, roundtrip_test_stream.output_tags)
+    self.assertEqual(test_stream.coder, roundtrip_test_stream.coder)
+
 
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index 57660c5..f5ccafa 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -1339,7 +1339,7 @@ class ParDo(PTransformWithSideInputs):
   @staticmethod
   @PTransform.register_urn(
       common_urns.primitives.PAR_DO.urn, beam_runner_api_pb2.ParDoPayload)
-  def from_runner_api_parameter(pardo_payload, context):
+  def from_runner_api_parameter(unused_ptransform, pardo_payload, context):
     assert pardo_payload.do_fn.urn == python_urns.PICKLED_DOFN_INFO
     fn, args, kwargs, si_tags_and_types, windowing = pickler.loads(
         pardo_payload.do_fn.payload)
@@ -1932,7 +1932,7 @@ class CombinePerKey(PTransformWithSideInputs):
   @PTransform.register_urn(
       common_urns.composites.COMBINE_PER_KEY.urn,
       beam_runner_api_pb2.CombinePayload)
-  def from_runner_api_parameter(combine_payload, context):
+  def from_runner_api_parameter(unused_ptransform, combine_payload, context):
     return CombinePerKey(
         CombineFn.from_runner_api(combine_payload.combine_fn, context))
 
@@ -1975,7 +1975,7 @@ class CombineValues(PTransformWithSideInputs):
   @PTransform.register_urn(
       common_urns.combine_components.COMBINE_GROUPED_VALUES.urn,
       beam_runner_api_pb2.CombinePayload)
-  def from_runner_api_parameter(combine_payload, context):
+  def from_runner_api_parameter(unused_ptransform, combine_payload, context):
     return CombineValues(
         CombineFn.from_runner_api(combine_payload.combine_fn, context))
 
@@ -2203,7 +2203,8 @@ class GroupByKey(PTransform):
 
   @staticmethod
   @PTransform.register_urn(common_urns.primitives.GROUP_BY_KEY.urn, None)
-  def from_runner_api_parameter(unused_payload, unused_context):
+  def from_runner_api_parameter(
+      unused_ptransform, unused_payload, unused_context):
     return GroupByKey()
 
   def runner_api_requires_keyed_input(self):
@@ -2494,7 +2495,7 @@ class WindowInto(ParDo):
         self.windowing.to_runner_api(context))
 
   @staticmethod
-  def from_runner_api_parameter(proto, context):
+  def from_runner_api_parameter(unused_ptransform, proto, context):
     windowing = Windowing.from_runner_api(proto, context)
     return WindowInto(
         windowing.windowfn,
@@ -2568,7 +2569,8 @@ class Flatten(PTransform):
     return common_urns.primitives.FLATTEN.urn, None
 
   @staticmethod
-  def from_runner_api_parameter(unused_parameter, unused_context):
+  def from_runner_api_parameter(
+      unused_ptransform, unused_parameter, unused_context):
     return Flatten()
 
 
@@ -2681,5 +2683,6 @@ class Impulse(PTransform):
 
   @staticmethod
   @PTransform.register_urn(common_urns.primitives.IMPULSE.urn, None)
-  def from_runner_api_parameter(unused_parameter, unused_context):
+  def from_runner_api_parameter(
+      unused_ptransform, unused_parameter, unused_context):
     return Impulse()
diff --git a/sdks/python/apache_beam/transforms/external_it_test.py b/sdks/python/apache_beam/transforms/external_it_test.py
index 40ee910..d99c218 100644
--- a/sdks/python/apache_beam/transforms/external_it_test.py
+++ b/sdks/python/apache_beam/transforms/external_it_test.py
@@ -46,7 +46,7 @@ class ExternalTransformIT(unittest.TestCase):
         return 'simple', None
 
       @staticmethod
-      def from_runner_api_parameter(_1, _2):
+      def from_runner_api_parameter(_0, _1, _2):
         return SimpleTransform()
 
     pipeline = TestPipeline(is_integration_test=True)
diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py
index 66a953b..5e44ef7 100644
--- a/sdks/python/apache_beam/transforms/ptransform.py
+++ b/sdks/python/apache_beam/transforms/ptransform.py
@@ -669,23 +669,25 @@ class PTransform(WithTypeHints, HasDisplayData):
 
   @classmethod
   def from_runner_api(cls,
-                      proto,  # type: Optional[beam_runner_api_pb2.FunctionSpec]
+                      proto,  # type: Optional[beam_runner_api_pb2.PTransform]
                       context  # type: PipelineContext
                      ):
     # type: (...) -> Optional[PTransform]
-    if proto is None or not proto.urn:
+    if proto is None or proto.spec is None or not proto.spec.urn:
       return None
-    parameter_type, constructor = cls._known_urns[proto.urn]
+    parameter_type, constructor = cls._known_urns[proto.spec.urn]
 
     try:
       return constructor(
-          proto_utils.parse_Bytes(proto.payload, parameter_type), context)
+          proto,
+          proto_utils.parse_Bytes(proto.spec.payload, parameter_type),
+          context)
     except Exception:
       if context.allow_proto_holders:
         # For external transforms we cannot build a Python ParDo object so
         # we build a holder transform instead.
         from apache_beam.transforms.core import RunnerAPIPTransformHolder
-        return RunnerAPIPTransformHolder(proto, context)
+        return RunnerAPIPTransformHolder(proto.spec, context)
       raise
 
   def to_runner_api_parameter(
@@ -707,14 +709,14 @@ class PTransform(WithTypeHints, HasDisplayData):
 
 
 @PTransform.register_urn(python_urns.GENERIC_COMPOSITE_TRANSFORM, None)
-def _create_transform(payload, unused_context):
+def _create_transform(unused_ptransform, payload, unused_context):
   empty_transform = PTransform()
   empty_transform._fn_api_payload = payload
   return empty_transform
 
 
 @PTransform.register_urn(python_urns.PICKLED_TRANSFORM, None)
-def _unpickle_transform(pickled_bytes, unused_context):
+def _unpickle_transform(unused_ptransform, pickled_bytes, unused_context):
   return pickler.loads(pickled_bytes)
 
 
diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py
index 3ad99fe..361edbe 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -722,7 +722,8 @@ class Reshuffle(PTransform):
 
   @staticmethod
   @PTransform.register_urn(common_urns.composites.RESHUFFLE.urn, None)
-  def from_runner_api_parameter(unused_parameter, unused_context):
+  def from_runner_api_parameter(
+      unused_ptransform, unused_parameter, unused_context):
     return Reshuffle()
 
 


Mime
View raw message