Skip to content

Commit

Permalink
fix: improve test infrastructure and fix event loop issues in integra…
Browse files Browse the repository at this point in the history
…tion tests
  • Loading branch information
lfnovo committed Jan 3, 2025
1 parent ab0fd3e commit 101351a
Show file tree
Hide file tree
Showing 8 changed files with 424 additions and 251 deletions.
12 changes: 10 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
.PHONY: ruff lint test
.PHONY: ruff lint test test-integration test-unit test-all

lint:
uv run python -m mypy .

ruff:
ruff check . --fix

test:
test-integration:
uv run pytest -v -m integration

test-unit:
uv run pytest -v -m "not integration"

test-all:
uv run pytest -v

test: test-unit
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ services:
surrealdb:
image: surrealdb/surrealdb:latest
ports:
- "8000:8000"
- "8013:8000"
volumes:
- ./surreal_data:/mydata
# environment:
Expand Down
12 changes: 9 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "surrantic"
version = "0.1.5"
version = "0.1.7"
description = "A simple Pydantic ORM implementation for SurrealDB"
readme = "README.md"
authors = [
Expand All @@ -21,8 +21,9 @@ classifiers = [
]
dependencies = [
"pydantic>=2.0.0",
"pytest-mock>=3.14.0",
"python-dotenv>=1.0.0",
"surrealdb>=0.3.0",
"surrealdb>=0.4.1",
]

[project.urls]
Expand All @@ -49,7 +50,12 @@ check_untyped_defs = true
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
addopts = "-ra -q --cov=surrantic"
addopts = "-v -m 'not integration'"
markers = [
"integration: marks tests as integration tests"
]
asyncio_mode = "strict"
asyncio_default_fixture_loop_scope = "function"

[dependency-groups]
dev = [
Expand Down
136 changes: 72 additions & 64 deletions src/surrantic/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,16 @@ def _prepare_data(obj: BaseModel) -> str:
return "{ " + ", ".join(items) + " }"

def _log_query(query: str, result: Any = None) -> None:
"""Log query and result if debug is enabled"""
config = SurranticConfig.get_instance()
if config.debug:
logger.debug("Query: %s", query)
if result is not None:
logger.debug("Result: %s", result)
"""Log a query and its result.
Args:
query: The query to log
result: Optional result to log
"""
logger.debug(f"Query: {query}")
if result is not None:
logger.debug(f"Result type: {type(result)}")
logger.debug(f"Result: {result}")

class SurranticConfig:
"""Configuration class for Surrantic database connection.
Expand Down Expand Up @@ -146,7 +150,7 @@ def _format_datetime_z(dt: datetime) -> str:
@asynccontextmanager
async def _get_db(cls) -> AsyncGenerator[AsyncSurrealDB, None]:
"""Get a configured database connection as a context manager.
Yields:
AsyncSurrealDB: The configured database connection
"""
Expand All @@ -156,17 +160,17 @@ async def _get_db(cls) -> AsyncGenerator[AsyncSurrealDB, None]:
await db.connect()
await db.sign_in(config.user, config.password)
await db.use(config.namespace, config.database)
logger.debug("Database connection established")
_log_query("Database connection established")
yield db
finally:
_log_query("Database connection closed")
await db.close()
logger.debug("Database connection closed")

@classmethod
@contextmanager
def _get_sync_db(cls) -> Generator[SurrealDB, None, None]:
"""Get a configured synchronous database connection as a context manager.
Yields:
SurrealDB: The configured database connection
"""
Expand All @@ -176,80 +180,90 @@ def _get_sync_db(cls) -> Generator[SurrealDB, None, None]:
db.connect()
db.sign_in(config.user, config.password)
db.use(config.namespace, config.database)
logger.debug("Database connection established")
_log_query("Database connection established")
yield db
finally:
_log_query("Database connection closed")
db.close()
logger.debug("Database connection closed")

@classmethod
async def aget_all(cls: Type[T], order_by: Optional[str] = None, order_direction: Optional[str] = None) -> List[T]:
"""Asynchronously retrieve all records from the table.
Args:
order_by: Optional field name to order results by
order_direction: Optional direction ('ASC' or 'DESC') for ordering
Returns:
List of model instances
Raises:
ValueError: If table_name is not set
RuntimeError: If the database operation fails
"""
if not cls.table_name:
raise ValueError("table_name must be set")

query = f"SELECT * FROM {cls.table_name}"
if order_by:
direction = order_direction or "ASC"
query += f" ORDER BY {order_by} {direction}"

_log_query(query)
async with cls._get_db() as db:
result = await db.query(query)
_log_query(query, result)
return [cls(**item) for item in result[0]["result"]]
if result and len(result) > 0 and 'result' in result[0] and len(result[0]['result']) > 0:
items = []
for item in result[0]['result']:
items.append(cls(**item))
return items
return []

@classmethod
def get_all(cls: Type[T], order_by: Optional[str] = None, order_direction: Optional[str] = None) -> List[T]:
"""Synchronously retrieve all records from the table.
Args:
order_by: Optional field name to order results by
order_direction: Optional direction ('ASC' or 'DESC') for ordering
Returns:
List of model instances
Raises:
ValueError: If table_name is not set
RuntimeError: If the database operation fails
"""
if not cls.table_name:
raise ValueError("table_name must be set")

query = f"SELECT * FROM {cls.table_name}"
if order_by:
direction = order_direction or "ASC"
query += f" ORDER BY {order_by} {direction}"

_log_query(query)
with cls._get_sync_db() as db:
result = db.query(query)
_log_query(query, result)
return [cls(**item) for item in result[0]["result"]]
if result and len(result) > 0 and 'result' in result[0] and len(result[0]['result']) > 0:
items = []
for item in result[0]['result']:
items.append(cls(**item))
return items
return []

@classmethod
async def aget(cls: Type[T], id: Union[str, RecordID]) -> Optional[T]:
"""Asynchronously retrieve a single record by ID.
Args:
id: The record ID to retrieve, either as string or RecordID
Returns:
Model instance if found, None otherwise
Raises:
RuntimeError: If the database operation fails
"""
Expand All @@ -258,20 +272,21 @@ async def aget(cls: Type[T], id: Union[str, RecordID]) -> Optional[T]:
async with cls._get_db() as db:
result = await db.query(query)
_log_query(query, result)
if result and result[0]:
return cls(**result[0][0])
return None
if result and len(result) > 0 and 'result' in result[0] and len(result[0]['result']) > 0:
item = result[0]['result'][0]
return cls(**item)
return None

@classmethod
def get(cls: Type[T], id: Union[str, RecordID]) -> Optional[T]:
"""Synchronously retrieve a single record by ID.
Args:
id: The record ID to retrieve, either as string or RecordID
Returns:
Model instance if found, None otherwise
Raises:
RuntimeError: If the database operation fails
"""
Expand All @@ -280,75 +295,68 @@ def get(cls: Type[T], id: Union[str, RecordID]) -> Optional[T]:
with cls._get_sync_db() as db:
result = db.query(query)
_log_query(query, result)
if result and result[0]:
return cls(**result[0][0])
return None
if result and len(result) > 0 and 'result' in result[0] and len(result[0]['result']) > 0:
item = result[0]['result'][0]
return cls(**item)
return None

async def asave(self) -> None:
"""Asynchronously save or update the record in SurrealDB.
Updates the created and updated timestamps automatically.
Creates a new record if id is None, otherwise updates existing record.
"""Asynchronously save the model to the database.
Raises:
Exception: If table_name is not defined
ValueError: If table_name is not set
RuntimeError: If the database operation fails
"""
if not self.table_name:
raise ValueError("table_name must be set")

now = datetime.now(timezone.utc)
self.updated = datetime.now(timezone.utc)
if not self.created:
self.created = now
self.updated = now
self.created = self.updated

data = _prepare_data(self)
if self.id:
query = f"UPDATE {self.id} SET {data}"
query = f"UPDATE {self.id} CONTENT {data}"
else:
query = f"CREATE {self.table_name} SET {data}"
query = f"CREATE {self.table_name} CONTENT {data}"

_log_query(query)
async with self._get_db() as db:
result = await db.query(query)
_log_query(query, result)
if result and result[0]:
self.id = RecordID.from_string(result[0][0]["id"])
if result and len(result) > 0 and 'result' in result[0] and len(result[0]['result']) > 0:
self.id = result[0]['result'][0]["id"]

def save(self) -> None:
"""Synchronously save or update the record in SurrealDB.
Updates the created and updated timestamps automatically.
Creates a new record if id is None, otherwise updates existing record.
"""Synchronously save the model to the database.
Raises:
Exception: If table_name is not defined
ValueError: If table_name is not set
RuntimeError: If the database operation fails
"""
if not self.table_name:
raise ValueError("table_name must be set")

now = datetime.now(timezone.utc)
self.updated = datetime.now(timezone.utc)
if not self.created:
self.created = now
self.updated = now
self.created = self.updated

data = _prepare_data(self)
if self.id:
query = f"UPDATE {self.id} SET {data}"
query = f"UPDATE {self.id} CONTENT {data}"
else:
query = f"CREATE {self.table_name} SET {data}"
query = f"CREATE {self.table_name} CONTENT {data}"

_log_query(query)
with self._get_sync_db() as db:
result = db.query(query)
_log_query(query, result)
if result and result[0]:
self.id = RecordID.from_string(result[0][0]["id"])
if result and len(result) > 0 and 'result' in result[0] and len(result[0]['result']) > 0:
self.id = result[0]['result'][0]["id"]

async def adelete(self) -> None:
"""Asynchronously delete the record from the database.
Raises:
ValueError: If the record has no ID
RuntimeError: If the database operation fails
Expand All @@ -364,7 +372,7 @@ async def adelete(self) -> None:

def delete(self) -> None:
"""Synchronously delete the record from the database.
Raises:
ValueError: If the record has no ID
RuntimeError: If the database operation fails
Expand Down
1 change: 1 addition & 0 deletions tests/integration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Loading

0 comments on commit 101351a

Please sign in to comment.