diff --git a/cr8/bench_spec.py b/cr8/bench_spec.py index 4940c61..b739a4d 100644 --- a/cr8/bench_spec.py +++ b/cr8/bench_spec.py @@ -30,12 +30,13 @@ def from_dict(d): class Spec: - def __init__(self, setup, teardown, queries=None, load_data=None, meta=None): + def __init__(self, setup, teardown, queries=None, load_data=None, meta=None, session_settings=None): self.setup = setup self.teardown = teardown self.queries = queries self.load_data = load_data self.meta = meta or {} + self.session_settings = session_settings or {} @staticmethod def from_dict(d): @@ -45,6 +46,7 @@ def from_dict(d): meta=d.get('meta', {}), queries=d.get('queries', []), load_data=d.get('load_data', []), + session_settings=d.get('session_settings', {}), ) @staticmethod diff --git a/cr8/clients.py b/cr8/clients.py index 6164b49..0c6de01 100644 --- a/cr8/clients.py +++ b/cr8/clients.py @@ -1,4 +1,5 @@ import json + import aiohttp import itertools import calendar @@ -6,7 +7,7 @@ import time from urllib.parse import urlparse, parse_qs, urlunparse from datetime import datetime, date -from typing import List, Union, Iterable +from typing import List, Union, Iterable, Dict from decimal import Decimal from cr8.aio import asyncio # import via aio for uvloop setup @@ -216,18 +217,25 @@ def _verify_ssl_from_first(hosts): class AsyncpgClient: - def __init__(self, hosts, pool_size=25): + def __init__(self, hosts, pool_size=25, session_settings=None): self.dsn = _to_dsn(hosts) self.pool_size = pool_size self._pool = None self.is_cratedb = True + self.session_settings = session_settings or {} async def _get_pool(self): + + async def set_session_settings(conn): + for setting, value in self.session_settings.items(): + await conn.execute(f'set {setting}={value}') + if not self._pool: self._pool = await asyncpg.create_pool( self.dsn, min_size=self.pool_size, - max_size=self.pool_size + max_size=self.pool_size, + init=set_session_settings ) return self._pool @@ -308,59 +316,83 @@ def _append_sql(host): class HttpClient: - def __init__(self, hosts, conn_pool_limit=25): + def __init__(self, hosts, conn_pool_limit=25, session_settings=None): self.hosts = hosts self.urls = itertools.cycle(list(map(_append_sql, hosts))) - self._connector_params = { - 'limit': conn_pool_limit, - 'verify_ssl': _verify_ssl_from_first(self.hosts) - } - self.__session = None + self.conn_pool_limit = conn_pool_limit self.is_cratedb = True - - @property - async def _session(self): - session = self.__session - if session is None: - conn = aiohttp.TCPConnector(**self._connector_params) - self.__session = session = aiohttp.ClientSession(connector=conn) - return session + self._pools = {} + self.session_settings = session_settings or {} + + async def _session(self, url): + pool = self._pools.get(url) + if not pool: + pool = asyncio.Queue(maxsize=self.conn_pool_limit) + self._pools[url] = pool + _connector_params = { + 'limit': 1, + 'verify_ssl': _verify_ssl_from_first(self.hosts) + } + for n in range(0, self.conn_pool_limit): + tcp_connector = aiohttp.TCPConnector(**_connector_params) + session = aiohttp.ClientSession(connector=tcp_connector) + for setting, value in self.session_settings.items(): + payload = {'stmt': f'set {setting}={value}'} + await _exec( + session, + url, + dumps(payload, cls=CrateJsonEncoder) + ) + pool.put_nowait(session) + + return await pool.get() async def execute(self, stmt, args=None): payload = {'stmt': _plain_or_callable(stmt)} if args: payload['args'] = _plain_or_callable(args) - session = await self._session - return await _exec( + url = next(self.urls) + session = await self._session(url) + result = await _exec( session, - next(self.urls), + url, dumps(payload, cls=CrateJsonEncoder) ) + await self._pools[url].put(session) + return result async def execute_many(self, stmt, bulk_args): data = dumps(dict( stmt=_plain_or_callable(stmt), bulk_args=_plain_or_callable(bulk_args) ), cls=CrateJsonEncoder) - session = await self._session - return await _exec(session, next(self.urls), data) + url = next(self.urls) + session = await self._session(url) + result = await _exec(session, url, data) + await self._pools[url].put(session) + return result async def get_server_version(self): - session = await self._session urlparts = urlparse(self.hosts[0]) url = urlunparse((urlparts.scheme, urlparts.netloc, '/', '', '', '')) + session = await self._session(url) async with session.get(url) as resp: r = await resp.json() version = r['version'] - return { + result = { 'hash': version['build_hash'], 'number': version['number'], 'date': _date_or_none(version['build_timestamp'][:10]) } + await self._pools[url].put(session) + return result async def _close(self): - if self.__session: - await self.__session.close() + for url, pool in self._pools.items(): + while not pool.empty(): + session = await pool.get() + await session.close() + self._pools = {} def close(self): asyncio.get_event_loop().run_until_complete(self._close()) @@ -372,10 +404,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.close() -def client(hosts, concurrency=25): +def client(hosts, session_settings=None, concurrency=25): hosts = hosts or 'localhost:4200' if hosts.startswith('asyncpg://'): if not asyncpg: raise ValueError('Cannot use "asyncpg" scheme if asyncpg is not available') - return AsyncpgClient(hosts, pool_size=concurrency) - return HttpClient(_to_http_hosts(hosts), conn_pool_limit=concurrency) + return AsyncpgClient(hosts, pool_size=concurrency, session_settings=session_settings) + return HttpClient(_to_http_hosts(hosts), conn_pool_limit=concurrency, session_settings=session_settings) diff --git a/cr8/engine.py b/cr8/engine.py index 213786b..fbe8170 100644 --- a/cr8/engine.py +++ b/cr8/engine.py @@ -5,8 +5,7 @@ from cr8 import aio from cr8.metrics import Stats, get_sampler -from cr8.clients import client - +from cr8.clients import client, HttpClient TimedStats = namedtuple('TimedStats', ['started', 'ended', 'stats']) @@ -69,9 +68,9 @@ def _generate_statements(stmt, args, iterations, duration): class Runner: - def __init__(self, hosts, concurrency, sample_mode): + def __init__(self, hosts, concurrency, sample_mode, session_settings=None): self.concurrency = concurrency - self.client = client(hosts, concurrency=concurrency) + self.client = client(hosts, session_settings=session_settings, concurrency=concurrency) self.sampler = get_sampler(sample_mode) def warmup(self, stmt, num_warmup, concurrency=0, args=None): diff --git a/cr8/run_spec.py b/cr8/run_spec.py index d3f0d37..3216fc7 100644 --- a/cr8/run_spec.py +++ b/cr8/run_spec.py @@ -179,7 +179,7 @@ def _skip_message(self, min_version, stmt): server_version='.'.join((str(x) for x in self.server_version))) return msg - def run_queries(self, queries: Iterable[dict], meta=None): + def run_queries(self, queries: Iterable[dict], meta=None, session_settings=None): for query in queries: stmt = query['statement'] iterations = query.get('iterations', 1) @@ -204,7 +204,7 @@ def run_queries(self, queries: Iterable[dict], meta=None): f' Concurrency: {concurrency}\n' f' {mode_desc}: {duration or iterations}') ) - with Runner(self.benchmark_hosts, concurrency, self.sample_mode) as runner: + with Runner(self.benchmark_hosts, concurrency, self.sample_mode, session_settings) as runner: if warmup > 0: runner.warmup(stmt, warmup, concurrency, args) timed_stats = runner.run( @@ -266,7 +266,7 @@ def do_run_spec(spec, queries = (q for q in spec.queries if 'name' in q and rex.match(q['name'])) else: queries = spec.queries - executor.run_queries(queries, spec.meta) + executor.run_queries(queries, spec.meta, spec.session_settings) finally: if not action or 'teardown' in action: log.info('# Running tearDown') diff --git a/specs/count_countries.json b/specs/count_countries.json index 31fe1e1..1d21d91 100644 --- a/specs/count_countries.json +++ b/specs/count_countries.json @@ -14,6 +14,10 @@ } ] }, + "session_settings": { + "application_name": "my_app", + "timezone": "UTC" + }, "queries": [{ "iterations": 1000, "statement": "select count(*) from countries" diff --git a/specs/sample.py b/specs/sample.py index 5575a3c..df6283b 100644 --- a/specs/sample.py +++ b/specs/sample.py @@ -1,4 +1,3 @@ - from itertools import count from cr8.bench_spec import Spec, Instructions @@ -21,4 +20,5 @@ def queries(): setup=Instructions(statements=["create table t (x int)"]), teardown=Instructions(statements=["drop table t"]), queries=queries(), -) + session_settings={'application_name': 'my_app', 'timezone': 'UTC'} +) \ No newline at end of file diff --git a/specs/sample.toml b/specs/sample.toml index cf61bb4..04b9991 100644 --- a/specs/sample.toml +++ b/specs/sample.toml @@ -14,6 +14,10 @@ statement_files = ["sql/create_countries.sql"] target = "countries" cmd = ['echo', '{"capital": "Demo"}'] +[session_settings] +application_name = 'my_app' +timezone = 'UTC' + [[queries]] name = "count countries" # Can be used to give the queries a name for easier analytics of the results statement = "select count(*) from countries" diff --git a/tests/test_integration.py b/tests/test_integration.py index b30b365..bd4845c 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -63,7 +63,7 @@ def parse(self, string, name=''): class SourceBuildTest(TestCase): def test_build_from_branch(self): - self.assertIsNotNone(get_crate('4.1')) + self.assertIsNotNone(get_crate('5.8')) def load_tests(loader, tests, ignore): diff --git a/tests/test_spec.py b/tests/test_spec.py new file mode 100644 index 0000000..865aa99 --- /dev/null +++ b/tests/test_spec.py @@ -0,0 +1,30 @@ +import os +from unittest import TestCase +from doctest import DocTestSuite + +from cr8.bench_spec import load_spec + +from cr8 import engine + + +class SpecTest(TestCase): + + def test_session_settings_from_spec(self): + spec = self.get_spec('sample.py') + self.assertEqual(spec.session_settings, {'application_name': 'my_app', 'timezone': 'UTC'}) + + def test_session_settings_from_toml(self): + spec = self.get_spec('sample.toml') + self.assertEqual(spec.session_settings, {'application_name': 'my_app', 'timezone': 'UTC'}) + + def test_session_settings_from_json(self): + spec = self.get_spec('count_countries.json') + self.assertEqual(spec.session_settings, {'application_name': 'my_app', 'timezone': 'UTC'}) + + def get_spec(self, name): + return load_spec(os.path.abspath(os.path.join(os.path.dirname(__file__), '../specs/', name))) + + +def load_tests(loader, tests, ignore): + tests.addTests(DocTestSuite(engine)) + return tests