From 87eac271df544b83e8385c7a80c2103b3c7b1b02 Mon Sep 17 00:00:00 2001 From: Astha Mohta Date: Mon, 12 Dec 2022 16:42:36 +0530 Subject: [PATCH 01/11] feat:fgac changes and samples --- google/cloud/spanner_dbapi/parse_utils.py | 2 +- google/cloud/spanner_v1/database.py | 120 ++++++++--- google/cloud/spanner_v1/instance.py | 14 +- google/cloud/spanner_v1/pool.py | 79 +++++-- google/cloud/spanner_v1/session.py | 23 ++- samples/samples/snippets.py | 204 ++++++++++++------- samples/samples/snippets_test.py | 34 ++++ tests/system/test_database_api.py | 123 +++++++++++ tests/unit/spanner_dbapi/test_parse_utils.py | 4 + tests/unit/test_database.py | 56 +++++ tests/unit/test_instance.py | 5 +- tests/unit/test_pool.py | 118 +++++++++-- tests/unit/test_session.py | 81 +++++++- 13 files changed, 717 insertions(+), 146 deletions(-) diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index e09b294dff..84cb2dc7a5 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -151,7 +151,7 @@ # DDL statements follow # https://cloud.google.com/spanner/docs/data-definition-language -RE_DDL = re.compile(r"^\s*(CREATE|ALTER|DROP)", re.IGNORECASE | re.DOTALL) +RE_DDL = re.compile(r"^\s*(CREATE|ALTER|DROP|GRANT|REVOKE)", re.IGNORECASE | re.DOTALL) RE_IS_INSERT = re.compile(r"^\s*(INSERT)", re.IGNORECASE | re.DOTALL) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 7d2384beed..367b2ee40d 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -27,11 +27,14 @@ from google.cloud.exceptions import NotFound from google.api_core.exceptions import Aborted from google.api_core import gapic_v1 +from google.iam.v1 import iam_policy_pb2 +from google.iam.v1 import options_pb2 from google.cloud.spanner_admin_database_v1 import CreateDatabaseRequest from google.cloud.spanner_admin_database_v1 import Database as DatabasePB from google.cloud.spanner_admin_database_v1 import EncryptionConfig from google.cloud.spanner_admin_database_v1 import RestoreDatabaseEncryptionConfig +from google.cloud.spanner_admin_database_v1 import ListDatabaseRolesRequest from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest from google.cloud.spanner_admin_database_v1.types import DatabaseDialect @@ -119,7 +122,8 @@ class Database(object): :class:`~google.cloud.spanner_admin_database_v1.types.DatabaseDialect` :param database_dialect: (Optional) database dialect for the database - + :type database_role: str or None + :param database_role: (Optional) user-assigned database_role for the session. """ _spanner_api = None @@ -133,6 +137,7 @@ def __init__( logger=None, encryption_config=None, database_dialect=DatabaseDialect.DATABASE_DIALECT_UNSPECIFIED, + database_role=None, ): self.database_id = database_id self._instance = instance @@ -149,9 +154,10 @@ def __init__( self._logger = logger self._encryption_config = encryption_config self._database_dialect = database_dialect + self._database_role = database_role if pool is None: - pool = BurstyPool() + pool = BurstyPool(database_role=database_role) self._pool = pool pool.bind(self) @@ -314,6 +320,14 @@ def database_dialect(self): """ return self._database_dialect + @property + def database_role(self): + """User-assigned database_role for sessions created by the pool. + :rtype: str + :returns: a str with the name of the database role. + """ + return self._database_role + @property def logger(self): """Logger used by the database. @@ -466,9 +480,7 @@ def update_ddl(self, ddl_statements, operation_id=""): metadata = _metadata_with_prefix(self.name) request = UpdateDatabaseDdlRequest( - database=self.name, - statements=ddl_statements, - operation_id=operation_id, + database=self.name, statements=ddl_statements, operation_id=operation_id, ) future = api.update_database_ddl(request=request, metadata=metadata) @@ -571,8 +583,7 @@ def execute_pdml(): request_options=request_options, ) method = functools.partial( - api.execute_streaming_sql, - metadata=metadata, + api.execute_streaming_sql, metadata=metadata, ) iterator = _restart_on_unavailable(method, request) @@ -584,16 +595,22 @@ def execute_pdml(): return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)() - def session(self, labels=None): + def session(self, labels=None, database_role=None): """Factory to create a session for this database. :type labels: dict (str -> str) or None :param labels: (Optional) user-assigned labels for the session. + :type database_role: str + :param database_role: (Optional) user-assigned database_role for the session. + :rtype: :class:`~google.cloud.spanner_v1.session.Session` :returns: a session bound to this database. """ - return Session(self, labels=labels) + # If role is specified in param, then that role is used + # instead. + role = database_role or self._database_role + return Session(self, labels=labels, database_role=role) def snapshot(self, **kw): """Return an object which wraps a snapshot. @@ -722,10 +739,7 @@ def restore(self, source): backup=source.name, encryption_config=self._encryption_config or None, ) - future = api.restore_database( - request=request, - metadata=metadata, - ) + future = api.restore_database(request=request, metadata=metadata,) return future def is_ready(self): @@ -772,6 +786,26 @@ def list_database_operations(self, filter_="", page_size=None): filter_=database_filter, page_size=page_size ) + def list_database_roles(self, page_size=None): + """Lists Cloud Spanner database roles. + + :type page_size: int + :param page_size: + Optional. The maximum number of database roles in each page of results + from this request. Non-positive values are ignored. Defaults to a + sensible value set by the API. + + :type: Iterable + :returns: + Iterable of :class:`~google.cloud.spanner_admin_database_v1.types.spanner_database_admin.DatabaseRole` + resources within the current database. + """ + api = self._instance._client.database_admin_api + metadata = _metadata_with_prefix(self.name) + + request = ListDatabaseRolesRequest(parent=self.name, page_size=page_size,) + return api.list_database_roles(request=request, metadata=metadata) + def table(self, table_id): """Factory to create a table object within this database. @@ -811,6 +845,54 @@ def list_tables(self): for row in results: yield self.table(row[0]) + def get_iam_policy(self, policy_version=None): + """Gets the access control policy for a database resource. + + :type policy_version: int + :param policy_version: + (Optional) the maximum policy version that will be + used to format the policy. Valid values are 0, 1 ,3. + + :rtype: :class:`~google.iam.v1.policy_pb2.Policy` + :returns: + returns an Identity and Access Management (IAM) policy. It is used to + specify access control policies for Cloud Platform + resources. + """ + api = self._instance._client.database_admin_api + metadata = _metadata_with_prefix(self.name) + + request = iam_policy_pb2.GetIamPolicyRequest( + resource=self.name, + options=options_pb2.GetPolicyOptions( + requested_policy_version=policy_version + ), + ) + response = api.get_iam_policy(request=request, metadata=metadata) + return response + + def set_iam_policy(self, policy): + """Sets the access control policy on a database resource. + Replaces any existing policy. + + :type policy: :class:`~google.iam.v1.policy_pb2.Policy` + :param policy_version: + the complete policy to be applied to the resource. + + :rtype: :class:`~google.iam.v1.policy_pb2.Policy` + :returns: + returns the new Identity and Access Management (IAM) policy. + """ + api = self._instance._client.database_admin_api + metadata = _metadata_with_prefix(self.name) + + request = iam_policy_pb2.SetIamPolicyRequest( + resource=self.name, + policy=policy, + ) + response = api.set_iam_policy(request=request, metadata=metadata) + return response + class BatchCheckout(object): """Context manager for using a batch from a database. @@ -1073,11 +1155,7 @@ def generate_read_batches( yield {"partition": partition, "read": read_info.copy()} def process_read_batch( - self, - batch, - *, - retry=gapic_v1.method.DEFAULT, - timeout=gapic_v1.method.DEFAULT, + self, batch, *, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, ): """Process a single, partitioned read. @@ -1194,11 +1272,7 @@ def generate_query_batches( yield {"partition": partition, "query": query_info} def process_query_batch( - self, - batch, - *, - retry=gapic_v1.method.DEFAULT, - timeout=gapic_v1.method.DEFAULT, + self, batch, *, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, ): """Process a single, partitioned query. diff --git a/google/cloud/spanner_v1/instance.py b/google/cloud/spanner_v1/instance.py index 6a9517a0e8..73e58f2e3e 100644 --- a/google/cloud/spanner_v1/instance.py +++ b/google/cloud/spanner_v1/instance.py @@ -431,6 +431,7 @@ def database( logger=None, encryption_config=None, database_dialect=DatabaseDialect.DATABASE_DIALECT_UNSPECIFIED, + database_role=None, ): """Factory to create a database within this instance. @@ -477,6 +478,7 @@ def database( logger=logger, encryption_config=encryption_config, database_dialect=database_dialect, + database_role=database_role, ) def list_databases(self, page_size=None): @@ -617,9 +619,7 @@ def list_backups(self, filter_="", page_size=None): """ metadata = _metadata_with_prefix(self.name) request = ListBackupsRequest( - parent=self.name, - filter=filter_, - page_size=page_size, + parent=self.name, filter=filter_, page_size=page_size, ) page_iter = self._client.database_admin_api.list_backups( request=request, metadata=metadata @@ -647,9 +647,7 @@ def list_backup_operations(self, filter_="", page_size=None): """ metadata = _metadata_with_prefix(self.name) request = ListBackupOperationsRequest( - parent=self.name, - filter=filter_, - page_size=page_size, + parent=self.name, filter=filter_, page_size=page_size, ) page_iter = self._client.database_admin_api.list_backup_operations( request=request, metadata=metadata @@ -677,9 +675,7 @@ def list_database_operations(self, filter_="", page_size=None): """ metadata = _metadata_with_prefix(self.name) request = ListDatabaseOperationsRequest( - parent=self.name, - filter=filter_, - page_size=page_size, + parent=self.name, filter=filter_, page_size=page_size, ) page_iter = self._client.database_admin_api.list_database_operations( request=request, metadata=metadata diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 56a78ef672..b5d3f445fc 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -30,14 +30,18 @@ class AbstractSessionPool(object): :type labels: dict (str -> str) or None :param labels: (Optional) user-assigned labels for sessions created by the pool. + + :type database_role: str + :param database_role: (Optional) user-assigned database_role for the session. """ _database = None - def __init__(self, labels=None): + def __init__(self, labels=None, database_role=None): if labels is None: labels = {} self._labels = labels + self._database_role = database_role @property def labels(self): @@ -48,6 +52,15 @@ def labels(self): """ return self._labels + @property + def database_role(self): + """User-assigned database_role for sessions created by the pool. + + :rtype: str + :returns: database_role assigned by the user + """ + return self._database_role + def bind(self, database): """Associate the pool with a database. @@ -104,9 +117,9 @@ def _new_session(self): :rtype: :class:`~google.cloud.spanner_v1.session.Session` :returns: new session instance. """ - if self.labels: - return self._database.session(labels=self.labels) - return self._database.session() + return self._database.session( + labels=self.labels, database_role=self.database_role + ) def session(self, **kwargs): """Check out a session from the pool. @@ -146,13 +159,22 @@ class FixedSizePool(AbstractSessionPool): :type labels: dict (str -> str) or None :param labels: (Optional) user-assigned labels for sessions created by the pool. + + :type database_role: str + :param database_role: (Optional) user-assigned database_role for the session. """ DEFAULT_SIZE = 10 DEFAULT_TIMEOUT = 10 - def __init__(self, size=DEFAULT_SIZE, default_timeout=DEFAULT_TIMEOUT, labels=None): - super(FixedSizePool, self).__init__(labels=labels) + def __init__( + self, + size=DEFAULT_SIZE, + default_timeout=DEFAULT_TIMEOUT, + labels=None, + database_role=None, + ): + super(FixedSizePool, self).__init__(labels=labels, database_role=database_role) self.size = size self.default_timeout = default_timeout self._sessions = queue.LifoQueue(size) @@ -167,12 +189,14 @@ def bind(self, database): self._database = database api = database.spanner_api metadata = _metadata_with_prefix(database.name) + self._database_role = self._database_role or self._database.database_role while not self._sessions.full(): resp = api.batch_create_sessions( database=database.name, session_count=self.size - self._sessions.qsize(), metadata=metadata, + creator_role=self.database_role, ) for session_pb in resp.session: session = self._new_session() @@ -243,10 +267,13 @@ class BurstyPool(AbstractSessionPool): :type labels: dict (str -> str) or None :param labels: (Optional) user-assigned labels for sessions created by the pool. + + :type database_role: str + :param database_role: (Optional) user-assigned database_role for the session. """ - def __init__(self, target_size=10, labels=None): - super(BurstyPool, self).__init__(labels=labels) + def __init__(self, target_size=10, labels=None, database_role=None): + super(BurstyPool, self).__init__(labels=labels, database_role=database_role) self.target_size = target_size self._database = None self._sessions = queue.LifoQueue(target_size) @@ -259,6 +286,7 @@ def bind(self, database): when needed. """ self._database = database + self._database_role = self._database_role or self._database.database_role def get(self): """Check a session out from the pool. @@ -340,10 +368,20 @@ class PingingPool(AbstractSessionPool): :type labels: dict (str -> str) or None :param labels: (Optional) user-assigned labels for sessions created by the pool. + + :type database_role: str + :param database_role: (Optional) user-assigned database_role for the session. """ - def __init__(self, size=10, default_timeout=10, ping_interval=3000, labels=None): - super(PingingPool, self).__init__(labels=labels) + def __init__( + self, + size=10, + default_timeout=10, + ping_interval=3000, + labels=None, + database_role=None, + ): + super(PingingPool, self).__init__(labels=labels, database_role=database_role) self.size = size self.default_timeout = default_timeout self._delta = datetime.timedelta(seconds=ping_interval) @@ -360,12 +398,14 @@ def bind(self, database): api = database.spanner_api metadata = _metadata_with_prefix(database.name) created_session_count = 0 + self._database_role = self._database_role or self._database.database_role while created_session_count < self.size: resp = api.batch_create_sessions( database=database.name, session_count=self.size - created_session_count, metadata=metadata, + creator_role=self.database_role, ) for session_pb in resp.session: session = self._new_session() @@ -470,13 +510,27 @@ class TransactionPingingPool(PingingPool): :type labels: dict (str -> str) or None :param labels: (Optional) user-assigned labels for sessions created by the pool. + + :type database_role: str + :param database_role: (Optional) user-assigned database_role for the session. """ - def __init__(self, size=10, default_timeout=10, ping_interval=3000, labels=None): + def __init__( + self, + size=10, + default_timeout=10, + ping_interval=3000, + labels=None, + database_role=None, + ): self._pending_sessions = queue.Queue() super(TransactionPingingPool, self).__init__( - size, default_timeout, ping_interval, labels=labels + size, + default_timeout, + ping_interval, + labels=labels, + database_role=database_role, ) self.begin_pending_transactions() @@ -489,6 +543,7 @@ def bind(self, database): when needed. """ super(TransactionPingingPool, self).bind(database) + self._database_role = self._database_role or self._database.database_role self.begin_pending_transactions() def put(self, session): diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 1ab6a93626..424c5d4538 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -52,16 +52,20 @@ class Session(object): :type labels: dict (str -> str) :param labels: (Optional) User-assigned labels for the session. + + :type database_role: str + :param database_role: (Optional) user-assigned database_role for the session. """ _session_id = None _transaction = None - def __init__(self, database, labels=None): + def __init__(self, database, labels=None, database_role=None): self._database = database if labels is None: labels = {} self._labels = labels + self._database_role = database_role def __lt__(self, other): return self._session_id < other._session_id @@ -71,6 +75,14 @@ def session_id(self): """Read-only ID, set by the back-end during :meth:`create`.""" return self._session_id + @property + def database_role(self): + """User-assigned database-role for the session. + + :rtype: str + :returns: the database role str (None if no database role were assigned).""" + return self._database_role + @property def labels(self): """User-assigned labels for the session. @@ -115,15 +127,14 @@ def create(self): metadata = _metadata_with_prefix(self._database.name) request = CreateSessionRequest(database=self._database.name) + if self._database.database_role is not None: + request.session.creator_role = self._database.database_role if self._labels: request.session.labels = self._labels with trace_call("CloudSpanner.CreateSession", self, self._labels): - session_pb = api.create_session( - request=request, - metadata=metadata, - ) + session_pb = api.create_session(request=request, metadata=metadata,) self._session_id = session_pb.name.split("/")[-1] def exists(self): @@ -441,4 +452,4 @@ def _get_retry_delay(cause, attempts): nanos = retry_info.retry_delay.nanos return retry_info.retry_delay.seconds + nanos / 1.0e9 - return 2**attempts + random.random() + return 2 ** attempts + random.random() diff --git a/samples/samples/snippets.py b/samples/samples/snippets.py index 35f348939e..83e3dca814 100644 --- a/samples/samples/snippets.py +++ b/samples/samples/snippets.py @@ -33,6 +33,11 @@ from google.cloud.spanner_v1 import param_types from google.cloud.spanner_v1.data_types import JsonObject from google.protobuf import field_mask_pb2 # type: ignore +from google.cloud import spanner_admin_database_v1 +from google.cloud.spanner_v1 import database, param_types +from google.type import expr_pb2 +from google.iam.v1 import policy_pb2 + OPERATION_TIMEOUT_SECONDS = 240 @@ -2210,104 +2215,116 @@ def set_request_tag(instance_id, database_id): # [END spanner_set_request_tag] -# [START spanner_create_instance_config] -def create_instance_config(user_config_name, base_config_id): - """Creates the new user-managed instance configuration using base instance config.""" - - # user_config_name = `custom-nam11` - # base_config_id = `projects//instanceConfigs/nam11` +def add_and_drop_database_roles(instance_id, database_id): + """Showcases how to manage a user defined database role.""" + # [START spanner_add_and_drop_database_roles] + # instance_id = "your-spanner-instance" + # database_id = "your-spanner-db-id" spanner_client = spanner.Client() - base_config = spanner_client.instance_admin_api.get_instance_config( - name=base_config_id - ) + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + role_parent = "new_parent" + role_child = "new_child" - # The replicas for the custom instance configuration must include all the replicas of the base - # configuration, in addition to at least one from the list of optional replicas of the base - # configuration. - replicas = [] - for replica in base_config.replicas: - replicas.append(replica) - replicas.append(base_config.optional_replicas[0]) - operation = spanner_client.instance_admin_api.create_instance_config( - parent=spanner_client.project_name, - instance_config_id=user_config_name, - instance_config=spanner_instance_admin.InstanceConfig( - name="{}/instanceConfigs/{}".format( - spanner_client.project_name, user_config_name - ), - display_name="custom-python-samples", - config_type=spanner_instance_admin.InstanceConfig.Type.USER_MANAGED, - replicas=replicas, - base_config=base_config.name, - labels={"python_cloud_spanner_samples": "true"}, - ), + operation = database.update_ddl( + [ + "CREATE ROLE {}".format(role_parent), + "GRANT SELECT ON TABLE Singers TO ROLE {}".format(role_parent), + "CREATE ROLE {}".format(role_child), + "GRANT ROLE {} TO ROLE {}".format(role_parent, role_child), + ] ) - print("Waiting for operation to complete...") operation.result(OPERATION_TIMEOUT_SECONDS) + print( + "Created roles {} and {} and granted privileges".format(role_parent, role_child) + ) - print("Created instance configuration {}".format(user_config_name)) - + operation = database.update_ddl( + [ + "REVOKE ROLE {} FROM ROLE {}".format(role_parent, role_child), + "DROP ROLE {}".format(role_child), + ] + ) + operation.result(OPERATION_TIMEOUT_SECONDS) + print("Revoked privileges and dropped role {}".format(role_child)) -# [END spanner_create_instance_config] + # [END spanner_add_and_drop_database_roles] -# [START spanner_update_instance_config] -def update_instance_config(user_config_name): - """Updates the user-managed instance configuration.""" - # user_config_name = `custom-nam11` +def read_data_with_database_role(instance_id, database_id): + """Showcases how a user defined database role is used by member.""" + # [START spanner_read_data_with_database_role] + # instance_id = "your-spanner-instance" + # database_id = "your-spanner-db-id" spanner_client = spanner.Client() - config = spanner_client.instance_admin_api.get_instance_config( - name="{}/instanceConfigs/{}".format( - spanner_client.project_name, user_config_name - ) - ) - config.display_name = "updated custom instance config" - config.labels["updated"] = "true" - operation = spanner_client.instance_admin_api.update_instance_config( - instance_config=config, - update_mask=field_mask_pb2.FieldMask(paths=["display_name", "labels"]), - ) - print("Waiting for operation to complete...") - operation.result(OPERATION_TIMEOUT_SECONDS) - print("Updated instance configuration {}".format(user_config_name)) + instance = spanner_client.instance(instance_id) + role = "new_parent" + database = instance.database(database_id, database_role=role) + with database.snapshot() as snapshot: + results = snapshot.execute_sql("SELECT * FROM Singers") + for row in results: + print("SingerId: {}, FirstName: {}, LastName: {}".format(*row)) -# [END spanner_update_instance_config] + # [END spanner_read_data_with_database_role] -# [START spanner_delete_instance_config] -def delete_instance_config(user_config_id): - """Deleted the user-managed instance configuration.""" - spanner_client = spanner.Client() - spanner_client.instance_admin_api.delete_instance_config(name=user_config_id) - print("Instance config {} successfully deleted".format(user_config_id)) +def list_database_roles(instance_id, database_id): + """Showcases how to list Database Roles.""" + # [START spanner_list_database_roles] + # instance_id = "your-spanner-instance" + # database_id = "your-spanner-db-id" + spanner_client = spanner.Client() + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) -# [END spanner_delete_instance_config] + # List database roles. + print("Database Roles are:") + for role in database.list_database_roles(): + print(role.name.split("/")[-1]) + # [END spanner_list_database_roles] -# [START spanner_list_instance_config_operations] -def list_instance_config_operations(): - """List the user-managed instance configuration operations.""" +def enable_fine_grained_access( + instance_id, + database_id, + iam_member="user:alice@example.com", + database_role="new_parent", + title="condition title", +): + """Showcases how to enable fine grained access control.""" + # [START spanner_enable_fine_grained_access] + # instance_id = "your-spanner-instance" + # database_id = "your-spanner-db-id" + # iam_member = "user:alice@example.com" + # database_role = "new_parent" + # title = "condition title" spanner_client = spanner.Client() - operations = spanner_client.instance_admin_api.list_instance_config_operations( - request=spanner_instance_admin.ListInstanceConfigOperationsRequest( - parent=spanner_client.project_name, - filter="(metadata.@type=type.googleapis.com/google.spanner.admin.instance.v1.CreateInstanceConfigMetadata)", - ) + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + policy = database.get_iam_policy(3) + if policy.version < 3: + policy.version = 3 + + new_binding = policy_pb2.Binding( + role="roles/spanner.fineGrainedAccessUser", + members=[iam_member], + condition=expr_pb2.Expr( + title=title, + expression=f'resource.name.endsWith("/databaseRoles/{database_role}")', + ), ) - for op in operations: - metadata = spanner_instance_admin.CreateInstanceConfigMetadata.pb( - spanner_instance_admin.CreateInstanceConfigMetadata() - ) - op.metadata.Unpack(metadata) - print( - "List instance config operations {} is {}% completed.".format( - metadata.instance_config.name, metadata.progress.progress_percent - ) - ) + policy.version = 3 + policy.bindings.append(new_binding) + database.set_iam_policy(policy) -# [END spanner_list_instance_config_operations] + new_policy = database.get_iam_policy(3) + print( + f"Enabled fine-grained access in IAM. New policy has version {new_policy.version}" + ) + # [END spanner_enable_fine_grained_access] if __name__ == "__main__": # noqa: C901 @@ -2419,6 +2436,23 @@ def list_instance_config_operations(): "create_client_with_query_options", help=create_client_with_query_options.__doc__, ) + subparsers.add_parser( + "add_and_drop_database_roles", help=add_and_drop_database_roles.__doc__ + ) + subparsers.add_parser( + "read_data_with_database_role", help=read_data_with_database_role.__doc__ + ) + subparsers.add_parser("list_database_roles", help=list_database_roles.__doc__) + enable_fine_grained_access_parser = subparsers.add_parser( + "enable_fine_grained_access", help=enable_fine_grained_access.__doc__ + ) + enable_fine_grained_access_parser.add_argument( + "--iam_member", default="user:alice@example.com" + ) + enable_fine_grained_access_parser.add_argument( + "--database_role", default="new_parent" + ) + enable_fine_grained_access_parser.add_argument("--title", default="condition title") args = parser.parse_args() @@ -2534,3 +2568,17 @@ def list_instance_config_operations(): query_data_with_query_options(args.instance_id, args.database_id) elif args.command == "create_client_with_query_options": create_client_with_query_options(args.instance_id, args.database_id) + elif args.command == "add_and_drop_database_roles": + add_and_drop_database_roles(args.instance_id, args.database_id) + elif args.command == "read_data_with_database_role": + read_data_with_database_role(args.instance_id, args.database_id) + elif args.command == "list_database_roles": + list_database_roles(args.instance_id, args.database_id) + elif args.command == "enable_fine_grained_access": + enable_fine_grained_access( + args.instance_id, + args.database_id, + args.iam_member, + args.database_role, + args.title, + ) diff --git a/samples/samples/snippets_test.py b/samples/samples/snippets_test.py index 05cfedfdde..fa6df4580a 100644 --- a/samples/samples/snippets_test.py +++ b/samples/samples/snippets_test.py @@ -759,3 +759,37 @@ def test_set_request_tag(capsys, instance_id, sample_database): snippets.set_request_tag(instance_id, sample_database.database_id) out, _ = capsys.readouterr() assert "SingerId: 1, AlbumId: 1, AlbumTitle: Total Junk" in out + + +@pytest.mark.dependency(name="add_and_drop_database_roles", depends=["insert_data"]) +def test_add_and_drop_database_roles(capsys, instance_id, sample_database): + snippets.add_and_drop_database_roles(instance_id, sample_database.database_id) + out, _ = capsys.readouterr() + assert "Created roles new_parent and new_child and granted privileges" in out + assert "Revoked privileges and dropped role new_child" in out + + +@pytest.mark.dependency(depends=["add_and_drop_database_roles"]) +def test_read_data_with_database_role(capsys, instance_id, sample_database): + snippets.read_data_with_database_role(instance_id, sample_database.database_id) + out, _ = capsys.readouterr() + assert "ingerId: 1, FirstName: Marc, LastName: Richards" in out + + +@pytest.mark.dependency(depends=["add_and_drop_database_roles"]) +def test_list_database_roles(capsys, instance_id, sample_database): + snippets.list_database_roles(instance_id, sample_database.database_id) + out, _ = capsys.readouterr() + assert "new_parent" in out + + +@pytest.mark.dependency(depends=["add_and_drop_database_roles"]) +def test_enable_fine_grained_access(capsys, instance_id, sample_database): + iam_member = "user:asthamohta@google.com" + database_role = "new_parent" + title = "condition title" + snippets.enable_fine_grained_access( + instance_id, sample_database.database_id, iam_member, database_role, title + ) + out, _ = capsys.readouterr() + assert "Enabled fine-grained access in IAM. New policy has version 3" in out diff --git a/tests/system/test_database_api.py b/tests/system/test_database_api.py index e9e6c69287..59a382156b 100644 --- a/tests/system/test_database_api.py +++ b/tests/system/test_database_api.py @@ -18,7 +18,9 @@ import pytest from google.api_core import exceptions +from google.iam.v1 import policy_pb2 from google.cloud import spanner_v1 +from google.type import expr_pb2 from . import _helpers from . import _sample_data @@ -164,6 +166,48 @@ def test_create_database_with_default_leader_success( assert result[0] == default_leader +def test_iam_policy(not_emulator, shared_instance, databases_to_delete): + pool = spanner_v1.BurstyPool(labels={"testcase": "iam_policy"}) + temp_db_id = _helpers.unique_id("iam_db", separator="_") + create_table = ( + f"CREATE TABLE policy (\n" + f" Id STRING(36) NOT NULL,\n" + f" Field1 STRING(36) NOT NULL\n" + f") PRIMARY KEY (Id)" + ) + create_role = f"CREATE ROLE parent" + + temp_db = shared_instance.database( + temp_db_id, + ddl_statements=[create_table, create_role], + pool=pool, + ) + create_op = temp_db.create() + databases_to_delete.append(temp_db) + create_op.result(DBAPI_OPERATION_TIMEOUT) + policy = temp_db.get_iam_policy(3) + + assert policy.version == 0 + assert policy.etag == b"\x00 \x01" + + new_binding = policy_pb2.Binding( + role="roles/spanner.fineGrainedAccessUser", + members=["user:asthamohta@google.com"], + condition=expr_pb2.Expr( + title="condition title", + expression=f'resource.name.endsWith("/databaseRoles/parent")', + ), + ) + + policy.version = 3 + policy.bindings.append(new_binding) + temp_db.set_iam_policy(policy) + + new_policy = temp_db.get_iam_policy(3) + assert new_policy.version == 3 + assert new_policy.bindings == [new_binding] + + def test_table_not_found(shared_instance): temp_db_id = _helpers.unique_id("tbl_not_found", separator="_") @@ -301,6 +345,85 @@ def test_update_ddl_w_default_leader_success( assert len(temp_db.ddl_statements) == len(ddl_statements) +def test_create_role_grant_access_success( + not_emulator, + shared_instance, + databases_to_delete, +): + creator_role_parent = _helpers.unique_id("role_parent", separator="_") + creator_role_orphan = _helpers.unique_id("role_orphan", separator="_") + + temp_db_id = _helpers.unique_id("dfl_ldrr_upd_ddl", separator="_") + temp_db = shared_instance.database(temp_db_id) + + create_op = temp_db.create() + databases_to_delete.append(temp_db) + create_op.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. + + # Create role and grant select permission on table contacts for parent role. + ddl_statements = _helpers.DDL_STATEMENTS + [ + f"CREATE ROLE {creator_role_parent}", + f"CREATE ROLE {creator_role_orphan}", + f"GRANT SELECT ON TABLE contacts TO ROLE {creator_role_parent}", + ] + operation = temp_db.update_ddl(ddl_statements) + operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. + + # Perform select with orphan role on table contacts. + # Expect PermissionDenied exception. + temp_db = shared_instance.database(temp_db_id, database_role=creator_role_orphan) + with pytest.raises(exceptions.PermissionDenied): + with temp_db.snapshot() as snapshot: + results = snapshot.execute_sql("SELECT * FROM contacts") + for row in results: + pass + + # Perform select with parent role on table contacts. Expect success. + temp_db = shared_instance.database(temp_db_id, database_role=creator_role_parent) + with temp_db.snapshot() as snapshot: + snapshot.execute_sql("SELECT * FROM contacts") + + ddl_remove_roles = [ + f"REVOKE SELECT ON TABLE contacts FROM ROLE {creator_role_parent}", + f"DROP ROLE {creator_role_parent}", + f"DROP ROLE {creator_role_orphan}", + ] + # Revoke permission and Delete roles. + operation = temp_db.update_ddl(ddl_remove_roles) + operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. + + +def test_list_database_role_success( + not_emulator, + shared_instance, + databases_to_delete, +): + creator_role_parent = _helpers.unique_id("role_parent", separator="_") + creator_role_orphan = _helpers.unique_id("role_orphan", separator="_") + + temp_db_id = _helpers.unique_id("dfl_ldrr_upd_ddl", separator="_") + temp_db = shared_instance.database(temp_db_id) + + create_op = temp_db.create() + databases_to_delete.append(temp_db) + create_op.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. + + # Create role and grant select permission on table contacts for parent role. + ddl_statements = _helpers.DDL_STATEMENTS + [ + f"CREATE ROLE {creator_role_parent}", + f"CREATE ROLE {creator_role_orphan}", + ] + operation = temp_db.update_ddl(ddl_statements) + operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. + + # List database roles. + roles_list = [] + for role in temp_db.list_database_roles(): + roles_list.append(role.name.split("/")[-1]) + assert creator_role_parent in roles_list + assert creator_role_orphan in roles_list + + def test_db_batch_insert_then_db_snapshot_read(shared_database): _helpers.retry_has_all_dll(shared_database.reload)() sd = _sample_data diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index 511ad838cf..ddd1d5572a 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -54,6 +54,10 @@ def test_classify_stmt(self): "CREATE INDEX AlbumsByAlbumTitle2 ON Albums(AlbumTitle) STORING (MarketingBudget)", STMT_DDL, ), + ("CREATE ROLE parent", STMT_DDL), + ("GRANT SELECT ON TABLE Singers TO ROLE parent", STMT_DDL), + ("REVOKE SELECT ON TABLE Singers TO ROLE parent", STMT_DDL), + ("GRANT ROLE parent TO ROLE child", STMT_DDL), ("INSERT INTO table (col1) VALUES (1)", STMT_INSERT), ("UPDATE table SET col1 = 1 WHERE col1 = NULL", STMT_UPDATING), ) diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index bd47a2ac31..bff89320c7 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -61,6 +61,7 @@ class _BaseTest(unittest.TestCase): BACKUP_ID = "backup_id" BACKUP_NAME = INSTANCE_NAME + "/backups/" + BACKUP_ID TRANSACTION_TAG = "transaction-tag" + DATABASE_ROLE = "dummy-role" def _make_one(self, *args, **kwargs): return self._get_target_class()(*args, **kwargs) @@ -112,6 +113,7 @@ def test_ctor_defaults(self): self.assertIsNone(database._logger) # BurstyPool does not create sessions during 'bind()'. self.assertTrue(database._pool._sessions.empty()) + self.assertIsNone(database.database_role) def test_ctor_w_explicit_pool(self): instance = _Instance(self.INSTANCE_NAME) @@ -123,6 +125,15 @@ def test_ctor_w_explicit_pool(self): self.assertIs(database._pool, pool) self.assertIs(pool._bound, database) + def test_ctor_w_database_role(self): + instance = _Instance(self.INSTANCE_NAME) + database = self._make_one( + self.DATABASE_ID, instance, database_role=self.DATABASE_ROLE + ) + self.assertEqual(database.database_id, self.DATABASE_ID) + self.assertIs(database._instance, instance) + self.assertIs(database.database_role, self.DATABASE_ROLE) + def test_ctor_w_ddl_statements_non_string(self): with self.assertRaises(ValueError): @@ -1527,6 +1538,51 @@ def test_list_database_operations_explicit_filter(self): filter_=expected_filter_, page_size=page_size ) + def test_list_database_roles_grpc_error(self): + from google.api_core.exceptions import Unknown + from google.cloud.spanner_admin_database_v1 import ListDatabaseRolesRequest + + client = _Client() + api = client.database_admin_api = self._make_database_admin_api() + api.list_database_roles.side_effect = Unknown("testing") + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + + with self.assertRaises(Unknown): + database.list_database_roles() + + expected_request = ListDatabaseRolesRequest( + parent=database.name, + ) + + api.list_database_roles.assert_called_once_with( + request=expected_request, + metadata=[("google-cloud-resource-prefix", database.name)], + ) + + def test_list_database_roles_defaults(self): + from google.cloud.spanner_admin_database_v1 import ListDatabaseRolesRequest + + client = _Client() + api = client.database_admin_api = self._make_database_admin_api() + instance = _Instance(self.INSTANCE_NAME, client=client) + instance.list_database_roles = mock.MagicMock(return_value=[]) + pool = _Pool() + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + + resp = database.list_database_roles() + + expected_request = ListDatabaseRolesRequest( + parent=database.name, + ) + + api.list_database_roles.assert_called_once_with( + request=expected_request, + metadata=[("google-cloud-resource-prefix", database.name)], + ) + self.assertIsNotNone(resp) + def test_table_factory_defaults(self): from google.cloud.spanner_v1.table import Table diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index c715fb2ee1..e0a0f663cf 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -13,7 +13,6 @@ # limitations under the License. import unittest - import mock @@ -544,6 +543,7 @@ def test_database_factory_defaults(self): self.assertIsNone(database._logger) pool = database._pool self.assertIs(pool._database, database) + self.assertIsNone(database.database_role) def test_database_factory_explicit(self): from logging import Logger @@ -553,6 +553,7 @@ def test_database_factory_explicit(self): client = _Client(self.PROJECT) instance = self._make_one(self.INSTANCE_ID, client, self.CONFIG_NAME) DATABASE_ID = "database-id" + DATABASE_ROLE = "dummy-role" pool = _Pool() logger = mock.create_autospec(Logger, instance=True) encryption_config = {"kms_key_name": "kms_key_name"} @@ -563,6 +564,7 @@ def test_database_factory_explicit(self): pool=pool, logger=logger, encryption_config=encryption_config, + database_role=DATABASE_ROLE, ) self.assertIsInstance(database, Database) @@ -573,6 +575,7 @@ def test_database_factory_explicit(self): self.assertIs(database._logger, logger) self.assertIs(pool._bound, database) self.assertIs(database._encryption_config, encryption_config) + self.assertIs(database.database_role, DATABASE_ROLE) def test_list_databases(self): from google.cloud.spanner_admin_database_v1 import Database as DatabasePB diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py index 593420187d..541f3aa668 100644 --- a/tests/unit/test_pool.py +++ b/tests/unit/test_pool.py @@ -44,12 +44,15 @@ def test_ctor_defaults(self): pool = self._make_one() self.assertIsNone(pool._database) self.assertEqual(pool.labels, {}) + self.assertIsNone(pool.database_role) def test_ctor_explicit(self): labels = {"foo": "bar"} - pool = self._make_one(labels=labels) + database_role = "dummy-role" + pool = self._make_one(labels=labels, database_role=database_role) self.assertIsNone(pool._database) self.assertEqual(pool.labels, labels) + self.assertEqual(pool.database_role, database_role) def test_bind_abstract(self): pool = self._make_one() @@ -82,7 +85,7 @@ def test__new_session_wo_labels(self): new_session = pool._new_session() self.assertIs(new_session, session) - database.session.assert_called_once_with() + database.session.assert_called_once_with(labels={}, database_role=None) def test__new_session_w_labels(self): labels = {"foo": "bar"} @@ -94,7 +97,19 @@ def test__new_session_w_labels(self): new_session = pool._new_session() self.assertIs(new_session, session) - database.session.assert_called_once_with(labels=labels) + database.session.assert_called_once_with(labels=labels, database_role=None) + + def test__new_session_w_database_role(self): + database_role = "dummy-role" + pool = self._make_one(database_role=database_role) + database = pool._database = _make_database("name") + session = _make_session() + database.session.return_value = session + + new_session = pool._new_session() + + self.assertIs(new_session, session) + database.session.assert_called_once_with(labels={}, database_role=database_role) def test_session_wo_kwargs(self): from google.cloud.spanner_v1.pool import SessionCheckout @@ -133,26 +148,34 @@ def test_ctor_defaults(self): self.assertEqual(pool.default_timeout, 10) self.assertTrue(pool._sessions.empty()) self.assertEqual(pool.labels, {}) + self.assertIsNone(pool.database_role) def test_ctor_explicit(self): labels = {"foo": "bar"} - pool = self._make_one(size=4, default_timeout=30, labels=labels) + database_role = "dummy-role" + pool = self._make_one( + size=4, default_timeout=30, labels=labels, database_role=database_role + ) self.assertIsNone(pool._database) self.assertEqual(pool.size, 4) self.assertEqual(pool.default_timeout, 30) self.assertTrue(pool._sessions.empty()) self.assertEqual(pool.labels, labels) + self.assertEqual(pool.database_role, database_role) def test_bind(self): + database_role = "dummy-role" pool = self._make_one() database = _Database("name") SESSIONS = [_Session(database)] * 10 + database._database_role = database_role database._sessions.extend(SESSIONS) pool.bind(database) self.assertIs(pool._database, database) self.assertEqual(pool.size, 10) + self.assertEqual(pool.database_role, database_role) self.assertEqual(pool.default_timeout, 10) self.assertTrue(pool._sessions.full()) @@ -272,14 +295,25 @@ def test_ctor_defaults(self): self.assertEqual(pool.target_size, 10) self.assertTrue(pool._sessions.empty()) self.assertEqual(pool.labels, {}) + self.assertIsNone(pool.database_role) def test_ctor_explicit(self): labels = {"foo": "bar"} - pool = self._make_one(target_size=4, labels=labels) + database_role = "dummy-role" + pool = self._make_one(target_size=4, labels=labels, database_role=database_role) self.assertIsNone(pool._database) self.assertEqual(pool.target_size, 4) self.assertTrue(pool._sessions.empty()) self.assertEqual(pool.labels, labels) + self.assertEqual(pool.database_role, database_role) + + def test_ctor_explicit_w_database_role_in_db(self): + database_role = "dummy-role" + pool = self._make_one() + database = pool._database = _Database("name") + database._database_role = database_role + pool.bind(database) + self.assertEqual(pool.database_role, database_role) def test_get_empty(self): pool = self._make_one() @@ -392,11 +426,17 @@ def test_ctor_defaults(self): self.assertEqual(pool._delta.seconds, 3000) self.assertTrue(pool._sessions.empty()) self.assertEqual(pool.labels, {}) + self.assertIsNone(pool.database_role) def test_ctor_explicit(self): labels = {"foo": "bar"} + database_role = "dummy-role" pool = self._make_one( - size=4, default_timeout=30, ping_interval=1800, labels=labels + size=4, + default_timeout=30, + ping_interval=1800, + labels=labels, + database_role=database_role, ) self.assertIsNone(pool._database) self.assertEqual(pool.size, 4) @@ -404,6 +444,17 @@ def test_ctor_explicit(self): self.assertEqual(pool._delta.seconds, 1800) self.assertTrue(pool._sessions.empty()) self.assertEqual(pool.labels, labels) + self.assertEqual(pool.database_role, database_role) + + def test_ctor_explicit_w_database_role_in_db(self): + database_role = "dummy-role" + pool = self._make_one() + database = pool._database = _Database("name") + SESSIONS = [_Session(database)] * 10 + database._sessions.extend(SESSIONS) + database._database_role = database_role + pool.bind(database) + self.assertEqual(pool.database_role, database_role) def test_bind(self): pool = self._make_one() @@ -624,11 +675,17 @@ def test_ctor_defaults(self): self.assertTrue(pool._sessions.empty()) self.assertTrue(pool._pending_sessions.empty()) self.assertEqual(pool.labels, {}) + self.assertIsNone(pool.database_role) def test_ctor_explicit(self): labels = {"foo": "bar"} + database_role = "dummy-role" pool = self._make_one( - size=4, default_timeout=30, ping_interval=1800, labels=labels + size=4, + default_timeout=30, + ping_interval=1800, + labels=labels, + database_role=database_role, ) self.assertIsNone(pool._database) self.assertEqual(pool.size, 4) @@ -637,6 +694,17 @@ def test_ctor_explicit(self): self.assertTrue(pool._sessions.empty()) self.assertTrue(pool._pending_sessions.empty()) self.assertEqual(pool.labels, labels) + self.assertEqual(pool.database_role, database_role) + + def test_ctor_explicit_w_database_role_in_db(self): + database_role = "dummy-role" + pool = self._make_one() + database = pool._database = _Database("name") + SESSIONS = [_Session(database)] * 10 + database._sessions.extend(SESSIONS) + database._database_role = database_role + pool.bind(database) + self.assertEqual(pool.database_role, database_role) def test_bind(self): pool = self._make_one() @@ -794,10 +862,12 @@ def test_ctor_wo_kwargs(self): def test_ctor_w_kwargs(self): pool = _Pool() - checkout = self._make_one(pool, foo="bar") + checkout = self._make_one(pool, foo="bar", database_role="dummy-role") self.assertIs(checkout._pool, pool) self.assertIsNone(checkout._session) - self.assertEqual(checkout._kwargs, {"foo": "bar"}) + self.assertEqual( + checkout._kwargs, {"foo": "bar", "database_role": "dummy-role"} + ) def test_context_manager_wo_kwargs(self): session = object() @@ -885,17 +955,30 @@ class _Database(object): def __init__(self, name): self.name = name self._sessions = [] + self._database_role = None def mock_batch_create_sessions( - database=None, session_count=10, timeout=10, metadata=[] + database=None, + session_count=10, + timeout=10, + metadata=[], + creator_role=None, + labels={}, ): from google.cloud.spanner_v1 import BatchCreateSessionsResponse from google.cloud.spanner_v1 import Session if session_count < 2: - response = BatchCreateSessionsResponse(session=[Session()]) + response = BatchCreateSessionsResponse( + session=[Session(creator_role=creator_role, labels=labels)] + ) else: - response = BatchCreateSessionsResponse(session=[Session(), Session()]) + response = BatchCreateSessionsResponse( + session=[ + Session(creator_role=creator_role, labels=labels), + Session(creator_role=creator_role, labels=labels), + ] + ) return response from google.cloud.spanner_v1 import SpannerClient @@ -903,7 +986,16 @@ def mock_batch_create_sessions( self.spanner_api = mock.create_autospec(SpannerClient, instance=True) self.spanner_api.batch_create_sessions.side_effect = mock_batch_create_sessions - def session(self): + @property + def database_role(self): + """Database role used in sessions to connect to this database. + + :rtype: str + :returns: an str with the name of the database role. + """ + return self._database_role + + def session(self, **kwargs): # always return first session in the list # to avoid reversing the order of putting # sessions into pool (important for order tests) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 0f297654bb..005cd0cd1f 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -45,6 +45,7 @@ class TestSession(OpenTelemetryBase): DATABASE_NAME = INSTANCE_NAME + "/databases/" + DATABASE_ID SESSION_ID = "session-id" SESSION_NAME = DATABASE_NAME + "/sessions/" + SESSION_ID + DATABASE_ROLE = "dummy-role" BASE_ATTRIBUTES = { "db.type": "spanner", "db.url": "spanner.googleapis.com", @@ -61,19 +62,20 @@ def _make_one(self, *args, **kwargs): return self._getTargetClass()(*args, **kwargs) @staticmethod - def _make_database(name=DATABASE_NAME): + def _make_database(name=DATABASE_NAME, database_role=None): from google.cloud.spanner_v1.database import Database database = mock.create_autospec(Database, instance=True) database.name = name database.log_commit_stats = False + database.database_role = database_role return database @staticmethod - def _make_session_pb(name, labels=None): + def _make_session_pb(name, labels=None, database_role=None): from google.cloud.spanner_v1 import Session - return Session(name=name, labels=labels) + return Session(name=name, labels=labels, creator_role=database_role) def _make_spanner_api(self): from google.cloud.spanner_v1 import SpannerClient @@ -87,6 +89,20 @@ def test_constructor_wo_labels(self): self.assertIs(session._database, database) self.assertEqual(session.labels, {}) + def test_constructor_w_database_role(self): + database = self._make_database(database_role=self.DATABASE_ROLE) + session = self._make_one(database, database_role=self.DATABASE_ROLE) + self.assertIs(session.session_id, None) + self.assertIs(session._database, database) + self.assertEqual(session.database_role, self.DATABASE_ROLE) + + def test_constructor_wo_database_role(self): + database = self._make_database() + session = self._make_one(database) + self.assertIs(session.session_id, None) + self.assertIs(session._database, database) + self.assertIs(session.database_role, None) + def test_constructor_w_labels(self): database = self._make_database() labels = {"foo": "bar"} @@ -126,6 +142,65 @@ def test_create_w_session_id(self): self.assertNoSpans() + def test_create_w_database_role(self): + from google.cloud.spanner_v1 import CreateSessionRequest + from google.cloud.spanner_v1 import Session as SessionRequestProto + + session_pb = self._make_session_pb( + self.SESSION_NAME, database_role=self.DATABASE_ROLE + ) + gax_api = self._make_spanner_api() + gax_api.create_session.return_value = session_pb + database = self._make_database(database_role=self.DATABASE_ROLE) + database.spanner_api = gax_api + session = self._make_one(database, database_role=self.DATABASE_ROLE) + + session.create() + + self.assertEqual(session.session_id, self.SESSION_ID) + self.assertEqual(session.database_role, self.DATABASE_ROLE) + session_template = SessionRequestProto(creator_role=self.DATABASE_ROLE) + + request = CreateSessionRequest( + database=database.name, + session=session_template, + ) + + gax_api.create_session.assert_called_once_with( + request=request, + metadata=[("google-cloud-resource-prefix", database.name)], + ) + + self.assertSpanAttributes( + "CloudSpanner.CreateSession", attributes=TestSession.BASE_ATTRIBUTES + ) + + def test_create_wo_database_role(self): + from google.cloud.spanner_v1 import CreateSessionRequest + + session_pb = self._make_session_pb(self.SESSION_NAME) + gax_api = self._make_spanner_api() + gax_api.create_session.return_value = session_pb + database = self._make_database() + database.spanner_api = gax_api + session = self._make_one(database) + session.create() + + self.assertEqual(session.session_id, self.SESSION_ID) + self.assertIsNone(session.database_role) + + request = CreateSessionRequest( + database=database.name, + ) + + gax_api.create_session.assert_called_once_with( + request=request, metadata=[("google-cloud-resource-prefix", database.name)] + ) + + self.assertSpanAttributes( + "CloudSpanner.CreateSession", attributes=TestSession.BASE_ATTRIBUTES + ) + def test_create_ok(self): from google.cloud.spanner_v1 import CreateSessionRequest From 9ba5bb4a2a8cb7a51f77183a5d8631fde952f829 Mon Sep 17 00:00:00 2001 From: Astha Mohta Date: Tue, 13 Dec 2022 10:46:52 +0530 Subject: [PATCH 02/11] linting --- google/cloud/spanner_v1/database.py | 29 +++++++++++++++++++++++------ google/cloud/spanner_v1/instance.py | 12 +++++++++--- google/cloud/spanner_v1/pool.py | 13 +++++++++++-- google/cloud/spanner_v1/session.py | 7 +++++-- samples/samples/snippets.py | 3 --- samples/samples/snippets_test.py | 12 ------------ tests/system/test_database_api.py | 9 ++++++++- tests/unit/test_pool.py | 9 +++++---- 8 files changed, 61 insertions(+), 33 deletions(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 367b2ee40d..e4f5d9a7a0 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -480,7 +480,9 @@ def update_ddl(self, ddl_statements, operation_id=""): metadata = _metadata_with_prefix(self.name) request = UpdateDatabaseDdlRequest( - database=self.name, statements=ddl_statements, operation_id=operation_id, + database=self.name, + statements=ddl_statements, + operation_id=operation_id, ) future = api.update_database_ddl(request=request, metadata=metadata) @@ -583,7 +585,8 @@ def execute_pdml(): request_options=request_options, ) method = functools.partial( - api.execute_streaming_sql, metadata=metadata, + api.execute_streaming_sql, + metadata=metadata, ) iterator = _restart_on_unavailable(method, request) @@ -739,7 +742,10 @@ def restore(self, source): backup=source.name, encryption_config=self._encryption_config or None, ) - future = api.restore_database(request=request, metadata=metadata,) + future = api.restore_database( + request=request, + metadata=metadata, + ) return future def is_ready(self): @@ -803,7 +809,10 @@ def list_database_roles(self, page_size=None): api = self._instance._client.database_admin_api metadata = _metadata_with_prefix(self.name) - request = ListDatabaseRolesRequest(parent=self.name, page_size=page_size,) + request = ListDatabaseRolesRequest( + parent=self.name, + page_size=page_size, + ) return api.list_database_roles(request=request, metadata=metadata) def table(self, table_id): @@ -1155,7 +1164,11 @@ def generate_read_batches( yield {"partition": partition, "read": read_info.copy()} def process_read_batch( - self, batch, *, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, + self, + batch, + *, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, ): """Process a single, partitioned read. @@ -1272,7 +1285,11 @@ def generate_query_batches( yield {"partition": partition, "query": query_info} def process_query_batch( - self, batch, *, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, + self, + batch, + *, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, ): """Process a single, partitioned query. diff --git a/google/cloud/spanner_v1/instance.py b/google/cloud/spanner_v1/instance.py index 73e58f2e3e..f972f817b3 100644 --- a/google/cloud/spanner_v1/instance.py +++ b/google/cloud/spanner_v1/instance.py @@ -619,7 +619,9 @@ def list_backups(self, filter_="", page_size=None): """ metadata = _metadata_with_prefix(self.name) request = ListBackupsRequest( - parent=self.name, filter=filter_, page_size=page_size, + parent=self.name, + filter=filter_, + page_size=page_size, ) page_iter = self._client.database_admin_api.list_backups( request=request, metadata=metadata @@ -647,7 +649,9 @@ def list_backup_operations(self, filter_="", page_size=None): """ metadata = _metadata_with_prefix(self.name) request = ListBackupOperationsRequest( - parent=self.name, filter=filter_, page_size=page_size, + parent=self.name, + filter=filter_, + page_size=page_size, ) page_iter = self._client.database_admin_api.list_backup_operations( request=request, metadata=metadata @@ -675,7 +679,9 @@ def list_database_operations(self, filter_="", page_size=None): """ metadata = _metadata_with_prefix(self.name) request = ListDatabaseOperationsRequest( - parent=self.name, filter=filter_, page_size=page_size, + parent=self.name, + filter=filter_, + page_size=page_size, ) page_iter = self._client.database_admin_api.list_database_operations( request=request, metadata=metadata diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index b5d3f445fc..abf5f441f9 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -19,6 +19,8 @@ from google.cloud.exceptions import NotFound from google.cloud.spanner_v1._helpers import _metadata_with_prefix +from google.cloud.spanner_v1 import Session +from google.cloud.spanner_v1 import BatchCreateSessionsRequest _NOW = datetime.datetime.utcnow # unit tests may replace @@ -190,13 +192,16 @@ def bind(self, database): api = database.spanner_api metadata = _metadata_with_prefix(database.name) self._database_role = self._database_role or self._database.database_role + request = BatchCreateSessionsRequest( + session_template=Session(creator_role=self.database_role), + ) while not self._sessions.full(): resp = api.batch_create_sessions( + request=request, database=database.name, session_count=self.size - self._sessions.qsize(), metadata=metadata, - creator_role=self.database_role, ) for session_pb in resp.session: session = self._new_session() @@ -400,12 +405,16 @@ def bind(self, database): created_session_count = 0 self._database_role = self._database_role or self._database.database_role + request = BatchCreateSessionsRequest( + session_template=Session(creator_role=self.database_role), + ) + while created_session_count < self.size: resp = api.batch_create_sessions( + request=request, database=database.name, session_count=self.size - created_session_count, metadata=metadata, - creator_role=self.database_role, ) for session_pb in resp.session: session = self._new_session() diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 424c5d4538..c210f8f61d 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -134,7 +134,10 @@ def create(self): request.session.labels = self._labels with trace_call("CloudSpanner.CreateSession", self, self._labels): - session_pb = api.create_session(request=request, metadata=metadata,) + session_pb = api.create_session( + request=request, + metadata=metadata, + ) self._session_id = session_pb.name.split("/")[-1] def exists(self): @@ -452,4 +455,4 @@ def _get_retry_delay(cause, attempts): nanos = retry_info.retry_delay.nanos return retry_info.retry_delay.seconds + nanos / 1.0e9 - return 2 ** attempts + random.random() + return 2**attempts + random.random() diff --git a/samples/samples/snippets.py b/samples/samples/snippets.py index 83e3dca814..3307a3aedd 100644 --- a/samples/samples/snippets.py +++ b/samples/samples/snippets.py @@ -29,11 +29,8 @@ import time from google.cloud import spanner -from google.cloud.spanner_admin_instance_v1.types import spanner_instance_admin from google.cloud.spanner_v1 import param_types from google.cloud.spanner_v1.data_types import JsonObject -from google.protobuf import field_mask_pb2 # type: ignore -from google.cloud import spanner_admin_database_v1 from google.cloud.spanner_v1 import database, param_types from google.type import expr_pb2 from google.iam.v1 import policy_pb2 diff --git a/samples/samples/snippets_test.py b/samples/samples/snippets_test.py index fa6df4580a..6d5822e37b 100644 --- a/samples/samples/snippets_test.py +++ b/samples/samples/snippets_test.py @@ -781,15 +781,3 @@ def test_list_database_roles(capsys, instance_id, sample_database): snippets.list_database_roles(instance_id, sample_database.database_id) out, _ = capsys.readouterr() assert "new_parent" in out - - -@pytest.mark.dependency(depends=["add_and_drop_database_roles"]) -def test_enable_fine_grained_access(capsys, instance_id, sample_database): - iam_member = "user:asthamohta@google.com" - database_role = "new_parent" - title = "condition title" - snippets.enable_fine_grained_access( - instance_id, sample_database.database_id, iam_member, database_role, title - ) - out, _ = capsys.readouterr() - assert "Enabled fine-grained access in IAM. New policy has version 3" in out diff --git a/tests/system/test_database_api.py b/tests/system/test_database_api.py index 59a382156b..b708ff23c4 100644 --- a/tests/system/test_database_api.py +++ b/tests/system/test_database_api.py @@ -166,7 +166,12 @@ def test_create_database_with_default_leader_success( assert result[0] == default_leader -def test_iam_policy(not_emulator, shared_instance, databases_to_delete): +def test_iam_policy( + not_emulator, + shared_instance, + databases_to_delete, + not_postgres, +): pool = spanner_v1.BurstyPool(labels={"testcase": "iam_policy"}) temp_db_id = _helpers.unique_id("iam_db", separator="_") create_table = ( @@ -349,6 +354,7 @@ def test_create_role_grant_access_success( not_emulator, shared_instance, databases_to_delete, + not_postgres, ): creator_role_parent = _helpers.unique_id("role_parent", separator="_") creator_role_orphan = _helpers.unique_id("role_orphan", separator="_") @@ -397,6 +403,7 @@ def test_list_database_role_success( not_emulator, shared_instance, databases_to_delete, + not_postgres, ): creator_role_parent = _helpers.unique_id("role_parent", separator="_") creator_role_orphan = _helpers.unique_id("role_orphan", separator="_") diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py index 541f3aa668..1a53aa1604 100644 --- a/tests/unit/test_pool.py +++ b/tests/unit/test_pool.py @@ -958,25 +958,26 @@ def __init__(self, name): self._database_role = None def mock_batch_create_sessions( + request=None, database=None, session_count=10, timeout=10, metadata=[], - creator_role=None, labels={}, ): from google.cloud.spanner_v1 import BatchCreateSessionsResponse from google.cloud.spanner_v1 import Session + database_role = request.session_template.creator_role if request else None if session_count < 2: response = BatchCreateSessionsResponse( - session=[Session(creator_role=creator_role, labels=labels)] + session=[Session(creator_role=database_role, labels=labels)] ) else: response = BatchCreateSessionsResponse( session=[ - Session(creator_role=creator_role, labels=labels), - Session(creator_role=creator_role, labels=labels), + Session(creator_role=database_role, labels=labels), + Session(creator_role=database_role, labels=labels), ] ) return response From 15cede676da240d77f4b2435477b3de5183520c4 Mon Sep 17 00:00:00 2001 From: Astha Mohta Date: Tue, 13 Dec 2022 12:38:19 +0530 Subject: [PATCH 03/11] fixing samples --- samples/samples/snippets.py | 107 +++++++++++++++++++++++++++++++++++- 1 file changed, 104 insertions(+), 3 deletions(-) diff --git a/samples/samples/snippets.py b/samples/samples/snippets.py index 3307a3aedd..2637b71a2d 100644 --- a/samples/samples/snippets.py +++ b/samples/samples/snippets.py @@ -29,12 +29,12 @@ import time from google.cloud import spanner +from google.cloud.spanner_admin_instance_v1.types import spanner_instance_admin from google.cloud.spanner_v1 import param_types -from google.cloud.spanner_v1.data_types import JsonObject -from google.cloud.spanner_v1 import database, param_types from google.type import expr_pb2 from google.iam.v1 import policy_pb2 - +from google.cloud.spanner_v1.data_types import JsonObject +from google.protobuf import field_mask_pb2 # type: ignore OPERATION_TIMEOUT_SECONDS = 240 @@ -2212,6 +2212,106 @@ def set_request_tag(instance_id, database_id): # [END spanner_set_request_tag] +# [START spanner_create_instance_config] +def create_instance_config(user_config_name, base_config_id): + """Creates the new user-managed instance configuration using base instance config.""" + + # user_config_name = `custom-nam11` + # base_config_id = `projects//instanceConfigs/nam11` + spanner_client = spanner.Client() + base_config = spanner_client.instance_admin_api.get_instance_config( + name=base_config_id + ) + + # The replicas for the custom instance configuration must include all the replicas of the base + # configuration, in addition to at least one from the list of optional replicas of the base + # configuration. + replicas = [] + for replica in base_config.replicas: + replicas.append(replica) + replicas.append(base_config.optional_replicas[0]) + operation = spanner_client.instance_admin_api.create_instance_config( + parent=spanner_client.project_name, + instance_config_id=user_config_name, + instance_config=spanner_instance_admin.InstanceConfig( + name="{}/instanceConfigs/{}".format( + spanner_client.project_name, user_config_name + ), + display_name="custom-python-samples", + config_type=spanner_instance_admin.InstanceConfig.Type.USER_MANAGED, + replicas=replicas, + base_config=base_config.name, + labels={"python_cloud_spanner_samples": "true"}, + ), + ) + print("Waiting for operation to complete...") + operation.result(OPERATION_TIMEOUT_SECONDS) + + print("Created instance configuration {}".format(user_config_name)) + + +# [END spanner_create_instance_config] + +# [START spanner_update_instance_config] +def update_instance_config(user_config_name): + """Updates the user-managed instance configuration.""" + + # user_config_name = `custom-nam11` + spanner_client = spanner.Client() + config = spanner_client.instance_admin_api.get_instance_config( + name="{}/instanceConfigs/{}".format( + spanner_client.project_name, user_config_name + ) + ) + config.display_name = "updated custom instance config" + config.labels["updated"] = "true" + operation = spanner_client.instance_admin_api.update_instance_config( + instance_config=config, + update_mask=field_mask_pb2.FieldMask(paths=["display_name", "labels"]), + ) + print("Waiting for operation to complete...") + operation.result(OPERATION_TIMEOUT_SECONDS) + print("Updated instance configuration {}".format(user_config_name)) + + +# [END spanner_update_instance_config] + +# [START spanner_delete_instance_config] +def delete_instance_config(user_config_id): + """Deleted the user-managed instance configuration.""" + spanner_client = spanner.Client() + spanner_client.instance_admin_api.delete_instance_config(name=user_config_id) + print("Instance config {} successfully deleted".format(user_config_id)) + + +# [END spanner_delete_instance_config] + + +# [START spanner_list_instance_config_operations] +def list_instance_config_operations(): + """List the user-managed instance configuration operations.""" + spanner_client = spanner.Client() + operations = spanner_client.instance_admin_api.list_instance_config_operations( + request=spanner_instance_admin.ListInstanceConfigOperationsRequest( + parent=spanner_client.project_name, + filter="(metadata.@type=type.googleapis.com/google.spanner.admin.instance.v1.CreateInstanceConfigMetadata)", + ) + ) + for op in operations: + metadata = spanner_instance_admin.CreateInstanceConfigMetadata.pb( + spanner_instance_admin.CreateInstanceConfigMetadata() + ) + op.metadata.Unpack(metadata) + print( + "List instance config operations {} is {}% completed.".format( + metadata.instance_config.name, metadata.progress.progress_percent + ) + ) + + +# [END spanner_list_instance_config_operations] + + def add_and_drop_database_roles(instance_id, database_id): """Showcases how to manage a user defined database role.""" # [START spanner_add_and_drop_database_roles] @@ -2451,6 +2551,7 @@ def enable_fine_grained_access( ) enable_fine_grained_access_parser.add_argument("--title", default="condition title") + args = parser.parse_args() if args.command == "create_instance": From 6d88700b24e970972c4beb1edd5dd79779b66eb7 Mon Sep 17 00:00:00 2001 From: Astha Mohta Date: Tue, 13 Dec 2022 12:42:11 +0530 Subject: [PATCH 04/11] linting --- samples/samples/snippets.py | 1 - 1 file changed, 1 deletion(-) diff --git a/samples/samples/snippets.py b/samples/samples/snippets.py index 2637b71a2d..368e7c67c1 100644 --- a/samples/samples/snippets.py +++ b/samples/samples/snippets.py @@ -2551,7 +2551,6 @@ def enable_fine_grained_access( ) enable_fine_grained_access_parser.add_argument("--title", default="condition title") - args = parser.parse_args() if args.command == "create_instance": From b9b46d5696d7154c492fc94d6b58504b7dc7fce6 Mon Sep 17 00:00:00 2001 From: Astha Mohta Date: Tue, 13 Dec 2022 14:12:34 +0530 Subject: [PATCH 05/11] linting --- tests/system/test_database_api.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/system/test_database_api.py b/tests/system/test_database_api.py index b708ff23c4..9fac10ed4d 100644 --- a/tests/system/test_database_api.py +++ b/tests/system/test_database_api.py @@ -175,12 +175,12 @@ def test_iam_policy( pool = spanner_v1.BurstyPool(labels={"testcase": "iam_policy"}) temp_db_id = _helpers.unique_id("iam_db", separator="_") create_table = ( - f"CREATE TABLE policy (\n" - f" Id STRING(36) NOT NULL,\n" - f" Field1 STRING(36) NOT NULL\n" - f") PRIMARY KEY (Id)" + "CREATE TABLE policy (\n" + + " Id STRING(36) NOT NULL,\n" + + " Field1 STRING(36) NOT NULL\n" + + ") PRIMARY KEY (Id)" ) - create_role = f"CREATE ROLE parent" + create_role = "CREATE ROLE parent" temp_db = shared_instance.database( temp_db_id, @@ -200,7 +200,7 @@ def test_iam_policy( members=["user:asthamohta@google.com"], condition=expr_pb2.Expr( title="condition title", - expression=f'resource.name.endsWith("/databaseRoles/parent")', + expression='resource.name.endsWith("/databaseRoles/parent")', ), ) From 7bcb0948c33d52e60e60c92d4495e35cfef3af32 Mon Sep 17 00:00:00 2001 From: Astha Mohta <35952883+asthamohta@users.noreply.github.com> Date: Tue, 13 Dec 2022 17:14:28 +0530 Subject: [PATCH 06/11] Update database.py --- google/cloud/spanner_v1/database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index e4f5d9a7a0..0d27763432 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -32,9 +32,9 @@ from google.cloud.spanner_admin_database_v1 import CreateDatabaseRequest from google.cloud.spanner_admin_database_v1 import Database as DatabasePB +from google.cloud.spanner_admin_database_v1 import ListDatabaseRolesRequest from google.cloud.spanner_admin_database_v1 import EncryptionConfig from google.cloud.spanner_admin_database_v1 import RestoreDatabaseEncryptionConfig -from google.cloud.spanner_admin_database_v1 import ListDatabaseRolesRequest from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest from google.cloud.spanner_admin_database_v1.types import DatabaseDialect From b8dc80997b0471fffb9603ceba68de6cfe965c25 Mon Sep 17 00:00:00 2001 From: Astha Mohta <35952883+asthamohta@users.noreply.github.com> Date: Tue, 13 Dec 2022 17:15:25 +0530 Subject: [PATCH 07/11] Update pool.py --- google/cloud/spanner_v1/pool.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index abf5f441f9..216ba5aeff 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -18,9 +18,9 @@ import queue from google.cloud.exceptions import NotFound -from google.cloud.spanner_v1._helpers import _metadata_with_prefix -from google.cloud.spanner_v1 import Session from google.cloud.spanner_v1 import BatchCreateSessionsRequest +from google.cloud.spanner_v1 import Session +from google.cloud.spanner_v1._helpers import _metadata_with_prefix _NOW = datetime.datetime.utcnow # unit tests may replace From e6b78534d977b401f9afa1c41c7aa1a280231f44 Mon Sep 17 00:00:00 2001 From: Astha Mohta <35952883+asthamohta@users.noreply.github.com> Date: Tue, 13 Dec 2022 17:17:57 +0530 Subject: [PATCH 08/11] Update snippets.py --- samples/samples/snippets.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/samples/samples/snippets.py b/samples/samples/snippets.py index 368e7c67c1..ad138b3a1c 100644 --- a/samples/samples/snippets.py +++ b/samples/samples/snippets.py @@ -2400,6 +2400,10 @@ def enable_fine_grained_access( instance = spanner_client.instance(instance_id) database = instance.database(database_id) + # The policy in the response from getDatabaseIAMPolicy might use the policy version + # that you specified, or it might use a lower policy version. For example, if you + # specify version 3, but the policy has no conditional role bindings, the response + # uses version 1. Valid values are 0, 1, and 3. policy = database.get_iam_policy(3) if policy.version < 3: policy.version = 3 From 638d1b0b27d95848cab1dda26ee4c99062f706e1 Mon Sep 17 00:00:00 2001 From: Astha Mohta Date: Thu, 15 Dec 2022 16:33:54 +0530 Subject: [PATCH 09/11] fixing pools --- google/cloud/spanner_v1/pool.py | 8 ++++---- tests/system/test_database_api.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index abf5f441f9..939d35d4bb 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -193,14 +193,14 @@ def bind(self, database): metadata = _metadata_with_prefix(database.name) self._database_role = self._database_role or self._database.database_role request = BatchCreateSessionsRequest( + database=database.database_id, + session_count=self.size - self._sessions.qsize(), session_template=Session(creator_role=self.database_role), ) while not self._sessions.full(): resp = api.batch_create_sessions( request=request, - database=database.name, - session_count=self.size - self._sessions.qsize(), metadata=metadata, ) for session_pb in resp.session: @@ -406,14 +406,14 @@ def bind(self, database): self._database_role = self._database_role or self._database.database_role request = BatchCreateSessionsRequest( + database=database.database_id, + session_count=self.size - created_session_count, session_template=Session(creator_role=self.database_role), ) while created_session_count < self.size: resp = api.batch_create_sessions( request=request, - database=database.name, - session_count=self.size - created_session_count, metadata=metadata, ) for session_pb in resp.session: diff --git a/tests/system/test_database_api.py b/tests/system/test_database_api.py index 9fac10ed4d..055fb42685 100644 --- a/tests/system/test_database_api.py +++ b/tests/system/test_database_api.py @@ -20,6 +20,7 @@ from google.api_core import exceptions from google.iam.v1 import policy_pb2 from google.cloud import spanner_v1 +from google.cloud.spanner_v1.pool import PingingPool from google.type import expr_pb2 from . import _helpers from . import _sample_data @@ -73,6 +74,34 @@ def test_create_database(shared_instance, databases_to_delete, database_dialect) assert temp_db.name in database_ids +def test_database_binding_of_pool( + not_emulator, shared_instance, databases_to_delete, not_postgres +): + temp_db_id = _helpers.unique_id("binding_db", separator="_") + temp_db = shared_instance.database(temp_db_id) + + create_op = temp_db.create() + databases_to_delete.append(temp_db) + create_op.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. + + # Create role and grant select permission on table contacts for parent role. + ddl_statements = _helpers.DDL_STATEMENTS + [ + f"CREATE ROLE parent", + f"GRANT SELECT ON TABLE contacts TO ROLE parent", + ] + operation = temp_db.update_ddl(ddl_statements) + operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. + + pool = PingingPool( + size=1, + default_timeout=500, + ping_interval=100, + database_role="parent", + ) + database = shared_instance.database(temp_db.name, pool=pool) + assert database._pool.database_role == "parent" + + def test_create_database_pitr_invalid_retention_period( not_emulator, # PITR-lite features are not supported by the emulator not_postgres, From c3221373eeb2405af825905bd3f70ec2e89a7862 Mon Sep 17 00:00:00 2001 From: Astha Mohta Date: Thu, 15 Dec 2022 18:52:07 +0530 Subject: [PATCH 10/11] fixing tests --- tests/system/test_database_api.py | 36 +++++++++++++++++++++++++++---- tests/unit/test_pool.py | 5 ++--- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/tests/system/test_database_api.py b/tests/system/test_database_api.py index 055fb42685..a08b015481 100644 --- a/tests/system/test_database_api.py +++ b/tests/system/test_database_api.py @@ -20,7 +20,7 @@ from google.api_core import exceptions from google.iam.v1 import policy_pb2 from google.cloud import spanner_v1 -from google.cloud.spanner_v1.pool import PingingPool +from google.cloud.spanner_v1.pool import FixedSizePool, PingingPool from google.type import expr_pb2 from . import _helpers from . import _sample_data @@ -74,7 +74,35 @@ def test_create_database(shared_instance, databases_to_delete, database_dialect) assert temp_db.name in database_ids -def test_database_binding_of_pool( +def test_database_binding_of_fixed_size_pool( + not_emulator, shared_instance, databases_to_delete, not_postgres +): + temp_db_id = _helpers.unique_id("fixed_size_db", separator="_") + temp_db = shared_instance.database(temp_db_id) + + create_op = temp_db.create() + databases_to_delete.append(temp_db) + create_op.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. + + # Create role and grant select permission on table contacts for parent role. + ddl_statements = _helpers.DDL_STATEMENTS + [ + "CREATE ROLE parent", + "GRANT SELECT ON TABLE contacts TO ROLE parent", + ] + operation = temp_db.update_ddl(ddl_statements) + operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. + + pool = FixedSizePool( + size=1, + default_timeout=500, + ping_interval=100, + database_role="parent", + ) + database = shared_instance.database(temp_db.name, pool=pool) + assert database._pool.database_role == "parent" + + +def test_database_binding_of_pinging_pool( not_emulator, shared_instance, databases_to_delete, not_postgres ): temp_db_id = _helpers.unique_id("binding_db", separator="_") @@ -86,8 +114,8 @@ def test_database_binding_of_pool( # Create role and grant select permission on table contacts for parent role. ddl_statements = _helpers.DDL_STATEMENTS + [ - f"CREATE ROLE parent", - f"GRANT SELECT ON TABLE contacts TO ROLE parent", + "CREATE ROLE parent", + "GRANT SELECT ON TABLE contacts TO ROLE parent", ] operation = temp_db.update_ddl(ddl_statements) operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py index 48cc1434ef..3a9d35bc92 100644 --- a/tests/unit/test_pool.py +++ b/tests/unit/test_pool.py @@ -956,11 +956,10 @@ def __init__(self, name): self.name = name self._sessions = [] self._database_role = None + self.database_id = name def mock_batch_create_sessions( request=None, - database=None, - session_count=10, timeout=10, metadata=[], labels={}, @@ -969,7 +968,7 @@ def mock_batch_create_sessions( from google.cloud.spanner_v1 import Session database_role = request.session_template.creator_role if request else None - if session_count < 2: + if request.session_count < 2: response = BatchCreateSessionsResponse( session=[Session(creator_role=database_role, labels=labels)] ) From b5325766cfe30f14d5b98eb311c0606cb7b03662 Mon Sep 17 00:00:00 2001 From: Astha Mohta Date: Thu, 15 Dec 2022 21:35:22 +0530 Subject: [PATCH 11/11] changes --- tests/system/test_database_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/system/test_database_api.py b/tests/system/test_database_api.py index a08b015481..699b3f4a69 100644 --- a/tests/system/test_database_api.py +++ b/tests/system/test_database_api.py @@ -95,7 +95,6 @@ def test_database_binding_of_fixed_size_pool( pool = FixedSizePool( size=1, default_timeout=500, - ping_interval=100, database_role="parent", ) database = shared_instance.database(temp_db.name, pool=pool)