Skip to content

Commit

Permalink
fix: update query generation to fix SurrealDB syntax issues
Browse files Browse the repository at this point in the history
- Fix SET syntax to use = instead of :
- Fix record ID handling in UPDATE queries
- Add better handling of different response formats
  • Loading branch information
lfnovo committed Jan 3, 2025
1 parent 101351a commit 86def76
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 32 deletions.
3 changes: 2 additions & 1 deletion .windsurfrules
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ This project uses uv for dependency management. That means you should use:
- `uv build` for building


After we are done with a new features, we:
After we are done with a new feature, I can ask you to publish. In that case, we:
- Run the complete test suite
- Bump the version in @pyproject.toml
- Commit and push, merge to main.
- Generate a new tag
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "surrantic"
version = "0.1.7"
version = "0.1.8"
description = "A simple Pydantic ORM implementation for SurrealDB"
readme = "README.md"
authors = [
Expand Down
71 changes: 44 additions & 27 deletions src/surrantic/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ def _prepare_value(value: Any) -> str:
return str(value)
return json.dumps(value)

def _prepare_data(obj: BaseModel) -> str:
"""Convert Pydantic model to SurrealQL object format using model fields"""
def _prepare_data(obj: Any) -> str:
"""Prepare data for database query."""
items = []
for field_name, field in obj.model_fields.items():
for field_name in obj.model_fields:
value = getattr(obj, field_name)
if value is not None:
items.append(f"{field_name}: {_prepare_value(value)}")
return "{ " + ", ".join(items) + " }"
items.append(f"{field_name} = {_prepare_value(value)}")
return ", ".join(items)

def _log_query(query: str, result: Any = None) -> None:
"""Log a query and its result.
Expand Down Expand Up @@ -301,58 +301,75 @@ def get(cls: Type[T], id: Union[str, RecordID]) -> Optional[T]:
return None

async def asave(self) -> None:
"""Asynchronously save the model to the database.
Raises:
ValueError: If table_name is not set
RuntimeError: If the database operation fails
"""
"""Save the model asynchronously."""
if not self.table_name:
raise ValueError("table_name must be set")

if not self.created:
self.created = datetime.now(timezone.utc)
self.updated = datetime.now(timezone.utc)
if not self.created:
self.created = self.updated

data = _prepare_data(self)
query = f"UPDATE {self.table_name} SET {data}"
if self.id:
query = f"UPDATE {self.id} CONTENT {data}"
# Extract just the record ID part without the table name
record_id = str(self.id).split(":")[-1]
query = f"UPDATE {self.table_name}:{record_id} SET {data}"
else:
query = f"CREATE {self.table_name} CONTENT {data}"
query = f"CREATE {self.table_name} SET {data}"

_log_query(query)
async with self._get_db() as db:
result = await db.query(query)
_log_query(query, result)
if result and len(result) > 0 and 'result' in result[0] and len(result[0]['result']) > 0:
self.id = result[0]['result'][0]["id"]
if result and len(result) > 0 and 'result' in result[0]:
result_data = result[0]['result']
if isinstance(result_data, str):
# If the result is a string (record ID), use it directly
self.id = result_data
elif isinstance(result_data, list) and len(result_data) > 0:
# If the result is a list with dictionary items
if isinstance(result_data[0], dict):
self.id = result_data[0]["id"]
else:
self.id = result_data[0]
logger.debug(f"asave result: {result}")

def save(self) -> None:
"""Synchronously save the model to the database.
Raises:
ValueError: If table_name is not set
RuntimeError: If the database operation fails
"""
"""Synchronously save the model to the database."""
if not self.table_name:
raise ValueError("table_name must be set")

if not self.created:
self.created = datetime.now(timezone.utc)
self.updated = datetime.now(timezone.utc)
if not self.created:
self.created = self.updated

data = _prepare_data(self)
query = f"UPDATE {self.table_name} SET {data}"
if self.id:
query = f"UPDATE {self.id} CONTENT {data}"
# Extract just the record ID part without the table name
record_id = str(self.id).split(":")[-1]
query = f"UPDATE {self.table_name}:{record_id} SET {data}"
else:
query = f"CREATE {self.table_name} CONTENT {data}"
query = f"CREATE {self.table_name} SET {data}"

_log_query(query)
with self._get_sync_db() as db:
result = db.query(query)
_log_query(query, result)
if result and len(result) > 0 and 'result' in result[0] and len(result[0]['result']) > 0:
self.id = result[0]['result'][0]["id"]
if result and len(result) > 0 and 'result' in result[0]:
result_data = result[0]['result']
if isinstance(result_data, str):
# If the result is a string (record ID), use it directly
self.id = result_data
elif isinstance(result_data, list) and len(result_data) > 0:
# If the result is a list with dictionary items
if isinstance(result_data[0], dict):
self.id = result_data[0]["id"]
else:
self.id = result_data[0]

async def adelete(self) -> None:
"""Asynchronously delete the record from the database.
Expand Down
4 changes: 2 additions & 2 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def test_prepare_data() -> None:
model = TestModel(name="Test", age=25)
from surrantic.base import _prepare_data
data = _prepare_data(model)
assert 'name: "Test"' in data
assert 'age: 25' in data
assert 'name = "Test"' in data
assert 'age = 25' in data

def test_format_datetime_z() -> None:
dt = datetime(2023, 1, 1, tzinfo=timezone.utc)
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 86def76

Please sign in to comment.