Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(spanner): add implementation and integration tests for max commi… #1082

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _check_state(self):
if self.committed is not None:
raise ValueError("Batch already committed")

def commit(self, return_commit_stats=False, request_options=None):
def commit(self, return_commit_stats=False, request_options=None, max_commit_delay=None):
"""Commit mutations to the database.

:type return_commit_stats: bool
Expand Down Expand Up @@ -189,6 +189,7 @@ def commit(self, return_commit_stats=False, request_options=None):
single_use_transaction=txn_options,
return_commit_stats=return_commit_stats,
request_options=request_options,
max_commit_delay=max_commit_delay,
)
with trace_call("CloudSpanner.Commit", self._session, trace_attributes):
method = functools.partial(
Expand Down
8 changes: 5 additions & 3 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ def snapshot(self, **kw):
"""
return SnapshotCheckout(self, **kw)

def batch(self, request_options=None):
def batch(self, request_options=None, max_commit_delay=None):
"""Return an object which wraps a batch.

The wrapper *must* be used as a context manager, with the batch
Expand All @@ -737,7 +737,7 @@ def batch(self, request_options=None):
:rtype: :class:`~google.cloud.spanner_v1.database.BatchCheckout`
:returns: new wrapper
"""
return BatchCheckout(self, request_options)
return BatchCheckout(self, request_options, max_commit_delay)

def mutation_groups(self):
"""Return an object which wraps a mutation_group.
Expand Down Expand Up @@ -1037,7 +1037,7 @@ class BatchCheckout(object):
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.
"""

def __init__(self, database, request_options=None):
def __init__(self, database, request_options=None, max_commit_delay=None):
self._database = database
self._session = self._batch = None
if request_options is None:
Expand All @@ -1046,6 +1046,7 @@ def __init__(self, database, request_options=None):
self._request_options = RequestOptions(request_options)
else:
self._request_options = request_options
self._max_commit_delay = max_commit_delay

def __enter__(self):
"""Begin ``with`` block."""
Expand All @@ -1062,6 +1063,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self._batch.commit(
return_commit_stats=self._database.log_commit_stats,
request_options=self._request_options,
max_commit_delay=self._max_commit_delay,
)
finally:
if self._database.log_commit_stats and self._batch.commit_stats:
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def run_in_transaction(self, func, *args, **kw):
"""
deadline = time.time() + kw.pop("timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS)
commit_request_options = kw.pop("commit_request_options", None)
max_commit_delay = kw.pop("max_commit_delay", None)
transaction_tag = kw.pop("transaction_tag", None)
attempts = 0

Expand Down Expand Up @@ -400,6 +401,7 @@ def run_in_transaction(self, func, *args, **kw):
txn.commit(
return_commit_stats=self._database.log_commit_stats,
request_options=commit_request_options,
max_commit_delay=max_commit_delay,
)
except Aborted as exc:
del self._transaction
Expand Down
3 changes: 2 additions & 1 deletion google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def rollback(self):
self.rolled_back = True
del self._session._transaction

def commit(self, return_commit_stats=False, request_options=None):
def commit(self, return_commit_stats=False, request_options=None, max_commit_delay=None):
"""Commit mutations to the database.

:type return_commit_stats: bool
Expand Down Expand Up @@ -229,6 +229,7 @@ def commit(self, return_commit_stats=False, request_options=None):
transaction_id=self._transaction_id,
return_commit_stats=return_commit_stats,
request_options=request_options,
max_commit_delay=max_commit_delay,
)
with trace_call("CloudSpanner.Commit", self._session, trace_attributes):
method = functools.partial(
Expand Down
37 changes: 36 additions & 1 deletion tests/system/test_database_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import time
import uuid

Expand Down Expand Up @@ -819,3 +819,38 @@ def _transaction_read(transaction):

with pytest.raises(exceptions.InvalidArgument):
shared_database.run_in_transaction(_transaction_read)


def test_db_batch_insert_w_max_commit_delay(shared_database):
_helpers.retry_has_all_dll(shared_database.reload)()
sd = _sample_data

with shared_database.batch(max_commit_delay=datetime.timedelta(milliseconds=100)) as batch:
batch.delete(sd.TABLE, sd.ALL)
batch.insert(sd.TABLE, sd.COLUMNS, sd.ROW_DATA)

with shared_database.snapshot(read_timestamp=batch.committed) as snapshot:
from_snap = list(snapshot.read(sd.TABLE, sd.COLUMNS, sd.ALL))

sd._check_rows_data(from_snap)


def test_db_run_in_transaction_w_max_commit_delay(shared_database):
_helpers.retry_has_all_dll(shared_database.reload)()
sd = _sample_data

with shared_database.batch() as batch:
batch.delete(sd.TABLE, sd.ALL)

def _unit_of_work(transaction, test):
rows = list(transaction.read(test.TABLE, test.COLUMNS, sd.ALL))
assert rows == []

transaction.insert_or_update(test.TABLE, test.COLUMNS, test.ROW_DATA)

shared_database.run_in_transaction(_unit_of_work, test=sd, max_commit_delay=datetime.timedelta(milliseconds=100))

with shared_database.snapshot() as after:
rows = list(after.execute_sql(sd.SQL))

sd._check_rows_data(rows)