Skip to content

Commit

Permalink
Merge branch 'main' into release-please--branches--main
Browse files Browse the repository at this point in the history
  • Loading branch information
harshachinta authored Dec 6, 2024
2 parents 0feb0a8 + 96da8e1 commit 291e605
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 113 deletions.
139 changes: 139 additions & 0 deletions tests/mockserver_tests/mock_server_test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer
from google.cloud.spanner_v1.testing.mock_spanner import (
start_mock_server,
SpannerServicer,
)
import google.cloud.spanner_v1.types.type as spanner_type
import google.cloud.spanner_v1.types.result_set as result_set
from google.api_core.client_options import ClientOptions
from google.auth.credentials import AnonymousCredentials
from google.cloud.spanner_v1 import Client, TypeCode, FixedSizePool
from google.cloud.spanner_v1.database import Database
from google.cloud.spanner_v1.instance import Instance
import grpc


def add_result(sql: str, result: result_set.ResultSet):
MockServerTestBase.spanner_service.mock_spanner.add_result(sql, result)


def add_update_count(
sql: str, count: int, dml_mode: AutocommitDmlMode = AutocommitDmlMode.TRANSACTIONAL
):
if dml_mode == AutocommitDmlMode.PARTITIONED_NON_ATOMIC:
stats = dict(row_count_lower_bound=count)
else:
stats = dict(row_count_exact=count)
result = result_set.ResultSet(dict(stats=result_set.ResultSetStats(stats)))
add_result(sql, result)


def add_select1_result():
add_single_result("select 1", "c", TypeCode.INT64, [("1",)])


def add_single_result(
sql: str, column_name: str, type_code: spanner_type.TypeCode, row
):
result = result_set.ResultSet(
dict(
metadata=result_set.ResultSetMetadata(
dict(
row_type=spanner_type.StructType(
dict(
fields=[
spanner_type.StructType.Field(
dict(
name=column_name,
type=spanner_type.Type(dict(code=type_code)),
)
)
]
)
)
)
),
)
)
result.rows.extend(row)
MockServerTestBase.spanner_service.mock_spanner.add_result(sql, result)


class MockServerTestBase(unittest.TestCase):
server: grpc.Server = None
spanner_service: SpannerServicer = None
database_admin_service: DatabaseAdminServicer = None
port: int = None

def __init__(self, *args, **kwargs):
super(MockServerTestBase, self).__init__(*args, **kwargs)
self._client = None
self._instance = None
self._database = None

@classmethod
def setup_class(cls):
(
MockServerTestBase.server,
MockServerTestBase.spanner_service,
MockServerTestBase.database_admin_service,
MockServerTestBase.port,
) = start_mock_server()

@classmethod
def teardown_class(cls):
if MockServerTestBase.server is not None:
MockServerTestBase.server.stop(grace=None)
MockServerTestBase.server = None

def setup_method(self, *args, **kwargs):
self._client = None
self._instance = None
self._database = None

def teardown_method(self, *args, **kwargs):
MockServerTestBase.spanner_service.clear_requests()
MockServerTestBase.database_admin_service.clear_requests()

@property
def client(self) -> Client:
if self._client is None:
self._client = Client(
project="p",
credentials=AnonymousCredentials(),
client_options=ClientOptions(
api_endpoint="localhost:" + str(MockServerTestBase.port),
),
)
return self._client

@property
def instance(self) -> Instance:
if self._instance is None:
self._instance = self.client.instance("test-instance")
return self._instance

@property
def database(self) -> Database:
if self._database is None:
self._database = self.instance.database(
"test-database", pool=FixedSizePool(size=10)
)
return self._database
121 changes: 8 additions & 113 deletions tests/mockserver_tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,131 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from google.cloud.spanner_admin_database_v1.types import spanner_database_admin
from google.cloud.spanner_dbapi import Connection
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer
from google.cloud.spanner_v1.testing.mock_spanner import (
start_mock_server,
SpannerServicer,
)
import google.cloud.spanner_v1.types.type as spanner_type
import google.cloud.spanner_v1.types.result_set as result_set
from google.api_core.client_options import ClientOptions
from google.auth.credentials import AnonymousCredentials
from google.cloud.spanner_v1 import (
Client,
FixedSizePool,
BatchCreateSessionsRequest,
ExecuteSqlRequest,
BeginTransactionRequest,
TransactionOptions,
)
from google.cloud.spanner_v1.database import Database
from google.cloud.spanner_v1.instance import Instance
import grpc


class TestBasics(unittest.TestCase):
server: grpc.Server = None
spanner_service: SpannerServicer = None
database_admin_service: DatabaseAdminServicer = None
port: int = None

def __init__(self, *args, **kwargs):
super(TestBasics, self).__init__(*args, **kwargs)
self._client = None
self._instance = None
self._database = None

@classmethod
def setUpClass(cls):
(
TestBasics.server,
TestBasics.spanner_service,
TestBasics.database_admin_service,
TestBasics.port,
) = start_mock_server()

@classmethod
def tearDownClass(cls):
if TestBasics.server is not None:
TestBasics.server.stop(grace=None)
TestBasics.server = None

def teardown_method(self, *args, **kwargs):
TestBasics.spanner_service.clear_requests()
TestBasics.database_admin_service.clear_requests()

def _add_select1_result(self):
result = result_set.ResultSet(
dict(
metadata=result_set.ResultSetMetadata(
dict(
row_type=spanner_type.StructType(
dict(
fields=[
spanner_type.StructType.Field(
dict(
name="c",
type=spanner_type.Type(
dict(code=spanner_type.TypeCode.INT64)
),
)
)
]
)
)
)
),
)
)
result.rows.extend(["1"])
TestBasics.spanner_service.mock_spanner.add_result("select 1", result)

def add_update_count(
self,
sql: str,
count: int,
dml_mode: AutocommitDmlMode = AutocommitDmlMode.TRANSACTIONAL,
):
if dml_mode == AutocommitDmlMode.PARTITIONED_NON_ATOMIC:
stats = dict(row_count_lower_bound=count)
else:
stats = dict(row_count_exact=count)
result = result_set.ResultSet(dict(stats=result_set.ResultSetStats(stats)))
TestBasics.spanner_service.mock_spanner.add_result(sql, result)

@property
def client(self) -> Client:
if self._client is None:
self._client = Client(
project="test-project",
credentials=AnonymousCredentials(),
client_options=ClientOptions(
api_endpoint="localhost:" + str(TestBasics.port),
),
)
return self._client

@property
def instance(self) -> Instance:
if self._instance is None:
self._instance = self.client.instance("test-instance")
return self._instance
from tests.mockserver_tests.mock_server_test_base import (
MockServerTestBase,
add_select1_result,
add_update_count,
)

@property
def database(self) -> Database:
if self._database is None:
self._database = self.instance.database(
"test-database", pool=FixedSizePool(size=10)
)
return self._database

class TestBasics(MockServerTestBase):
def test_select1(self):
self._add_select1_result()
add_select1_result()
with self.database.snapshot() as snapshot:
results = snapshot.execute_sql("select 1")
result_list = []
Expand Down Expand Up @@ -171,7 +66,7 @@ def test_create_table(self):
# been re-factored to use a base class for the boiler plate code.
def test_dbapi_partitioned_dml(self):
sql = "UPDATE singers SET foo='bar' WHERE active = true"
self.add_update_count(sql, 100, AutocommitDmlMode.PARTITIONED_NON_ATOMIC)
add_update_count(sql, 100, AutocommitDmlMode.PARTITIONED_NON_ATOMIC)
connection = Connection(self.instance, self.database)
connection.autocommit = True
connection.set_autocommit_dml_mode(AutocommitDmlMode.PARTITIONED_NON_ATOMIC)
Expand Down

0 comments on commit 291e605

Please sign in to comment.