diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index 29bd604e7b..1f4bf5b174 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -19,6 +19,7 @@ import math import time import base64 +import threading from google.protobuf.struct_pb2 import ListValue from google.protobuf.struct_pb2 import Value @@ -30,6 +31,7 @@ from google.cloud.spanner_v1 import TypeCode from google.cloud.spanner_v1 import ExecuteSqlRequest from google.cloud.spanner_v1 import JsonObject +from google.cloud.spanner_v1.request_id_header import with_request_id # Validation error messages NUMERIC_MAX_SCALE_ERR_MSG = ( @@ -525,3 +527,45 @@ def _metadata_with_leader_aware_routing(value, **kw): List[Tuple[str, str]]: RPC metadata with leader aware routing header """ return ("x-goog-spanner-route-to-leader", str(value).lower()) + + +class AtomicCounter: + def __init__(self, start_value=0): + self.__lock = threading.Lock() + self.__value = start_value + + @property + def value(self): + with self.__lock: + return self.__value + + def increment(self, n=1): + with self.__lock: + self.__value += n + return self.__value + + def __iadd__(self, n): + """ + Defines the inplace += operator result. + """ + with self.__lock: + self.__value += n + return self + + def __add__(self, n): + """ + Defines the result of invoking: value = AtomicCounter + addable + """ + with self.__lock: + n += self.__value + return n + + def __radd__(self, n): + """ + Defines the result of invoking: value = addable + AtomicCounter + """ + return self.__add__(n) + + +def _metadata_with_request_id(*args, **kwargs): + return with_request_id(*args, **kwargs) diff --git a/google/cloud/spanner_v1/request_id_header.py b/google/cloud/spanner_v1/request_id_header.py new file mode 100644 index 0000000000..8376778273 --- /dev/null +++ b/google/cloud/spanner_v1/request_id_header.py @@ -0,0 +1,42 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +REQ_ID_VERSION = 1 # The version of the x-goog-spanner-request-id spec. +REQ_ID_HEADER_KEY = "x-goog-spanner-request-id" + + +def generate_rand_uint64(): + b = os.urandom(8) + return ( + b[7] & 0xFF + | (b[6] & 0xFF) << 8 + | (b[5] & 0xFF) << 16 + | (b[4] & 0xFF) << 24 + | (b[3] & 0xFF) << 32 + | (b[2] & 0xFF) << 36 + | (b[1] & 0xFF) << 48 + | (b[0] & 0xFF) << 56 + ) + + +REQ_RAND_PROCESS_ID = generate_rand_uint64() + + +def with_request_id(client_id, channel_id, nth_request, attempt, other_metadata=[]): + req_id = f"{REQ_ID_VERSION}.{REQ_RAND_PROCESS_ID}.{client_id}.{channel_id}.{nth_request}.{attempt}" + all_metadata = other_metadata.copy() + all_metadata.append((REQ_ID_HEADER_KEY, req_id)) + return all_metadata diff --git a/tests/unit/test_atomic_counter.py b/tests/unit/test_atomic_counter.py new file mode 100644 index 0000000000..92d10cac79 --- /dev/null +++ b/tests/unit/test_atomic_counter.py @@ -0,0 +1,78 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import threading +import unittest +from google.cloud.spanner_v1._helpers import AtomicCounter + + +class TestAtomicCounter(unittest.TestCase): + def test_initialization(self): + ac_default = AtomicCounter() + assert ac_default.value == 0 + + ac_1 = AtomicCounter(1) + assert ac_1.value == 1 + + ac_negative_1 = AtomicCounter(-1) + assert ac_negative_1.value == -1 + + def test_increment(self): + ac = AtomicCounter() + result_default = ac.increment() + assert result_default == 1 + assert ac.value == 1 + + result_with_value = ac.increment(2) + assert result_with_value == 3 + assert ac.value == 3 + result_plus_100 = ac.increment(100) + assert result_plus_100 == 103 + + def test_plus_call(self): + ac = AtomicCounter() + ac += 1 + assert ac.value == 1 + + n = ac + 2 + assert n == 3 + assert ac.value == 1 + + n = 200 + ac + assert n == 201 + assert ac.value == 1 + + def test_multiple_threads_incrementing(self): + ac = AtomicCounter() + n = 200 + m = 10 + + def do_work(): + for i in range(m): + ac.increment() + + threads = [] + for i in range(n): + th = threading.Thread(target=do_work) + threads.append(th) + th.start() + + random.shuffle(threads) + for th in threads: + th.join() + assert not th.is_alive() + + # Finally the result should be n*m + assert ac.value == n * m