beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From da...@apache.org
Subject [1/2] incubator-beam git commit: Add DatastoreIO to Python SDK
Date Wed, 23 Nov 2016 18:42:33 GMT
Repository: incubator-beam
Updated Branches:
  refs/heads/python-sdk 21f9c6d2c -> 9b9d016c8


Add DatastoreIO to Python SDK


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/2b69cce0
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/2b69cce0
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/2b69cce0

Branch: refs/heads/python-sdk
Commit: 2b69cce0f311a2ef40fdef4fe60d3e6fc13a8868
Parents: 21f9c6d
Author: Vikas Kedigehalli <vikasrk@google.com>
Authored: Tue Nov 15 16:41:24 2016 -0800
Committer: Davor Bonaci <davor@google.com>
Committed: Wed Nov 23 10:42:00 2016 -0800

----------------------------------------------------------------------
 .../apache_beam/examples/datastore_wordcount.py | 118 ++++++++
 .../apache_beam/io/datastore/v1/datastoreio.py  | 287 +++++++++++++++++++
 .../io/datastore/v1/datastoreio_test.py         | 172 +++++++++++
 .../io/datastore/v1/fake_datastore.py           |  75 +++++
 .../apache_beam/io/datastore/v1/helper.py       | 152 ++++++++++
 .../apache_beam/io/datastore/v1/helper_test.py  | 125 +++++++-
 .../io/datastore/v1/query_splitter_test.py      |  62 +---
 7 files changed, 930 insertions(+), 61 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2b69cce0/sdks/python/apache_beam/examples/datastore_wordcount.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/examples/datastore_wordcount.py b/sdks/python/apache_beam/examples/datastore_wordcount.py
new file mode 100644
index 0000000..af75b1c
--- /dev/null
+++ b/sdks/python/apache_beam/examples/datastore_wordcount.py
@@ -0,0 +1,118 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""A word-counting workflow that uses Google Cloud Datastore."""
+
+from __future__ import absolute_import
+
+import argparse
+import logging
+import re
+
+import apache_beam as beam
+from apache_beam.io.datastore.v1.datastoreio import ReadFromDatastore
+from apache_beam.utils.options import GoogleCloudOptions
+from apache_beam.utils.options import PipelineOptions
+from apache_beam.utils.options import SetupOptions
+from google.datastore.v1 import query_pb2
+
+
+empty_line_aggregator = beam.Aggregator('emptyLines')
+average_word_size_aggregator = beam.Aggregator('averageWordLength',
+                                               beam.combiners.MeanCombineFn(),
+                                               float)
+
+
+class WordExtractingDoFn(beam.DoFn):
+  """Parse each line of input text into words."""
+
+  def process(self, context):
+    """Returns an iterator over the words of this element.
+    The element is a line of text.  If the line is blank, note that, too.
+    Args:
+      context: the call-specific context: data and aggregator.
+    Returns:
+      The processed element.
+    """
+    content_value = context.element.properties.get('content', None)
+    text_line = ''
+    if content_value:
+      text_line = content_value.string_value
+
+    if not text_line:
+      context.aggregate_to(empty_line_aggregator, 1)
+    words = re.findall(r'[A-Za-z\']+', text_line)
+    for w in words:
+      context.aggregate_to(average_word_size_aggregator, len(w))
+    return words
+
+
+def run(argv=None):
+  """Main entry point; defines and runs the wordcount pipeline."""
+
+  parser = argparse.ArgumentParser()
+  parser.add_argument('--kind',
+                      dest='kind',
+                      required=True,
+                      help='Datastore Kind')
+  parser.add_argument('--namespace',
+                      dest='namespace',
+                      help='Datastore Namespace')
+  parser.add_argument('--output',
+                      dest='output',
+                      required=True,
+                      help='Output file to write results to.')
+  known_args, pipeline_args = parser.parse_known_args(argv)
+  # We use the save_main_session option because one or more DoFn's in this
+  # workflow rely on global context (e.g., a module imported at module level).
+  pipeline_options = PipelineOptions(pipeline_args)
+  pipeline_options.view_as(SetupOptions).save_main_session = True
+  gcloud_options = pipeline_options.view_as(GoogleCloudOptions)
+  p = beam.Pipeline(options=pipeline_options)
+
+  query = query_pb2.Query()
+  query.kind.add().name = known_args.kind
+
+  # Read entities from Cloud Datastore into a PCollection.
+  lines = p | 'read from datastore' >> ReadFromDatastore(
+      gcloud_options.project, query, known_args.namespace)
+
+  # Count the occurrences of each word.
+  counts = (lines
+            | 'split' >> (beam.ParDo(WordExtractingDoFn())
+                          .with_output_types(unicode))
+            | 'pair_with_one' >> beam.Map(lambda x: (x, 1))
+            | 'group' >> beam.GroupByKey()
+            | 'count' >> beam.Map(lambda (word, ones): (word, sum(ones))))
+
+  # Format the counts into a PCollection of strings.
+  output = counts | 'format' >> beam.Map(lambda (word, c): '%s: %s' % (word, c))
+
+  # Write the output using a "Write" transform that has side effects.
+  # pylint: disable=expression-not-assigned
+  output | 'write' >> beam.io.Write(beam.io.TextFileSink(known_args.output))
+
+  # Actually run the pipeline (all operations above are deferred).
+  result = p.run()
+  empty_line_values = result.aggregated_values(empty_line_aggregator)
+  logging.info('number of empty lines: %d', sum(empty_line_values.values()))
+  word_length_values = result.aggregated_values(average_word_size_aggregator)
+  logging.info('average word lengths: %s', word_length_values.values())
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  run()

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2b69cce0/sdks/python/apache_beam/io/datastore/v1/datastoreio.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/datastore/v1/datastoreio.py b/sdks/python/apache_beam/io/datastore/v1/datastoreio.py
new file mode 100644
index 0000000..d542439
--- /dev/null
+++ b/sdks/python/apache_beam/io/datastore/v1/datastoreio.py
@@ -0,0 +1,287 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""A connector for reading from and writing to Google Cloud Datastore"""
+
+import logging
+
+from googledatastore import helper as datastore_helper
+
+from apache_beam.io.datastore.v1 import helper
+from apache_beam.io.datastore.v1 import query_splitter
+from apache_beam.transforms import Create
+from apache_beam.transforms import DoFn
+from apache_beam.transforms import FlatMap
+from apache_beam.transforms import GroupByKey
+from apache_beam.transforms import PTransform
+from apache_beam.transforms import ParDo
+from apache_beam.transforms.util import Values
+
+__all__ = ['ReadFromDatastore']
+
+
+class ReadFromDatastore(PTransform):
+  """A ``PTransform`` for reading from Google Cloud Datastore.
+
+  To read a ``PCollection[Entity]`` from a Cloud Datastore ``Query``, use
+  ``ReadFromDatastore`` transform by providing a `project` id and a `query` to
+  read from. You can optionally provide a `namespace` and/or specify how many
+  splits you want for the query through `num_splits` option.
+
+  Note: Normally, a runner will read from Cloud Datastore in parallel across
+  many workers. However, when the `query` is configured with a `limit` or if the
+  query contains inequality filters like `GREATER_THAN, LESS_THAN` etc., then
+  all the returned results will be read by a single worker in order to ensure
+  correct data. Since data is read from a single worker, this could have
+  significant impact on the performance of the job.
+
+  The semantics for the query splitting is defined below:
+    1. If `num_splits` is equal to 0, then the number of splits will be chosen
+    dynamically at runtime based on the query data size.
+
+    2. Any value of `num_splits` greater than
+    `ReadFromDatastore._NUM_QUERY_SPLITS_MAX` will be capped at that value.
+
+    3. If the `query` has a user limit set, or contains inequality filters, then
+    `num_splits` will be ignored and no split will be performed.
+
+    4. Under certain cases Cloud Datastore is unable to split query to the
+    requested number of splits. In such cases we just use whatever the Cloud
+    Datastore returns.
+
+  See https://developers.google.com/datastore/ for more details on Google Cloud
+  Datastore.
+  """
+
+  # An upper bound on the number of splits for a query.
+  _NUM_QUERY_SPLITS_MAX = 50000
+  # A lower bound on the number of splits for a query. This is to ensure that
+  # we parellelize the query even when Datastore statistics are not available.
+  _NUM_QUERY_SPLITS_MIN = 12
+  # Default bundle size of 64MB.
+  _DEFAULT_BUNDLE_SIZE_BYTES = 64 * 1024 * 1024
+
+  def __init__(self, project, query, namespace=None, num_splits=0):
+    """Initialize the ReadFromDatastore transform.
+
+    Args:
+      project: The Project ID
+      query: Cloud Datastore query to be read from.
+      namespace: An optional namespace.
+      num_splits: Number of splits for the query.
+    """
+    super(ReadFromDatastore, self).__init__()
+
+    if not project:
+      ValueError("Project cannot be empty")
+    if not query:
+      ValueError("Query cannot be empty")
+    if num_splits < 0:
+      ValueError("num_splits must be greater than or equal 0")
+
+    self._project = project
+    # using _namespace conflicts with DisplayData._namespace
+    self._datastore_namespace = namespace
+    self._query = query
+    self._num_splits = num_splits
+
+  def apply(self, pcoll):
+    # This is a composite transform involves the following:
+    #   1. Create a singleton of the user provided `query` and apply a ``ParDo``
+    #   that splits the query into `num_splits` and assign each split query a
+    #   unique `int` as the key. The resulting output is of the type
+    #   ``PCollection[(int, Query)]``.
+    #
+    #   If the value of `num_splits` is less than or equal to 0, then the
+    #   number of splits will be computed dynamically based on the size of the
+    #   data for the `query`.
+    #
+    #   2. The resulting ``PCollection`` is sharded using a ``GroupByKey``
+    #   operation. The queries are extracted from the (int, Iterable[Query]) and
+    #   flattened to output a ``PCollection[Query]``.
+    #
+    #   3. In the third step, a ``ParDo`` reads entities for each query and
+    #   outputs a ``PCollection[Entity]``.
+
+    queries = (pcoll.pipeline
+               | 'User Query' >> Create([self._query])
+               | 'Split Query' >> ParDo(ReadFromDatastore.SplitQueryFn(
+                   self._project, self._query, self._datastore_namespace,
+                   self._num_splits)))
+
+    sharded_queries = queries | GroupByKey() | Values() | FlatMap('flatten',
+                                                                  lambda x: x)
+
+    entities = sharded_queries | 'Read' >> ParDo(
+        ReadFromDatastore.ReadFn(self._project, self._datastore_namespace))
+    return entities
+
+  def display_data(self):
+    disp_data = {'project': self._project,
+                 'query': str(self._query),
+                 'num_splits': self._num_splits}
+
+    if self._datastore_namespace is not None:
+      disp_data['namespace'] = self._datastore_namespace
+
+    return disp_data
+
+  class SplitQueryFn(DoFn):
+    """A `DoFn` that splits a given query into multiple sub-queries."""
+    def __init__(self, project, query, namespace, num_splits):
+      super(ReadFromDatastore.SplitQueryFn, self).__init__()
+      self._datastore = None
+      self._project = project
+      self._datastore_namespace = namespace
+      self._query = query
+      self._num_splits = num_splits
+
+    def start_bundle(self, context):
+      self._datastore = helper.get_datastore(self._project)
+
+    def process(self, p_context, *args, **kwargs):
+      # distinct key to be used to group query splits.
+      key = 1
+      query = p_context.element
+
+      # If query has a user set limit, then the query cannot be split.
+      if query.HasField('limit'):
+        return [(key, query)]
+
+      # Compute the estimated numSplits if not specified by the user.
+      if self._num_splits == 0:
+        estimated_num_splits = ReadFromDatastore.get_estimated_num_splits(
+            self._project, self._datastore_namespace, self._query,
+            self._datastore)
+      else:
+        estimated_num_splits = self._num_splits
+
+      logging.info("Splitting the query into %d splits", estimated_num_splits)
+      try:
+        query_splits = query_splitter.get_splits(
+            self._datastore, query, estimated_num_splits,
+            helper.make_partition(self._project, self._datastore_namespace))
+      except Exception:
+        logging.warning("Unable to parallelize the given query: %s", query,
+                        exc_info=True)
+        query_splits = [(key, query)]
+
+      sharded_query_splits = []
+      for split_query in query_splits:
+        sharded_query_splits.append((key, split_query))
+        key += 1
+
+      return sharded_query_splits
+
+    def display_data(self):
+      disp_data = {'project': self._project,
+                   'query': str(self._query),
+                   'num_splits': self._num_splits}
+
+      if self._datastore_namespace is not None:
+        disp_data['namespace'] = self._datastore_namespace
+
+      return disp_data
+
+  class ReadFn(DoFn):
+    """A DoFn that reads entities from Cloud Datastore, for a given query."""
+    def __init__(self, project, namespace=None):
+      super(ReadFromDatastore.ReadFn, self).__init__()
+      self._project = project
+      self._datastore_namespace = namespace
+      self._datastore = None
+
+    def start_bundle(self, context):
+      self._datastore = helper.get_datastore(self._project)
+
+    def process(self, p_context, *args, **kwargs):
+      query = p_context.element
+      # Returns an iterator of entities that reads in batches.
+      entities = helper.fetch_entities(self._project, self._datastore_namespace,
+                                       query, self._datastore)
+      return entities
+
+    def display_data(self):
+      disp_data = {'project': self._project}
+
+      if self._datastore_namespace is not None:
+        disp_data['namespace'] = self._datastore_namespace
+
+      return disp_data
+
+  @staticmethod
+  def query_latest_statistics_timestamp(project, namespace, datastore):
+    """Fetches the latest timestamp of statistics from Cloud Datastore.
+
+    Cloud Datastore system tables with statistics are periodically updated.
+    This method fethes the latest timestamp (in microseconds) of statistics
+    update using the `__Stat_Total__` table.
+    """
+    query = helper.make_latest_timestamp_query(namespace)
+    req = helper.make_request(project, namespace, query)
+    resp = datastore.run_query(req)
+    if len(resp.batch.entity_results) == 0:
+      raise RuntimeError("Datastore total statistics unavailable.")
+
+    entity = resp.batch.entity_results[0].entity
+    return datastore_helper.micros_from_timestamp(
+        entity.properties['timestamp'].timestamp_value)
+
+  @staticmethod
+  def get_estimated_size_bytes(project, namespace, query, datastore):
+    """Get the estimated size of the data returned by the given query.
+
+    Cloud Datastore provides no way to get a good estimate of how large the
+    result of a query is going to be. Hence we use the __Stat_Kind__ system
+    table to get size of the entire kind as an approximate estimate, assuming
+    exactly 1 kind is specified in the query.
+    See https://cloud.google.com/datastore/docs/concepts/stats.
+    """
+    kind = query.kind[0].name
+    latest_timestamp = ReadFromDatastore.query_latest_statistics_timestamp(
+        project, namespace, datastore)
+    logging.info('Latest stats timestamp for kind %s is %s',
+                 kind, latest_timestamp)
+
+    kind_stats_query = (
+        helper.make_kind_stats_query(namespace, kind, latest_timestamp))
+
+    req = helper.make_request(project, namespace, kind_stats_query)
+    resp = datastore.run_query(req)
+    if len(resp.batch.entity_results) == 0:
+      raise RuntimeError("Datastore statistics for kind %s unavailable" % kind)
+
+    entity = resp.batch.entity_results[0].entity
+    return datastore_helper.get_value(entity.properties['entity_bytes'])
+
+  @staticmethod
+  def get_estimated_num_splits(project, namespace, query, datastore):
+    """Computes the number of splits to be performed on the given query."""
+    try:
+      estimated_size_bytes = ReadFromDatastore.get_estimated_size_bytes(
+          project, namespace, query, datastore)
+      logging.info('Estimated size bytes for query: %s', estimated_size_bytes)
+      num_splits = int(min(ReadFromDatastore._NUM_QUERY_SPLITS_MAX, round(
+          (float(estimated_size_bytes) /
+           ReadFromDatastore._DEFAULT_BUNDLE_SIZE_BYTES))))
+
+    except Exception as e:
+      logging.warning('Failed to fetch estimated size bytes: %s', e)
+      # Fallback in case estimated size is unavailable.
+      num_splits = ReadFromDatastore._NUM_QUERY_SPLITS_MIN
+
+    return max(num_splits, ReadFromDatastore._NUM_QUERY_SPLITS_MIN)

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2b69cce0/sdks/python/apache_beam/io/datastore/v1/datastoreio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/datastore/v1/datastoreio_test.py b/sdks/python/apache_beam/io/datastore/v1/datastoreio_test.py
new file mode 100644
index 0000000..2bf01f4
--- /dev/null
+++ b/sdks/python/apache_beam/io/datastore/v1/datastoreio_test.py
@@ -0,0 +1,172 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from google.datastore.v1 import datastore_pb2
+from google.datastore.v1 import query_pb2
+from google.protobuf import timestamp_pb2
+from googledatastore import helper as datastore_helper
+from mock import MagicMock, call, patch
+from apache_beam.io.datastore.v1 import helper
+from apache_beam.io.datastore.v1 import query_splitter
+from apache_beam.io.datastore.v1.datastoreio import ReadFromDatastore
+
+
+class DatastoreioTest(unittest.TestCase):
+  _PROJECT = 'project'
+  _KIND = 'kind'
+  _NAMESPACE = 'namespace'
+
+  def setUp(self):
+    self._mock_datastore = MagicMock()
+    self._query = query_pb2.Query()
+    self._query.kind.add().name = self._KIND
+
+  def test_get_estimated_size_bytes_without_namespace(self):
+    entity_bytes = 100
+    timestamp = timestamp_pb2.Timestamp(seconds=1234)
+    self.check_estimated_size_bytes(entity_bytes, timestamp)
+
+  def test_get_estimated_size_bytes_with_namespace(self):
+    entity_bytes = 100
+    timestamp = timestamp_pb2.Timestamp(seconds=1234)
+    self.check_estimated_size_bytes(entity_bytes, timestamp, self._NAMESPACE)
+
+  def test_SplitQueryFn_with_num_splits(self):
+    with patch.object(helper, 'get_datastore',
+                      return_value=self._mock_datastore):
+      num_splits = 23
+
+      def fake_get_splits(datastore, query, num_splits, partition=None):
+        return self.split_query(query, num_splits)
+
+      with patch.object(query_splitter, 'get_splits',
+                        side_effect=fake_get_splits):
+
+        split_query_fn = ReadFromDatastore.SplitQueryFn(
+            self._PROJECT, self._query, None, num_splits)
+        mock_context = MagicMock()
+        mock_context.element = self._query
+        split_query_fn.start_bundle(mock_context)
+        returned_split_queries = []
+        for split_query in split_query_fn.process(mock_context):
+          returned_split_queries.append(split_query)
+
+        self.assertEqual(len(returned_split_queries), num_splits)
+        self.assertEqual(0, len(self._mock_datastore.run_query.call_args_list))
+        self.verify_unique_keys(returned_split_queries)
+
+  def test_SplitQueryFn_without_num_splits(self):
+    with patch.object(helper, 'get_datastore',
+                      return_value=self._mock_datastore):
+      # Force SplitQueryFn to compute the number of query splits
+      num_splits = 0
+      expected_num_splits = 23
+      entity_bytes = (expected_num_splits *
+                      ReadFromDatastore._DEFAULT_BUNDLE_SIZE_BYTES)
+      with patch.object(ReadFromDatastore, 'get_estimated_size_bytes',
+                        return_value=entity_bytes):
+
+        def fake_get_splits(datastore, query, num_splits, partition=None):
+          return self.split_query(query, num_splits)
+
+        with patch.object(query_splitter, 'get_splits',
+                          side_effect=fake_get_splits):
+          split_query_fn = ReadFromDatastore.SplitQueryFn(
+              self._PROJECT, self._query, None, num_splits)
+          mock_context = MagicMock()
+          mock_context.element = self._query
+          split_query_fn.start_bundle(mock_context)
+          returned_split_queries = []
+          for split_query in split_query_fn.process(mock_context):
+            returned_split_queries.append(split_query)
+
+          self.assertEqual(len(returned_split_queries), expected_num_splits)
+          self.assertEqual(0,
+                           len(self._mock_datastore.run_query.call_args_list))
+          self.verify_unique_keys(returned_split_queries)
+
+  def test_SplitQueryFn_with_query_limit(self):
+    """A test that verifies no split is performed when the query has a limit."""
+    with patch.object(helper, 'get_datastore',
+                      return_value=self._mock_datastore):
+      self._query.limit.value = 3
+      split_query_fn = ReadFromDatastore.SplitQueryFn(
+          self._PROJECT, self._query, None, 4)
+      mock_context = MagicMock()
+      mock_context.element = self._query
+      split_query_fn.start_bundle(mock_context)
+      returned_split_queries = []
+      for split_query in split_query_fn.process(mock_context):
+        returned_split_queries.append(split_query)
+
+      self.assertEqual(1, len(returned_split_queries))
+      self.assertEqual(0, len(self._mock_datastore.method_calls))
+
+  def verify_unique_keys(self, queries):
+    """A helper function that verifies if all the queries have unique keys."""
+    keys, _ = zip(*queries)
+    keys = set(keys)
+    self.assertEqual(len(keys), len(queries))
+
+  def check_estimated_size_bytes(self, entity_bytes, timestamp, namespace=None):
+    """A helper method to test get_estimated_size_bytes"""
+
+    timestamp_req = helper.make_request(
+        self._PROJECT, namespace, helper.make_latest_timestamp_query(namespace))
+    timestamp_resp = self.make_stats_response(
+        {'timestamp': datastore_helper.from_timestamp(timestamp)})
+    kind_stat_req = helper.make_request(
+        self._PROJECT, namespace, helper.make_kind_stats_query(
+            namespace, self._query.kind[0].name,
+            datastore_helper.micros_from_timestamp(timestamp)))
+    kind_stat_resp = self.make_stats_response(
+        {'entity_bytes': entity_bytes})
+
+    def fake_run_query(req):
+      if req == timestamp_req:
+        return timestamp_resp
+      elif req == kind_stat_req:
+        return kind_stat_resp
+      else:
+        print kind_stat_req
+        raise ValueError("Unknown req: %s" % req)
+
+    self._mock_datastore.run_query.side_effect = fake_run_query
+    self.assertEqual(entity_bytes, ReadFromDatastore.get_estimated_size_bytes(
+        self._PROJECT, namespace, self._query, self._mock_datastore))
+    self.assertEqual(self._mock_datastore.run_query.call_args_list,
+                     [call(timestamp_req), call(kind_stat_req)])
+
+  def make_stats_response(self, property_map):
+    resp = datastore_pb2.RunQueryResponse()
+    entity_result = resp.batch.entity_results.add()
+    datastore_helper.add_properties(entity_result.entity, property_map)
+    return resp
+
+  def split_query(self, query, num_splits):
+    """Generate dummy query splits."""
+    split_queries = []
+    for _ in range(0, num_splits):
+      q = query_pb2.Query()
+      q.CopyFrom(query)
+      split_queries.append(q)
+    return split_queries
+
+if __name__ == '__main__':
+  unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2b69cce0/sdks/python/apache_beam/io/datastore/v1/fake_datastore.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/datastore/v1/fake_datastore.py b/sdks/python/apache_beam/io/datastore/v1/fake_datastore.py
new file mode 100644
index 0000000..631908e
--- /dev/null
+++ b/sdks/python/apache_beam/io/datastore/v1/fake_datastore.py
@@ -0,0 +1,75 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Fake datastore used for unit testing."""
+import uuid
+
+from google.datastore.v1 import datastore_pb2
+from google.datastore.v1 import query_pb2
+
+
+def create_run_query(entities, batch_size):
+  """A fake datastore run_query method that returns entities in batches.
+
+  Note: the outer method is needed to make the `entities` and `batch_size`
+  available in the scope of fake_run_query method.
+
+  Args:
+    entities: list of entities supposed to be contained in the datastore.
+    batch_size: the number of entities that run_query method returns in one
+                request.
+  """
+  def run_query(req):
+    start = int(req.query.start_cursor) if req.query.start_cursor else 0
+    # if query limit is less than batch_size, then only return that much.
+    count = min(batch_size, req.query.limit.value)
+    # cannot go more than the number of entities contained in datastore.
+    end = min(len(entities), start + count)
+    finish = False
+    # Finish reading when there are no more entities to return,
+    # or request query limit has been satisfied.
+    if end == len(entities) or count == req.query.limit.value:
+      finish = True
+    return create_response(entities[start:end], str(end), finish)
+  return run_query
+
+
+def create_response(entities, end_cursor, finish):
+  """Creates a query response for a given batch of scatter entities."""
+  resp = datastore_pb2.RunQueryResponse()
+  if finish:
+    resp.batch.more_results = query_pb2.QueryResultBatch.NO_MORE_RESULTS
+  else:
+    resp.batch.more_results = query_pb2.QueryResultBatch.NOT_FINISHED
+
+  resp.batch.end_cursor = end_cursor
+  for entity_result in entities:
+    resp.batch.entity_results.add().CopyFrom(entity_result)
+
+  return resp
+
+
+def create_entities(count):
+  """Creates a list of entities with random keys."""
+  entities = []
+
+  for _ in range(count):
+    entity_result = query_pb2.EntityResult()
+    entity_result.entity.key.path.add().name = str(uuid.uuid4())
+    entities.append(entity_result)
+
+  return entities

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2b69cce0/sdks/python/apache_beam/io/datastore/v1/helper.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/datastore/v1/helper.py b/sdks/python/apache_beam/io/datastore/v1/helper.py
index 626ab35..39ca40c 100644
--- a/sdks/python/apache_beam/io/datastore/v1/helper.py
+++ b/sdks/python/apache_beam/io/datastore/v1/helper.py
@@ -16,6 +16,18 @@
 #
 
 """Cloud Datastore helper functions."""
+import sys
+
+from google.datastore.v1 import datastore_pb2
+from google.datastore.v1 import entity_pb2
+from google.datastore.v1 import query_pb2
+from googledatastore import PropertyFilter, CompositeFilter
+from googledatastore import helper as datastore_helper
+from googledatastore.connection import Datastore
+from googledatastore.connection import RPCError
+import googledatastore
+
+from apache_beam.utils import retry
 
 
 def key_comparator(k1, k2):
@@ -82,3 +94,143 @@ def str_compare(s1, s2):
     return -1
   else:
     return 1
+
+
+def get_datastore(project):
+  """Returns a Cloud Datastore client."""
+  credentials = googledatastore.helper.get_credentials_from_env()
+  datastore = Datastore(project, credentials)
+  return datastore
+
+
+def make_request(project, namespace, query):
+  """Make a Cloud Datastore request for the given query."""
+  req = datastore_pb2.RunQueryRequest()
+  req.partition_id.CopyFrom(make_partition(project, namespace))
+
+  req.query.CopyFrom(query)
+  return req
+
+
+def make_partition(project, namespace):
+  """Make a PartitionId for the given project and namespace."""
+  partition = entity_pb2.PartitionId()
+  partition.project_id = project
+  if namespace is not None:
+    partition.namespace_id = namespace
+
+  return partition
+
+
+def retry_on_rpc_error(exception):
+  """A retry filter for Cloud Datastore RPCErrors."""
+  if isinstance(exception, RPCError):
+    if exception.code >= 500:
+      return True
+    else:
+      return False
+  else:
+    # TODO(vikasrk): Figure out what other errors should be retried.
+    return False
+
+
+def fetch_entities(project, namespace, query, datastore):
+  """A helper method to fetch entities from Cloud Datastore.
+
+  Args:
+    project: Project ID
+    namespace: Cloud Datastore namespace
+    query: Query to be read from
+    datastore: Cloud Datastore Client
+
+  Returns:
+    An iterator of entities.
+  """
+  return QueryIterator(project, namespace, query, datastore)
+
+
+def make_latest_timestamp_query(namespace):
+  """Make a Query to fetch the latest timestamp statistics."""
+  query = query_pb2.Query()
+  if namespace is None:
+    query.kind.add().name = '__Stat_Total__'
+  else:
+    query.kind.add().name = '__Stat_Ns_Total__'
+
+  # Descending order of `timestamp`
+  datastore_helper.add_property_orders(query, "-timestamp")
+  # Only get the latest entity
+  query.limit.value = 1
+  return query
+
+
+def make_kind_stats_query(namespace, kind, latest_timestamp):
+  """Make a Query to fetch the latest kind statistics."""
+  kind_stat_query = query_pb2.Query()
+  if namespace is None:
+    kind_stat_query.kind.add().name = '__Stat_Kind__'
+  else:
+    kind_stat_query.kind.add().name = '__Stat_Ns_Kind__'
+
+  kind_filter = datastore_helper.set_property_filter(
+      query_pb2.Filter(), 'kind_name', PropertyFilter.EQUAL, unicode(kind))
+  timestamp_filter = datastore_helper.set_property_filter(
+      query_pb2.Filter(), 'timestamp', PropertyFilter.EQUAL,
+      latest_timestamp)
+
+  datastore_helper.set_composite_filter(kind_stat_query.filter,
+                                        CompositeFilter.AND, kind_filter,
+                                        timestamp_filter)
+  return kind_stat_query
+
+
+class QueryIterator(object):
+  """A iterator class for entities of a given query.
+
+  Entities are read in batches. Retries on failures.
+  """
+  _NOT_FINISHED = query_pb2.QueryResultBatch.NOT_FINISHED
+  # Maximum number of results to request per query.
+  _BATCH_SIZE = 500
+
+  def __init__(self, project, namespace, query, datastore):
+    self._query = query
+    self._datastore = datastore
+    self._project = project
+    self._namespace = namespace
+    self._start_cursor = None
+    self._limit = self._query.limit.value or sys.maxint
+    self._req = make_request(project, namespace, query)
+
+  @retry.with_exponential_backoff(num_retries=5,
+                                  retry_filter=retry_on_rpc_error)
+  def _next_batch(self):
+    """Fetches the next batch of entities."""
+    if self._start_cursor is not None:
+      self._req.query.start_cursor = self._start_cursor
+
+    # set batch size
+    self._req.query.limit.value = min(self._BATCH_SIZE, self._limit)
+    resp = self._datastore.run_query(self._req)
+    return resp
+
+  def __iter__(self):
+    more_results = True
+    while more_results:
+      resp = self._next_batch()
+      for entity_result in resp.batch.entity_results:
+        yield entity_result.entity
+
+      self._start_cursor = resp.batch.end_cursor
+      num_results = len(resp.batch.entity_results)
+      self._limit -= num_results
+
+      # Check if we need to read more entities.
+      # True when query limit hasn't been satisfied and there are more entities
+      # to be read. The latter is true if the response has a status
+      # `NOT_FINISHED` or if the number of results read in the previous batch
+      # is equal to `_BATCH_SIZE` (all indications that there is more data be
+      # read).
+      more_results = ((self._limit > 0) and
+                      ((num_results == self._BATCH_SIZE) or
+                       (resp.batch.more_results == self._NOT_FINISHED)))

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2b69cce0/sdks/python/apache_beam/io/datastore/v1/helper_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/datastore/v1/helper_test.py b/sdks/python/apache_beam/io/datastore/v1/helper_test.py
index 50f8e4c..69741d2 100644
--- a/sdks/python/apache_beam/io/datastore/v1/helper_test.py
+++ b/sdks/python/apache_beam/io/datastore/v1/helper_test.py
@@ -16,14 +16,134 @@
 #
 
 """Tests for datastore helper."""
-
+import imp
+import sys
 import unittest
-from apache_beam.io.datastore.v1 import helper
+
+from google.datastore.v1 import datastore_pb2
+from google.datastore.v1 import query_pb2
 from google.datastore.v1.entity_pb2 import Key
+from googledatastore.connection import RPCError
+from mock import MagicMock, Mock, patch
+
+from apache_beam.io.datastore.v1 import fake_datastore
+from apache_beam.io.datastore.v1 import helper
+from apache_beam.utils import retry
 
 
 class HelperTest(unittest.TestCase):
 
+  def setUp(self):
+    self._mock_datastore = MagicMock()
+    self._query = query_pb2.Query()
+    self._query.kind.add().name = 'dummy_kind'
+    self.patch_retry()
+
+  def patch_retry(self):
+
+    """A function to patch retry module to use mock clock and logger."""
+    real_retry_with_exponential_backoff = retry.with_exponential_backoff
+
+    def patched_retry_with_exponential_backoff(num_retries, retry_filter):
+      """A patch for retry decorator to use a mock dummy clock and logger."""
+      return real_retry_with_exponential_backoff(
+          num_retries=num_retries, retry_filter=retry_filter, logger=Mock(),
+          clock=Mock())
+
+    patch.object(retry, 'with_exponential_backoff',
+                 side_effect=patched_retry_with_exponential_backoff).start()
+
+    # Reload module after patching.
+    imp.reload(helper)
+
+    def kill_patches():
+      patch.stopall()
+      # Reload module again after removing patch.
+      imp.reload(helper)
+
+    self.addCleanup(kill_patches)
+
+  def permanent_datastore_failure(self, req):
+    raise RPCError("dummy", 500, "failed")
+
+  def transient_datastore_failure(self, req):
+    if self._transient_fail_count:
+      self._transient_fail_count -= 1
+      raise RPCError("dummy", 500, "failed")
+    else:
+      return datastore_pb2.RunQueryResponse()
+
+  def test_query_iterator(self):
+    self._mock_datastore.run_query.side_effect = (
+        self.permanent_datastore_failure)
+    query_iterator = helper.QueryIterator("project", None, self._query,
+                                          self._mock_datastore)
+    self.assertRaises(RPCError, iter(query_iterator).next)
+    self.assertEqual(6, len(self._mock_datastore.run_query.call_args_list))
+
+  def test_query_iterator_with_transient_failures(self):
+    self._mock_datastore.run_query.side_effect = (
+        self.transient_datastore_failure)
+    query_iterator = helper.QueryIterator("project", None, self._query,
+                                          self._mock_datastore)
+    fail_count = 2
+    self._transient_fail_count = fail_count
+    for _ in query_iterator:
+      pass
+
+    self.assertEqual(fail_count + 1,
+                     len(self._mock_datastore.run_query.call_args_list))
+
+  def test_query_iterator_with_single_batch(self):
+    num_entities = 100
+    batch_size = 500
+    self.check_query_iterator(num_entities, batch_size, self._query)
+
+  def test_query_iterator_with_multiple_batches(self):
+    num_entities = 1098
+    batch_size = 500
+    self.check_query_iterator(num_entities, batch_size, self._query)
+
+  def test_query_iterator_with_exact_batch_multiple(self):
+    num_entities = 1000
+    batch_size = 500
+    self.check_query_iterator(num_entities, batch_size, self._query)
+
+  def test_query_iterator_with_query_limit(self):
+    num_entities = 1098
+    batch_size = 500
+    self._query.limit.value = 1004
+    self.check_query_iterator(num_entities, batch_size, self._query)
+
+  def test_query_iterator_with_large_query_limit(self):
+    num_entities = 1098
+    batch_size = 500
+    self._query.limit.value = 10000
+    self.check_query_iterator(num_entities, batch_size, self._query)
+
+  def check_query_iterator(self, num_entities, batch_size, query):
+    """A helper method to test the QueryIterator.
+
+    Args:
+      num_entities: number of entities contained in the fake datastore.
+      batch_size: the number of entities returned by fake datastore in one req.
+      query: the query to be executed
+
+    """
+    entities = fake_datastore.create_entities(num_entities)
+    self._mock_datastore.run_query.side_effect = \
+        fake_datastore.create_run_query(entities, batch_size)
+    query_iterator = helper.QueryIterator("project", None, self._query,
+                                          self._mock_datastore)
+
+    i = 0
+    for entity in query_iterator:
+      self.assertEqual(entity, entities[i].entity)
+      i += 1
+
+    limit = query.limit.value if query.HasField('limit') else sys.maxint
+    self.assertEqual(i, min(num_entities, limit))
+
   def test_compare_path_with_different_kind(self):
     p1 = Key.PathElement()
     p1.kind = 'dummy1'
@@ -120,5 +240,6 @@ class HelperTest(unittest.TestCase):
     p21.kind = 'dummy1'
     self.assertLess(helper.key_comparator(k1, k2), 0)
 
+
 if __name__ == '__main__':
   unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2b69cce0/sdks/python/apache_beam/io/datastore/v1/query_splitter_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/datastore/v1/query_splitter_test.py b/sdks/python/apache_beam/io/datastore/v1/query_splitter_test.py
index 979a69f..810719b 100644
--- a/sdks/python/apache_beam/io/datastore/v1/query_splitter_test.py
+++ b/sdks/python/apache_beam/io/datastore/v1/query_splitter_test.py
@@ -18,11 +18,11 @@
 """Cloud Datastore query splitter test."""
 
 import unittest
-import uuid
 
 from mock import MagicMock
 from mock import call
 
+from apache_beam.io.datastore.v1 import fake_datastore
 from apache_beam.io.datastore.v1 import query_splitter
 
 from google.datastore.v1 import datastore_pb2
@@ -150,11 +150,11 @@ class QuerySplitterTest(unittest.TestCase):
       batch_size: the number of entities returned by fake datastore in one req.
     """
 
-    entities = QuerySplitterTest.create_entities(num_entities)
+    entities = fake_datastore.create_entities(num_entities)
     mock_datastore = MagicMock()
     # Assign a fake run_query method as a side_effect to the mock.
     mock_datastore.run_query.side_effect = \
-      QuerySplitterTest.create_run_query(entities, batch_size)
+        fake_datastore.create_run_query(entities, batch_size)
 
     split_queries = query_splitter.get_splits(mock_datastore, query, num_splits)
 
@@ -173,33 +173,6 @@ class QuerySplitterTest(unittest.TestCase):
     self.assertEqual(expected_calls, mock_datastore.run_query.call_args_list)
 
   @staticmethod
-  def create_run_query(entities, batch_size):
-    """A fake datastore run_query method that returns entities in batches.
-
-    Note: the outer method is needed to make the `entities` and `batch_size`
-    available in the scope of fake_run_query method.
-
-    Args:
-      entities: list of entities supposed to be contained in the datastore.
-      batch_size: the number of entities that run_query method returns in one
-                  request.
-    """
-    def fake_run_query(req):
-      start = int(req.query.start_cursor) if req.query.start_cursor else 0
-      # if query limit is less than batch_size, then only return that much.
-      count = min(batch_size, req.query.limit.value)
-      # cannot go more than the number of entities contained in datastore.
-      end = min(len(entities), start + count)
-      finish = False
-      # Finish reading when there are no more entities to return,
-      # or request query limit has been satisfied.
-      if end == len(entities) or count == req.query.limit.value:
-        finish = True
-      return QuerySplitterTest.create_scatter_response(entities[start:end],
-                                                       str(end), finish)
-    return fake_run_query
-
-  @staticmethod
   def create_scatter_requests(query, num_splits, batch_size, num_entities):
     """Creates a list of expected scatter requests from the query splitter.
 
@@ -223,35 +196,6 @@ class QuerySplitterTest(unittest.TestCase):
 
     return requests
 
-  @staticmethod
-  def create_scatter_response(entities, end_cursor, finish):
-    """Creates a query response for a given batch of scatter entities."""
-
-    resp = datastore_pb2.RunQueryResponse()
-    if finish:
-      resp.batch.more_results = query_pb2.QueryResultBatch.NO_MORE_RESULTS
-    else:
-      resp.batch.more_results = query_pb2.QueryResultBatch.NOT_FINISHED
-
-    resp.batch.end_cursor = end_cursor
-    for entity_result in entities:
-      resp.batch.entity_results.add().CopyFrom(entity_result)
-
-    return resp
-
-  @staticmethod
-  def create_entities(count):
-    """Creates a list of entities with random keys."""
-
-    entities = []
-
-    for _ in range(0, count):
-      entity_result = query_pb2.EntityResult()
-      entity_result.entity.key.path.add().name = str(uuid.uuid4())
-      entities.append(entity_result)
-
-    return entities
-
 
 if __name__ == '__main__':
   unittest.main()



Mime
View raw message