Skip to content

Commit

Permalink
observability: PDML + some batch write spans
Browse files Browse the repository at this point in the history
This change adds spans for Partitioned DML and making
updates for Batch.

Carved out from PR #1241.
  • Loading branch information
odeke-em committed Dec 17, 2024
1 parent ad69c48 commit 6777f97
Show file tree
Hide file tree
Showing 13 changed files with 260 additions and 123 deletions.
2 changes: 1 addition & 1 deletion google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
)
observability_options = getattr(database, "observability_options", None)
with trace_call(
"CloudSpanner.BatchWrite",
"CloudSpanner.batch_write",
self._session,
trace_attributes,
observability_options=observability_options,
Expand Down
173 changes: 106 additions & 67 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,8 @@ def execute_partitioned_dml(
)

def execute_pdml():
with SessionCheckout(self._pool) as session:
def do_execute_pdml(session, span):
add_span_event(span, "Starting BeginTransaction")
txn = api.begin_transaction(
session=session.name, options=txn_options, metadata=metadata
)
Expand Down Expand Up @@ -732,6 +733,13 @@ def execute_pdml():

return result_set.stats.row_count_lower_bound

with trace_call(
"CloudSpanner.Database.execute_partitioned_pdml",
observability_options=self.observability_options,
) as span:
with SessionCheckout(self._pool) as session:
return do_execute_pdml(session, span)

return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)()

def session(self, labels=None, database_role=None):
Expand Down Expand Up @@ -1349,6 +1357,10 @@ def to_dict(self):
"transaction_id": snapshot._transaction_id,
}

@property
def observability_options(self):
return getattr(self._database, "observability_options", {})

def _get_session(self):
"""Create session as needed.
Expand Down Expand Up @@ -1468,27 +1480,32 @@ def generate_read_batches(
mappings of information used perform actual partitioned reads via
:meth:`process_read_batch`.
"""
partitions = self._get_snapshot().partition_read(
table=table,
columns=columns,
keyset=keyset,
index=index,
partition_size_bytes=partition_size_bytes,
max_partitions=max_partitions,
retry=retry,
timeout=timeout,
)
with trace_call(
f"CloudSpanner.{type(self).__name__}.generate_read_batches",
extra_attributes=dict(table=table, columns=columns),
observability_options=self.observability_options,
):
partitions = self._get_snapshot().partition_read(
table=table,
columns=columns,
keyset=keyset,
index=index,
partition_size_bytes=partition_size_bytes,
max_partitions=max_partitions,
retry=retry,
timeout=timeout,
)

read_info = {
"table": table,
"columns": columns,
"keyset": keyset._to_dict(),
"index": index,
"data_boost_enabled": data_boost_enabled,
"directed_read_options": directed_read_options,
}
for partition in partitions:
yield {"partition": partition, "read": read_info.copy()}
read_info = {
"table": table,
"columns": columns,
"keyset": keyset._to_dict(),
"index": index,
"data_boost_enabled": data_boost_enabled,
"directed_read_options": directed_read_options,
}
for partition in partitions:
yield {"partition": partition, "read": read_info.copy()}

def process_read_batch(
self,
Expand All @@ -1514,12 +1531,17 @@ def process_read_batch(
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
kwargs = copy.deepcopy(batch["read"])
keyset_dict = kwargs.pop("keyset")
kwargs["keyset"] = KeySet._from_dict(keyset_dict)
return self._get_snapshot().read(
partition=batch["partition"], **kwargs, retry=retry, timeout=timeout
)
observability_options = self.observability_options or {}
with trace_call(
f"CloudSpanner.{type(self).__name__}.process_read_batch",
observability_options=observability_options,
):
kwargs = copy.deepcopy(batch["read"])
keyset_dict = kwargs.pop("keyset")
kwargs["keyset"] = KeySet._from_dict(keyset_dict)
return self._get_snapshot().read(
partition=batch["partition"], **kwargs, retry=retry, timeout=timeout
)

def generate_query_batches(
self,
Expand Down Expand Up @@ -1594,34 +1616,39 @@ def generate_query_batches(
mappings of information used perform actual partitioned reads via
:meth:`process_read_batch`.
"""
partitions = self._get_snapshot().partition_query(
sql=sql,
params=params,
param_types=param_types,
partition_size_bytes=partition_size_bytes,
max_partitions=max_partitions,
retry=retry,
timeout=timeout,
)
with trace_call(
f"CloudSpanner.{type(self).__name__}.generate_query_batches",
extra_attributes=dict(sql=sql),
observability_options=self.observability_options,
):
partitions = self._get_snapshot().partition_query(
sql=sql,
params=params,
param_types=param_types,
partition_size_bytes=partition_size_bytes,
max_partitions=max_partitions,
retry=retry,
timeout=timeout,
)

query_info = {
"sql": sql,
"data_boost_enabled": data_boost_enabled,
"directed_read_options": directed_read_options,
}
if params:
query_info["params"] = params
query_info["param_types"] = param_types

# Query-level options have higher precedence than client-level and
# environment-level options
default_query_options = self._database._instance._client._query_options
query_info["query_options"] = _merge_query_options(
default_query_options, query_options
)
query_info = {
"sql": sql,
"data_boost_enabled": data_boost_enabled,
"directed_read_options": directed_read_options,
}
if params:
query_info["params"] = params
query_info["param_types"] = param_types

# Query-level options have higher precedence than client-level and
# environment-level options
default_query_options = self._database._instance._client._query_options
query_info["query_options"] = _merge_query_options(
default_query_options, query_options
)

for partition in partitions:
yield {"partition": partition, "query": query_info}
for partition in partitions:
yield {"partition": partition, "query": query_info}

def process_query_batch(
self,
Expand All @@ -1646,9 +1673,16 @@ def process_query_batch(
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
return self._get_snapshot().execute_sql(
partition=batch["partition"], **batch["query"], retry=retry, timeout=timeout
)
with trace_call(
f"CloudSpanner.{type(self).__name__}.process_query_batch",
observability_options=self.observability_options,
):
return self._get_snapshot().execute_sql(
partition=batch["partition"],
**batch["query"],
retry=retry,
timeout=timeout,
)

def run_partitioned_query(
self,
Expand Down Expand Up @@ -1703,18 +1737,23 @@ def run_partitioned_query(
:rtype: :class:`~google.cloud.spanner_v1.merged_result_set.MergedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
partitions = list(
self.generate_query_batches(
sql,
params,
param_types,
partition_size_bytes,
max_partitions,
query_options,
data_boost_enabled,
with trace_call(
f"CloudSpanner.${type(self).__name__}.run_partitioned_query",
extra_attributes=dict(sql=sql),
observability_options=self.observability_options,
):
partitions = list(
self.generate_query_batches(
sql,
params,
param_types,
partition_size_bytes,
max_partitions,
query_options,
data_boost_enabled,
)
)
)
return MergedResultSet(self, partitions, 0)
return MergedResultSet(self, partitions, 0)

def process(self, batch):
"""Process a single, partitioned query or read.
Expand Down
10 changes: 10 additions & 0 deletions google/cloud/spanner_v1/merged_result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ def __init__(self, batch_snapshot, partition_id, merged_result_set):
self._queue: Queue[PartitionExecutorResult] = merged_result_set._queue

def run(self):
observability_options = getattr(
self._batch_snapshot, "observability_options", {}
)
with trace_call(
"CloudSpanner.PartitionExecutor.run",
observability_options=observability_options,
):
return self.__run()

def __run(self):
results = None
try:
results = self._batch_snapshot.process_query_batch(self._partition_id)
Expand Down
29 changes: 9 additions & 20 deletions google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,12 +523,11 @@ def bind(self, database):
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
created_session_count = 0
self._database_role = self._database_role or self._database.database_role

request = BatchCreateSessionsRequest(
database=database.name,
session_count=self.size - created_session_count,
session_count=self.size,
session_template=Session(creator_role=self.database_role),
)

Expand All @@ -549,38 +548,28 @@ def bind(self, database):
span_event_attributes,
)

if created_session_count >= self.size:
add_span_event(
current_span,
"Created no new sessions as sessionPool is full",
span_event_attributes,
)
return

add_span_event(
current_span,
f"Creating {request.session_count} sessions",
span_event_attributes,
)

observability_options = getattr(self._database, "observability_options", None)
with trace_call(
"CloudSpanner.PingingPool.BatchCreateSessions",
observability_options=observability_options,
) as span:
returned_session_count = 0
while created_session_count < self.size:
while returned_session_count < self.size:
resp = api.batch_create_sessions(
request=request,
metadata=metadata,
)

add_span_event(
span,
f"Created {len(resp.session)} sessions",
)

for session_pb in resp.session:
session = self._new_session()
returned_session_count += 1
session._session_id = session_pb.name.split("/")[-1]
self.put(session)
returned_session_count += 1

created_session_count += len(resp.session)

add_span_event(
span,
Expand Down
1 change: 1 addition & 0 deletions google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ def run_in_transaction(self, func, *args, **kw):
) as span:
while True:
if self._transaction is None:
add_span_event(span, "Creating Transaction")
txn = self.transaction()
txn.transaction_tag = transaction_tag
txn.exclude_txn_from_change_streams = (
Expand Down
9 changes: 6 additions & 3 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,10 +675,13 @@ def partition_read(
)

trace_attributes = {"table_id": table, "columns": columns}
can_include_index = (index != "") and (index is not None)
if can_include_index:
trace_attributes["index"] = index

with trace_call(
f"CloudSpanner.{type(self).__name__}.partition_read",
self._session,
trace_attributes,
extra_attributes=trace_attributes,
observability_options=getattr(database, "observability_options", None),
):
method = functools.partial(
Expand Down Expand Up @@ -779,7 +782,7 @@ def partition_query(

trace_attributes = {"db.statement": sql}
with trace_call(
"CloudSpanner.PartitionReadWriteTransaction",
f"CloudSpanner.{type(self).__name__}.partition_query",
self._session,
trace_attributes,
observability_options=getattr(database, "observability_options", None),
Expand Down
2 changes: 1 addition & 1 deletion tests/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def assertSpanAttributes(
):
if HAS_OPENTELEMETRY_INSTALLED:
if not span:
span_list = self.ot_exporter.get_finished_spans()
span_list = self.get_finished_spans()
self.assertEqual(len(span_list) > 0, True)
span = span_list[0]

Expand Down
3 changes: 2 additions & 1 deletion tests/system/test_observability_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
not HAS_OTEL_INSTALLED, reason="OpenTelemetry is necessary to test traces."
)
@pytest.mark.skipif(
not _helpers.USE_EMULATOR, reason="mulator is necessary to test traces."
not _helpers.USE_EMULATOR, reason="Emulator is necessary to test traces."
)
def test_observability_options_propagation():
PROJECT = _helpers.EMULATOR_PROJECT
Expand Down Expand Up @@ -108,6 +108,7 @@ def test_propagation(enable_extended_tracing):
wantNames = [
"CloudSpanner.CreateSession",
"CloudSpanner.Snapshot.execute_streaming_sql",
"CloudSpanner.Database.snapshot",
]
assert gotNames == wantNames

Expand Down
Loading

0 comments on commit 6777f97

Please sign in to comment.