Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit 8d0eaa2

Browse files
committed
test: add mockserver tests for asyncIO operations
1 parent 03dd932 commit 8d0eaa2

File tree

23 files changed

+2751
-565
lines changed

23 files changed

+2751
-565
lines changed

google/cloud/aio/_cross_sync/cross_sync.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ class CrossSync(metaclass=MappingMeta):
9191
Task: TypeAlias = asyncio.Task
9292
Event: TypeAlias = asyncio.Event
9393
Semaphore: TypeAlias = asyncio.Semaphore
94+
LifoQueue: TypeAlias = asyncio.LifoQueue
95+
PriorityQueue: TypeAlias = asyncio.PriorityQueue
9496
StopIteration: TypeAlias = StopAsyncIteration
9597
# provide aliases for common async type annotations
9698
Awaitable: TypeAlias = typing.Awaitable
@@ -160,6 +162,23 @@ async def run_if_async(func, *args, **kwargs):
160162
return await res
161163
return res
162164

165+
@staticmethod
166+
async def queue_get(queue, block=True, timeout=None):
167+
if not block:
168+
return queue.get_nowait()
169+
if timeout is not None:
170+
return await asyncio.wait_for(queue.get(), timeout=timeout)
171+
return await queue.get()
172+
173+
@staticmethod
174+
async def queue_put(queue, item, block=True, timeout=None):
175+
if not block:
176+
return queue.put_nowait(item)
177+
if timeout is not None:
178+
await asyncio.wait_for(queue.put(item), timeout=timeout)
179+
else:
180+
await queue.put(item)
181+
163182
@staticmethod
164183
async def gather_partials(
165184
partial_list: Sequence[Callable[[], Awaitable[T]]],
@@ -288,6 +307,8 @@ class _Sync_Impl(metaclass=MappingMeta):
288307
Task: TypeAlias = concurrent.futures.Future
289308
Event: TypeAlias = threading.Event
290309
Semaphore: TypeAlias = threading.Semaphore
310+
LifoQueue: TypeAlias = queue.LifoQueue
311+
PriorityQueue: TypeAlias = queue.PriorityQueue
291312
StopIteration: TypeAlias = StopIteration
292313
# type annotations
293314
Awaitable: TypeAlias = Union[T]
@@ -304,6 +325,14 @@ def run_if_async(func, *args, **kwargs):
304325
"""
305326
return func(*args, **kwargs)
306327

328+
@staticmethod
329+
def queue_get(queue, block=True, timeout=None):
330+
return queue.get(block=block, timeout=timeout)
331+
332+
@staticmethod
333+
def queue_put(queue, item, block=True, timeout=None):
334+
queue.put(item, block=block, timeout=timeout)
335+
307336
@classmethod
308337
def Mock(cls, *args, **kwargs):
309338
from unittest.mock import Mock

google/cloud/spanner_v1/_async/client.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,31 @@
4141

4242

4343
from google.cloud.spanner_admin_database_v1 import DatabaseAdminAsyncClient as DatabaseAdminClient
44-
from google.cloud.spanner_admin_database_v1.services.database_admin.transports.grpc import (
45-
DatabaseAdminGrpcTransport,
46-
)
44+
if CrossSync.is_async:
45+
from google.cloud.spanner_admin_database_v1.services.database_admin.transports.grpc_asyncio import (
46+
DatabaseAdminGrpcAsyncIOTransport as DatabaseAdminGrpcTransport,
47+
)
48+
else:
49+
from google.cloud.spanner_admin_database_v1.services.database_admin.transports.grpc import (
50+
DatabaseAdminGrpcTransport,
51+
)
4752
from google.cloud.spanner_admin_instance_v1 import InstanceAdminAsyncClient as InstanceAdminClient
48-
from google.cloud.spanner_admin_instance_v1.services.instance_admin.transports.grpc import (
49-
InstanceAdminGrpcTransport,
50-
)
53+
if CrossSync.is_async:
54+
from google.cloud.spanner_admin_instance_v1.services.instance_admin.transports.grpc_asyncio import (
55+
InstanceAdminGrpcAsyncIOTransport as InstanceAdminGrpcTransport,
56+
)
57+
else:
58+
from google.cloud.spanner_admin_instance_v1.services.instance_admin.transports.grpc import (
59+
InstanceAdminGrpcTransport,
60+
)
5161
from google.cloud.spanner_admin_instance_v1 import ListInstanceConfigsRequest
5262
from google.cloud.spanner_admin_instance_v1 import ListInstancesRequest
5363
from google.cloud.spanner_v1 import __version__
5464
from google.cloud.spanner_v1 import ExecuteSqlRequest
5565
from google.cloud.spanner_v1 import DefaultTransactionOptions
5666
from google.cloud.spanner_v1._helpers import _merge_query_options
5767
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
58-
from google.cloud.spanner_v1.instance import Instance
68+
from google.cloud.spanner_v1._async.instance import Instance
5969
from google.cloud.spanner_v1.metrics.constants import (
6070
METRIC_EXPORT_INTERVAL_MS,
6171
)

google/cloud/spanner_v1/_async/database.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from typing import Optional
2323

2424
import grpc
25+
import asyncio
26+
import inspect
2527
import logging
2628
import re
2729
import threading
@@ -64,7 +66,7 @@
6466
from google.cloud.spanner_v1._async.batch import MutationGroups
6567
from google.cloud.spanner_v1.keyset import KeySet
6668
from google.cloud.spanner_v1.merged_result_set import MergedResultSet
67-
from google.cloud.spanner_v1.pool import BurstyPool
69+
from google.cloud.spanner_v1._async.pool import BurstyPool
6870
from google.cloud.spanner_v1._async.session import Session
6971
from google.cloud.spanner_v1._async.database_sessions_manager import (
7072
DatabaseSessionsManager,
@@ -73,9 +75,14 @@
7375
from google.cloud.spanner_v1._async.snapshot import _restart_on_unavailable
7476
from google.cloud.spanner_v1._async.snapshot import Snapshot
7577
from google.cloud.spanner_v1._async.streamed import StreamedResultSet
76-
from google.cloud.spanner_v1.services.spanner.transports.grpc import (
77-
SpannerGrpcTransport,
78-
)
78+
if CrossSync.is_async:
79+
from google.cloud.spanner_v1.services.spanner.transports.grpc_asyncio import (
80+
SpannerGrpcAsyncIOTransport as SpannerGrpcTransport,
81+
)
82+
else:
83+
from google.cloud.spanner_v1.services.spanner.transports.grpc import (
84+
SpannerGrpcTransport,
85+
)
7986
from google.cloud.spanner_v1.table import Table
8087
from google.cloud.spanner_v1._opentelemetry_tracing import (
8188
add_span_event,
@@ -205,7 +212,14 @@ def __init__(
205212
pool = BurstyPool(database_role=database_role)
206213

207214
self._pool = pool
208-
pool.bind(self)
215+
res = pool.bind(self)
216+
try:
217+
loop = asyncio.get_running_loop()
218+
if loop.is_running() and inspect.isawaitable(res):
219+
loop.create_task(res)
220+
except RuntimeError:
221+
# No running loop, bind should have been sync or will be failed later
222+
pass
209223
is_experimental_host = self._instance.experimental_host is not None
210224

211225
self._sessions_manager = DatabaseSessionsManager(
@@ -448,17 +462,21 @@ def spanner_api(self):
448462
client_info = self._instance._client._client_info
449463
client_options = self._instance._client._client_options
450464
if self._instance.emulator_host is not None:
451-
transport = SpannerGrpcTransport(
452-
channel=grpc.insecure_channel(self._instance.emulator_host)
453-
)
465+
if CrossSync.is_async:
466+
channel = grpc.aio.insecure_channel(self._instance.emulator_host)
467+
else:
468+
channel = grpc.insecure_channel(self._instance.emulator_host)
469+
transport = SpannerGrpcTransport(channel=channel)
454470
self._spanner_api = SpannerClient(
455471
client_info=client_info, transport=transport
456472
)
457473
return self._spanner_api
458474
if self._instance.experimental_host is not None:
459-
transport = SpannerGrpcTransport(
460-
channel=grpc.insecure_channel(self._instance.experimental_host)
461-
)
475+
if CrossSync.is_async:
476+
channel = grpc.aio.insecure_channel(self._instance.experimental_host)
477+
else:
478+
channel = grpc.insecure_channel(self._instance.experimental_host)
479+
transport = SpannerGrpcTransport(channel=channel)
462480
self._spanner_api = SpannerClient(
463481
client_info=client_info,
464482
transport=transport,

0 commit comments

Comments
 (0)