beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From k...@apache.org
Subject [beam] branch master updated: [BEAM-8111] Enable CloudObjectsTest$DefaultCoders (#9446)
Date Fri, 04 Oct 2019 05:04:01 GMT
This is an automated email from the ASF dual-hosted git repository.

kenn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new e14f3f8  [BEAM-8111] Enable CloudObjectsTest$DefaultCoders (#9446)
e14f3f8 is described below

commit e14f3f84ba0da6288db4795b92199e2905134a17
Author: Brian Hulette <bhulette@google.com>
AuthorDate: Thu Oct 3 22:03:17 2019 -0700

    [BEAM-8111] Enable CloudObjectsTest$DefaultCoders (#9446)
---
 .../beam/runners/dataflow/util/CloudObjects.java   |  2 +
 .../runners/dataflow/util/CloudObjectsTest.java    | 54 +++++++++++++++++-----
 .../java/org/apache/beam/sdk/coders/RowCoder.java  | 18 ++++++++
 .../org/apache/beam/sdk/schemas/SchemaCoder.java   | 48 +++++++++++++++++--
 4 files changed, 108 insertions(+), 14 deletions(-)

diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjects.java
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjects.java
index ec51351..ef47e1f 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjects.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjects.java
@@ -29,6 +29,7 @@ import org.apache.beam.runners.core.construction.Timer;
 import org.apache.beam.sdk.coders.ByteArrayCoder;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.CustomCoder;
+import org.apache.beam.sdk.coders.DoubleCoder;
 import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.LengthPrefixCoder;
@@ -49,6 +50,7 @@ public class CloudObjects {
           ByteArrayCoder.class,
           KvCoder.class,
           VarLongCoder.class,
+          DoubleCoder.class,
           IntervalWindowCoder.class,
           IterableCoder.class,
           Timer.Coder.class,
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/CloudObjectsTest.java
b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/CloudObjectsTest.java
index 215567e..26b721c 100644
--- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/CloudObjectsTest.java
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/CloudObjectsTest.java
@@ -17,11 +17,12 @@
  */
 package org.apache.beam.runners.dataflow.util;
 
+import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.emptyIterable;
 import static org.hamcrest.Matchers.hasItem;
+import static org.hamcrest.Matchers.instanceOf;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertThat;
 import static org.junit.Assert.assertTrue;
 
 import java.io.IOException;
@@ -32,7 +33,6 @@ import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
-import org.apache.beam.runners.core.construction.ModelCoderRegistrar;
 import org.apache.beam.runners.core.construction.SdkComponents;
 import org.apache.beam.sdk.coders.AvroCoder;
 import org.apache.beam.sdk.coders.ByteArrayCoder;
@@ -50,7 +50,9 @@ import org.apache.beam.sdk.coders.SerializableCoder;
 import org.apache.beam.sdk.coders.SetCoder;
 import org.apache.beam.sdk.coders.StructuredCoder;
 import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.schemas.LogicalTypes;
 import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.FieldType;
 import org.apache.beam.sdk.schemas.SchemaCoder;
 import org.apache.beam.sdk.transforms.join.CoGbkResult.CoGbkResultCoder;
 import org.apache.beam.sdk.transforms.join.CoGbkResultSchema;
@@ -62,7 +64,9 @@ import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList.Builder;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
 import org.junit.Test;
+import org.junit.experimental.runners.Enclosed;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 import org.junit.runners.Parameterized;
@@ -70,7 +74,22 @@ import org.junit.runners.Parameterized.Parameter;
 import org.junit.runners.Parameterized.Parameters;
 
 /** Tests for {@link CloudObjects}. */
+@RunWith(Enclosed.class)
 public class CloudObjectsTest {
+  private static final Schema TEST_SCHEMA =
+      Schema.builder()
+          .addBooleanField("bool")
+          .addByteField("int8")
+          .addInt16Field("int16")
+          .addInt32Field("int32")
+          .addInt64Field("int64")
+          .addFloatField("float")
+          .addDoubleField("double")
+          .addStringField("string")
+          .addArrayField("list_int32", FieldType.INT32)
+          .addLogicalTypeField("fixed_bytes", LogicalTypes.FixedBytes.of(4))
+          .build();
+
   /** Tests that all of the Default Coders are tested. */
   @RunWith(JUnit4.class)
   public static class DefaultsPresentTest {
@@ -143,7 +162,8 @@ public class CloudObjectsTest {
                       CoGbkResultSchema.of(
                           ImmutableList.of(new TupleTag<Long>(), new TupleTag<byte[]>())),
                       UnionCoder.of(ImmutableList.of(VarLongCoder.of(), ByteArrayCoder.of()))))
-              .add(SchemaCoder.of(Schema.builder().build()));
+              .add(SchemaCoder.of(Schema.builder().build()))
+              .add(SchemaCoder.of(TEST_SCHEMA));
       for (Class<? extends Coder> atomicCoder :
           DefaultCoderCloudObjectTranslatorRegistrar.KNOWN_ATOMIC_CODERS) {
         dataBuilder.add(InstanceBuilder.ofType(atomicCoder).fromFactoryMethod("of").build());
@@ -177,21 +197,33 @@ public class CloudObjectsTest {
 
     private static void checkPipelineProtoCoderIds(
         Coder<?> coder, CloudObject cloudObject, SdkComponents sdkComponents) throws
Exception {
-      if (ModelCoderRegistrar.isKnownCoder(coder)) {
+      if (CloudObjects.DATAFLOW_KNOWN_CODERS.contains(coder.getClass())) {
         assertFalse(cloudObject.containsKey(PropertyNames.PIPELINE_PROTO_CODER_ID));
       } else {
         assertTrue(cloudObject.containsKey(PropertyNames.PIPELINE_PROTO_CODER_ID));
         assertEquals(
             sdkComponents.registerCoder(coder),
-            cloudObject.get(PropertyNames.PIPELINE_PROTO_CODER_ID));
+            ((CloudObject) cloudObject.get(PropertyNames.PIPELINE_PROTO_CODER_ID))
+                .get(PropertyNames.VALUE));
+      }
+      List<? extends Coder<?>> expectedComponents;
+      if (coder instanceof StructuredCoder) {
+        expectedComponents = ((StructuredCoder) coder).getComponents();
+      } else {
+        expectedComponents = coder.getCoderArguments();
       }
-      List<? extends Coder<?>> coderArguments = coder.getCoderArguments();
       Object cloudComponentsObject = cloudObject.get(PropertyNames.COMPONENT_ENCODINGS);
-      assertTrue(cloudComponentsObject instanceof List);
-      List<CloudObject> cloudComponents = (List<CloudObject>) cloudComponentsObject;
-      assertEquals(coderArguments.size(), cloudComponents.size());
-      for (int i = 0; i < coderArguments.size(); i++) {
-        checkPipelineProtoCoderIds(coderArguments.get(i), cloudComponents.get(i), sdkComponents);
+      List<CloudObject> cloudComponents;
+      if (cloudComponentsObject == null) {
+        cloudComponents = Lists.newArrayList();
+      } else {
+        assertThat(cloudComponentsObject, instanceOf(List.class));
+        cloudComponents = (List<CloudObject>) cloudComponentsObject;
+      }
+      assertEquals(expectedComponents.size(), cloudComponents.size());
+      for (int i = 0; i < expectedComponents.size(); i++) {
+        checkPipelineProtoCoderIds(
+            expectedComponents.get(i), cloudComponents.get(i), sdkComponents);
       }
     }
   }
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java
index f6cfe6a..79faa91 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java
@@ -24,6 +24,7 @@ import java.io.InputStream;
 import java.io.OutputStream;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.UUID;
 import java.util.stream.Collectors;
 import javax.annotation.Nullable;
@@ -247,4 +248,21 @@ public class RowCoder extends CustomCoder<Row> {
     String string = "Schema: " + schema + "  UUID: " + id + " delegateCoder: " + getDelegateCoder();
     return string;
   }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) {
+      return true;
+    }
+    if (o == null || getClass() != o.getClass()) {
+      return false;
+    }
+    RowCoder rowCoder = (RowCoder) o;
+    return schema.equals(rowCoder.schema);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(schema);
+  }
 }
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java
index 0199534..9e06b44 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java
@@ -20,12 +20,12 @@ package org.apache.beam.sdk.schemas;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
+import java.util.Objects;
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.annotations.Experimental.Kind;
 import org.apache.beam.sdk.coders.CustomCoder;
 import org.apache.beam.sdk.coders.RowCoder;
 import org.apache.beam.sdk.transforms.SerializableFunction;
-import org.apache.beam.sdk.transforms.SerializableFunctions;
 import org.apache.beam.sdk.values.Row;
 
 /** {@link SchemaCoder} is used as the coder for types that have schemas registered. */
@@ -57,8 +57,7 @@ public class SchemaCoder<T> extends CustomCoder<T> {
 
   /** Returns a {@link SchemaCoder} for {@link Row} classes. */
   public static SchemaCoder<Row> of(Schema schema) {
-    return new SchemaCoder<>(
-        schema, SerializableFunctions.identity(), SerializableFunctions.identity());
+    return new SchemaCoder<>(schema, identity(), identity());
   }
 
   /** Returns the schema associated with this type. */
@@ -100,4 +99,47 @@ public class SchemaCoder<T> extends CustomCoder<T> {
   public String toString() {
     return "SchemaCoder: " + rowCoder.toString();
   }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) {
+      return true;
+    }
+    if (o == null || getClass() != o.getClass()) {
+      return false;
+    }
+    SchemaCoder<?> that = (SchemaCoder<?>) o;
+    return rowCoder.equals(that.rowCoder)
+        && toRowFunction.equals(that.toRowFunction)
+        && fromRowFunction.equals(that.fromRowFunction);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(rowCoder, toRowFunction, fromRowFunction);
+  }
+
+  private static RowIdentity identity() {
+    return new RowIdentity();
+  }
+
+  private static class RowIdentity implements SerializableFunction<Row, Row> {
+    @Override
+    public Row apply(Row input) {
+      return input;
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hash(getClass());
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+      return o != null && getClass() == o.getClass();
+    }
+  }
 }


Mime
View raw message