Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(x-goog-spanner-request-id): implement Request-ID #1264

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 81 additions & 1 deletion google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@
from google.cloud.spanner_v1 import TypeCode
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import JsonObject
from google.cloud.spanner_v1.request_id_header import with_request_id
from google.cloud.spanner_v1.request_id_header import REQ_ID_HEADER_KEY, with_request_id
from google.rpc.error_details_pb2 import RetryInfo

import random
from typing import Callable

# Validation error messages
NUMERIC_MAX_SCALE_ERR_MSG = (
Expand Down Expand Up @@ -641,6 +642,85 @@ def __radd__(self, n):
"""
return self.__add__(n)

def reset(self):
with self.__lock:
self.__value = 0


def _metadata_with_request_id(*args, **kwargs):
return with_request_id(*args, **kwargs)


patched = {}


def inject_retry_header_control(api):
# For each method, add an _attempt value that'll then be
# retrieved for each retry.
# 1. Patch the __getattribute__ method to match items in our manifest.
target = type(api)
hex_id = hex(id(target))
if patched.get(hex_id, None) is not None:
return

orig_getattribute = getattr(target, "__getattribute__")

def patched_getattribute(obj, key, *args, **kwargs):
if key.startswith("_"):
return orig_getattribute(obj, key, *args, **kwargs)

attr = orig_getattribute(obj, key, *args, **kwargs)
print("args", args, "attr.dir", dir(attr))

# 0. If we already patched it, we can return immediately.
if getattr(attr, "_patched", None) is not None:
return attr

# 1. Skip over non-methods.
if not callable(attr):
return attr

# 2. Skip modifying private and mangled methods.
mangled_or_private = attr.__name__.startswith("_")
if mangled_or_private:
return attr

print("\033[35mattr", attr, "hex_id", hex(id(attr)), "\033[00m")

# 3. Wrap the callable attribute and then capture its metadata keyed argument.
def wrapped_attr(*args, **kwargs):
metadata = kwargs.get("metadata", [])
if not metadata:
# Increment the reinvocation count.
print("not metatadata", attr.__name__)
wrapped_attr._attempt += 1
return attr(*args, **kwargs)

# 4. Find all the headers that match the target header key.
all_metadata = []
for key, value in metadata:
if key is REQ_ID_HEADER_KEY:
print("key", key, "value", value, "attempt", wrapped_attr._attempt)
# 5. Increment the original_attempt with that of our re-invocation count.
splits = value.split(".")
hdr_attempt_plus_reinvocation = (
int(splits[-1]) + wrapped_attr._attempt
)
splits[-1] = str(hdr_attempt_plus_reinvocation)
value = ".".join(splits)

all_metadata.append((key, value))

# Increment the reinvocation count.
wrapped_attr._attempt += 1

kwargs["metadata"] = all_metadata
print("\033[34mwrap_callable", hex(id(attr)), attr.__name__, "\033[00m")
return attr(*args, **kwargs)

wrapped_attr._attempt = 0
wrapped_attr._patched = True
return wrapped_attr

setattr(target, "__getattribute__", patched_getattribute)
patched[hex_id] = True
12 changes: 10 additions & 2 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,11 @@ def commit(
method = functools.partial(
api.commit,
request=request,
metadata=metadata,
metadata=database.metadata_with_request_id(
database._next_nth_request,
1,
metadata,
),
)
deadline = time.time() + kwargs.get(
"timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS
Expand Down Expand Up @@ -352,7 +356,11 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
method = functools.partial(
api.batch_write,
request=request,
metadata=metadata,
metadata=database.metadata_with_request_id(
database._next_nth_request,
1,
metadata,
),
)
response = _retry(
method,
Expand Down
9 changes: 9 additions & 0 deletions google/cloud/spanner_v1/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from google.cloud.spanner_v1._helpers import _merge_query_options
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1.instance import Instance
from google.cloud.spanner_v1._helpers import AtomicCounter

_CLIENT_INFO = client_info.ClientInfo(client_library_version=__version__)
EMULATOR_ENV_VAR = "SPANNER_EMULATOR_HOST"
Expand Down Expand Up @@ -147,6 +148,8 @@ class Client(ClientWithProject):
SCOPE = (SPANNER_ADMIN_SCOPE,)
"""The scopes required for Google Cloud Spanner."""

NTH_CLIENT = AtomicCounter()

def __init__(
self,
project=None,
Expand Down Expand Up @@ -199,6 +202,12 @@ def __init__(
self._route_to_leader_enabled = route_to_leader_enabled
self._directed_read_options = directed_read_options
self._observability_options = observability_options
self._nth_client_id = Client.NTH_CLIENT.increment()
self._nth_request = AtomicCounter(0)

@property
def _next_nth_request(self):
return self._nth_request.increment()

@property
def credentials(self):
Expand Down
Loading