beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From bchamb...@apache.org
Subject [1/2] incubator-beam git commit: Execute ModelEnforcements in TransformExecutor
Date Fri, 01 Apr 2016 17:31:52 GMT
Repository: incubator-beam
Updated Branches:
  refs/heads/master 9a42971ce -> f236db027


Execute ModelEnforcements in TransformExecutor

This allows a configurable application of Model Enforcement based on the
class of transform being executed, both before and after an element is
processed and after the transform completes.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/197a93cd
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/197a93cd
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/197a93cd

Branch: refs/heads/master
Commit: 197a93cdb30db306f2404eb89be94995b0da5d8b
Parents: 9a42971
Author: Thomas Groh <tgroh@google.com>
Authored: Mon Mar 28 09:35:33 2016 -0700
Committer: bchambers <bchambers@google.com>
Committed: Fri Apr 1 10:19:44 2016 -0700

----------------------------------------------------------------------
 .../runners/inprocess/CompletionCallback.java   |   5 +-
 .../ExecutorServiceParallelExecutor.java        |  35 ++-
 .../inprocess/InProcessPipelineRunner.java      |   7 +
 .../runners/inprocess/TransformExecutor.java    |  64 +++++-
 .../inprocess/TransformExecutorTest.java        | 221 ++++++++++++++++++-
 5 files changed, 318 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/197a93cd/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CompletionCallback.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CompletionCallback.java
b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CompletionCallback.java
index 4bb74a7..b581616 100644
--- a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CompletionCallback.java
+++ b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/CompletionCallback.java
@@ -24,9 +24,10 @@ import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.C
  */
 interface CompletionCallback {
   /**
-   * Handle a successful result.
+   * Handle a successful result, returning the committed outputs of the result.
    */
-  void handleResult(CommittedBundle<?> inputBundle, InProcessTransformResult result);
+  Iterable<? extends CommittedBundle<?>> handleResult(
+      CommittedBundle<?> inputBundle, InProcessTransformResult result);
 
   /**
    * Handle a result that terminated abnormally due to the provided {@link Throwable}.

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/197a93cd/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java
b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java
index 59c4918..628f107 100644
--- a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java
+++ b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ExecutorServiceParallelExecutor.java
@@ -21,6 +21,7 @@ import com.google.cloud.dataflow.sdk.Pipeline;
 import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.FiredTimers;
 import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle;
 import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform;
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
 import com.google.cloud.dataflow.sdk.util.KeyedWorkItem;
 import com.google.cloud.dataflow.sdk.util.KeyedWorkItems;
 import com.google.cloud.dataflow.sdk.util.TimeDomain;
@@ -62,6 +63,10 @@ final class ExecutorServiceParallelExecutor implements InProcessExecutor
{
   private final Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> valueToConsumers;
   private final Set<PValue> keyedPValues;
   private final TransformEvaluatorRegistry registry;
+  @SuppressWarnings("rawtypes")
+  private final Map<Class<? extends PTransform>, Collection<ModelEnforcementFactory>>
+      transformEnforcements;
+
   private final InProcessEvaluationContext evaluationContext;
 
   private final ConcurrentMap<StepAndKey, TransformExecutorService> currentEvaluations;
@@ -80,9 +85,11 @@ final class ExecutorServiceParallelExecutor implements InProcessExecutor
{
       Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> valueToConsumers,
       Set<PValue> keyedPValues,
       TransformEvaluatorRegistry registry,
+      @SuppressWarnings("rawtypes")
+      Map<Class<? extends PTransform>, Collection<ModelEnforcementFactory>>
transformEnforcements,
       InProcessEvaluationContext context) {
     return new ExecutorServiceParallelExecutor(
-        executorService, valueToConsumers, keyedPValues, registry, context);
+        executorService, valueToConsumers, keyedPValues, registry, transformEnforcements,
context);
   }
 
   private ExecutorServiceParallelExecutor(
@@ -90,11 +97,14 @@ final class ExecutorServiceParallelExecutor implements InProcessExecutor
{
       Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> valueToConsumers,
       Set<PValue> keyedPValues,
       TransformEvaluatorRegistry registry,
+      @SuppressWarnings("rawtypes")
+      Map<Class<? extends PTransform>, Collection<ModelEnforcementFactory>>
transformEnforcements,
       InProcessEvaluationContext context) {
     this.executorService = executorService;
     this.valueToConsumers = valueToConsumers;
     this.keyedPValues = keyedPValues;
     this.registry = registry;
+    this.transformEnforcements = transformEnforcements;
     this.evaluationContext = context;
 
     currentEvaluations = new ConcurrentHashMap<>();
@@ -128,6 +138,7 @@ final class ExecutorServiceParallelExecutor implements InProcessExecutor
{
       @Nullable final CommittedBundle<T> bundle,
       final CompletionCallback onComplete) {
     TransformExecutorService transformExecutor;
+
     if (bundle != null && isKeyed(bundle.getPCollection())) {
       final StepAndKey stepAndKey =
           StepAndKey.of(transform, bundle == null ? null : bundle.getKey());
@@ -135,9 +146,21 @@ final class ExecutorServiceParallelExecutor implements InProcessExecutor
{
     } else {
       transformExecutor = parallelExecutorService;
     }
+
+    Collection<ModelEnforcementFactory> enforcements =
+        MoreObjects.firstNonNull(
+            transformEnforcements.get(transform.getTransform().getClass()),
+            Collections.<ModelEnforcementFactory>emptyList());
+
     TransformExecutor<T> callable =
         TransformExecutor.create(
-            registry, evaluationContext, bundle, transform, onComplete, transformExecutor);
+            registry,
+            enforcements,
+            evaluationContext,
+            bundle,
+            transform,
+            onComplete,
+            transformExecutor);
     transformExecutor.schedule(callable);
   }
 
@@ -178,12 +201,14 @@ final class ExecutorServiceParallelExecutor implements InProcessExecutor
{
    */
   private class DefaultCompletionCallback implements CompletionCallback {
     @Override
-    public void handleResult(CommittedBundle<?> inputBundle, InProcessTransformResult
result) {
+    public Iterable<? extends CommittedBundle<?>> handleResult(
+        CommittedBundle<?> inputBundle, InProcessTransformResult result) {
       Iterable<? extends CommittedBundle<?>> resultBundles =
           evaluationContext.handleResult(inputBundle, Collections.<TimerData>emptyList(),
result);
       for (CommittedBundle<?> outputBundle : resultBundles) {
         allUpdates.offer(ExecutorUpdate.fromBundle(outputBundle));
       }
+      return resultBundles;
     }
 
     @Override
@@ -206,12 +231,14 @@ final class ExecutorServiceParallelExecutor implements InProcessExecutor
{
     }
 
     @Override
-    public void handleResult(CommittedBundle<?> inputBundle, InProcessTransformResult
result) {
+    public Iterable<? extends CommittedBundle<?>> handleResult(
+        CommittedBundle<?> inputBundle, InProcessTransformResult result) {
       Iterable<? extends CommittedBundle<?>> resultBundles =
           evaluationContext.handleResult(inputBundle, timers, result);
       for (CommittedBundle<?> outputBundle : resultBundles) {
         allUpdates.offer(ExecutorUpdate.fromBundle(outputBundle));
       }
+      return resultBundles;
     }
 
     @Override

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/197a93cd/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java
b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java
index 7f65cf0..8123711 100644
--- a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java
+++ b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java
@@ -54,6 +54,7 @@ import com.google.common.collect.ImmutableSet;
 import org.joda.time.Instant;
 
 import java.util.Collection;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.ExecutorService;
@@ -245,6 +246,7 @@ public class InProcessPipelineRunner
             consumerTrackingVisitor.getValueToConsumers(),
             keyedPValueVisitor.getKeyedPValues(),
             TransformEvaluatorRegistry.defaultRegistry(),
+            defaultModelEnforcements(options),
             context);
     executor.start(consumerTrackingVisitor.getRootTransforms());
 
@@ -264,6 +266,11 @@ public class InProcessPipelineRunner
     return result;
   }
 
+  private Map<Class<? extends PTransform>, Collection<ModelEnforcementFactory>>
+      defaultModelEnforcements(InProcessPipelineOptions options) {
+    return Collections.emptyMap();
+  }
+
   /**
    * The result of running a {@link Pipeline} with the {@link InProcessPipelineRunner}.
    *

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/197a93cd/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutor.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutor.java
b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutor.java
index 06bc6a8..62a9e24 100644
--- a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutor.java
+++ b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutor.java
@@ -22,6 +22,8 @@ import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform;
 import com.google.cloud.dataflow.sdk.util.WindowedValue;
 import com.google.common.base.Throwables;
 
+import java.util.ArrayList;
+import java.util.Collection;
 import java.util.concurrent.Callable;
 
 import javax.annotation.Nullable;
@@ -37,6 +39,7 @@ import javax.annotation.Nullable;
 class TransformExecutor<T> implements Callable<InProcessTransformResult> {
   public static <T> TransformExecutor<T> create(
       TransformEvaluatorFactory factory,
+      Iterable<? extends ModelEnforcementFactory> modelEnforcements,
       InProcessEvaluationContext evaluationContext,
       CommittedBundle<T> inputBundle,
       AppliedPTransform<?, ?, ?> transform,
@@ -44,6 +47,7 @@ class TransformExecutor<T> implements Callable<InProcessTransformResult>
{
       TransformExecutorService transformEvaluationState) {
     return new TransformExecutor<>(
         factory,
+        modelEnforcements,
         evaluationContext,
         inputBundle,
         transform,
@@ -52,6 +56,8 @@ class TransformExecutor<T> implements Callable<InProcessTransformResult>
{
   }
 
   private final TransformEvaluatorFactory evaluatorFactory;
+  private final Iterable<? extends ModelEnforcementFactory> modelEnforcements;
+
   private final InProcessEvaluationContext evaluationContext;
 
   /** The transform that will be evaluated. */
@@ -66,12 +72,14 @@ class TransformExecutor<T> implements Callable<InProcessTransformResult>
{
 
   private TransformExecutor(
       TransformEvaluatorFactory factory,
+      Iterable<? extends ModelEnforcementFactory> modelEnforcements,
       InProcessEvaluationContext evaluationContext,
       CommittedBundle<T> inputBundle,
       AppliedPTransform<?, ?, ?> transform,
       CompletionCallback completionCallback,
       TransformExecutorService transformEvaluationState) {
     this.evaluatorFactory = factory;
+    this.modelEnforcements = modelEnforcements;
     this.evaluationContext = evaluationContext;
 
     this.inputBundle = inputBundle;
@@ -86,15 +94,17 @@ class TransformExecutor<T> implements Callable<InProcessTransformResult>
{
   public InProcessTransformResult call() {
     this.thread = Thread.currentThread();
     try {
+      Collection<ModelEnforcement<T>> enforcements = new ArrayList<>();
+      for (ModelEnforcementFactory enforcementFactory : modelEnforcements) {
+        ModelEnforcement<T> enforcement = enforcementFactory.forBundle(inputBundle,
transform);
+        enforcements.add(enforcement);
+      }
       TransformEvaluator<T> evaluator =
           evaluatorFactory.forApplication(transform, inputBundle, evaluationContext);
-      if (inputBundle != null) {
-        for (WindowedValue<T> value : inputBundle.getElements()) {
-          evaluator.processElement(value);
-        }
-      }
-      InProcessTransformResult result = evaluator.finishBundle();
-      onComplete.handleResult(inputBundle, result);
+
+      processElements(evaluator, enforcements);
+
+      InProcessTransformResult result = finishBundle(evaluator, enforcements);
       return result;
     } catch (Throwable t) {
       onComplete.handleThrowable(inputBundle, t);
@@ -106,6 +116,46 @@ class TransformExecutor<T> implements Callable<InProcessTransformResult>
{
   }
 
   /**
+   * Processes all the elements in the input bundle using the transform evaluator, applying
any
+   * necessary {@link ModelEnforcement ModelEnforcements}.
+   */
+  private void processElements(
+      TransformEvaluator<T> evaluator, Collection<ModelEnforcement<T>>
enforcements)
+      throws Exception {
+    if (inputBundle != null) {
+      for (WindowedValue<T> value : inputBundle.getElements()) {
+        for (ModelEnforcement<T> enforcement : enforcements) {
+          enforcement.beforeElement(value);
+        }
+
+        evaluator.processElement(value);
+
+        for (ModelEnforcement<T> enforcement : enforcements) {
+          enforcement.afterElement(value);
+        }
+      }
+    }
+  }
+
+  /**
+   * Finishes processing the input bundle and commit the result using the
+   * {@link CompletionCallback}, applying any {@link ModelEnforcement} if necessary.
+   *
+   * @return the {@link InProcessTransformResult} produced by
+   *         {@link TransformEvaluator#finishBundle()}
+   */
+  private InProcessTransformResult finishBundle(
+      TransformEvaluator<T> evaluator, Collection<ModelEnforcement<T>>
enforcements)
+      throws Exception {
+    InProcessTransformResult result = evaluator.finishBundle();
+    Iterable<? extends CommittedBundle<?>> outputs = onComplete.handleResult(inputBundle,
result);
+    for (ModelEnforcement<T> enforcement : enforcements) {
+      enforcement.afterFinish(inputBundle, result, outputs);
+    }
+    return result;
+  }
+
+  /**
    * If this {@link TransformExecutor} is currently executing, return the thread it is executing
in.
    * Otherwise, return null.
    */

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/197a93cd/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorTest.java
b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorTest.java
index 2ba7ecb..a710753 100644
--- a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorTest.java
+++ b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TransformExecutorTest.java
@@ -17,18 +17,24 @@
  */
 package com.google.cloud.dataflow.sdk.runners.inprocess;
 
+import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.isA;
 import static org.hamcrest.Matchers.not;
 import static org.hamcrest.Matchers.nullValue;
 import static org.junit.Assert.assertThat;
 import static org.mockito.Mockito.when;
 
+import com.google.cloud.dataflow.sdk.coders.ByteArrayCoder;
 import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle;
 import com.google.cloud.dataflow.sdk.testing.TestPipeline;
+import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform;
 import com.google.cloud.dataflow.sdk.transforms.Create;
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
 import com.google.cloud.dataflow.sdk.transforms.WithKeys;
+import com.google.cloud.dataflow.sdk.util.UserCodeException;
 import com.google.cloud.dataflow.sdk.util.WindowedValue;
 import com.google.cloud.dataflow.sdk.values.KV;
 import com.google.cloud.dataflow.sdk.values.PCollection;
@@ -37,7 +43,9 @@ import com.google.common.util.concurrent.MoreExecutors;
 import org.hamcrest.Matchers;
 import org.joda.time.Instant;
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.ExpectedException;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 import org.mockito.Mock;
@@ -45,10 +53,13 @@ import org.mockito.MockitoAnnotations;
 
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
 import java.util.concurrent.atomic.AtomicBoolean;
 
 /**
@@ -56,6 +67,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
  */
 @RunWith(JUnit4.class)
 public class TransformExecutorTest {
+  @Rule public ExpectedException thrown = ExpectedException.none();
   private PCollection<String> created;
   private PCollection<KV<Integer, String>> downstream;
 
@@ -108,6 +120,7 @@ public class TransformExecutorTest {
     TransformExecutor<Object> executor =
         TransformExecutor.create(
             registry,
+            Collections.<ModelEnforcementFactory>emptyList(),
             evaluationContext,
             null,
             created.getProducingTransformInternal(),
@@ -153,6 +166,7 @@ public class TransformExecutorTest {
     TransformExecutor<String> executor =
         TransformExecutor.create(
             registry,
+            Collections.<ModelEnforcementFactory>emptyList(),
             evaluationContext,
             inputBundle,
             downstream.getProducingTransformInternal(),
@@ -198,6 +212,7 @@ public class TransformExecutorTest {
     TransformExecutor<String> executor =
         TransformExecutor.create(
             registry,
+            Collections.<ModelEnforcementFactory>emptyList(),
             evaluationContext,
             inputBundle,
             downstream.getProducingTransformInternal(),
@@ -235,6 +250,7 @@ public class TransformExecutorTest {
     TransformExecutor<String> executor =
         TransformExecutor.create(
             registry,
+            Collections.<ModelEnforcementFactory>emptyList(),
             evaluationContext,
             inputBundle,
             downstream.getProducingTransformInternal(),
@@ -276,6 +292,7 @@ public class TransformExecutorTest {
     TransformExecutor<String> executor =
         TransformExecutor.create(
             registry,
+            Collections.<ModelEnforcementFactory>emptyList(),
             evaluationContext,
             null,
             created.getProducingTransformInternal(),
@@ -290,6 +307,171 @@ public class TransformExecutorTest {
     evaluatorLatch.countDown();
   }
 
+  @Test
+  public void callWithEnforcementAppliesEnforcement() throws Exception {
+    final InProcessTransformResult result =
+        StepTransformResult.withoutHold(downstream.getProducingTransformInternal()).build();
+
+    TransformEvaluator<Object> evaluator =
+        new TransformEvaluator<Object>() {
+          @Override
+          public void processElement(WindowedValue<Object> element) throws Exception
{
+          }
+
+          @Override
+          public InProcessTransformResult finishBundle() throws Exception {
+            return result;
+          }
+        };
+
+    WindowedValue<String> fooElem = WindowedValue.valueInGlobalWindow("foo");
+    WindowedValue<String> barElem = WindowedValue.valueInGlobalWindow("bar");
+    CommittedBundle<String> inputBundle =
+        InProcessBundle.unkeyed(created).add(fooElem).add(barElem).commit(Instant.now());
+    when(
+            registry.forApplication(
+                downstream.getProducingTransformInternal(), inputBundle, evaluationContext))
+        .thenReturn(evaluator);
+
+    TestEnforcementFactory enforcement = new TestEnforcementFactory();
+    TransformExecutor<String> executor =
+        TransformExecutor.create(
+            registry,
+            Collections.<ModelEnforcementFactory>singleton(enforcement),
+            evaluationContext,
+            inputBundle,
+            downstream.getProducingTransformInternal(),
+            completionCallback,
+            transformEvaluationState);
+
+    executor.call();
+    TestEnforcement<?> testEnforcement = enforcement.instance;
+    assertThat(
+        testEnforcement.beforeElements,
+        Matchers.<WindowedValue<?>>containsInAnyOrder(barElem, fooElem));
+    assertThat(
+        testEnforcement.afterElements,
+        Matchers.<WindowedValue<?>>containsInAnyOrder(barElem, fooElem));
+    assertThat(testEnforcement.finishedBundles, contains(result));
+  }
+
+  @Test
+  public void callWithEnforcementThrowsOnFinishPropagates() throws Exception {
+    PCollection<byte[]> pcBytes =
+        created.apply(
+            new PTransform<PCollection<String>, PCollection<byte[]>>()
{
+              @Override
+              public PCollection<byte[]> apply(PCollection<String> input) {
+                return PCollection.<byte[]>createPrimitiveOutputInternal(
+                        input.getPipeline(), input.getWindowingStrategy(), input.isBounded())
+                    .setCoder(ByteArrayCoder.of());
+              }
+            });
+
+    final InProcessTransformResult result =
+        StepTransformResult.withoutHold(pcBytes.getProducingTransformInternal()).build();
+    final CountDownLatch testLatch = new CountDownLatch(1);
+    final CountDownLatch evaluatorLatch = new CountDownLatch(1);
+
+    TransformEvaluator<Object> evaluator =
+        new TransformEvaluator<Object>() {
+          @Override
+          public void processElement(WindowedValue<Object> element) throws Exception
{}
+
+          @Override
+          public InProcessTransformResult finishBundle() throws Exception {
+            testLatch.countDown();
+            evaluatorLatch.await();
+            return result;
+          }
+        };
+
+    WindowedValue<byte[]> fooBytes = WindowedValue.valueInGlobalWindow("foo".getBytes());
+    CommittedBundle<byte[]> inputBundle =
+        InProcessBundle.unkeyed(pcBytes).add(fooBytes).commit(Instant.now());
+    when(
+            registry.forApplication(
+                pcBytes.getProducingTransformInternal(), inputBundle, evaluationContext))
+        .thenReturn(evaluator);
+
+    TransformExecutor<byte[]> executor =
+        TransformExecutor.create(
+            registry,
+            Collections.<ModelEnforcementFactory>singleton(ImmutabilityEnforcementFactory.create()),
+            evaluationContext,
+            inputBundle,
+            pcBytes.getProducingTransformInternal(),
+            completionCallback,
+            transformEvaluationState);
+
+    Future<InProcessTransformResult> task = Executors.newSingleThreadExecutor().submit(executor);
+    testLatch.await();
+    fooBytes.getValue()[0] = 'b';
+    evaluatorLatch.countDown();
+
+    thrown.expectCause(isA(UserCodeException.class));
+    task.get();
+  }
+
+  @Test
+  public void callWithEnforcementThrowsOnElementPropagates() throws Exception {
+    PCollection<byte[]> pcBytes =
+        created.apply(
+            new PTransform<PCollection<String>, PCollection<byte[]>>()
{
+              @Override
+              public PCollection<byte[]> apply(PCollection<String> input) {
+                return PCollection.<byte[]>createPrimitiveOutputInternal(
+                        input.getPipeline(), input.getWindowingStrategy(), input.isBounded())
+                    .setCoder(ByteArrayCoder.of());
+              }
+            });
+
+    final InProcessTransformResult result =
+        StepTransformResult.withoutHold(pcBytes.getProducingTransformInternal()).build();
+    final CountDownLatch testLatch = new CountDownLatch(1);
+    final CountDownLatch evaluatorLatch = new CountDownLatch(1);
+
+    TransformEvaluator<Object> evaluator =
+        new TransformEvaluator<Object>() {
+          @Override
+          public void processElement(WindowedValue<Object> element) throws Exception
{
+            testLatch.countDown();
+            evaluatorLatch.await();
+          }
+
+          @Override
+          public InProcessTransformResult finishBundle() throws Exception {
+            return result;
+          }
+        };
+
+    WindowedValue<byte[]> fooBytes = WindowedValue.valueInGlobalWindow("foo".getBytes());
+    CommittedBundle<byte[]> inputBundle =
+        InProcessBundle.unkeyed(pcBytes).add(fooBytes).commit(Instant.now());
+    when(
+            registry.forApplication(
+                pcBytes.getProducingTransformInternal(), inputBundle, evaluationContext))
+        .thenReturn(evaluator);
+
+    TransformExecutor<byte[]> executor =
+        TransformExecutor.create(
+            registry,
+            Collections.<ModelEnforcementFactory>singleton(ImmutabilityEnforcementFactory.create()),
+            evaluationContext,
+            inputBundle,
+            pcBytes.getProducingTransformInternal(),
+            completionCallback,
+            transformEvaluationState);
+
+    Future<InProcessTransformResult> task = Executors.newSingleThreadExecutor().submit(executor);
+    testLatch.await();
+    fooBytes.getValue()[0] = 'b';
+    evaluatorLatch.countDown();
+
+    thrown.expectCause(isA(UserCodeException.class));
+    task.get();
+  }
+
   private static class RegisteringCompletionCallback implements CompletionCallback {
     private InProcessTransformResult handledResult = null;
     private Throwable handledThrowable = null;
@@ -300,9 +482,11 @@ public class TransformExecutorTest {
     }
 
     @Override
-    public void handleResult(CommittedBundle<?> inputBundle, InProcessTransformResult
result) {
+    public Iterable<? extends CommittedBundle<?>> handleResult(
+        CommittedBundle<?> inputBundle, InProcessTransformResult result) {
       handledResult = result;
       onMethod.countDown();
+      return Collections.emptyList();
     }
 
     @Override
@@ -311,4 +495,39 @@ public class TransformExecutorTest {
       onMethod.countDown();
     }
   }
+
+  private static class TestEnforcementFactory implements ModelEnforcementFactory {
+    private TestEnforcement<?> instance;
+    @Override
+    public <T> TestEnforcement<T> forBundle(
+        CommittedBundle<T> input, AppliedPTransform<?, ?, ?> consumer) {
+      TestEnforcement<T> newEnforcement = new TestEnforcement<>();
+      instance = newEnforcement;
+      return newEnforcement;
+    }
+  }
+
+  private static class TestEnforcement<T> implements ModelEnforcement<T> {
+    private final List<WindowedValue<T>> beforeElements = new ArrayList<>();
+    private final List<WindowedValue<T>> afterElements = new ArrayList<>();
+    private final List<InProcessTransformResult> finishedBundles = new ArrayList<>();
+
+    @Override
+    public void beforeElement(WindowedValue<T> element) {
+      beforeElements.add(element);
+    }
+
+    @Override
+    public void afterElement(WindowedValue<T> element) {
+      afterElements.add(element);
+    }
+
+    @Override
+    public void afterFinish(
+        CommittedBundle<T> input,
+        InProcessTransformResult result,
+        Iterable<? extends CommittedBundle<?>> outputs) {
+      finishedBundles.add(result);
+    }
+  }
 }


Mime
View raw message