Skip to content

Commit

Permalink
Upgrade pydantic>=2.6.1 (amun-ai#597)
Browse files Browse the repository at this point in the history
* Upgrade pydantic>=2.6.1

* Fix pydantic

* Fix expires_in and at
  • Loading branch information
oeway authored Mar 3, 2024
1 parent f0ff414 commit f42ac6e
Show file tree
Hide file tree
Showing 15 changed files with 132 additions and 133 deletions.
2 changes: 1 addition & 1 deletion hypha/VERSION
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"version": "0.15.45"
"version": "0.15.46"
}
20 changes: 10 additions & 10 deletions hypha/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ async def list_apps(
workspace = context["from"].split("/")[0]

workspace = await self.store.get_workspace(workspace)
return [app_info.dict() for app_info in workspace.applications.values()]
return [app_info.model_dump() for app_info in workspace.applications.values()]

async def save_application(
self,
Expand Down Expand Up @@ -278,7 +278,7 @@ async def save_file(key, content):
await save_file(f"{app_dir}/{att['name']}", att["source"])
files.append(att["name"])

content = json.dumps(rdf.dict(), indent=4)
content = json.dumps(rdf.model_dump(), indent=4)
await save_file(f"{app_dir}/rdf.json", content)
logger.info("Saved application (%s)to workspace: %s", mhash, workspace)

Expand Down Expand Up @@ -328,7 +328,7 @@ async def download_file(key, local_path):
async with aiofiles.open(
local_app_dir / "rdf.json", "r", encoding="utf-8"
) as fil:
rdf = RDF.parse_obj(json.loads(await fil.read()))
rdf = RDF.model_validate(json.loads(await fil.read()))

if rdf.attachments:
files = rdf.attachments.get("files")
Expand Down Expand Up @@ -385,7 +385,7 @@ async def install(
if not workspace:
workspace = context["from"].split("/")[0]

user_info = UserInfo.parse_obj(context["user"])
user_info = UserInfo.model_validate(context["user"])
workspace = await self.store.get_workspace(workspace)

if not await self.store.check_permission(workspace, user_info):
Expand Down Expand Up @@ -456,10 +456,10 @@ async def install(
"public_url": public_url,
}
)
rdf = RDF.parse_obj(rdf_obj)
rdf = RDF.model_validate(rdf_obj)
await self.save_application(app_id, rdf, source, attachments)
ws = await self.store.get_workspace_interface(workspace.name)
await ws.install_application(rdf.dict())
await ws.install_application(rdf.model_dump())
return rdf_obj

async def uninstall(self, app_id: str, context: Optional[dict] = None) -> None:
Expand All @@ -472,7 +472,7 @@ async def uninstall(self, app_id: str, context: Optional[dict] = None) -> None:
workspace_name, mhash = app_id.split("/")
workspace = await self.store.get_workspace(workspace_name)

user_info = UserInfo.parse_obj(context["user"])
user_info = UserInfo.model_validate(context["user"])
if not await self.store.check_permission(workspace, user_info):
raise Exception(
f"User {user_info.id} does not have permission"
Expand Down Expand Up @@ -523,7 +523,7 @@ async def launch(

def _client_deleted(self, client: dict) -> None:
"""Called when client is deleted."""
client = ClientInfo.parse_obj(client)
client = ClientInfo.model_validate(client)
page_id = f"{client.workspace}/{client.id}"
if page_id in self._client_callbacks:
callbacks = self._client_callbacks[page_id]
Expand All @@ -534,7 +534,7 @@ def _client_deleted(self, client: dict) -> None:

def _client_updated(self, client: dict) -> None:
"""Called when client is updated."""
client = ClientInfo.parse_obj(client)
client = ClientInfo.model_validate(client)
page_id = f"{client.workspace}/{client.id}"
if page_id in self._client_callbacks:
callbacks = self._client_callbacks[page_id]
Expand Down Expand Up @@ -572,7 +572,7 @@ async def start(
ws = await self.store.get_workspace_interface(workspace)
token = await ws.generate_token({"parent_client": context["from"]})

user_info = UserInfo.parse_obj(context["user"])
user_info = UserInfo.model_validate(context["user"])
if not await self.store.check_permission(workspace, user_info):
raise Exception(
f"User {user_info.id} does not have permission"
Expand Down
6 changes: 3 additions & 3 deletions hypha/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.service.config.get("require_context"):
authorization = scope["headers"].get("authorization")
user_info = parse_token(authorization, allow_anonymouse=True)
result = await func(scope, {"user": user_info.dict()})
result = await func(scope, {"user": user_info.model_dump()})
else:
result = await func(scope)
headers = Headers(headers=result.get("headers"))
Expand Down Expand Up @@ -152,7 +152,7 @@ def __init__(

async def mount_asgi_app(self, service: dict):
"""Mount the ASGI apps from new services."""
service = ServiceInfo.parse_obj(service)
service = ServiceInfo.model_validate(service)

if service.type in ["ASGI", "functions"]:
workspace = service.config.workspace
Expand All @@ -177,7 +177,7 @@ async def mount_asgi_app(self, service: dict):

async def umount_asgi_app(self, service: dict):
"""Unmount the ASGI apps."""
service = ServiceInfo.parse_obj(service)
service = ServiceInfo.model_validate(service)
if service.type in ["ASGI", "functions"]:
service_id = service.id
if ":" in service_id: # Remove client_id
Expand Down
63 changes: 31 additions & 32 deletions hypha/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from pydantic import ( # pylint: disable=no-name-in-module
BaseModel,
EmailStr,
Extra,
PrivateAttr,
constr,
)
Expand All @@ -29,9 +28,9 @@ class TokenConfig(BaseModel):
"""Represent a token configuration."""

scopes: List[str]
expires_in: Optional[int]
email: Optional[EmailStr]
parent_client: Optional[str]
expires_in: Optional[float] = None
email: Optional[EmailStr] = None
parent_client: Optional[str] = None


class VisibilityEnum(str, Enum):
Expand Down Expand Up @@ -71,7 +70,7 @@ class ServiceInfo(BaseModel):
class Config:
"""Set the config for pydantic."""

extra = Extra.allow
extra='allow'

def is_singleton(self):
"""Check if the service is singleton."""
Expand All @@ -84,10 +83,10 @@ class UserInfo(BaseModel):
id: str
roles: List[str]
is_anonymous: bool
email: Optional[EmailStr]
parent: Optional[str]
scopes: Optional[List[str]] # a list of workspace
expires_at: Optional[int]
email: Optional[EmailStr] = None
parent: Optional[str] = None
scopes: Optional[List[str]] = None # a list of workspace
expires_at: Optional[float] = None
_metadata: Dict[str, Any] = PrivateAttr(
default_factory=lambda: {}
) # e.g. s3 credential
Expand All @@ -107,8 +106,8 @@ class ClientInfo(BaseModel):
"""Represent service."""

id: str
parent: Optional[str]
name: Optional[str]
parent: Optional[str] = None
name: Optional[str] = None
workspace: str
services: List[ServiceInfo] = []
user_info: UserInfo
Expand All @@ -120,25 +119,25 @@ class RDF(BaseModel):
name: str
id: str
tags: List[str]
documentation: Optional[str]
covers: Optional[List[str]]
badges: Optional[List[str]]
authors: Optional[List[Dict[str, str]]]
attachments: Optional[Dict[str, List[Any]]]
config: Optional[Dict[str, Any]]
documentation: Optional[str] = None
covers: Optional[List[str]] = None
badges: Optional[List[str]] = None
authors: Optional[List[Dict[str, str]]] = None
attachments: Optional[Dict[str, List[Any]]] = None
config: Optional[Dict[str, Any]] = None
type: str
format_version: str = "0.2.1"
version: str = "0.1.0"
links: Optional[List[str]]
maintainers: Optional[List[Dict[str, str]]]
license: Optional[str]
git_repo: Optional[str]
source: Optional[str]
links: Optional[List[str]] = None
maintainers: Optional[List[Dict[str, str]]] = None
license: Optional[str] = None
git_repo: Optional[str] = None
source: Optional[str] = None

class Config:
"""Set the config for pydantic."""

extra = Extra.allow
extra='allow'


class ApplicationInfo(RDF):
Expand All @@ -154,15 +153,15 @@ class WorkspaceInfo(BaseModel):
persistent: bool
owners: List[str]
visibility: VisibilityEnum
description: Optional[str]
icon: Optional[str]
covers: Optional[List[str]]
docs: Optional[str]
allow_list: Optional[List[str]]
deny_list: Optional[List[str]]
read_only: bool = False
applications: Dict[str, RDF] = {} # installed applications
interfaces: Dict[str, List[Any]] = {}
description: Optional[str] = None
icon: Optional[str] = None
covers: Optional[List[str]] = None
docs: Optional[str] = None
allow_list: Optional[List[str]] = None
deny_list: Optional[List[str]] = None
applications: Optional[Dict[str, RDF]] = {} # installed applications
interfaces: Optional[Dict[str, List[Any]]] = {}


class RedisRPCConnection:
Expand All @@ -182,7 +181,7 @@ def __init__(
self._workspace = workspace
self._client_id = client_id
assert "/" not in client_id
self._user_info = user_info.dict()
self._user_info = user_info.model_dump()

def on_message(self, handler: Callable):
"""Setting message handler."""
Expand Down
34 changes: 17 additions & 17 deletions hypha/core/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
self.local_base_url = local_base_url
self._public_services: List[ServiceInfo] = []
self._ready = False
self._public_workspace = WorkspaceInfo.parse_obj(
self._public_workspace = WorkspaceInfo.model_validate(
{
"name": "public",
"persistent": True,
Expand Down Expand Up @@ -140,7 +140,7 @@ async def init(self, reset_redis, startup_functions=None):
await self.cleanup_disconnected_clients()
for service in self._public_services:
try:
await self._public_workspace_interface.register_service(service.dict())
await self._public_workspace_interface.register_service(service.model_dump())
except Exception: # pylint: disable=broad-except
logger.exception("Failed to register public service: %s", service)
raise
Expand All @@ -153,14 +153,14 @@ async def init(self, reset_redis, startup_functions=None):

async def _register_public_service(self, service: dict):
"""Register the public service."""
service = ServiceInfo.parse_obj(service)
service = ServiceInfo.model_validate(service)
assert ":" in service.id and service.config.workspace
# Add public service to the registry
if (
service.config.visibility == VisibilityEnum.public
and not service.id.endswith(":built-in")
):
service_dict = service.dict()
service_dict = service.model_dump()
service_summary = {k: service_dict[k] for k in SERVICE_SUMMARY_FIELD}
if "/" not in service.id:
service_id = service.config.workspace + "/" + service.id
Expand All @@ -171,7 +171,7 @@ async def _register_public_service(self, service: dict):
)

async def _unregister_public_service(self, service: dict):
service = ServiceInfo.parse_obj(service)
service = ServiceInfo.model_validate(service)
if (
service.config.visibility == VisibilityEnum.public
and not service.id.endswith(":built-in")
Expand All @@ -187,7 +187,7 @@ async def add_disconnected_client(self, client_info: ClientInfo):
await self._redis.hset(
"clients:disconnected",
f"{client_info.workspace}/{client_info.id}",
json.dumps({"client": client_info.json(), "timestamp": time.time()}),
json.dumps({"client": client_info.model_dump_json(), "timestamp": time.time()}),
)

async def remove_disconnected_client(
Expand Down Expand Up @@ -225,35 +225,35 @@ async def cleanup_disconnected_clients(self):

async def register_user(self, user_info: UserInfo):
"""Register a user."""
await self._redis.hset("users", user_info.id, user_info.json())
await self._redis.hset("users", user_info.id, user_info.model_dump_json())

async def get_user(self, user_id: str):
"""Get a user."""
user_info = await self._redis.hget("users", user_id)
if user_info is None:
return None
return UserInfo.parse_obj(json.loads(user_info.decode()))
return UserInfo.model_validate(json.loads(user_info.decode()))

async def get_user_workspace(self, user_id: str):
"""Get a user."""
workspace_info = await self._redis.hget("workspaces", user_id)
if workspace_info is None:
return None
workspace_info = WorkspaceInfo.parse_obj(json.loads(workspace_info.decode()))
workspace_info = WorkspaceInfo.model_validate(json.loads(workspace_info.decode()))
return workspace_info

async def get_all_users(self):
"""Get all users."""
users = await self._redis.hgetall("users")
return [
UserInfo.parse_obj(json.loads(user.decode())) for user in users.values()
UserInfo.model_validate(json.loads(user.decode())) for user in users.values()
]

async def get_all_workspace(self):
"""Get all workspaces."""
workspaces = await self._redis.hgetall("workspaces")
return [
WorkspaceInfo.parse_obj(json.loads(v.decode()))
WorkspaceInfo.model_validate(json.loads(v.decode()))
for k, v in workspaces.items()
]

Expand All @@ -263,7 +263,7 @@ async def workspace_exists(self, workspace_name: str):

async def register_workspace(self, workspace: dict, overwrite=False):
"""Add a workspace."""
workspace = WorkspaceInfo.parse_obj(workspace)
workspace = WorkspaceInfo.model_validate(workspace)
if not overwrite and await self._redis.hexists("workspaces", workspace.name):
raise RuntimeError(f"Workspace {workspace.name} already exists.")
if overwrite:
Expand All @@ -278,17 +278,17 @@ async def register_workspace(self, workspace: dict, overwrite=False):
raise KeyError(
f"Client does not exist: {workspace.name}/{client_id}"
)
client_info = ClientInfo.parse_obj(json.loads(client_info.decode()))
client_info = ClientInfo.model_validate(json.loads(client_info.decode()))
await self._redis.srem(
f"user:{client_info.user_info.id}:clients", client_info.id
)
# assert ret >= 1, f"Client not found in user({client_info.user_info.id})'s clients list: {client_info.id}"
await self._redis.hdel(f"{workspace.name}:clients", client_id)
await self._redis.delete(f"{workspace}:clients")
await self._redis.hset("workspaces", workspace.name, workspace.json())
await self._redis.hset("workspaces", workspace.name, workspace.model_dump_json())
await self.get_workspace_manager(workspace.name, setup=True)

self._event_bus.emit("workspace_registered", workspace.dict())
self._event_bus.emit("workspace_registered", workspace.model_dump())

async def connect_to_workspace(self, workspace: str, client_id: str):
"""Connect to a workspace."""
Expand Down Expand Up @@ -443,9 +443,9 @@ def register_public_service(self, service: dict):
assert (
"require_context" not in service
), "`require_context` should be placed inside `config`"
formated_service = ServiceInfo.parse_obj(service)
formated_service = ServiceInfo.model_validate(service)
# Note: service can set its `visibility` to `public` or `protected`
self._public_services.append(ServiceInfo.parse_obj(formated_service))
self._public_services.append(ServiceInfo.model_validate(formated_service))
return {
"id": formated_service.id,
"workspace": "public",
Expand Down
Loading

0 comments on commit f42ac6e

Please sign in to comment.