Skip to content

Commit

Permalink
feat(x-goog-spanner-request-id): implement request_id generation and …
Browse files Browse the repository at this point in the history
…propagation

Generates a request_id that is then injected inside metadata
that's sent over to the Cloud Spanner backend.

Fixes #1261
  • Loading branch information
odeke-em committed Dec 12, 2024
1 parent 054a186 commit 4f1da67
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 1 deletion.
30 changes: 30 additions & 0 deletions google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import math
import time
import base64
import threading

from google.protobuf.struct_pb2 import ListValue
from google.protobuf.struct_pb2 import Value
Expand Down Expand Up @@ -437,3 +438,32 @@ def _metadata_with_leader_aware_routing(value, **kw):
List[Tuple[str, str]]: RPC metadata with leader aware routing header
"""
return ("x-goog-spanner-route-to-leader", str(value).lower())


class AtomicInt:
def __init__(self, start_value=0):
self.__lock = threading.Lock()
self.__value = start_value

def __iadd__(self, n):
res = 0
with self.__lock:
res = self.__value
res += n
self.__value = res
return res

def __add__(self, n):
res = 0
with self.__lock:
res = self.__value
res += 0
return res

@property
def value(self):
with self.__lock:
return self.__value

def increment(self, value=1):
return self.__iadd__(value)
10 changes: 10 additions & 0 deletions google/cloud/spanner_v1/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import grpc
import os
import warnings
import threading

from google.api_core.gapic_v1 import client_info
from google.auth.credentials import AnonymousCredentials
Expand All @@ -48,6 +49,7 @@
from google.cloud.spanner_v1._helpers import _merge_query_options
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1.instance import Instance
from google.cloud.spanner_v1._helpers import AtomicInt

_CLIENT_INFO = client_info.ClientInfo(client_library_version=__version__)
EMULATOR_ENV_VAR = "SPANNER_EMULATOR_HOST"
Expand Down Expand Up @@ -147,6 +149,8 @@ class Client(ClientWithProject):
SCOPE = (SPANNER_ADMIN_SCOPE,)
"""The scopes required for Google Cloud Spanner."""

NTH_CLIENT = AtomicInt()

def __init__(
self,
project=None,
Expand Down Expand Up @@ -199,6 +203,12 @@ def __init__(
self._route_to_leader_enabled = route_to_leader_enabled
self._directed_read_options = directed_read_options
self._observability_options = observability_options
self._nth_client_id = Client.NTH_CLIENT.increment()
self._nth_request = AtomicInt()

@property
def _next_nth_request(self):
return self._nth_request.increment()

@property
def credentials(self):
Expand Down
12 changes: 12 additions & 0 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from google.cloud.spanner_v1 import SpannerClient
from google.cloud.spanner_v1._helpers import _merge_query_options
from google.cloud.spanner_v1._helpers import (
AtomicInt,
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
)
Expand Down Expand Up @@ -693,8 +694,15 @@ def execute_partitioned_dml(
_metadata_with_leader_aware_routing(self._route_to_leader_enabled)
)

nth_request = self._next_nth_request()
attempt = AtomicInt(1) # It'll be incremented inside _restart_on_unavailable

def execute_pdml():
with SessionCheckout(self._pool) as session:
channel_id = session._channel_id
metadata = with_request_id(
self._client._nth_client_id, nth_request, attempt.value, metadata
)
txn = api.begin_transaction(
session=session.name, options=txn_options, metadata=metadata
)
Expand All @@ -719,6 +727,7 @@ def execute_pdml():
request=request,
transaction_selector=txn_selector,
observability_options=self.observability_options,
attempt=attempt,
)

result_set = StreamedResultSet(iterator)
Expand All @@ -728,6 +737,9 @@ def execute_pdml():

return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)()

def _next_nth_request(self):
return self._instance._client._next_nth_request

def session(self, labels=None, database_role=None):
"""Factory to create a session for this database.
Expand Down
12 changes: 11 additions & 1 deletion google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import datetime
import queue
import threading

from google.cloud.exceptions import NotFound
from google.cloud.spanner_v1 import BatchCreateSessionsRequest
Expand Down Expand Up @@ -47,6 +48,8 @@ def __init__(self, labels=None, database_role=None):
labels = {}
self._labels = labels
self._database_role = database_role
self.__lock = threading.lock()
self._session_id_to_channel_id = dict()

@property
def labels(self):
Expand Down Expand Up @@ -122,10 +125,17 @@ def _new_session(self):
:rtype: :class:`~google.cloud.spanner_v1.session.Session`
:returns: new session instance.
"""
return self._database.session(
session = self._database.session(
labels=self.labels, database_role=self.database_role
)

with self.__lock:
channel_id = len(self._session_id_to_channel_id) + 1
self._session_id_to_channel_id[session._session.id] = channel_id
session._channel_id = channel_id

return session

def session(self, **kwargs):
"""Check out a session from the pool.
Expand Down
42 changes: 42 additions & 0 deletions google/cloud/spanner_v1/request_id_header.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import threading

REQ_ID_VERSION = 1 # The version of the x-goog-spanner-request-id spec.
REQ_ID_HEADER_KEY = "x-goog-spanner-request-id"


def generate_rand_uint64():
b = os.urandom(8)
return (
b[7] & 0xFF
| (b[6] & 0xFF) << 8
| (b[5] & 0xFF) << 16
| (b[4] & 0xFF) << 24
| (b[3] & 0xFF) << 32
| (b[2] & 0xFF) << 36
| (b[1] & 0xFF) << 48
| (b[0] & 0xFF) << 56
)


REQ_RAND_PROCESS_ID = generate_rand_uint64()


def with_request_id(client_id, nth_request, attempt, other_metadata=[]):
req_id = f"{REQ_ID_VERSION}.{REQ_RAND_PROCESS_ID}.{client_id}.{channel_id}.{nth_request}.{attempt}"
other_metadata.append((REQ_ID_HEADER_KEY, req_id))
return other_metadata
2 changes: 2 additions & 0 deletions google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(self, database, labels=None, database_role=None):
labels = {}
self._labels = labels
self._database_role = database_role
self.__channel_id = 0

def __lt__(self, other):
return self._session_id < other._session_id
Expand Down Expand Up @@ -203,6 +204,7 @@ def delete(self):
raise ValueError("Session ID not set by back-end")
api = self._database.spanner_api
metadata = _metadata_with_prefix(self._database.name)
# Generate the request_id
observability_options = getattr(self._database, "observability_options", None)
with trace_call(
"CloudSpanner.DeleteSession",
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def _restart_on_unavailable(
transaction=None,
transaction_selector=None,
observability_options=None,
attempt=0,
):
"""Restart iteration after :exc:`.ServiceUnavailable`.
Expand Down Expand Up @@ -91,6 +92,7 @@ def _restart_on_unavailable(
):
iterator = method(request=request)
while True:
attempt += 1
try:
for item in iterator:
item_buffer.append(item)
Expand Down

0 comments on commit 4f1da67

Please sign in to comment.