Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

observability: PDML + some batch write spans #1274

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
)
observability_options = getattr(database, "observability_options", None)
with trace_call(
"CloudSpanner.BatchWrite",
"CloudSpanner.batch_write",
self._session,
trace_attributes,
observability_options=observability_options,
Expand Down
173 changes: 106 additions & 67 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,8 @@ def execute_partitioned_dml(
)

def execute_pdml():
with SessionCheckout(self._pool) as session:
def do_execute_pdml(session, span):
add_span_event(span, "Starting BeginTransaction")
txn = api.begin_transaction(
session=session.name, options=txn_options, metadata=metadata
)
Expand Down Expand Up @@ -732,6 +733,13 @@ def execute_pdml():

return result_set.stats.row_count_lower_bound

with trace_call(
"CloudSpanner.Database.execute_partitioned_pdml",
observability_options=self.observability_options,
) as span:
with SessionCheckout(self._pool) as session:
return do_execute_pdml(session, span)

return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)()

def session(self, labels=None, database_role=None):
Expand Down Expand Up @@ -1349,6 +1357,10 @@ def to_dict(self):
"transaction_id": snapshot._transaction_id,
}

@property
def observability_options(self):
return getattr(self._database, "observability_options", {})

def _get_session(self):
"""Create session as needed.

Expand Down Expand Up @@ -1468,27 +1480,32 @@ def generate_read_batches(
mappings of information used perform actual partitioned reads via
:meth:`process_read_batch`.
"""
partitions = self._get_snapshot().partition_read(
table=table,
columns=columns,
keyset=keyset,
index=index,
partition_size_bytes=partition_size_bytes,
max_partitions=max_partitions,
retry=retry,
timeout=timeout,
)
with trace_call(
f"CloudSpanner.{type(self).__name__}.generate_read_batches",
extra_attributes=dict(table=table, columns=columns),
observability_options=self.observability_options,
):
partitions = self._get_snapshot().partition_read(
table=table,
columns=columns,
keyset=keyset,
index=index,
partition_size_bytes=partition_size_bytes,
max_partitions=max_partitions,
retry=retry,
timeout=timeout,
)

read_info = {
"table": table,
"columns": columns,
"keyset": keyset._to_dict(),
"index": index,
"data_boost_enabled": data_boost_enabled,
"directed_read_options": directed_read_options,
}
for partition in partitions:
yield {"partition": partition, "read": read_info.copy()}
read_info = {
"table": table,
"columns": columns,
"keyset": keyset._to_dict(),
"index": index,
"data_boost_enabled": data_boost_enabled,
"directed_read_options": directed_read_options,
}
for partition in partitions:
yield {"partition": partition, "read": read_info.copy()}

def process_read_batch(
self,
Expand All @@ -1514,12 +1531,17 @@ def process_read_batch(
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
kwargs = copy.deepcopy(batch["read"])
keyset_dict = kwargs.pop("keyset")
kwargs["keyset"] = KeySet._from_dict(keyset_dict)
return self._get_snapshot().read(
partition=batch["partition"], **kwargs, retry=retry, timeout=timeout
)
observability_options = self.observability_options or {}
with trace_call(
f"CloudSpanner.{type(self).__name__}.process_read_batch",
observability_options=observability_options,
):
kwargs = copy.deepcopy(batch["read"])
keyset_dict = kwargs.pop("keyset")
kwargs["keyset"] = KeySet._from_dict(keyset_dict)
return self._get_snapshot().read(
partition=batch["partition"], **kwargs, retry=retry, timeout=timeout
)

def generate_query_batches(
self,
Expand Down Expand Up @@ -1594,34 +1616,39 @@ def generate_query_batches(
mappings of information used perform actual partitioned reads via
:meth:`process_read_batch`.
"""
partitions = self._get_snapshot().partition_query(
sql=sql,
params=params,
param_types=param_types,
partition_size_bytes=partition_size_bytes,
max_partitions=max_partitions,
retry=retry,
timeout=timeout,
)
with trace_call(
f"CloudSpanner.{type(self).__name__}.generate_query_batches",
extra_attributes=dict(sql=sql),
observability_options=self.observability_options,
):
partitions = self._get_snapshot().partition_query(
sql=sql,
params=params,
param_types=param_types,
partition_size_bytes=partition_size_bytes,
max_partitions=max_partitions,
retry=retry,
timeout=timeout,
)

query_info = {
"sql": sql,
"data_boost_enabled": data_boost_enabled,
"directed_read_options": directed_read_options,
}
if params:
query_info["params"] = params
query_info["param_types"] = param_types

# Query-level options have higher precedence than client-level and
# environment-level options
default_query_options = self._database._instance._client._query_options
query_info["query_options"] = _merge_query_options(
default_query_options, query_options
)
query_info = {
"sql": sql,
"data_boost_enabled": data_boost_enabled,
"directed_read_options": directed_read_options,
}
if params:
query_info["params"] = params
query_info["param_types"] = param_types

# Query-level options have higher precedence than client-level and
# environment-level options
default_query_options = self._database._instance._client._query_options
query_info["query_options"] = _merge_query_options(
default_query_options, query_options
)

for partition in partitions:
yield {"partition": partition, "query": query_info}
for partition in partitions:
yield {"partition": partition, "query": query_info}

def process_query_batch(
self,
Expand All @@ -1646,9 +1673,16 @@ def process_query_batch(
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
return self._get_snapshot().execute_sql(
partition=batch["partition"], **batch["query"], retry=retry, timeout=timeout
)
with trace_call(
f"CloudSpanner.{type(self).__name__}.process_query_batch",
observability_options=self.observability_options,
):
return self._get_snapshot().execute_sql(
partition=batch["partition"],
**batch["query"],
retry=retry,
timeout=timeout,
)

def run_partitioned_query(
self,
Expand Down Expand Up @@ -1703,18 +1737,23 @@ def run_partitioned_query(
:rtype: :class:`~google.cloud.spanner_v1.merged_result_set.MergedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
partitions = list(
self.generate_query_batches(
sql,
params,
param_types,
partition_size_bytes,
max_partitions,
query_options,
data_boost_enabled,
with trace_call(
f"CloudSpanner.${type(self).__name__}.run_partitioned_query",
extra_attributes=dict(sql=sql),
observability_options=self.observability_options,
):
partitions = list(
self.generate_query_batches(
sql,
params,
param_types,
partition_size_bytes,
max_partitions,
query_options,
data_boost_enabled,
)
)
)
return MergedResultSet(self, partitions, 0)
return MergedResultSet(self, partitions, 0)

def process(self, batch):
"""Process a single, partitioned query or read.
Expand Down
10 changes: 10 additions & 0 deletions google/cloud/spanner_v1/merged_result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ def __init__(self, batch_snapshot, partition_id, merged_result_set):
self._queue: Queue[PartitionExecutorResult] = merged_result_set._queue

def run(self):
observability_options = getattr(
self._batch_snapshot, "observability_options", {}
)
with trace_call(
"CloudSpanner.PartitionExecutor.run",
observability_options=observability_options,
):
return self.__run()

def __run(self):
results = None
try:
results = self._batch_snapshot.process_query_batch(self._partition_id)
Expand Down
29 changes: 9 additions & 20 deletions google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,12 +523,11 @@ def bind(self, database):
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
created_session_count = 0
self._database_role = self._database_role or self._database.database_role

request = BatchCreateSessionsRequest(
database=database.name,
session_count=self.size - created_session_count,
session_count=self.size,
session_template=Session(creator_role=self.database_role),
)

Expand All @@ -549,38 +548,28 @@ def bind(self, database):
span_event_attributes,
)

if created_session_count >= self.size:
add_span_event(
current_span,
"Created no new sessions as sessionPool is full",
span_event_attributes,
)
return

add_span_event(
current_span,
f"Creating {request.session_count} sessions",
span_event_attributes,
)

observability_options = getattr(self._database, "observability_options", None)
with trace_call(
"CloudSpanner.PingingPool.BatchCreateSessions",
observability_options=observability_options,
) as span:
returned_session_count = 0
while created_session_count < self.size:
while returned_session_count < self.size:
resp = api.batch_create_sessions(
request=request,
metadata=metadata,
)

add_span_event(
span,
f"Created {len(resp.session)} sessions",
)

for session_pb in resp.session:
session = self._new_session()
returned_session_count += 1
session._session_id = session_pb.name.split("/")[-1]
self.put(session)
returned_session_count += 1

created_session_count += len(resp.session)

add_span_event(
span,
Expand Down
1 change: 1 addition & 0 deletions google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ def run_in_transaction(self, func, *args, **kw):
) as span:
while True:
if self._transaction is None:
add_span_event(span, "Creating Transaction")
txn = self.transaction()
txn.transaction_tag = transaction_tag
txn.exclude_txn_from_change_streams = (
Expand Down
9 changes: 6 additions & 3 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,10 +675,13 @@ def partition_read(
)

trace_attributes = {"table_id": table, "columns": columns}
can_include_index = (index != "") and (index is not None)
if can_include_index:
trace_attributes["index"] = index

with trace_call(
f"CloudSpanner.{type(self).__name__}.partition_read",
self._session,
trace_attributes,
extra_attributes=trace_attributes,
observability_options=getattr(database, "observability_options", None),
):
method = functools.partial(
Expand Down Expand Up @@ -779,7 +782,7 @@ def partition_query(

trace_attributes = {"db.statement": sql}
with trace_call(
"CloudSpanner.PartitionReadWriteTransaction",
f"CloudSpanner.{type(self).__name__}.partition_query",
self._session,
trace_attributes,
observability_options=getattr(database, "observability_options", None),
Expand Down
19 changes: 18 additions & 1 deletion tests/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def assertSpanAttributes(
):
if HAS_OPENTELEMETRY_INSTALLED:
if not span:
span_list = self.ot_exporter.get_finished_spans()
span_list = self.get_finished_spans()
self.assertEqual(len(span_list) > 0, True)
span = span_list[0]

Expand Down Expand Up @@ -132,3 +132,20 @@ def get_finished_spans(self):

def reset(self):
self.tearDown()

def finished_spans_events_statuses(self):
span_list = self.get_finished_spans()
# Some event attributes are noisy/highly ephemeral
# and can't be directly compared against.
got_all_events = []
imprecise_event_attributes = ["exception.stacktrace", "delay_seconds", "cause"]
for span in span_list:
for event in span.events:
evt_attributes = event.attributes.copy()
for attr_name in imprecise_event_attributes:
if attr_name in evt_attributes:
evt_attributes[attr_name] = "EPHEMERAL"

got_all_events.append((event.name, evt_attributes))

return got_all_events
Loading