Skip to content

feat(x-goog-spanner-request-id): implement Request-ID #1264

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
cb548a2
feat(x-goog-spanner-request-id): implement request_id generation and …
odeke-em Dec 2, 2024
105c397
More plumbing for Database DDL methods
odeke-em Dec 20, 2024
60f1b71
Update test_spannner.test_transaction tests
odeke-em Dec 20, 2024
53653d1
Update test_session tests
odeke-em Dec 20, 2024
4afa2ce
Update tests
odeke-em Dec 25, 2024
3402751
Update propagation of changes
odeke-em Dec 25, 2024
2c4d32d
Plumb test_context_manager
odeke-em Dec 25, 2024
7ea1b79
Fix more tests
odeke-em Dec 25, 2024
0f1ecd3
More test plumbing
odeke-em Dec 25, 2024
3532936
Infer database._channel_id only once along with spanner_api
odeke-em Dec 27, 2024
24f1bd4
Update batch tests
odeke-em Dec 27, 2024
ba39d46
test: add tests for retries and move mock server test to correct dire…
olavloite Dec 27, 2024
04ec941
fix: revert default Python version
olavloite Dec 27, 2024
d4bf747
Fix discrepancy with api.batch_create_sessions automatically retrying…
odeke-em Jan 4, 2025
9ce98c3
Take into account current behavior of /GetSession /BatchCreateSession…
odeke-em Jan 4, 2025
0495082
Implement interceptor to wrap and increase x-goog-spanner-request-id …
odeke-em Jan 15, 2025
df8f81f
Correctly handle wrapping by class for api objects
odeke-em Jan 17, 2025
f435747
Revert poool creation interception attempts
odeke-em Jan 18, 2025
ea0823f
Wire up and revert some prints
odeke-em Jan 18, 2025
f8ad94f
Initial ExecuteStreamingSql request in snapshot should have the header
odeke-em Mar 28, 2025
15984cb
Consolidate retries for _restart_on_unavailable
odeke-em Mar 28, 2025
a5fdebd
Fix missing variable declaration
odeke-em Mar 30, 2025
80faca0
Adjust with updates
odeke-em Mar 30, 2025
06c12a2
Update _execute_partitioned_dml_helper
odeke-em Apr 1, 2025
9d4d942
Monkey patch updates
odeke-em Apr 8, 2025
6ed0b74
Merge branch 'main' into x-goog-spanner-request-id
odeke-em May 17, 2025
c31c03f
Reduce unnecessary changes
odeke-em May 17, 2025
e67cc9e
Experiment with wrapping for gapic retries
odeke-em May 18, 2025
dab92dd
Remove duplicate code
odeke-em May 19, 2025
4334dd4
Complete mockserver tests
odeke-em May 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 271 additions & 1 deletion google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import math
import time
import base64
import inspect
import threading

from google.protobuf.struct_pb2 import ListValue
Expand All @@ -33,7 +34,7 @@
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import JsonObject, Interval
from google.cloud.spanner_v1 import TransactionOptions
from google.cloud.spanner_v1.request_id_header import with_request_id
from google.cloud.spanner_v1.request_id_header import REQ_ID_HEADER_KEY, with_request_id
from google.rpc.error_details_pb2 import RetryInfo

try:
Expand All @@ -45,6 +46,7 @@
HAS_OPENTELEMETRY_INSTALLED = False
from typing import List, Tuple
import random
from typing import Callable

# Validation error messages
NUMERIC_MAX_SCALE_ERR_MSG = (
Expand Down Expand Up @@ -574,6 +576,7 @@ def _retry(


def _check_rst_stream_error(exc):
print("\033[31mrst_", exc, "\033[00m")
resumable_error = (
any(
resumable_message in exc.message
Expand All @@ -587,6 +590,20 @@ def _check_rst_stream_error(exc):
raise


def _check_unavailable(exc):
resumable_error = (
any(
resumable_message in exc.message
for resumable_message in (
"INTERNAL",
"Service unavailable",
)
),
)
if not resumable_error:
raise


def _metadata_with_leader_aware_routing(value, **kw):
"""Create RPC metadata containing a leader aware routing header

Expand Down Expand Up @@ -749,3 +766,256 @@ def _merge_Transaction_Options(

# Convert protobuf object back into a TransactionOptions instance
return TransactionOptions(merged_pb)


class InterceptingHeaderInjector:
def __init__(self, original_callable: Callable):
self._original_callable = original_callable


patched = {}
patched_mu = threading.Lock()


def inject_retry_header_control(api):
# monkey_patch(type(api))
# monkey_patch(api)
pass


def monkey_patch(typ):
keys = dir(typ)
attempts = dict()
for key in keys:
if key.startswith("_"):
continue

if key != "batch_create_sessions":
continue

fn = getattr(typ, key)

signature = inspect.signature(fn)
if signature.parameters.get("metadata", None) is None:
continue

print("fn.__call__", inspect.getsource(fn))

def as_proxy(db, *args, **kwargs):
print("db_key", hex(id(db)))
print("as_proxy", args, kwargs)
metadata = kwargs.get("metadata", None)
if not metadata:
return fn(db, *args, **kwargs)

hash_key = hex(id(db)) + "." + hex(id(key))
attempts.setdefault(hash_key, 0)
attempts[hash_key] += 1
# 4. Find all the headers that match the target header key.
all_metadata = []
for mkey, value in metadata:
if mkey is not REQ_ID_HEADER_KEY:
continue

splits = value.split(".")
# 5. Increment the original_attempt with that of our re-invocation count.
print("\033[34mkey", mkey, "\033[00m", splits)
hdr_attempt_plus_reinvocation = int(splits[-1]) + attempts[hash_key]
splits[-1] = str(hdr_attempt_plus_reinvocation)
value = ".".join(splits)

all_metadata.append((mkey, value))

kwargs["metadata"] = all_metadata
return fn(db, *args, **kwargs)

setattr(typ, key, as_proxy)


def alt_foo():
memoize_map = dict()
orig_get_attr = getattr(obj, "__getattribute__")
hex_orig = hex(id(orig_get_attr))
hex_patched = None

def patched_getattribute(obj, key, *args, **kwargs):
if key.startswith("_"):
return orig_get_attr(obj, key, *args, **kwargs)

if key != "batch_create_sessions":
return orig_get_attr(obj, key, *args, **kwargs)

map_key = hex(id(key)) + hex(id(obj))
memoized = memoize_map.get(map_key, None)
if memoized:
if False:
print(
"memoized_hit",
key,
"\033[35m",
inspect.getsource(orig_value),
"\033[00m",
)
print("memoized_hit", key, "\033[35m", map_key, "\033[00m")
return memoized

orig_value = orig_get_attr(obj, key, *args, **kwargs)
if not callable(orig_value):
return orig_value

signature = inspect.signature(orig_value)
if signature.parameters.get("metadata", None) is None:
return orig_value

if False:
print(
key,
"\033[34m",
map_key,
"\033[00m",
signature,
signature.parameters.get("metadata", None),
)

if False:
stack = inspect.stack()
ends = stack[-50:-20]
for i, st in enumerate(ends):
print(i, st.filename, st.lineno)

print(
"\033[33mmonkey patching now\033[00m",
key,
"hex_orig",
hex_orig,
"hex_patched",
hex_patched,
)
counters = dict(attempt=0)

def patched_method(*aargs, **kkwargs):
counters["attempt"] += 1
print("counters", counters)
metadata = kkwargs.get("metadata", None)
if not metadata:
return orig_value(*aargs, **kkwargs)

# 4. Find all the headers that match the target header key.
all_metadata = []
for mkey, value in metadata:
if mkey is REQ_ID_HEADER_KEY:
attempt = counters["attempt"]
if attempt > 1:
# 5. Increment the original_attempt with that of our re-invocation count.
splits = value.split(".")
print("\033[34mkey", mkey, "\033[00m", splits)
hdr_attempt_plus_reinvocation = int(splits[-1]) + attempt
splits[-1] = str(hdr_attempt_plus_reinvocation)
value = ".".join(splits)

all_metadata.append((mkey, value))

kwargs["metadata"] = all_metadata

try:
return orig_value(*aargs, **kkwargs)

except (InternalServerError, ServiceUnavailable) as exc:
print("caught this exception, incrementing", exc)
counters["attempt"] += 1
raise exc

memoize_map[map_key] = patched_method
return patched_method

hex_patched = hex(id(patched_getattribute))
setattr(obj, "__getattribute__", patched_getattribute)


def foo(api):
global patched
global patched_mu

# For each method, add an _attempt value that'll then be
# retrieved for each retry.
# 1. Patch the __getattribute__ method to match items in our manifest.
target = type(api)
hex_id = hex(id(target))
if patched.get(hex_id, None) is not None:
return

orig_getattribute = getattr(target, "__getattribute__")

def patched_getattribute(obj, key, *args, **kwargs):
# 1. Skip modifying private and mangled methods.
if key.startswith("_"):
return orig_getattribute(obj, key, *args, **kwargs)

attr = orig_getattribute(obj, key, *args, **kwargs)

# 2. Skip over non-methods.
if not callable(attr):
patched_mu.release()
return attr

patched_key = hex(id(key)) + hex(id(obj))
patched_mu.acquire()
already_patched = patched.get(patched_key, None)

other_attempts = dict(attempts=0)

# 3. Wrap the callable attribute and then capture its metadata keyed argument.
def wrapped_attr(*args, **kwargs):
print("\033[31m", key, "attempt", other_attempts["attempts"], "\033[00m")
other_attempts["attempts"] += 1

metadata = kwargs.get("metadata", [])
if not metadata:
# Increment the reinvocation count.
wrapped_attr._attempt += 1
return attr(*args, **kwargs)

print(
"\033[35mwrapped_attr",
key,
args,
kwargs,
"attempt",
wrapped_attr._attempt,
"\033[00m",
)

# 4. Find all the headers that match the target header key.
all_metadata = []
for mkey, value in metadata:
if mkey is REQ_ID_HEADER_KEY:
if wrapped_attr._attempt > 0:
# 5. Increment the original_attempt with that of our re-invocation count.
splits = value.split(".")
print("\033[34mkey", mkey, "\033[00m", splits)
hdr_attempt_plus_reinvocation = (
int(splits[-1]) + wrapped_attr._attempt
)
splits[-1] = str(hdr_attempt_plus_reinvocation)
value = ".".join(splits)

all_metadata.append((mkey, value))

kwargs["metadata"] = all_metadata
wrapped_attr._attempt += 1
print(key, "\033[36mreplaced_all_metadata", all_metadata, "\033[00m")
return attr(*args, **kwargs)

if already_patched:
print("patched_key \033[32m", patched_key, key, "\033[00m", already_patched)
setattr(attr, "patched", True)
# Increment the reinvocation count.
patched_mu.release()
return already_patched

patched[patched_key] = wrapped_attr
setattr(wrapped_attr, "_attempt", 0)
patched_mu.release()
return wrapped_attr

setattr(target, "__getattribute__", patched_getattribute)
7 changes: 6 additions & 1 deletion google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
_merge_Transaction_Options,
AtomicCounter,
)
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1 import RequestOptions
Expand Down Expand Up @@ -388,7 +389,11 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
method = functools.partial(
api.batch_write,
request=request,
metadata=metadata,
metadata=database.metadata_with_request_id(
database._next_nth_request,
1,
metadata,
),
)
response = _retry(
method,
Expand Down
1 change: 1 addition & 0 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from google.cloud.spanner_v1 import SpannerClient
from google.cloud.spanner_v1._helpers import _merge_query_options
from google.cloud.spanner_v1._helpers import (
AtomicCounter,
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
_metadata_with_request_id,
Expand Down
Loading