Repository: beam Updated Branches: refs/heads/master 34b38ef95 -> 9cc8018b3 http://git-wip-us.apache.org/repos/asf/beam/blob/8766b03e/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index 31307cc..ccf84b2 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -332,20 +332,58 @@ final class StreamingTransformTranslator { }; } + private static TransformEvaluator> parDo() { + return new TransformEvaluator>() { + @Override + public void evaluate(final ParDo.Bound transform, + final EvaluationContext context) { + final DoFn doFn = transform.getFn(); + rejectSplittable(doFn); + rejectStateAndTimers(doFn); + final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); + final WindowingStrategy windowingStrategy = + context.getInput(transform).getWindowingStrategy(); + final SparkPCollectionView pviews = context.getPViews(); + + @SuppressWarnings("unchecked") + UnboundedDataset unboundedDataset = + ((UnboundedDataset) context.borrowDataset(transform)); + JavaDStream> dStream = unboundedDataset.getDStream(); + + final String stepName = context.getCurrentTransform().getFullName(); + + JavaDStream> outStream = + dStream.transform(new Function>, + JavaRDD>>() { + @Override + public JavaRDD> call(JavaRDD> rdd) throws + Exception { + final JavaSparkContext jsc = new JavaSparkContext(rdd.context()); + final Accumulator aggAccum = + SparkAggregators.getNamedAggregators(jsc); + final Accumulator metricsAccum = + MetricsAccumulator.getInstance(); + final Map, KV, SideInputBroadcast>> sideInputs = + TranslationUtils.getSideInputs(transform.getSideInputs(), + jsc, pviews); + return rdd.mapPartitions( + new DoFnFunction<>(aggAccum, metricsAccum, stepName, doFn, runtimeContext, + sideInputs, windowingStrategy)); + } + }); + + context.putDataset(transform, + new UnboundedDataset<>(outStream, unboundedDataset.getStreamSources())); + } + }; + } + private static TransformEvaluator> multiDo() { return new TransformEvaluator>() { - public void evaluate( - final ParDo.BoundMulti transform, final EvaluationContext context) { - if (transform.getSideOutputTags().size() == 0) { - evaluateSingle(transform, context); - } else { - evaluateMulti(transform, context); - } - } - - private void evaluateMulti( - final ParDo.BoundMulti transform, final EvaluationContext context) { + @Override + public void evaluate(final ParDo.BoundMulti transform, + final EvaluationContext context) { final DoFn doFn = transform.getFn(); rejectSplittable(doFn); rejectStateAndTimers(doFn); @@ -389,60 +427,10 @@ final class StreamingTransformTranslator { JavaDStream> values = (JavaDStream>) (JavaDStream) TranslationUtils.dStreamValues(filtered); - context.putDataset( - e.getValue(), new UnboundedDataset<>(values, unboundedDataset.getStreamSources())); + context.putDataset(e.getValue(), + new UnboundedDataset<>(values, unboundedDataset.getStreamSources())); } } - - private void evaluateSingle( - final ParDo.BoundMulti transform, final EvaluationContext context) { - final DoFn doFn = transform.getFn(); - rejectSplittable(doFn); - rejectStateAndTimers(doFn); - final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); - final WindowingStrategy windowingStrategy = - context.getInput(transform).getWindowingStrategy(); - final SparkPCollectionView pviews = context.getPViews(); - - @SuppressWarnings("unchecked") - UnboundedDataset unboundedDataset = - ((UnboundedDataset) context.borrowDataset(transform)); - JavaDStream> dStream = unboundedDataset.getDStream(); - - final String stepName = context.getCurrentTransform().getFullName(); - - JavaDStream> outStream = - dStream.transform( - new Function>, JavaRDD>>() { - @Override - public JavaRDD> call(JavaRDD> rdd) - throws Exception { - final JavaSparkContext jsc = new JavaSparkContext(rdd.context()); - final Accumulator aggAccum = - SparkAggregators.getNamedAggregators(jsc); - final Accumulator metricsAccum = - MetricsAccumulator.getInstance(); - final Map, KV, SideInputBroadcast>> - sideInputs = - TranslationUtils.getSideInputs(transform.getSideInputs(), jsc, pviews); - return rdd.mapPartitions( - new DoFnFunction<>( - aggAccum, - metricsAccum, - stepName, - doFn, - runtimeContext, - sideInputs, - windowingStrategy)); - } - }); - - PCollection output = - (PCollection) - Iterables.getOnlyElement(context.getOutputs(transform)).getValue(); - context.putDataset( - output, new UnboundedDataset<>(outStream, unboundedDataset.getStreamSources())); - } }; } @@ -487,6 +475,7 @@ final class StreamingTransformTranslator { EVALUATORS.put(Read.Unbounded.class, readUnbounded()); EVALUATORS.put(GroupByKey.class, groupByKey()); EVALUATORS.put(Combine.GroupedValues.class, combineGrouped()); + EVALUATORS.put(ParDo.Bound.class, parDo()); EVALUATORS.put(ParDo.BoundMulti.class, multiDo()); EVALUATORS.put(ConsoleIO.Write.Unbound.class, print()); EVALUATORS.put(CreateStream.class, createFromQueue()); http://git-wip-us.apache.org/repos/asf/beam/blob/8766b03e/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java index d66633b..b181a04 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java @@ -83,7 +83,7 @@ public class TrackStreamingSourcesTest { p.apply(emptyStream).apply(ParDo.of(new PassthroughFn<>())); - p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.BoundMulti.class, 0)); + p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.Bound.class, 0)); assertThat(StreamingSourceTracker.numAssertions, equalTo(1)); } @@ -111,7 +111,7 @@ public class TrackStreamingSourcesTest { PCollectionList.of(pcol1).and(pcol2).apply(Flatten.pCollections()); flattened.apply(ParDo.of(new PassthroughFn<>())); - p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.BoundMulti.class, 0, 1)); + p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.Bound.class, 0, 1)); assertThat(StreamingSourceTracker.numAssertions, equalTo(1)); } http://git-wip-us.apache.org/repos/asf/beam/blob/8766b03e/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java index 9225231..19c5a2d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java @@ -738,8 +738,12 @@ public class ParDo { @Override public PCollection expand(PCollection input) { - TupleTag mainOutput = new TupleTag<>(); - return input.apply(withOutputTags(mainOutput, TupleTagList.empty())).get(mainOutput); + validateWindowType(input, fn); + return PCollection.createPrimitiveOutputInternal( + input.getPipeline(), + input.getWindowingStrategy(), + input.isBounded()) + .setTypeDescriptor(getFn().getOutputTypeDescriptor()); } @Override