forked from Significant-Gravitas/AutoGPT
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(libs): Add integration credentials store (Significant-Gravitas#7826
) - Add `SupabaseIntegrationCredentialsStore` in `.supabase_integration_credentials_store` - Add `supabase` dependency - Add `pydantic` dependency --------- Co-authored-by: Reinier van der Leer <[email protected]>
- Loading branch information
1 parent
012bad7
commit 95af63b
Showing
5 changed files
with
1,140 additions
and
1 deletion.
There are no files selected for viewing
8 changes: 8 additions & 0 deletions
8
rnd/autogpt_libs/autogpt_libs/supabase_integration_credentials_store/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from .store import SupabaseIntegrationCredentialsStore | ||
from .types import APIKeyCredentials, OAuth2Credentials | ||
|
||
__all__ = [ | ||
"SupabaseIntegrationCredentialsStore", | ||
"APIKeyCredentials", | ||
"OAuth2Credentials", | ||
] |
91 changes: 91 additions & 0 deletions
91
rnd/autogpt_libs/autogpt_libs/supabase_integration_credentials_store/store.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from typing import cast | ||
|
||
from supabase import Client, create_client | ||
|
||
from .types import Credentials, OAuth2Credentials, UserMetadata, UserMetadataRaw | ||
|
||
|
||
class SupabaseIntegrationCredentialsStore: | ||
def __init__(self, url: str, key: str): | ||
self.supabase: Client = create_client(url, key) | ||
|
||
def add_creds(self, user_id: str, credentials: Credentials) -> None: | ||
if self.get_creds_by_id(user_id, credentials.id): | ||
raise ValueError( | ||
f"Can not re-create existing credentials with ID {credentials.id} " | ||
f"for user with ID {user_id}" | ||
) | ||
self._set_user_integration_creds( | ||
user_id, [*self.get_all_creds(user_id), credentials] | ||
) | ||
|
||
def get_all_creds(self, user_id: str) -> list[Credentials]: | ||
user_metadata = self._get_user_metadata(user_id) | ||
return UserMetadata.model_validate(user_metadata).integration_credentials | ||
|
||
def get_creds_by_id(self, user_id: str, credentials_id: str) -> Credentials | None: | ||
credentials = self.get_all_creds(user_id) | ||
return next((c for c in credentials if c.id == credentials_id), None) | ||
|
||
def get_creds_by_provider(self, user_id: str, provider: str) -> list[Credentials]: | ||
credentials = self.get_all_creds(user_id) | ||
return [c for c in credentials if c.provider == provider] | ||
|
||
def get_authorized_providers(self, user_id: str) -> list[str]: | ||
credentials = self.get_all_creds(user_id) | ||
return list(set(c.provider for c in credentials)) | ||
|
||
def update_creds(self, user_id: str, updated: Credentials) -> None: | ||
current = self.get_creds_by_id(user_id, updated.id) | ||
if not current: | ||
raise ValueError( | ||
f"Credentials with ID {updated.id} " | ||
f"for user with ID {user_id} not found" | ||
) | ||
if type(current) is not type(updated): | ||
raise TypeError( | ||
f"Can not update credentials with ID {updated.id} " | ||
f"from type {type(current)} " | ||
f"to type {type(updated)}" | ||
) | ||
|
||
# Ensure no scopes are removed when updating credentials | ||
if ( | ||
isinstance(updated, OAuth2Credentials) | ||
and isinstance(current, OAuth2Credentials) | ||
and not set(updated.scopes).issuperset(current.scopes) | ||
): | ||
raise ValueError( | ||
f"Can not update credentials with ID {updated.id} " | ||
f"and scopes {current.scopes} " | ||
f"to more restrictive set of scopes {updated.scopes}" | ||
) | ||
|
||
# Update the credentials | ||
updated_credentials_list = [ | ||
updated if c.id == updated.id else c for c in self.get_all_creds(user_id) | ||
] | ||
self._set_user_integration_creds(user_id, updated_credentials_list) | ||
|
||
def delete_creds_by_id(self, user_id: str, credentials_id: str) -> None: | ||
filtered_credentials = [ | ||
c for c in self.get_all_creds(user_id) if c.id != credentials_id | ||
] | ||
self._set_user_integration_creds(user_id, filtered_credentials) | ||
|
||
def _set_user_integration_creds( | ||
self, user_id: str, credentials: list[Credentials] | ||
) -> None: | ||
raw_metadata = self._get_user_metadata(user_id) | ||
raw_metadata.update( | ||
{"integration_credentials": [c.model_dump() for c in credentials]} | ||
) | ||
self.supabase.auth.admin.update_user_by_id( | ||
user_id, {"user_metadata": raw_metadata} | ||
) | ||
|
||
def _get_user_metadata(self, user_id: str) -> UserMetadataRaw: | ||
response = self.supabase.auth.admin.get_user_by_id(user_id) | ||
if not response.user: | ||
raise ValueError(f"User with ID {user_id} not found") | ||
return cast(UserMetadataRaw, response.user.user_metadata) |
45 changes: 45 additions & 0 deletions
45
rnd/autogpt_libs/autogpt_libs/supabase_integration_credentials_store/types.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from typing import Annotated, Any, Literal, Optional, TypedDict | ||
from uuid import uuid4 | ||
|
||
from pydantic import BaseModel, Field, SecretStr, field_serializer | ||
|
||
|
||
class _BaseCredentials(BaseModel): | ||
id: str = Field(default_factory=lambda: str(uuid4())) | ||
provider: str | ||
title: str | ||
|
||
@field_serializer("*") | ||
def dump_secret_strings(value: Any, _info): | ||
if isinstance(value, SecretStr): | ||
return value.get_secret_value() | ||
return value | ||
|
||
|
||
class OAuth2Credentials(_BaseCredentials): | ||
type: Literal["oauth2"] = "oauth2" | ||
access_token: SecretStr | ||
access_token_expires_at: int | ||
refresh_token: SecretStr | ||
refresh_token_expires_at: Optional[int] | ||
scopes: list[str] | ||
|
||
|
||
class APIKeyCredentials(_BaseCredentials): | ||
type: Literal["api_key"] = "api_key" | ||
api_key: SecretStr | ||
expires_at: Optional[int] | ||
|
||
|
||
Credentials = Annotated[ | ||
OAuth2Credentials | APIKeyCredentials, | ||
Field(discriminator="type"), | ||
] | ||
|
||
|
||
class UserMetadata(BaseModel): | ||
integration_credentials: list[Credentials] = Field(default_factory=list) | ||
|
||
|
||
class UserMetadataRaw(TypedDict, total=False): | ||
integration_credentials: list[dict] |
Oops, something went wrong.