diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index ab2d5dd63..42ba3f83c 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -81,7 +81,11 @@ def visit_FunctionDef(self, node): def visit_Constant(self, node): """Replace string type annotations""" - node.s = self.replacements.get(node.s, node.s) + try: + node.s = self.replacements.get(node.s, node.s) + except TypeError: + # ignore unhashable types (e.g. list) + pass return node diff --git a/docs/data_client/async_data_authorized_view.rst b/docs/data_client/async_data_authorized_view.rst new file mode 100644 index 000000000..7d7312970 --- /dev/null +++ b/docs/data_client/async_data_authorized_view.rst @@ -0,0 +1,11 @@ +Authorized View Async +~~~~~~~~~~~~~~~~~~~~~ + + .. note:: + + It is generally not recommended to use the async client in an otherwise synchronous codebase. To make use of asyncio's + performance benefits, the codebase should be designed to be async from the ground up. + +.. autoclass:: google.cloud.bigtable.data._async.client.AuthorizedViewAsync + :members: + :inherited-members: diff --git a/docs/data_client/async_data_table.rst b/docs/data_client/async_data_table.rst index 3b7973e8e..37c396570 100644 --- a/docs/data_client/async_data_table.rst +++ b/docs/data_client/async_data_table.rst @@ -8,4 +8,4 @@ Table Async .. autoclass:: google.cloud.bigtable.data._async.client.TableAsync :members: - :show-inheritance: + :inherited-members: diff --git a/docs/data_client/data_client_usage.rst b/docs/data_client/data_client_usage.rst index f5bbac278..708dafc62 100644 --- a/docs/data_client/data_client_usage.rst +++ b/docs/data_client/data_client_usage.rst @@ -9,6 +9,7 @@ Sync Surface sync_data_client sync_data_table + sync_data_authorized_view sync_data_mutations_batcher sync_data_execute_query_iterator @@ -20,6 +21,7 @@ Async Surface async_data_client async_data_table + async_data_authorized_view async_data_mutations_batcher async_data_execute_query_iterator diff --git a/docs/data_client/sync_data_authorized_view.rst b/docs/data_client/sync_data_authorized_view.rst new file mode 100644 index 000000000..c0ac29721 --- /dev/null +++ b/docs/data_client/sync_data_authorized_view.rst @@ -0,0 +1,6 @@ +Authorized View +~~~~~~~~~~~~~~~ + +.. autoclass:: google.cloud.bigtable.data._sync_autogen.client.AuthorizedView + :members: + :inherited-members: diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py index 15f9bc167..9439f0f8d 100644 --- a/google/cloud/bigtable/data/__init__.py +++ b/google/cloud/bigtable/data/__init__.py @@ -17,9 +17,11 @@ from google.cloud.bigtable.data._async.client import BigtableDataClientAsync from google.cloud.bigtable.data._async.client import TableAsync +from google.cloud.bigtable.data._async.client import AuthorizedViewAsync from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync from google.cloud.bigtable.data._sync_autogen.client import BigtableDataClient from google.cloud.bigtable.data._sync_autogen.client import Table +from google.cloud.bigtable.data._sync_autogen.client import AuthorizedView from google.cloud.bigtable.data._sync_autogen.mutations_batcher import MutationsBatcher from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery @@ -76,9 +78,11 @@ __all__ = ( "BigtableDataClientAsync", "TableAsync", + "AuthorizedViewAsync", "MutationsBatcherAsync", "BigtableDataClient", "Table", + "AuthorizedView", "MutationsBatcher", "RowKeySamples", "ReadRowsQuery", diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index bf618bf04..8e6833bca 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -15,10 +15,10 @@ from __future__ import annotations from typing import Sequence, TYPE_CHECKING -import functools from google.api_core import exceptions as core_exceptions from google.api_core import retry as retries +import google.cloud.bigtable_v2.types.bigtable as types_pb import google.cloud.bigtable.data.exceptions as bt_exceptions from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _retry_exception_factory @@ -36,12 +36,16 @@ from google.cloud.bigtable_v2.services.bigtable.async_client import ( BigtableAsyncClient as GapicClientType, ) - from google.cloud.bigtable.data._async.client import TableAsync as TableType + from google.cloud.bigtable.data._async.client import ( # type: ignore + _DataApiTargetAsync as TargetType, + ) else: from google.cloud.bigtable_v2.services.bigtable.client import ( # type: ignore BigtableClient as GapicClientType, ) - from google.cloud.bigtable.data._sync_autogen.client import Table as TableType # type: ignore + from google.cloud.bigtable.data._sync_autogen.client import ( # type: ignore + _DataApiTarget as TargetType, + ) __CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen._mutate_rows" @@ -59,7 +63,7 @@ class _MutateRowsOperationAsync: Args: gapic_client: the client to use for the mutate_rows call - table: the table associated with the request + target: the table or view associated with the request mutation_entries: a list of RowMutationEntry objects to send to the server operation_timeout: the timeout to use for the entire operation, in seconds. attempt_timeout: the timeout to use for each mutate_rows attempt, in seconds. @@ -70,7 +74,7 @@ class _MutateRowsOperationAsync: def __init__( self, gapic_client: GapicClientType, - table: TableType, + target: TargetType, mutation_entries: list["RowMutationEntry"], operation_timeout: float, attempt_timeout: float | None, @@ -84,13 +88,8 @@ def __init__( f"{_MUTATE_ROWS_REQUEST_MUTATION_LIMIT} mutations across " f"all entries. Found {total_mutations}." ) - # create partial function to pass to trigger rpc call - self._gapic_fn = functools.partial( - gapic_client.mutate_rows, - table_name=table.table_name, - app_profile_id=table.app_profile_id, - retry=None, - ) + self._target = target + self._gapic_fn = gapic_client.mutate_rows # create predicate for determining which errors are retryable self.is_retryable = retries.if_exception_type( # RPC level errors @@ -173,8 +172,12 @@ async def _run_attempt(self): # make gapic request try: result_generator = await self._gapic_fn( + request=types_pb.MutateRowsRequest( + entries=request_entries, + app_profile_id=self._target.app_profile_id, + **self._target._request_path, + ), timeout=next(self.timeout_generator), - entries=request_entries, retry=None, ) async for result_list in result_generator: diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 6d2fa3a7d..8787bfa71 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -37,9 +37,11 @@ if TYPE_CHECKING: if CrossSync.is_async: - from google.cloud.bigtable.data._async.client import TableAsync as TableType + from google.cloud.bigtable.data._async.client import ( + _DataApiTargetAsync as TargetType, + ) else: - from google.cloud.bigtable.data._sync_autogen.client import Table as TableType # type: ignore + from google.cloud.bigtable.data._sync_autogen.client import _DataApiTarget as TargetType # type: ignore __CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen._read_rows" @@ -59,7 +61,7 @@ class _ReadRowsOperationAsync: Args: query: The query to execute - table: The table to send the request to + target: The table or view to send the request to operation_timeout: The total time to allow for the operation, in seconds attempt_timeout: The time to allow for each individual attempt, in seconds retryable_exceptions: A list of exceptions that should trigger a retry @@ -69,7 +71,7 @@ class _ReadRowsOperationAsync: "attempt_timeout_gen", "operation_timeout", "request", - "table", + "target", "_predicate", "_last_yielded_row_key", "_remaining_count", @@ -78,7 +80,7 @@ class _ReadRowsOperationAsync: def __init__( self, query: ReadRowsQuery, - table: TableType, + target: TargetType, operation_timeout: float, attempt_timeout: float, retryable_exceptions: Sequence[type[Exception]] = (), @@ -90,12 +92,12 @@ def __init__( if isinstance(query, dict): self.request = ReadRowsRequestPB( **query, - table_name=table.table_name, - app_profile_id=table.app_profile_id, + **target._request_path, + app_profile_id=target.app_profile_id, ) else: - self.request = query._to_pb(table) - self.table = table + self.request = query._to_pb(target) + self.target = target self._predicate = retries.if_exception_type(*retryable_exceptions) self._last_yielded_row_key: bytes | None = None self._remaining_count: int | None = self.request.rows_limit or None @@ -142,7 +144,7 @@ def _read_rows_attempt(self) -> CrossSync.Iterable[Row]: if self._remaining_count == 0: return self.merge_rows(None) # create and return a new row merger - gapic_stream = self.table.client._gapic_client.read_rows( + gapic_stream = self.target.client._gapic_client.read_rows( self.request, timeout=next(self.attempt_timeout_gen), retry=None, diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 3c5093d10..6ee21b554 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -25,6 +25,7 @@ TYPE_CHECKING, ) +import abc import time import warnings import random @@ -47,6 +48,10 @@ DEFAULT_CLIENT_INFO, ) from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest +from google.cloud.bigtable_v2.types.bigtable import SampleRowKeysRequest +from google.cloud.bigtable_v2.types.bigtable import MutateRowRequest +from google.cloud.bigtable_v2.types.bigtable import CheckAndMutateRowRequest +from google.cloud.bigtable_v2.types.bigtable import ReadModifyWriteRowRequest from google.cloud.client import ClientWithProject from google.cloud.environment_vars import BIGTABLE_EMULATOR # type: ignore from google.api_core import retry as retries @@ -210,8 +215,8 @@ def __init__( self.transport = cast(TransportType, self._gapic_client.transport) # keep track of active instances to for warmup on channel refresh self._active_instances: Set[_WarmedInstanceKey] = set() - # keep track of table objects associated with each instance - # only remove instance from _active_instances when all associated tables remove it + # keep track of _DataApiTarget objects associated with each instance + # only remove instance from _active_instances when all associated targets are closed self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() self._channel_refresh_task: CrossSync.Task[None] | None = None @@ -320,7 +325,7 @@ async def _ping_and_warm_instances( ], wait_for_ready=True, ) - for (instance_name, table_name, app_profile_id) in instance_list + for (instance_name, app_profile_id) in instance_list ] result_list = await CrossSync.gather_partials( partial_list, return_exceptions=True, sync_executor=self._executor @@ -404,10 +409,13 @@ async def _manage_channel( replace_symbols={ "TableAsync": "Table", "ExecuteQueryIteratorAsync": "ExecuteQueryIterator", + "_DataApiTargetAsync": "_DataApiTarget", } ) async def _register_instance( - self, instance_id: str, owner: TableAsync | ExecuteQueryIteratorAsync + self, + instance_id: str, + owner: _DataApiTargetAsync | ExecuteQueryIteratorAsync, ) -> None: """ Registers an instance with the client, and warms the channel for the instance @@ -422,9 +430,7 @@ async def _register_instance( owners call _remove_instance_registration """ instance_name = self._gapic_client.instance_path(self.project, instance_id) - instance_key = _WarmedInstanceKey( - instance_name, owner.table_name, owner.app_profile_id - ) + instance_key = _WarmedInstanceKey(instance_name, owner.app_profile_id) self._instance_owners.setdefault(instance_key, set()).add(id(owner)) if instance_key not in self._active_instances: self._active_instances.add(instance_key) @@ -440,10 +446,13 @@ async def _register_instance( replace_symbols={ "TableAsync": "Table", "ExecuteQueryIteratorAsync": "ExecuteQueryIterator", + "_DataApiTargetAsync": "_DataApiTarget", } ) async def _remove_instance_registration( - self, instance_id: str, owner: TableAsync | "ExecuteQueryIteratorAsync" + self, + instance_id: str, + owner: _DataApiTargetAsync | ExecuteQueryIteratorAsync, ) -> bool: """ Removes an instance from the client's registered instances, to prevent @@ -460,9 +469,7 @@ async def _remove_instance_registration( bool: True if instance was removed, else False """ instance_name = self._gapic_client.instance_path(self.project, instance_id) - instance_key = _WarmedInstanceKey( - instance_name, owner.table_name, owner.app_profile_id - ) + instance_key = _WarmedInstanceKey(instance_name, owner.app_profile_id) owner_list = self._instance_owners.get(instance_key, set()) try: owner_list.remove(id(owner)) @@ -528,6 +535,72 @@ def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAs """ return TableAsync(self, instance_id, table_id, *args, **kwargs) + @CrossSync.convert( + replace_symbols={"AuthorizedViewAsync": "AuthorizedView"}, + docstring_format_vars={ + "LOOP_MESSAGE": ( + "Must be created within an async context (running event loop)", + "", + ), + "RAISE_NO_LOOP": ( + "RuntimeError: if called outside of an async context (no running event loop)", + "None", + ), + }, + ) + def get_authorized_view( + self, instance_id: str, table_id: str, authorized_view_id: str, *args, **kwargs + ) -> AuthorizedViewAsync: + """ + Returns an authorized view instance for making data API requests. All arguments are passed + directly to the AuthorizedViewAsync constructor. + + {LOOP_MESSAGE} + + Args: + instance_id: The Bigtable instance ID to associate with this client. + instance_id is combined with the client's project to fully + specify the instance + table_id: The ID of the table. table_id is combined with the + instance_id and the client's project to fully specify the table + authorized_view_id: The id for the authorized view to use for requests + app_profile_id: The app profile to associate with requests. + https://cloud.google.com/bigtable/docs/app-profiles + default_read_rows_operation_timeout: The default timeout for read rows + operations, in seconds. If not set, defaults to Table's value + default_read_rows_attempt_timeout: The default timeout for individual + read rows rpc requests, in seconds. If not set, defaults Table's value + default_mutate_rows_operation_timeout: The default timeout for mutate rows + operations, in seconds. If not set, defaults to Table's value + default_mutate_rows_attempt_timeout: The default timeout for individual + mutate rows rpc requests, in seconds. If not set, defaults Table's value + default_operation_timeout: The default timeout for all other operations, in + seconds. If not set, defaults to Table's value + default_attempt_timeout: The default timeout for all other individual rpc + requests, in seconds. If not set, defaults to Table's value + default_read_rows_retryable_errors: a list of errors that will be retried + if encountered during read_rows and related operations. If not set, + defaults to Table's value + default_mutate_rows_retryable_errors: a list of errors that will be retried + if encountered during mutate_rows and related operations. If not set, + defaults to Table's value + default_retryable_errors: a list of errors that will be retried if + encountered during all other operations. If not set, defaults to + Table's value + Returns: + AuthorizedViewAsync: a table instance for making data API requests + Raises: + {RAISE_NO_LOOP} + """ + return CrossSync.AuthorizedView( + self, + instance_id, + table_id, + authorized_view_id, + *args, + **kwargs, + ) + @CrossSync.convert( replace_symbols={"ExecuteQueryIteratorAsync": "ExecuteQueryIterator"} ) @@ -679,13 +752,12 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb) -@CrossSync.convert_class(sync_name="Table", add_mapping_for_name="Table") -class TableAsync: +@CrossSync.convert_class(sync_name="_DataApiTarget") +class _DataApiTargetAsync(abc.ABC): """ - Main Data API surface + Abstract class containing API surface for BigtableDataClient. Should not be created directly - Table object maintains table_id, and app_profile_id context, and passes them with - each call + Can be instantiated as a Table or an AuthorizedView """ @CrossSync.convert( @@ -809,6 +881,7 @@ def __init__( default_mutate_rows_retryable_errors or () ) self.default_retryable_errors = default_retryable_errors or () + try: self._register_instance_future = CrossSync.create_task( self.client._register_instance, @@ -821,6 +894,20 @@ def __init__( f"{self.__class__.__name__} must be created within an async event loop context." ) from e + @property + @abc.abstractmethod + def _request_path(self) -> dict[str, str]: + """ + Used to populate table_name or authorized_view_name for rpc requests, depending on the subclass + + Unimplemented in base class + """ + raise NotImplementedError + + def __str__(self): + path_str = list(self._request_path.values())[0] if self._request_path else "" + return f"{self.__class__.__name__}<{path_str!r}>" + @CrossSync.convert(replace_symbols={"AsyncIterable": "Iterable"}) async def read_rows_stream( self, @@ -1177,8 +1264,9 @@ async def sample_row_keys( @CrossSync.convert async def execute_rpc(): results = await self.client._gapic_client.sample_row_keys( - table_name=self.table_name, - app_profile_id=self.app_profile_id, + request=SampleRowKeysRequest( + app_profile_id=self.app_profile_id, **self._request_path + ), timeout=next(attempt_timeout_gen), retry=None, ) @@ -1305,10 +1393,14 @@ async def mutate_row( target = partial( self.client._gapic_client.mutate_row, - row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, - mutations=[mutation._to_pb() for mutation in mutations_list], - table_name=self.table_name, - app_profile_id=self.app_profile_id, + request=MutateRowRequest( + row_key=row_key.encode("utf-8") + if isinstance(row_key, str) + else row_key, + mutations=[mutation._to_pb() for mutation in mutations_list], + app_profile_id=self.app_profile_id, + **self._request_path, + ), timeout=attempt_timeout, retry=None, ) @@ -1430,12 +1522,16 @@ async def check_and_mutate_row( false_case_mutations = [false_case_mutations] false_case_list = [m._to_pb() for m in false_case_mutations or []] result = await self.client._gapic_client.check_and_mutate_row( - true_mutations=true_case_list, - false_mutations=false_case_list, - predicate_filter=predicate._to_pb() if predicate is not None else None, - row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, - table_name=self.table_name, - app_profile_id=self.app_profile_id, + request=CheckAndMutateRowRequest( + true_mutations=true_case_list, + false_mutations=false_case_list, + predicate_filter=predicate._to_pb() if predicate is not None else None, + row_key=row_key.encode("utf-8") + if isinstance(row_key, str) + else row_key, + app_profile_id=self.app_profile_id, + **self._request_path, + ), timeout=operation_timeout, retry=None, ) @@ -1480,10 +1576,14 @@ async def read_modify_write_row( if not rules: raise ValueError("rules must contain at least one item") result = await self.client._gapic_client.read_modify_write_row( - rules=[rule._to_pb() for rule in rules], - row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, - table_name=self.table_name, - app_profile_id=self.app_profile_id, + request=ReadModifyWriteRowRequest( + rules=[rule._to_pb() for rule in rules], + row_key=row_key.encode("utf-8") + if isinstance(row_key, str) + else row_key, + app_profile_id=self.app_profile_id, + **self._request_path, + ), timeout=operation_timeout, retry=None, ) @@ -1520,3 +1620,107 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): grpc channels will no longer be warmed """ await self.close() + + +@CrossSync.convert_class( + sync_name="Table", + add_mapping_for_name="Table", + replace_symbols={"_DataApiTargetAsync": "_DataApiTarget"}, +) +class TableAsync(_DataApiTargetAsync): + """ + Main Data API surface for interacting with a Bigtable table. + + Table object maintains table_id, and app_profile_id context, and passes them with + each call + """ + + @property + def _request_path(self) -> dict[str, str]: + return {"table_name": self.table_name} + + +@CrossSync.convert_class( + sync_name="AuthorizedView", + add_mapping_for_name="AuthorizedView", + replace_symbols={"_DataApiTargetAsync": "_DataApiTarget"}, +) +class AuthorizedViewAsync(_DataApiTargetAsync): + """ + Provides access to an authorized view of a table. + + An authorized view is a subset of a table that you configure to include specific table data. + Then you grant access to the authorized view separately from access to the table. + + AuthorizedView object maintains table_id, app_profile_id, and authorized_view_id context, + and passed them with each call + """ + + @CrossSync.convert( + docstring_format_vars={ + "LOOP_MESSAGE": ( + "Must be created within an async context (running event loop)", + "", + ), + "RAISE_NO_LOOP": ( + "RuntimeError: if called outside of an async context (no running event loop)", + "None", + ), + } + ) + def __init__( + self, + client, + instance_id, + table_id, + authorized_view_id, + app_profile_id: str | None = None, + **kwargs, + ): + """ + Initialize an AuthorizedView instance + + {LOOP_MESSAGE} + + Args: + instance_id: The Bigtable instance ID to associate with this client. + instance_id is combined with the client's project to fully + specify the instance + table_id: The ID of the table. table_id is combined with the + instance_id and the client's project to fully specify the table + authorized_view_id: The id for the authorized view to use for requests + app_profile_id: The app profile to associate with requests. + https://cloud.google.com/bigtable/docs/app-profiles + default_read_rows_operation_timeout: The default timeout for read rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_read_rows_attempt_timeout: The default timeout for individual + read rows rpc requests, in seconds. If not set, defaults to 20 seconds + default_mutate_rows_operation_timeout: The default timeout for mutate rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_mutate_rows_attempt_timeout: The default timeout for individual + mutate rows rpc requests, in seconds. If not set, defaults to 60 seconds + default_operation_timeout: The default timeout for all other operations, in + seconds. If not set, defaults to 60 seconds + default_attempt_timeout: The default timeout for all other individual rpc + requests, in seconds. If not set, defaults to 20 seconds + default_read_rows_retryable_errors: a list of errors that will be retried + if encountered during read_rows and related operations. + Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted) + default_mutate_rows_retryable_errors: a list of errors that will be retried + if encountered during mutate_rows and related operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + default_retryable_errors: a list of errors that will be retried if + encountered during all other operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + Raises: + {RAISE_NO_LOOP} + """ + super().__init__(client, instance_id, table_id, app_profile_id, **kwargs) + self.authorized_view_id = authorized_view_id + self.authorized_view_name: str = self.client._gapic_client.authorized_view_path( + self.client.project, instance_id, table_id, authorized_view_id + ) + + @property + def _request_path(self) -> dict[str, str]: + return {"authorized_view_name": self.authorized_view_name} diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 6e15bb5f3..a8e99ea9e 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -37,9 +37,11 @@ from google.cloud.bigtable.data.mutations import RowMutationEntry if CrossSync.is_async: - from google.cloud.bigtable.data._async.client import TableAsync as TableType + from google.cloud.bigtable.data._async.client import ( + _DataApiTargetAsync as TargetType, + ) else: - from google.cloud.bigtable.data._sync_autogen.client import Table as TableType # type: ignore + from google.cloud.bigtable.data._sync_autogen.client import _DataApiTarget as TargetType # type: ignore __CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen.mutations_batcher" @@ -179,7 +181,7 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] @CrossSync.convert_class(sync_name="MutationsBatcher") class MutationsBatcherAsync: """ - Allows users to send batches using context manager API: + Allows users to send batches using context manager API. Runs mutate_row, mutate_rows, and check_and_mutate_row internally, combining to use as few network requests as required @@ -191,7 +193,7 @@ class MutationsBatcherAsync: - when batcher is closed or destroyed Args: - table: Table to preform rpc calls + table: table or autrhorized_view used to preform rpc calls flush_interval: Automatically flush every flush_interval seconds. If None, no time-based flushing is performed. flush_limit_mutation_count: Flush immediately after flush_limit_mutation_count @@ -210,7 +212,7 @@ class MutationsBatcherAsync: def __init__( self, - table: TableType, + table: TargetType, *, flush_interval: float | None = 5, flush_limit_mutation_count: int | None = 1000, @@ -230,7 +232,7 @@ def __init__( ) self._closed = CrossSync.Event() - self._table = table + self._target = table self._staged_entries: list[RowMutationEntry] = [] self._staged_count, self._staged_bytes = 0, 0 self._flow_control = CrossSync._FlowControl( @@ -380,8 +382,8 @@ async def _execute_mutate_rows( """ try: operation = CrossSync._MutateRowsOperation( - self._table.client._gapic_client, - self._table, + self._target.client._gapic_client, + self._target, batch, operation_timeout=self._operation_timeout, attempt_timeout=self._attempt_timeout, @@ -491,7 +493,7 @@ def _on_exit(self): """ if not self._closed.is_set() and self._staged_entries: warnings.warn( - f"MutationsBatcher for table {self._table.table_name} was not closed. " + f"MutationsBatcher for target {self._target!r} was not closed. " f"{len(self._staged_entries)} Unflushed mutations will not be sent to the server." ) diff --git a/google/cloud/bigtable/data/_cross_sync/_decorators.py b/google/cloud/bigtable/data/_cross_sync/_decorators.py index ea86e83af..a0dd140dd 100644 --- a/google/cloud/bigtable/data/_cross_sync/_decorators.py +++ b/google/cloud/bigtable/data/_cross_sync/_decorators.py @@ -179,7 +179,8 @@ def _convert_ast_to_py(cls, ast_node: ast.expr | None) -> Any: cls._convert_ast_to_py(k): cls._convert_ast_to_py(v) for k, v in zip(ast_node.keys, ast_node.values) } - raise ValueError(f"Unsupported type {type(ast_node)}") + # unsupported node type + return ast_node class ConvertClass(AstDecorator): @@ -421,6 +422,15 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): import ast import copy + arg_nodes = [ + a if isinstance(a, ast.expr) else ast.Constant(value=a) for a in self._args + ] + kwarg_nodes = [] + for k, v in self._kwargs.items(): + if not isinstance(v, ast.expr): + v = ast.Constant(value=v) + kwarg_nodes.append(ast.keyword(arg=k, value=v)) + new_node = copy.deepcopy(wrapped_node) if not hasattr(new_node, "decorator_list"): new_node.decorator_list = [] @@ -431,11 +441,8 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): attr="fixture", ctx=ast.Load(), ), - args=[ast.Constant(value=a) for a in self._args], - keywords=[ - ast.keyword(arg=k, value=ast.Constant(value=v)) - for k, v in self._kwargs.items() - ], + args=arg_nodes, + keywords=kwarg_nodes, ) ) return new_node diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index a70ebfb6d..424a34486 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -28,8 +28,8 @@ if TYPE_CHECKING: import grpc - from google.cloud.bigtable.data import TableAsync - from google.cloud.bigtable.data import Table + from google.cloud.bigtable.data._async.client import _DataApiTargetAsync + from google.cloud.bigtable.data._sync_autogen.client import _DataApiTarget """ Helper functions used in various places in the library. @@ -44,9 +44,10 @@ # used by read_rows_sharded to limit how many requests are attempted in parallel _CONCURRENCY_LIMIT = 10 -# used to register instance data with the client for channel warming +# used to identify an active bigtable resource that needs to be warmed through PingAndWarm +# each instance/app_profile_id pair needs to be individually tracked _WarmedInstanceKey = namedtuple( - "_WarmedInstanceKey", ["instance_name", "table_name", "app_profile_id"] + "_WarmedInstanceKey", ["instance_name", "app_profile_id"] ) @@ -121,7 +122,7 @@ def _retry_exception_factory( def _get_timeouts( operation: float | TABLE_DEFAULT, attempt: float | None | TABLE_DEFAULT, - table: "TableAsync" | "Table", + table: "_DataApiTargetAsync" | "_DataApiTarget", ) -> tuple[float, float]: """ Convert passed in timeout values to floats, using table defaults if necessary. @@ -226,7 +227,7 @@ def _get_error_type( def _get_retryable_errors( call_codes: Sequence["grpc.StatusCode" | int | type[Exception]] | TABLE_DEFAULT, - table: "TableAsync" | "Table", + table: "_DataApiTargetAsync" | "_DataApiTarget", ) -> list[type[Exception]]: """ Convert passed in retryable error codes to a list of exception types. diff --git a/google/cloud/bigtable/data/_sync_autogen/_mutate_rows.py b/google/cloud/bigtable/data/_sync_autogen/_mutate_rows.py index 8e8c5ca89..3bf7b562f 100644 --- a/google/cloud/bigtable/data/_sync_autogen/_mutate_rows.py +++ b/google/cloud/bigtable/data/_sync_autogen/_mutate_rows.py @@ -17,9 +17,9 @@ from __future__ import annotations from typing import Sequence, TYPE_CHECKING -import functools from google.api_core import exceptions as core_exceptions from google.api_core import retry as retries +import google.cloud.bigtable_v2.types.bigtable as types_pb import google.cloud.bigtable.data.exceptions as bt_exceptions from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _retry_exception_factory @@ -32,7 +32,9 @@ from google.cloud.bigtable_v2.services.bigtable.client import ( BigtableClient as GapicClientType, ) - from google.cloud.bigtable.data._sync_autogen.client import Table as TableType + from google.cloud.bigtable.data._sync_autogen.client import ( + _DataApiTarget as TargetType, + ) class _MutateRowsOperation: @@ -47,7 +49,7 @@ class _MutateRowsOperation: Args: gapic_client: the client to use for the mutate_rows call - table: the table associated with the request + target: the table or view associated with the request mutation_entries: a list of RowMutationEntry objects to send to the server operation_timeout: the timeout to use for the entire operation, in seconds. attempt_timeout: the timeout to use for each mutate_rows attempt, in seconds. @@ -57,7 +59,7 @@ class _MutateRowsOperation: def __init__( self, gapic_client: GapicClientType, - table: TableType, + target: TargetType, mutation_entries: list["RowMutationEntry"], operation_timeout: float, attempt_timeout: float | None, @@ -68,12 +70,8 @@ def __init__( raise ValueError( f"mutate_rows requests can contain at most {_MUTATE_ROWS_REQUEST_MUTATION_LIMIT} mutations across all entries. Found {total_mutations}." ) - self._gapic_fn = functools.partial( - gapic_client.mutate_rows, - table_name=table.table_name, - app_profile_id=table.app_profile_id, - retry=None, - ) + self._target = target + self._gapic_fn = gapic_client.mutate_rows self.is_retryable = retries.if_exception_type( *retryable_exceptions, bt_exceptions._MutateRowsIncomplete ) @@ -140,8 +138,12 @@ def _run_attempt(self): return try: result_generator = self._gapic_fn( + request=types_pb.MutateRowsRequest( + entries=request_entries, + app_profile_id=self._target.app_profile_id, + **self._target._request_path, + ), timeout=next(self.timeout_generator), - entries=request_entries, retry=None, ) for result_list in result_generator: diff --git a/google/cloud/bigtable/data/_sync_autogen/_read_rows.py b/google/cloud/bigtable/data/_sync_autogen/_read_rows.py index 92619c6a4..3593475a9 100644 --- a/google/cloud/bigtable/data/_sync_autogen/_read_rows.py +++ b/google/cloud/bigtable/data/_sync_autogen/_read_rows.py @@ -34,7 +34,9 @@ from google.cloud.bigtable.data._cross_sync import CrossSync if TYPE_CHECKING: - from google.cloud.bigtable.data._sync_autogen.client import Table as TableType + from google.cloud.bigtable.data._sync_autogen.client import ( + _DataApiTarget as TargetType, + ) class _ReadRowsOperation: @@ -51,7 +53,7 @@ class _ReadRowsOperation: Args: query: The query to execute - table: The table to send the request to + target: The table or view to send the request to operation_timeout: The total time to allow for the operation, in seconds attempt_timeout: The time to allow for each individual attempt, in seconds retryable_exceptions: A list of exceptions that should trigger a retry @@ -61,7 +63,7 @@ class _ReadRowsOperation: "attempt_timeout_gen", "operation_timeout", "request", - "table", + "target", "_predicate", "_last_yielded_row_key", "_remaining_count", @@ -70,7 +72,7 @@ class _ReadRowsOperation: def __init__( self, query: ReadRowsQuery, - table: TableType, + target: TargetType, operation_timeout: float, attempt_timeout: float, retryable_exceptions: Sequence[type[Exception]] = (), @@ -81,13 +83,11 @@ def __init__( self.operation_timeout = operation_timeout if isinstance(query, dict): self.request = ReadRowsRequestPB( - **query, - table_name=table.table_name, - app_profile_id=table.app_profile_id, + **query, **target._request_path, app_profile_id=target.app_profile_id ) else: - self.request = query._to_pb(table) - self.table = table + self.request = query._to_pb(target) + self.target = target self._predicate = retries.if_exception_type(*retryable_exceptions) self._last_yielded_row_key: bytes | None = None self._remaining_count: int | None = self.request.rows_limit or None @@ -125,7 +125,7 @@ def _read_rows_attempt(self) -> CrossSync._Sync_Impl.Iterable[Row]: self.request.rows_limit = self._remaining_count if self._remaining_count == 0: return self.merge_rows(None) - gapic_stream = self.table.client._gapic_client.read_rows( + gapic_stream = self.target.client._gapic_client.read_rows( self.request, timeout=next(self.attempt_timeout_gen), retry=None ) chunked_stream = self.chunk_stream(gapic_stream) diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py index 5e21c1f51..b36bf359a 100644 --- a/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -18,6 +18,7 @@ from __future__ import annotations from typing import cast, Any, Optional, Set, Sequence, TYPE_CHECKING +import abc import time import warnings import random @@ -38,6 +39,10 @@ DEFAULT_CLIENT_INFO, ) from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest +from google.cloud.bigtable_v2.types.bigtable import SampleRowKeysRequest +from google.cloud.bigtable_v2.types.bigtable import MutateRowRequest +from google.cloud.bigtable_v2.types.bigtable import CheckAndMutateRowRequest +from google.cloud.bigtable_v2.types.bigtable import ReadModifyWriteRowRequest from google.cloud.client import ClientWithProject from google.cloud.environment_vars import BIGTABLE_EMULATOR from google.api_core import retry as retries @@ -243,7 +248,7 @@ def _ping_and_warm_instances( ], wait_for_ready=True, ) - for (instance_name, table_name, app_profile_id) in instance_list + for (instance_name, app_profile_id) in instance_list ] result_list = CrossSync._Sync_Impl.gather_partials( partial_list, return_exceptions=True, sync_executor=self._executor @@ -300,7 +305,7 @@ def _manage_channel( next_sleep = max(next_refresh - (time.monotonic() - start_timestamp), 0) def _register_instance( - self, instance_id: str, owner: Table | ExecuteQueryIterator + self, instance_id: str, owner: _DataApiTarget | ExecuteQueryIterator ) -> None: """Registers an instance with the client, and warms the channel for the instance The client will periodically refresh grpc channel used to make @@ -313,9 +318,7 @@ def _register_instance( _instance_owners, and instances will only be unregistered when all owners call _remove_instance_registration""" instance_name = self._gapic_client.instance_path(self.project, instance_id) - instance_key = _WarmedInstanceKey( - instance_name, owner.table_name, owner.app_profile_id - ) + instance_key = _WarmedInstanceKey(instance_name, owner.app_profile_id) self._instance_owners.setdefault(instance_key, set()).add(id(owner)) if instance_key not in self._active_instances: self._active_instances.add(instance_key) @@ -325,7 +328,7 @@ def _register_instance( self._start_background_channel_refresh() def _remove_instance_registration( - self, instance_id: str, owner: Table | "ExecuteQueryIterator" + self, instance_id: str, owner: _DataApiTarget | ExecuteQueryIterator ) -> bool: """Removes an instance from the client's registered instances, to prevent warming new channels for the instance @@ -340,9 +343,7 @@ def _remove_instance_registration( Returns: bool: True if instance was removed, else False""" instance_name = self._gapic_client.instance_path(self.project, instance_id) - instance_key = _WarmedInstanceKey( - instance_name, owner.table_name, owner.app_profile_id - ) + instance_key = _WarmedInstanceKey(instance_name, owner.app_profile_id) owner_list = self._instance_owners.get(instance_key, set()) try: owner_list.remove(id(owner)) @@ -393,6 +394,52 @@ def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> Table: None""" return Table(self, instance_id, table_id, *args, **kwargs) + def get_authorized_view( + self, instance_id: str, table_id: str, authorized_view_id: str, *args, **kwargs + ) -> AuthorizedView: + """Returns an authorized view instance for making data API requests. All arguments are passed + directly to the AuthorizedView constructor. + + + + Args: + instance_id: The Bigtable instance ID to associate with this client. + instance_id is combined with the client's project to fully + specify the instance + table_id: The ID of the table. table_id is combined with the + instance_id and the client's project to fully specify the table + authorized_view_id: The id for the authorized view to use for requests + app_profile_id: The app profile to associate with requests. + https://cloud.google.com/bigtable/docs/app-profiles + default_read_rows_operation_timeout: The default timeout for read rows + operations, in seconds. If not set, defaults to Table's value + default_read_rows_attempt_timeout: The default timeout for individual + read rows rpc requests, in seconds. If not set, defaults Table's value + default_mutate_rows_operation_timeout: The default timeout for mutate rows + operations, in seconds. If not set, defaults to Table's value + default_mutate_rows_attempt_timeout: The default timeout for individual + mutate rows rpc requests, in seconds. If not set, defaults Table's value + default_operation_timeout: The default timeout for all other operations, in + seconds. If not set, defaults to Table's value + default_attempt_timeout: The default timeout for all other individual rpc + requests, in seconds. If not set, defaults to Table's value + default_read_rows_retryable_errors: a list of errors that will be retried + if encountered during read_rows and related operations. If not set, + defaults to Table's value + default_mutate_rows_retryable_errors: a list of errors that will be retried + if encountered during mutate_rows and related operations. If not set, + defaults to Table's value + default_retryable_errors: a list of errors that will be retried if + encountered during all other operations. If not set, defaults to + Table's value + Returns: + AuthorizedView: a table instance for making data API requests + Raises: + None""" + return CrossSync._Sync_Impl.AuthorizedView( + self, instance_id, table_id, authorized_view_id, *args, **kwargs + ) + def execute_query( self, query: str, @@ -532,13 +579,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._gapic_client.__exit__(exc_type, exc_val, exc_tb) -@CrossSync._Sync_Impl.add_mapping_decorator("Table") -class Table: +class _DataApiTarget(abc.ABC): """ - Main Data API surface + Abstract class containing API surface for BigtableDataClient. Should not be created directly - Table object maintains table_id, and app_profile_id context, and passes them with - each call + Can be instantiated as a Table or an AuthorizedView """ def __init__( @@ -653,6 +698,18 @@ def __init__( f"{self.__class__.__name__} must be created within an async event loop context." ) from e + @property + @abc.abstractmethod + def _request_path(self) -> dict[str, str]: + """Used to populate table_name or authorized_view_name for rpc requests, depending on the subclass + + Unimplemented in base class""" + raise NotImplementedError + + def __str__(self): + path_str = list(self._request_path.values())[0] if self._request_path else "" + return f"{self.__class__.__name__}<{path_str!r}>" + def read_rows_stream( self, query: ReadRowsQuery, @@ -979,8 +1036,9 @@ def sample_row_keys( def execute_rpc(): results = self.client._gapic_client.sample_row_keys( - table_name=self.table_name, - app_profile_id=self.app_profile_id, + request=SampleRowKeysRequest( + app_profile_id=self.app_profile_id, **self._request_path + ), timeout=next(attempt_timeout_gen), retry=None, ) @@ -1096,10 +1154,14 @@ def mutate_row( sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) target = partial( self.client._gapic_client.mutate_row, - row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, - mutations=[mutation._to_pb() for mutation in mutations_list], - table_name=self.table_name, - app_profile_id=self.app_profile_id, + request=MutateRowRequest( + row_key=row_key.encode("utf-8") + if isinstance(row_key, str) + else row_key, + mutations=[mutation._to_pb() for mutation in mutations_list], + app_profile_id=self.app_profile_id, + **self._request_path, + ), timeout=attempt_timeout, retry=None, ) @@ -1214,12 +1276,16 @@ def check_and_mutate_row( false_case_mutations = [false_case_mutations] false_case_list = [m._to_pb() for m in false_case_mutations or []] result = self.client._gapic_client.check_and_mutate_row( - true_mutations=true_case_list, - false_mutations=false_case_list, - predicate_filter=predicate._to_pb() if predicate is not None else None, - row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, - table_name=self.table_name, - app_profile_id=self.app_profile_id, + request=CheckAndMutateRowRequest( + true_mutations=true_case_list, + false_mutations=false_case_list, + predicate_filter=predicate._to_pb() if predicate is not None else None, + row_key=row_key.encode("utf-8") + if isinstance(row_key, str) + else row_key, + app_profile_id=self.app_profile_id, + **self._request_path, + ), timeout=operation_timeout, retry=None, ) @@ -1261,10 +1327,14 @@ def read_modify_write_row( if not rules: raise ValueError("rules must contain at least one item") result = self.client._gapic_client.read_modify_write_row( - rules=[rule._to_pb() for rule in rules], - row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, - table_name=self.table_name, - app_profile_id=self.app_profile_id, + request=ReadModifyWriteRowRequest( + rules=[rule._to_pb() for rule in rules], + row_key=row_key.encode("utf-8") + if isinstance(row_key, str) + else row_key, + app_profile_id=self.app_profile_id, + **self._request_path, + ), timeout=operation_timeout, retry=None, ) @@ -1291,3 +1361,85 @@ def __exit__(self, exc_type, exc_val, exc_tb): Unregister this instance with the client, so that grpc channels will no longer be warmed""" self.close() + + +@CrossSync._Sync_Impl.add_mapping_decorator("Table") +class Table(_DataApiTarget): + """ + Main Data API surface for interacting with a Bigtable table. + + Table object maintains table_id, and app_profile_id context, and passes them with + each call + """ + + @property + def _request_path(self) -> dict[str, str]: + return {"table_name": self.table_name} + + +@CrossSync._Sync_Impl.add_mapping_decorator("AuthorizedView") +class AuthorizedView(_DataApiTarget): + """ + Provides access to an authorized view of a table. + + An authorized view is a subset of a table that you configure to include specific table data. + Then you grant access to the authorized view separately from access to the table. + + AuthorizedView object maintains table_id, app_profile_id, and authorized_view_id context, + and passed them with each call + """ + + def __init__( + self, + client, + instance_id, + table_id, + authorized_view_id, + app_profile_id: str | None = None, + **kwargs, + ): + """Initialize an AuthorizedView instance + + + + Args: + instance_id: The Bigtable instance ID to associate with this client. + instance_id is combined with the client's project to fully + specify the instance + table_id: The ID of the table. table_id is combined with the + instance_id and the client's project to fully specify the table + authorized_view_id: The id for the authorized view to use for requests + app_profile_id: The app profile to associate with requests. + https://cloud.google.com/bigtable/docs/app-profiles + default_read_rows_operation_timeout: The default timeout for read rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_read_rows_attempt_timeout: The default timeout for individual + read rows rpc requests, in seconds. If not set, defaults to 20 seconds + default_mutate_rows_operation_timeout: The default timeout for mutate rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_mutate_rows_attempt_timeout: The default timeout for individual + mutate rows rpc requests, in seconds. If not set, defaults to 60 seconds + default_operation_timeout: The default timeout for all other operations, in + seconds. If not set, defaults to 60 seconds + default_attempt_timeout: The default timeout for all other individual rpc + requests, in seconds. If not set, defaults to 20 seconds + default_read_rows_retryable_errors: a list of errors that will be retried + if encountered during read_rows and related operations. + Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted) + default_mutate_rows_retryable_errors: a list of errors that will be retried + if encountered during mutate_rows and related operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + default_retryable_errors: a list of errors that will be retried if + encountered during all other operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + Raises: + None""" + super().__init__(client, instance_id, table_id, app_profile_id, **kwargs) + self.authorized_view_id = authorized_view_id + self.authorized_view_name: str = self.client._gapic_client.authorized_view_path( + self.client.project, instance_id, table_id, authorized_view_id + ) + + @property + def _request_path(self) -> dict[str, str]: + return {"authorized_view_name": self.authorized_view_name} diff --git a/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py b/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py index 2e4237b74..84f0ba8c0 100644 --- a/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py @@ -32,7 +32,9 @@ if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry - from google.cloud.bigtable.data._sync_autogen.client import Table as TableType + from google.cloud.bigtable.data._sync_autogen.client import ( + _DataApiTarget as TargetType, + ) _MB_SIZE = 1024 * 1024 @@ -148,7 +150,7 @@ def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): class MutationsBatcher: """ - Allows users to send batches using context manager API: + Allows users to send batches using context manager API. Runs mutate_row, mutate_rows, and check_and_mutate_row internally, combining to use as few network requests as required @@ -160,7 +162,7 @@ class MutationsBatcher: - when batcher is closed or destroyed Args: - table: Table to preform rpc calls + table: table or autrhorized_view used to preform rpc calls flush_interval: Automatically flush every flush_interval seconds. If None, no time-based flushing is performed. flush_limit_mutation_count: Flush immediately after flush_limit_mutation_count @@ -179,7 +181,7 @@ class MutationsBatcher: def __init__( self, - table: TableType, + table: TargetType, *, flush_interval: float | None = 5, flush_limit_mutation_count: int | None = 1000, @@ -198,7 +200,7 @@ def __init__( batch_retryable_errors, table ) self._closed = CrossSync._Sync_Impl.Event() - self._table = table + self._target = table self._staged_entries: list[RowMutationEntry] = [] (self._staged_count, self._staged_bytes) = (0, 0) self._flow_control = CrossSync._Sync_Impl._FlowControl( @@ -324,8 +326,8 @@ def _execute_mutate_rows( FailedMutationEntryError objects will not contain index information""" try: operation = CrossSync._Sync_Impl._MutateRowsOperation( - self._table.client._gapic_client, - self._table, + self._target.client._gapic_client, + self._target, batch, operation_timeout=self._operation_timeout, attempt_timeout=self._attempt_timeout, @@ -414,7 +416,7 @@ def _on_exit(self): """Called when program is exited. Raises warning if unflushed mutations remain""" if not self._closed.is_set() and self._staged_entries: warnings.warn( - f"MutationsBatcher for table {self._table.table_name} was not closed. {len(self._staged_entries)} Unflushed mutations will not be sent to the server." + f"MutationsBatcher for target {self._target!r} was not closed. {len(self._staged_entries)} Unflushed mutations will not be sent to the server." ) @staticmethod diff --git a/google/cloud/bigtable/data/read_rows_query.py b/google/cloud/bigtable/data/read_rows_query.py index e0839a2af..7652bfbb9 100644 --- a/google/cloud/bigtable/data/read_rows_query.py +++ b/google/cloud/bigtable/data/read_rows_query.py @@ -489,11 +489,11 @@ def _to_pb(self, table) -> ReadRowsRequestPB: ReadRowsRequest protobuf """ return ReadRowsRequestPB( - table_name=table.table_name, app_profile_id=table.app_profile_id, filter=self.filter._to_pb() if self.filter else None, rows_limit=self.limit or 0, rows=self._row_set, + **table._request_path, ) def __eq__(self, other): diff --git a/tests/system/data/setup_fixtures.py b/tests/system/data/setup_fixtures.py index 3b5a0af06..a77ffc008 100644 --- a/tests/system/data/setup_fixtures.py +++ b/tests/system/data/setup_fixtures.py @@ -20,6 +20,12 @@ import os import uuid +from . import TEST_FAMILY, TEST_FAMILY_2 + +# authorized view subset to allow all qualifiers +ALLOW_ALL = "" +ALL_QUALIFIERS = {"qualifier_prefixes": [ALLOW_ALL]} + @pytest.fixture(scope="session") def admin_client(): @@ -140,6 +146,63 @@ def table_id( print(f"Table {init_table_id} not found, skipping deletion") +@pytest.fixture(scope="session") +def authorized_view_id( + admin_client, + project_id, + instance_id, + table_id, +): + """ + Creates and returns a new temporary authorized view for the test session + + Args: + - admin_client: Client for interacting with the Table Admin API. Supplied by the admin_client fixture. + - project_id: The project ID of the GCP project to test against. Supplied by the project_id fixture. + - instance_id: The ID of the Bigtable instance to test against. Supplied by the instance_id fixture. + - table_id: The ID of the table to create the authorized view for. Supplied by the table_id fixture. + """ + from google.api_core import exceptions + from google.api_core import retry + + retry = retry.Retry( + predicate=retry.if_exception_type(exceptions.FailedPrecondition) + ) + new_view_id = uuid.uuid4().hex[:8] + parent_path = f"projects/{project_id}/instances/{instance_id}/tables/{table_id}" + new_path = f"{parent_path}/authorizedViews/{new_view_id}" + try: + print(f"Creating view: {new_path}") + admin_client.table_admin_client.create_authorized_view( + request={ + "parent": parent_path, + "authorized_view_id": new_view_id, + "authorized_view": { + "subset_view": { + "row_prefixes": [ALLOW_ALL], + "family_subsets": { + TEST_FAMILY: ALL_QUALIFIERS, + TEST_FAMILY_2: ALL_QUALIFIERS, + }, + }, + }, + }, + retry=retry, + ) + except exceptions.AlreadyExists: + pass + except exceptions.MethodNotImplemented: + # will occur when run in emulator. Pass empty id + new_view_id = None + yield new_view_id + if new_view_id: + print(f"Deleting view: {new_path}") + try: + admin_client.table_admin_client.delete_authorized_view(name=new_path) + except exceptions.NotFound: + print(f"View {new_view_id} not found, skipping deletion") + + @pytest.fixture(scope="session") def project_id(client): """Returns the project ID from the client.""" diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 53e97acc1..3eba384e1 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -18,7 +18,7 @@ import uuid import os from google.api_core import retry -from google.api_core.exceptions import ClientError +from google.api_core.exceptions import ClientError, PermissionDenied from google.cloud.bigtable.data.execute_query.metadata import SqlType from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE @@ -33,6 +33,12 @@ __CROSS_SYNC_OUTPUT__ = "tests.system.data.test_system_autogen" +TARGETS = ["table"] +if not os.environ.get(BIGTABLE_EMULATOR): + # emulator doesn't support authorized views + TARGETS.append("authorized_view") + + @CrossSync.convert_class( sync_name="TempRowBuilder", add_mapping_for_name="TempRowBuilder", @@ -42,9 +48,9 @@ class TempRowBuilderAsync: Used to add rows to a table for testing purposes. """ - def __init__(self, table): + def __init__(self, target): self.rows = [] - self.table = table + self.target = target @CrossSync.convert async def add_row( @@ -55,7 +61,7 @@ async def add_row( elif isinstance(value, int): value = value.to_bytes(8, byteorder="big", signed=True) request = { - "table_name": self.table.table_name, + "table_name": self.target.table_name, "row_key": row_key, "mutations": [ { @@ -67,20 +73,20 @@ async def add_row( } ], } - await self.table.client._gapic_client.mutate_row(request) + await self.target.client._gapic_client.mutate_row(request) self.rows.append(row_key) @CrossSync.convert async def delete_rows(self): if self.rows: request = { - "table_name": self.table.table_name, + "table_name": self.target.table_name, "entries": [ {"row_key": row, "mutations": [{"delete_from_row": {}}]} for row in self.rows ], } - await self.table.client._gapic_client.mutate_rows(request) + await self.target.client._gapic_client.mutate_rows(request) @CrossSync.convert_class(sync_name="TestSystem") @@ -93,10 +99,23 @@ async def client(self): yield client @CrossSync.convert - @CrossSync.pytest_fixture(scope="session") - async def table(self, client, table_id, instance_id): - async with client.get_table(instance_id, table_id) as table: - yield table + @CrossSync.pytest_fixture(scope="session", params=TARGETS) + async def target(self, client, table_id, authorized_view_id, instance_id, request): + """ + This fixture runs twice: once for a standard table, and once with an authorized view + + Note: emulator doesn't support authorized views. Only use target + """ + if request.param == "table": + async with client.get_table(instance_id, table_id) as table: + yield table + elif request.param == "authorized_view": + async with client.get_authorized_view( + instance_id, table_id, authorized_view_id + ) as view: + yield view + else: + raise ValueError(f"unknown target type: {request.param}") @CrossSync.drop @pytest.fixture(scope="session") @@ -138,14 +157,14 @@ def cluster_config(self, project_id): return cluster @CrossSync.convert - @pytest.mark.usefixtures("table") - async def _retrieve_cell_value(self, table, row_key): + @pytest.mark.usefixtures("target") + async def _retrieve_cell_value(self, target, row_key): """ Helper to read an individual row """ from google.cloud.bigtable.data import ReadRowsQuery - row_list = await table.read_rows(ReadRowsQuery(row_keys=row_key)) + row_list = await target.read_rows(ReadRowsQuery(row_keys=row_key)) assert len(row_list) == 1 row = row_list[0] cell = row.cells[0] @@ -174,32 +193,32 @@ async def _create_row_and_mutation( @CrossSync.convert @CrossSync.pytest_fixture(scope="function") - async def temp_rows(self, table): - builder = CrossSync.TempRowBuilder(table) + async def temp_rows(self, target): + builder = CrossSync.TempRowBuilder(target) yield builder await builder.delete_rows() - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @pytest.mark.usefixtures("client") @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=10 ) @CrossSync.pytest - async def test_ping_and_warm_gapic(self, client, table): + async def test_ping_and_warm_gapic(self, client, target): """ Simple ping rpc test This test ensures channels are able to authenticate with backend """ - request = {"name": table.instance_name} + request = {"name": target.instance_name} await client._gapic_client.ping_and_warm(request) - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @pytest.mark.usefixtures("client") @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) @CrossSync.pytest - async def test_ping_and_warm(self, client, table): + async def test_ping_and_warm(self, client, target): """ Test ping and warm from handwritten client """ @@ -240,41 +259,43 @@ async def test_channel_refresh(self, table_id, instance_id, temp_rows): await client.close() @CrossSync.pytest - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - async def test_mutation_set_cell(self, table, temp_rows): + async def test_mutation_set_cell(self, target, temp_rows): """ Ensure cells can be set properly """ row_key = b"bulk_mutate" new_value = uuid.uuid4().hex.encode() row_key, mutation = await self._create_row_and_mutation( - table, temp_rows, new_value=new_value + target, temp_rows, new_value=new_value ) - await table.mutate_row(row_key, mutation) + await target.mutate_row(row_key, mutation) # ensure cell is updated - assert (await self._retrieve_cell_value(table, row_key)) == new_value + assert (await self._retrieve_cell_value(target, row_key)) == new_value @pytest.mark.skipif( bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't use splits" ) @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) @CrossSync.pytest - async def test_sample_row_keys(self, client, table, temp_rows, column_split_config): + async def test_sample_row_keys( + self, client, target, temp_rows, column_split_config + ): """ - Sample keys should return a single sample in small test tables + Sample keys should return a single sample in small test targets """ await temp_rows.add_row(b"row_key_1") await temp_rows.add_row(b"row_key_2") - results = await table.sample_row_keys() + results = await target.sample_row_keys() assert len(results) == len(column_split_config) + 1 # first keys should match the split config for idx in range(len(column_split_config)): @@ -285,9 +306,9 @@ async def test_sample_row_keys(self, client, table, temp_rows, column_split_conf assert isinstance(results[-1][1], int) @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.pytest - async def test_bulk_mutations_set_cell(self, client, table, temp_rows): + async def test_bulk_mutations_set_cell(self, client, target, temp_rows): """ Ensure cells can be set properly """ @@ -295,17 +316,17 @@ async def test_bulk_mutations_set_cell(self, client, table, temp_rows): new_value = uuid.uuid4().hex.encode() row_key, mutation = await self._create_row_and_mutation( - table, temp_rows, new_value=new_value + target, temp_rows, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) - await table.bulk_mutate_rows([bulk_mutation]) + await target.bulk_mutate_rows([bulk_mutation]) # ensure cell is updated - assert (await self._retrieve_cell_value(table, row_key)) == new_value + assert (await self._retrieve_cell_value(target, row_key)) == new_value @CrossSync.pytest - async def test_bulk_mutations_raise_exception(self, client, table): + async def test_bulk_mutations_raise_exception(self, client, target): """ If an invalid mutation is passed, an exception should be raised """ @@ -320,7 +341,7 @@ async def test_bulk_mutations_raise_exception(self, client, table): bulk_mutation = RowMutationEntry(row_key, [mutation]) with pytest.raises(MutationsExceptionGroup) as exc: - await table.bulk_mutate_rows([bulk_mutation]) + await target.bulk_mutate_rows([bulk_mutation]) assert len(exc.value.exceptions) == 1 entry_error = exc.value.exceptions[0] assert isinstance(entry_error, FailedMutationEntryError) @@ -328,12 +349,12 @@ async def test_bulk_mutations_raise_exception(self, client, table): assert entry_error.entry == bulk_mutation @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) @CrossSync.pytest - async def test_mutations_batcher_context_manager(self, client, table, temp_rows): + async def test_mutations_batcher_context_manager(self, client, target, temp_rows): """ test batcher with context manager. Should flush on exit """ @@ -341,28 +362,28 @@ async def test_mutations_batcher_context_manager(self, client, table, temp_rows) new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] row_key, mutation = await self._create_row_and_mutation( - table, temp_rows, new_value=new_value + target, temp_rows, new_value=new_value ) row_key2, mutation2 = await self._create_row_and_mutation( - table, temp_rows, new_value=new_value2 + target, temp_rows, new_value=new_value2 ) bulk_mutation = RowMutationEntry(row_key, [mutation]) bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - async with table.mutations_batcher() as batcher: + async with target.mutations_batcher() as batcher: await batcher.append(bulk_mutation) await batcher.append(bulk_mutation2) # ensure cell is updated - assert (await self._retrieve_cell_value(table, row_key)) == new_value + assert (await self._retrieve_cell_value(target, row_key)) == new_value assert len(batcher._staged_entries) == 0 @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) @CrossSync.pytest - async def test_mutations_batcher_timer_flush(self, client, table, temp_rows): + async def test_mutations_batcher_timer_flush(self, client, target, temp_rows): """ batch should occur after flush_interval seconds """ @@ -370,26 +391,26 @@ async def test_mutations_batcher_timer_flush(self, client, table, temp_rows): new_value = uuid.uuid4().hex.encode() row_key, mutation = await self._create_row_and_mutation( - table, temp_rows, new_value=new_value + target, temp_rows, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) flush_interval = 0.1 - async with table.mutations_batcher(flush_interval=flush_interval) as batcher: + async with target.mutations_batcher(flush_interval=flush_interval) as batcher: await batcher.append(bulk_mutation) await CrossSync.yield_to_event_loop() assert len(batcher._staged_entries) == 1 await CrossSync.sleep(flush_interval + 0.1) assert len(batcher._staged_entries) == 0 # ensure cell is updated - assert (await self._retrieve_cell_value(table, row_key)) == new_value + assert (await self._retrieve_cell_value(target, row_key)) == new_value @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) @CrossSync.pytest - async def test_mutations_batcher_count_flush(self, client, table, temp_rows): + async def test_mutations_batcher_count_flush(self, client, target, temp_rows): """ batch should flush after flush_limit_mutation_count mutations """ @@ -397,15 +418,15 @@ async def test_mutations_batcher_count_flush(self, client, table, temp_rows): new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] row_key, mutation = await self._create_row_and_mutation( - table, temp_rows, new_value=new_value + target, temp_rows, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) row_key2, mutation2 = await self._create_row_and_mutation( - table, temp_rows, new_value=new_value2 + target, temp_rows, new_value=new_value2 ) bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - async with table.mutations_batcher(flush_limit_mutation_count=2) as batcher: + async with target.mutations_batcher(flush_limit_mutation_count=2) as batcher: await batcher.append(bulk_mutation) assert len(batcher._flush_jobs) == 0 # should be noop; flush not scheduled @@ -421,16 +442,16 @@ async def test_mutations_batcher_count_flush(self, client, table, temp_rows): assert len(batcher._staged_entries) == 0 assert len(batcher._flush_jobs) == 0 # ensure cells were updated - assert (await self._retrieve_cell_value(table, row_key)) == new_value - assert (await self._retrieve_cell_value(table, row_key2)) == new_value2 + assert (await self._retrieve_cell_value(target, row_key)) == new_value + assert (await self._retrieve_cell_value(target, row_key2)) == new_value2 @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) @CrossSync.pytest - async def test_mutations_batcher_bytes_flush(self, client, table, temp_rows): + async def test_mutations_batcher_bytes_flush(self, client, target, temp_rows): """ batch should flush after flush_limit_bytes bytes """ @@ -438,17 +459,17 @@ async def test_mutations_batcher_bytes_flush(self, client, table, temp_rows): new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] row_key, mutation = await self._create_row_and_mutation( - table, temp_rows, new_value=new_value + target, temp_rows, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) row_key2, mutation2 = await self._create_row_and_mutation( - table, temp_rows, new_value=new_value2 + target, temp_rows, new_value=new_value2 ) bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) flush_limit = bulk_mutation.size() + bulk_mutation2.size() - 1 - async with table.mutations_batcher(flush_limit_bytes=flush_limit) as batcher: + async with target.mutations_batcher(flush_limit_bytes=flush_limit) as batcher: await batcher.append(bulk_mutation) assert len(batcher._flush_jobs) == 0 assert len(batcher._staged_entries) == 1 @@ -462,13 +483,13 @@ async def test_mutations_batcher_bytes_flush(self, client, table, temp_rows): # for sync version: grab result future.result() # ensure cells were updated - assert (await self._retrieve_cell_value(table, row_key)) == new_value - assert (await self._retrieve_cell_value(table, row_key2)) == new_value2 + assert (await self._retrieve_cell_value(target, row_key)) == new_value + assert (await self._retrieve_cell_value(target, row_key2)) == new_value2 @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.pytest - async def test_mutations_batcher_no_flush(self, client, table, temp_rows): + async def test_mutations_batcher_no_flush(self, client, target, temp_rows): """ test with no flush requirements met """ @@ -477,16 +498,16 @@ async def test_mutations_batcher_no_flush(self, client, table, temp_rows): new_value = uuid.uuid4().hex.encode() start_value = b"unchanged" row_key, mutation = await self._create_row_and_mutation( - table, temp_rows, start_value=start_value, new_value=new_value + target, temp_rows, start_value=start_value, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) row_key2, mutation2 = await self._create_row_and_mutation( - table, temp_rows, start_value=start_value, new_value=new_value + target, temp_rows, start_value=start_value, new_value=new_value ) bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) size_limit = bulk_mutation.size() + bulk_mutation2.size() + 1 - async with table.mutations_batcher( + async with target.mutations_batcher( flush_limit_bytes=size_limit, flush_limit_mutation_count=3, flush_interval=1 ) as batcher: await batcher.append(bulk_mutation) @@ -498,16 +519,16 @@ async def test_mutations_batcher_no_flush(self, client, table, temp_rows): assert len(batcher._staged_entries) == 2 assert len(batcher._flush_jobs) == 0 # ensure cells were not updated - assert (await self._retrieve_cell_value(table, row_key)) == start_value - assert (await self._retrieve_cell_value(table, row_key2)) == start_value + assert (await self._retrieve_cell_value(target, row_key)) == start_value + assert (await self._retrieve_cell_value(target, row_key2)) == start_value @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) @CrossSync.pytest - async def test_mutations_batcher_large_batch(self, client, table, temp_rows): + async def test_mutations_batcher_large_batch(self, client, target, temp_rows): """ test batcher with large batch of mutations """ @@ -523,14 +544,14 @@ async def test_mutations_batcher_large_batch(self, client, table, temp_rows): # append row key for eventual deletion temp_rows.rows.append(row_key) - async with table.mutations_batcher() as batcher: + async with target.mutations_batcher() as batcher: for mutation in row_mutations: await batcher.append(mutation) # ensure cell is updated assert len(batcher._staged_entries) == 0 @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @pytest.mark.parametrize( "start,increment,expected", [ @@ -548,7 +569,7 @@ async def test_mutations_batcher_large_batch(self, client, table, temp_rows): ) @CrossSync.pytest async def test_read_modify_write_row_increment( - self, client, table, temp_rows, start, increment, expected + self, client, target, temp_rows, start, increment, expected ): """ test read_modify_write_row @@ -563,17 +584,17 @@ async def test_read_modify_write_row_increment( ) rule = IncrementRule(family, qualifier, increment) - result = await table.read_modify_write_row(row_key, rule) + result = await target.read_modify_write_row(row_key, rule) assert result.row_key == row_key assert len(result) == 1 assert result[0].family == family assert result[0].qualifier == qualifier assert int(result[0]) == expected # ensure that reading from server gives same value - assert (await self._retrieve_cell_value(table, row_key)) == result[0].value + assert (await self._retrieve_cell_value(target, row_key)) == result[0].value @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @pytest.mark.parametrize( "start,append,expected", [ @@ -588,7 +609,7 @@ async def test_read_modify_write_row_increment( ) @CrossSync.pytest async def test_read_modify_write_row_append( - self, client, table, temp_rows, start, append, expected + self, client, target, temp_rows, start, append, expected ): """ test read_modify_write_row @@ -603,19 +624,19 @@ async def test_read_modify_write_row_append( ) rule = AppendValueRule(family, qualifier, append) - result = await table.read_modify_write_row(row_key, rule) + result = await target.read_modify_write_row(row_key, rule) assert result.row_key == row_key assert len(result) == 1 assert result[0].family == family assert result[0].qualifier == qualifier assert result[0].value == expected # ensure that reading from server gives same value - assert (await self._retrieve_cell_value(table, row_key)) == result[0].value + assert (await self._retrieve_cell_value(target, row_key)) == result[0].value @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.pytest - async def test_read_modify_write_row_chained(self, client, table, temp_rows): + async def test_read_modify_write_row_chained(self, client, target, temp_rows): """ test read_modify_write_row with multiple rules """ @@ -636,7 +657,7 @@ async def test_read_modify_write_row_chained(self, client, table, temp_rows): AppendValueRule(family, qualifier, "world"), AppendValueRule(family, qualifier, "!"), ] - result = await table.read_modify_write_row(row_key, rule) + result = await target.read_modify_write_row(row_key, rule) assert result.row_key == row_key assert result[0].family == family assert result[0].qualifier == qualifier @@ -647,10 +668,10 @@ async def test_read_modify_write_row_chained(self, client, table, temp_rows): + b"helloworld!" ) # ensure that reading from server gives same value - assert (await self._retrieve_cell_value(table, row_key)) == result[0].value + assert (await self._retrieve_cell_value(target, row_key)) == result[0].value @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @pytest.mark.parametrize( "start_val,predicate_range,expected_result", [ @@ -660,7 +681,7 @@ async def test_read_modify_write_row_chained(self, client, table, temp_rows): ) @CrossSync.pytest async def test_check_and_mutate( - self, client, table, temp_rows, start_val, predicate_range, expected_result + self, client, target, temp_rows, start_val, predicate_range, expected_result ): """ test that check_and_mutate_row works applies the right mutations, and returns the right result @@ -685,7 +706,7 @@ async def test_check_and_mutate( family=TEST_FAMILY, qualifier=qualifier, new_value=true_mutation_value ) predicate = ValueRangeFilter(predicate_range[0], predicate_range[1]) - result = await table.check_and_mutate_row( + result = await target.check_and_mutate_row( row_key, predicate, true_case_mutations=true_mutation, @@ -696,34 +717,34 @@ async def test_check_and_mutate( expected_value = ( true_mutation_value if expected_result else false_mutation_value ) - assert (await self._retrieve_cell_value(table, row_key)) == expected_value + assert (await self._retrieve_cell_value(target, row_key)) == expected_value @pytest.mark.skipif( bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't raise InvalidArgument", ) @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.pytest - async def test_check_and_mutate_empty_request(self, client, table): + async def test_check_and_mutate_empty_request(self, client, target): """ check_and_mutate with no true or fale mutations should raise an error """ from google.api_core import exceptions with pytest.raises(exceptions.InvalidArgument) as e: - await table.check_and_mutate_row( + await target.check_and_mutate_row( b"row_key", None, true_case_mutations=None, false_case_mutations=None ) assert "No mutations provided" in str(e.value) - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.convert(replace_symbols={"__anext__": "__next__"}) @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) @CrossSync.pytest - async def test_read_rows_stream(self, table, temp_rows): + async def test_read_rows_stream(self, target, temp_rows): """ Ensure that the read_rows_stream method works """ @@ -731,7 +752,7 @@ async def test_read_rows_stream(self, table, temp_rows): await temp_rows.add_row(b"row_key_2") # full table scan - generator = await table.read_rows_stream({}) + generator = await target.read_rows_stream({}) first_row = await generator.__anext__() second_row = await generator.__anext__() assert first_row.row_key == b"row_key_1" @@ -739,29 +760,29 @@ async def test_read_rows_stream(self, table, temp_rows): with pytest.raises(CrossSync.StopIteration): await generator.__anext__() - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) @CrossSync.pytest - async def test_read_rows(self, table, temp_rows): + async def test_read_rows(self, target, temp_rows): """ Ensure that the read_rows method works """ await temp_rows.add_row(b"row_key_1") await temp_rows.add_row(b"row_key_2") # full table scan - row_list = await table.read_rows({}) + row_list = await target.read_rows({}) assert len(row_list) == 2 assert row_list[0].row_key == b"row_key_1" assert row_list[1].row_key == b"row_key_2" - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) @CrossSync.pytest - async def test_read_rows_sharded_simple(self, table, temp_rows): + async def test_read_rows_sharded_simple(self, target, temp_rows): """ Test read rows sharded with two queries """ @@ -773,19 +794,19 @@ async def test_read_rows_sharded_simple(self, table, temp_rows): await temp_rows.add_row(b"d") query1 = ReadRowsQuery(row_keys=[b"a", b"c"]) query2 = ReadRowsQuery(row_keys=[b"b", b"d"]) - row_list = await table.read_rows_sharded([query1, query2]) + row_list = await target.read_rows_sharded([query1, query2]) assert len(row_list) == 4 assert row_list[0].row_key == b"a" assert row_list[1].row_key == b"c" assert row_list[2].row_key == b"b" assert row_list[3].row_key == b"d" - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) @CrossSync.pytest - async def test_read_rows_sharded_from_sample(self, table, temp_rows): + async def test_read_rows_sharded_from_sample(self, target, temp_rows): """ Test end-to-end sharding """ @@ -797,21 +818,21 @@ async def test_read_rows_sharded_from_sample(self, table, temp_rows): await temp_rows.add_row(b"c") await temp_rows.add_row(b"d") - table_shard_keys = await table.sample_row_keys() + table_shard_keys = await target.sample_row_keys() query = ReadRowsQuery(row_ranges=[RowRange(start_key=b"b", end_key=b"z")]) shard_queries = query.shard(table_shard_keys) - row_list = await table.read_rows_sharded(shard_queries) + row_list = await target.read_rows_sharded(shard_queries) assert len(row_list) == 3 assert row_list[0].row_key == b"b" assert row_list[1].row_key == b"c" assert row_list[2].row_key == b"d" - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) @CrossSync.pytest - async def test_read_rows_sharded_filters_limits(self, table, temp_rows): + async def test_read_rows_sharded_filters_limits(self, target, temp_rows): """ Test read rows sharded with filters and limits """ @@ -827,7 +848,7 @@ async def test_read_rows_sharded_filters_limits(self, table, temp_rows): label_filter2 = ApplyLabelFilter("second") query1 = ReadRowsQuery(row_keys=[b"a", b"c"], limit=1, row_filter=label_filter1) query2 = ReadRowsQuery(row_keys=[b"b", b"d"], row_filter=label_filter2) - row_list = await table.read_rows_sharded([query1, query2]) + row_list = await target.read_rows_sharded([query1, query2]) assert len(row_list) == 3 assert row_list[0].row_key == b"a" assert row_list[1].row_key == b"b" @@ -836,12 +857,12 @@ async def test_read_rows_sharded_filters_limits(self, table, temp_rows): assert row_list[1][0].labels == ["second"] assert row_list[2][0].labels == ["second"] - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) @CrossSync.pytest - async def test_read_rows_range_query(self, table, temp_rows): + async def test_read_rows_range_query(self, target, temp_rows): """ Ensure that the read_rows method works """ @@ -854,17 +875,17 @@ async def test_read_rows_range_query(self, table, temp_rows): await temp_rows.add_row(b"d") # full table scan query = ReadRowsQuery(row_ranges=RowRange(start_key=b"b", end_key=b"d")) - row_list = await table.read_rows(query) + row_list = await target.read_rows(query) assert len(row_list) == 2 assert row_list[0].row_key == b"b" assert row_list[1].row_key == b"c" - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) @CrossSync.pytest - async def test_read_rows_single_key_query(self, table, temp_rows): + async def test_read_rows_single_key_query(self, target, temp_rows): """ Ensure that the read_rows method works with specified query """ @@ -876,17 +897,17 @@ async def test_read_rows_single_key_query(self, table, temp_rows): await temp_rows.add_row(b"d") # retrieve specific keys query = ReadRowsQuery(row_keys=[b"a", b"c"]) - row_list = await table.read_rows(query) + row_list = await target.read_rows(query) assert len(row_list) == 2 assert row_list[0].row_key == b"a" assert row_list[1].row_key == b"c" - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) @CrossSync.pytest - async def test_read_rows_with_filter(self, table, temp_rows): + async def test_read_rows_with_filter(self, target, temp_rows): """ ensure filters are applied """ @@ -901,15 +922,15 @@ async def test_read_rows_with_filter(self, table, temp_rows): expected_label = "test-label" row_filter = ApplyLabelFilter(expected_label) query = ReadRowsQuery(row_filter=row_filter) - row_list = await table.read_rows(query) + row_list = await target.read_rows(query) assert len(row_list) == 4 for row in row_list: assert row[0].labels == [expected_label] - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.convert(replace_symbols={"__anext__": "__next__", "aclose": "close"}) @CrossSync.pytest - async def test_read_rows_stream_close(self, table, temp_rows): + async def test_read_rows_stream_close(self, target, temp_rows): """ Ensure that the read_rows_stream can be closed """ @@ -919,7 +940,7 @@ async def test_read_rows_stream_close(self, table, temp_rows): await temp_rows.add_row(b"row_key_2") # full table scan query = ReadRowsQuery() - generator = await table.read_rows_stream(query) + generator = await target.read_rows_stream(query) # grab first row first_row = await generator.__anext__() assert first_row.row_key == b"row_key_1" @@ -928,16 +949,16 @@ async def test_read_rows_stream_close(self, table, temp_rows): with pytest.raises(CrossSync.StopIteration): await generator.__anext__() - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.pytest - async def test_read_row(self, table, temp_rows): + async def test_read_row(self, target, temp_rows): """ Test read_row (single row helper) """ from google.cloud.bigtable.data import Row await temp_rows.add_row(b"row_key_1", value=b"value") - row = await table.read_row(b"row_key_1") + row = await target.read_row(b"row_key_1") assert isinstance(row, Row) assert row.row_key == b"row_key_1" assert row.cells[0].value == b"value" @@ -946,24 +967,24 @@ async def test_read_row(self, table, temp_rows): bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't raise InvalidArgument", ) - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.pytest - async def test_read_row_missing(self, table): + async def test_read_row_missing(self, target): """ Test read_row when row does not exist """ from google.api_core import exceptions row_key = "row_key_not_exist" - result = await table.read_row(row_key) + result = await target.read_row(row_key) assert result is None with pytest.raises(exceptions.InvalidArgument) as e: - await table.read_row("") + await target.read_row("") assert "Row keys must be non-empty" in str(e) - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.pytest - async def test_read_row_w_filter(self, table, temp_rows): + async def test_read_row_w_filter(self, target, temp_rows): """ Test read_row (single row helper) """ @@ -973,7 +994,7 @@ async def test_read_row_w_filter(self, table, temp_rows): await temp_rows.add_row(b"row_key_1", value=b"value") expected_label = "test-label" label_filter = ApplyLabelFilter(expected_label) - row = await table.read_row(b"row_key_1", row_filter=label_filter) + row = await target.read_row(b"row_key_1", row_filter=label_filter) assert isinstance(row, Row) assert row.row_key == b"row_key_1" assert row.cells[0].value == b"value" @@ -983,26 +1004,26 @@ async def test_read_row_w_filter(self, table, temp_rows): bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't raise InvalidArgument", ) - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.pytest - async def test_row_exists(self, table, temp_rows): + async def test_row_exists(self, target, temp_rows): from google.api_core import exceptions """Test row_exists with rows that exist and don't exist""" - assert await table.row_exists(b"row_key_1") is False + assert await target.row_exists(b"row_key_1") is False await temp_rows.add_row(b"row_key_1") - assert await table.row_exists(b"row_key_1") is True - assert await table.row_exists("row_key_1") is True - assert await table.row_exists(b"row_key_2") is False - assert await table.row_exists("row_key_2") is False - assert await table.row_exists("3") is False + assert await target.row_exists(b"row_key_1") is True + assert await target.row_exists("row_key_1") is True + assert await target.row_exists(b"row_key_2") is False + assert await target.row_exists("row_key_2") is False + assert await target.row_exists("3") is False await temp_rows.add_row(b"3") - assert await table.row_exists(b"3") is True + assert await target.row_exists(b"3") is True with pytest.raises(exceptions.InvalidArgument) as e: - await table.row_exists("") + await target.row_exists("") assert "Row keys must be non-empty" in str(e) - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) @@ -1033,7 +1054,7 @@ async def test_row_exists(self, table, temp_rows): ) @CrossSync.pytest async def test_literal_value_filter( - self, table, temp_rows, cell_value, filter_input, expect_match + self, target, temp_rows, cell_value, filter_input, expect_match ): """ Literal value filter does complex escaping on re2 strings. @@ -1045,7 +1066,7 @@ async def test_literal_value_filter( f = LiteralValueFilter(filter_input) await temp_rows.add_row(b"row_key_1", value=cell_value) query = ReadRowsQuery(row_filter=f) - row_list = await table.read_rows(query) + row_list = await target.read_rows(query) assert len(row_list) == bool( expect_match ), f"row {type(cell_value)}({cell_value}) not found with {type(filter_input)}({filter_input}) filter" @@ -1055,10 +1076,31 @@ async def test_literal_value_filter( reason="emulator doesn't support SQL", ) @CrossSync.pytest + async def test_authorized_view_unauthenticated( + self, client, authorized_view_id, instance_id, table_id + ): + """ + Requesting family outside authorized family_subset should raise exception + """ + from google.cloud.bigtable.data.mutations import SetCell + + async with client.get_authorized_view( + instance_id, table_id, authorized_view_id + ) as view: + mutation = SetCell(family="unauthorized", qualifier="q", new_value="v") + with pytest.raises(PermissionDenied) as e: + await view.mutate_row(b"row-key", mutation) + assert "outside the Authorized View" in e.value.message + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't support SQL", + ) @pytest.mark.usefixtures("client") @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) + @CrossSync.pytest async def test_execute_query_simple(self, client, table_id, instance_id): result = await client.execute_query("SELECT 1 AS a, 'foo' AS b", instance_id) rows = [r async for r in result] @@ -1072,11 +1114,11 @@ async def test_execute_query_simple(self, client, table_id, instance_id): reason="emulator doesn't support SQL", ) @CrossSync.pytest - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - async def test_execute_against_table( + async def test_execute_against_target( self, client, instance_id, table_id, temp_rows ): await temp_rows.add_row(b"row_key_1") @@ -1197,7 +1239,7 @@ async def test_execute_query_params(self, client, table_id, instance_id): reason="emulator doesn't support SQL", ) @CrossSync.pytest - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) diff --git a/tests/system/data/test_system_autogen.py b/tests/system/data/test_system_autogen.py index ede24be76..baa6de485 100644 --- a/tests/system/data/test_system_autogen.py +++ b/tests/system/data/test_system_autogen.py @@ -20,7 +20,7 @@ import uuid import os from google.api_core import retry -from google.api_core.exceptions import ClientError +from google.api_core.exceptions import ClientError, PermissionDenied from google.cloud.bigtable.data.execute_query.metadata import SqlType from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE from google.cloud.environment_vars import BIGTABLE_EMULATOR @@ -28,6 +28,10 @@ from google.cloud.bigtable.data._cross_sync import CrossSync from . import TEST_FAMILY, TEST_FAMILY_2 +TARGETS = ["table"] +if not os.environ.get(BIGTABLE_EMULATOR): + TARGETS.append("authorized_view") + @CrossSync._Sync_Impl.add_mapping_decorator("TempRowBuilder") class TempRowBuilder: @@ -35,9 +39,9 @@ class TempRowBuilder: Used to add rows to a table for testing purposes. """ - def __init__(self, table): + def __init__(self, target): self.rows = [] - self.table = table + self.target = target def add_row( self, row_key, *, family=TEST_FAMILY, qualifier=b"q", value=b"test-value" @@ -47,7 +51,7 @@ def add_row( elif isinstance(value, int): value = value.to_bytes(8, byteorder="big", signed=True) request = { - "table_name": self.table.table_name, + "table_name": self.target.table_name, "row_key": row_key, "mutations": [ { @@ -59,19 +63,19 @@ def add_row( } ], } - self.table.client._gapic_client.mutate_row(request) + self.target.client._gapic_client.mutate_row(request) self.rows.append(row_key) def delete_rows(self): if self.rows: request = { - "table_name": self.table.table_name, + "table_name": self.target.table_name, "entries": [ {"row_key": row, "mutations": [{"delete_from_row": {}}]} for row in self.rows ], } - self.table.client._gapic_client.mutate_rows(request) + self.target.client._gapic_client.mutate_rows(request) class TestSystem: @@ -81,10 +85,21 @@ def client(self): with CrossSync._Sync_Impl.DataClient(project=project) as client: yield client - @pytest.fixture(scope="session") - def table(self, client, table_id, instance_id): - with client.get_table(instance_id, table_id) as table: - yield table + @pytest.fixture(scope="session", params=TARGETS) + def target(self, client, table_id, authorized_view_id, instance_id, request): + """This fixture runs twice: once for a standard table, and once with an authorized view + + Note: emulator doesn't support authorized views. Only use target""" + if request.param == "table": + with client.get_table(instance_id, table_id) as table: + yield table + elif request.param == "authorized_view": + with client.get_authorized_view( + instance_id, table_id, authorized_view_id + ) as view: + yield view + else: + raise ValueError(f"unknown target type: {request.param}") @pytest.fixture(scope="session") def column_family_config(self): @@ -110,12 +125,12 @@ def cluster_config(self, project_id): } return cluster - @pytest.mark.usefixtures("table") - def _retrieve_cell_value(self, table, row_key): + @pytest.mark.usefixtures("target") + def _retrieve_cell_value(self, target, row_key): """Helper to read an individual row""" from google.cloud.bigtable.data import ReadRowsQuery - row_list = table.read_rows(ReadRowsQuery(row_keys=row_key)) + row_list = target.read_rows(ReadRowsQuery(row_keys=row_key)) assert len(row_list) == 1 row = row_list[0] cell = row.cells[0] @@ -138,28 +153,28 @@ def _create_row_and_mutation( return (row_key, mutation) @pytest.fixture(scope="function") - def temp_rows(self, table): - builder = CrossSync._Sync_Impl.TempRowBuilder(table) + def temp_rows(self, target): + builder = CrossSync._Sync_Impl.TempRowBuilder(target) yield builder builder.delete_rows() - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @pytest.mark.usefixtures("client") @CrossSync._Sync_Impl.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=10 ) - def test_ping_and_warm_gapic(self, client, table): + def test_ping_and_warm_gapic(self, client, target): """Simple ping rpc test This test ensures channels are able to authenticate with backend""" - request = {"name": table.instance_name} + request = {"name": target.instance_name} client._gapic_client.ping_and_warm(request) - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @pytest.mark.usefixtures("client") @CrossSync._Sync_Impl.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - def test_ping_and_warm(self, client, table): + def test_ping_and_warm(self, client, target): """Test ping and warm from handwritten client""" results = client._ping_and_warm_instances() assert len(results) == 1 @@ -192,33 +207,33 @@ def test_channel_refresh(self, table_id, instance_id, temp_rows): finally: client.close() - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync._Sync_Impl.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - def test_mutation_set_cell(self, table, temp_rows): + def test_mutation_set_cell(self, target, temp_rows): """Ensure cells can be set properly""" row_key = b"bulk_mutate" new_value = uuid.uuid4().hex.encode() (row_key, mutation) = self._create_row_and_mutation( - table, temp_rows, new_value=new_value + target, temp_rows, new_value=new_value ) - table.mutate_row(row_key, mutation) - assert self._retrieve_cell_value(table, row_key) == new_value + target.mutate_row(row_key, mutation) + assert self._retrieve_cell_value(target, row_key) == new_value @pytest.mark.skipif( bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't use splits" ) @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync._Sync_Impl.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - def test_sample_row_keys(self, client, table, temp_rows, column_split_config): - """Sample keys should return a single sample in small test tables""" + def test_sample_row_keys(self, client, target, temp_rows, column_split_config): + """Sample keys should return a single sample in small test targets""" temp_rows.add_row(b"row_key_1") temp_rows.add_row(b"row_key_2") - results = table.sample_row_keys() + results = target.sample_row_keys() assert len(results) == len(column_split_config) + 1 for idx in range(len(column_split_config)): assert results[idx][0] == column_split_config[idx] @@ -227,20 +242,20 @@ def test_sample_row_keys(self, client, table, temp_rows, column_split_config): assert isinstance(results[-1][1], int) @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - def test_bulk_mutations_set_cell(self, client, table, temp_rows): + @pytest.mark.usefixtures("target") + def test_bulk_mutations_set_cell(self, client, target, temp_rows): """Ensure cells can be set properly""" from google.cloud.bigtable.data.mutations import RowMutationEntry new_value = uuid.uuid4().hex.encode() (row_key, mutation) = self._create_row_and_mutation( - table, temp_rows, new_value=new_value + target, temp_rows, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) - table.bulk_mutate_rows([bulk_mutation]) - assert self._retrieve_cell_value(table, row_key) == new_value + target.bulk_mutate_rows([bulk_mutation]) + assert self._retrieve_cell_value(target, row_key) == new_value - def test_bulk_mutations_raise_exception(self, client, table): + def test_bulk_mutations_raise_exception(self, client, target): """If an invalid mutation is passed, an exception should be raised""" from google.cloud.bigtable.data.mutations import RowMutationEntry, SetCell from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup @@ -252,7 +267,7 @@ def test_bulk_mutations_raise_exception(self, client, table): ) bulk_mutation = RowMutationEntry(row_key, [mutation]) with pytest.raises(MutationsExceptionGroup) as exc: - table.bulk_mutate_rows([bulk_mutation]) + target.bulk_mutate_rows([bulk_mutation]) assert len(exc.value.exceptions) == 1 entry_error = exc.value.exceptions[0] assert isinstance(entry_error, FailedMutationEntryError) @@ -260,71 +275,71 @@ def test_bulk_mutations_raise_exception(self, client, table): assert entry_error.entry == bulk_mutation @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync._Sync_Impl.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - def test_mutations_batcher_context_manager(self, client, table, temp_rows): + def test_mutations_batcher_context_manager(self, client, target, temp_rows): """test batcher with context manager. Should flush on exit""" from google.cloud.bigtable.data.mutations import RowMutationEntry (new_value, new_value2) = [uuid.uuid4().hex.encode() for _ in range(2)] (row_key, mutation) = self._create_row_and_mutation( - table, temp_rows, new_value=new_value + target, temp_rows, new_value=new_value ) (row_key2, mutation2) = self._create_row_and_mutation( - table, temp_rows, new_value=new_value2 + target, temp_rows, new_value=new_value2 ) bulk_mutation = RowMutationEntry(row_key, [mutation]) bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - with table.mutations_batcher() as batcher: + with target.mutations_batcher() as batcher: batcher.append(bulk_mutation) batcher.append(bulk_mutation2) - assert self._retrieve_cell_value(table, row_key) == new_value + assert self._retrieve_cell_value(target, row_key) == new_value assert len(batcher._staged_entries) == 0 @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync._Sync_Impl.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - def test_mutations_batcher_timer_flush(self, client, table, temp_rows): + def test_mutations_batcher_timer_flush(self, client, target, temp_rows): """batch should occur after flush_interval seconds""" from google.cloud.bigtable.data.mutations import RowMutationEntry new_value = uuid.uuid4().hex.encode() (row_key, mutation) = self._create_row_and_mutation( - table, temp_rows, new_value=new_value + target, temp_rows, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) flush_interval = 0.1 - with table.mutations_batcher(flush_interval=flush_interval) as batcher: + with target.mutations_batcher(flush_interval=flush_interval) as batcher: batcher.append(bulk_mutation) CrossSync._Sync_Impl.yield_to_event_loop() assert len(batcher._staged_entries) == 1 CrossSync._Sync_Impl.sleep(flush_interval + 0.1) assert len(batcher._staged_entries) == 0 - assert self._retrieve_cell_value(table, row_key) == new_value + assert self._retrieve_cell_value(target, row_key) == new_value @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync._Sync_Impl.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - def test_mutations_batcher_count_flush(self, client, table, temp_rows): + def test_mutations_batcher_count_flush(self, client, target, temp_rows): """batch should flush after flush_limit_mutation_count mutations""" from google.cloud.bigtable.data.mutations import RowMutationEntry (new_value, new_value2) = [uuid.uuid4().hex.encode() for _ in range(2)] (row_key, mutation) = self._create_row_and_mutation( - table, temp_rows, new_value=new_value + target, temp_rows, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) (row_key2, mutation2) = self._create_row_and_mutation( - table, temp_rows, new_value=new_value2 + target, temp_rows, new_value=new_value2 ) bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - with table.mutations_batcher(flush_limit_mutation_count=2) as batcher: + with target.mutations_batcher(flush_limit_mutation_count=2) as batcher: batcher.append(bulk_mutation) assert len(batcher._flush_jobs) == 0 assert len(batcher._staged_entries) == 1 @@ -335,29 +350,29 @@ def test_mutations_batcher_count_flush(self, client, table, temp_rows): future.result() assert len(batcher._staged_entries) == 0 assert len(batcher._flush_jobs) == 0 - assert self._retrieve_cell_value(table, row_key) == new_value - assert self._retrieve_cell_value(table, row_key2) == new_value2 + assert self._retrieve_cell_value(target, row_key) == new_value + assert self._retrieve_cell_value(target, row_key2) == new_value2 @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync._Sync_Impl.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - def test_mutations_batcher_bytes_flush(self, client, table, temp_rows): + def test_mutations_batcher_bytes_flush(self, client, target, temp_rows): """batch should flush after flush_limit_bytes bytes""" from google.cloud.bigtable.data.mutations import RowMutationEntry (new_value, new_value2) = [uuid.uuid4().hex.encode() for _ in range(2)] (row_key, mutation) = self._create_row_and_mutation( - table, temp_rows, new_value=new_value + target, temp_rows, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) (row_key2, mutation2) = self._create_row_and_mutation( - table, temp_rows, new_value=new_value2 + target, temp_rows, new_value=new_value2 ) bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) flush_limit = bulk_mutation.size() + bulk_mutation2.size() - 1 - with table.mutations_batcher(flush_limit_bytes=flush_limit) as batcher: + with target.mutations_batcher(flush_limit_bytes=flush_limit) as batcher: batcher.append(bulk_mutation) assert len(batcher._flush_jobs) == 0 assert len(batcher._staged_entries) == 1 @@ -367,27 +382,27 @@ def test_mutations_batcher_bytes_flush(self, client, table, temp_rows): for future in list(batcher._flush_jobs): future future.result() - assert self._retrieve_cell_value(table, row_key) == new_value - assert self._retrieve_cell_value(table, row_key2) == new_value2 + assert self._retrieve_cell_value(target, row_key) == new_value + assert self._retrieve_cell_value(target, row_key2) == new_value2 @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - def test_mutations_batcher_no_flush(self, client, table, temp_rows): + @pytest.mark.usefixtures("target") + def test_mutations_batcher_no_flush(self, client, target, temp_rows): """test with no flush requirements met""" from google.cloud.bigtable.data.mutations import RowMutationEntry new_value = uuid.uuid4().hex.encode() start_value = b"unchanged" (row_key, mutation) = self._create_row_and_mutation( - table, temp_rows, start_value=start_value, new_value=new_value + target, temp_rows, start_value=start_value, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) (row_key2, mutation2) = self._create_row_and_mutation( - table, temp_rows, start_value=start_value, new_value=new_value + target, temp_rows, start_value=start_value, new_value=new_value ) bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) size_limit = bulk_mutation.size() + bulk_mutation2.size() + 1 - with table.mutations_batcher( + with target.mutations_batcher( flush_limit_bytes=size_limit, flush_limit_mutation_count=3, flush_interval=1 ) as batcher: batcher.append(bulk_mutation) @@ -397,15 +412,15 @@ def test_mutations_batcher_no_flush(self, client, table, temp_rows): CrossSync._Sync_Impl.yield_to_event_loop() assert len(batcher._staged_entries) == 2 assert len(batcher._flush_jobs) == 0 - assert self._retrieve_cell_value(table, row_key) == start_value - assert self._retrieve_cell_value(table, row_key2) == start_value + assert self._retrieve_cell_value(target, row_key) == start_value + assert self._retrieve_cell_value(target, row_key2) == start_value @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync._Sync_Impl.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - def test_mutations_batcher_large_batch(self, client, table, temp_rows): + def test_mutations_batcher_large_batch(self, client, target, temp_rows): """test batcher with large batch of mutations""" from google.cloud.bigtable.data.mutations import RowMutationEntry, SetCell @@ -417,13 +432,13 @@ def test_mutations_batcher_large_batch(self, client, table, temp_rows): row_key = uuid.uuid4().hex.encode() row_mutations.append(RowMutationEntry(row_key, [add_mutation])) temp_rows.rows.append(row_key) - with table.mutations_batcher() as batcher: + with target.mutations_batcher() as batcher: for mutation in row_mutations: batcher.append(mutation) assert len(batcher._staged_entries) == 0 @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @pytest.mark.parametrize( "start,increment,expected", [ @@ -440,7 +455,7 @@ def test_mutations_batcher_large_batch(self, client, table, temp_rows): ], ) def test_read_modify_write_row_increment( - self, client, table, temp_rows, start, increment, expected + self, client, target, temp_rows, start, increment, expected ): """test read_modify_write_row""" from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule @@ -450,16 +465,16 @@ def test_read_modify_write_row_increment( qualifier = b"test-qualifier" temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) rule = IncrementRule(family, qualifier, increment) - result = table.read_modify_write_row(row_key, rule) + result = target.read_modify_write_row(row_key, rule) assert result.row_key == row_key assert len(result) == 1 assert result[0].family == family assert result[0].qualifier == qualifier assert int(result[0]) == expected - assert self._retrieve_cell_value(table, row_key) == result[0].value + assert self._retrieve_cell_value(target, row_key) == result[0].value @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @pytest.mark.parametrize( "start,append,expected", [ @@ -473,7 +488,7 @@ def test_read_modify_write_row_increment( ], ) def test_read_modify_write_row_append( - self, client, table, temp_rows, start, append, expected + self, client, target, temp_rows, start, append, expected ): """test read_modify_write_row""" from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule @@ -483,17 +498,17 @@ def test_read_modify_write_row_append( qualifier = b"test-qualifier" temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) rule = AppendValueRule(family, qualifier, append) - result = table.read_modify_write_row(row_key, rule) + result = target.read_modify_write_row(row_key, rule) assert result.row_key == row_key assert len(result) == 1 assert result[0].family == family assert result[0].qualifier == qualifier assert result[0].value == expected - assert self._retrieve_cell_value(table, row_key) == result[0].value + assert self._retrieve_cell_value(target, row_key) == result[0].value @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - def test_read_modify_write_row_chained(self, client, table, temp_rows): + @pytest.mark.usefixtures("target") + def test_read_modify_write_row_chained(self, client, target, temp_rows): """test read_modify_write_row with multiple rules""" from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule @@ -512,7 +527,7 @@ def test_read_modify_write_row_chained(self, client, table, temp_rows): AppendValueRule(family, qualifier, "world"), AppendValueRule(family, qualifier, "!"), ] - result = table.read_modify_write_row(row_key, rule) + result = target.read_modify_write_row(row_key, rule) assert result.row_key == row_key assert result[0].family == family assert result[0].qualifier == qualifier @@ -521,16 +536,16 @@ def test_read_modify_write_row_chained(self, client, table, temp_rows): == (start_amount + increment_amount).to_bytes(8, "big", signed=True) + b"helloworld!" ) - assert self._retrieve_cell_value(table, row_key) == result[0].value + assert self._retrieve_cell_value(target, row_key) == result[0].value @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @pytest.mark.parametrize( "start_val,predicate_range,expected_result", [(1, (0, 2), True), (-1, (0, 2), False)], ) def test_check_and_mutate( - self, client, table, temp_rows, start_val, predicate_range, expected_result + self, client, target, temp_rows, start_val, predicate_range, expected_result ): """test that check_and_mutate_row works applies the right mutations, and returns the right result""" from google.cloud.bigtable.data.mutations import SetCell @@ -549,7 +564,7 @@ def test_check_and_mutate( family=TEST_FAMILY, qualifier=qualifier, new_value=true_mutation_value ) predicate = ValueRangeFilter(predicate_range[0], predicate_range[1]) - result = table.check_and_mutate_row( + result = target.check_and_mutate_row( row_key, predicate, true_case_mutations=true_mutation, @@ -559,33 +574,33 @@ def test_check_and_mutate( expected_value = ( true_mutation_value if expected_result else false_mutation_value ) - assert self._retrieve_cell_value(table, row_key) == expected_value + assert self._retrieve_cell_value(target, row_key) == expected_value @pytest.mark.skipif( bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't raise InvalidArgument", ) @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - def test_check_and_mutate_empty_request(self, client, table): + @pytest.mark.usefixtures("target") + def test_check_and_mutate_empty_request(self, client, target): """check_and_mutate with no true or fale mutations should raise an error""" from google.api_core import exceptions with pytest.raises(exceptions.InvalidArgument) as e: - table.check_and_mutate_row( + target.check_and_mutate_row( b"row_key", None, true_case_mutations=None, false_case_mutations=None ) assert "No mutations provided" in str(e.value) - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync._Sync_Impl.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - def test_read_rows_stream(self, table, temp_rows): + def test_read_rows_stream(self, target, temp_rows): """Ensure that the read_rows_stream method works""" temp_rows.add_row(b"row_key_1") temp_rows.add_row(b"row_key_2") - generator = table.read_rows_stream({}) + generator = target.read_rows_stream({}) first_row = generator.__next__() second_row = generator.__next__() assert first_row.row_key == b"row_key_1" @@ -593,24 +608,24 @@ def test_read_rows_stream(self, table, temp_rows): with pytest.raises(CrossSync._Sync_Impl.StopIteration): generator.__next__() - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync._Sync_Impl.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - def test_read_rows(self, table, temp_rows): + def test_read_rows(self, target, temp_rows): """Ensure that the read_rows method works""" temp_rows.add_row(b"row_key_1") temp_rows.add_row(b"row_key_2") - row_list = table.read_rows({}) + row_list = target.read_rows({}) assert len(row_list) == 2 assert row_list[0].row_key == b"row_key_1" assert row_list[1].row_key == b"row_key_2" - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync._Sync_Impl.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - def test_read_rows_sharded_simple(self, table, temp_rows): + def test_read_rows_sharded_simple(self, target, temp_rows): """Test read rows sharded with two queries""" from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery @@ -620,18 +635,18 @@ def test_read_rows_sharded_simple(self, table, temp_rows): temp_rows.add_row(b"d") query1 = ReadRowsQuery(row_keys=[b"a", b"c"]) query2 = ReadRowsQuery(row_keys=[b"b", b"d"]) - row_list = table.read_rows_sharded([query1, query2]) + row_list = target.read_rows_sharded([query1, query2]) assert len(row_list) == 4 assert row_list[0].row_key == b"a" assert row_list[1].row_key == b"c" assert row_list[2].row_key == b"b" assert row_list[3].row_key == b"d" - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync._Sync_Impl.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - def test_read_rows_sharded_from_sample(self, table, temp_rows): + def test_read_rows_sharded_from_sample(self, target, temp_rows): """Test end-to-end sharding""" from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.cloud.bigtable.data.read_rows_query import RowRange @@ -640,20 +655,20 @@ def test_read_rows_sharded_from_sample(self, table, temp_rows): temp_rows.add_row(b"b") temp_rows.add_row(b"c") temp_rows.add_row(b"d") - table_shard_keys = table.sample_row_keys() + table_shard_keys = target.sample_row_keys() query = ReadRowsQuery(row_ranges=[RowRange(start_key=b"b", end_key=b"z")]) shard_queries = query.shard(table_shard_keys) - row_list = table.read_rows_sharded(shard_queries) + row_list = target.read_rows_sharded(shard_queries) assert len(row_list) == 3 assert row_list[0].row_key == b"b" assert row_list[1].row_key == b"c" assert row_list[2].row_key == b"d" - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync._Sync_Impl.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - def test_read_rows_sharded_filters_limits(self, table, temp_rows): + def test_read_rows_sharded_filters_limits(self, target, temp_rows): """Test read rows sharded with filters and limits""" from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.cloud.bigtable.data.row_filters import ApplyLabelFilter @@ -666,7 +681,7 @@ def test_read_rows_sharded_filters_limits(self, table, temp_rows): label_filter2 = ApplyLabelFilter("second") query1 = ReadRowsQuery(row_keys=[b"a", b"c"], limit=1, row_filter=label_filter1) query2 = ReadRowsQuery(row_keys=[b"b", b"d"], row_filter=label_filter2) - row_list = table.read_rows_sharded([query1, query2]) + row_list = target.read_rows_sharded([query1, query2]) assert len(row_list) == 3 assert row_list[0].row_key == b"a" assert row_list[1].row_key == b"b" @@ -675,11 +690,11 @@ def test_read_rows_sharded_filters_limits(self, table, temp_rows): assert row_list[1][0].labels == ["second"] assert row_list[2][0].labels == ["second"] - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync._Sync_Impl.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - def test_read_rows_range_query(self, table, temp_rows): + def test_read_rows_range_query(self, target, temp_rows): """Ensure that the read_rows method works""" from google.cloud.bigtable.data import ReadRowsQuery from google.cloud.bigtable.data import RowRange @@ -689,16 +704,16 @@ def test_read_rows_range_query(self, table, temp_rows): temp_rows.add_row(b"c") temp_rows.add_row(b"d") query = ReadRowsQuery(row_ranges=RowRange(start_key=b"b", end_key=b"d")) - row_list = table.read_rows(query) + row_list = target.read_rows(query) assert len(row_list) == 2 assert row_list[0].row_key == b"b" assert row_list[1].row_key == b"c" - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync._Sync_Impl.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - def test_read_rows_single_key_query(self, table, temp_rows): + def test_read_rows_single_key_query(self, target, temp_rows): """Ensure that the read_rows method works with specified query""" from google.cloud.bigtable.data import ReadRowsQuery @@ -707,16 +722,16 @@ def test_read_rows_single_key_query(self, table, temp_rows): temp_rows.add_row(b"c") temp_rows.add_row(b"d") query = ReadRowsQuery(row_keys=[b"a", b"c"]) - row_list = table.read_rows(query) + row_list = target.read_rows(query) assert len(row_list) == 2 assert row_list[0].row_key == b"a" assert row_list[1].row_key == b"c" - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync._Sync_Impl.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - def test_read_rows_with_filter(self, table, temp_rows): + def test_read_rows_with_filter(self, target, temp_rows): """ensure filters are applied""" from google.cloud.bigtable.data import ReadRowsQuery from google.cloud.bigtable.data.row_filters import ApplyLabelFilter @@ -728,33 +743,33 @@ def test_read_rows_with_filter(self, table, temp_rows): expected_label = "test-label" row_filter = ApplyLabelFilter(expected_label) query = ReadRowsQuery(row_filter=row_filter) - row_list = table.read_rows(query) + row_list = target.read_rows(query) assert len(row_list) == 4 for row in row_list: assert row[0].labels == [expected_label] - @pytest.mark.usefixtures("table") - def test_read_rows_stream_close(self, table, temp_rows): + @pytest.mark.usefixtures("target") + def test_read_rows_stream_close(self, target, temp_rows): """Ensure that the read_rows_stream can be closed""" from google.cloud.bigtable.data import ReadRowsQuery temp_rows.add_row(b"row_key_1") temp_rows.add_row(b"row_key_2") query = ReadRowsQuery() - generator = table.read_rows_stream(query) + generator = target.read_rows_stream(query) first_row = generator.__next__() assert first_row.row_key == b"row_key_1" generator.close() with pytest.raises(CrossSync._Sync_Impl.StopIteration): generator.__next__() - @pytest.mark.usefixtures("table") - def test_read_row(self, table, temp_rows): + @pytest.mark.usefixtures("target") + def test_read_row(self, target, temp_rows): """Test read_row (single row helper)""" from google.cloud.bigtable.data import Row temp_rows.add_row(b"row_key_1", value=b"value") - row = table.read_row(b"row_key_1") + row = target.read_row(b"row_key_1") assert isinstance(row, Row) assert row.row_key == b"row_key_1" assert row.cells[0].value == b"value" @@ -763,20 +778,20 @@ def test_read_row(self, table, temp_rows): bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't raise InvalidArgument", ) - @pytest.mark.usefixtures("table") - def test_read_row_missing(self, table): + @pytest.mark.usefixtures("target") + def test_read_row_missing(self, target): """Test read_row when row does not exist""" from google.api_core import exceptions row_key = "row_key_not_exist" - result = table.read_row(row_key) + result = target.read_row(row_key) assert result is None with pytest.raises(exceptions.InvalidArgument) as e: - table.read_row("") + target.read_row("") assert "Row keys must be non-empty" in str(e) - @pytest.mark.usefixtures("table") - def test_read_row_w_filter(self, table, temp_rows): + @pytest.mark.usefixtures("target") + def test_read_row_w_filter(self, target, temp_rows): """Test read_row (single row helper)""" from google.cloud.bigtable.data import Row from google.cloud.bigtable.data.row_filters import ApplyLabelFilter @@ -784,7 +799,7 @@ def test_read_row_w_filter(self, table, temp_rows): temp_rows.add_row(b"row_key_1", value=b"value") expected_label = "test-label" label_filter = ApplyLabelFilter(expected_label) - row = table.read_row(b"row_key_1", row_filter=label_filter) + row = target.read_row(b"row_key_1", row_filter=label_filter) assert isinstance(row, Row) assert row.row_key == b"row_key_1" assert row.cells[0].value == b"value" @@ -794,25 +809,25 @@ def test_read_row_w_filter(self, table, temp_rows): bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't raise InvalidArgument", ) - @pytest.mark.usefixtures("table") - def test_row_exists(self, table, temp_rows): + @pytest.mark.usefixtures("target") + def test_row_exists(self, target, temp_rows): from google.api_core import exceptions "Test row_exists with rows that exist and don't exist" - assert table.row_exists(b"row_key_1") is False + assert target.row_exists(b"row_key_1") is False temp_rows.add_row(b"row_key_1") - assert table.row_exists(b"row_key_1") is True - assert table.row_exists("row_key_1") is True - assert table.row_exists(b"row_key_2") is False - assert table.row_exists("row_key_2") is False - assert table.row_exists("3") is False + assert target.row_exists(b"row_key_1") is True + assert target.row_exists("row_key_1") is True + assert target.row_exists(b"row_key_2") is False + assert target.row_exists("row_key_2") is False + assert target.row_exists("3") is False temp_rows.add_row(b"3") - assert table.row_exists(b"3") is True + assert target.row_exists(b"3") is True with pytest.raises(exceptions.InvalidArgument) as e: - table.row_exists("") + target.row_exists("") assert "Row keys must be non-empty" in str(e) - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync._Sync_Impl.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) @@ -842,7 +857,7 @@ def test_row_exists(self, table, temp_rows): ], ) def test_literal_value_filter( - self, table, temp_rows, cell_value, filter_input, expect_match + self, target, temp_rows, cell_value, filter_input, expect_match ): """Literal value filter does complex escaping on re2 strings. Make sure inputs are properly interpreted by the server""" @@ -852,11 +867,28 @@ def test_literal_value_filter( f = LiteralValueFilter(filter_input) temp_rows.add_row(b"row_key_1", value=cell_value) query = ReadRowsQuery(row_filter=f) - row_list = table.read_rows(query) + row_list = target.read_rows(query) assert len(row_list) == bool( expect_match ), f"row {type(cell_value)}({cell_value}) not found with {type(filter_input)}({filter_input}) filter" + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't support SQL" + ) + def test_authorized_view_unauthenticated( + self, client, authorized_view_id, instance_id, table_id + ): + """Requesting family outside authorized family_subset should raise exception""" + from google.cloud.bigtable.data.mutations import SetCell + + with client.get_authorized_view( + instance_id, table_id, authorized_view_id + ) as view: + mutation = SetCell(family="unauthorized", qualifier="q", new_value="v") + with pytest.raises(PermissionDenied) as e: + view.mutate_row(b"row-key", mutation) + assert "outside the Authorized View" in e.value.message + @pytest.mark.skipif( bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't support SQL" ) @@ -875,11 +907,11 @@ def test_execute_query_simple(self, client, table_id, instance_id): @pytest.mark.skipif( bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't support SQL" ) - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync._Sync_Impl.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - def test_execute_against_table(self, client, instance_id, table_id, temp_rows): + def test_execute_against_target(self, client, instance_id, table_id, temp_rows): temp_rows.add_row(b"row_key_1") result = client.execute_query("SELECT * FROM `" + table_id + "`", instance_id) rows = [r for r in result] @@ -982,7 +1014,7 @@ def test_execute_query_params(self, client, table_id, instance_id): @pytest.mark.skipif( bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't support SQL" ) - @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("target") @CrossSync._Sync_Impl.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index 13f668fd3..f14fa6dee 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -15,6 +15,8 @@ import pytest from google.cloud.bigtable_v2.types import MutateRowsResponse +from google.cloud.bigtable.data.mutations import RowMutationEntry +from google.cloud.bigtable.data.mutations import DeleteAllFromRow from google.rpc import status_pb2 from google.api_core.exceptions import DeadlineExceeded from google.api_core.exceptions import Forbidden @@ -37,8 +39,11 @@ def _target_class(self): def _make_one(self, *args, **kwargs): if not args: + fake_target = CrossSync.Mock() + fake_target._request_path = {"table_name": "table"} + fake_target.app_profile_id = None kwargs["gapic_client"] = kwargs.pop("gapic_client", mock.Mock()) - kwargs["table"] = kwargs.pop("table", CrossSync.Mock()) + kwargs["target"] = kwargs.pop("target", fake_target) kwargs["operation_timeout"] = kwargs.pop("operation_timeout", 5) kwargs["attempt_timeout"] = kwargs.pop("attempt_timeout", 0.1) kwargs["retryable_exceptions"] = kwargs.pop("retryable_exceptions", ()) @@ -46,9 +51,8 @@ def _make_one(self, *args, **kwargs): return self._target_class()(*args, **kwargs) def _make_mutation(self, count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count + mutation = RowMutationEntry("k", [DeleteAllFromRow() for _ in range(count)]) + mutation.size = lambda: size return mutation @CrossSync.convert @@ -95,16 +99,10 @@ def test_ctor(self): attempt_timeout, retryable_exceptions, ) - # running gapic_fn should trigger a client call + # running gapic_fn should trigger a client call with baked-in args assert client.mutate_rows.call_count == 0 instance._gapic_fn() assert client.mutate_rows.call_count == 1 - # gapic_fn should call with table details - inner_kwargs = client.mutate_rows.call_args[1] - assert len(inner_kwargs) == 3 - assert inner_kwargs["table_name"] == table.table_name - assert inner_kwargs["app_profile_id"] == table.app_profile_id - assert inner_kwargs["retry"] is None # entries should be passed down entries_w_pb = [_EntryWithProto(e, e._to_pb()) for e in entries] assert instance.mutations == entries_w_pb @@ -174,6 +172,8 @@ async def test_mutate_rows_attempt_exception(self, exc_type): """ client = CrossSync.Mock() table = mock.Mock() + table._request_path = {"table_name": "table"} + table.app_profile_id = None entries = [self._make_mutation(), self._make_mutation()] operation_timeout = 0.05 expected_exception = exc_type("test") @@ -307,7 +307,8 @@ async def test_run_attempt_single_entry_success(self): assert mock_gapic_fn.call_count == 1 _, kwargs = mock_gapic_fn.call_args assert kwargs["timeout"] == expected_timeout - assert kwargs["entries"] == [mutation._to_pb()] + request = kwargs["request"] + assert request.entries == [mutation._to_pb()] @CrossSync.pytest async def test_run_attempt_empty_request(self): diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index 944681a84..c43f46d5a 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -54,7 +54,7 @@ def test_ctor(self): client.read_rows.return_value = None table = mock.Mock() table._client = client - table.table_name = "test_table" + table._request_path = {"table_name": "test_table"} table.app_profile_id = "test_profile" expected_operation_timeout = 42 expected_request_timeout = 44 @@ -78,7 +78,7 @@ def test_ctor(self): assert instance._remaining_count == row_limit assert instance.operation_timeout == expected_operation_timeout assert client.read_rows.call_count == 0 - assert instance.request.table_name == table.table_name + assert instance.request.table_name == "test_table" assert instance.request.app_profile_id == table.app_profile_id assert instance.request.rows_limit == row_limit @@ -267,7 +267,7 @@ async def mock_stream(): query = ReadRowsQuery(limit=start_limit) table = mock.Mock() - table.table_name = "table_name" + table._request_path = {"table_name": "table_name"} table.app_profile_id = "app_profile_id" instance = self._make_one(query, table, 10, 10) assert instance._remaining_count == start_limit @@ -306,7 +306,7 @@ async def mock_stream(): query = ReadRowsQuery(limit=start_limit) table = mock.Mock() - table.table_name = "table_name" + table._request_path = {"table_name": "table_name"} table.app_profile_id = "app_profile_id" instance = self._make_one(query, table, 10, 10) assert instance._remaining_count == start_limit diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 96fcf66b3..6326e9ec6 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -28,6 +28,7 @@ from google.api_core import exceptions as core_exceptions from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete +from google.cloud.bigtable.data.mutations import DeleteAllFromRow from google.cloud.bigtable.data import TABLE_DEFAULT from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule @@ -272,9 +273,7 @@ async def test__ping_and_warm_instances(self): assert gather.call_args[1]["return_exceptions"] is True assert gather.call_args[1]["sync_executor"] == client_mock._executor # test with instances - client_mock._active_instances = [ - (mock.Mock(), mock.Mock(), mock.Mock()) - ] * 4 + client_mock._active_instances = [(mock.Mock(), mock.Mock())] * 4 gather.reset_mock() channel.reset_mock() result = await self._get_target_class()._ping_and_warm_instances( @@ -292,7 +291,6 @@ async def test__ping_and_warm_instances(self): for idx, (_, kwargs) in enumerate(grpc_call_args): ( expected_instance, - expected_table, expected_app_profile, ) = client_mock._active_instances[idx] request = kwargs["request"] @@ -323,7 +321,7 @@ async def test__ping_and_warm_single_instance(self): gather.side_effect = lambda *args, **kwargs: [fn() for fn in args[0]] # test with large set of instances client_mock._active_instances = [mock.Mock()] * 100 - test_key = ("test-instance", "test-table", "test-app-profile") + test_key = ("test-instance", "test-app-profile") result = await self._get_target_class()._ping_and_warm_instances( client_mock, test_key ) @@ -551,7 +549,6 @@ async def test__register_instance(self): # ensure active_instances and instance_owners were updated properly expected_key = ( "prefix/instance-1", - table_mock.table_name, table_mock.app_profile_id, ) assert len(active_instances) == 1 @@ -577,7 +574,6 @@ async def test__register_instance(self): assert len(instance_owners) == 2 expected_key2 = ( "prefix/instance-2", - table_mock2.table_name, table_mock2.app_profile_id, ) assert any( @@ -612,7 +608,6 @@ async def test__register_instance_duplicate(self): table_mock = mock.Mock() expected_key = ( "prefix/instance-1", - table_mock.table_name, table_mock.app_profile_id, ) # fake first registration @@ -639,13 +634,13 @@ async def test__register_instance_duplicate(self): @pytest.mark.parametrize( "insert_instances,expected_active,expected_owner_keys", [ - ([("i", "t", None)], [("i", "t", None)], [("i", "t", None)]), - ([("i", "t", "p")], [("i", "t", "p")], [("i", "t", "p")]), - ([("1", "t", "p"), ("1", "t", "p")], [("1", "t", "p")], [("1", "t", "p")]), + ([("i", None)], [("i", None)], [("i", None)]), + ([("i", "p")], [("i", "p")], [("i", "p")]), + ([("1", "p"), ("1", "p")], [("1", "p")], [("1", "p")]), ( - [("1", "t", "p"), ("2", "t", "p")], - [("1", "t", "p"), ("2", "t", "p")], - [("1", "t", "p"), ("2", "t", "p")], + [("1", "p"), ("2", "p")], + [("1", "p"), ("2", "p")], + [("1", "p"), ("2", "p")], ), ], ) @@ -666,8 +661,7 @@ async def test__register_instance_state( client_mock._ping_and_warm_instances = CrossSync.Mock() table_mock = mock.Mock() # register instances - for instance, table, profile in insert_instances: - table_mock.table_name = table + for instance, profile in insert_instances: table_mock.app_profile_id = profile await self._get_target_class()._register_instance( client_mock, instance, table_mock @@ -700,11 +694,11 @@ async def test__remove_instance_registration(self): instance_1_path = client._gapic_client.instance_path( client.project, "instance-1" ) - instance_1_key = (instance_1_path, table.table_name, table.app_profile_id) + instance_1_key = (instance_1_path, table.app_profile_id) instance_2_path = client._gapic_client.instance_path( client.project, "instance-2" ) - instance_2_key = (instance_2_path, table.table_name, table.app_profile_id) + instance_2_key = (instance_2_path, table.app_profile_id) assert len(client._instance_owners[instance_1_key]) == 1 assert list(client._instance_owners[instance_1_key])[0] == id(table) assert len(client._instance_owners[instance_2_key]) == 1 @@ -735,13 +729,13 @@ async def test__multiple_table_registration(self): client.project, "instance_1" ) instance_1_key = _WarmedInstanceKey( - instance_1_path, table_1.table_name, table_1.app_profile_id + instance_1_path, table_1.app_profile_id ) assert len(client._instance_owners[instance_1_key]) == 1 assert len(client._active_instances) == 1 assert id(table_1) in client._instance_owners[instance_1_key] # duplicate table should register in instance_owners under same key - async with client.get_table("instance_1", "table_1") as table_2: + async with client.get_table("instance_1", "table_2") as table_2: assert table_2._register_instance_future is not None if not CrossSync.is_async: # give the background task time to run @@ -751,7 +745,9 @@ async def test__multiple_table_registration(self): assert id(table_1) in client._instance_owners[instance_1_key] assert id(table_2) in client._instance_owners[instance_1_key] # unique table should register in instance_owners and active_instances - async with client.get_table("instance_1", "table_3") as table_3: + async with client.get_table( + "instance_1", "table_3", app_profile_id="diff" + ) as table_3: assert table_3._register_instance_future is not None if not CrossSync.is_async: # give the background task time to run @@ -760,7 +756,7 @@ async def test__multiple_table_registration(self): client.project, "instance_1" ) instance_3_key = _WarmedInstanceKey( - instance_3_path, table_3.table_name, table_3.app_profile_id + instance_3_path, table_3.app_profile_id ) assert len(client._instance_owners[instance_1_key]) == 2 assert len(client._instance_owners[instance_3_key]) == 1 @@ -800,13 +796,13 @@ async def test__multiple_instance_registration(self): client.project, "instance_1" ) instance_1_key = _WarmedInstanceKey( - instance_1_path, table_1.table_name, table_1.app_profile_id + instance_1_path, table_1.app_profile_id ) instance_2_path = client._gapic_client.instance_path( client.project, "instance_2" ) instance_2_key = _WarmedInstanceKey( - instance_2_path, table_2.table_name, table_2.app_profile_id + instance_2_path, table_2.app_profile_id ) assert len(client._instance_owners[instance_1_key]) == 1 assert len(client._instance_owners[instance_2_key]) == 1 @@ -824,8 +820,12 @@ async def test__multiple_instance_registration(self): assert len(client._instance_owners[instance_1_key]) == 0 assert len(client._instance_owners[instance_2_key]) == 0 + @pytest.mark.parametrize("method", ["get_table", "get_authorized_view"]) @CrossSync.pytest - async def test_get_table(self): + async def test_get_api_surface(self, method): + """ + test client.get_table and client.get_authorized_view + """ from google.cloud.bigtable.data._helpers import _WarmedInstanceKey client = self._make_client(project="project-id") @@ -833,67 +833,90 @@ async def test_get_table(self): expected_table_id = "table-id" expected_instance_id = "instance-id" expected_app_profile_id = "app-profile-id" - table = client.get_table( - expected_instance_id, - expected_table_id, - expected_app_profile_id, - ) + if method == "get_table": + surface = client.get_table( + expected_instance_id, + expected_table_id, + expected_app_profile_id, + ) + assert isinstance(surface, CrossSync.TestTable._get_target_class()) + elif method == "get_authorized_view": + surface = client.get_authorized_view( + expected_instance_id, + expected_table_id, + "view_id", + expected_app_profile_id, + ) + assert isinstance(surface, CrossSync.TestAuthorizedView._get_target_class()) + assert ( + surface.authorized_view_name + == f"projects/{client.project}/instances/{expected_instance_id}/tables/{expected_table_id}/authorizedViews/view_id" + ) + else: + raise TypeError(f"unexpected method: {method}") await CrossSync.yield_to_event_loop() - assert isinstance(table, CrossSync.TestTable._get_target_class()) - assert table.table_id == expected_table_id + assert surface.table_id == expected_table_id assert ( - table.table_name + surface.table_name == f"projects/{client.project}/instances/{expected_instance_id}/tables/{expected_table_id}" ) - assert table.instance_id == expected_instance_id + assert surface.instance_id == expected_instance_id assert ( - table.instance_name + surface.instance_name == f"projects/{client.project}/instances/{expected_instance_id}" ) - assert table.app_profile_id == expected_app_profile_id - assert table.client is client - instance_key = _WarmedInstanceKey( - table.instance_name, table.table_name, table.app_profile_id - ) + assert surface.app_profile_id == expected_app_profile_id + assert surface.client is client + instance_key = _WarmedInstanceKey(surface.instance_name, surface.app_profile_id) assert instance_key in client._active_instances - assert client._instance_owners[instance_key] == {id(table)} + assert client._instance_owners[instance_key] == {id(surface)} await client.close() + @pytest.mark.parametrize("method", ["get_table", "get_authorized_view"]) @CrossSync.pytest - async def test_get_table_arg_passthrough(self): + async def test_api_surface_arg_passthrough(self, method): """ - All arguments passed in get_table should be sent to constructor + All arguments passed in get_table and get_authorized_view should be sent to constructor """ + if method == "get_table": + surface_type = CrossSync.TestTable._get_target_class() + elif method == "get_authorized_view": + surface_type = CrossSync.TestAuthorizedView._get_target_class() + else: + raise TypeError(f"unexpected method: {method}") + async with self._make_client(project="project-id") as client: - with mock.patch.object( - CrossSync.TestTable._get_target_class(), "__init__" - ) as mock_constructor: + with mock.patch.object(surface_type, "__init__") as mock_constructor: mock_constructor.return_value = None assert not client._active_instances - expected_table_id = "table-id" - expected_instance_id = "instance-id" - expected_app_profile_id = "app-profile-id" - expected_args = (1, "test", {"test": 2}) + expected_args = ( + "table", + "instance", + "view", + "app_profile", + 1, + "test", + {"test": 2}, + ) expected_kwargs = {"hello": "world", "test": 2} - client.get_table( - expected_instance_id, - expected_table_id, - expected_app_profile_id, + getattr(client, method)( *expected_args, **expected_kwargs, ) mock_constructor.assert_called_once_with( client, - expected_instance_id, - expected_table_id, - expected_app_profile_id, *expected_args, **expected_kwargs, ) + @pytest.mark.parametrize("method", ["get_table", "get_authorized_view"]) @CrossSync.pytest - async def test_get_table_context_manager(self): + async def test_api_surface_context_manager(self, method): + """ + get_table and get_authorized_view should work as context managers + """ + from functools import partial from google.cloud.bigtable.data._helpers import _WarmedInstanceKey expected_table_id = "table-id" @@ -901,17 +924,35 @@ async def test_get_table_context_manager(self): expected_app_profile_id = "app-profile-id" expected_project_id = "project-id" - with mock.patch.object( - CrossSync.TestTable._get_target_class(), "close" - ) as close_mock: + if method == "get_table": + surface_type = CrossSync.TestTable._get_target_class() + elif method == "get_authorized_view": + surface_type = CrossSync.TestAuthorizedView._get_target_class() + else: + raise TypeError(f"unexpected method: {method}") + + with mock.patch.object(surface_type, "close") as close_mock: async with self._make_client(project=expected_project_id) as client: - async with client.get_table( - expected_instance_id, - expected_table_id, - expected_app_profile_id, - ) as table: + if method == "get_table": + fn = partial( + client.get_table, + expected_instance_id, + expected_table_id, + expected_app_profile_id, + ) + elif method == "get_authorized_view": + fn = partial( + client.get_authorized_view, + expected_instance_id, + expected_table_id, + "view_id", + expected_app_profile_id, + ) + else: + raise TypeError(f"unexpected method: {method}") + async with fn() as table: await CrossSync.yield_to_event_loop() - assert isinstance(table, CrossSync.TestTable._get_target_class()) + assert isinstance(table, surface_type) assert table.table_id == expected_table_id assert ( table.table_name @@ -925,7 +966,7 @@ async def test_get_table_context_manager(self): assert table.app_profile_id == expected_app_profile_id assert table.client is client instance_key = _WarmedInstanceKey( - table.instance_name, table.table_name, table.app_profile_id + table.instance_name, table.app_profile_id ) assert instance_key in client._active_instances assert client._instance_owners[instance_key] == {id(table)} @@ -1009,8 +1050,20 @@ def _make_client(self, *args, **kwargs): def _get_target_class(): return CrossSync.Table + def _make_one( + self, + client, + instance_id="instance", + table_id="table", + app_profile_id=None, + **kwargs, + ): + return self._get_target_class()( + client, instance_id, table_id, app_profile_id, **kwargs + ) + @CrossSync.pytest - async def test_table_ctor(self): + async def test_ctor(self): from google.cloud.bigtable.data._helpers import _WarmedInstanceKey expected_table_id = "table-id" @@ -1040,11 +1093,17 @@ async def test_table_ctor(self): await CrossSync.yield_to_event_loop() assert table.table_id == expected_table_id assert table.instance_id == expected_instance_id + assert ( + table.table_name + == f"projects/{client.project}/instances/{expected_instance_id}/tables/{expected_table_id}" + ) + assert ( + table.instance_name + == f"projects/{client.project}/instances/{expected_instance_id}" + ) assert table.app_profile_id == expected_app_profile_id assert table.client is client - instance_key = _WarmedInstanceKey( - table.instance_name, table.table_name, table.app_profile_id - ) + instance_key = _WarmedInstanceKey(table.instance_name, table.app_profile_id) assert instance_key in client._active_instances assert client._instance_owners[instance_key] == {id(table)} assert table.default_operation_timeout == expected_operation_timeout @@ -1073,23 +1132,15 @@ async def test_table_ctor(self): await client.close() @CrossSync.pytest - async def test_table_ctor_defaults(self): + async def test_ctor_defaults(self): """ should provide default timeout values and app_profile_id """ - expected_table_id = "table-id" - expected_instance_id = "instance-id" client = self._make_client() assert not client._active_instances - table = self._get_target_class()( - client, - expected_instance_id, - expected_table_id, - ) + table = self._make_one(client) await CrossSync.yield_to_event_loop() - assert table.table_id == expected_table_id - assert table.instance_id == expected_instance_id assert table.app_profile_id is None assert table.client is client assert table.default_operation_timeout == 60 @@ -1101,7 +1152,7 @@ async def test_table_ctor_defaults(self): await client.close() @CrossSync.pytest - async def test_table_ctor_invalid_timeout_values(self): + async def test_ctor_invalid_timeout_values(self): """ bad timeout values should raise ValueError """ @@ -1120,10 +1171,10 @@ async def test_table_ctor_invalid_timeout_values(self): ] for operation_timeout, attempt_timeout in timeout_pairs: with pytest.raises(ValueError) as e: - self._get_target_class()(client, "", "", **{attempt_timeout: -1}) + self._make_one(client, **{attempt_timeout: -1}) assert "attempt_timeout must be greater than 0" in str(e.value) with pytest.raises(ValueError) as e: - self._get_target_class()(client, "", "", **{operation_timeout: -1}) + self._make_one(client, **{operation_timeout: -1}) assert "operation_timeout must be greater than 0" in str(e.value) await client.close() @@ -1173,13 +1224,13 @@ def test_table_ctor_sync(self): ("sample_row_keys", (), False, ()), ( "mutate_row", - (b"row_key", [mock.Mock()]), + (b"row_key", [DeleteAllFromRow()]), False, (), ), ( "bulk_mutate_rows", - ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), + ([mutations.RowMutationEntry(b"key", [DeleteAllFromRow()])],), False, (_MutateRowsIncomplete,), ), @@ -1291,7 +1342,7 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ gapic_client = gapic_client._client gapic_client._transport = transport_mock gapic_client._is_universe_domain_valid = True - table = self._get_target_class()(client, "instance-id", "table-id", profile) + table = self._make_one(client, app_profile_id=profile) try: test_fn = table.__getattribute__(fn_name) maybe_stream = await test_fn(*fn_args) @@ -1307,12 +1358,128 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ # expect x-goog-request-params tag assert metadata[0][0] == "x-goog-request-params" routing_str = metadata[0][1] - assert "table_name=" + table.table_name in routing_str + assert self._expected_routing_header(table) in routing_str if include_app_profile: assert "app_profile_id=profile" in routing_str else: assert "app_profile_id=" not in routing_str + @staticmethod + def _expected_routing_header(table): + """ + the expected routing header for this _ApiSurface type + """ + return f"table_name={table.table_name}" + + +@CrossSync.convert_class( + "TestAuthorizedView", add_mapping_for_name="TestAuthorizedView" +) +class TestAuthorizedViewsAsync(CrossSync.TestTable): + """ + Inherit tests from TestTableAsync, with some modifications + """ + + @staticmethod + @CrossSync.convert + def _get_target_class(): + return CrossSync.AuthorizedView + + def _make_one( + self, + client, + instance_id="instance", + table_id="table", + view_id="view", + app_profile_id=None, + **kwargs, + ): + return self._get_target_class()( + client, instance_id, table_id, view_id, app_profile_id, **kwargs + ) + + @staticmethod + def _expected_routing_header(view): + """ + the expected routing header for this _ApiSurface type + """ + return f"authorized_view_name={view.authorized_view_name}" + + @CrossSync.pytest + async def test_ctor(self): + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_view_id = "view_id" + expected_app_profile_id = "app-profile-id" + expected_operation_timeout = 123 + expected_attempt_timeout = 12 + expected_read_rows_operation_timeout = 1.5 + expected_read_rows_attempt_timeout = 0.5 + expected_mutate_rows_operation_timeout = 2.5 + expected_mutate_rows_attempt_timeout = 0.75 + client = self._make_client() + assert not client._active_instances + + view = self._get_target_class()( + client, + expected_instance_id, + expected_table_id, + expected_view_id, + expected_app_profile_id, + default_operation_timeout=expected_operation_timeout, + default_attempt_timeout=expected_attempt_timeout, + default_read_rows_operation_timeout=expected_read_rows_operation_timeout, + default_read_rows_attempt_timeout=expected_read_rows_attempt_timeout, + default_mutate_rows_operation_timeout=expected_mutate_rows_operation_timeout, + default_mutate_rows_attempt_timeout=expected_mutate_rows_attempt_timeout, + ) + await CrossSync.yield_to_event_loop() + assert view.table_id == expected_table_id + assert ( + view.table_name + == f"projects/{client.project}/instances/{expected_instance_id}/tables/{expected_table_id}" + ) + assert view.instance_id == expected_instance_id + assert ( + view.instance_name + == f"projects/{client.project}/instances/{expected_instance_id}" + ) + assert view.authorized_view_id == expected_view_id + assert ( + view.authorized_view_name + == f"projects/{client.project}/instances/{expected_instance_id}/tables/{expected_table_id}/authorizedViews/{expected_view_id}" + ) + assert view.app_profile_id == expected_app_profile_id + assert view.client is client + instance_key = _WarmedInstanceKey(view.instance_name, view.app_profile_id) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(view)} + assert view.default_operation_timeout == expected_operation_timeout + assert view.default_attempt_timeout == expected_attempt_timeout + assert ( + view.default_read_rows_operation_timeout + == expected_read_rows_operation_timeout + ) + assert ( + view.default_read_rows_attempt_timeout == expected_read_rows_attempt_timeout + ) + assert ( + view.default_mutate_rows_operation_timeout + == expected_mutate_rows_operation_timeout + ) + assert ( + view.default_mutate_rows_attempt_timeout + == expected_mutate_rows_attempt_timeout + ) + # ensure task reaches completion + await view._register_instance_future + assert view._register_instance_future.done() + assert not view._register_instance_future.cancelled() + assert view._register_instance_future.exception() is None + await client.close() + @CrossSync.convert_class( "TestReadRows", @@ -2144,11 +2311,12 @@ async def test_sample_row_keys_gapic_params(self): await table.sample_row_keys(attempt_timeout=expected_timeout) args, kwargs = sample_row_keys.call_args assert len(args) == 0 - assert len(kwargs) == 4 + assert len(kwargs) == 3 assert kwargs["timeout"] == expected_timeout - assert kwargs["app_profile_id"] == expected_profile - assert kwargs["table_name"] == table.table_name assert kwargs["retry"] is None + request = kwargs["request"] + assert request.app_profile_id == expected_profile + assert request.table_name == table.table_name @pytest.mark.parametrize( "retryable_exception", @@ -2244,17 +2412,18 @@ async def test_mutate_row(self, mutation_arg): ) assert mock_gapic.call_count == 1 kwargs = mock_gapic.call_args_list[0].kwargs + request = kwargs["request"] assert ( - kwargs["table_name"] + request.table_name == "projects/project/instances/instance/tables/table" ) - assert kwargs["row_key"] == b"row_key" + assert request.row_key == b"row_key" formatted_mutations = ( [mutation._to_pb() for mutation in mutation_arg] if isinstance(mutation_arg, list) else [mutation_arg._to_pb()] ) - assert kwargs["mutations"] == formatted_mutations + assert request.mutations == formatted_mutations assert kwargs["timeout"] == expected_attempt_timeout # make sure gapic layer is not retrying assert kwargs["retry"] is None @@ -2426,11 +2595,12 @@ async def test_bulk_mutate_rows(self, mutation_arg): ) assert mock_gapic.call_count == 1 kwargs = mock_gapic.call_args[1] + request = kwargs["request"] assert ( - kwargs["table_name"] + request.table_name == "projects/project/instances/instance/tables/table" ) - assert kwargs["entries"] == [bulk_mutation._to_pb()] + assert request.entries == [bulk_mutation._to_pb()] assert kwargs["timeout"] == expected_attempt_timeout assert kwargs["retry"] is None @@ -2451,12 +2621,13 @@ async def test_bulk_mutate_rows_multiple_entries(self): ) assert mock_gapic.call_count == 1 kwargs = mock_gapic.call_args[1] + request = kwargs["request"] assert ( - kwargs["table_name"] + request.table_name == "projects/project/instances/instance/tables/table" ) - assert kwargs["entries"][0] == entry_1._to_pb() - assert kwargs["entries"][1] == entry_2._to_pb() + assert request.entries[0] == entry_1._to_pb() + assert request.entries[1] == entry_2._to_pb() @CrossSync.pytest @pytest.mark.parametrize( @@ -2764,8 +2935,8 @@ async def test_check_and_mutate(self, gapic_result): ) row_key = b"row_key" predicate = None - true_mutations = [mock.Mock()] - false_mutations = [mock.Mock(), mock.Mock()] + true_mutations = [DeleteAllFromRow()] + false_mutations = [DeleteAllFromRow(), DeleteAllFromRow()] operation_timeout = 0.2 found = await table.check_and_mutate_row( row_key, @@ -2776,16 +2947,17 @@ async def test_check_and_mutate(self, gapic_result): ) assert found == gapic_result kwargs = mock_gapic.call_args[1] - assert kwargs["table_name"] == table.table_name - assert kwargs["row_key"] == row_key - assert kwargs["predicate_filter"] == predicate - assert kwargs["true_mutations"] == [ + request = kwargs["request"] + assert request.table_name == table.table_name + assert request.row_key == row_key + assert bool(request.predicate_filter) is False + assert request.true_mutations == [ m._to_pb() for m in true_mutations ] - assert kwargs["false_mutations"] == [ + assert request.false_mutations == [ m._to_pb() for m in false_mutations ] - assert kwargs["app_profile_id"] == app_profile + assert request.app_profile_id == app_profile assert kwargs["timeout"] == operation_timeout assert kwargs["retry"] is None @@ -2827,16 +2999,18 @@ async def test_check_and_mutate_single_mutations(self): false_case_mutations=false_mutation, ) kwargs = mock_gapic.call_args[1] - assert kwargs["true_mutations"] == [true_mutation._to_pb()] - assert kwargs["false_mutations"] == [false_mutation._to_pb()] + request = kwargs["request"] + assert request.true_mutations == [true_mutation._to_pb()] + assert request.false_mutations == [false_mutation._to_pb()] @CrossSync.pytest async def test_check_and_mutate_predicate_object(self): """predicate filter should be passed to gapic request""" from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + from google.cloud.bigtable_v2.types.data import RowFilter mock_predicate = mock.Mock() - predicate_pb = {"predicate": "dict"} + predicate_pb = RowFilter({"sink": True}) mock_predicate._to_pb.return_value = predicate_pb async with self._make_client() as client: async with client.get_table("instance", "table") as table: @@ -2849,10 +3023,11 @@ async def test_check_and_mutate_predicate_object(self): await table.check_and_mutate_row( b"row_key", mock_predicate, - false_case_mutations=[mock.Mock()], + false_case_mutations=[DeleteAllFromRow()], ) kwargs = mock_gapic.call_args[1] - assert kwargs["predicate_filter"] == predicate_pb + request = kwargs["request"] + assert request.predicate_filter == predicate_pb assert mock_predicate._to_pb.call_count == 1 assert kwargs["retry"] is None @@ -2860,11 +3035,11 @@ async def test_check_and_mutate_predicate_object(self): async def test_check_and_mutate_mutations_parsing(self): """mutations objects should be converted to protos""" from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse - from google.cloud.bigtable.data.mutations import DeleteAllFromRow + from google.cloud.bigtable.data.mutations import DeleteAllFromFamily mutations = [mock.Mock() for _ in range(5)] for idx, mutation in enumerate(mutations): - mutation._to_pb.return_value = f"fake {idx}" + mutation._to_pb.return_value = DeleteAllFromFamily(f"fake {idx}")._to_pb() mutations.append(DeleteAllFromRow()) async with self._make_client() as client: async with client.get_table("instance", "table") as table: @@ -2881,11 +3056,15 @@ async def test_check_and_mutate_mutations_parsing(self): false_case_mutations=mutations[2:], ) kwargs = mock_gapic.call_args[1] - assert kwargs["true_mutations"] == ["fake 0", "fake 1"] - assert kwargs["false_mutations"] == [ - "fake 2", - "fake 3", - "fake 4", + request = kwargs["request"] + assert request.true_mutations == [ + DeleteAllFromFamily("fake 0")._to_pb(), + DeleteAllFromFamily("fake 1")._to_pb(), + ] + assert request.false_mutations == [ + DeleteAllFromFamily("fake 2")._to_pb(), + DeleteAllFromFamily("fake 3")._to_pb(), + DeleteAllFromFamily("fake 4")._to_pb(), DeleteAllFromRow()._to_pb(), ] assert all( @@ -2933,7 +3112,8 @@ async def test_read_modify_write_call_rule_args(self, call_rules, expected_rules await table.read_modify_write_row("key", call_rules) assert mock_gapic.call_count == 1 found_kwargs = mock_gapic.call_args_list[0][1] - assert found_kwargs["rules"] == expected_rules + request = found_kwargs["request"] + assert request.rules == expected_rules assert found_kwargs["retry"] is None @pytest.mark.parametrize("rules", [[], None]) @@ -2956,15 +3136,16 @@ async def test_read_modify_write_call_defaults(self): with mock.patch.object( client._gapic_client, "read_modify_write_row" ) as mock_gapic: - await table.read_modify_write_row(row_key, mock.Mock()) + await table.read_modify_write_row(row_key, IncrementRule("f", "q")) assert mock_gapic.call_count == 1 kwargs = mock_gapic.call_args_list[0][1] + request = kwargs["request"] assert ( - kwargs["table_name"] + request.table_name == f"projects/{project}/instances/{instance}/tables/{table_id}" ) - assert kwargs["app_profile_id"] is None - assert kwargs["row_key"] == row_key.encode() + assert bool(request.app_profile_id) is False + assert request.row_key == row_key.encode() assert kwargs["timeout"] > 1 @CrossSync.pytest @@ -2981,13 +3162,14 @@ async def test_read_modify_write_call_overrides(self): ) as mock_gapic: await table.read_modify_write_row( row_key, - mock.Mock(), + IncrementRule("f", "q"), operation_timeout=expected_timeout, ) assert mock_gapic.call_count == 1 kwargs = mock_gapic.call_args_list[0][1] - assert kwargs["app_profile_id"] is profile_id - assert kwargs["row_key"] == row_key + request = kwargs["request"] + assert request.app_profile_id == profile_id + assert request.row_key == row_key assert kwargs["timeout"] == expected_timeout @CrossSync.pytest @@ -2998,10 +3180,11 @@ async def test_read_modify_write_string_key(self): with mock.patch.object( client._gapic_client, "read_modify_write_row" ) as mock_gapic: - await table.read_modify_write_row(row_key, mock.Mock()) + await table.read_modify_write_row(row_key, IncrementRule("f", "q")) assert mock_gapic.call_count == 1 kwargs = mock_gapic.call_args_list[0][1] - assert kwargs["row_key"] == row_key.encode() + request = kwargs["request"] + assert request.row_key == row_key.encode() @CrossSync.pytest async def test_read_modify_write_row_building(self): @@ -3020,7 +3203,9 @@ async def test_read_modify_write_row_building(self): ) as mock_gapic: with mock.patch.object(Row, "_from_pb") as constructor_mock: mock_gapic.return_value = mock_response - await table.read_modify_write_row("key", mock.Mock()) + await table.read_modify_write_row( + "key", IncrementRule("f", "q") + ) assert constructor_mock.call_count == 1 constructor_mock.assert_called_once_with(mock_response.row) diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 2df8dde6d..29f2f1026 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -19,6 +19,8 @@ import google.api_core.exceptions as core_exceptions import google.api_core.retry from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete +from google.cloud.bigtable.data.mutations import RowMutationEntry +from google.cloud.bigtable.data.mutations import DeleteAllFromRow from google.cloud.bigtable.data import TABLE_DEFAULT from google.cloud.bigtable.data._cross_sync import CrossSync @@ -38,9 +40,9 @@ def _make_one(self, max_mutation_count=10, max_mutation_bytes=100): @staticmethod def _make_mutation(count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count + mutation = RowMutationEntry("k", DeleteAllFromRow()) + mutation.mutations = [DeleteAllFromRow() for _ in range(count)] + mutation.size = lambda: size return mutation def test_ctor(self): @@ -308,6 +310,8 @@ def _make_one(self, table=None, **kwargs): if table is None: table = mock.Mock() + table._request_path = {"table_name": "table"} + table.app_profile_id = None table.default_mutate_rows_operation_timeout = 10 table.default_mutate_rows_attempt_timeout = 10 table.default_mutate_rows_retryable_errors = ( @@ -319,9 +323,9 @@ def _make_one(self, table=None, **kwargs): @staticmethod def _make_mutation(count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count + mutation = RowMutationEntry("k", DeleteAllFromRow()) + mutation.size = lambda: size + mutation.mutations = [DeleteAllFromRow() for _ in range(count)] return mutation @CrossSync.pytest @@ -334,7 +338,7 @@ async def test_ctor_defaults(self): table.default_mutate_rows_attempt_timeout = 8 table.default_mutate_rows_retryable_errors = [Exception] async with self._make_one(table) as instance: - assert instance._table == table + assert instance._target == table assert instance.closed is False assert instance._flush_jobs == set() assert len(instance._staged_entries) == 0 @@ -390,7 +394,7 @@ async def test_ctor_explicit(self): batch_attempt_timeout=attempt_timeout, batch_retryable_errors=retryable_errors, ) as instance: - assert instance._table == table + assert instance._target == table assert instance.closed is False assert instance._flush_jobs == set() assert len(instance._staged_entries) == 0 @@ -435,7 +439,7 @@ async def test_ctor_no_flush_limits(self): flush_limit_mutation_count=flush_limit_count, flush_limit_bytes=flush_limit_bytes, ) as instance: - assert instance._table == table + assert instance._target == table assert instance.closed is False assert instance._staged_entries == [] assert len(instance._oldest_exceptions) == 0 @@ -903,10 +907,10 @@ async def test_timer_flush_end_to_end(self): mutations = [self._make_mutation(count=2, size=2)] * num_mutations async with self._make_one(flush_interval=0.05) as instance: - instance._table.default_operation_timeout = 10 - instance._table.default_attempt_timeout = 9 + instance._target.default_operation_timeout = 10 + instance._target.default_attempt_timeout = 9 with mock.patch.object( - instance._table.client._gapic_client, "mutate_rows" + instance._target.client._gapic_client, "mutate_rows" ) as gapic_mock: gapic_mock.side_effect = ( lambda *args, **kwargs: self._mock_gapic_return(num_mutations) diff --git a/tests/unit/data/_sync_autogen/test__mutate_rows.py b/tests/unit/data/_sync_autogen/test__mutate_rows.py index 2173c88fb..b198df01b 100644 --- a/tests/unit/data/_sync_autogen/test__mutate_rows.py +++ b/tests/unit/data/_sync_autogen/test__mutate_rows.py @@ -17,6 +17,8 @@ import pytest from google.cloud.bigtable_v2.types import MutateRowsResponse +from google.cloud.bigtable.data.mutations import RowMutationEntry +from google.cloud.bigtable.data.mutations import DeleteAllFromRow from google.rpc import status_pb2 from google.api_core.exceptions import DeadlineExceeded from google.api_core.exceptions import Forbidden @@ -34,8 +36,11 @@ def _target_class(self): def _make_one(self, *args, **kwargs): if not args: + fake_target = CrossSync._Sync_Impl.Mock() + fake_target._request_path = {"table_name": "table"} + fake_target.app_profile_id = None kwargs["gapic_client"] = kwargs.pop("gapic_client", mock.Mock()) - kwargs["table"] = kwargs.pop("table", CrossSync._Sync_Impl.Mock()) + kwargs["target"] = kwargs.pop("target", fake_target) kwargs["operation_timeout"] = kwargs.pop("operation_timeout", 5) kwargs["attempt_timeout"] = kwargs.pop("attempt_timeout", 0.1) kwargs["retryable_exceptions"] = kwargs.pop("retryable_exceptions", ()) @@ -43,9 +48,8 @@ def _make_one(self, *args, **kwargs): return self._target_class()(*args, **kwargs) def _make_mutation(self, count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count + mutation = RowMutationEntry("k", [DeleteAllFromRow() for _ in range(count)]) + mutation.size = lambda: size return mutation def _mock_stream(self, mutation_list, error_dict): @@ -92,11 +96,6 @@ def test_ctor(self): assert client.mutate_rows.call_count == 0 instance._gapic_fn() assert client.mutate_rows.call_count == 1 - inner_kwargs = client.mutate_rows.call_args[1] - assert len(inner_kwargs) == 3 - assert inner_kwargs["table_name"] == table.table_name - assert inner_kwargs["app_profile_id"] == table.app_profile_id - assert inner_kwargs["retry"] is None entries_w_pb = [_EntryWithProto(e, e._to_pb()) for e in entries] assert instance.mutations == entries_w_pb assert next(instance.timeout_generator) == attempt_timeout @@ -148,6 +147,8 @@ def test_mutate_rows_attempt_exception(self, exc_type): """exceptions raised from attempt should be raised in MutationsExceptionGroup""" client = CrossSync._Sync_Impl.Mock() table = mock.Mock() + table._request_path = {"table_name": "table"} + table.app_profile_id = None entries = [self._make_mutation(), self._make_mutation()] operation_timeout = 0.05 expected_exception = exc_type("test") @@ -260,7 +261,8 @@ def test_run_attempt_single_entry_success(self): assert mock_gapic_fn.call_count == 1 (_, kwargs) = mock_gapic_fn.call_args assert kwargs["timeout"] == expected_timeout - assert kwargs["entries"] == [mutation._to_pb()] + request = kwargs["request"] + assert request.entries == [mutation._to_pb()] def test_run_attempt_empty_request(self): """Calling with no mutations should result in no API calls""" diff --git a/tests/unit/data/_sync_autogen/test__read_rows.py b/tests/unit/data/_sync_autogen/test__read_rows.py index 973b07bcb..a545142d3 100644 --- a/tests/unit/data/_sync_autogen/test__read_rows.py +++ b/tests/unit/data/_sync_autogen/test__read_rows.py @@ -48,7 +48,7 @@ def test_ctor(self): client.read_rows.return_value = None table = mock.Mock() table._client = client - table.table_name = "test_table" + table._request_path = {"table_name": "test_table"} table.app_profile_id = "test_profile" expected_operation_timeout = 42 expected_request_timeout = 44 @@ -72,7 +72,7 @@ def test_ctor(self): assert instance._remaining_count == row_limit assert instance.operation_timeout == expected_operation_timeout assert client.read_rows.call_count == 0 - assert instance.request.table_name == table.table_name + assert instance.request.table_name == "test_table" assert instance.request.app_profile_id == table.app_profile_id assert instance.request.rows_limit == row_limit @@ -252,7 +252,7 @@ def mock_stream(): query = ReadRowsQuery(limit=start_limit) table = mock.Mock() - table.table_name = "table_name" + table._request_path = {"table_name": "table_name"} table.app_profile_id = "app_profile_id" instance = self._make_one(query, table, 10, 10) assert instance._remaining_count == start_limit @@ -287,7 +287,7 @@ def mock_stream(): query = ReadRowsQuery(limit=start_limit) table = mock.Mock() - table.table_name = "table_name" + table._request_path = {"table_name": "table_name"} table.app_profile_id = "app_profile_id" instance = self._make_one(query, table, 10, 10) assert instance._remaining_count == start_limit diff --git a/tests/unit/data/_sync_autogen/test_client.py b/tests/unit/data/_sync_autogen/test_client.py index 720f0e0b6..518575c5a 100644 --- a/tests/unit/data/_sync_autogen/test_client.py +++ b/tests/unit/data/_sync_autogen/test_client.py @@ -27,6 +27,7 @@ from google.api_core import exceptions as core_exceptions from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete +from google.cloud.bigtable.data.mutations import DeleteAllFromRow from google.cloud.bigtable.data import TABLE_DEFAULT from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule @@ -207,9 +208,7 @@ def test__ping_and_warm_instances(self): assert len(result) == 0 assert gather.call_args[1]["return_exceptions"] is True assert gather.call_args[1]["sync_executor"] == client_mock._executor - client_mock._active_instances = [ - (mock.Mock(), mock.Mock(), mock.Mock()) - ] * 4 + client_mock._active_instances = [(mock.Mock(), mock.Mock())] * 4 gather.reset_mock() channel.reset_mock() result = self._get_target_class()._ping_and_warm_instances( @@ -223,7 +222,6 @@ def test__ping_and_warm_instances(self): for idx, (_, kwargs) in enumerate(grpc_call_args): ( expected_instance, - expected_table, expected_app_profile, ) = client_mock._active_instances[idx] request = kwargs["request"] @@ -250,7 +248,7 @@ def test__ping_and_warm_single_instance(self): ) as gather: gather.side_effect = lambda *args, **kwargs: [fn() for fn in args[0]] client_mock._active_instances = [mock.Mock()] * 100 - test_key = ("test-instance", "test-table", "test-app-profile") + test_key = ("test-instance", "test-app-profile") result = self._get_target_class()._ping_and_warm_instances( client_mock, test_key ) @@ -436,11 +434,7 @@ def test__register_instance(self): client_mock, "instance-1", table_mock ) assert client_mock._start_background_channel_refresh.call_count == 1 - expected_key = ( - "prefix/instance-1", - table_mock.table_name, - table_mock.app_profile_id, - ) + expected_key = ("prefix/instance-1", table_mock.app_profile_id) assert len(active_instances) == 1 assert expected_key == tuple(list(active_instances)[0]) assert len(instance_owners) == 1 @@ -458,11 +452,7 @@ def test__register_instance(self): assert client_mock._ping_and_warm_instances.call_count == 1 assert len(active_instances) == 2 assert len(instance_owners) == 2 - expected_key2 = ( - "prefix/instance-2", - table_mock2.table_name, - table_mock2.app_profile_id, - ) + expected_key2 = ("prefix/instance-2", table_mock2.app_profile_id) assert any( [ expected_key2 == tuple(list(active_instances)[i]) @@ -489,11 +479,7 @@ def test__register_instance_duplicate(self): client_mock.transport.channels = mock_channels client_mock._ping_and_warm_instances = CrossSync._Sync_Impl.Mock() table_mock = mock.Mock() - expected_key = ( - "prefix/instance-1", - table_mock.table_name, - table_mock.app_profile_id, - ) + expected_key = ("prefix/instance-1", table_mock.app_profile_id) self._get_target_class()._register_instance( client_mock, "instance-1", table_mock ) @@ -514,13 +500,13 @@ def test__register_instance_duplicate(self): @pytest.mark.parametrize( "insert_instances,expected_active,expected_owner_keys", [ - ([("i", "t", None)], [("i", "t", None)], [("i", "t", None)]), - ([("i", "t", "p")], [("i", "t", "p")], [("i", "t", "p")]), - ([("1", "t", "p"), ("1", "t", "p")], [("1", "t", "p")], [("1", "t", "p")]), + ([("i", None)], [("i", None)], [("i", None)]), + ([("i", "p")], [("i", "p")], [("i", "p")]), + ([("1", "p"), ("1", "p")], [("1", "p")], [("1", "p")]), ( - [("1", "t", "p"), ("2", "t", "p")], - [("1", "t", "p"), ("2", "t", "p")], - [("1", "t", "p"), ("2", "t", "p")], + [("1", "p"), ("2", "p")], + [("1", "p"), ("2", "p")], + [("1", "p"), ("2", "p")], ), ], ) @@ -537,8 +523,7 @@ def test__register_instance_state( client_mock._channel_refresh_task = None client_mock._ping_and_warm_instances = CrossSync._Sync_Impl.Mock() table_mock = mock.Mock() - for instance, table, profile in insert_instances: - table_mock.table_name = table + for instance, profile in insert_instances: table_mock.app_profile_id = profile self._get_target_class()._register_instance( client_mock, instance, table_mock @@ -570,11 +555,11 @@ def test__remove_instance_registration(self): instance_1_path = client._gapic_client.instance_path( client.project, "instance-1" ) - instance_1_key = (instance_1_path, table.table_name, table.app_profile_id) + instance_1_key = (instance_1_path, table.app_profile_id) instance_2_path = client._gapic_client.instance_path( client.project, "instance-2" ) - instance_2_key = (instance_2_path, table.table_name, table.app_profile_id) + instance_2_key = (instance_2_path, table.app_profile_id) assert len(client._instance_owners[instance_1_key]) == 1 assert list(client._instance_owners[instance_1_key])[0] == id(table) assert len(client._instance_owners[instance_2_key]) == 1 @@ -602,26 +587,28 @@ def test__multiple_table_registration(self): client.project, "instance_1" ) instance_1_key = _WarmedInstanceKey( - instance_1_path, table_1.table_name, table_1.app_profile_id + instance_1_path, table_1.app_profile_id ) assert len(client._instance_owners[instance_1_key]) == 1 assert len(client._active_instances) == 1 assert id(table_1) in client._instance_owners[instance_1_key] - with client.get_table("instance_1", "table_1") as table_2: + with client.get_table("instance_1", "table_2") as table_2: assert table_2._register_instance_future is not None table_2._register_instance_future.result() assert len(client._instance_owners[instance_1_key]) == 2 assert len(client._active_instances) == 1 assert id(table_1) in client._instance_owners[instance_1_key] assert id(table_2) in client._instance_owners[instance_1_key] - with client.get_table("instance_1", "table_3") as table_3: + with client.get_table( + "instance_1", "table_3", app_profile_id="diff" + ) as table_3: assert table_3._register_instance_future is not None table_3._register_instance_future.result() instance_3_path = client._gapic_client.instance_path( client.project, "instance_1" ) instance_3_key = _WarmedInstanceKey( - instance_3_path, table_3.table_name, table_3.app_profile_id + instance_3_path, table_3.app_profile_id ) assert len(client._instance_owners[instance_1_key]) == 2 assert len(client._instance_owners[instance_3_key]) == 1 @@ -652,13 +639,13 @@ def test__multiple_instance_registration(self): client.project, "instance_1" ) instance_1_key = _WarmedInstanceKey( - instance_1_path, table_1.table_name, table_1.app_profile_id + instance_1_path, table_1.app_profile_id ) instance_2_path = client._gapic_client.instance_path( client.project, "instance_2" ) instance_2_key = _WarmedInstanceKey( - instance_2_path, table_2.table_name, table_2.app_profile_id + instance_2_path, table_2.app_profile_id ) assert len(client._instance_owners[instance_1_key]) == 1 assert len(client._instance_owners[instance_2_key]) == 1 @@ -674,7 +661,9 @@ def test__multiple_instance_registration(self): assert len(client._instance_owners[instance_1_key]) == 0 assert len(client._instance_owners[instance_2_key]) == 0 - def test_get_table(self): + @pytest.mark.parametrize("method", ["get_table", "get_authorized_view"]) + def test_get_api_surface(self, method): + """test client.get_table and client.get_authorized_view""" from google.cloud.bigtable.data._helpers import _WarmedInstanceKey client = self._make_client(project="project-id") @@ -682,77 +671,113 @@ def test_get_table(self): expected_table_id = "table-id" expected_instance_id = "instance-id" expected_app_profile_id = "app-profile-id" - table = client.get_table( - expected_instance_id, expected_table_id, expected_app_profile_id - ) + if method == "get_table": + surface = client.get_table( + expected_instance_id, expected_table_id, expected_app_profile_id + ) + assert isinstance( + surface, CrossSync._Sync_Impl.TestTable._get_target_class() + ) + elif method == "get_authorized_view": + surface = client.get_authorized_view( + expected_instance_id, + expected_table_id, + "view_id", + expected_app_profile_id, + ) + assert isinstance( + surface, CrossSync._Sync_Impl.TestAuthorizedView._get_target_class() + ) + assert ( + surface.authorized_view_name + == f"projects/{client.project}/instances/{expected_instance_id}/tables/{expected_table_id}/authorizedViews/view_id" + ) + else: + raise TypeError(f"unexpected method: {method}") CrossSync._Sync_Impl.yield_to_event_loop() - assert isinstance(table, CrossSync._Sync_Impl.TestTable._get_target_class()) - assert table.table_id == expected_table_id + assert surface.table_id == expected_table_id assert ( - table.table_name + surface.table_name == f"projects/{client.project}/instances/{expected_instance_id}/tables/{expected_table_id}" ) - assert table.instance_id == expected_instance_id + assert surface.instance_id == expected_instance_id assert ( - table.instance_name + surface.instance_name == f"projects/{client.project}/instances/{expected_instance_id}" ) - assert table.app_profile_id == expected_app_profile_id - assert table.client is client - instance_key = _WarmedInstanceKey( - table.instance_name, table.table_name, table.app_profile_id - ) + assert surface.app_profile_id == expected_app_profile_id + assert surface.client is client + instance_key = _WarmedInstanceKey(surface.instance_name, surface.app_profile_id) assert instance_key in client._active_instances - assert client._instance_owners[instance_key] == {id(table)} + assert client._instance_owners[instance_key] == {id(surface)} client.close() - def test_get_table_arg_passthrough(self): - """All arguments passed in get_table should be sent to constructor""" + @pytest.mark.parametrize("method", ["get_table", "get_authorized_view"]) + def test_api_surface_arg_passthrough(self, method): + """All arguments passed in get_table and get_authorized_view should be sent to constructor""" + if method == "get_table": + surface_type = CrossSync._Sync_Impl.TestTable._get_target_class() + elif method == "get_authorized_view": + surface_type = CrossSync._Sync_Impl.TestAuthorizedView._get_target_class() + else: + raise TypeError(f"unexpected method: {method}") with self._make_client(project="project-id") as client: - with mock.patch.object( - CrossSync._Sync_Impl.TestTable._get_target_class(), "__init__" - ) as mock_constructor: + with mock.patch.object(surface_type, "__init__") as mock_constructor: mock_constructor.return_value = None assert not client._active_instances - expected_table_id = "table-id" - expected_instance_id = "instance-id" - expected_app_profile_id = "app-profile-id" - expected_args = (1, "test", {"test": 2}) - expected_kwargs = {"hello": "world", "test": 2} - client.get_table( - expected_instance_id, - expected_table_id, - expected_app_profile_id, - *expected_args, - **expected_kwargs, + expected_args = ( + "table", + "instance", + "view", + "app_profile", + 1, + "test", + {"test": 2}, ) + expected_kwargs = {"hello": "world", "test": 2} + getattr(client, method)(*expected_args, **expected_kwargs) mock_constructor.assert_called_once_with( - client, - expected_instance_id, - expected_table_id, - expected_app_profile_id, - *expected_args, - **expected_kwargs, + client, *expected_args, **expected_kwargs ) - def test_get_table_context_manager(self): + @pytest.mark.parametrize("method", ["get_table", "get_authorized_view"]) + def test_api_surface_context_manager(self, method): + """get_table and get_authorized_view should work as context managers""" + from functools import partial from google.cloud.bigtable.data._helpers import _WarmedInstanceKey expected_table_id = "table-id" expected_instance_id = "instance-id" expected_app_profile_id = "app-profile-id" expected_project_id = "project-id" - with mock.patch.object( - CrossSync._Sync_Impl.TestTable._get_target_class(), "close" - ) as close_mock: + if method == "get_table": + surface_type = CrossSync._Sync_Impl.TestTable._get_target_class() + elif method == "get_authorized_view": + surface_type = CrossSync._Sync_Impl.TestAuthorizedView._get_target_class() + else: + raise TypeError(f"unexpected method: {method}") + with mock.patch.object(surface_type, "close") as close_mock: with self._make_client(project=expected_project_id) as client: - with client.get_table( - expected_instance_id, expected_table_id, expected_app_profile_id - ) as table: - CrossSync._Sync_Impl.yield_to_event_loop() - assert isinstance( - table, CrossSync._Sync_Impl.TestTable._get_target_class() + if method == "get_table": + fn = partial( + client.get_table, + expected_instance_id, + expected_table_id, + expected_app_profile_id, + ) + elif method == "get_authorized_view": + fn = partial( + client.get_authorized_view, + expected_instance_id, + expected_table_id, + "view_id", + expected_app_profile_id, ) + else: + raise TypeError(f"unexpected method: {method}") + with fn() as table: + CrossSync._Sync_Impl.yield_to_event_loop() + assert isinstance(table, surface_type) assert table.table_id == expected_table_id assert ( table.table_name @@ -766,7 +791,7 @@ def test_get_table_context_manager(self): assert table.app_profile_id == expected_app_profile_id assert table.client is client instance_key = _WarmedInstanceKey( - table.instance_name, table.table_name, table.app_profile_id + table.instance_name, table.app_profile_id ) assert instance_key in client._active_instances assert client._instance_owners[instance_key] == {id(table)} @@ -821,7 +846,19 @@ def _make_client(self, *args, **kwargs): def _get_target_class(): return CrossSync._Sync_Impl.Table - def test_table_ctor(self): + def _make_one( + self, + client, + instance_id="instance", + table_id="table", + app_profile_id=None, + **kwargs, + ): + return self._get_target_class()( + client, instance_id, table_id, app_profile_id, **kwargs + ) + + def test_ctor(self): from google.cloud.bigtable.data._helpers import _WarmedInstanceKey expected_table_id = "table-id" @@ -850,11 +887,17 @@ def test_table_ctor(self): CrossSync._Sync_Impl.yield_to_event_loop() assert table.table_id == expected_table_id assert table.instance_id == expected_instance_id + assert ( + table.table_name + == f"projects/{client.project}/instances/{expected_instance_id}/tables/{expected_table_id}" + ) + assert ( + table.instance_name + == f"projects/{client.project}/instances/{expected_instance_id}" + ) assert table.app_profile_id == expected_app_profile_id assert table.client is client - instance_key = _WarmedInstanceKey( - table.instance_name, table.table_name, table.app_profile_id - ) + instance_key = _WarmedInstanceKey(table.instance_name, table.app_profile_id) assert instance_key in client._active_instances assert client._instance_owners[instance_key] == {id(table)} assert table.default_operation_timeout == expected_operation_timeout @@ -881,18 +924,12 @@ def test_table_ctor(self): assert table._register_instance_future.exception() is None client.close() - def test_table_ctor_defaults(self): + def test_ctor_defaults(self): """should provide default timeout values and app_profile_id""" - expected_table_id = "table-id" - expected_instance_id = "instance-id" client = self._make_client() assert not client._active_instances - table = self._get_target_class()( - client, expected_instance_id, expected_table_id - ) + table = self._make_one(client) CrossSync._Sync_Impl.yield_to_event_loop() - assert table.table_id == expected_table_id - assert table.instance_id == expected_instance_id assert table.app_profile_id is None assert table.client is client assert table.default_operation_timeout == 60 @@ -903,7 +940,7 @@ def test_table_ctor_defaults(self): assert table.default_mutate_rows_attempt_timeout == 60 client.close() - def test_table_ctor_invalid_timeout_values(self): + def test_ctor_invalid_timeout_values(self): """bad timeout values should raise ValueError""" client = self._make_client() timeout_pairs = [ @@ -919,10 +956,10 @@ def test_table_ctor_invalid_timeout_values(self): ] for operation_timeout, attempt_timeout in timeout_pairs: with pytest.raises(ValueError) as e: - self._get_target_class()(client, "", "", **{attempt_timeout: -1}) + self._make_one(client, **{attempt_timeout: -1}) assert "attempt_timeout must be greater than 0" in str(e.value) with pytest.raises(ValueError) as e: - self._get_target_class()(client, "", "", **{operation_timeout: -1}) + self._make_one(client, **{operation_timeout: -1}) assert "operation_timeout must be greater than 0" in str(e.value) client.close() @@ -935,10 +972,10 @@ def test_table_ctor_invalid_timeout_values(self): ("read_rows_sharded", ([ReadRowsQuery()],), True, ()), ("row_exists", (b"row_key",), True, ()), ("sample_row_keys", (), False, ()), - ("mutate_row", (b"row_key", [mock.Mock()]), False, ()), + ("mutate_row", (b"row_key", [DeleteAllFromRow()]), False, ()), ( "bulk_mutate_rows", - ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), + ([mutations.RowMutationEntry(b"key", [DeleteAllFromRow()])],), False, (_MutateRowsIncomplete,), ), @@ -1035,7 +1072,7 @@ def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): gapic_client = client._gapic_client gapic_client._transport = transport_mock gapic_client._is_universe_domain_valid = True - table = self._get_target_class()(client, "instance-id", "table-id", profile) + table = self._make_one(client, app_profile_id=profile) try: test_fn = table.__getattribute__(fn_name) maybe_stream = test_fn(*fn_args) @@ -1048,12 +1085,118 @@ def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): assert len(metadata) == 1 assert metadata[0][0] == "x-goog-request-params" routing_str = metadata[0][1] - assert "table_name=" + table.table_name in routing_str + assert self._expected_routing_header(table) in routing_str if include_app_profile: assert "app_profile_id=profile" in routing_str else: assert "app_profile_id=" not in routing_str + @staticmethod + def _expected_routing_header(table): + """the expected routing header for this _ApiSurface type""" + return f"table_name={table.table_name}" + + +@CrossSync._Sync_Impl.add_mapping_decorator("TestAuthorizedView") +class TestAuthorizedView(CrossSync._Sync_Impl.TestTable): + """ + Inherit tests from TestTableAsync, with some modifications + """ + + @staticmethod + def _get_target_class(): + return CrossSync._Sync_Impl.AuthorizedView + + def _make_one( + self, + client, + instance_id="instance", + table_id="table", + view_id="view", + app_profile_id=None, + **kwargs, + ): + return self._get_target_class()( + client, instance_id, table_id, view_id, app_profile_id, **kwargs + ) + + @staticmethod + def _expected_routing_header(view): + """the expected routing header for this _ApiSurface type""" + return f"authorized_view_name={view.authorized_view_name}" + + def test_ctor(self): + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_view_id = "view_id" + expected_app_profile_id = "app-profile-id" + expected_operation_timeout = 123 + expected_attempt_timeout = 12 + expected_read_rows_operation_timeout = 1.5 + expected_read_rows_attempt_timeout = 0.5 + expected_mutate_rows_operation_timeout = 2.5 + expected_mutate_rows_attempt_timeout = 0.75 + client = self._make_client() + assert not client._active_instances + view = self._get_target_class()( + client, + expected_instance_id, + expected_table_id, + expected_view_id, + expected_app_profile_id, + default_operation_timeout=expected_operation_timeout, + default_attempt_timeout=expected_attempt_timeout, + default_read_rows_operation_timeout=expected_read_rows_operation_timeout, + default_read_rows_attempt_timeout=expected_read_rows_attempt_timeout, + default_mutate_rows_operation_timeout=expected_mutate_rows_operation_timeout, + default_mutate_rows_attempt_timeout=expected_mutate_rows_attempt_timeout, + ) + CrossSync._Sync_Impl.yield_to_event_loop() + assert view.table_id == expected_table_id + assert ( + view.table_name + == f"projects/{client.project}/instances/{expected_instance_id}/tables/{expected_table_id}" + ) + assert view.instance_id == expected_instance_id + assert ( + view.instance_name + == f"projects/{client.project}/instances/{expected_instance_id}" + ) + assert view.authorized_view_id == expected_view_id + assert ( + view.authorized_view_name + == f"projects/{client.project}/instances/{expected_instance_id}/tables/{expected_table_id}/authorizedViews/{expected_view_id}" + ) + assert view.app_profile_id == expected_app_profile_id + assert view.client is client + instance_key = _WarmedInstanceKey(view.instance_name, view.app_profile_id) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(view)} + assert view.default_operation_timeout == expected_operation_timeout + assert view.default_attempt_timeout == expected_attempt_timeout + assert ( + view.default_read_rows_operation_timeout + == expected_read_rows_operation_timeout + ) + assert ( + view.default_read_rows_attempt_timeout == expected_read_rows_attempt_timeout + ) + assert ( + view.default_mutate_rows_operation_timeout + == expected_mutate_rows_operation_timeout + ) + assert ( + view.default_mutate_rows_attempt_timeout + == expected_mutate_rows_attempt_timeout + ) + view._register_instance_future + assert view._register_instance_future.done() + assert not view._register_instance_future.cancelled() + assert view._register_instance_future.exception() is None + client.close() + @CrossSync._Sync_Impl.add_mapping_decorator("TestReadRows") class TestReadRows: @@ -1787,11 +1930,12 @@ def test_sample_row_keys_gapic_params(self): table.sample_row_keys(attempt_timeout=expected_timeout) (args, kwargs) = sample_row_keys.call_args assert len(args) == 0 - assert len(kwargs) == 4 + assert len(kwargs) == 3 assert kwargs["timeout"] == expected_timeout - assert kwargs["app_profile_id"] == expected_profile - assert kwargs["table_name"] == table.table_name assert kwargs["retry"] is None + request = kwargs["request"] + assert request.app_profile_id == expected_profile + assert request.table_name == table.table_name @pytest.mark.parametrize( "retryable_exception", @@ -1879,17 +2023,18 @@ def test_mutate_row(self, mutation_arg): ) assert mock_gapic.call_count == 1 kwargs = mock_gapic.call_args_list[0].kwargs + request = kwargs["request"] assert ( - kwargs["table_name"] + request.table_name == "projects/project/instances/instance/tables/table" ) - assert kwargs["row_key"] == b"row_key" + assert request.row_key == b"row_key" formatted_mutations = ( [mutation._to_pb() for mutation in mutation_arg] if isinstance(mutation_arg, list) else [mutation_arg._to_pb()] ) - assert kwargs["mutations"] == formatted_mutations + assert request.mutations == formatted_mutations assert kwargs["timeout"] == expected_attempt_timeout assert kwargs["retry"] is None @@ -2033,11 +2178,12 @@ def test_bulk_mutate_rows(self, mutation_arg): ) assert mock_gapic.call_count == 1 kwargs = mock_gapic.call_args[1] + request = kwargs["request"] assert ( - kwargs["table_name"] + request.table_name == "projects/project/instances/instance/tables/table" ) - assert kwargs["entries"] == [bulk_mutation._to_pb()] + assert request.entries == [bulk_mutation._to_pb()] assert kwargs["timeout"] == expected_attempt_timeout assert kwargs["retry"] is None @@ -2055,12 +2201,13 @@ def test_bulk_mutate_rows_multiple_entries(self): table.bulk_mutate_rows([entry_1, entry_2]) assert mock_gapic.call_count == 1 kwargs = mock_gapic.call_args[1] + request = kwargs["request"] assert ( - kwargs["table_name"] + request.table_name == "projects/project/instances/instance/tables/table" ) - assert kwargs["entries"][0] == entry_1._to_pb() - assert kwargs["entries"][1] == entry_2._to_pb() + assert request.entries[0] == entry_1._to_pb() + assert request.entries[1] == entry_2._to_pb() @pytest.mark.parametrize( "exception", @@ -2328,8 +2475,8 @@ def test_check_and_mutate(self, gapic_result): ) row_key = b"row_key" predicate = None - true_mutations = [mock.Mock()] - false_mutations = [mock.Mock(), mock.Mock()] + true_mutations = [DeleteAllFromRow()] + false_mutations = [DeleteAllFromRow(), DeleteAllFromRow()] operation_timeout = 0.2 found = table.check_and_mutate_row( row_key, @@ -2340,16 +2487,17 @@ def test_check_and_mutate(self, gapic_result): ) assert found == gapic_result kwargs = mock_gapic.call_args[1] - assert kwargs["table_name"] == table.table_name - assert kwargs["row_key"] == row_key - assert kwargs["predicate_filter"] == predicate - assert kwargs["true_mutations"] == [ + request = kwargs["request"] + assert request.table_name == table.table_name + assert request.row_key == row_key + assert bool(request.predicate_filter) is False + assert request.true_mutations == [ m._to_pb() for m in true_mutations ] - assert kwargs["false_mutations"] == [ + assert request.false_mutations == [ m._to_pb() for m in false_mutations ] - assert kwargs["app_profile_id"] == app_profile + assert request.app_profile_id == app_profile assert kwargs["timeout"] == operation_timeout assert kwargs["retry"] is None @@ -2389,15 +2537,17 @@ def test_check_and_mutate_single_mutations(self): false_case_mutations=false_mutation, ) kwargs = mock_gapic.call_args[1] - assert kwargs["true_mutations"] == [true_mutation._to_pb()] - assert kwargs["false_mutations"] == [false_mutation._to_pb()] + request = kwargs["request"] + assert request.true_mutations == [true_mutation._to_pb()] + assert request.false_mutations == [false_mutation._to_pb()] def test_check_and_mutate_predicate_object(self): """predicate filter should be passed to gapic request""" from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + from google.cloud.bigtable_v2.types.data import RowFilter mock_predicate = mock.Mock() - predicate_pb = {"predicate": "dict"} + predicate_pb = RowFilter({"sink": True}) mock_predicate._to_pb.return_value = predicate_pb with self._make_client() as client: with client.get_table("instance", "table") as table: @@ -2408,21 +2558,24 @@ def test_check_and_mutate_predicate_object(self): predicate_matched=True ) table.check_and_mutate_row( - b"row_key", mock_predicate, false_case_mutations=[mock.Mock()] + b"row_key", + mock_predicate, + false_case_mutations=[DeleteAllFromRow()], ) kwargs = mock_gapic.call_args[1] - assert kwargs["predicate_filter"] == predicate_pb + request = kwargs["request"] + assert request.predicate_filter == predicate_pb assert mock_predicate._to_pb.call_count == 1 assert kwargs["retry"] is None def test_check_and_mutate_mutations_parsing(self): """mutations objects should be converted to protos""" from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse - from google.cloud.bigtable.data.mutations import DeleteAllFromRow + from google.cloud.bigtable.data.mutations import DeleteAllFromFamily mutations = [mock.Mock() for _ in range(5)] for idx, mutation in enumerate(mutations): - mutation._to_pb.return_value = f"fake {idx}" + mutation._to_pb.return_value = DeleteAllFromFamily(f"fake {idx}")._to_pb() mutations.append(DeleteAllFromRow()) with self._make_client() as client: with client.get_table("instance", "table") as table: @@ -2439,11 +2592,15 @@ def test_check_and_mutate_mutations_parsing(self): false_case_mutations=mutations[2:], ) kwargs = mock_gapic.call_args[1] - assert kwargs["true_mutations"] == ["fake 0", "fake 1"] - assert kwargs["false_mutations"] == [ - "fake 2", - "fake 3", - "fake 4", + request = kwargs["request"] + assert request.true_mutations == [ + DeleteAllFromFamily("fake 0")._to_pb(), + DeleteAllFromFamily("fake 1")._to_pb(), + ] + assert request.false_mutations == [ + DeleteAllFromFamily("fake 2")._to_pb(), + DeleteAllFromFamily("fake 3")._to_pb(), + DeleteAllFromFamily("fake 4")._to_pb(), DeleteAllFromRow()._to_pb(), ] assert all( @@ -2486,7 +2643,8 @@ def test_read_modify_write_call_rule_args(self, call_rules, expected_rules): table.read_modify_write_row("key", call_rules) assert mock_gapic.call_count == 1 found_kwargs = mock_gapic.call_args_list[0][1] - assert found_kwargs["rules"] == expected_rules + request = found_kwargs["request"] + assert request.rules == expected_rules assert found_kwargs["retry"] is None @pytest.mark.parametrize("rules", [[], None]) @@ -2507,15 +2665,16 @@ def test_read_modify_write_call_defaults(self): with mock.patch.object( client._gapic_client, "read_modify_write_row" ) as mock_gapic: - table.read_modify_write_row(row_key, mock.Mock()) + table.read_modify_write_row(row_key, IncrementRule("f", "q")) assert mock_gapic.call_count == 1 kwargs = mock_gapic.call_args_list[0][1] + request = kwargs["request"] assert ( - kwargs["table_name"] + request.table_name == f"projects/{project}/instances/{instance}/tables/{table_id}" ) - assert kwargs["app_profile_id"] is None - assert kwargs["row_key"] == row_key.encode() + assert bool(request.app_profile_id) is False + assert request.row_key == row_key.encode() assert kwargs["timeout"] > 1 def test_read_modify_write_call_overrides(self): @@ -2530,12 +2689,15 @@ def test_read_modify_write_call_overrides(self): client._gapic_client, "read_modify_write_row" ) as mock_gapic: table.read_modify_write_row( - row_key, mock.Mock(), operation_timeout=expected_timeout + row_key, + IncrementRule("f", "q"), + operation_timeout=expected_timeout, ) assert mock_gapic.call_count == 1 kwargs = mock_gapic.call_args_list[0][1] - assert kwargs["app_profile_id"] is profile_id - assert kwargs["row_key"] == row_key + request = kwargs["request"] + assert request.app_profile_id == profile_id + assert request.row_key == row_key assert kwargs["timeout"] == expected_timeout def test_read_modify_write_string_key(self): @@ -2545,10 +2707,11 @@ def test_read_modify_write_string_key(self): with mock.patch.object( client._gapic_client, "read_modify_write_row" ) as mock_gapic: - table.read_modify_write_row(row_key, mock.Mock()) + table.read_modify_write_row(row_key, IncrementRule("f", "q")) assert mock_gapic.call_count == 1 kwargs = mock_gapic.call_args_list[0][1] - assert kwargs["row_key"] == row_key.encode() + request = kwargs["request"] + assert request.row_key == row_key.encode() def test_read_modify_write_row_building(self): """results from gapic call should be used to construct row""" @@ -2564,7 +2727,7 @@ def test_read_modify_write_row_building(self): ) as mock_gapic: with mock.patch.object(Row, "_from_pb") as constructor_mock: mock_gapic.return_value = mock_response - table.read_modify_write_row("key", mock.Mock()) + table.read_modify_write_row("key", IncrementRule("f", "q")) assert constructor_mock.call_count == 1 constructor_mock.assert_called_once_with(mock_response.row) diff --git a/tests/unit/data/_sync_autogen/test_mutations_batcher.py b/tests/unit/data/_sync_autogen/test_mutations_batcher.py index 59ea621ac..72db64146 100644 --- a/tests/unit/data/_sync_autogen/test_mutations_batcher.py +++ b/tests/unit/data/_sync_autogen/test_mutations_batcher.py @@ -22,6 +22,8 @@ import google.api_core.exceptions as core_exceptions import google.api_core.retry from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete +from google.cloud.bigtable.data.mutations import RowMutationEntry +from google.cloud.bigtable.data.mutations import DeleteAllFromRow from google.cloud.bigtable.data import TABLE_DEFAULT from google.cloud.bigtable.data._cross_sync import CrossSync @@ -36,9 +38,9 @@ def _make_one(self, max_mutation_count=10, max_mutation_bytes=100): @staticmethod def _make_mutation(count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count + mutation = RowMutationEntry("k", DeleteAllFromRow()) + mutation.mutations = [DeleteAllFromRow() for _ in range(count)] + mutation.size = lambda: size return mutation def test_ctor(self): @@ -258,6 +260,8 @@ def _make_one(self, table=None, **kwargs): if table is None: table = mock.Mock() + table._request_path = {"table_name": "table"} + table.app_profile_id = None table.default_mutate_rows_operation_timeout = 10 table.default_mutate_rows_attempt_timeout = 10 table.default_mutate_rows_retryable_errors = ( @@ -268,9 +272,9 @@ def _make_one(self, table=None, **kwargs): @staticmethod def _make_mutation(count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count + mutation = RowMutationEntry("k", DeleteAllFromRow()) + mutation.size = lambda: size + mutation.mutations = [DeleteAllFromRow() for _ in range(count)] return mutation def test_ctor_defaults(self): @@ -284,7 +288,7 @@ def test_ctor_defaults(self): table.default_mutate_rows_attempt_timeout = 8 table.default_mutate_rows_retryable_errors = [Exception] with self._make_one(table) as instance: - assert instance._table == table + assert instance._target == table assert instance.closed is False assert instance._flush_jobs == set() assert len(instance._staged_entries) == 0 @@ -341,7 +345,7 @@ def test_ctor_explicit(self): batch_attempt_timeout=attempt_timeout, batch_retryable_errors=retryable_errors, ) as instance: - assert instance._table == table + assert instance._target == table assert instance.closed is False assert instance._flush_jobs == set() assert len(instance._staged_entries) == 0 @@ -387,7 +391,7 @@ def test_ctor_no_flush_limits(self): flush_limit_mutation_count=flush_limit_count, flush_limit_bytes=flush_limit_bytes, ) as instance: - assert instance._table == table + assert instance._target == table assert instance.closed is False assert instance._staged_entries == [] assert len(instance._oldest_exceptions) == 0 @@ -783,10 +787,10 @@ def test_timer_flush_end_to_end(self): num_mutations = 10 mutations = [self._make_mutation(count=2, size=2)] * num_mutations with self._make_one(flush_interval=0.05) as instance: - instance._table.default_operation_timeout = 10 - instance._table.default_attempt_timeout = 9 + instance._target.default_operation_timeout = 10 + instance._target.default_attempt_timeout = 9 with mock.patch.object( - instance._table.client._gapic_client, "mutate_rows" + instance._target.client._gapic_client, "mutate_rows" ) as gapic_mock: gapic_mock.side_effect = ( lambda *args, **kwargs: self._mock_gapic_return(num_mutations) diff --git a/tests/unit/data/test_sync_up_to_date.py b/tests/unit/data/test_sync_up_to_date.py index 492d35ddf..d4623a6c8 100644 --- a/tests/unit/data/test_sync_up_to_date.py +++ b/tests/unit/data/test_sync_up_to_date.py @@ -19,6 +19,9 @@ import re from difflib import unified_diff +if sys.version_info < (3, 9): + pytest.skip("ast.unparse is only available in 3.9+", allow_module_level=True) + # add cross_sync to path test_dir_name = os.path.dirname(__file__) repo_root = os.path.join(test_dir_name, "..", "..", "..") @@ -48,9 +51,6 @@ def test_found_files(): ), "test proxy handler not found" -@pytest.mark.skipif( - sys.version_info < (3, 9), reason="ast.unparse is only available in 3.9+" -) @pytest.mark.parametrize("sync_file", sync_files, ids=lambda f: f.output_path) def test_sync_up_to_date(sync_file): """