Skip to content

Commit

Permalink
start using any io for better readability
Browse files Browse the repository at this point in the history
  • Loading branch information
07pepa committed Nov 7, 2024
1 parent 22e4acc commit 3cbdb88
Show file tree
Hide file tree
Showing 10 changed files with 134 additions and 118 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ For more installation options (eg: `aws`, `gcp`, `srv` ...) you can look in the
## Example

```python
import asyncio
import anyio
from typing import Optional

from motor.motor_asyncio import AsyncIOMotorClient
Expand Down Expand Up @@ -94,7 +94,7 @@ async def example():


if __name__ == "__main__":
asyncio.run(example())
anyio.run(example)
```

## Links
Expand Down
4 changes: 2 additions & 2 deletions beanie/executors/migrate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import asyncio
import logging
import os
import shutil
from datetime import datetime
from pathlib import Path
from typing import Any

import anyio
import click
import toml

Expand Down Expand Up @@ -197,7 +197,7 @@ def migrate(
settings_kwargs["use_transaction"] = use_transaction
settings = MigrationSettings(**settings_kwargs)

asyncio.run(run_migrate(settings))
anyio.run(run_migrate, settings)


@migrations.command()
Expand Down
79 changes: 42 additions & 37 deletions beanie/migrations/controllers/iterative.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
from functools import partial
from inspect import isclass, signature
from typing import Any, List, Optional, Type, Union

from anyio import create_task_group

from beanie.migrations.controllers.base import BaseMigrationController
from beanie.migrations.utils import update_dict
from beanie.odm.documents import Document
Expand Down Expand Up @@ -92,43 +94,46 @@ def models(self) -> List[Type[Document]]:

async def run(self, session):
output_documents = []
all_migration_ops = []
async for input_document in self.input_document_model.find_all(
session=session
):
output = DummyOutput()
function_kwargs = {
"input_document": input_document,
"output_document": output,
}
if "self" in self.function_signature.parameters:
function_kwargs["self"] = None
await self.function(**function_kwargs)
output_dict = (
input_document.dict()
if not IS_PYDANTIC_V2
else input_document.model_dump()
)
update_dict(output_dict, output.dict())
output_document = parse_model(
self.output_document_model, output_dict
)
output_documents.append(output_document)

if len(output_documents) == self.batch_size:
all_migration_ops.append(
self.output_document_model.replace_many(
documents=output_documents, session=session
)
async with create_task_group() as tg:
async for input_document in self.input_document_model.find_all(
session=session
):
output = DummyOutput()
function_kwargs = {
"input_document": input_document,
"output_document": output,
}
if "self" in self.function_signature.parameters:
function_kwargs["self"] = None
await self.function(**function_kwargs)
output_dict = (
input_document.dict()
if not IS_PYDANTIC_V2
else input_document.model_dump()
)
output_documents = []

if output_documents:
all_migration_ops.append(
self.output_document_model.replace_many(
documents=output_documents, session=session
update_dict(output_dict, output.dict())
output_document = parse_model(
self.output_document_model, output_dict
)
output_documents.append(output_document)

if len(output_documents) == self.batch_size:
tg.start_soon(
partial(
self.output_document_model.replace_many,
documents=output_documents,
session=session,
)
)
output_documents = []

if output_documents:
tg.start_soon(
partial(
self.output_document_model.replace_many,
documents=output_documents,
session=session,
)
)
)
await asyncio.gather(*all_migration_ops)

return IterativeMigration
36 changes: 20 additions & 16 deletions beanie/odm/actions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import inspect
from enum import Enum
from functools import wraps
from functools import partial, wraps
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -15,6 +14,7 @@
Union,
)

from anyio import create_task_group, to_thread
from typing_extensions import ParamSpec

if TYPE_CHECKING:
Expand Down Expand Up @@ -76,7 +76,7 @@ def add_action(
:param document_class: document class
:param event_types: List[EventTypes]
:param action_direction: ActionDirections - before or after
:param funct: Callable - function
:param funct: Callable - function must be either thread safe or async safe
"""
if cls._actions.get(document_class) is None:
cls._actions[document_class] = {
Expand Down Expand Up @@ -130,19 +130,23 @@ async def run_actions(
actions_list = cls.get_action_list(
document_class, event_type, action_direction
)
coros = []
for action in actions_list:
if action.__name__ in exclude:
continue

if inspect.iscoroutinefunction(action):
coros.append(action(instance))
elif inspect.isfunction(action):
action(instance)
await asyncio.gather(*coros)


# `Any` because there is arbitrary attribute assignment on this type
async with create_task_group() as tg:
for action in actions_list:
if action.__name__ in exclude:
continue
if inspect.iscoroutinefunction(action):
tg.start_soon(action, instance)
elif inspect.isfunction(action):
tg.start_soon(
partial(
to_thread.run_sync,
partial(action, instance),
abandon_on_cancel=True,
)
)


# `Any` because there is an arbitrary attribute assignment on this type
F = TypeVar("F", bound=Any)


Expand Down
88 changes: 45 additions & 43 deletions beanie/odm/documents.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import warnings
from datetime import datetime, timezone
from enum import Enum
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -20,6 +20,7 @@
)
from uuid import UUID, uuid4

from anyio import create_task_group
from bson import DBRef, ObjectId
from lazy_model import LazyModel
from motor.motor_asyncio import AsyncIOMotorClientSession
Expand Down Expand Up @@ -127,6 +128,10 @@
DocumentProjectionType = TypeVar("DocumentProjectionType", bound=BaseModel)


def only_documents(objets: List[Any]):
return filter(lambda obj: isinstance(obj, Document), objets)


def json_schema_extra(schema: Dict[str, Any], model: Type["Document"]) -> None:
# remove excluded fields from the json schema
properties = schema.get("properties")
Expand Down Expand Up @@ -353,16 +358,16 @@ async def insert(
LinkTypes.OPTIONAL_LIST,
]:
if isinstance(value, List):
await asyncio.gather(
*[
obj.save(
link_rule=WriteRules.WRITE,
session=session,
async with create_task_group() as tg:
for obj in only_documents(value):
tg.start_soon(
partial(
obj.save,
link_rule=WriteRules.WRITE,
session=session,
)
)
for obj in value
if isinstance(obj, Document)
]
)

result = await self.get_motor_collection().insert_one(
get_dict(
self, to_db=True, keep_nulls=self.get_settings().keep_nulls
Expand Down Expand Up @@ -513,18 +518,17 @@ async def replace(
LinkTypes.OPTIONAL_BACK_LIST,
]:
if isinstance(value, List):
await asyncio.gather(
*[
obj.replace(
link_rule=link_rule,
bulk_writer=bulk_writer,
ignore_revision=ignore_revision,
session=session,
async with create_task_group() as tg:
for obj in only_documents(value):
tg.start_soon(
partial(
obj.replace,
link_rule=link_rule,
bulk_writer=bulk_writer,
ignore_revision=ignore_revision,
session=session,
)
)
for obj in value
if isinstance(obj, Document)
]
)

use_revision_id = self.get_settings().use_revision
find_query: Dict[str, Any] = {"_id": self.id}
Expand Down Expand Up @@ -586,15 +590,15 @@ async def save(
LinkTypes.OPTIONAL_BACK_LIST,
]:
if isinstance(value, List):
await asyncio.gather(
*[
obj.save(
link_rule=link_rule, session=session
async with create_task_group() as tg:
for obj in only_documents(value):
tg.start_soon(
partial(
obj.save,
link_rule=link_rule,
session=session,
)
)
for obj in value
if isinstance(obj, Document)
]
)

if self.get_settings().keep_nulls is False:
return await self.update(
Expand Down Expand Up @@ -911,16 +915,15 @@ async def delete(
LinkTypes.OPTIONAL_BACK_LIST,
]:
if isinstance(value, List):
await asyncio.gather(
*[
obj.delete(
link_rule=DeleteRules.DELETE_LINKS,
**pymongo_kwargs,
async with create_task_group() as tg:
for obj in only_documents(value):
tg.start_soon(
partial(
obj.delete,
link_rule=DeleteRules.DELETE_LINKS,
**pymongo_kwargs,
)
)
for obj in value
if isinstance(obj, Document)
]
)

return await self.find_one({"_id": self.id}).delete(
session=session, bulk_writer=bulk_writer, **pymongo_kwargs
Expand Down Expand Up @@ -1182,12 +1185,11 @@ async def fetch_link(self, field: Union[str, Any]):
setattr(self, field, values)

async def fetch_all_links(self):
coros = []
link_fields = self.get_link_fields()
if link_fields is not None:
for ref in link_fields.values():
coros.append(self.fetch_link(ref.field_name)) # TODO lists
await asyncio.gather(*coros)
if link_fields is not None and len(link_fields.values()) > 0:
async with create_task_group() as tg:
for ref in link_fields.values():
tg.start_soon(self.fetch_link, ref.field_name)

@classmethod
def get_link_fields(cls) -> Optional[Dict[str, LinkInfo]]:
Expand Down
10 changes: 5 additions & 5 deletions beanie/odm/fields.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import asyncio
from collections import OrderedDict
from dataclasses import dataclass
from enum import Enum
Expand All @@ -18,6 +17,7 @@
)
from typing import OrderedDict as OrderedDictType

from anyio import create_task_group
from bson import DBRef, ObjectId
from bson.errors import InvalidId
from pydantic import BaseModel
Expand Down Expand Up @@ -363,10 +363,10 @@ def repack_links(

@classmethod
async def fetch_many(cls, links: List[Link]):
coros = []
for link in links:
coros.append(link.fetch())
return await asyncio.gather(*coros)
if links:
async with create_task_group() as tg:
for link in links:
tg.start_soon(link.fetch)

if IS_PYDANTIC_V2:

Expand Down
Loading

0 comments on commit 3cbdb88

Please sign in to comment.