Skip to content

Commit

Permalink
feat: add connection variable for ignoring transaction warnings (#1249)
Browse files Browse the repository at this point in the history
Adds a connection variable for ignoring transaction warnings. Also adds
a **kwargs argument to the connect function. This will be used for
further connection variables in the future.

Fixes googleapis/python-spanner-sqlalchemy#494
  • Loading branch information
olavloite authored Dec 4, 2024
1 parent ccae6e0 commit eeb7836
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
26 changes: 19 additions & 7 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,11 @@ class Connection:
committed by other transactions since the start of the read-only transaction. Commit or rolling back
the read-only transaction is semantically the same, and only indicates that the read-only transaction
should end a that a new one should be started when the next statement is executed.
**kwargs: Initial value for connection variables.
"""

def __init__(self, instance, database=None, read_only=False):
def __init__(self, instance, database=None, read_only=False, **kwargs):
self._instance = instance
self._database = database
self._ddl_statements = []
Expand All @@ -117,6 +119,7 @@ def __init__(self, instance, database=None, read_only=False):
self._batch_dml_executor: BatchDmlExecutor = None
self._transaction_helper = TransactionRetryHelper(self)
self._autocommit_dml_mode: AutocommitDmlMode = AutocommitDmlMode.TRANSACTIONAL
self._connection_variables = kwargs

@property
def spanner_client(self):
Expand Down Expand Up @@ -206,6 +209,10 @@ def _client_transaction_started(self):
"""
return (not self._autocommit) or self._transaction_begin_marked

@property
def _ignore_transaction_warnings(self):
return self._connection_variables.get("ignore_transaction_warnings", False)

@property
def instance(self):
"""Instance to which this connection relates.
Expand Down Expand Up @@ -398,9 +405,10 @@ def commit(self):
if self.database is None:
raise ValueError("Database needs to be passed for this operation")
if not self._client_transaction_started:
warnings.warn(
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
)
if not self._ignore_transaction_warnings:
warnings.warn(
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
)
return

self.run_prior_DDL_statements()
Expand All @@ -418,9 +426,10 @@ def rollback(self):
This is a no-op if there is no active client transaction.
"""
if not self._client_transaction_started:
warnings.warn(
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
)
if not self._ignore_transaction_warnings:
warnings.warn(
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
)
return
try:
if self._spanner_transaction_started and not self._read_only:
Expand Down Expand Up @@ -654,6 +663,7 @@ def connect(
user_agent=None,
client=None,
route_to_leader_enabled=True,
**kwargs,
):
"""Creates a connection to a Google Cloud Spanner database.
Expand Down Expand Up @@ -696,6 +706,8 @@ def connect(
disable leader aware routing. Disabling leader aware routing would
route all requests in RW/PDML transactions to the closest region.
**kwargs: Initial value for connection variables.
:rtype: :class:`google.cloud.spanner_dbapi.connection.Connection`
:returns: Connection object associated with the given Google Cloud Spanner
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/spanner_dbapi/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,19 @@ def test_commit_in_autocommit_mode(self, mock_warn):
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
)

@mock.patch.object(warnings, "warn")
def test_commit_in_autocommit_mode_with_ignore_warnings(self, mock_warn):
conn = self._make_connection(
DatabaseDialect.DATABASE_DIALECT_UNSPECIFIED,
ignore_transaction_warnings=True,
)
assert conn._ignore_transaction_warnings
conn._autocommit = True

conn.commit()

assert not mock_warn.warn.called

def test_commit_database_error(self):
from google.cloud.spanner_dbapi import Connection

Expand Down

0 comments on commit eeb7836

Please sign in to comment.