beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From chamik...@apache.org
Subject [beam] branch master updated: [BEAM-5959] Add GCS KMS support
Date Thu, 07 Feb 2019 00:03:42 GMT
This is an automated email from the ASF dual-hosted git repository.

chamikara pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 8150d3b  [BEAM-5959] Add GCS KMS support
     new 693dfe3  Merge pull request #7682: [BEAM-5959] Add GCS KMS support
8150d3b is described below

commit 8150d3be73e31595d757972c36e3b05fb6589fe8
Author: Udi Meiri <ehudm@google.com>
AuthorDate: Wed Jan 30 16:56:31 2019 -0800

    [BEAM-5959] Add GCS KMS support
    
    GCS operations refactored to support objects with KMS keys. Beam will
    not set KMS keys on GCS objects. Users can utilize bucket default keys.
    
    Also adds --dataflowKmsKey, which is passed to DataflowRunner.
    
    Details:
    - GCS copy operation reimplemented as rewrite - allowing copies of
    objects using KMS keys (source or dest).
      - Rewrite also supports copying across regions and storage classes.
      https://cloud.google.com/storage/docs/json_api/v1/objects/rewrite
    - Introduces --dataflowKmsKey flag, which should apply to Dataflow
    pipeline state.
      - If creating a new bucket for gcpTempLocation, sets the bucket
      default key to --dataflowKmsKey.
    - New integration tests:
      - GcsKmsKeyIT - pipeline with GCS sink using --dataflowKmsKey and
      gcpTempLocation.
      - GcsUtilIT - tests multi-part rewrite against prod GCS
---
 build.gradle                                       |   1 +
 runners/google-cloud-dataflow-java/build.gradle    |  23 ++++
 .../google-cloud-platform-core/build.gradle        |  31 +++++
 .../sdk/extensions/gcp/options/GcpOptions.java     |  34 +++--
 .../java/org/apache/beam/sdk/util/GcsUtil.java     | 137 ++++++++++++++++-----
 .../java/org/apache/beam/sdk/util/GcsUtilIT.java   |  73 +++++++++++
 .../java/org/apache/beam/sdk/util/GcsUtilTest.java |  52 +++++++-
 sdks/java/io/google-cloud-platform/build.gradle    |  24 ++++
 .../beam/sdk/io/gcp/storage/GcsKmsKeyIT.java       | 113 +++++++++++++++++
 9 files changed, 446 insertions(+), 42 deletions(-)

diff --git a/build.gradle b/build.gradle
index e526580..118c22d 100644
--- a/build.gradle
+++ b/build.gradle
@@ -198,6 +198,7 @@ task javaPreCommitPortabilityApi() {
 
 task javaPostCommit() {
   dependsOn ":beam-runners-google-cloud-dataflow-java:postCommit"
+  dependsOn ":beam-sdks-java-extensions-google-cloud-platform-core:postCommit"
   dependsOn ":beam-sdks-java-io-google-cloud-platform:postCommit"
 }
 
diff --git a/runners/google-cloud-dataflow-java/build.gradle b/runners/google-cloud-dataflow-java/build.gradle
index 407f9c9..3613f9b 100644
--- a/runners/google-cloud-dataflow-java/build.gradle
+++ b/runners/google-cloud-dataflow-java/build.gradle
@@ -113,6 +113,7 @@ def dataflowUploadTemp = project.findProperty('dataflowTempRoot') ?: 'gs://temp-
 def testFilesToStage = project.findProperty('filesToStage') ?: 'test.txt'
 def dataflowLegacyWorkerJar = project.findProperty('dataflowWorkerJar') ?: project(":beam-runners-google-cloud-dataflow-java-legacy-worker").shadowJar.archivePath
 def dataflowFnApiWorkerJar = project.findProperty('dataflowWorkerJar') ?: project(":beam-runners-google-cloud-dataflow-java-fn-api-worker").shadowJar.archivePath
+def dataflowKmsKey = project.findProperty('dataflowKmsKey') ?: "projects/apache-beam-testing/locations/global/keyRings/beam-it/cryptoKeys/test"
 
 def dockerImageRoot = project.findProperty('dockerImageRoot') ?: "us.gcr.io/${dataflowProject}/java-postcommit-it"
 def dockerImageContainer = "${dockerImageRoot}/java"
@@ -300,12 +301,32 @@ task googleCloudPlatformLegacyWorkerIntegrationTest(type: Test) {
   include '**/*IT.class'
   exclude '**/BigQueryIOReadIT.class'
   exclude '**/PubsubReadIT.class'
+  exclude '**/*KmsKeyIT.class'
   maxParallelForks 4
   classpath = configurations.googleCloudPlatformIntegrationTest
   testClassesDirs = files(project(":beam-sdks-java-io-google-cloud-platform").sourceSets.test.output.classesDirs)
   useJUnit { }
 }
 
+task googleCloudPlatformLegacyWorkerKmsIntegrationTest(type: Test) {
+    group = "Verification"
+    dependsOn ":beam-runners-google-cloud-dataflow-java-legacy-worker:shadowJar"
+    systemProperty "beamTestPipelineOptions", JsonOutput.toJson([
+            "--runner=TestDataflowRunner",
+            "--project=${dataflowProject}",
+            "--tempRoot=${dataflowPostCommitTempRoot}",
+            "--dataflowWorkerJar=${dataflowLegacyWorkerJar}",
+            "--workerHarnessContainerImage=",
+            "--dataflowKmsKey=${dataflowKmsKey}",
+    ])
+
+    include '**/*KmsKeyIT.class'
+    maxParallelForks 4
+    classpath = configurations.googleCloudPlatformIntegrationTest
+    testClassesDirs = files(project(":beam-sdks-java-io-google-cloud-platform").sourceSets.test.output.classesDirs)
+    useJUnit { }
+}
+
 task googleCloudPlatformFnApiWorkerIntegrationTest(type: Test) {
     group = "Verification"
     dependsOn ":beam-runners-google-cloud-dataflow-java-fn-api-worker:shadowJar"
@@ -326,6 +347,7 @@ task googleCloudPlatformFnApiWorkerIntegrationTest(type: Test) {
     exclude '**/SpannerWriteIT.class'
     exclude '**/BigQueryNestedRecordsIT.class'
     exclude '**/SplitQueryFnIT.class'
+    exclude '**/*KmsKeyIT.class'
 
     maxParallelForks 4
     classpath = configurations.googleCloudPlatformIntegrationTest
@@ -429,6 +451,7 @@ task postCommit {
   group = "Verification"
   description = "Various integration tests using the Dataflow runner."
   dependsOn googleCloudPlatformLegacyWorkerIntegrationTest
+  dependsOn googleCloudPlatformLegacyWorkerKmsIntegrationTest
   dependsOn examplesJavaLegacyWorkerIntegrationTest
   dependsOn coreSDKJavaLegacyWorkerIntegrationTest
 }
diff --git a/sdks/java/extensions/google-cloud-platform-core/build.gradle b/sdks/java/extensions/google-cloud-platform-core/build.gradle
index 45d3a8b..5be4cf2 100644
--- a/sdks/java/extensions/google-cloud-platform-core/build.gradle
+++ b/sdks/java/extensions/google-cloud-platform-core/build.gradle
@@ -16,6 +16,8 @@
  * limitations under the License.
  */
 
+import groovy.json.JsonOutput
+
 apply plugin: org.apache.beam.gradle.BeamModulePlugin
 applyJavaNature()
 
@@ -55,3 +57,32 @@ dependencies {
   testCompile library.java.mockito_core
   testCompile library.java.slf4j_jdk14
 }
+
+// Note that no runner is specified here, so tests running under this task should not be
running
+// pipelines.
+task integrationTest(type: Test) {
+  group = "Verification"
+  def gcpProject = project.findProperty('gcpProject') ?: 'apache-beam-testing'
+  def gcpTempRoot = project.findProperty('gcpTempRoot') ?: 'gs://temp-storage-for-end-to-end-tests'
+  def dataflowKmsKey = project.findProperty('dataflowKmsKey') ?: "projects/apache-beam-testing/locations/global/keyRings/beam-it/cryptoKeys/test"
+  systemProperty "beamTestPipelineOptions", JsonOutput.toJson([
+          "--project=${gcpProject}",
+          "--tempRoot=${gcpTempRoot}",
+          "--dataflowKmsKey=${dataflowKmsKey}",
+  ])
+
+  // Disable Gradle cache: these ITs interact with live service that should always be considered
"out of date"
+  outputs.upToDateWhen { false }
+
+  include '**/*IT.class'
+  maxParallelForks 4
+  classpath = sourceSets.test.runtimeClasspath
+  testClassesDirs = sourceSets.test.output.classesDirs
+  useJUnit { }
+}
+
+task postCommit {
+  group = "Verification"
+  description = "Integration tests of GCP connectors using the DirectRunner."
+  dependsOn integrationTest
+}
diff --git a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/options/GcpOptions.java
b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/options/GcpOptions.java
index a516b59..1575918 100644
--- a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/options/GcpOptions.java
+++ b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/options/GcpOptions.java
@@ -27,6 +27,7 @@ import com.google.api.client.util.Sleeper;
 import com.google.api.services.cloudresourcemanager.CloudResourceManager;
 import com.google.api.services.cloudresourcemanager.model.Project;
 import com.google.api.services.storage.model.Bucket;
+import com.google.api.services.storage.model.Bucket.Encryption;
 import com.google.auth.Credentials;
 import com.google.auth.http.HttpCredentialsAdapter;
 import com.google.cloud.hadoop.util.ChainingHttpRequestInitializer;
@@ -42,6 +43,7 @@ import java.util.Map;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
 import javax.annotation.Nullable;
+import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.extensions.gcp.auth.CredentialFactory;
 import org.apache.beam.sdk.extensions.gcp.auth.GcpCredentialFactory;
 import org.apache.beam.sdk.extensions.gcp.auth.NullCredentialInitializer;
@@ -268,9 +270,9 @@ public interface GcpOptions extends GoogleApiDebugOptions, PipelineOptions
{
      */
     @VisibleForTesting
     static String tryCreateDefaultBucket(PipelineOptions options, CloudResourceManager crmClient)
{
-      GcsOptions gcpOptions = options.as(GcsOptions.class);
+      GcsOptions gcsOptions = options.as(GcsOptions.class);
 
-      final String projectId = gcpOptions.getProject();
+      final String projectId = gcsOptions.getProject();
       checkArgument(!isNullOrEmpty(projectId), "--project is a required option.");
 
       // Look up the project number, to create a default bucket with a stable
@@ -282,16 +284,20 @@ public interface GcpOptions extends GoogleApiDebugOptions, PipelineOptions
{
         throw new RuntimeException("Unable to verify project with ID " + projectId, e);
       }
       String region = DEFAULT_REGION;
-      if (!isNullOrEmpty(gcpOptions.getZone())) {
-        region = getRegionFromZone(gcpOptions.getZone());
+      if (!isNullOrEmpty(gcsOptions.getZone())) {
+        region = getRegionFromZone(gcsOptions.getZone());
       }
       final String bucketName = "dataflow-staging-" + region + "-" + projectNumber;
       LOG.info("No tempLocation specified, attempting to use default bucket: {}", bucketName);
-      Bucket bucket = new Bucket().setName(bucketName).setLocation(region);
+      Bucket bucket =
+          new Bucket()
+              .setName(bucketName)
+              .setLocation(region)
+              .setEncryption(new Encryption().setDefaultKmsKeyName(gcsOptions.getDataflowKmsKey()));
       // Always try to create the bucket before checking access, so that we do not
       // race with other pipelines that may be attempting to do the same thing.
       try {
-        gcpOptions.getGcsUtil().createBucket(projectId, bucket);
+        gcsOptions.getGcsUtil().createBucket(projectId, bucket);
       } catch (FileAlreadyExistsException e) {
         LOG.debug("Bucket '{}'' already exists, verifying access.", bucketName);
       } catch (IOException e) {
@@ -301,7 +307,7 @@ public interface GcpOptions extends GoogleApiDebugOptions, PipelineOptions
{
       // Once the bucket is expected to exist, verify that it is correctly owned
       // by the project executing the job.
       try {
-        long owner = gcpOptions.getGcsUtil().bucketOwner(GcsPath.fromComponents(bucketName,
""));
+        long owner = gcsOptions.getGcsUtil().bucketOwner(GcsPath.fromComponents(bucketName,
""));
         checkArgument(
             owner == projectNumber,
             "Bucket owner does not match the project from --project:" + " %s vs. %s",
@@ -390,4 +396,18 @@ public interface GcpOptions extends GoogleApiDebugOptions, PipelineOptions
{
       }
     }
   }
+
+  /**
+   * GCP <a href="https://cloud.google.com/kms/">Cloud KMS</a> key for Dataflow
pipelines and
+   * buckets created by GcpTempLocationFactory.
+   */
+  @Description(
+      "GCP Cloud KMS key for Dataflow pipelines. Also used by gcpTempLocation as the default
key "
+          + "for new buckets. Key format is: "
+          + "projects/<project>/locations/<location>/keyRings/<keyring>/cryptoKeys/<key>")
+  @Experimental
+  @Nullable
+  String getDataflowKmsKey();
+
+  void setDataflowKmsKey(String dataflowKmsKey);
 }
diff --git a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/util/GcsUtil.java
b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/util/GcsUtil.java
index 0ff21bf..f06bed8 100644
--- a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/util/GcsUtil.java
+++ b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/util/GcsUtil.java
@@ -31,6 +31,7 @@ import com.google.api.client.util.Sleeper;
 import com.google.api.services.storage.Storage;
 import com.google.api.services.storage.model.Bucket;
 import com.google.api.services.storage.model.Objects;
+import com.google.api.services.storage.model.RewriteResponse;
 import com.google.api.services.storage.model.StorageObject;
 import com.google.auto.value.AutoValue;
 import com.google.cloud.hadoop.gcsio.GoogleCloudStorageReadChannel;
@@ -50,6 +51,8 @@ import java.nio.file.FileAlreadyExistsException;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.Iterator;
+import java.util.LinkedList;
 import java.util.List;
 import java.util.concurrent.CompletionStage;
 import java.util.concurrent.ExecutionException;
@@ -57,6 +60,7 @@ import java.util.concurrent.ExecutorService;
 import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.ThreadPoolExecutor;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
 import javax.annotation.Nullable;
@@ -141,6 +145,11 @@ public class GcsUtil {
   // Exposed for testing.
   final ExecutorService executorService;
 
+  /** Rewrite operation setting. For testing purposes only. */
+  @VisibleForTesting @Nullable Long maxBytesRewrittenPerCall;
+
+  @VisibleForTesting @Nullable AtomicInteger numRewriteTokensUsed;
+
   /** Returns the prefix portion of the glob that doesn't contain wildcards. */
   public static String getNonWildcardPrefix(String globExp) {
     Matcher m = GLOB_PREFIX.matcher(globExp);
@@ -210,6 +219,8 @@ public class GcsUtil {
     this.httpRequestInitializer = httpRequestInitializer;
     this.uploadBufferSizeBytes = uploadBufferSizeBytes;
     this.executorService = executorService;
+    this.maxBytesRewrittenPerCall = null;
+    this.numRewriteTokensUsed = null;
   }
 
   // Use this only for testing purposes.
@@ -279,6 +290,7 @@ public class GcsUtil {
   private static BackOff createBackOff() {
     return BackOffAdapter.toGcpBackOff(BACKOFF_FACTORY.backoff());
   }
+
   /**
    * Returns the file size from GCS or throws {@link FileNotFoundException} if the resource
does not
    * exist.
@@ -420,10 +432,11 @@ public class GcsUtil {
             new ClientRequestHelper<>(),
             path.getBucket(),
             path.getObject(),
+            type,
+            /* kmsKeyName= */ null,
             AsyncWriteChannelOptions.newBuilder().build(),
             new ObjectWriteConditions(),
-            Collections.emptyMap(),
-            type);
+            Collections.emptyMap());
     if (uploadBufferSizeBytes != null) {
       channel.setUploadBufferSize(uploadBufferSizeBytes);
     }
@@ -597,13 +610,84 @@ public class GcsUtil {
     return batches;
   }
 
+  /**
+   * Wrapper for RewriteRequest that supports multiple calls.
+   *
+   * <p>Usage: create, enqueue(), and execute batch. Then, check getReadyToEnqueue()
if another
+   * round of enqueue() and execute is required. Repeat until getReadyToEnqueue() returns
false.
+   */
+  class RewriteOp extends JsonBatchCallback<RewriteResponse> {
+    private GcsPath from;
+    private GcsPath to;
+    private boolean readyToEnqueue;
+    @VisibleForTesting Storage.Objects.Rewrite rewriteRequest;
+
+    public boolean getReadyToEnqueue() {
+      return readyToEnqueue;
+    }
+
+    public void enqueue(BatchRequest batch) throws IOException {
+      if (!readyToEnqueue) {
+        throw new IOException(
+            String.format(
+                "Invalid state for Rewrite, from=%s, to=%s, readyToEnqueue=%s",
+                from, to, readyToEnqueue));
+      }
+      rewriteRequest.queue(batch, this);
+      readyToEnqueue = false;
+    }
+
+    public RewriteOp(GcsPath from, GcsPath to) throws IOException {
+      this.from = from;
+      this.to = to;
+      rewriteRequest =
+          storageClient
+              .objects()
+              .rewrite(from.getBucket(), from.getObject(), to.getBucket(), to.getObject(),
null);
+      if (maxBytesRewrittenPerCall != null) {
+        rewriteRequest.setMaxBytesRewrittenPerCall(maxBytesRewrittenPerCall);
+      }
+      readyToEnqueue = true;
+    }
+
+    @Override
+    public void onSuccess(RewriteResponse rewriteResponse, HttpHeaders responseHeaders)
+        throws IOException {
+      if (rewriteResponse.getDone()) {
+        LOG.debug("Rewrite done: {} to {}", from, to);
+        readyToEnqueue = false;
+      } else {
+        LOG.debug(
+            "Rewrite progress: {} of {} bytes, {} to {}",
+            rewriteResponse.getTotalBytesRewritten(),
+            rewriteResponse.getObjectSize(),
+            from,
+            to);
+        rewriteRequest.setRewriteToken(rewriteResponse.getRewriteToken());
+        readyToEnqueue = true;
+        if (numRewriteTokensUsed != null) {
+          numRewriteTokensUsed.incrementAndGet();
+        }
+      }
+    }
+
+    @Override
+    public void onFailure(GoogleJsonError e, HttpHeaders responseHeaders) throws IOException
{
+      readyToEnqueue = false;
+      throw new IOException(String.format("Error trying to rewrite %s to %s: %s", from, to,
e));
+    }
+  }
+
   public void copy(Iterable<String> srcFilenames, Iterable<String> destFilenames)
       throws IOException {
-    executeBatches(makeCopyBatches(srcFilenames, destFilenames));
+    LinkedList<RewriteOp> rewrites = makeRewriteOps(srcFilenames, destFilenames);
+    while (rewrites.size() > 0) {
+      executeBatches(makeCopyBatches(rewrites));
+    }
   }
 
-  List<BatchRequest> makeCopyBatches(Iterable<String> srcFilenames, Iterable<String>
destFilenames)
-      throws IOException {
+  LinkedList<RewriteOp> makeRewriteOps(
+      Iterable<String> srcFilenames, Iterable<String> destFilenames) throws IOException
{
     List<String> srcList = Lists.newArrayList(srcFilenames);
     List<String> destList = Lists.newArrayList(destFilenames);
     checkArgument(
@@ -611,13 +695,27 @@ public class GcsUtil {
         "Number of source files %s must equal number of destination files %s",
         srcList.size(),
         destList.size());
-
-    List<BatchRequest> batches = new ArrayList<>();
-    BatchRequest batch = createBatchRequest();
+    LinkedList<RewriteOp> rewrites = Lists.newLinkedList();
     for (int i = 0; i < srcList.size(); i++) {
       final GcsPath sourcePath = GcsPath.fromUri(srcList.get(i));
       final GcsPath destPath = GcsPath.fromUri(destList.get(i));
-      enqueueCopy(sourcePath, destPath, batch);
+      rewrites.addLast(new RewriteOp(sourcePath, destPath));
+    }
+    return rewrites;
+  }
+
+  List<BatchRequest> makeCopyBatches(LinkedList<RewriteOp> rewrites) throws IOException
{
+    List<BatchRequest> batches = new ArrayList<>();
+    BatchRequest batch = createBatchRequest();
+    Iterator<RewriteOp> it = rewrites.iterator();
+    while (it.hasNext()) {
+      RewriteOp rewrite = it.next();
+      if (!rewrite.getReadyToEnqueue()) {
+        it.remove();
+        continue;
+      }
+      rewrite.enqueue(batch);
+
       if (batch.size() >= MAX_REQUESTS_PER_BATCH) {
         batches.add(batch);
         batch = createBatchRequest();
@@ -700,27 +798,6 @@ public class GcsUtil {
     }
   }
 
-  private void enqueueCopy(final GcsPath from, final GcsPath to, BatchRequest batch)
-      throws IOException {
-    Storage.Objects.Copy copyRequest =
-        storageClient
-            .objects()
-            .copy(from.getBucket(), from.getObject(), to.getBucket(), to.getObject(), null);
-    copyRequest.queue(
-        batch,
-        new JsonBatchCallback<StorageObject>() {
-          @Override
-          public void onSuccess(StorageObject obj, HttpHeaders responseHeaders) {
-            LOG.debug("Successfully copied {} to {}", from, to);
-          }
-
-          @Override
-          public void onFailure(GoogleJsonError e, HttpHeaders responseHeaders) throws IOException
{
-            throw new IOException(String.format("Error trying to copy %s to %s: %s", from,
to, e));
-          }
-        });
-  }
-
   private void enqueueDelete(final GcsPath file, BatchRequest batch) throws IOException {
     Storage.Objects.Delete deleteRequest =
         storageClient.objects().delete(file.getBucket(), file.getObject());
diff --git a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/util/GcsUtilIT.java
b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/util/GcsUtilIT.java
new file mode 100644
index 0000000..60e4ca1
--- /dev/null
+++ b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/util/GcsUtilIT.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.util;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.equalTo;
+import static org.junit.Assert.assertNotNull;
+
+import com.google.common.collect.Lists;
+import java.io.IOException;
+import java.util.Date;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.apache.beam.sdk.extensions.gcp.options.GcsOptions;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.testing.TestPipelineOptions;
+import org.apache.beam.sdk.util.gcsfs.GcsPath;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Integration tests for {@link GcsUtil}. These tests are designed to run against production
Google
+ * Cloud Storage.
+ *
+ * <p>This is a runnerless integration test, even though the Beam IT framework assumes
one. Thus,
+ * this test should only be run against single runner (such as DirectRunner).
+ */
+@RunWith(JUnit4.class)
+public class GcsUtilIT {
+  /** Tests a rewrite operation that requires multiple API calls (using a continuation token).
*/
+  @Test
+  public void testRewriteMultiPart() throws IOException {
+    TestPipelineOptions options =
+        TestPipeline.testingPipelineOptions().as(TestPipelineOptions.class);
+    GcsOptions gcsOptions = options.as(GcsOptions.class);
+    // Setting the KMS key is necessary to trigger multi-part rewrites (gcpTempLocation is
created
+    // with a bucket default key).
+    assertNotNull(gcsOptions.getDataflowKmsKey());
+
+    GcsUtil gcsUtil = gcsOptions.getGcsUtil();
+    String srcFilename = "gs://dataflow-samples/wikipedia_edits/wiki_data-000000000000.json";
+    String dstFilename =
+        gcsOptions.getGcpTempLocation()
+            + String.format(
+                "/GcsUtilIT-%tF-%<tH-%<tM-%<tS-%<tL.testRewriteMultiPart.copy",
new Date());
+    gcsUtil.maxBytesRewrittenPerCall = 50L * 1024 * 1024;
+    gcsUtil.numRewriteTokensUsed = new AtomicInteger();
+
+    gcsUtil.copy(Lists.newArrayList(srcFilename), Lists.newArrayList(dstFilename));
+
+    assertThat(gcsUtil.numRewriteTokensUsed.get(), equalTo(3));
+    assertThat(
+        gcsUtil.getObject(GcsPath.fromUri(srcFilename)).getMd5Hash(),
+        equalTo(gcsUtil.getObject(GcsPath.fromUri(dstFilename)).getMd5Hash()));
+
+    gcsUtil.remove(Lists.newArrayList(dstFilename));
+  }
+}
diff --git a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/util/GcsUtilTest.java
b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/util/GcsUtilTest.java
index bebe07e..5829cb8 100644
--- a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/util/GcsUtilTest.java
+++ b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/util/GcsUtilTest.java
@@ -66,6 +66,7 @@ import java.nio.file.AccessDeniedException;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.LinkedList;
 import java.util.List;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutorService;
@@ -74,6 +75,7 @@ import java.util.concurrent.TimeUnit;
 import org.apache.beam.sdk.extensions.gcp.auth.TestCredential;
 import org.apache.beam.sdk.extensions.gcp.options.GcsOptions;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.util.GcsUtil.RewriteOp;
 import org.apache.beam.sdk.util.GcsUtil.StorageObjectOrIOException;
 import org.apache.beam.sdk.util.gcsfs.GcsPath;
 import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableList;
@@ -838,32 +840,72 @@ public class GcsUtilTest {
   }
 
   @Test
+  public void testMakeRewriteOps() throws IOException {
+    GcsOptions gcsOptions = gcsOptionsWithTestCredential();
+    GcsUtil gcsUtil = gcsOptions.getGcsUtil();
+
+    LinkedList<RewriteOp> rewrites =
+        gcsUtil.makeRewriteOps(makeStrings("s", 1), makeStrings("d", 1));
+    assertEquals(1, rewrites.size());
+
+    RewriteOp rewrite = rewrites.pop();
+    assertTrue(rewrite.getReadyToEnqueue());
+    Storage.Objects.Rewrite request = rewrite.rewriteRequest;
+    assertNull(request.getMaxBytesRewrittenPerCall());
+    assertEquals("bucket", request.getSourceBucket());
+    assertEquals("s0", request.getSourceObject());
+    assertEquals("bucket", request.getDestinationBucket());
+    assertEquals("d0", request.getDestinationObject());
+  }
+
+  @Test
+  public void testMakeRewriteOpsWithOptions() throws IOException {
+    GcsOptions gcsOptions = gcsOptionsWithTestCredential();
+    GcsUtil gcsUtil = gcsOptions.getGcsUtil();
+    gcsUtil.maxBytesRewrittenPerCall = 1337L;
+
+    LinkedList<RewriteOp> rewrites =
+        gcsUtil.makeRewriteOps(makeStrings("s", 1), makeStrings("d", 1));
+    assertEquals(1, rewrites.size());
+
+    RewriteOp rewrite = rewrites.pop();
+    assertTrue(rewrite.getReadyToEnqueue());
+    Storage.Objects.Rewrite request = rewrite.rewriteRequest;
+    assertEquals(Long.valueOf(1337L), request.getMaxBytesRewrittenPerCall());
+  }
+
+  @Test
   public void testMakeCopyBatches() throws IOException {
     GcsUtil gcsUtil = gcsOptionsWithTestCredential().getGcsUtil();
 
     // Small number of files fits in 1 batch
-    List<BatchRequest> batches = gcsUtil.makeCopyBatches(makeStrings("s", 3), makeStrings("d",
3));
+    List<BatchRequest> batches =
+        gcsUtil.makeCopyBatches(gcsUtil.makeRewriteOps(makeStrings("s", 3), makeStrings("d",
3)));
     assertThat(batches.size(), equalTo(1));
     assertThat(sumBatchSizes(batches), equalTo(3));
 
     // 1 batch of files fits in 1 batch
-    batches = gcsUtil.makeCopyBatches(makeStrings("s", 100), makeStrings("d", 100));
+    batches =
+        gcsUtil.makeCopyBatches(
+            gcsUtil.makeRewriteOps(makeStrings("s", 100), makeStrings("d", 100)));
     assertThat(batches.size(), equalTo(1));
     assertThat(sumBatchSizes(batches), equalTo(100));
 
     // A little more than 5 batches of files fits in 6 batches
-    batches = gcsUtil.makeCopyBatches(makeStrings("s", 501), makeStrings("d", 501));
+    batches =
+        gcsUtil.makeCopyBatches(
+            gcsUtil.makeRewriteOps(makeStrings("s", 501), makeStrings("d", 501)));
     assertThat(batches.size(), equalTo(6));
     assertThat(sumBatchSizes(batches), equalTo(501));
   }
 
   @Test
-  public void testInvalidCopyBatches() throws IOException {
+  public void testMakeRewriteOpsInvalid() throws IOException {
     GcsUtil gcsUtil = gcsOptionsWithTestCredential().getGcsUtil();
     thrown.expect(IllegalArgumentException.class);
     thrown.expectMessage("Number of source files 3");
 
-    gcsUtil.makeCopyBatches(makeStrings("s", 3), makeStrings("d", 1));
+    gcsUtil.makeRewriteOps(makeStrings("s", 3), makeStrings("d", 1));
   }
 
   @Test
diff --git a/sdks/java/io/google-cloud-platform/build.gradle b/sdks/java/io/google-cloud-platform/build.gradle
index 4e8c5fb..532d3b8 100644
--- a/sdks/java/io/google-cloud-platform/build.gradle
+++ b/sdks/java/io/google-cloud-platform/build.gradle
@@ -99,6 +99,29 @@ task integrationTest(type: Test) {
   include '**/*IT.class'
   exclude '**/BigQueryIOReadIT.class'
   exclude '**/BigQueryToTableIT.class'
+  exclude '**/*KmsKeyIT.class'
+  maxParallelForks 4
+  classpath = sourceSets.test.runtimeClasspath
+  testClassesDirs = sourceSets.test.output.classesDirs
+  useJUnit { }
+}
+
+task integrationTestKms(type: Test) {
+  group = "Verification"
+  def gcpProject = project.findProperty('gcpProject') ?: 'apache-beam-testing'
+  def gcpTempRoot = project.findProperty('gcpTempRoot') ?: 'gs://temp-storage-for-end-to-end-tests'
+  def dataflowKmsKey = project.findProperty('dataflowKmsKey') ?: "projects/apache-beam-testing/locations/global/keyRings/beam-it/cryptoKeys/test"
+  systemProperty "beamTestPipelineOptions", JsonOutput.toJson([
+          "--runner=DirectRunner",
+          "--project=${gcpProject}",
+          "--tempRoot=${gcpTempRoot}",
+          "--dataflowKmsKey=${dataflowKmsKey}",
+  ])
+
+  // Disable Gradle cache: these ITs interact with live service that should always be considered
"out of date"
+  outputs.upToDateWhen { false }
+
+  include '**/*KmsKeyIT.class'
   maxParallelForks 4
   classpath = sourceSets.test.runtimeClasspath
   testClassesDirs = sourceSets.test.output.classesDirs
@@ -109,4 +132,5 @@ task postCommit {
   group = "Verification"
   description = "Integration tests of GCP connectors using the DirectRunner."
   dependsOn integrationTest
+  dependsOn integrationTestKms
 }
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/storage/GcsKmsKeyIT.java
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/storage/GcsKmsKeyIT.java
new file mode 100644
index 0000000..7130195
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/storage/GcsKmsKeyIT.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.storage;
+
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.CoreMatchers.notNullValue;
+import static org.hamcrest.CoreMatchers.startsWith;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+import com.google.common.collect.Iterables;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.Date;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.PipelineResult;
+import org.apache.beam.sdk.PipelineResult.State;
+import org.apache.beam.sdk.extensions.gcp.options.GcsOptions;
+import org.apache.beam.sdk.io.FileSystems;
+import org.apache.beam.sdk.io.TextIO;
+import org.apache.beam.sdk.io.fs.MatchResult;
+import org.apache.beam.sdk.io.fs.MatchResult.Metadata;
+import org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions;
+import org.apache.beam.sdk.io.fs.ResourceId;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.testing.FileChecksumMatcher;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.testing.TestPipelineOptions;
+import org.apache.beam.sdk.util.GcsUtil;
+import org.apache.beam.sdk.util.gcsfs.GcsPath;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+// Run a specific test using:
+//   ./gradlew :beam-sdks-java-io-google-cloud-platform:integrationTest --tests
+// GcsKmsKeyIT.testGcsWriteWithKmsKey --info
+
+/** Integration test for GCS CMEK support. */
+@RunWith(JUnit4.class)
+public class GcsKmsKeyIT {
+
+  private static final String INPUT_FILE = "gs://dataflow-samples/shakespeare/kinglear.txt";
+  private static final String EXPECTED_CHECKSUM = "b9778bfac7fa8b934e42a322ef4bd4706b538fd0";
+
+  @BeforeClass
+  public static void setup() {
+    PipelineOptionsFactory.register(TestPipelineOptions.class);
+  }
+
+  /**
+   * Tests writing to gcpTempLocation with --dataflowKmsKey set on the command line. Verifies
that
+   * resulting output uses specified key and is readable. Does not verify any temporary files.
+   */
+  @Test
+  public void testGcsWriteWithKmsKey() {
+    TestPipelineOptions options =
+        TestPipeline.testingPipelineOptions().as(TestPipelineOptions.class);
+    GcsOptions gcsOptions = options.as(GcsOptions.class);
+    final String expectedKmsKey = gcsOptions.getDataflowKmsKey();
+    assertThat(expectedKmsKey, notNullValue());
+
+    ResourceId filenamePrefix =
+        FileSystems.matchNewResource(gcsOptions.getGcpTempLocation(), true)
+            .resolve(
+                String.format("GcsKmsKeyIT-%tF-%<tH-%<tM-%<tS-%<tL.output", new
Date()),
+                StandardResolveOptions.RESOLVE_FILE);
+
+    Pipeline p = Pipeline.create(options);
+    p.apply("ReadLines", TextIO.read().from(INPUT_FILE))
+        .apply("WriteLines", TextIO.write().to(filenamePrefix));
+
+    PipelineResult result = p.run();
+    State state = result.waitUntilFinish();
+    assertThat(state, equalTo(State.DONE));
+
+    String filePattern = filenamePrefix + "*-of-*";
+    assertThat(result, new FileChecksumMatcher(EXPECTED_CHECKSUM, filePattern));
+
+    // Verify objects have KMS key set.
+    try {
+      MatchResult matchResult =
+          Iterables.getOnlyElement(FileSystems.match(Collections.singletonList(filePattern)));
+      GcsUtil gcsUtil = gcsOptions.getGcsUtil();
+      for (Metadata metadata : matchResult.metadata()) {
+        String kmsKey =
+            gcsUtil.getObject(GcsPath.fromUri(metadata.resourceId().toString())).getKmsKeyName();
+        // Returned kmsKey should have a version suffix.
+        assertThat(
+            metadata.resourceId().toString(),
+            kmsKey,
+            startsWith(expectedKmsKey + "/cryptoKeyVersions/"));
+      }
+    } catch (IOException e) {
+      throw new AssertionError(e);
+    }
+  }
+}


Mime
View raw message