Skip to content

Commit

Permalink
feat: support transaction and request tags in dbapi
Browse files Browse the repository at this point in the history
Adds support for setting transaction tags and request tags in dbapi.
This makes these options available to frameworks that depend on
dbapi, like SQLAlchemy and Django.

Towards googleapis/python-spanner-sqlalchemy#525
  • Loading branch information
olavloite committed Dec 9, 2024
1 parent a6811af commit 9031124
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 6 deletions.
35 changes: 32 additions & 3 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(self, instance, database=None, read_only=False, **kwargs):
self.request_priority = None
self._transaction_begin_marked = False
# whether transaction started at Spanner. This means that we had
# made atleast one call to Spanner.
# made at least one call to Spanner.
self._spanner_transaction_started = False
self._batch_mode = BatchMode.NONE
self._batch_dml_executor: BatchDmlExecutor = None
Expand Down Expand Up @@ -261,6 +261,28 @@ def request_options(self):
self.request_priority = None
return req_opts

@property
def transaction_tag(self):
"""The transaction tag that will be applied to the next read/write
transaction on this `Connection`. This property is automatically cleared
when a new transaction is started.
Returns:
str: The transaction tag that will be applied to the next read/write transaction.
"""
return self._connection_variables.get("transaction_tag", None)

@transaction_tag.setter
def transaction_tag(self, value):
"""Sets the transaction tag for the next read/write transaction on this
`Connection`. This property is automatically cleared when a new transaction
is started.
Args:
value (str): The transaction tag for the next read/write transaction.
"""
self._connection_variables["transaction_tag"] = value

@property
def staleness(self):
"""Current read staleness option value of this `Connection`.
Expand Down Expand Up @@ -340,6 +362,8 @@ def transaction_checkout(self):
if not self.read_only and self._client_transaction_started:
if not self._spanner_transaction_started:
self._transaction = self._session_checkout().transaction()
self._transaction.transaction_tag = self.transaction_tag
self.transaction_tag = None
self._snapshot = None
self._spanner_transaction_started = True
self._transaction.begin()
Expand Down Expand Up @@ -458,7 +482,9 @@ def run_prior_DDL_statements(self):

return self.database.update_ddl(ddl_statements).result()

def run_statement(self, statement: Statement):
def run_statement(
self, statement: Statement, request_options: RequestOptions = None
):
"""Run single SQL statement in begun transaction.
This method is never used in autocommit mode. In
Expand All @@ -472,6 +498,9 @@ def run_statement(self, statement: Statement):
:param retried: (Optional) Retry the SQL statement if statement
execution failed. Defaults to false.
:type request_options: :class:`RequestOptions`
:param request_options: Request options to use for this statement.
:rtype: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet`,
:class:`google.cloud.spanner_dbapi.checksum.ResultsChecksum`
:returns: Streamed result set of the statement and a
Expand All @@ -482,7 +511,7 @@ def run_statement(self, statement: Statement):
statement.sql,
statement.params,
param_types=statement.param_types,
request_options=self.request_options,
request_options=request_options or self.request_options,
)

@check_not_closed
Expand Down
42 changes: 39 additions & 3 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from google.cloud.spanner_dbapi.transaction_helper import CursorStatementType
from google.cloud.spanner_dbapi.utils import PeekIterator
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets
from google.cloud.spanner_v1 import RequestOptions
from google.cloud.spanner_v1.merged_result_set import MergedResultSet

ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"])
Expand Down Expand Up @@ -97,6 +98,39 @@ def __init__(self, connection):
self._parsed_statement: ParsedStatement = None
self._in_retry_mode = False
self._batch_dml_rows_count = None
self._request_tag = None

@property
def request_tag(self):
"""The request tag that will be applied to the next statement on this
cursor. This property is automatically cleared when a statement is
executed.
Returns:
str: The request tag that will be applied to the next statement on
this cursor.
"""
return self._request_tag

@request_tag.setter
def request_tag(self, value):
"""Sets the request tag for the next statement on this cursor. This
property is automatically cleared when a statement is executed.
Args:
value (str): The request tag for the statement.
"""
self._request_tag = value

@property
def request_options(self):
options = self.connection.request_options
if self._request_tag:
if not options:
options = RequestOptions()
options.request_tag = self._request_tag
self._request_tag = None
return options

@property
def is_closed(self):
Expand Down Expand Up @@ -284,7 +318,7 @@ def _execute(self, sql, args=None, call_from_execute_many=False):
sql,
params=args,
param_types=self._parsed_statement.statement.param_types,
request_options=self.connection.request_options,
request_options=self.request_options,
)
self._result_set = None
else:
Expand Down Expand Up @@ -318,7 +352,9 @@ def _execute_in_rw_transaction(self):
if self.connection._client_transaction_started:
while True:
try:
self._result_set = self.connection.run_statement(statement)
self._result_set = self.connection.run_statement(
statement, self.request_options
)
self._itr = PeekIterator(self._result_set)
return
except Aborted:
Expand Down Expand Up @@ -478,7 +514,7 @@ def _handle_DQL_with_snapshot(self, snapshot, sql, params):
sql,
params,
get_param_types(params),
request_options=self.connection.request_options,
request_options=self.request_options,
)
# Read the first element so that the StreamedResultSet can
# return the metadata after a DQL statement.
Expand Down
180 changes: 180 additions & 0 deletions tests/mockserver_tests/test_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud.spanner_dbapi import Connection
from google.cloud.spanner_v1 import (
BatchCreateSessionsRequest,
ExecuteSqlRequest,
BeginTransactionRequest,
TypeCode,
CommitRequest,
)
from tests.mockserver_tests.mock_server_test_base import (
MockServerTestBase,
add_single_result,
)


class TestTags(MockServerTestBase):
@classmethod
def setup_class(cls):
super().setup_class()
add_single_result(
"select name from singers", "name", TypeCode.STRING, [("Some Singer",)]
)

def test_select_autocommit_no_tags(self):
connection = Connection(self.instance, self.database)
connection.autocommit = True
request = self._execute_and_verify_select_singers(connection)
self.assertEqual("", request.request_options.request_tag)
self.assertEqual("", request.request_options.transaction_tag)

def test_select_autocommit_with_request_tag(self):
connection = Connection(self.instance, self.database)
connection.autocommit = True
request = self._execute_and_verify_select_singers(
connection, request_tag="my_tag"
)
self.assertEqual("my_tag", request.request_options.request_tag)
self.assertEqual("", request.request_options.transaction_tag)

def test_select_read_only_transaction_no_tags(self):
connection = Connection(self.instance, self.database)
connection.autocommit = False
connection.read_only = True
request = self._execute_and_verify_select_singers(connection)
self.assertEqual("", request.request_options.request_tag)
self.assertEqual("", request.request_options.transaction_tag)

def test_select_read_only_transaction_with_request_tag(self):
connection = Connection(self.instance, self.database)
connection.autocommit = False
connection.read_only = True
request = self._execute_and_verify_select_singers(
connection, request_tag="my_tag"
)
self.assertEqual("my_tag", request.request_options.request_tag)
self.assertEqual("", request.request_options.transaction_tag)

def test_select_read_write_transaction_no_tags(self):
connection = Connection(self.instance, self.database)
connection.autocommit = False
request = self._execute_and_verify_select_singers(connection)
self.assertEqual("", request.request_options.request_tag)
self.assertEqual("", request.request_options.transaction_tag)

def test_select_read_write_transaction_with_request_tag(self):
connection = Connection(self.instance, self.database)
connection.autocommit = False
request = self._execute_and_verify_select_singers(
connection, request_tag="my_tag"
)
self.assertEqual("my_tag", request.request_options.request_tag)
self.assertEqual("", request.request_options.transaction_tag)

def test_select_read_write_transaction_with_transaction_tag(self):
connection = Connection(self.instance, self.database)
connection.autocommit = False
connection.transaction_tag = "my_transaction_tag"
# The transaction tag should be included for all statements in the transaction.
self._execute_and_verify_select_singers(connection)
self._execute_and_verify_select_singers(connection)

# The transaction tag was cleared from the connection when the transaction
# was started.
self.assertIsNone(connection.transaction_tag)
# The commit call should also include a transaction tag.
connection.commit()
requests = self.spanner_service.requests
self.assertEqual(5, len(requests))
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
self.assertTrue(isinstance(requests[1], BeginTransactionRequest))
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
self.assertTrue(isinstance(requests[3], ExecuteSqlRequest))
self.assertTrue(isinstance(requests[4], CommitRequest))
self.assertEqual(
"my_transaction_tag", requests[2].request_options.transaction_tag
)
self.assertEqual(
"my_transaction_tag", requests[3].request_options.transaction_tag
)
self.assertEqual(
"my_transaction_tag", requests[4].request_options.transaction_tag
)

def test_select_read_write_transaction_with_transaction_and_request_tag(self):
connection = Connection(self.instance, self.database)
connection.autocommit = False
connection.transaction_tag = "my_transaction_tag"
# The transaction tag should be included for all statements in the transaction.
self._execute_and_verify_select_singers(connection, request_tag="my_tag1")
self._execute_and_verify_select_singers(connection, request_tag="my_tag2")

# The transaction tag was cleared from the connection when the transaction
# was started.
self.assertIsNone(connection.transaction_tag)
# The commit call should also include a transaction tag.
connection.commit()
requests = self.spanner_service.requests
self.assertEqual(5, len(requests))
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
self.assertTrue(isinstance(requests[1], BeginTransactionRequest))
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
self.assertTrue(isinstance(requests[3], ExecuteSqlRequest))
self.assertTrue(isinstance(requests[4], CommitRequest))
self.assertEqual(
"my_transaction_tag", requests[2].request_options.transaction_tag
)
self.assertEqual("my_tag1", requests[2].request_options.request_tag)
self.assertEqual(
"my_transaction_tag", requests[3].request_options.transaction_tag
)
self.assertEqual("my_tag2", requests[3].request_options.request_tag)
self.assertEqual(
"my_transaction_tag", requests[4].request_options.transaction_tag
)

def test_request_tag_is_cleared(self):
connection = Connection(self.instance, self.database)
connection.autocommit = True
with connection.cursor() as cursor:
cursor.request_tag = "my_tag"
cursor.execute("select name from singers")
# This query will not have a request tag.
cursor.execute("select name from singers")
requests = self.spanner_service.requests
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
self.assertEqual("my_tag", requests[1].request_options.request_tag)
self.assertEqual("", requests[2].request_options.request_tag)

def _execute_and_verify_select_singers(
self, connection: Connection, request_tag: str = "", transaction_tag: str = ""
) -> ExecuteSqlRequest:
with connection.cursor() as cursor:
if request_tag:
cursor.request_tag = request_tag
cursor.execute("select name from singers")
result_list = cursor.fetchall()
for row in result_list:
self.assertEqual("Some Singer", row[0])
self.assertEqual(1, len(result_list))
requests = self.spanner_service.requests
return next(
request
for request in requests
if isinstance(request, ExecuteSqlRequest)
and request.sql == "select name from singers"
)

0 comments on commit 9031124

Please sign in to comment.