beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From k...@apache.org
Subject [1/2] beam git commit: Add CombineTranslation
Date Thu, 25 May 2017 20:37:11 GMT
Repository: beam
Updated Branches:
  refs/heads/master 1be6f67aa -> 2040e2bd4


Add CombineTranslation

This translates Combines to CombinePayloads and back


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/5b899a85
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/5b899a85
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/5b899a85

Branch: refs/heads/master
Commit: 5b899a8518cc8910f0a855303c14088d72b332e5
Parents: 4ec3366
Author: Thomas Groh <tgroh@google.com>
Authored: Thu May 18 10:23:16 2017 -0700
Committer: Thomas Groh <tgroh@google.com>
Committed: Wed May 24 13:04:55 2017 -0700

----------------------------------------------------------------------
 .../core/construction/CombineTranslation.java   | 125 ++++++++++++++++++
 .../construction/CombineTranslationTest.java    | 130 +++++++++++++++++++
 .../org/apache/beam/sdk/transforms/Count.java   |  10 ++
 .../org/apache/beam/sdk/transforms/Sum.java     |  30 +++++
 4 files changed, 295 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/5b899a85/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java
new file mode 100644
index 0000000..e0b6d5c
--- /dev/null
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.core.construction;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import com.google.common.collect.Iterables;
+import com.google.protobuf.Any;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.BytesValue;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.common.runner.v1.RunnerApi;
+import org.apache.beam.sdk.common.runner.v1.RunnerApi.CombinePayload;
+import org.apache.beam.sdk.common.runner.v1.RunnerApi.FunctionSpec;
+import org.apache.beam.sdk.common.runner.v1.RunnerApi.SdkFunctionSpec;
+import org.apache.beam.sdk.common.runner.v1.RunnerApi.SideInput;
+import org.apache.beam.sdk.runners.AppliedPTransform;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.CombineFnBase.GlobalCombineFn;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.util.AppliedCombineFn;
+import org.apache.beam.sdk.util.SerializableUtils;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+
+/**
+ * Methods for translating between {@link Combine.PerKey} {@link PTransform PTransforms}
and {@link
+ * RunnerApi.CombinePayload} protos.
+ */
+public class CombineTranslation {
+  private static final String JAVA_SERIALIZED_COMBINE_FN_URN = "urn:beam:java:combinefn:v1";
+
+  public static CombinePayload toProto(
+      AppliedPTransform<?, ?, Combine.PerKey<?, ?, ?>> combine, SdkComponents
sdkComponents)
+      throws IOException {
+    GlobalCombineFn<?, ?, ?> combineFn = combine.getTransform().getFn();
+    try {
+      Coder<?> accumulatorCoder = extractAccumulatorCoder(combineFn, (AppliedPTransform)
combine);
+      Map<String, SideInput> sideInputs = new HashMap<>();
+      return RunnerApi.CombinePayload.newBuilder()
+          .setAccumulatorCoderId(sdkComponents.registerCoder(accumulatorCoder))
+          .putAllSideInputs(sideInputs)
+          .setCombineFn(toProto(combineFn))
+          .build();
+    } catch (CannotProvideCoderException e) {
+      throw new IllegalStateException(e);
+    }
+  }
+
+  private static <K, InputT, AccumT> Coder<AccumT> extractAccumulatorCoder(
+      GlobalCombineFn<InputT, AccumT, ?> combineFn,
+      AppliedPTransform<PCollection<KV<K, InputT>>, ?, Combine.PerKey<K,
InputT, ?>> transform)
+      throws CannotProvideCoderException {
+    KvCoder<K, InputT> inputCoder =
+        (KvCoder<K, InputT>)
+            ((PCollection<KV<K, InputT>>) Iterables.getOnlyElement(transform.getInputs().values()))
+                .getCoder();
+    return AppliedCombineFn.withInputCoder(
+            combineFn,
+            transform.getPipeline().getCoderRegistry(),
+            inputCoder,
+            transform.getTransform().getSideInputs(),
+            ((PCollection<?>) Iterables.getOnlyElement(transform.getOutputs().values()))
+                .getWindowingStrategy())
+        .getAccumulatorCoder();
+  }
+
+  private static SdkFunctionSpec toProto(GlobalCombineFn<?, ?, ?> combineFn) {
+    return SdkFunctionSpec.newBuilder()
+        // TODO: Set Java SDK Environment URN
+        .setSpec(
+            FunctionSpec.newBuilder()
+                .setUrn(JAVA_SERIALIZED_COMBINE_FN_URN)
+                .setParameter(
+                    Any.pack(
+                        BytesValue.newBuilder()
+                            .setValue(
+                                ByteString.copyFrom(
+                                    SerializableUtils.serializeToByteArray(combineFn)))
+                            .build())))
+        .build();
+  }
+
+  public static Coder<?> getAccumulatorCoder(
+      CombinePayload payload, RunnerApi.Components components) throws IOException {
+    String id = payload.getAccumulatorCoderId();
+    return Coders.fromProto(components.getCodersOrThrow(id), components);
+  }
+
+  public static GlobalCombineFn<?, ?, ?> getCombineFn(CombinePayload payload)
+      throws IOException {
+    checkArgument(payload.getCombineFn().getSpec().getUrn().equals(JAVA_SERIALIZED_COMBINE_FN_URN));
+    return (GlobalCombineFn<?, ?, ?>)
+        SerializableUtils.deserializeFromByteArray(
+            payload
+                .getCombineFn()
+                .getSpec()
+                .getParameter()
+                .unpack(BytesValue.class)
+                .getValue()
+                .toByteArray(),
+            "CombineFn");
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/5b899a85/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java
b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java
new file mode 100644
index 0000000..6251545
--- /dev/null
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CombineTranslationTest.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.core.construction;
+
+import static com.google.common.base.Preconditions.checkState;
+import static org.junit.Assert.assertEquals;
+
+import com.google.common.collect.ImmutableList;
+import java.util.concurrent.atomic.AtomicReference;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderRegistry;
+import org.apache.beam.sdk.coders.VoidCoder;
+import org.apache.beam.sdk.common.runner.v1.RunnerApi;
+import org.apache.beam.sdk.common.runner.v1.RunnerApi.CombinePayload;
+import org.apache.beam.sdk.runners.AppliedPTransform;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.Combine.BinaryCombineIntegerFn;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.Sum;
+import org.apache.beam.sdk.values.PCollection;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+/**
+ * Tests for {@link CombineTranslation}.
+ */
+@RunWith(Parameterized.class)
+public class CombineTranslationTest {
+  @Parameters(name = "{index}: {0}")
+  public static Iterable<Combine.CombineFn<Integer, ?, ?>> params() {
+    BinaryCombineIntegerFn sum = Sum.ofIntegers();
+    CombineFn<Integer, ?, Long> count = Count.combineFn();
+    TestCombineFn test = new TestCombineFn();
+    return ImmutableList.<CombineFn<Integer, ?, ?>>builder().add(sum).add(count).add(test).build();
+  }
+
+  @Rule public TestPipeline pipeline = TestPipeline.create();
+  @Parameter(0)
+  public Combine.CombineFn<Integer, ?, ?> combineFn;
+
+  @Test
+  public void testToFromProto() throws Exception {
+    PCollection<Integer> input = pipeline.apply(Create.of(1, 2, 3));
+    input.apply(Combine.globally(combineFn));
+    final AtomicReference<AppliedPTransform<?, ?, Combine.PerKey<?, ?, ?>>>
combine =
+        new AtomicReference<>();
+    pipeline.traverseTopologically(
+        new PipelineVisitor.Defaults() {
+          @Override
+          public void leaveCompositeTransform(Node node) {
+            if (node.getTransform() instanceof Combine.PerKey) {
+              checkState(combine.get() == null);
+              combine.set((AppliedPTransform) node.toAppliedPTransform(getPipeline()));
+            }
+          }
+        });
+    checkState(combine.get() != null);
+
+    SdkComponents sdkComponents = SdkComponents.create();
+    CombinePayload combineProto = CombineTranslation.toProto(combine.get(), sdkComponents);
+    RunnerApi.Components componentsProto = sdkComponents.toComponents();
+
+    assertEquals(
+        combineFn.getAccumulatorCoder(pipeline.getCoderRegistry(), input.getCoder()),
+        CombineTranslation.getAccumulatorCoder(combineProto, componentsProto));
+    assertEquals(combineFn, CombineTranslation.getCombineFn(combineProto));
+  }
+
+  private static class TestCombineFn extends Combine.CombineFn<Integer, Void, Void>
{
+    @Override
+    public Void createAccumulator() {
+      return null;
+    }
+
+    @Override
+    public Coder<Void> getAccumulatorCoder(CoderRegistry registry, Coder<Integer>
inputCoder) {
+      return (Coder) VoidCoder.of();
+    }
+
+    @Override
+    public Void extractOutput(Void accumulator) {
+      return accumulator;
+    }
+
+    @Override
+    public Void mergeAccumulators(Iterable<Void> accumulators) {
+      return null;
+    }
+
+    @Override
+    public Void addInput(Void accumulator, Integer input) {
+      return accumulator;
+    }
+
+    @Override
+    public boolean equals(Object other) {
+      return other != null && other.getClass().equals(TestCombineFn.class);
+    }
+
+    @Override
+    public int hashCode() {
+      return TestCombineFn.class.hashCode();
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/5b899a85/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Count.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Count.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Count.java
index b405dd1..ee24b3f 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Count.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Count.java
@@ -195,5 +195,15 @@ public class Count {
         }
       };
     }
+
+    @Override
+    public boolean equals(Object other) {
+      return other != null && getClass().equals(other.getClass());
+    }
+
+    @Override
+    public int hashCode() {
+      return getClass().hashCode();
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/5b899a85/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Sum.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Sum.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Sum.java
index ccade4d..6b65416 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Sum.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Sum.java
@@ -151,6 +151,16 @@ public class Sum {
     public int identity() {
       return 0;
     }
+
+    @Override
+    public boolean equals(Object other) {
+      return other != null && other.getClass().equals(this.getClass());
+    }
+
+    @Override
+    public int hashCode() {
+      return getClass().hashCode();
+    }
   }
 
   private static class SumLongFn extends Combine.BinaryCombineLongFn {
@@ -164,6 +174,16 @@ public class Sum {
     public long identity() {
       return 0;
     }
+
+    @Override
+    public boolean equals(Object other) {
+      return other != null && other.getClass().equals(this.getClass());
+    }
+
+    @Override
+    public int hashCode() {
+      return getClass().hashCode();
+    }
   }
 
   private static class SumDoubleFn extends Combine.BinaryCombineDoubleFn {
@@ -177,5 +197,15 @@ public class Sum {
     public double identity() {
       return 0;
     }
+
+    @Override
+    public boolean equals(Object other) {
+      return other != null && other.getClass().equals(this.getClass());
+    }
+
+    @Override
+    public int hashCode() {
+      return getClass().hashCode();
+    }
   }
 }


Mime
View raw message