Skip to content

Commit

Permalink
Address more review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em committed Nov 28, 2024
1 parent 7a73d65 commit b6187d8
Showing 1 changed file with 4 additions and 231 deletions.
235 changes: 4 additions & 231 deletions tests/unit/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,10 @@ def test_spans_get_non_empty_session_exists(self):
"pool.Get",
attributes=TestBurstyPool.BASE_ATTRIBUTES,
)
self.assertSpanEvents(
"pool.Get",
["Acquiring session", "Waiting for a session to become available"],
)

def test_get_non_empty_session_expired(self):
pool = self._make_one()
Expand Down Expand Up @@ -888,237 +892,6 @@ def test_spans_get_empty_pool(self):
self.assertSpanEvents("pool.Get", wantEventNames)


class TestTransactionPingingPool(OpenTelemetryBase):
BASE_ATTRIBUTES = {
"db.type": "spanner",
"db.url": "spanner.googleapis.com",
"db.instance": "name",
"net.host.name": "spanner.googleapis.com",
}
enrich_with_otel_scope(BASE_ATTRIBUTES)

def _getTargetClass(self):
from google.cloud.spanner_v1.pool import TransactionPingingPool

return TransactionPingingPool

def _make_one(self, *args, **kwargs):
return self._getTargetClass()(*args, **kwargs)

def test_ctor_defaults(self):
pool = self._make_one()
self.assertIsNone(pool._database)
self.assertEqual(pool.size, 10)
self.assertEqual(pool.default_timeout, 10)
self.assertEqual(pool._delta.seconds, 3000)
self.assertTrue(pool._sessions.empty())
self.assertTrue(pool._pending_sessions.empty())
self.assertEqual(pool.labels, {})
self.assertIsNone(pool.database_role)

def test_ctor_explicit(self):
labels = {"foo": "bar"}
database_role = "dummy-role"
pool = self._make_one(
size=4,
default_timeout=30,
ping_interval=1800,
labels=labels,
database_role=database_role,
)
self.assertIsNone(pool._database)
self.assertEqual(pool.size, 4)
self.assertEqual(pool.default_timeout, 30)
self.assertEqual(pool._delta.seconds, 1800)
self.assertTrue(pool._sessions.empty())
self.assertTrue(pool._pending_sessions.empty())
self.assertEqual(pool.labels, labels)
self.assertEqual(pool.database_role, database_role)

def test_ctor_explicit_w_database_role_in_db(self):
database_role = "dummy-role"
pool = self._make_one()
database = pool._database = _Database("name")
SESSIONS = [_Session(database)] * 10
database._sessions.extend(SESSIONS)
database._database_role = database_role
pool.bind(database)
self.assertEqual(pool.database_role, database_role)

def test_bind(self):
pool = self._make_one()
database = _Database("name")
SESSIONS = [_Session(database) for _ in range(10)]
database._sessions.extend(SESSIONS)
pool.bind(database)

self.assertIs(pool._database, database)
self.assertEqual(pool.size, 10)
self.assertEqual(pool.default_timeout, 10)
self.assertEqual(pool._delta.seconds, 3000)
self.assertTrue(pool._sessions.full())

api = database.spanner_api
self.assertEqual(api.batch_create_sessions.call_count, 5)
for session in SESSIONS:
session.create.assert_not_called()
txn = session._transaction
txn.begin.assert_not_called()

self.assertTrue(pool._pending_sessions.empty())
self.assertNoSpans()

def test_bind_w_timestamp_race(self):
import datetime
from google.cloud._testing import _Monkey
from google.cloud.spanner_v1 import pool as MUT

NOW = datetime.datetime.utcnow()
pool = self._make_one()
database = _Database("name")
SESSIONS = [_Session(database) for _ in range(10)]
database._sessions.extend(SESSIONS)

with _Monkey(MUT, _NOW=lambda: NOW):
pool.bind(database)

self.assertIs(pool._database, database)
self.assertEqual(pool.size, 10)
self.assertEqual(pool.default_timeout, 10)
self.assertEqual(pool._delta.seconds, 3000)
self.assertTrue(pool._sessions.full())

api = database.spanner_api
self.assertEqual(api.batch_create_sessions.call_count, 5)
for session in SESSIONS:
session.create.assert_not_called()
txn = session._transaction
txn.begin.assert_not_called()

self.assertTrue(pool._pending_sessions.empty())
self.assertNoSpans()

def test_put_full(self):
import queue

pool = self._make_one(size=4)
database = _Database("name")
SESSIONS = [_Session(database) for _ in range(4)]
database._sessions.extend(SESSIONS)
pool.bind(database)

with self.assertRaises(queue.Full):
pool.put(_Session(database))

self.assertTrue(pool._sessions.full())
self.assertNoSpans()

def test_put_non_full_w_active_txn(self):
pool = self._make_one(size=1)
session_queue = pool._sessions = _Queue()
pending = pool._pending_sessions = _Queue()
database = _Database("name")
session = _Session(database)
txn = session.transaction()

pool.put(session)

self.assertEqual(len(session_queue._items), 1)
_, queued = session_queue._items[0]
self.assertIs(queued, session)

self.assertEqual(len(pending._items), 0)
txn.begin.assert_not_called()
self.assertNoSpans()

def test_put_non_full_w_committed_txn(self):
pool = self._make_one(size=1)
session_queue = pool._sessions = _Queue()
pending = pool._pending_sessions = _Queue()
database = _Database("name")
session = _Session(database)
committed = session.transaction()
committed.committed = True

pool.put(session)

self.assertEqual(len(session_queue._items), 0)

self.assertEqual(len(pending._items), 1)
self.assertIs(pending._items[0], session)
self.assertIsNot(session._transaction, committed)
session._transaction.begin.assert_not_called()
self.assertNoSpans()

def test_put_non_full(self):
pool = self._make_one(size=1)
session_queue = pool._sessions = _Queue()
pending = pool._pending_sessions = _Queue()
database = _Database("name")
session = _Session(database)

pool.put(session)

self.assertEqual(len(session_queue._items), 0)
self.assertEqual(len(pending._items), 1)
self.assertIs(pending._items[0], session)

self.assertFalse(pending.empty())
self.assertNoSpans()

def test_begin_pending_transactions_empty(self):
pool = self._make_one(size=1)
pool.begin_pending_transactions() # no raise
self.assertNoSpans()

def test_begin_pending_transactions_non_empty(self):
pool = self._make_one(size=1)
pool._sessions = _Queue()

database = _Database("name")
TRANSACTIONS = [_make_transaction(object())]
PENDING_SESSIONS = [_Session(database, transaction=txn) for txn in TRANSACTIONS]

pending = pool._pending_sessions = _Queue(*PENDING_SESSIONS)
self.assertFalse(pending.empty())

pool.begin_pending_transactions() # no raise

for txn in TRANSACTIONS:
txn.begin.assert_not_called()

self.assertTrue(pending.empty())
self.assertNoSpans()

def test_spans_get_empty_pool(self):
pool = self._make_one()
database = _Database("name")
session1 = _Session(database)
database._sessions.append(session1)
try:
pool.bind(database)
except Exception:
pass

with trace_call("pool.Get", session1):
try:
pool.get()
except Exception:
pass

self.assertTrue(pool._sessions.empty())

self.assertSpanAttributes(
"pool.Get",
attributes=TestTransactionPingingPool.BASE_ATTRIBUTES,
)
wantEventNames = [
"Waiting for a session to become available",
"No session available",
]
self.assertSpanEvents("pool.Get", wantEventNames)


class TestSessionCheckout(unittest.TestCase):
def _getTargetClass(self):
from google.cloud.spanner_v1.pool import SessionCheckout
Expand Down

0 comments on commit b6187d8

Please sign in to comment.