diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index 29bd604e7b..b38745c311 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -19,6 +19,7 @@ import math import time import base64 +import threading from google.protobuf.struct_pb2 import ListValue from google.protobuf.struct_pb2 import Value @@ -525,3 +526,41 @@ def _metadata_with_leader_aware_routing(value, **kw): List[Tuple[str, str]]: RPC metadata with leader aware routing header """ return ("x-goog-spanner-route-to-leader", str(value).lower()) + + +class AtomicCounter: + def __init__(self, start_value=0): + self.__lock = threading.Lock() + self.__value = start_value + + @property + def value(self): + with self.__lock: + return self.__value + + def increment(self, n=1): + with self.__lock: + self.__value += n + return self.__value + + def __iadd__(self, n): + """ + Defines the inplace += operator result. + """ + with self.__lock: + self.__value += n + return self + + def __add__(self, n): + """ + Defines the result of invoking: value = AtomicCounter + addable + """ + with self.__lock: + n += self.__value + return n + + def __radd__(self, n): + """ + Defines the result of invoking: value = addable + AtomicCounter + """ + return self.__add__(n) diff --git a/google/cloud/spanner_v1/client.py b/google/cloud/spanner_v1/client.py index afe6264717..20c645e0b3 100644 --- a/google/cloud/spanner_v1/client.py +++ b/google/cloud/spanner_v1/client.py @@ -26,6 +26,7 @@ import grpc import os import warnings +import threading from google.api_core.gapic_v1 import client_info from google.auth.credentials import AnonymousCredentials @@ -48,6 +49,7 @@ from google.cloud.spanner_v1._helpers import _merge_query_options from google.cloud.spanner_v1._helpers import _metadata_with_prefix from google.cloud.spanner_v1.instance import Instance +from google.cloud.spanner_v1._helpers import AtomicCounter _CLIENT_INFO = client_info.ClientInfo(client_library_version=__version__) EMULATOR_ENV_VAR = "SPANNER_EMULATOR_HOST" @@ -147,6 +149,8 @@ class Client(ClientWithProject): SCOPE = (SPANNER_ADMIN_SCOPE,) """The scopes required for Google Cloud Spanner.""" + NTH_CLIENT = AtomicCounter() + def __init__( self, project=None, @@ -199,6 +203,12 @@ def __init__( self._route_to_leader_enabled = route_to_leader_enabled self._directed_read_options = directed_read_options self._observability_options = observability_options + self._nth_client_id = Client.NTH_CLIENT.increment() + self._nth_request = AtomicCounter() + + @property + def _next_nth_request(self): + return self._nth_request.increment() @property def credentials(self): diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 88d2bb60f7..32e23e9bf7 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -50,6 +50,7 @@ from google.cloud.spanner_v1 import SpannerClient from google.cloud.spanner_v1._helpers import _merge_query_options from google.cloud.spanner_v1._helpers import ( + AtomicCounter, _metadata_with_prefix, _metadata_with_leader_aware_routing, ) @@ -698,8 +699,15 @@ def execute_partitioned_dml( _metadata_with_leader_aware_routing(self._route_to_leader_enabled) ) + nth_request = self._next_nth_request() + attempt = AtomicCounter(1) # It'll be incremented inside _restart_on_unavailable + def execute_pdml(): with SessionCheckout(self._pool) as session: + channel_id = session._channel_id + metadata = with_request_id( + self._client._nth_client_id, nth_request, attempt.value, metadata + ) txn = api.begin_transaction( session=session.name, options=txn_options, metadata=metadata ) @@ -725,6 +733,7 @@ def execute_pdml(): request=request, transaction_selector=txn_selector, observability_options=self.observability_options, + attempt=attempt, ) result_set = StreamedResultSet(iterator) @@ -734,6 +743,9 @@ def execute_pdml(): return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)() + def _next_nth_request(self): + return self._instance._client._next_nth_request + def session(self, labels=None, database_role=None): """Factory to create a session for this database. diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 03bff81b52..f3d9ce05ab 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -17,6 +17,7 @@ import datetime import queue import time +import threading from google.cloud.exceptions import NotFound from google.cloud.spanner_v1 import BatchCreateSessionsRequest @@ -53,6 +54,8 @@ def __init__(self, labels=None, database_role=None): labels = {} self._labels = labels self._database_role = database_role + self.__lock = threading.lock() + self._session_id_to_channel_id = dict() @property def labels(self): @@ -128,10 +131,17 @@ def _new_session(self): :rtype: :class:`~google.cloud.spanner_v1.session.Session` :returns: new session instance. """ - return self._database.session( + session = self._database.session( labels=self.labels, database_role=self.database_role ) + with self.__lock: + channel_id = len(self._session_id_to_channel_id) + 1 + self._session_id_to_channel_id[session._session.id] = channel_id + session._channel_id = channel_id + + return session + def session(self, **kwargs): """Check out a session from the pool. diff --git a/google/cloud/spanner_v1/request_id_header.py b/google/cloud/spanner_v1/request_id_header.py new file mode 100644 index 0000000000..121f6e9e16 --- /dev/null +++ b/google/cloud/spanner_v1/request_id_header.py @@ -0,0 +1,42 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import threading + +REQ_ID_VERSION = 1 # The version of the x-goog-spanner-request-id spec. +REQ_ID_HEADER_KEY = "x-goog-spanner-request-id" + + +def generate_rand_uint64(): + b = os.urandom(8) + return ( + b[7] & 0xFF + | (b[6] & 0xFF) << 8 + | (b[5] & 0xFF) << 16 + | (b[4] & 0xFF) << 24 + | (b[3] & 0xFF) << 32 + | (b[2] & 0xFF) << 36 + | (b[1] & 0xFF) << 48 + | (b[0] & 0xFF) << 56 + ) + + +REQ_RAND_PROCESS_ID = generate_rand_uint64() + + +def with_request_id(client_id, nth_request, attempt, other_metadata=[]): + req_id = f"{REQ_ID_VERSION}.{REQ_RAND_PROCESS_ID}.{client_id}.{channel_id}.{nth_request}.{attempt}" + other_metadata.append((REQ_ID_HEADER_KEY, req_id)) + return other_metadata diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index d73a8cc2b5..85a0966e8f 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -75,6 +75,7 @@ def __init__(self, database, labels=None, database_role=None): self._labels = labels self._database_role = database_role self._last_use_time = datetime.utcnow() + self.__channel_id = 0 def __lt__(self, other): return self._session_id < other._session_id diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 6234c96435..045d29ac0d 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -58,6 +58,7 @@ def _restart_on_unavailable( transaction=None, transaction_selector=None, observability_options=None, + attempt=0, ): """Restart iteration after :exc:`.ServiceUnavailable`. @@ -92,6 +93,7 @@ def _restart_on_unavailable( ): iterator = method(request=request) while True: + attempt += 1 try: for item in iterator: item_buffer.append(item) diff --git a/tests/unit/test_atomic_counter.py b/tests/unit/test_atomic_counter.py new file mode 100644 index 0000000000..7915a807ad --- /dev/null +++ b/tests/unit/test_atomic_counter.py @@ -0,0 +1,81 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import random +import threading +import unittest +from google.cloud.spanner_v1._helpers import AtomicCounter + +class TestAtomicCounter(unittest.TestCase): + def test_initialization(self): + ac_default = AtomicCounter() + assert ac_default.value == 0 + + ac_1 = AtomicCounter(1) + assert ac_1.value == 1 + + ac_negative_1 = AtomicCounter(-1) + assert ac_negative_1.value == -1 + + def test_increment(self): + ac = AtomicCounter() + result_default = ac.increment() + assert result_default == 1 + assert ac.value == 1 + + result_with_value = ac.increment(2) + assert result_with_value == 3 + assert ac.value == 3 + result_plus_100 = ac.increment(100) + assert result_plus_100 == 103 + + def test_plus_call(self): + ac = AtomicCounter() + ac += 1 + assert ac.value == 1 + + n = ac + 2 + assert n == 3 + assert ac.value == 1 + + n = 200 + ac + assert n == 201 + assert ac.value == 1 + + + def test_multiple_threads_incrementing(self): + ac = AtomicCounter() + n = 200 + m = 10 + + def do_work(): + for i in range(m): + ac.increment() + + threads = [] + for i in range(n): + th = threading.Thread(target=do_work) + threads.append(th) + th.start() + + time.sleep(0.3) + + random.shuffle(threads) + for th in threads: + th.join() + assert th.is_alive() == False + + # Finally the result should be n*m + assert ac.value == n*m