beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From chamik...@apache.org
Subject [1/2] beam git commit: Added a preprocessing step to the Cloud Spanner sink.
Date Tue, 07 Nov 2017 01:35:27 GMT
Repository: beam
Updated Branches:
  refs/heads/master 35952f655 -> 3dfcb4447


Added a preprocessing step to the Cloud Spanner sink.

The general intuition we follow here: if mutations are presorted by the primary key before batching, it is more likely that mutations in the batch will end up in the same partition. It minimizes the number of participants in the distributed transaction on the Cloud Spanner side and leads to a better throughput.

Mutations are encoded before running other steps to avoid paying the serialization price. Primary keys are encoded using OrderedCode library, and ApproximateQuantiles transform is used to sample keys.

Once primary keys are sampled, for each mutation we assign the index of the closest primary key as a key and group by that key. Range deletes are submitted separately.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/227801b3
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/227801b3
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/227801b3

Branch: refs/heads/master
Commit: 227801b31e31294c18dcc28bf98078f0868c41b0
Parents: 35952f6
Author: Mairbek Khadikov <mairbek@google.com>
Authored: Fri Sep 29 16:43:05 2017 -0700
Committer: chamikara@google.com <chamikara@google.com>
Committed: Mon Nov 6 17:22:54 2017 -0800

----------------------------------------------------------------------
 .../sdk/io/gcp/spanner/NaiveSpannerReadFn.java  |   1 +
 .../sdk/io/gcp/spanner/SerializedMutation.java  |  35 ++
 .../io/gcp/spanner/SerializedMutationCoder.java |  60 +++
 .../beam/sdk/io/gcp/spanner/SpannerIO.java      | 309 ++++++++++++-
 .../sdk/io/gcp/spanner/SpannerWriteGroupFn.java | 133 ------
 .../sdk/io/gcp/spanner/SpannerIOWriteTest.java  | 447 ++++++++++++++-----
 .../beam/sdk/io/gcp/spanner/SpannerWriteIT.java |   5 +-
 7 files changed, 746 insertions(+), 244 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/227801b3/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/NaiveSpannerReadFn.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/NaiveSpannerReadFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/NaiveSpannerReadFn.java
index 5dc6ead..34996f1 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/NaiveSpannerReadFn.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/NaiveSpannerReadFn.java
@@ -48,6 +48,7 @@ class NaiveSpannerReadFn extends DoFn<ReadOperation, Struct> {
   public void setup() throws Exception {
     spannerAccessor = config.connectToSpanner();
   }
+
   @Teardown
   public void teardown() throws Exception {
     spannerAccessor.close();

http://git-wip-us.apache.org/repos/asf/beam/blob/227801b3/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SerializedMutation.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SerializedMutation.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SerializedMutation.java
new file mode 100644
index 0000000..a5bebce
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SerializedMutation.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.spanner;
+
+import com.google.auto.value.AutoValue;
+
+@AutoValue
+abstract class SerializedMutation {
+  static SerializedMutation create(String tableName, byte[] key,
+      byte[] bytes) {
+    return new AutoValue_SerializedMutation(tableName, key, bytes);
+  }
+
+  abstract String getTableName();
+
+  abstract byte[] getEncodedKey();
+
+  abstract byte[] getMutationGroupBytes();
+
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/227801b3/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SerializedMutationCoder.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SerializedMutationCoder.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SerializedMutationCoder.java
new file mode 100644
index 0000000..33ec1ed
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SerializedMutationCoder.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.spanner;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import org.apache.beam.sdk.coders.AtomicCoder;
+import org.apache.beam.sdk.coders.ByteArrayCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+
+class SerializedMutationCoder extends AtomicCoder<SerializedMutation> {
+
+  private static final SerializedMutationCoder INSTANCE = new SerializedMutationCoder();
+
+  public static SerializedMutationCoder of() {
+    return INSTANCE;
+  }
+
+  private final ByteArrayCoder byteArrayCoder;
+  private final StringUtf8Coder stringCoder;
+
+  private SerializedMutationCoder() {
+    byteArrayCoder = ByteArrayCoder.of();
+    stringCoder = StringUtf8Coder.of();
+  }
+
+  @Override
+  public void encode(SerializedMutation value, OutputStream out)
+      throws IOException {
+    stringCoder.encode(value.getTableName(), out);
+    byteArrayCoder.encode(value.getEncodedKey(), out);
+    byteArrayCoder.encode(value.getMutationGroupBytes(), out);
+  }
+
+  @Override
+  public SerializedMutation decode(InputStream in)
+      throws IOException {
+    String tableName =  stringCoder.decode(in);
+    byte[] encodedKey = byteArrayCoder.decode(in);
+    byte[] mutationBytes = byteArrayCoder.decode(in);
+    return SerializedMutation.create(tableName, encodedKey, mutationBytes);
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/227801b3/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
index be4417b..530c466 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
@@ -23,6 +23,8 @@ import static com.google.common.base.Preconditions.checkNotNull;
 import com.google.auto.value.AutoValue;
 import com.google.cloud.ServiceFactory;
 import com.google.cloud.Timestamp;
+import com.google.cloud.spanner.AbortedException;
+import com.google.cloud.spanner.Key;
 import com.google.cloud.spanner.KeySet;
 import com.google.cloud.spanner.Mutation;
 import com.google.cloud.spanner.Spanner;
@@ -31,23 +33,41 @@ import com.google.cloud.spanner.Statement;
 import com.google.cloud.spanner.Struct;
 import com.google.cloud.spanner.TimestampBound;
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Iterables;
+import com.google.common.primitives.UnsignedBytes;
+import java.io.Serializable;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.Comparator;
 import java.util.List;
+import java.util.Map;
+import java.util.UUID;
 import javax.annotation.Nullable;
 import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.options.ValueProvider;
+import org.apache.beam.sdk.transforms.ApproximateQuantiles;
+import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.Reshuffle;
 import org.apache.beam.sdk.transforms.View;
 import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.util.BackOff;
+import org.apache.beam.sdk.util.BackOffUtils;
+import org.apache.beam.sdk.util.FluentBackoff;
+import org.apache.beam.sdk.util.Sleeper;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.PDone;
+import org.joda.time.Duration;
 
 /**
  * Experimental {@link PTransform Transforms} for reading from and writing to <a
@@ -159,6 +179,12 @@ import org.apache.beam.sdk.values.PDone;
 public class SpannerIO {
 
   private static final long DEFAULT_BATCH_SIZE_BYTES = 1024 * 1024; // 1 MB
+  // Max number of mutations to batch together.
+  private static final int MAX_NUM_MUTATIONS = 10000;
+  // The maximum number of keys to fit in memory when computing approximate quantiles.
+  private static final long MAX_NUM_KEYS = (long) 1e6;
+  // TODO calculate number of samples based on the size of the input.
+  private static final int DEFAULT_NUM_SAMPLES = 1000;
 
   /**
    * Creates an uninitialized instance of {@link Read}. Before use, the {@link Read} must be
@@ -208,6 +234,7 @@ public class SpannerIO {
     return new AutoValue_SpannerIO_Write.Builder()
         .setSpannerConfig(SpannerConfig.create())
         .setBatchSizeBytes(DEFAULT_BATCH_SIZE_BYTES)
+        .setNumSamples(DEFAULT_NUM_SAMPLES)
         .build();
   }
 
@@ -572,6 +599,12 @@ public class SpannerIO {
 
     abstract long getBatchSizeBytes();
 
+    abstract int getNumSamples();
+
+    @Nullable
+     abstract PTransform<PCollection<KV<String, byte[]>>, PCollection<KV<String, List<byte[]>>>>
+         getSampler();
+
     abstract Builder toBuilder();
 
     @AutoValue.Builder
@@ -581,6 +614,12 @@ public class SpannerIO {
 
       abstract Builder setBatchSizeBytes(long batchSizeBytes);
 
+      abstract Builder setNumSamples(int numSamples);
+
+      abstract Builder setSampler(
+          PTransform<PCollection<KV<String, byte[]>>, PCollection<KV<String, List<byte[]>>>>
+              sampler);
+
       abstract Write build();
     }
 
@@ -634,6 +673,13 @@ public class SpannerIO {
       return withSpannerConfig(config.withServiceFactory(serviceFactory));
     }
 
+    @VisibleForTesting
+    Write withSampler(
+        PTransform<PCollection<KV<String, byte[]>>, PCollection<KV<String, List<byte[]>>>>
+            sampler) {
+      return toBuilder().setSampler(sampler).build();
+    }
+
     /**
      * Same transform but can be applied to {@link PCollection} of {@link MutationGroup}.
      */
@@ -652,11 +698,10 @@ public class SpannerIO {
 
       input
           .apply("To mutation group", ParDo.of(new ToMutationGroupFn()))
-          .apply("Write mutations to Cloud Spanner", ParDo.of(new SpannerWriteGroupFn(this)));
+          .apply("Write mutations to Cloud Spanner", new WriteGrouped(this));
       return PDone.in(input.getPipeline());
     }
 
-
     @Override
     public void populateDisplayData(DisplayData.Builder builder) {
       super.populateDisplayData(builder);
@@ -666,6 +711,19 @@ public class SpannerIO {
     }
   }
 
+  /**
+   * A singleton that wraps {@code UnsignedBytes#lexicographicalComparator} which unfortunately
+   * is not serializable.
+   */
+  @VisibleForTesting
+  enum SerializableBytesComparator implements Comparator<byte[]>, Serializable {
+    INSTANCE {
+      @Override public int compare(byte[] a, byte[] b) {
+        return UnsignedBytes.lexicographicalComparator().compare(a, b);
+      }
+    }
+  }
+
   /** Same as {@link Write} but supports grouped mutations. */
   public static class WriteGrouped extends PTransform<PCollection<MutationGroup>, PDone> {
     private final Write spec;
@@ -674,9 +732,53 @@ public class SpannerIO {
       this.spec = spec;
     }
 
-    @Override public PDone expand(PCollection<MutationGroup> input) {
-      input.apply("Write mutations to Cloud Spanner", ParDo.of(new SpannerWriteGroupFn(spec)));
+    @Override
+    public PDone expand(PCollection<MutationGroup> input) {
+      PTransform<PCollection<KV<String, byte[]>>, PCollection<KV<String, List<byte[]>>>>
+          sampler = spec.getSampler();
+      if (sampler == null) {
+        sampler = createDefaultSampler();
+      }
+      // First, read the Cloud Spanner schema.
+      final PCollectionView<SpannerSchema> schemaView = input.getPipeline()
+          .apply(Create.of((Void) null))
+          .apply("Read information schema",
+              ParDo.of(new ReadSpannerSchema(spec.getSpannerConfig())))
+          .apply("Schema View", View.<SpannerSchema>asSingleton());
+
+      // Serialize mutations, we don't need to encode/decode them while reshuffling.
+      // The primary key is encoded via OrderedCode so we can calculate quantiles.
+      PCollection<SerializedMutation> serialized = input
+          .apply("Serialize mutations",
+              ParDo.of(new SerializeMutationsFn(schemaView)).withSideInputs(schemaView))
+          .setCoder(SerializedMutationCoder.of());
+
+      // Sample primary keys using ApproximateQuantiles.
+      PCollectionView<Map<String, List<byte[]>>> keySample = serialized
+          .apply("Extract keys", ParDo.of(new ExtractKeys()))
+          .apply("Sample keys", sampler)
+          .apply("Keys sample as view", View.<String, List<byte[]>>asMap());
+
+      // Assign partition based on the closest element in the sample and group mutations.
+      AssignPartitionFn assignPartitionFn = new AssignPartitionFn(keySample);
+      serialized
+          .apply("Partition input", ParDo.of(assignPartitionFn).withSideInputs(keySample))
+          .setCoder(KvCoder.of(StringUtf8Coder.of(), SerializedMutationCoder.of()))
+          .apply("Group by partition", GroupByKey.<String, SerializedMutation>create())
+          .apply("Batch mutations together",
+              ParDo.of(new BatchFn(spec.getBatchSizeBytes(), spec.getSpannerConfig(), schemaView))
+                  .withSideInputs(schemaView))
+          .apply("Write mutations to Spanner",
+          ParDo.of(new WriteToSpannerFn(spec.getSpannerConfig())));
       return PDone.in(input.getPipeline());
+
+    }
+
+    private PTransform<PCollection<KV<String, byte[]>>, PCollection<KV<String, List<byte[]>>>>
+        createDefaultSampler() {
+      return Combine.perKey(ApproximateQuantiles.ApproximateQuantilesCombineFn
+          .create(spec.getNumSamples(), SerializableBytesComparator.INSTANCE, MAX_NUM_KEYS,
+              1. / spec.getNumSamples()));
     }
   }
 
@@ -688,5 +790,202 @@ public class SpannerIO {
     }
   }
 
-  private SpannerIO() {} // Prevent construction.
+  /**
+   * Serializes mutations to ((table name, serialized key), serialized value) tuple.
+   */
+  private static class SerializeMutationsFn
+      extends DoFn<MutationGroup, SerializedMutation> {
+
+    final PCollectionView<SpannerSchema> schemaView;
+
+    private SerializeMutationsFn(PCollectionView<SpannerSchema> schemaView) {
+      this.schemaView = schemaView;
+    }
+
+    @ProcessElement
+    public void processElement(ProcessContext c) {
+      MutationGroup g = c.element();
+      Mutation m = g.primary();
+      SpannerSchema schema = c.sideInput(schemaView);
+      String table = m.getTable();
+      MutationGroupEncoder mutationGroupEncoder = new MutationGroupEncoder(schema);
+
+      byte[] key;
+      if (m.getOperation() != Mutation.Op.DELETE) {
+        key = mutationGroupEncoder.encodeKey(m);
+      } else if (isPointDelete(m)) {
+        Key next = m.getKeySet().getKeys().iterator().next();
+        key = mutationGroupEncoder.encodeKey(m.getTable(), next);
+      } else {
+        // The key is left empty for non-point deletes, since there is no general way to batch them.
+        key = new byte[] {};
+      }
+      byte[] value = mutationGroupEncoder.encode(g);
+      c.output(SerializedMutation.create(table, key, value));
+    }
+  }
+
+  private static class ExtractKeys
+      extends DoFn<SerializedMutation, KV<String, byte[]>> {
+
+    @ProcessElement
+    public void processElement(ProcessContext c) {
+      SerializedMutation m = c.element();
+      c.output(KV.of(m.getTableName(), m.getEncodedKey()));
+    }
+  }
+
+
+
+  private static boolean isPointDelete(Mutation m) {
+    return m.getOperation() == Mutation.Op.DELETE && Iterables.isEmpty(m.getKeySet().getRanges())
+        && Iterables.size(m.getKeySet().getKeys()) == 1;
+  }
+
+  /**
+   * Assigns a partition to the mutation group token based on the sampled data.
+   */
+  private static class AssignPartitionFn
+      extends DoFn<SerializedMutation, KV<String, SerializedMutation>> {
+
+    final PCollectionView<Map<String, List<byte[]>>> sampleView;
+
+    public AssignPartitionFn(PCollectionView<Map<String, List<byte[]>>> sampleView) {
+      this.sampleView = sampleView;
+    }
+
+    @ProcessElement public void processElement(ProcessContext c) {
+      Map<String, List<byte[]>> sample = c.sideInput(sampleView);
+      SerializedMutation g = c.element();
+      String table = g.getTableName();
+      byte[] key = g.getEncodedKey();
+      String groupKey;
+      if (key.length == 0) {
+        // This is a range or multi-key delete mutation. We cannot group it with other mutations
+        // so we assign a random group key to it so it is applied independently.
+        groupKey = UUID.randomUUID().toString();
+      } else {
+        int partition = Collections
+            .binarySearch(sample.get(table), key, SerializableBytesComparator.INSTANCE);
+        if (partition < 0) {
+          partition = -partition - 1;
+        }
+        groupKey = table + "%" + partition;
+      }
+      c.output(KV.of(groupKey, g));
+    }
+  }
+
+  /**
+   * Batches mutations together.
+   */
+  private static class BatchFn
+      extends DoFn<KV<String, Iterable<SerializedMutation>>, Iterable<Mutation>> {
+
+    private static final int MAX_RETRIES = 5;
+    private static final FluentBackoff BUNDLE_WRITE_BACKOFF = FluentBackoff.DEFAULT
+        .withMaxRetries(MAX_RETRIES).withInitialBackoff(Duration.standardSeconds(5));
+
+    private final long maxBatchSizeBytes;
+    private final SpannerConfig spannerConfig;
+    private final PCollectionView<SpannerSchema> schemaView;
+
+    private transient SpannerAccessor spannerAccessor;
+    // Current batch of mutations to be written.
+    private List<Mutation> mutations;
+    // total size of the current batch.
+    private long batchSizeBytes;
+
+    private BatchFn(long maxBatchSizeBytes, SpannerConfig spannerConfig,
+        PCollectionView<SpannerSchema> schemaView) {
+      this.maxBatchSizeBytes = maxBatchSizeBytes;
+      this.spannerConfig = spannerConfig;
+      this.schemaView = schemaView;
+    }
+
+    @Setup
+    public void setup() throws Exception {
+      mutations = new ArrayList<>();
+      batchSizeBytes = 0;
+      spannerAccessor = spannerConfig.connectToSpanner();
+    }
+
+    @Teardown
+    public void teardown() throws Exception {
+      spannerAccessor.close();
+    }
+
+    @ProcessElement
+    public void processElement(ProcessContext c) throws Exception {
+      MutationGroupEncoder mutationGroupEncoder = new MutationGroupEncoder(c.sideInput(schemaView));
+      KV<String, Iterable<SerializedMutation>> element = c.element();
+      for (SerializedMutation kv : element.getValue()) {
+        byte[] value = kv.getMutationGroupBytes();
+        MutationGroup mg = mutationGroupEncoder.decode(value);
+        Iterables.addAll(mutations, mg);
+        batchSizeBytes += MutationSizeEstimator.sizeOf(mg);
+        if (batchSizeBytes >= maxBatchSizeBytes || mutations.size() > MAX_NUM_MUTATIONS) {
+          c.output(mutations);
+          mutations = new ArrayList<>();
+          batchSizeBytes = 0;
+        }
+      }
+      if (!mutations.isEmpty()) {
+        c.output(mutations);
+        mutations = new ArrayList<>();
+        batchSizeBytes = 0;
+      }
+    }
+  }
+
+  private static class WriteToSpannerFn
+      extends DoFn<Iterable<Mutation>, Void> {
+    private static final int MAX_RETRIES = 5;
+    private static final FluentBackoff BUNDLE_WRITE_BACKOFF = FluentBackoff.DEFAULT
+        .withMaxRetries(MAX_RETRIES).withInitialBackoff(Duration.standardSeconds(5));
+
+    private transient SpannerAccessor spannerAccessor;
+    private final SpannerConfig spannerConfig;
+
+    public WriteToSpannerFn(SpannerConfig spannerConfig) {
+      this.spannerConfig = spannerConfig;
+    }
+
+    @Setup
+    public void setup() throws Exception {
+      spannerAccessor = spannerConfig.connectToSpanner();
+    }
+
+    @Teardown
+    public void teardown() throws Exception {
+      spannerAccessor.close();
+    }
+
+
+    @ProcessElement
+    public void processElement(ProcessContext c) throws Exception {
+      Sleeper sleeper = Sleeper.DEFAULT;
+      BackOff backoff = BUNDLE_WRITE_BACKOFF.backoff();
+
+      Iterable<Mutation> mutations = c.element();
+
+      while (true) {
+        // Batch upsert rows.
+        try {
+          spannerAccessor.getDatabaseClient().writeAtLeastOnce(mutations);
+          // Break if the commit threw no exception.
+          break;
+        } catch (AbortedException exception) {
+          // Only log the code and message for potentially-transient errors. The entire exception
+          // will be propagated upon the last retry.
+          if (!BackOffUtils.next(sleeper, backoff)) {
+            throw exception;
+          }
+        }
+      }
+    }
+
+  }
+
+    private SpannerIO() {} // Prevent construction.
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/227801b3/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteGroupFn.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteGroupFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteGroupFn.java
deleted file mode 100644
index 9343c0c..0000000
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteGroupFn.java
+++ /dev/null
@@ -1,133 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.sdk.io.gcp.spanner;
-
-import com.google.cloud.spanner.AbortedException;
-import com.google.cloud.spanner.DatabaseClient;
-import com.google.common.annotations.VisibleForTesting;
-import com.google.common.collect.Iterables;
-
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-
-import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.display.DisplayData;
-import org.apache.beam.sdk.util.BackOff;
-import org.apache.beam.sdk.util.BackOffUtils;
-import org.apache.beam.sdk.util.FluentBackoff;
-import org.apache.beam.sdk.util.Sleeper;
-import org.joda.time.Duration;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-/** Batches together and writes mutations to Google Cloud Spanner. */
-@VisibleForTesting
-class SpannerWriteGroupFn extends DoFn<MutationGroup, Void> {
-  private static final Logger LOG = LoggerFactory.getLogger(SpannerWriteGroupFn.class);
-  private final SpannerIO.Write spec;
-  // Current batch of mutations to be written.
-  private List<MutationGroup> mutations;
-  private long batchSizeBytes = 0;
-
-  private static final int MAX_RETRIES = 5;
-  private static final FluentBackoff BUNDLE_WRITE_BACKOFF =
-      FluentBackoff.DEFAULT
-          .withMaxRetries(MAX_RETRIES)
-          .withInitialBackoff(Duration.standardSeconds(5));
-
-  private transient SpannerAccessor spannerAccessor;
-
-  @VisibleForTesting
-  SpannerWriteGroupFn(SpannerIO.Write spec) {
-    this.spec = spec;
-  }
-
-  @Setup
-  public void setup() throws Exception {
-    spannerAccessor = spec.getSpannerConfig().connectToSpanner();
-    mutations = new ArrayList<>();
-    batchSizeBytes = 0;
-  }
-
-  @Teardown
-  public void teardown() throws Exception {
-    spannerAccessor.close();
-  }
-
-  @ProcessElement
-  public void processElement(ProcessContext c) throws Exception {
-    MutationGroup m = c.element();
-    mutations.add(m);
-    batchSizeBytes += MutationSizeEstimator.sizeOf(m);
-    if (batchSizeBytes >= spec.getBatchSizeBytes()) {
-      flushBatch();
-    }
-  }
-
-  @FinishBundle
-  public void finishBundle() throws Exception {
-    if (!mutations.isEmpty()) {
-      flushBatch();
-    }
-  }
-
-  /**
-   * Writes a batch of mutations to Cloud Spanner.
-   *
-   * <p>If a commit fails, it will be retried up to {@link #MAX_RETRIES} times. If the retry limit
-   * is exceeded, the last exception from Cloud Spanner will be thrown.
-   *
-   * @throws AbortedException if the commit fails or IOException or InterruptedException if
-   *     backing off between retries fails.
-   */
-  private void flushBatch() throws AbortedException, IOException, InterruptedException {
-    LOG.debug("Writing batch of {} mutations", mutations.size());
-    Sleeper sleeper = Sleeper.DEFAULT;
-    BackOff backoff = BUNDLE_WRITE_BACKOFF.backoff();
-
-    DatabaseClient databaseClient = spannerAccessor.getDatabaseClient();
-    while (true) {
-      // Batch upsert rows.
-      try {
-        databaseClient.writeAtLeastOnce(Iterables.concat(mutations));
-
-        // Break if the commit threw no exception.
-        break;
-      } catch (AbortedException exception) {
-        // Only log the code and message for potentially-transient errors. The entire exception
-        // will be propagated upon the last retry.
-        LOG.error(
-            "Error writing to Spanner ({}): {}", exception.getCode(), exception.getMessage());
-        if (!BackOffUtils.next(sleeper, backoff)) {
-          LOG.error("Aborting after {} retries.", MAX_RETRIES);
-          throw exception;
-        }
-      }
-    }
-    LOG.debug("Successfully wrote {} mutations", mutations.size());
-    mutations = new ArrayList<>();
-    batchSizeBytes = 0;
-  }
-
-  @Override
-  public void populateDisplayData(DisplayData.Builder builder) {
-    super.populateDisplayData(builder);
-    spec.populateDisplayData(builder);
-  }
-}

http://git-wip-us.apache.org/repos/asf/beam/blob/227801b3/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java
index 53783d1..de1d403 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java
@@ -21,21 +21,38 @@ import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisp
 import static org.hamcrest.Matchers.hasSize;
 import static org.junit.Assert.assertThat;
 import static org.mockito.Mockito.argThat;
+import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
-import com.google.cloud.spanner.DatabaseId;
+import com.google.cloud.spanner.Key;
+import com.google.cloud.spanner.KeyRange;
+import com.google.cloud.spanner.KeySet;
 import com.google.cloud.spanner.Mutation;
+import com.google.cloud.spanner.ReadOnlyTransaction;
+import com.google.cloud.spanner.ResultSets;
+import com.google.cloud.spanner.Statement;
+import com.google.cloud.spanner.Struct;
+import com.google.cloud.spanner.Type;
+import com.google.cloud.spanner.Value;
+import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Iterables;
 import java.io.Serializable;
+import java.util.ArrayList;
 import java.util.Arrays;
-
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
 import org.apache.beam.sdk.testing.NeedsRunner;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Create;
-import org.apache.beam.sdk.transforms.DoFnTester;
+import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
+import org.hamcrest.Description;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
@@ -50,17 +67,68 @@ import org.mockito.ArgumentMatcher;
  */
 @RunWith(JUnit4.class)
 public class SpannerIOWriteTest implements Serializable {
-  @Rule public final transient TestPipeline pipeline = TestPipeline.create();
+
+  @Rule public transient TestPipeline pipeline = TestPipeline.create();
   @Rule public transient ExpectedException thrown = ExpectedException.none();
 
   private FakeServiceFactory serviceFactory;
 
-  @Before
-  @SuppressWarnings("unchecked")
-  public void setUp() throws Exception {
+  @Before @SuppressWarnings("unchecked") public void setUp() throws Exception {
     serviceFactory = new FakeServiceFactory();
+
+    ReadOnlyTransaction tx = mock(ReadOnlyTransaction.class);
+    when(serviceFactory.mockDatabaseClient().readOnlyTransaction()).thenReturn(tx);
+
+    // Simplest schema: a table with int64 key
+    preparePkMetadata(tx, Arrays.asList(pkMetadata("test", "key", "ASC")));
+    prepareColumnMetadata(tx, Arrays.asList(columnMetadata("test", "key", "INT64")));
+  }
+
+  private static Struct columnMetadata(String tableName, String columnName, String type) {
+    return Struct.newBuilder().add("table_name", Value.string(tableName))
+        .add("column_name", Value.string(columnName)).add("spanner_type", Value.string(type))
+        .build();
+  }
+
+  private static Struct pkMetadata(String tableName, String columnName, String ordering) {
+    return Struct.newBuilder().add("table_name", Value.string(tableName))
+        .add("column_name", Value.string(columnName)).add("column_ordering", Value.string(ordering))
+        .build();
+  }
+
+  private void prepareColumnMetadata(ReadOnlyTransaction tx, List<Struct> rows) {
+    Type type = Type.struct(Type.StructField.of("table_name", Type.string()),
+        Type.StructField.of("column_name", Type.string()),
+        Type.StructField.of("spanner_type", Type.string()));
+    when(tx.executeQuery(argThat(new ArgumentMatcher<Statement>() {
+
+      @Override public boolean matches(Object argument) {
+        if (!(argument instanceof Statement)) {
+          return false;
+        }
+        Statement st = (Statement) argument;
+        return st.getSql().contains("information_schema.columns");
+      }
+    }))).thenReturn(ResultSets.forRows(type, rows));
+  }
+
+  private void preparePkMetadata(ReadOnlyTransaction tx, List<Struct> rows) {
+    Type type = Type.struct(Type.StructField.of("table_name", Type.string()),
+        Type.StructField.of("column_name", Type.string()),
+        Type.StructField.of("column_ordering", Type.string()));
+    when(tx.executeQuery(argThat(new ArgumentMatcher<Statement>() {
+
+      @Override public boolean matches(Object argument) {
+        if (!(argument instanceof Statement)) {
+          return false;
+        }
+        Statement st = (Statement) argument;
+        return st.getSql().contains("information_schema.index_columns");
+      }
+    }))).thenReturn(ResultSets.forRows(type, rows));
   }
 
+
   @Test
   public void emptyTransform() throws Exception {
     SpannerIO.Write write = SpannerIO.write();
@@ -88,7 +156,7 @@ public class SpannerIOWriteTest implements Serializable {
   @Test
   @Category(NeedsRunner.class)
   public void singleMutationPipeline() throws Exception {
-    Mutation mutation = Mutation.newInsertOrUpdateBuilder("test").set("one").to(2).build();
+    Mutation mutation = m(2L);
     PCollection<Mutation> mutations = pipeline.apply(Create.of(mutation));
 
     mutations.apply(
@@ -98,20 +166,17 @@ public class SpannerIOWriteTest implements Serializable {
             .withDatabaseId("test-database")
             .withServiceFactory(serviceFactory));
     pipeline.run();
-    verify(serviceFactory.mockSpanner())
-        .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database"));
-    verify(serviceFactory.mockDatabaseClient(), times(1))
-        .writeAtLeastOnce(argThat(new IterableOfSize(1)));
+
+    verifyBatches(
+        batch(m(2L))
+    );
   }
 
   @Test
   @Category(NeedsRunner.class)
   public void singleMutationGroupPipeline() throws Exception {
-    Mutation one = Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build();
-    Mutation two = Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build();
-    Mutation three = Mutation.newInsertOrUpdateBuilder("test").set("three").to(3).build();
     PCollection<MutationGroup> mutations = pipeline
-        .apply(Create.<MutationGroup>of(g(one, two, three)));
+        .apply(Create.<MutationGroup>of(g(m(1L), m(2L), m(3L))));
     mutations.apply(
         SpannerIO.write()
             .withProjectId("test-project")
@@ -120,106 +185,195 @@ public class SpannerIOWriteTest implements Serializable {
             .withServiceFactory(serviceFactory)
             .grouped());
     pipeline.run();
-    verify(serviceFactory.mockSpanner())
-        .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database"));
-    verify(serviceFactory.mockDatabaseClient(), times(1))
-        .writeAtLeastOnce(argThat(new IterableOfSize(3)));
+
+    verifyBatches(
+        batch(m(1L), m(2L), m(3L))
+    );
   }
 
   @Test
+  @Category(NeedsRunner.class)
   public void batching() throws Exception {
-    MutationGroup one = g(Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build());
-    MutationGroup two = g(Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build());
-    SpannerIO.Write write =
-        SpannerIO.write()
-            .withProjectId("test-project")
-            .withInstanceId("test-instance")
-            .withDatabaseId("test-database")
-            .withBatchSizeBytes(1000000000)
-            .withServiceFactory(serviceFactory);
-    SpannerWriteGroupFn writerFn = new SpannerWriteGroupFn(write);
-    DoFnTester<MutationGroup, Void> fnTester = DoFnTester.of(writerFn);
-    fnTester.processBundle(Arrays.asList(one, two));
-
-    verify(serviceFactory.mockSpanner())
-        .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database"));
-    verify(serviceFactory.mockDatabaseClient(), times(1))
-        .writeAtLeastOnce(argThat(new IterableOfSize(2)));
+    MutationGroup one = g(m(1L));
+    MutationGroup two = g(m(2L));
+    PCollection<MutationGroup> mutations = pipeline.apply(Create.of(one, two));
+    mutations.apply(SpannerIO.write()
+        .withProjectId("test-project")
+        .withInstanceId("test-instance")
+        .withDatabaseId("test-database")
+        .withServiceFactory(serviceFactory)
+        .withBatchSizeBytes(1000000000)
+        .withSampler(fakeSampler(m(1000L)))
+        .grouped());
+    pipeline.run();
+
+    verifyBatches(
+        batch(m(1L), m(2L))
+    );
   }
 
   @Test
+  @Category(NeedsRunner.class)
+  public void batchingWithDeletes() throws Exception {
+    PCollection<MutationGroup> mutations = pipeline
+        .apply(Create.of(g(m(1L)), g(m(2L)), g(del(3L)), g(del(4L))));
+    mutations.apply(SpannerIO.write()
+        .withProjectId("test-project")
+        .withInstanceId("test-instance")
+        .withDatabaseId("test-database")
+        .withServiceFactory(serviceFactory)
+        .withBatchSizeBytes(1000000000)
+        .withSampler(fakeSampler(m(1000L)))
+        .grouped());
+    pipeline.run();
+
+    verifyBatches(
+        batch(m(1L), m(2L), del(3L), del(4L))
+    );
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void noBatchingRangeDelete() throws Exception {
+    Mutation all = Mutation.delete("test", KeySet.all());
+    Mutation prefix = Mutation.delete("test", KeySet.prefixRange(Key.of(1L)));
+    Mutation range = Mutation.delete("test", KeySet.range(KeyRange.openOpen(Key.of(1L), Key
+        .newBuilder().build())));
+
+    PCollection<MutationGroup> mutations = pipeline.apply(Create
+        .of(
+            g(m(1L)),
+            g(m(2L)),
+            g(del(5L, 6L)),
+            g(delRange(50L, 55L)),
+            g(delRange(11L, 20L)),
+            g(all),
+            g(prefix), g(range)
+        )
+    );
+    mutations.apply(SpannerIO.write()
+        .withProjectId("test-project")
+        .withInstanceId("test-instance")
+        .withDatabaseId("test-database")
+        .withServiceFactory(serviceFactory)
+        .withBatchSizeBytes(1000000000)
+        .withSampler(fakeSampler(m(1000L)))
+        .grouped());
+    pipeline.run();
+
+    verifyBatches(
+        batch(m(1L), m(2L)),
+        batch(del(5L, 6L)),
+        batch(delRange(11L, 20L)),
+        batch(delRange(50L, 55L)),
+        batch(all),
+        batch(prefix),
+        batch(range)
+    );
+  }
+
+  private void verifyBatches(Iterable<Mutation>... batches) {
+    for (Iterable<Mutation> b : batches) {
+      verify(serviceFactory.mockDatabaseClient(), times(1)).writeAtLeastOnce(mutationsInNoOrder(b));
+    }
+
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
   public void batchingGroups() throws Exception {
-    MutationGroup one = g(Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build());
-    MutationGroup two = g(Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build());
-    MutationGroup three = g(Mutation.newInsertOrUpdateBuilder("test").set("three").to(3).build());
 
     // Have a room to accumulate one more item.
-    long batchSize = MutationSizeEstimator.sizeOf(one) + 1;
+    long batchSize = MutationSizeEstimator.sizeOf(g(m(1L))) + 1;
 
-    SpannerIO.Write write =
-        SpannerIO.write()
-            .withProjectId("test-project")
-            .withInstanceId("test-instance")
-            .withDatabaseId("test-database")
-            .withBatchSizeBytes(batchSize)
-            .withServiceFactory(serviceFactory);
-    SpannerWriteGroupFn writerFn = new SpannerWriteGroupFn(write);
-    DoFnTester<MutationGroup, Void> fnTester = DoFnTester.of(writerFn);
-    fnTester.processBundle(Arrays.asList(one, two, three));
-
-    verify(serviceFactory.mockSpanner())
-        .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database"));
+    PCollection<MutationGroup> mutations = pipeline.apply(Create.of(g(m(1L)), g(m(2L)), g(m(3L))));
+    mutations.apply(SpannerIO.write()
+        .withProjectId("test-project")
+        .withInstanceId("test-instance")
+        .withDatabaseId("test-database")
+        .withServiceFactory(serviceFactory)
+        .withBatchSizeBytes(batchSize)
+        .withSampler(fakeSampler(m(1000L)))
+        .grouped());
+
+    pipeline.run();
+
+    // The content of batches is not deterministic. Just verify that the size is correct.
     verify(serviceFactory.mockDatabaseClient(), times(1))
-        .writeAtLeastOnce(argThat(new IterableOfSize(2)));
+        .writeAtLeastOnce(iterableOfSize(2));
     verify(serviceFactory.mockDatabaseClient(), times(1))
-        .writeAtLeastOnce(argThat(new IterableOfSize(1)));
+        .writeAtLeastOnce(iterableOfSize(1));
   }
 
   @Test
+  @Category(NeedsRunner.class)
   public void noBatching() throws Exception {
-    MutationGroup one = g(Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build());
-    MutationGroup two = g(Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build());
-    SpannerIO.Write write =
-        SpannerIO.write()
-            .withProjectId("test-project")
-            .withInstanceId("test-instance")
-            .withDatabaseId("test-database")
-            .withBatchSizeBytes(0) // turn off batching.
-            .withServiceFactory(serviceFactory);
-    SpannerWriteGroupFn writerFn = new SpannerWriteGroupFn(write);
-    DoFnTester<MutationGroup, Void> fnTester = DoFnTester.of(writerFn);
-    fnTester.processBundle(Arrays.asList(one, two));
-
-    verify(serviceFactory.mockSpanner())
-        .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database"));
-    verify(serviceFactory.mockDatabaseClient(), times(2))
-        .writeAtLeastOnce(argThat(new IterableOfSize(1)));
+    PCollection<MutationGroup> mutations = pipeline.apply(Create.of(g(m(1L)), g(m(2L))));
+    mutations.apply(SpannerIO.write()
+        .withProjectId("test-project")
+        .withInstanceId("test-instance")
+        .withDatabaseId("test-database")
+        .withServiceFactory(serviceFactory)
+        .withBatchSizeBytes(1)
+        .withSampler(fakeSampler(m(1000L)))
+        .grouped());
+    pipeline.run();
+
+    verifyBatches(
+        batch(m(1L)),
+        batch(m(2L))
+    );
   }
 
   @Test
-  public void groups() throws Exception {
-    Mutation one = Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build();
-    Mutation two = Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build();
-    Mutation three = Mutation.newInsertOrUpdateBuilder("test").set("three").to(3).build();
+  @Category(NeedsRunner.class)
+  public void batchingPlusSampling() throws Exception {
+    PCollection<MutationGroup> mutations = pipeline
+        .apply(Create.of(
+            g(m(1L)), g(m(2L)), g(m(3L)), g(m(4L)),  g(m(5L)),
+            g(m(6L)), g(m(7L)), g(m(8L)), g(m(9L)),  g(m(10L)))
+        );
 
-    // Smallest batch size
-    long batchSize = 1;
+    mutations.apply(SpannerIO.write()
+        .withProjectId("test-project")
+        .withInstanceId("test-instance")
+        .withDatabaseId("test-database")
+        .withServiceFactory(serviceFactory)
+        .withBatchSizeBytes(1000000000)
+        .withSampler(fakeSampler(m(2L), m(5L), m(10L)))
+        .grouped());
+    pipeline.run();
 
-    SpannerIO.Write write =
-        SpannerIO.write()
-            .withProjectId("test-project")
-            .withInstanceId("test-instance")
-            .withDatabaseId("test-database")
-            .withBatchSizeBytes(batchSize)
-            .withServiceFactory(serviceFactory);
-    SpannerWriteGroupFn writerFn = new SpannerWriteGroupFn(write);
-    DoFnTester<MutationGroup, Void> fnTester = DoFnTester.of(writerFn);
-    fnTester.processBundle(Arrays.asList(g(one, two, three)));
-
-    verify(serviceFactory.mockSpanner())
-        .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database"));
-    verify(serviceFactory.mockDatabaseClient(), times(1))
-        .writeAtLeastOnce(argThat(new IterableOfSize(3)));
+    verifyBatches(
+        batch(m(1L), m(2L)),
+        batch(m(3L), m(4L), m(5L)),
+        batch(m(6L), m(7L), m(8L), m(9L), m(10L))
+    );
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void noBatchingPlusSampling() throws Exception {
+    PCollection<MutationGroup> mutations = pipeline
+        .apply(Create.of(g(m(1L)), g(m(2L)), g(m(3L)), g(m(4L)), g(m(5L))));
+    mutations.apply(SpannerIO.write()
+        .withProjectId("test-project")
+        .withInstanceId("test-instance")
+        .withDatabaseId("test-database")
+        .withServiceFactory(serviceFactory)
+        .withBatchSizeBytes(1)
+        .withSampler(fakeSampler(m(2L)))
+        .grouped());
+
+    pipeline.run();
+
+    verifyBatches(
+        batch(m(1L)),
+        batch(m(2L)),
+        batch(m(3L)),
+        batch(m(4L)),
+        batch(m(5L))
+    );
   }
 
   @Test
@@ -239,20 +393,105 @@ public class SpannerIOWriteTest implements Serializable {
     assertThat(data, hasDisplayItem("batchSizeBytes", 123));
   }
 
-  private static class IterableOfSize extends ArgumentMatcher<Iterable<Mutation>> {
-    private final int size;
+  private static MutationGroup g(Mutation m, Mutation... other) {
+    return MutationGroup.create(m, other);
+  }
 
-    private IterableOfSize(int size) {
-      this.size = size;
-    }
+  private static Mutation m(Long key) {
+    return Mutation.newInsertOrUpdateBuilder("test").set("key").to(key).build();
+  }
 
-    @Override
-    public boolean matches(Object argument) {
-      return argument instanceof Iterable && Iterables.size((Iterable<?>) argument) == size;
+  private static Iterable<Mutation> batch(Mutation... m) {
+    return Arrays.asList(m);
+  }
+
+  private static Mutation del(Long... keys) {
+
+    KeySet.Builder builder = KeySet.newBuilder();
+    for (Long key : keys) {
+      builder.addKey(Key.of(key));
     }
+    return Mutation.delete("test", builder.build());
   }
 
-  private static MutationGroup g(Mutation m, Mutation... other) {
-    return MutationGroup.create(m, other);
+  private static Mutation delRange(Long start, Long end) {
+    return Mutation.delete("test", KeySet.range(KeyRange.closedClosed(Key.of(start), Key.of(end))));
+  }
+
+  private static Iterable<Mutation> mutationsInNoOrder(Iterable<Mutation> expected) {
+    final ImmutableSet<Mutation> mutations = ImmutableSet.copyOf(expected);
+    return argThat(new ArgumentMatcher<Iterable<Mutation>>() {
+
+      @Override
+      public boolean matches(Object argument) {
+        if (!(argument instanceof Iterable)) {
+          return false;
+        }
+        ImmutableSet<Mutation> actual = ImmutableSet.copyOf((Iterable) argument);
+        return actual.equals(mutations);
+      }
+
+      @Override
+      public void describeTo(Description description) {
+        description.appendText("Iterable must match ").appendValue(mutations);
+      }
+
+    });
+  }
+
+  private Iterable<Mutation> iterableOfSize(final int size) {
+    return argThat(new ArgumentMatcher<Iterable<Mutation>>() {
+
+      @Override
+      public boolean matches(Object argument) {
+        return argument instanceof Iterable && Iterables.size((Iterable<?>) argument) == size;
+      }
+
+      @Override
+      public void describeTo(Description description) {
+        description.appendText("The size of the iterable must equal ").appendValue(size);
+      }
+    });
+  }
+
+  private static FakeSampler fakeSampler(Mutation... mutations) {
+    SpannerSchema.Builder schema = SpannerSchema.builder();
+    schema.addColumn("test", "key", "INT64");
+    schema.addKeyPart("test", "key", false);
+    return new FakeSampler(schema.build(), Arrays.asList(mutations));
+  }
+
+  private static class FakeSampler
+      extends PTransform<PCollection<KV<String, byte[]>>, PCollection<KV<String, List<byte[]>>>> {
+
+    private final SpannerSchema schema;
+    private final List<Mutation> mutations;
+
+    private FakeSampler(SpannerSchema schema, List<Mutation> mutations) {
+      this.schema = schema;
+      this.mutations = mutations;
+    }
+
+    @Override
+    public PCollection<KV<String, List<byte[]>>> expand(
+        PCollection<KV<String, byte[]>> input) {
+      MutationGroupEncoder coder = new MutationGroupEncoder(schema);
+      Map<String, List<byte[]>> map = new HashMap<>();
+      for (Mutation m : mutations) {
+        String table = m.getTable();
+        List<byte[]> list = map.get(table);
+        if (list == null) {
+          list = new ArrayList<>();
+          map.put(table, list);
+        }
+        list.add(coder.encodeKey(m));
+      }
+      List<KV<String, List<byte[]>>> result = new ArrayList<>();
+      for (Map.Entry<String, List<byte[]>> entry : map.entrySet()) {
+        Collections.sort(entry.getValue(), SpannerIO.SerializableBytesComparator.INSTANCE);
+        result.add(KV.of(entry.getKey(), entry.getValue()));
+      }
+      return input.getPipeline().apply(Create.of(result));
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/227801b3/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java
index d208f5c..89be159 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java
@@ -119,7 +119,8 @@ public class SpannerWriteIT {
 
   @Test
   public void testWrite() throws Exception {
-    p.apply(GenerateSequence.from(0).to(100))
+    int numRecords = 100;
+    p.apply(GenerateSequence.from(0).to(numRecords))
         .apply(ParDo.of(new GenerateMutations(options.getTable())))
         .apply(
             SpannerIO.write()
@@ -138,7 +139,7 @@ public class SpannerWriteIT {
             .singleUse()
             .executeQuery(Statement.of("SELECT COUNT(*) FROM " + options.getTable()));
     assertThat(resultSet.next(), is(true));
-    assertThat(resultSet.getLong(0), equalTo(100L));
+    assertThat(resultSet.getLong(0), equalTo((long) numRecords));
     assertThat(resultSet.next(), is(false));
   }
 


Mime
View raw message