Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix concurrent initialization of connection pool #1825

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ Changelog
0.23
====

0.23.1
------
Fixed
^^^^^
- Concurrent connection pool initialization (#1825)

0.23.0
------
Added
Expand Down
24 changes: 24 additions & 0 deletions tests/test_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys

from tests.testmodels import Tournament, UniqueName
from tortoise import Tortoise, connections
from tortoise.contrib import test
from tortoise.contrib.test.condition import NotEQ
from tortoise.transactions import in_transaction
Expand Down Expand Up @@ -89,3 +90,26 @@ async def test_nonconcurrent_get_or_create(self):
self.assertEqual(len(una_created), 1)
for una in unas:
self.assertEqual(una[0], unas[0][0])


class TestConcurrentDBConnectionInitialization(test.IsolatedTestCase):
"""Tortoise.init is lazy and does not initialize the database connection until the first query.
These tests ensure that concurrent queries do not cause initialization issues."""

async def _setUpDB(self) -> None:
"""Override to avoid database connection initialization when generating the schema."""
await super()._setUpDB()
config = test.getDBConfig(app_label="models", modules=test._MODULES)
await Tortoise.init(config, _create_db=True)

async def test_concurrent_queries(self):
await asyncio.gather(
*[connections.get("models").execute_query("SELECT 1") for _ in range(100)]
)

async def test_concurrent_transactions(self) -> None:
async def transaction() -> None:
async with in_transaction():
await connections.get("models").execute_query("SELECT 1")

await asyncio.gather(*[transaction() for _ in range(100)])
4 changes: 3 additions & 1 deletion tortoise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,9 @@ async def init(
table_name_generator: Callable[[Type["Model"]], str] | None = None,
) -> None:
"""
Sets up Tortoise-ORM.
Sets up Tortoise-ORM: loads apps and models, configures database connections but does not
connect to the database yet. The actual connection or connection pool is established
lazily on first query execution.

You can configure using only one of ``config``, ``config_file``
and ``(db_url, modules)``.
Expand Down
4 changes: 2 additions & 2 deletions tortoise/backends/asyncpg/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ async def db_delete(self) -> None:
await self.close()

def acquire_connection(self) -> Union["PoolConnectionWrapper", "ConnectionWrapper"]:
return PoolConnectionWrapper(self)
return PoolConnectionWrapper(self, self._pool_init_lock)

def _in_transaction(self) -> "TransactionContext":
return TransactionContextPooled(TransactionWrapper(self))
return TransactionContextPooled(TransactionWrapper(self), self._pool_init_lock)

@translate_exceptions
async def execute_insert(self, query: str, values: list) -> Optional[asyncpg.Record]:
Expand Down
34 changes: 20 additions & 14 deletions tortoise/backends/base/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,10 @@ class ConnectionWrapper(Generic[T_conn]):
"""Wraps the connections with a lock to facilitate safe concurrent access when using
asyncio.gather, TaskGroup, or similar."""

__slots__ = ("connection", "lock", "client")
__slots__ = ("connection", "_lock", "client")

def __init__(self, lock: asyncio.Lock, client: Any) -> None:
self.lock: asyncio.Lock = lock
self._lock: asyncio.Lock = lock
self.client = client
self.connection: T_conn = client._connection

Expand All @@ -235,12 +235,12 @@ async def ensure_connection(self) -> None:
self.connection = self.client._connection

async def __aenter__(self) -> T_conn:
await self.lock.acquire()
await self._lock.acquire()
await self.ensure_connection()
return self.connection

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.lock.release()
self._lock.release()


class TransactionContext(Generic[T_conn]):
Expand All @@ -259,15 +259,19 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ...
class TransactionContextPooled(TransactionContext):
"A version of TransactionContext that uses a pool to acquire connections."

__slots__ = ("conn_wrapper", "connection", "connection_name", "token")
__slots__ = ("connection", "connection_name", "token", "_pool_init_lock")

def __init__(self, connection: Any) -> None:
def __init__(self, connection: Any, pool_init_lock: asyncio.Lock) -> None:
self.connection = connection
self.connection_name = connection.connection_name
self._pool_init_lock = pool_init_lock

async def ensure_connection(self) -> None:
if not self.connection._parent._pool:
await self.connection._parent.create_connection(with_db=True)
# a safeguard against multiple concurrent tasks trying to initialize the pool
async with self._pool_init_lock:
if not self.connection._parent._pool:
await self.connection._parent.create_connection(with_db=True)

async def __aenter__(self) -> T_conn:
await self.ensure_connection()
Expand Down Expand Up @@ -315,25 +319,27 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
class PoolConnectionWrapper(Generic[T_conn]):
"""Class to manage acquiring from and releasing connections to a pool."""

def __init__(self, client: Any) -> None:
self.pool = client._pool
def __init__(self, client: Any, pool_init_lock: asyncio.Lock) -> None:
self.client = client
self.connection: Optional[T_conn] = None
self._pool_init_lock = pool_init_lock

async def ensure_connection(self) -> None:
if not self.pool:
await self.client.create_connection(with_db=True)
self.pool = self.client._pool
if not self.client._pool:
# a safeguard against multiple concurrent tasks trying to initialize the pool
async with self._pool_init_lock:
if not self.client._pool:
await self.client.create_connection(with_db=True)

async def __aenter__(self) -> T_conn:
await self.ensure_connection()
# get first available connection. If none available, wait until one is released
self.connection = await self.pool.acquire()
self.connection = await self.client._pool.acquire()
return cast(T_conn, self.connection)

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
# release the connection back to the pool
await self.pool.release(self.connection)
await self.client._pool.release(self.connection)


class BaseTransactionWrapper:
Expand Down
4 changes: 3 additions & 1 deletion tortoise/backends/base_postgres/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import asyncio
from asyncio.events import AbstractEventLoop
from functools import wraps
from typing import (
Expand Down Expand Up @@ -90,6 +91,7 @@ def __init__(
self._template: dict = {}
self._pool = None
self._connection = None
self._pool_init_lock = asyncio.Lock()

@abc.abstractmethod
async def create_connection(self, with_db: bool) -> None:
Expand Down Expand Up @@ -128,7 +130,7 @@ async def db_delete(self) -> None:
await self.close()

def acquire_connection(self) -> Union["ConnectionWrapper", "PoolConnectionWrapper"]:
return PoolConnectionWrapper(self._pool)
return PoolConnectionWrapper(self._pool, self._pool_init_lock)

@abc.abstractmethod
def _in_transaction(self) -> "TransactionContext":
Expand Down
2 changes: 1 addition & 1 deletion tortoise/backends/mssql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
self.dsn = f"DRIVER={driver};SERVER={host},{port};UID={user};PWD={password};"

def _in_transaction(self) -> "TransactionContext":
return TransactionContextPooled(TransactionWrapper(self))
return TransactionContextPooled(TransactionWrapper(self), self._pool_init_lock)

@translate_exceptions
async def execute_insert(self, query: str, values: list) -> int:
Expand Down
5 changes: 3 additions & 2 deletions tortoise/backends/mysql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
self._template: dict = {}
self._pool: Optional[mysql.Pool] = None
self._connection = None
self._pool_init_lock = asyncio.Lock()

async def create_connection(self, with_db: bool) -> None:
if charset_by_name(self.charset) is None:
Expand Down Expand Up @@ -172,10 +173,10 @@ async def db_delete(self) -> None:
await self.close()

def acquire_connection(self) -> Union["ConnectionWrapper", "PoolConnectionWrapper"]:
return PoolConnectionWrapper(self)
return PoolConnectionWrapper(self, self._pool_init_lock)

def _in_transaction(self) -> "TransactionContext":
return TransactionContextPooled(TransactionWrapper(self))
return TransactionContextPooled(TransactionWrapper(self), self._pool_init_lock)

@translate_exceptions
async def execute_insert(self, query: str, values: list) -> int:
Expand Down
3 changes: 2 additions & 1 deletion tortoise/backends/odbc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
self._template: dict = {}
self._pool: Optional[asyncodbc.Pool] = None
self._connection = None
self._pool_init_lock = asyncio.Lock()

async def create_connection(self, with_db: bool) -> None:
self._template = {
Expand Down Expand Up @@ -114,7 +115,7 @@ async def close(self) -> None:
self._pool = None

def acquire_connection(self) -> ConnWrapperType:
return PoolConnectionWrapper(self)
return PoolConnectionWrapper(self, self._pool_init_lock)

@translate_exceptions
async def execute_many(self, query: str, values: list) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tortoise/backends/oracle/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def __init__(
self.dsn = f"DRIVER={driver};DBQ={dbq};UID={user};PWD={password};"

def _in_transaction(self) -> "TransactionContext":
return TransactionContextPooled(TransactionWrapper(self))
return TransactionContextPooled(TransactionWrapper(self), self._pool_init_lock)

def acquire_connection(self) -> Union["ConnectionWrapper", "PoolConnectionWrapper"]:
return OraclePoolConnectionWrapper(self)
return OraclePoolConnectionWrapper(self, self._pool_init_lock)

async def db_create(self) -> None:
await self.create_connection(with_db=False)
Expand Down
4 changes: 2 additions & 2 deletions tortoise/backends/psycopg/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,10 @@ async def _translate_exceptions(self, func, *args, **kwargs) -> Exception:
def acquire_connection(
self,
) -> typing.Union[base_client.ConnectionWrapper, PoolConnectionWrapper]:
return PoolConnectionWrapper(self)
return PoolConnectionWrapper(self, self._pool_init_lock)

def _in_transaction(self) -> base_client.TransactionContext:
return base_client.TransactionContextPooled(TransactionWrapper(self))
return base_client.TransactionContextPooled(TransactionWrapper(self), self._pool_init_lock)


class TransactionWrapper(PsycopgClient, base_client.BaseTransactionWrapper):
Expand Down
2 changes: 1 addition & 1 deletion tortoise/backends/sqlite/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ async def ensure_connection(self) -> None:
self.connection._connection = self.connection._parent._connection

async def __aenter__(self) -> T_conn:
await self.ensure_connection()
await self._trxlock.acquire()
await self.ensure_connection()
self.token = connections.set(self.connection_name, self.connection)
await self.connection.begin()
return self.connection
Expand Down