diff --git a/google/cloud/bigtable/__init__.py b/google/cloud/bigtable/__init__.py index daa562c0c..251e41e42 100644 --- a/google/cloud/bigtable/__init__.py +++ b/google/cloud/bigtable/__init__.py @@ -22,6 +22,7 @@ from google.cloud.bigtable.client import Table from google.cloud.bigtable.read_rows_query import ReadRowsQuery +from google.cloud.bigtable.read_rows_query import RowRange from google.cloud.bigtable.row_response import RowResponse from google.cloud.bigtable.row_response import CellResponse @@ -43,6 +44,7 @@ "Table", "RowKeySamples", "ReadRowsQuery", + "RowRange", "MutationsBatcher", "Mutation", "BulkMutationsEntry", diff --git a/google/cloud/bigtable/read_rows_query.py b/google/cloud/bigtable/read_rows_query.py index 64583b2d7..9fd349d5f 100644 --- a/google/cloud/bigtable/read_rows_query.py +++ b/google/cloud/bigtable/read_rows_query.py @@ -13,36 +13,192 @@ # limitations under the License. # from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any +from .row_response import row_key +from dataclasses import dataclass +from google.cloud.bigtable.row_filters import RowFilter if TYPE_CHECKING: - from google.cloud.bigtable.row_filters import RowFilter from google.cloud.bigtable import RowKeySamples +@dataclass +class _RangePoint: + """Model class for a point in a row range""" + + key: row_key + is_inclusive: bool + + +@dataclass +class RowRange: + start: _RangePoint | None + end: _RangePoint | None + + def __init__( + self, + start_key: str | bytes | None = None, + end_key: str | bytes | None = None, + start_is_inclusive: bool | None = None, + end_is_inclusive: bool | None = None, + ): + # check for invalid combinations of arguments + if start_is_inclusive is None: + start_is_inclusive = True + elif start_key is None: + raise ValueError("start_is_inclusive must be set with start_key") + if end_is_inclusive is None: + end_is_inclusive = False + elif end_key is None: + raise ValueError("end_is_inclusive must be set with end_key") + # ensure that start_key and end_key are bytes + if isinstance(start_key, str): + start_key = start_key.encode() + elif start_key is not None and not isinstance(start_key, bytes): + raise ValueError("start_key must be a string or bytes") + if isinstance(end_key, str): + end_key = end_key.encode() + elif end_key is not None and not isinstance(end_key, bytes): + raise ValueError("end_key must be a string or bytes") + + self.start = ( + _RangePoint(start_key, start_is_inclusive) + if start_key is not None + else None + ) + self.end = ( + _RangePoint(end_key, end_is_inclusive) if end_key is not None else None + ) + + def _to_dict(self) -> dict[str, bytes]: + """Converts this object to a dictionary""" + output = {} + if self.start is not None: + key = "start_key_closed" if self.start.is_inclusive else "start_key_open" + output[key] = self.start.key + if self.end is not None: + key = "end_key_closed" if self.end.is_inclusive else "end_key_open" + output[key] = self.end.key + return output + + class ReadRowsQuery: """ Class to encapsulate details of a read row request """ def __init__( - self, row_keys: list[str | bytes] | str | bytes | None = None, limit=None + self, + row_keys: list[str | bytes] | str | bytes | None = None, + row_ranges: list[RowRange] | RowRange | None = None, + limit: int | None = None, + row_filter: RowFilter | None = None, ): - pass + """ + Create a new ReadRowsQuery - def set_limit(self, limit: int) -> ReadRowsQuery: - raise NotImplementedError + Args: + - row_keys: row keys to include in the query + a query can contain multiple keys, but ranges should be preferred + - row_ranges: ranges of rows to include in the query + - limit: the maximum number of rows to return. None or 0 means no limit + default: None (no limit) + - row_filter: a RowFilter to apply to the query + """ + self.row_keys: set[bytes] = set() + self.row_ranges: list[RowRange | dict[str, bytes]] = [] + if row_ranges: + if isinstance(row_ranges, RowRange): + row_ranges = [row_ranges] + for r in row_ranges: + self.add_range(r) + if row_keys: + if not isinstance(row_keys, list): + row_keys = [row_keys] + for k in row_keys: + self.add_key(k) + self.limit: int | None = limit + self.filter: RowFilter | dict[str, Any] | None = row_filter - def set_filter(self, filter: "RowFilter") -> ReadRowsQuery: - raise NotImplementedError + @property + def limit(self) -> int | None: + return self._limit - def add_rows(self, row_id_list: list[str]) -> ReadRowsQuery: - raise NotImplementedError + @limit.setter + def limit(self, new_limit: int | None): + """ + Set the maximum number of rows to return by this query. + + None or 0 means no limit + + Args: + - new_limit: the new limit to apply to this query + Returns: + - a reference to this query for chaining + Raises: + - ValueError if new_limit is < 0 + """ + if new_limit is not None and new_limit < 0: + raise ValueError("limit must be >= 0") + self._limit = new_limit + + @property + def filter(self) -> RowFilter | dict[str, Any] | None: + return self._filter + + @filter.setter + def filter(self, row_filter: RowFilter | dict[str, Any] | None): + """ + Set a RowFilter to apply to this query + + Args: + - row_filter: a RowFilter to apply to this query + Can be a RowFilter object or a dict representation + Returns: + - a reference to this query for chaining + """ + if not ( + isinstance(row_filter, dict) + or isinstance(row_filter, RowFilter) + or row_filter is None + ): + raise ValueError("row_filter must be a RowFilter or dict") + self._filter = row_filter + + def add_key(self, row_key: str | bytes): + """ + Add a row key to this query + + A query can contain multiple keys, but ranges should be preferred + + Args: + - row_key: a key to add to this query + Returns: + - a reference to this query for chaining + Raises: + - ValueError if an input is not a string or bytes + """ + if isinstance(row_key, str): + row_key = row_key.encode() + elif not isinstance(row_key, bytes): + raise ValueError("row_key must be string or bytes") + self.row_keys.add(row_key) def add_range( - self, start_key: str | bytes | None = None, end_key: str | bytes | None = None - ) -> ReadRowsQuery: - raise NotImplementedError + self, + row_range: RowRange | dict[str, bytes], + ): + """ + Add a range of row keys to this query. + + Args: + - row_range: a range of row keys to add to this query + Can be a RowRange object or a dict representation in + RowRange proto format + """ + if not (isinstance(row_range, dict) or isinstance(row_range, RowRange)): + raise ValueError("row_range must be a RowRange or dict") + self.row_ranges.append(row_range) def shard(self, shard_keys: "RowKeySamples" | None = None) -> list[ReadRowsQuery]: """ @@ -54,3 +210,27 @@ def shard(self, shard_keys: "RowKeySamples" | None = None) -> list[ReadRowsQuery query (if possible) """ raise NotImplementedError + + def _to_dict(self) -> dict[str, Any]: + """ + Convert this query into a dictionary that can be used to construct a + ReadRowsRequest protobuf + """ + row_ranges = [] + for r in self.row_ranges: + dict_range = r._to_dict() if isinstance(r, RowRange) else r + row_ranges.append(dict_range) + row_keys = list(self.row_keys) + row_keys.sort() + row_set = {"row_keys": row_keys, "row_ranges": row_ranges} + final_dict: dict[str, Any] = { + "rows": row_set, + } + dict_filter = ( + self.filter.to_dict() if isinstance(self.filter, RowFilter) else self.filter + ) + if dict_filter: + final_dict["filter"] = dict_filter + if self.limit is not None: + final_dict["rows_limit"] = self.limit + return final_dict diff --git a/tests/unit/test_read_rows_query.py b/tests/unit/test_read_rows_query.py new file mode 100644 index 000000000..aa690bc86 --- /dev/null +++ b/tests/unit/test_read_rows_query.py @@ -0,0 +1,359 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +TEST_ROWS = [ + "row_key_1", + b"row_key_2", +] + + +class TestRowRange(unittest.TestCase): + @staticmethod + def _get_target_class(): + from google.cloud.bigtable.read_rows_query import RowRange + + return RowRange + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def test_ctor_start_end(self): + row_range = self._make_one("test_row", "test_row2") + self.assertEqual(row_range.start.key, "test_row".encode()) + self.assertEqual(row_range.end.key, "test_row2".encode()) + self.assertEqual(row_range.start.is_inclusive, True) + self.assertEqual(row_range.end.is_inclusive, False) + + def test_ctor_start_only(self): + row_range = self._make_one("test_row3") + self.assertEqual(row_range.start.key, "test_row3".encode()) + self.assertEqual(row_range.start.is_inclusive, True) + self.assertEqual(row_range.end, None) + + def test_ctor_end_only(self): + row_range = self._make_one(end_key="test_row4") + self.assertEqual(row_range.end.key, "test_row4".encode()) + self.assertEqual(row_range.end.is_inclusive, False) + self.assertEqual(row_range.start, None) + + def test_ctor_inclusive_flags(self): + row_range = self._make_one("test_row5", "test_row6", False, True) + self.assertEqual(row_range.start.key, "test_row5".encode()) + self.assertEqual(row_range.end.key, "test_row6".encode()) + self.assertEqual(row_range.start.is_inclusive, False) + self.assertEqual(row_range.end.is_inclusive, True) + + def test_ctor_defaults(self): + row_range = self._make_one() + self.assertEqual(row_range.start, None) + self.assertEqual(row_range.end, None) + + def test_ctor_flags_only(self): + with self.assertRaises(ValueError) as exc: + self._make_one(start_is_inclusive=True, end_is_inclusive=True) + self.assertEqual( + exc.exception.args, + ("start_is_inclusive must be set with start_key",), + ) + with self.assertRaises(ValueError) as exc: + self._make_one(start_is_inclusive=False, end_is_inclusive=False) + self.assertEqual( + exc.exception.args, + ("start_is_inclusive must be set with start_key",), + ) + with self.assertRaises(ValueError) as exc: + self._make_one(start_is_inclusive=False) + self.assertEqual( + exc.exception.args, + ("start_is_inclusive must be set with start_key",), + ) + with self.assertRaises(ValueError) as exc: + self._make_one(end_is_inclusive=True) + self.assertEqual( + exc.exception.args, ("end_is_inclusive must be set with end_key",) + ) + + def test_ctor_invalid_keys(self): + # test with invalid keys + with self.assertRaises(ValueError) as exc: + self._make_one(1, "2") + self.assertEqual(exc.exception.args, ("start_key must be a string or bytes",)) + with self.assertRaises(ValueError) as exc: + self._make_one("1", 2) + self.assertEqual(exc.exception.args, ("end_key must be a string or bytes",)) + + def test__to_dict_defaults(self): + row_range = self._make_one("test_row", "test_row2") + expected = { + "start_key_closed": b"test_row", + "end_key_open": b"test_row2", + } + self.assertEqual(row_range._to_dict(), expected) + + def test__to_dict_inclusive_flags(self): + row_range = self._make_one("test_row", "test_row2", False, True) + expected = { + "start_key_open": b"test_row", + "end_key_closed": b"test_row2", + } + self.assertEqual(row_range._to_dict(), expected) + + +class TestReadRowsQuery(unittest.TestCase): + @staticmethod + def _get_target_class(): + from google.cloud.bigtable.read_rows_query import ReadRowsQuery + + return ReadRowsQuery + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def test_ctor_defaults(self): + query = self._make_one() + self.assertEqual(query.row_keys, set()) + self.assertEqual(query.row_ranges, []) + self.assertEqual(query.filter, None) + self.assertEqual(query.limit, None) + + def test_ctor_explicit(self): + from google.cloud.bigtable.row_filters import RowFilterChain + + filter_ = RowFilterChain() + query = self._make_one(["row_key_1", "row_key_2"], limit=10, row_filter=filter_) + self.assertEqual(len(query.row_keys), 2) + self.assertIn("row_key_1".encode(), query.row_keys) + self.assertIn("row_key_2".encode(), query.row_keys) + self.assertEqual(query.row_ranges, []) + self.assertEqual(query.filter, filter_) + self.assertEqual(query.limit, 10) + + def test_ctor_invalid_limit(self): + with self.assertRaises(ValueError) as exc: + self._make_one(limit=-1) + self.assertEqual(exc.exception.args, ("limit must be >= 0",)) + + def test_set_filter(self): + from google.cloud.bigtable.row_filters import RowFilterChain + + filter1 = RowFilterChain() + query = self._make_one() + self.assertEqual(query.filter, None) + query.filter = filter1 + self.assertEqual(query.filter, filter1) + filter2 = RowFilterChain() + query.filter = filter2 + self.assertEqual(query.filter, filter2) + query.filter = None + self.assertEqual(query.filter, None) + query.filter = RowFilterChain() + self.assertEqual(query.filter, RowFilterChain()) + with self.assertRaises(ValueError) as exc: + query.filter = 1 + self.assertEqual( + exc.exception.args, ("row_filter must be a RowFilter or dict",) + ) + + def test_set_filter_dict(self): + from google.cloud.bigtable.row_filters import RowSampleFilter + from google.cloud.bigtable_v2.types.bigtable import ReadRowsRequest + + filter1 = RowSampleFilter(0.5) + filter1_dict = filter1.to_dict() + query = self._make_one() + self.assertEqual(query.filter, None) + query.filter = filter1_dict + self.assertEqual(query.filter, filter1_dict) + output = query._to_dict() + self.assertEqual(output["filter"], filter1_dict) + proto_output = ReadRowsRequest(**output) + self.assertEqual(proto_output.filter, filter1._to_pb()) + + query.filter = None + self.assertEqual(query.filter, None) + + def test_set_limit(self): + query = self._make_one() + self.assertEqual(query.limit, None) + query.limit = 10 + self.assertEqual(query.limit, 10) + query.limit = 9 + self.assertEqual(query.limit, 9) + query.limit = 0 + self.assertEqual(query.limit, 0) + with self.assertRaises(ValueError) as exc: + query.limit = -1 + self.assertEqual(exc.exception.args, ("limit must be >= 0",)) + with self.assertRaises(ValueError) as exc: + query.limit = -100 + self.assertEqual(exc.exception.args, ("limit must be >= 0",)) + + def test_add_key_str(self): + query = self._make_one() + self.assertEqual(query.row_keys, set()) + input_str = "test_row" + query.add_key(input_str) + self.assertEqual(len(query.row_keys), 1) + self.assertIn(input_str.encode(), query.row_keys) + input_str2 = "test_row2" + query.add_key(input_str2) + self.assertEqual(len(query.row_keys), 2) + self.assertIn(input_str.encode(), query.row_keys) + self.assertIn(input_str2.encode(), query.row_keys) + + def test_add_key_bytes(self): + query = self._make_one() + self.assertEqual(query.row_keys, set()) + input_bytes = b"test_row" + query.add_key(input_bytes) + self.assertEqual(len(query.row_keys), 1) + self.assertIn(input_bytes, query.row_keys) + input_bytes2 = b"test_row2" + query.add_key(input_bytes2) + self.assertEqual(len(query.row_keys), 2) + self.assertIn(input_bytes, query.row_keys) + self.assertIn(input_bytes2, query.row_keys) + + def test_add_rows_batch(self): + query = self._make_one() + self.assertEqual(query.row_keys, set()) + input_batch = ["test_row", b"test_row2", "test_row3"] + for k in input_batch: + query.add_key(k) + self.assertEqual(len(query.row_keys), 3) + self.assertIn(b"test_row", query.row_keys) + self.assertIn(b"test_row2", query.row_keys) + self.assertIn(b"test_row3", query.row_keys) + # test adding another batch + for k in ["test_row4", b"test_row5"]: + query.add_key(k) + self.assertEqual(len(query.row_keys), 5) + self.assertIn(input_batch[0].encode(), query.row_keys) + self.assertIn(input_batch[1], query.row_keys) + self.assertIn(input_batch[2].encode(), query.row_keys) + self.assertIn(b"test_row4", query.row_keys) + self.assertIn(b"test_row5", query.row_keys) + + def test_add_key_invalid(self): + query = self._make_one() + with self.assertRaises(ValueError) as exc: + query.add_key(1) + self.assertEqual(exc.exception.args, ("row_key must be string or bytes",)) + with self.assertRaises(ValueError) as exc: + query.add_key(["s"]) + self.assertEqual(exc.exception.args, ("row_key must be string or bytes",)) + + def test_duplicate_rows(self): + # should only hold one of each input key + key_1 = b"test_row" + key_2 = b"test_row2" + query = self._make_one(row_keys=[key_1, key_1, key_2]) + self.assertEqual(len(query.row_keys), 2) + self.assertIn(key_1, query.row_keys) + self.assertIn(key_2, query.row_keys) + key_3 = "test_row3" + for i in range(10): + query.add_key(key_3) + self.assertEqual(len(query.row_keys), 3) + + def test_add_range(self): + from google.cloud.bigtable.read_rows_query import RowRange + + query = self._make_one() + self.assertEqual(query.row_ranges, []) + input_range = RowRange(start_key=b"test_row") + query.add_range(input_range) + self.assertEqual(len(query.row_ranges), 1) + self.assertEqual(query.row_ranges[0], input_range) + input_range2 = RowRange(start_key=b"test_row2") + query.add_range(input_range2) + self.assertEqual(len(query.row_ranges), 2) + self.assertEqual(query.row_ranges[0], input_range) + self.assertEqual(query.row_ranges[1], input_range2) + + def test_add_range_dict(self): + query = self._make_one() + self.assertEqual(query.row_ranges, []) + input_range = {"start_key_closed": b"test_row"} + query.add_range(input_range) + self.assertEqual(len(query.row_ranges), 1) + self.assertEqual(query.row_ranges[0], input_range) + + def test_to_dict_rows_default(self): + # dictionary should be in rowset proto format + from google.cloud.bigtable_v2.types.bigtable import ReadRowsRequest + + query = self._make_one() + output = query._to_dict() + self.assertTrue(isinstance(output, dict)) + self.assertEqual(len(output.keys()), 1) + expected = {"rows": {"row_keys": [], "row_ranges": []}} + self.assertEqual(output, expected) + + request_proto = ReadRowsRequest(**output) + self.assertEqual(request_proto.rows.row_keys, []) + self.assertEqual(request_proto.rows.row_ranges, []) + self.assertFalse(request_proto.filter) + self.assertEqual(request_proto.rows_limit, 0) + + def test_to_dict_rows_populated(self): + # dictionary should be in rowset proto format + from google.cloud.bigtable_v2.types.bigtable import ReadRowsRequest + from google.cloud.bigtable.row_filters import PassAllFilter + from google.cloud.bigtable.read_rows_query import RowRange + + row_filter = PassAllFilter(False) + query = self._make_one(limit=100, row_filter=row_filter) + query.add_range(RowRange("test_row", "test_row2")) + query.add_range(RowRange("test_row3")) + query.add_range(RowRange(start_key=None, end_key="test_row5")) + query.add_range(RowRange(b"test_row6", b"test_row7", False, True)) + query.add_range(RowRange()) + query.add_key("test_row") + query.add_key(b"test_row2") + query.add_key("test_row3") + query.add_key(b"test_row3") + query.add_key(b"test_row4") + output = query._to_dict() + self.assertTrue(isinstance(output, dict)) + request_proto = ReadRowsRequest(**output) + rowset_proto = request_proto.rows + # check rows + self.assertEqual(len(rowset_proto.row_keys), 4) + self.assertEqual(rowset_proto.row_keys[0], b"test_row") + self.assertEqual(rowset_proto.row_keys[1], b"test_row2") + self.assertEqual(rowset_proto.row_keys[2], b"test_row3") + self.assertEqual(rowset_proto.row_keys[3], b"test_row4") + # check ranges + self.assertEqual(len(rowset_proto.row_ranges), 5) + self.assertEqual(rowset_proto.row_ranges[0].start_key_closed, b"test_row") + self.assertEqual(rowset_proto.row_ranges[0].end_key_open, b"test_row2") + self.assertEqual(rowset_proto.row_ranges[1].start_key_closed, b"test_row3") + self.assertEqual(rowset_proto.row_ranges[1].end_key_open, b"") + self.assertEqual(rowset_proto.row_ranges[2].start_key_closed, b"") + self.assertEqual(rowset_proto.row_ranges[2].end_key_open, b"test_row5") + self.assertEqual(rowset_proto.row_ranges[3].start_key_open, b"test_row6") + self.assertEqual(rowset_proto.row_ranges[3].end_key_closed, b"test_row7") + self.assertEqual(rowset_proto.row_ranges[4].start_key_closed, b"") + self.assertEqual(rowset_proto.row_ranges[4].end_key_open, b"") + # check limit + self.assertEqual(request_proto.rows_limit, 100) + # check filter + filter_proto = request_proto.filter + self.assertEqual(filter_proto, row_filter._to_pb()) + + def test_shard(self): + pass