Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add updated span events + trace more methods #1259

12 changes: 8 additions & 4 deletions google/cloud/spanner_v1/_opentelemetry_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def get_tracer(tracer_provider=None):


@contextmanager
def trace_call(name, session, extra_attributes=None, observability_options=None):
def trace_call(name, session=None, extra_attributes=None, observability_options=None):
odeke-em marked this conversation as resolved.
Show resolved Hide resolved
if session:
session._last_use_time = datetime.now()

if not HAS_OPENTELEMETRY_INSTALLED or not session:
if not (HAS_OPENTELEMETRY_INSTALLED and name):
# Empty context manager. Users will have to check if the generated value is None or a span
yield None
return
Expand All @@ -72,20 +72,24 @@ def trace_call(name, session, extra_attributes=None, observability_options=None)
# on by default.
enable_extended_tracing = True

db_name = ""
if session and getattr(session, "_database", None):
db_name = session._database.name

if isinstance(observability_options, dict): # Avoid false positives with mock.Mock
tracer_provider = observability_options.get("tracer_provider", None)
enable_extended_tracing = observability_options.get(
"enable_extended_tracing", enable_extended_tracing
)
db_name = observability_options.get("db_name", db_name)

tracer = get_tracer(tracer_provider)

# Set base attributes that we know for every trace created
db = session._database
attributes = {
"db.type": "spanner",
"db.url": SpannerClient.DEFAULT_ENDPOINT,
"db.instance": "" if not db else db.name,
"db.instance": db_name,
"net.host.name": SpannerClient.DEFAULT_ENDPOINT,
OTEL_SCOPE_NAME: TRACER_NAME,
OTEL_SCOPE_VERSION: TRACER_VERSION,
Expand Down
12 changes: 11 additions & 1 deletion google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def insert(self, table, columns, values):
:param values: Values to be modified.
"""
self._mutations.append(Mutation(insert=_make_write_pb(table, columns, values)))
# TODO: Decide if we should add a span event per mutation:
# https://github.com/googleapis/python-spanner/issues/1269

def update(self, table, columns, values):
"""Update one or more existing table rows.
Expand All @@ -84,6 +86,8 @@ def update(self, table, columns, values):
:param values: Values to be modified.
"""
self._mutations.append(Mutation(update=_make_write_pb(table, columns, values)))
# TODO: Decide if we should add a span event per mutation:
# https://github.com/googleapis/python-spanner/issues/1269

def insert_or_update(self, table, columns, values):
"""Insert/update one or more table rows.
Expand All @@ -100,6 +104,8 @@ def insert_or_update(self, table, columns, values):
self._mutations.append(
Mutation(insert_or_update=_make_write_pb(table, columns, values))
)
# TODO: Decide if we should add a span event per mutation:
# https://github.com/googleapis/python-spanner/issues/1269

def replace(self, table, columns, values):
"""Replace one or more table rows.
Expand All @@ -114,6 +120,8 @@ def replace(self, table, columns, values):
:param values: Values to be modified.
"""
self._mutations.append(Mutation(replace=_make_write_pb(table, columns, values)))
# TODO: Decide if we should add a span event per mutation:
# https://github.com/googleapis/python-spanner/issues/1269

def delete(self, table, keyset):
"""Delete one or more table rows.
Expand All @@ -126,6 +134,8 @@ def delete(self, table, keyset):
"""
delete = Mutation.Delete(table=table, key_set=keyset._to_pb())
self._mutations.append(Mutation(delete=delete))
# TODO: Decide if we should add a span event per mutation:
# https://github.com/googleapis/python-spanner/issues/1269


class Batch(_BatchBase):
Expand Down Expand Up @@ -207,7 +217,7 @@ def commit(
)
observability_options = getattr(database, "observability_options", None)
with trace_call(
"CloudSpanner.Commit",
f"CloudSpanner.{type(self).__name__}.commit",
self._session,
trace_attributes,
observability_options=observability_options,
Expand Down
42 changes: 27 additions & 15 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from google.cloud.spanner_v1._opentelemetry_tracing import (
add_span_event,
get_current_span,
trace_call,
)


Expand Down Expand Up @@ -720,6 +721,7 @@ def execute_pdml():

iterator = _restart_on_unavailable(
method=method,
trace_name="CloudSpanner.ExecuteStreamingSql",
request=request,
transaction_selector=txn_selector,
observability_options=self.observability_options,
Expand Down Expand Up @@ -881,20 +883,25 @@ def run_in_transaction(self, func, *args, **kw):
:raises Exception:
reraises any non-ABORT exceptions raised by ``func``.
"""
# Sanity check: Is there a transaction already running?
# If there is, then raise a red flag. Otherwise, mark that this one
# is running.
if getattr(self._local, "transaction_running", False):
raise RuntimeError("Spanner does not support nested transactions.")
self._local.transaction_running = True

# Check out a session and run the function in a transaction; once
# done, flip the sanity check bit back.
try:
with SessionCheckout(self._pool) as session:
return session.run_in_transaction(func, *args, **kw)
finally:
self._local.transaction_running = False
observability_options = getattr(self, "observability_options", None)
with trace_call(
"CloudSpanner.Database.run_in_transaction",
observability_options=observability_options,
):
# Sanity check: Is there a transaction already running?
# If there is, then raise a red flag. Otherwise, mark that this one
# is running.
if getattr(self._local, "transaction_running", False):
raise RuntimeError("Spanner does not support nested transactions.")
self._local.transaction_running = True

# Check out a session and run the function in a transaction; once
# done, flip the sanity check bit back.
try:
with SessionCheckout(self._pool) as session:
return session.run_in_transaction(func, *args, **kw)
finally:
self._local.transaction_running = False

def restore(self, source):
"""Restore from a backup to this database.
Expand Down Expand Up @@ -1120,7 +1127,12 @@ def observability_options(self):
if not (self._instance and self._instance._client):
return None

return getattr(self._instance._client, "observability_options", None)
opts = getattr(self._instance._client, "observability_options", None)
if not opts:
opts = dict()

opts["db_name"] = self.name
return opts


class BatchCheckout(object):
Expand Down
90 changes: 54 additions & 36 deletions google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from google.cloud.spanner_v1._opentelemetry_tracing import (
add_span_event,
get_current_span,
trace_call,
)
from warnings import warn

Expand Down Expand Up @@ -237,29 +238,41 @@ def bind(self, database):
session_template=Session(creator_role=self.database_role),
)

returned_session_count = 0
while not self._sessions.full():
request.session_count = requested_session_count - self._sessions.qsize()
observability_options = getattr(self._database, "observability_options", None)
with trace_call(
"CloudSpanner.FixedPool.BatchCreateSessions",
observability_options=observability_options,
) as span:
returned_session_count = 0
while not self._sessions.full():
request.session_count = requested_session_count - self._sessions.qsize()
add_span_event(
span,
f"Creating {request.session_count} sessions",
span_event_attributes,
)
resp = api.batch_create_sessions(
request=request,
metadata=metadata,
)

add_span_event(
span,
"Created sessions",
dict(count=len(resp.session)),
)

for session_pb in resp.session:
session = self._new_session()
session._session_id = session_pb.name.split("/")[-1]
self._sessions.put(session)
returned_session_count += 1

add_span_event(
harshachinta marked this conversation as resolved.
Show resolved Hide resolved
span,
f"Creating {request.session_count} sessions",
f"Requested for {requested_session_count} sessions, returned {returned_session_count}",
span_event_attributes,
)
resp = api.batch_create_sessions(
request=request,
metadata=metadata,
)
for session_pb in resp.session:
session = self._new_session()
session._session_id = session_pb.name.split("/")[-1]
self._sessions.put(session)
returned_session_count += 1

add_span_event(
span,
f"Requested for {requested_session_count} sessions, returned {returned_session_count}",
span_event_attributes,
)

def get(self, timeout=None):
"""Check a session out from the pool.
Expand Down Expand Up @@ -550,25 +563,30 @@ def bind(self, database):
span_event_attributes,
)

returned_session_count = 0
while created_session_count < self.size:
resp = api.batch_create_sessions(
request=request,
metadata=metadata,
)
for session_pb in resp.session:
session = self._new_session()
session._session_id = session_pb.name.split("/")[-1]
self.put(session)
returned_session_count += 1
observability_options = getattr(self._database, "observability_options", None)
with trace_call(
"CloudSpanner.PingingPool.BatchCreateSessions",
observability_options=observability_options,
) as span:
returned_session_count = 0
while created_session_count < self.size:
resp = api.batch_create_sessions(
request=request,
metadata=metadata,
)
for session_pb in resp.session:
session = self._new_session()
session._session_id = session_pb.name.split("/")[-1]
self.put(session)
returned_session_count += 1

created_session_count += len(resp.session)
created_session_count += len(resp.session)

add_span_event(
current_span,
f"Requested for {requested_session_count} sessions, return {returned_session_count}",
span_event_attributes,
)
add_span_event(
harshachinta marked this conversation as resolved.
Show resolved Hide resolved
span,
f"Requested for {requested_session_count} sessions, returned {returned_session_count}",
span_event_attributes,
)

def get(self, timeout=None):
"""Check a session out from the pool.
Expand Down
Loading
Loading