beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From echauc...@apache.org
Subject [beam] 10/37: Lazy init coder because coder instance cannot be interpolated by catalyst
Date Thu, 24 Oct 2019 10:18:14 GMT
This is an automated email from the ASF dual-hosted git repository.

echauchot pushed a commit to branch spark-runner_structured-streaming
in repository https://gitbox.apache.org/repos/asf/beam.git

commit e6b68a8f21aba2adcb7543eae806d71e08c0bff3
Author: Etienne Chauchot <echauchot@apache.org>
AuthorDate: Mon Sep 2 17:55:24 2019 +0200

    Lazy init coder because coder instance cannot be interpolated by catalyst
---
 runners/spark/build.gradle                         |  1 +
 .../translation/helpers/EncoderHelpers.java        | 63 +++++++++++++++-------
 .../structuredstreaming/utils/EncodersTest.java    |  3 +-
 3 files changed, 47 insertions(+), 20 deletions(-)

diff --git a/runners/spark/build.gradle b/runners/spark/build.gradle
index 73a710b..a948ef1 100644
--- a/runners/spark/build.gradle
+++ b/runners/spark/build.gradle
@@ -77,6 +77,7 @@ dependencies {
   provided "com.esotericsoftware.kryo:kryo:2.21"
   runtimeOnly library.java.jackson_module_scala
   runtimeOnly "org.scala-lang:scala-library:2.11.8"
+  compile "org.scala-lang.modules:scala-java8-compat_2.11:0.9.0"
   testCompile project(":sdks:java:io:kafka")
   testCompile project(path: ":sdks:java:core", configuration: "shadowTest")
   // SparkStateInternalsTest extends abstract StateInternalsTest
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
index cc862cd..694bc24 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
@@ -18,9 +18,9 @@
 package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
 
 import static org.apache.spark.sql.types.DataTypes.BinaryType;
+import static scala.compat.java8.JFunction.func;
 
 import java.io.ByteArrayInputStream;
-import java.io.IOException;
 import java.lang.reflect.Array;
 import java.util.ArrayList;
 import java.util.List;
@@ -45,7 +45,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.ExprCode;
 import org.apache.spark.sql.catalyst.expressions.codegen.VariableValue;
 import org.apache.spark.sql.types.DataType;
 import org.apache.spark.sql.types.ObjectType;
-import scala.Function1;
 import scala.StringContext;
 import scala.Tuple2;
 import scala.collection.JavaConversions;
@@ -94,17 +93,17 @@ public class EncoderHelpers {
   */
 
   /** A way to construct encoders using generic serializers. */
-  public static <T> Encoder<T> fromBeamCoder(Coder<T> coder/*, Class<T>
claz*/){
+  public static <T> Encoder<T> fromBeamCoder(Class<? extends Coder<T>>
coderClass/*, Class<T> claz*/){
 
     List<Expression> serialiserList = new ArrayList<>();
     Class<T> claz = (Class<T>) Object.class;
-    serialiserList.add(new EncodeUsingBeamCoder<>(new BoundReference(0, new ObjectType(claz),
true), coder));
+    serialiserList.add(new EncodeUsingBeamCoder<>(new BoundReference(0, new ObjectType(claz),
true), (Class<Coder<T>>)coderClass));
     ClassTag<T> classTag = ClassTag$.MODULE$.apply(claz);
     return new ExpressionEncoder<>(
         SchemaHelpers.binarySchema(),
         false,
         JavaConversions.collectionAsScalaIterable(serialiserList).toSeq(),
-        new DecodeUsingBeamCoder<>(new Cast(new GetColumnByOrdinal(0, BinaryType),
BinaryType), classTag, coder),
+        new DecodeUsingBeamCoder<>(new Cast(new GetColumnByOrdinal(0, BinaryType),
BinaryType), classTag, (Class<Coder<T>>)coderClass),
         classTag);
 
 /*
@@ -127,11 +126,11 @@ public class EncoderHelpers {
   public static class EncodeUsingBeamCoder<T> extends UnaryExpression implements NonSQLExpression
{
 
     private Expression child;
-    private Coder<T> beamCoder;
+    private Class<Coder<T>> coderClass;
 
-    public EncodeUsingBeamCoder(Expression child, Coder<T> beamCoder) {
+    public EncodeUsingBeamCoder(Expression child, Class<Coder<T>> coderClass)
{
       this.child = child;
-      this.beamCoder = beamCoder;
+      this.coderClass = coderClass;
     }
 
     @Override public Expression child() {
@@ -140,6 +139,7 @@ public class EncoderHelpers {
 
     @Override public ExprCode doGenCode(CodegenContext ctx, ExprCode ev) {
       // Code to serialize.
+      String beamCoder = lazyInitBeamCoder(ctx, coderClass);
       ExprCode input = child.genCode(ctx);
 
       /*
@@ -170,6 +170,7 @@ public class EncoderHelpers {
           new VariableValue("output", Array.class));
     }
 
+
     @Override public DataType dataType() {
       return BinaryType;
     }
@@ -179,7 +180,7 @@ public class EncoderHelpers {
         case 0:
           return child;
         case 1:
-          return beamCoder;
+          return coderClass;
         default:
           throw new ArrayIndexOutOfBoundsException("productElement out of bounds");
       }
@@ -201,11 +202,11 @@ public class EncoderHelpers {
         return false;
       }
       EncodeUsingBeamCoder<?> that = (EncodeUsingBeamCoder<?>) o;
-      return beamCoder.equals(that.beamCoder);
+      return coderClass.equals(that.coderClass);
     }
 
     @Override public int hashCode() {
-      return Objects.hash(super.hashCode(), beamCoder);
+      return Objects.hash(super.hashCode(), coderClass);
     }
   }
 
@@ -237,12 +238,12 @@ public class EncoderHelpers {
 
     private Expression child;
     private ClassTag<T> classTag;
-    private Coder<T> beamCoder;
+    private Class<Coder<T>> coderClass;
 
-    public DecodeUsingBeamCoder(Expression child, ClassTag<T> classTag, Coder<T>
beamCoder) {
+    public DecodeUsingBeamCoder(Expression child, ClassTag<T> classTag, Class<Coder<T>>
coderClass) {
       this.child = child;
       this.classTag = classTag;
-      this.beamCoder = beamCoder;
+      this.coderClass = coderClass;
     }
 
     @Override public Expression child() {
@@ -251,6 +252,7 @@ public class EncoderHelpers {
 
     @Override public ExprCode doGenCode(CodegenContext ctx, ExprCode ev) {
       // Code to deserialize.
+      String beamCoder = lazyInitBeamCoder(ctx, coderClass);
       ExprCode input = child.genCode(ctx);
       String javaType = CodeGenerator.javaType(dataType());
 
@@ -291,9 +293,10 @@ public class EncoderHelpers {
 
     @Override public Object nullSafeEval(Object input) {
       try {
+        Coder<T> beamCoder = coderClass.newInstance();
         return beamCoder.decode(new ByteArrayInputStream((byte[]) input));
-      } catch (IOException e) {
-        throw new IllegalStateException("Error decoding bytes for coder: " + beamCoder, e);
+      } catch (Exception e) {
+        throw new IllegalStateException("Error decoding bytes for coder: " + coderClass,
e);
       }
     }
 
@@ -308,7 +311,7 @@ public class EncoderHelpers {
         case 1:
           return classTag;
         case 2:
-          return beamCoder;
+          return coderClass;
         default:
           throw new ArrayIndexOutOfBoundsException("productElement out of bounds");
       }
@@ -330,11 +333,11 @@ public class EncoderHelpers {
         return false;
       }
       DecodeUsingBeamCoder<?> that = (DecodeUsingBeamCoder<?>) o;
-      return classTag.equals(that.classTag) && beamCoder.equals(that.beamCoder);
+      return classTag.equals(that.classTag) && coderClass.equals(that.coderClass);
     }
 
     @Override public int hashCode() {
-      return Objects.hash(super.hashCode(), classTag, beamCoder);
+      return Objects.hash(super.hashCode(), classTag, coderClass);
     }
   }
 /*
@@ -365,4 +368,26 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T],
kryo: B
   }
 */
 
+  private static <T> String lazyInitBeamCoder(CodegenContext ctx, Class<Coder<T>>
coderClass) {
+    String beamCoderInstance = "beamCoder";
+    ctx.addImmutableStateIfNotExists(coderClass.getName(), beamCoderInstance, func(v1 ->
{
+      /*
+    CODE GENERATED
+    v = (coderClass) coderClass.newInstance();
+     */
+        List<String> parts = new ArrayList<>();
+        parts.add("");
+        parts.add(" = (");
+        parts.add(") ");
+        parts.add(".newInstance();");
+        StringContext sc = new StringContext(JavaConversions.collectionAsScalaIterable(parts).toSeq());
+        List<Object> args = new ArrayList<>();
+        args.add(v1);
+        args.add(coderClass.getName());
+        args.add(coderClass.getName());
+        return sc.s(JavaConversions.collectionAsScalaIterable(args).toSeq());
+      }));
+    return beamCoderInstance;
+  }
+
 }
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/utils/EncodersTest.java
b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/utils/EncodersTest.java
index 7078b0c..0e38fe1 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/utils/EncodersTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/utils/EncodersTest.java
@@ -3,6 +3,7 @@ package org.apache.beam.runners.spark.structuredstreaming.utils;
 import java.util.ArrayList;
 import java.util.List;
 import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.spark.sql.SparkSession;
 import org.junit.Test;
@@ -23,7 +24,7 @@ public class EncodersTest {
     data.add(1);
     data.add(2);
     data.add(3);
-    sparkSession.createDataset(data, EncoderHelpers.fromBeamCoder(VarIntCoder.of()));
+    sparkSession.createDataset(data, EncoderHelpers.fromBeamCoder(VarIntCoder.class));
 //    sparkSession.createDataset(data, EncoderHelpers.genericEncoder());
   }
 }


Mime
View raw message