Skip to content

Commit

Permalink
Merge branch 'main' into release-please--branches--main
Browse files Browse the repository at this point in the history
  • Loading branch information
harshachinta authored Dec 6, 2024
2 parents 28242a2 + a6811af commit 43f1f39
Show file tree
Hide file tree
Showing 14 changed files with 602 additions and 201 deletions.
4 changes: 4 additions & 0 deletions google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ def _retry(
retry_count=5,
delay=2,
allowed_exceptions=None,
beforeNextRetry=None,
):
"""
Retry a function with a specified number of retries, delay between retries, and list of allowed exceptions.
Expand All @@ -479,6 +480,9 @@ def _retry(
"""
retries = 0
while retries <= retry_count:
if retries > 0 and beforeNextRetry:
beforeNextRetry(retries, delay)

try:
return func()
except Exception as exc:
Expand Down
19 changes: 17 additions & 2 deletions google/cloud/spanner_v1/_opentelemetry_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,11 @@ def trace_call(name, session, extra_attributes=None, observability_options=None)
tracer = get_tracer(tracer_provider)

# Set base attributes that we know for every trace created
db = session._database
attributes = {
"db.type": "spanner",
"db.url": SpannerClient.DEFAULT_ENDPOINT,
"db.instance": session._database.name,
"db.instance": "" if not db else db.name,
"net.host.name": SpannerClient.DEFAULT_ENDPOINT,
OTEL_SCOPE_NAME: TRACER_NAME,
OTEL_SCOPE_VERSION: TRACER_VERSION,
Expand All @@ -106,7 +107,10 @@ def trace_call(name, session, extra_attributes=None, observability_options=None)
yield span
except Exception as error:
span.set_status(Status(StatusCode.ERROR, str(error)))
span.record_exception(error)
# OpenTelemetry-Python imposes invoking span.record_exception on __exit__
# on any exception. We should file a bug later on with them to only
# invoke .record_exception if not already invoked, hence we should not
# invoke .record_exception on our own else we shall have 2 exceptions.
raise
else:
if (not span._status) or span._status.status_code == StatusCode.UNSET:
Expand All @@ -116,3 +120,14 @@ def trace_call(name, session, extra_attributes=None, observability_options=None)
# it wasn't previously set otherwise.
# https://github.com/googleapis/python-spanner/issues/1246
span.set_status(Status(StatusCode.OK))


def get_current_span():
if not HAS_OPENTELEMETRY_INSTALLED:
return None
return trace.get_current_span()


def add_span_event(span, event_name, event_attributes=None):
if span:
span.add_event(event_name, event_attributes)
12 changes: 12 additions & 0 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@
SpannerGrpcTransport,
)
from google.cloud.spanner_v1.table import Table
from google.cloud.spanner_v1._opentelemetry_tracing import (
add_span_event,
get_current_span,
)


SPANNER_DATA_SCOPE = "https://www.googleapis.com/auth/spanner.data"
Expand Down Expand Up @@ -1164,7 +1168,9 @@ def __init__(

def __enter__(self):
"""Begin ``with`` block."""
current_span = get_current_span()
session = self._session = self._database._pool.get()
add_span_event(current_span, "Using session", {"id": session.session_id})
batch = self._batch = Batch(session)
if self._request_options.transaction_tag:
batch.transaction_tag = self._request_options.transaction_tag
Expand All @@ -1187,6 +1193,12 @@ def __exit__(self, exc_type, exc_val, exc_tb):
extra={"commit_stats": self._batch.commit_stats},
)
self._database._pool.put(self._session)
current_span = get_current_span()
add_span_event(
current_span,
"Returned session to pool",
{"id": self._session.session_id},
)


class MutationGroupsCheckout(object):
Expand Down
173 changes: 166 additions & 7 deletions google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import datetime
import queue
import time

from google.cloud.exceptions import NotFound
from google.cloud.spanner_v1 import BatchCreateSessionsRequest
Expand All @@ -24,6 +25,10 @@
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
)
from google.cloud.spanner_v1._opentelemetry_tracing import (
add_span_event,
get_current_span,
)
from warnings import warn

_NOW = datetime.datetime.utcnow # unit tests may replace
Expand Down Expand Up @@ -196,20 +201,50 @@ def bind(self, database):
when needed.
"""
self._database = database
requested_session_count = self.size - self._sessions.qsize()
span = get_current_span()
span_event_attributes = {"kind": type(self).__name__}

if requested_session_count <= 0:
add_span_event(
span,
f"Invalid session pool size({requested_session_count}) <= 0",
span_event_attributes,
)
return

api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
self._database_role = self._database_role or self._database.database_role
if requested_session_count > 0:
add_span_event(
span,
f"Requesting {requested_session_count} sessions",
span_event_attributes,
)

if self._sessions.full():
add_span_event(span, "Session pool is already full", span_event_attributes)
return

request = BatchCreateSessionsRequest(
database=database.name,
session_count=self.size - self._sessions.qsize(),
session_count=requested_session_count,
session_template=Session(creator_role=self.database_role),
)

returned_session_count = 0
while not self._sessions.full():
request.session_count = requested_session_count - self._sessions.qsize()
add_span_event(
span,
f"Creating {request.session_count} sessions",
span_event_attributes,
)
resp = api.batch_create_sessions(
request=request,
metadata=metadata,
Expand All @@ -218,6 +253,13 @@ def bind(self, database):
session = self._new_session()
session._session_id = session_pb.name.split("/")[-1]
self._sessions.put(session)
returned_session_count += 1

add_span_event(
span,
f"Requested for {requested_session_count} sessions, returned {returned_session_count}",
span_event_attributes,
)

def get(self, timeout=None):
"""Check a session out from the pool.
Expand All @@ -233,12 +275,43 @@ def get(self, timeout=None):
if timeout is None:
timeout = self.default_timeout

session = self._sessions.get(block=True, timeout=timeout)
age = _NOW() - session.last_use_time
start_time = time.time()
current_span = get_current_span()
span_event_attributes = {"kind": type(self).__name__}
add_span_event(current_span, "Acquiring session", span_event_attributes)

if age >= self._max_age and not session.exists():
session = self._database.session()
session.create()
session = None
try:
add_span_event(
current_span,
"Waiting for a session to become available",
span_event_attributes,
)

session = self._sessions.get(block=True, timeout=timeout)
age = _NOW() - session.last_use_time

if age >= self._max_age and not session.exists():
if not session.exists():
add_span_event(
current_span,
"Session is not valid, recreating it",
span_event_attributes,
)
session = self._database.session()
session.create()
# Replacing with the updated session.id.
span_event_attributes["session.id"] = session._session_id

span_event_attributes["session.id"] = session._session_id
span_event_attributes["time.elapsed"] = time.time() - start_time
add_span_event(current_span, "Acquired session", span_event_attributes)

except queue.Empty as e:
add_span_event(
current_span, "No sessions available in the pool", span_event_attributes
)
raise e

return session

Expand Down Expand Up @@ -312,13 +385,32 @@ def get(self):
:returns: an existing session from the pool, or a newly-created
session.
"""
current_span = get_current_span()
span_event_attributes = {"kind": type(self).__name__}
add_span_event(current_span, "Acquiring session", span_event_attributes)

try:
add_span_event(
current_span,
"Waiting for a session to become available",
span_event_attributes,
)
session = self._sessions.get_nowait()
except queue.Empty:
add_span_event(
current_span,
"No sessions available in pool. Creating session",
span_event_attributes,
)
session = self._new_session()
session.create()
else:
if not session.exists():
add_span_event(
current_span,
"Session is not valid, recreating it",
span_event_attributes,
)
session = self._new_session()
session.create()
return session
Expand Down Expand Up @@ -427,6 +519,38 @@ def bind(self, database):
session_template=Session(creator_role=self.database_role),
)

span_event_attributes = {"kind": type(self).__name__}
current_span = get_current_span()
requested_session_count = request.session_count
if requested_session_count <= 0:
add_span_event(
current_span,
f"Invalid session pool size({requested_session_count}) <= 0",
span_event_attributes,
)
return

add_span_event(
current_span,
f"Requesting {requested_session_count} sessions",
span_event_attributes,
)

if created_session_count >= self.size:
add_span_event(
current_span,
"Created no new sessions as sessionPool is full",
span_event_attributes,
)
return

add_span_event(
current_span,
f"Creating {request.session_count} sessions",
span_event_attributes,
)

returned_session_count = 0
while created_session_count < self.size:
resp = api.batch_create_sessions(
request=request,
Expand All @@ -436,8 +560,16 @@ def bind(self, database):
session = self._new_session()
session._session_id = session_pb.name.split("/")[-1]
self.put(session)
returned_session_count += 1

created_session_count += len(resp.session)

add_span_event(
current_span,
f"Requested for {requested_session_count} sessions, return {returned_session_count}",
span_event_attributes,
)

def get(self, timeout=None):
"""Check a session out from the pool.
Expand All @@ -452,7 +584,26 @@ def get(self, timeout=None):
if timeout is None:
timeout = self.default_timeout

ping_after, session = self._sessions.get(block=True, timeout=timeout)
start_time = time.time()
span_event_attributes = {"kind": type(self).__name__}
current_span = get_current_span()
add_span_event(
current_span,
"Waiting for a session to become available",
span_event_attributes,
)

ping_after = None
session = None
try:
ping_after, session = self._sessions.get(block=True, timeout=timeout)
except queue.Empty as e:
add_span_event(
current_span,
"No sessions available in the pool within the specified timeout",
span_event_attributes,
)
raise e

if _NOW() > ping_after:
# Using session.exists() guarantees the returned session exists.
Expand All @@ -462,6 +613,14 @@ def get(self, timeout=None):
session = self._new_session()
session.create()

span_event_attributes.update(
{
"time.elapsed": time.time() - start_time,
"session.id": session._session_id,
"kind": "pinging_pool",
}
)
add_span_event(current_span, "Acquired session", span_event_attributes)
return session

def put(self, session):
Expand Down
Loading

0 comments on commit 43f1f39

Please sign in to comment.