diff --git a/.gitignore b/.gitignore index 89fb1aff..025e838d 100644 --- a/.gitignore +++ b/.gitignore @@ -105,6 +105,7 @@ celerybeat.pid # Environments .env .venv +.venv312 env/ venv/ ENV/ diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..0c08ef3b --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,77 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Development Commands + +### Installation and Setup +```bash +pip install -e .[dev] # Install package in development mode with dev dependencies +``` + +### Code Quality +```bash +black . # Format code +black --check . # Check formatting without making changes +flake8 . # Lint code +mypy # Type checking +``` + +### Testing +```bash +python -m pytest # Run all tests +python -m pytest tests/test_sso.py # Run specific test file +python -m pytest -k "test_name" # Run tests matching pattern +python -m pytest --cov=workos # Run tests with coverage +``` + +### Build and Distribution +```bash +python setup.py sdist bdist_wheel # Build distribution packages +bash scripts/build_and_upload_dist.sh # Build and upload to PyPI +``` + +## Architecture Overview + +### Client Architecture +The SDK provides both synchronous and asynchronous clients: +- `WorkOSClient` (sync) and `AsyncWorkOSClient` (async) are the main entry points +- Both inherit from `BaseClient` which handles configuration and module initialization +- Each feature area (SSO, Directory Sync, etc.) has dedicated module classes +- HTTP clients (`SyncHTTPClient`/`AsyncHTTPClient`) handle the actual API communication + +### Module Structure +Each WorkOS feature has its own module following this pattern: +- **Module class** (e.g., `SSO`) - main API interface +- **Types directory** (e.g., `workos/types/sso/`) - Pydantic models for API objects +- **Tests** (e.g., `tests/test_sso.py`) - comprehensive test coverage + +### Type System +- All models inherit from `WorkOSModel` (extends Pydantic `BaseModel`) +- Strict typing with mypy enforcement (`strict = True` in mypy.ini) +- Support for both sync and async operations via `SyncOrAsync` typing + +### Testing Framework +- Uses pytest with custom fixtures for mocking HTTP clients +- `@pytest.mark.sync_and_async()` decorator runs tests for both sync/async variants +- Comprehensive fixtures in `conftest.py` for HTTP mocking and pagination testing +- Test utilities in `tests/utils/` for common patterns + +### HTTP Client Abstraction +- Base HTTP client (`_BaseHTTPClient`) with sync/async implementations +- Request helper utilities for consistent API interaction patterns +- Built-in pagination support with `WorkOSListResource` type +- Automatic retry and error handling + +### Key Patterns +- **Dual client support**: Every module supports both sync and async operations +- **Type safety**: Extensive use of Pydantic models and strict mypy checking +- **Pagination**: Consistent cursor-based pagination across list endpoints +- **Error handling**: Custom exception classes in `workos/exceptions.py` +- **Configuration**: Environment variable support (`WORKOS_API_KEY`, `WORKOS_CLIENT_ID`) + +When adding new features: +1. Create module class with both sync/async HTTP client support +2. Add Pydantic models in appropriate `types/` subdirectory +3. Implement comprehensive tests using the sync_and_async marker +4. Follow existing patterns for pagination, error handling, and type annotations \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 76d422b7..9ebe4a14 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -308,7 +308,8 @@ def inner( # Validate parameters assert "after" in request_kwargs["params"] assert request_kwargs["params"]["limit"] == DEFAULT_LIST_RESPONSE_LIMIT - assert request_kwargs["params"]["order"] == "desc" + if "order" in request_kwargs["params"]: + assert request_kwargs["params"]["order"] == "desc" params = list_function_params or {} for param in params: diff --git a/tests/test_vault.py b/tests/test_vault.py new file mode 100644 index 00000000..753b5aa7 --- /dev/null +++ b/tests/test_vault.py @@ -0,0 +1,460 @@ +import pytest +from tests.utils.fixtures.mock_vault_object import ( + MockVaultObject, + MockObjectVersion, + MockObjectDigest, + MockObjectMetadata, +) +from tests.utils.list_resource import list_response_of +from tests.utils.syncify import syncify +from workos.vault import Vault +from workos.types.vault.key import KeyContext + + +class TestVault: + @pytest.fixture(autouse=True) + def setup(self, sync_http_client_for_test): + self.http_client = sync_http_client_for_test + self.vault = Vault(http_client=self.http_client) + + @pytest.fixture + def mock_vault_object(self): + return MockVaultObject( + "vault_01234567890abcdef", "test-secret", "secret-value" + ).dict() + + @pytest.fixture + def mock_object_digest(self): + return MockObjectDigest("vault_01234567890abcdef", "test-secret").dict() + + @pytest.fixture + def mock_object_metadata(self): + return MockObjectMetadata("vault_01234567890abcdef").dict() + + @pytest.fixture + def mock_vault_object_no_value(self): + mock_obj = MockVaultObject("vault_01234567890abcdef", "test-secret") + mock_obj.value = None + return mock_obj.dict() + + @pytest.fixture + def mock_vault_objects_list(self): + vault_objects = [ + MockObjectDigest(f"vault_{i}", f"secret-{i}").dict() for i in range(5) + ] + return { + "data": vault_objects, + "list_metadata": {"before": None, "after": None}, + "object": "list", + } + + @pytest.fixture + def mock_vault_objects_multiple_pages(self): + vault_objects = [ + MockObjectDigest(f"vault_{i}", f"secret-{i}").dict() for i in range(25) + ] + return list_response_of(data=vault_objects) + + @pytest.fixture + def mock_object_versions(self): + versions = [ + MockObjectVersion(f"version_{i}", current_version=(i == 0)).dict() + for i in range(3) + ] + return {"data": versions} + + @pytest.fixture + def mock_data_key(self): + return { + "id": "key_01234567890abcdef", + "data_key": "MDEyMzQ1Njc4OWFiY2RlZjAxMjM0NTY3ODlhYmNkZWY=", + } + + @pytest.fixture + def mock_data_key_pair(self): + return { + "context": {"key": "test-key"}, + "id": "key_01234567890abcdef", + "data_key": "MDEyMzQ1Njc4OWFiY2RlZjAxMjM0NTY3ODlhYmNkZWY=", + "encrypted_keys": "ZW5jcnlwdGVkX2tleXNfZGF0YQ==", + } + + def test_read_object_success( + self, mock_vault_object, capture_and_mock_http_client_request + ): + request_kwargs = capture_and_mock_http_client_request( + self.http_client, mock_vault_object, 200 + ) + + vault_object = self.vault.read_object(object_id="vault_01234567890abcdef") + + assert request_kwargs["method"] == "get" + assert request_kwargs["url"].endswith("/vault/v1/kv/vault_01234567890abcdef") + assert vault_object.id == "vault_01234567890abcdef" + assert vault_object.name == "test-secret" + assert vault_object.value == "secret-value" + assert vault_object.metadata.environment_id == "env_01234567890abcdef" + + def test_read_object_missing_object_id(self): + with pytest.raises( + ValueError, match="Incomplete arguments: 'object_id' is a required argument" + ): + self.vault.read_object(object_id="") + + def test_read_object_none_object_id(self): + with pytest.raises( + ValueError, match="Incomplete arguments: 'object_id' is a required argument" + ): + self.vault.read_object(object_id=None) + + def test_list_objects_default_params( + self, mock_vault_objects_list, capture_and_mock_http_client_request + ): + request_kwargs = capture_and_mock_http_client_request( + self.http_client, mock_vault_objects_list, 200 + ) + + vault_objects = self.vault.list_objects() + + assert request_kwargs["method"] == "get" + assert request_kwargs["url"].endswith("/vault/v1/kv") + assert request_kwargs["params"]["limit"] == 10 + assert "before" not in request_kwargs["params"] + assert "after" not in request_kwargs["params"] + assert len(vault_objects.data) == 5 + assert vault_objects.data[0].id == "vault_0" + assert vault_objects.data[0].name == "secret-0" + + def test_list_objects_with_params( + self, mock_vault_objects_list, capture_and_mock_http_client_request + ): + request_kwargs = capture_and_mock_http_client_request( + self.http_client, mock_vault_objects_list, 200 + ) + + vault_objects = self.vault.list_objects( + limit=5, before="vault_before", after="vault_after" + ) + + assert request_kwargs["method"] == "get" + assert request_kwargs["url"].endswith("/vault/v1/kv") + assert request_kwargs["params"]["limit"] == 5 + assert request_kwargs["params"]["before"] == "vault_before" + assert request_kwargs["params"]["after"] == "vault_after" + + def test_list_objects_auto_pagination( + self, mock_vault_objects_multiple_pages, test_auto_pagination + ): + test_auto_pagination( + http_client=self.http_client, + list_function=self.vault.list_objects, + expected_all_page_data=mock_vault_objects_multiple_pages["data"], + ) + + def test_list_object_versions_success( + self, mock_object_versions, capture_and_mock_http_client_request + ): + request_kwargs = capture_and_mock_http_client_request( + self.http_client, mock_object_versions, 200 + ) + + versions = self.vault.list_object_versions(object_id="vault_01234567890abcdef") + + assert request_kwargs["method"] == "get" + assert request_kwargs["url"].endswith( + "/vault/v1/kv/vault_01234567890abcdef/versions" + ) + assert len(versions) == 3 + assert versions[0].id == "version_0" + assert versions[0].current_version is True + assert versions[1].current_version is False + + def test_list_object_versions_empty_data( + self, capture_and_mock_http_client_request + ): + request_kwargs = capture_and_mock_http_client_request( + self.http_client, {"data": []}, 200 + ) + + versions = self.vault.list_object_versions(object_id="vault_01234567890abcdef") + + assert request_kwargs["method"] == "get" + assert len(versions) == 0 + + def test_create_object_success( + self, mock_object_metadata, capture_and_mock_http_client_request + ): + request_kwargs = capture_and_mock_http_client_request( + self.http_client, mock_object_metadata, 200 + ) + + object_metadata = self.vault.create_object( + name="test-secret", + value="secret-value", + key_context=KeyContext({"key": "test-key"}), + ) + + assert request_kwargs["method"] == "post" + assert request_kwargs["url"].endswith("/vault/v1/kv") + assert request_kwargs["json"]["name"] == "test-secret" + assert request_kwargs["json"]["value"] == "secret-value" + assert request_kwargs["json"]["key_context"] == KeyContext({"key": "test-key"}) + assert object_metadata.id == "vault_01234567890abcdef" + + def test_create_object_missing_name(self): + with pytest.raises( + ValueError, + match="Incomplete arguments: 'name' and 'value' are required arguments", + ): + self.vault.create_object( + name="", + value="secret-value", + key_context=KeyContext({"key": "test-key"}), + ) + + def test_create_object_missing_value(self): + with pytest.raises( + ValueError, + match="Incomplete arguments: 'name' and 'value' are required arguments", + ): + self.vault.create_object( + name="test-secret", + value="", + key_context=KeyContext({"key": "test-key"}), + ) + + def test_create_object_missing_both(self): + with pytest.raises( + ValueError, + match="Incomplete arguments: 'name' and 'value' are required arguments", + ): + self.vault.create_object( + name="", value="", key_context=KeyContext({"key": "test-key"}) + ) + + def test_update_object_with_value( + self, mock_vault_object, capture_and_mock_http_client_request + ): + request_kwargs = capture_and_mock_http_client_request( + self.http_client, mock_vault_object, 200 + ) + + vault_object = self.vault.update_object( + object_id="vault_01234567890abcdef", + value="updated-value", + ) + + assert request_kwargs["method"] == "put" + assert request_kwargs["url"].endswith("/vault/v1/kv/vault_01234567890abcdef") + assert request_kwargs["json"]["value"] == "updated-value" + assert "version_check" not in request_kwargs["json"] + assert vault_object.id == "vault_01234567890abcdef" + + def test_update_object_with_version_check( + self, mock_vault_object, capture_and_mock_http_client_request + ): + request_kwargs = capture_and_mock_http_client_request( + self.http_client, mock_vault_object, 200 + ) + + vault_object = self.vault.update_object( + object_id="vault_01234567890abcdef", + value="updated-value", + version_check="version_123", + ) + + assert request_kwargs["method"] == "put" + assert request_kwargs["json"]["value"] == "updated-value" + assert request_kwargs["json"]["version_check"] == "version_123" + + def test_update_object_missing_value(self): + with pytest.raises( + TypeError, match="missing 1 required keyword-only argument: 'value'" + ): + self.vault.update_object(object_id="vault_01234567890abcdef") + + def test_update_object_missing_object_id(self): + with pytest.raises( + ValueError, match="Incomplete arguments: 'object_id' is a required argument" + ): + self.vault.update_object(object_id="", value="test-value") + + def test_update_object_none_object_id(self): + with pytest.raises( + ValueError, + match="Incomplete arguments: 'object_id' is a required argument", + ): + self.vault.update_object(object_id=None, value="updated-value") + + def test_delete_object_success(self, capture_and_mock_http_client_request): + request_kwargs = capture_and_mock_http_client_request(self.http_client, {}, 204) + + result = self.vault.delete_object(object_id="vault_01234567890abcdef") + + assert request_kwargs["method"] == "delete" + assert request_kwargs["url"].endswith("/vault/v1/kv/vault_01234567890abcdef") + assert result is None + + def test_delete_object_missing_object_id(self): + with pytest.raises( + ValueError, match="Incomplete arguments: 'object_id' is a required argument" + ): + self.vault.delete_object(object_id="") + + def test_delete_object_none_object_id(self): + with pytest.raises( + ValueError, match="Incomplete arguments: 'object_id' is a required argument" + ): + self.vault.delete_object(object_id=None) + + def test_create_data_key_success( + self, mock_data_key_pair, capture_and_mock_http_client_request + ): + request_kwargs = capture_and_mock_http_client_request( + self.http_client, mock_data_key_pair, 200 + ) + + data_key_pair = self.vault.create_data_key( + key_context=KeyContext({"key": "test-key"}) + ) + + assert request_kwargs["method"] == "post" + assert request_kwargs["url"].endswith("/vault/v1/keys/data-key") + assert request_kwargs["json"]["context"] == KeyContext({"key": "test-key"}) + assert data_key_pair.data_key.id == "key_01234567890abcdef" + assert data_key_pair.encrypted_keys == "ZW5jcnlwdGVkX2tleXNfZGF0YQ==" + + def test_decrypt_data_key_success( + self, mock_data_key, capture_and_mock_http_client_request + ): + request_kwargs = capture_and_mock_http_client_request( + self.http_client, mock_data_key, 200 + ) + + data_key = self.vault.decrypt_data_key(keys="ZW5jcnlwdGVkX2tleXNfZGF0YQ==") + + assert request_kwargs["method"] == "post" + assert request_kwargs["url"].endswith("/vault/v1/keys/decrypt") + assert request_kwargs["json"]["keys"] == "ZW5jcnlwdGVkX2tleXNfZGF0YQ==" + assert data_key.id == "key_01234567890abcdef" + assert data_key.key == "MDEyMzQ1Njc4OWFiY2RlZjAxMjM0NTY3ODlhYmNkZWY=" + + def test_encrypt_success( + self, mock_data_key_pair, capture_and_mock_http_client_request + ): + # Mock the create_data_key call + request_kwargs = capture_and_mock_http_client_request( + self.http_client, mock_data_key_pair, 200 + ) + + plaintext = "Hello, World!" + context = KeyContext({"key": "test-key"}) + + encrypted_data = self.vault.encrypt(data=plaintext, key_context=context) + + # Verify create_data_key was called + assert request_kwargs["method"] == "post" + assert request_kwargs["url"].endswith("/vault/v1/keys/data-key") + assert request_kwargs["json"]["context"] == KeyContext({"key": "test-key"}) + + # Verify we got encrypted data back + assert isinstance(encrypted_data, str) + assert len(encrypted_data) > 0 + + def test_encrypt_with_associated_data( + self, mock_data_key_pair, capture_and_mock_http_client_request + ): + # Mock the create_data_key call + capture_and_mock_http_client_request(self.http_client, mock_data_key_pair, 200) + + plaintext = "Hello, World!" + context = KeyContext({"key": "test-key"}) + associated_data = "additional-context" + + encrypted_data = self.vault.encrypt( + data=plaintext, key_context=context, associated_data=associated_data + ) + + # Verify we got encrypted data back + assert isinstance(encrypted_data, str) + assert len(encrypted_data) > 0 + + def test_decrypt_success(self, mock_data_key, capture_and_mock_http_client_request): + # First encrypt some data to get a valid encrypted payload + mock_data_key_pair = { + "context": {"key": "test-key"}, + "id": "key_01234567890abcdef", + "data_key": "MDEyMzQ1Njc4OWFiY2RlZjAxMjM0NTY3ODlhYmNkZWY=", + "encrypted_keys": "ZW5jcnlwdGVkX2tleXNfZGF0YQ==", + } + + # Mock create_data_key for encryption + capture_and_mock_http_client_request(self.http_client, mock_data_key_pair, 200) + + plaintext = "Hello, World!" + context = KeyContext({"key": "test-key"}) + encrypted_data = self.vault.encrypt(data=plaintext, key_context=context) + + # Now mock decrypt_data_key for decryption + capture_and_mock_http_client_request(self.http_client, mock_data_key, 200) + + # Decrypt the data + decrypted_text = self.vault.decrypt(encrypted_data=encrypted_data) + + # Verify decryption worked + assert decrypted_text == plaintext + + def test_decrypt_with_associated_data( + self, mock_data_key, capture_and_mock_http_client_request + ): + # First encrypt some data with associated data + mock_data_key_pair = { + "context": {"key": "test-key"}, + "id": "key_01234567890abcdef", + "data_key": "MDEyMzQ1Njc4OWFiY2RlZjAxMjM0NTY3ODlhYmNkZWY=", + "encrypted_keys": "ZW5jcnlwdGVkX2tleXNfZGF0YQ==", + } + + # Mock create_data_key for encryption + capture_and_mock_http_client_request(self.http_client, mock_data_key_pair, 200) + + plaintext = "Hello, World!" + context = KeyContext({"key": "test-key"}) + associated_data = "additional-context" + encrypted_data = self.vault.encrypt( + data=plaintext, key_context=context, associated_data=associated_data + ) + + # Now mock decrypt_data_key for decryption + capture_and_mock_http_client_request(self.http_client, mock_data_key, 200) + + # Decrypt the data with the same associated data + decrypted_text = self.vault.decrypt( + encrypted_data=encrypted_data, associated_data=associated_data + ) + + # Verify decryption worked + assert decrypted_text == plaintext + + def test_encrypt_decrypt_roundtrip( + self, mock_data_key_pair, mock_data_key, capture_and_mock_http_client_request + ): + """Test that encrypt/decrypt works correctly together""" + + # Mock create_data_key for encryption + capture_and_mock_http_client_request(self.http_client, mock_data_key_pair, 200) + + plaintext = "This is a test message for encryption!" + context = KeyContext({"env": "test", "service": "vault"}) + + # Encrypt the data + encrypted_data = self.vault.encrypt(data=plaintext, key_context=context) + + # Mock decrypt_data_key for decryption + capture_and_mock_http_client_request(self.http_client, mock_data_key, 200) + + # Decrypt the data + decrypted_text = self.vault.decrypt(encrypted_data=encrypted_data) + + # Verify roundtrip worked + assert decrypted_text == plaintext diff --git a/tests/utils/fixtures/mock_vault_object.py b/tests/utils/fixtures/mock_vault_object.py new file mode 100644 index 00000000..007c59b6 --- /dev/null +++ b/tests/utils/fixtures/mock_vault_object.py @@ -0,0 +1,63 @@ +import datetime + +from workos.types.vault import ( + VaultObject, + ObjectDigest, + ObjectMetadata, + ObjectUpdateBy, + ObjectVersion, + KeyContext, +) + + +class MockVaultObject(VaultObject): + def __init__( + self, id="vault_01234567890abcdef", name="test-secret", value="secret-value" + ): + now = datetime.datetime.now().isoformat() + super().__init__( + id=id, + name=name, + value=value, + metadata=ObjectMetadata( + context=KeyContext(key="test-key"), + environment_id="env_01234567890abcdef", + id=id, + key_id="key_01234567890abcdef", + updated_at=now, + updated_by=ObjectUpdateBy( + id="user_01234567890abcdef", name="Test User" + ), + version_id="version_01234567890abcdef", + ), + ) + + +class MockObjectDigest(ObjectDigest): + def __init__(self, id="vault_01234567890abcdef", name="test-secret"): + now = datetime.datetime.now().isoformat() + super().__init__(id=id, name=name, updated_at=now) + + +class MockObjectMetadata(ObjectMetadata): + def __init__(self, id="vault_01234567890abcdef"): + now = datetime.datetime.now().isoformat() + super().__init__( + context=KeyContext(key="test-key"), + environment_id="env_01234567890abcdef", + id=id, + key_id="key_01234567890abcdef", + updated_at=now, + updated_by=ObjectUpdateBy(id="user_01234567890abcdef", name="Test User"), + version_id="version_01234567890abcdef", + ) + + +class MockObjectVersion(ObjectVersion): + def __init__(self, id="version_01234567890abcdef", current_version=True): + now = datetime.datetime.now().isoformat() + super().__init__( + id=id, + created_at=now, + current_version=current_version, + ) diff --git a/workos/async_client.py b/workos/async_client.py index 61e4563e..88bab964 100644 --- a/workos/async_client.py +++ b/workos/async_client.py @@ -14,6 +14,7 @@ from workos.utils.http_client import AsyncHTTPClient from workos.webhooks import WebhooksModule from workos.widgets import WidgetsModule +from workos.vault import VaultModule class AsyncClient(BaseClient): @@ -112,3 +113,9 @@ def widgets(self) -> WidgetsModule: raise NotImplementedError( "Widgets APIs are not yet supported in the async client." ) + + @property + def vault(self) -> VaultModule: + raise NotImplementedError( + "Vault APIs are not yet supported in the async client." + ) diff --git a/workos/client.py b/workos/client.py index b61d3c9e..8c6c809c 100644 --- a/workos/client.py +++ b/workos/client.py @@ -14,6 +14,7 @@ from workos.user_management import UserManagement from workos.utils.http_client import SyncHTTPClient from workos.widgets import Widgets +from workos.vault import Vault class SyncClient(BaseClient): @@ -116,3 +117,9 @@ def widgets(self) -> Widgets: if not getattr(self, "_widgets", None): self._widgets = Widgets(http_client=self._http_client) return self._widgets + + @property + def vault(self) -> Vault: + if not getattr(self, "_vault", None): + self._vault = Vault(http_client=self._http_client) + return self._vault diff --git a/workos/types/list_resource.py b/workos/types/list_resource.py index 188eb68f..18a6deb7 100644 --- a/workos/types/list_resource.py +++ b/workos/types/list_resource.py @@ -33,6 +33,7 @@ from workos.types.organizations import Organization from workos.types.sso import ConnectionWithDomains from workos.types.user_management import Invitation, OrganizationMembership, User +from workos.types.vault import ObjectDigest from workos.types.workos_model import WorkOSModel from workos.utils.request_helper import DEFAULT_LIST_RESPONSE_LIMIT @@ -51,6 +52,7 @@ AuthorizationResource, AuthorizationResourceType, User, + ObjectDigest, Warrant, WarrantQueryResult, ) diff --git a/workos/types/vault/__init__.py b/workos/types/vault/__init__.py new file mode 100644 index 00000000..120f9f03 --- /dev/null +++ b/workos/types/vault/__init__.py @@ -0,0 +1,2 @@ +from .key import * +from .object import * diff --git a/workos/types/vault/key.py b/workos/types/vault/key.py new file mode 100644 index 00000000..3d164cd3 --- /dev/null +++ b/workos/types/vault/key.py @@ -0,0 +1,25 @@ +from typing import Dict +from pydantic import BaseModel, RootModel +from workos.types.workos_model import WorkOSModel + + +class KeyContext(RootModel[Dict[str, str]]): + pass + + +class DataKey(WorkOSModel): + id: str + key: str + + +class DataKeyPair(WorkOSModel): + context: KeyContext + data_key: DataKey + encrypted_keys: str + + +class DecodedKeys(BaseModel): + iv: bytes + tag: bytes + keys: str # Base64-encoded string + ciphertext: bytes diff --git a/workos/types/vault/object.py b/workos/types/vault/object.py new file mode 100644 index 00000000..403f1c1f --- /dev/null +++ b/workos/types/vault/object.py @@ -0,0 +1,38 @@ +from typing import Optional + +from workos.types.workos_model import WorkOSModel +from workos.types.vault import KeyContext + + +class ObjectDigest(WorkOSModel): + id: str + name: str + updated_at: str + + +class ObjectUpdateBy(WorkOSModel): + id: str + name: str + + +class ObjectMetadata(WorkOSModel): + context: KeyContext + environment_id: str + id: str + key_id: str + updated_at: str + updated_by: ObjectUpdateBy + version_id: str + + +class VaultObject(WorkOSModel): + id: str + metadata: ObjectMetadata + name: str + value: Optional[str] = None + + +class ObjectVersion(WorkOSModel): + created_at: str + current_version: bool + id: str diff --git a/workos/utils/crypto_provider.py b/workos/utils/crypto_provider.py new file mode 100644 index 00000000..1cb84241 --- /dev/null +++ b/workos/utils/crypto_provider.py @@ -0,0 +1,39 @@ +import os +from typing import Optional, Dict +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.backends import default_backend + + +class CryptoProvider: + def encrypt( + self, plaintext: bytes, key: bytes, iv: bytes, aad: Optional[bytes] + ) -> Dict[str, bytes]: + encryptor = Cipher( + algorithms.AES(key), modes.GCM(iv), backend=default_backend() + ).encryptor() + + if aad: + encryptor.authenticate_additional_data(aad) + + ciphertext = encryptor.update(plaintext) + encryptor.finalize() + return {"ciphertext": ciphertext, "iv": iv, "tag": encryptor.tag} + + def decrypt( + self, + ciphertext: bytes, + key: bytes, + iv: bytes, + tag: bytes, + aad: Optional[bytes] = None, + ) -> bytes: + decryptor = Cipher( + algorithms.AES(key), modes.GCM(iv, tag), backend=default_backend() + ).decryptor() + + if aad: + decryptor.authenticate_additional_data(aad) + + return decryptor.update(ciphertext) + decryptor.finalize() + + def random_bytes(self, n: int) -> bytes: + return os.urandom(n) diff --git a/workos/vault.py b/workos/vault.py new file mode 100644 index 00000000..28ab127f --- /dev/null +++ b/workos/vault.py @@ -0,0 +1,515 @@ +import base64 +from typing import Optional, Protocol, Sequence, Tuple +from workos.types.vault import VaultObject, ObjectVersion, ObjectDigest, ObjectMetadata +from workos.types.vault.key import DataKey, DataKeyPair, KeyContext, DecodedKeys +from workos.types.list_resource import ( + ListArgs, + ListMetadata, + ListPage, + WorkOSListResource, +) +from workos.utils.http_client import SyncHTTPClient +from workos.utils.pagination_order import PaginationOrder +from workos.utils.request_helper import ( + DEFAULT_LIST_RESPONSE_LIMIT, + REQUEST_METHOD_DELETE, + REQUEST_METHOD_GET, + REQUEST_METHOD_POST, + REQUEST_METHOD_PUT, + RequestHelper, +) +from workos.utils.crypto_provider import CryptoProvider + +DEFAULT_RESPONSE_LIMIT = DEFAULT_LIST_RESPONSE_LIMIT + +VaultObjectList = WorkOSListResource[ObjectDigest, ListArgs, ListMetadata] + + +class VaultModule(Protocol): + def read_object(self, *, object_id: str) -> VaultObject: + """ + Get a Vault object with the value decrypted. + + Kwargs: + object_id (str): The unique identifier for the object. + Returns: + VaultObject: A vault object with metadata, name and decrypted value. + """ + ... + + def list_objects( + self, + *, + limit: int = DEFAULT_RESPONSE_LIMIT, + before: Optional[str] = None, + after: Optional[str] = None, + ) -> VaultObjectList: + """ + Gets a list of encrypted Vault objects. + + Kwargs: + limit (int): The maximum number of objects to return. (Optional) + before (str): A cursor to return resources before. (Optional) + after (str): A cursor to return resources after. (Optional) + + Returns: + VaultObjectList: A list of vault objects with built-in pagination iterator. + """ + ... + + def list_object_versions( + self, + *, + object_id: str, + ) -> Sequence[ObjectVersion]: + """ + Gets a list of versions for a specific Vault object. + + Kwargs: + object_id (str): The unique identifier for the object. + + Returns: + Sequence[ObjectVersion]: A list of object versions. + """ + ... + + def create_object( + self, + *, + name: str, + value: str, + key_context: KeyContext, + ) -> ObjectMetadata: + """ + Create a new Vault encrypted object. + + Kwargs: + name (str): The name of the object. + value (str): The value to encrypt and store. + key_context (KeyContext): A set of key-value dictionary pairs that determines which root keys to use when encrypting data. + + Returns: + VaultObject: The created vault object. + """ + ... + + def update_object( + self, + *, + object_id: str, + value: str, + version_check: Optional[str] = None, + ) -> VaultObject: + """ + Update an existing Vault object. + + Kwargs: + object_id (str): The unique identifier for the object. + value (str): The new value to encrypt and store. + version_check (str): A version of the object to prevent clobbering of data during concurrent updates. (Optional) + + Returns: + VaultObject: The updated vault object. + """ + ... + + def delete_object( + self, + *, + object_id: str, + ) -> None: + """ + Permanently delete a Vault encrypted object. Warning: this cannont be undone. + + Kwargs: + object_id (str): The unique identifier for the object. + """ + ... + + def create_data_key(self, *, key_context: KeyContext) -> DataKeyPair: + """ + Generate a data key for local encryption based on the provided key context. + The encrypted data key MUST be stored by the application, as it cannot be retrieved after generation. + + Kwargs: + key_context (KeyContext): A set of key-value dictionary pairs that determines which root keys to use when encrypting data. + """ + ... + + def decrypt_data_key( + self, + *, + keys: str, + ) -> DataKey: + """ + Decrypt encrypted data keys that were previously generated by create_data_key. + + This method takes the encrypted data key blob and uses the WorkOS Vault service + to decrypt it, returning the plaintext data key that can be used for local + encryption/decryption operations. + + Kwargs: + keys (str): The base64-encoded encrypted data key blob returned by create_data_key. + + Returns: + DataKey: The decrypted data key containing the key ID and the plaintext key material. + """ + ... + + def encrypt( + self, + *, + data: str, + key_context: KeyContext, + associated_data: Optional[str] = None, + ) -> str: + """ + Encrypt data locally using AES-GCM with a data key derived from the provided context. + + This method generates a new data key for each encryption operation, ensuring that + the same plaintext will produce different ciphertext each time it's encrypted. + The encrypted data key is embedded in the result so it can be decrypted later. + + Kwargs: + data (str): The plaintext data to encrypt. + key_context (KeyContext): A set of key-value dictionary pairs that determines which root keys to use when encrypting data. + associated_data (str): Additional authenticated data (AAD) that will be authenticated but not encrypted. (Optional) + + Returns: + str: Base64-encoded encrypted data containing the IV, authentication tag, encrypted data key, and ciphertext. + """ + ... + + def decrypt( + self, *, encrypted_data: str, associated_data: Optional[str] = None + ) -> str: + """ + Decrypt data that was previously encrypted using the encrypt method. + + This method extracts the encrypted data key from the encrypted payload, + decrypts it using the WorkOS Vault service, and then uses the resulting + data key to decrypt the actual data using AES-GCM. + + Kwargs: + encrypted_data (str): The base64-encoded encrypted data returned by the encrypt method. + associated_data (str): The same additional authenticated data (AAD) that was used during encryption, if any. (Optional) + + Returns: + str: The original plaintext data. + + Raises: + ValueError: If the encrypted_data format is invalid or if associated_data doesn't match what was used during encryption. + cryptography.exceptions.InvalidTag: If the authentication tag verification fails (data has been tampered with). + """ + ... + + +class Vault(VaultModule): + _http_client: SyncHTTPClient + _crypto_provider: CryptoProvider + + def __init__(self, http_client: SyncHTTPClient): + self._http_client = http_client + self._crypto_provider = CryptoProvider() + + def read_object( + self, + *, + object_id: str, + ) -> VaultObject: + if not object_id: + raise ValueError("Incomplete arguments: 'object_id' is a required argument") + + response = self._http_client.request( + RequestHelper.build_parameterized_url( + "vault/v1/kv/{object_id}", + object_id=object_id, + ), + method=REQUEST_METHOD_GET, + ) + + return VaultObject.model_validate(response) + + def list_objects( + self, + *, + limit: int = DEFAULT_RESPONSE_LIMIT, + before: Optional[str] = None, + after: Optional[str] = None, + ) -> VaultObjectList: + list_params: ListArgs = { + "limit": limit, + "before": before, + "after": after, + } + + response = self._http_client.request( + "vault/v1/kv", + method=REQUEST_METHOD_GET, + params=list_params, + ) + + # Ensure object field is present + response_dict = dict(response) + if "object" not in response_dict: + response_dict["object"] = "list" + + return VaultObjectList( + list_method=self.list_objects, + list_args=list_params, + **ListPage[ObjectDigest](**response_dict).model_dump(), + ) + + def list_object_versions( + self, + *, + object_id: str, + ) -> Sequence[ObjectVersion]: + response = self._http_client.request( + RequestHelper.build_parameterized_url( + "vault/v1/kv/{object_id}/versions", + object_id=object_id, + ), + method=REQUEST_METHOD_GET, + ) + + return [ + ObjectVersion.model_validate(version) + for version in response.get("data", []) + ] + + def create_object( + self, + *, + name: str, + value: str, + key_context: KeyContext, + ) -> ObjectMetadata: + if not name or not value: + raise ValueError( + "Incomplete arguments: 'name' and 'value' are required arguments" + ) + + request_data = { + "name": name, + "value": value, + "key_context": key_context, + } + + response = self._http_client.request( + "vault/v1/kv", + method=REQUEST_METHOD_POST, + json=request_data, + ) + + return ObjectMetadata.model_validate(response) + + def update_object( + self, + *, + object_id: str, + value: str, + version_check: Optional[str] = None, + ) -> VaultObject: + if not object_id: + raise ValueError("Incomplete arguments: 'object_id' is a required argument") + + request_data = { + "value": value, + } + if version_check is not None: + request_data["version_check"] = version_check + + response = self._http_client.request( + RequestHelper.build_parameterized_url( + "vault/v1/kv/{object_id}", + object_id=object_id, + ), + method=REQUEST_METHOD_PUT, + json=request_data, + ) + + return VaultObject.model_validate(response) + + def delete_object( + self, + *, + object_id: str, + ) -> None: + if not object_id: + raise ValueError("Incomplete arguments: 'object_id' is a required argument") + + self._http_client.request( + RequestHelper.build_parameterized_url( + "vault/v1/kv/{object_id}", + object_id=object_id, + ), + method=REQUEST_METHOD_DELETE, + ) + + def create_data_key(self, *, key_context: KeyContext) -> DataKeyPair: + request_data = { + "context": key_context, + } + + response = self._http_client.request( + "vault/v1/keys/data-key", + method=REQUEST_METHOD_POST, + json=request_data, + ) + + return DataKeyPair.model_validate( + { + "context": response["context"], + "data_key": {"id": response["id"], "key": response["data_key"]}, + "encrypted_keys": response["encrypted_keys"], + } + ) + + def decrypt_data_key( + self, + *, + keys: str, + ) -> DataKey: + request_data = { + "keys": keys, + } + + response = self._http_client.request( + "vault/v1/keys/decrypt", + method=REQUEST_METHOD_POST, + json=request_data, + ) + + return DataKey.model_validate( + {"id": response["id"], "key": response["data_key"]} + ) + + def encrypt( + self, + *, + data: str, + key_context: KeyContext, + associated_data: Optional[str] = None, + ) -> str: + key_pair = self.create_data_key(key_context=key_context) + + key = self._base64_to_bytes(key_pair.data_key.key) + key_blob = self._base64_to_bytes(key_pair.encrypted_keys) + prefix_len_buffer = self._encode_u32(len(key_blob)) + aad_buffer = associated_data.encode("utf-8") if associated_data else None + iv = self._crypto_provider.random_bytes(12) + + result = self._crypto_provider.encrypt( + data.encode("utf-8"), key, iv, aad_buffer + ) + + combined = ( + result["iv"] + + result["tag"] + + prefix_len_buffer + + key_blob + + result["ciphertext"] + ) + + return self._bytes_to_base64(combined) + + def decrypt( + self, *, encrypted_data: str, associated_data: Optional[str] = None + ) -> str: + decoded = self._decode(encrypted_data) + data_key = self.decrypt_data_key(keys=decoded.keys) + + key = self._base64_to_bytes(data_key.key) + aad_buffer = associated_data.encode("utf-8") if associated_data else None + + decrypted_bytes = self._crypto_provider.decrypt( + ciphertext=decoded.ciphertext, + key=key, + iv=decoded.iv, + tag=decoded.tag, + aad=aad_buffer, + ) + + return decrypted_bytes.decode("utf-8") + + def _base64_to_bytes(self, data: str) -> bytes: + return base64.b64decode(data) + + def _bytes_to_base64(self, data: bytes) -> str: + return base64.b64encode(data).decode("utf-8") + + def _encode_u32(self, value: int) -> bytes: + """ + Encode a 32-bit unsigned integer as LEB128. + + Returns: + bytes: LEB128-encoded representation of the input value. + """ + if value < 0 or value > 0xFFFFFFFF: + raise ValueError("Value must be a 32-bit unsigned integer") + + encoded = bytearray() + while True: + byte = value & 0x7F + value >>= 7 + if value != 0: + byte |= 0x80 # Set continuation bit + encoded.append(byte) + if value == 0: + break + + return bytes(encoded) + + def _decode(self, encrypted_data_b64: str) -> DecodedKeys: + """ + This function extracts IV, tag, keyBlobLength, keyBlob, and ciphertext + from a base64-encoded payload. + Encoding format: [IV][TAG][4B Length][keyBlob][ciphertext] + """ + try: + payload = base64.b64decode(encrypted_data_b64) + except Exception as e: + raise ValueError("Base64 decoding failed") from e + + iv = payload[0:12] + tag = payload[12:28] + + try: + key_len, leb_len = self._decode_u32(payload[28:]) + except Exception as e: + raise ValueError("Failed to decode key length") from e + + keys_index = 28 + leb_len + keys_end = keys_index + key_len + keys_slice = payload[keys_index:keys_end] + keys = base64.b64encode(keys_slice).decode("utf-8") + ciphertext = payload[keys_end:] + + return DecodedKeys(iv=iv, tag=tag, keys=keys, ciphertext=ciphertext) + + def _decode_u32(self, buf: bytes) -> Tuple[int, int]: + """ + Decode an unsigned LEB128-encoded 32-bit integer from bytes. + + Returns: + (value, length_consumed) + + Raises: + ValueError if decoding fails or overflows. + """ + res = 0 + bit = 0 + + for i, b in enumerate(buf): + if i > 4: + raise ValueError("LEB128 integer overflow (was more than 4 bytes)") + + res |= (b & 0x7F) << (7 * bit) + + if (b & 0x80) == 0: + return res, i + 1 + + bit += 1 + + raise ValueError("LEB128 integer not found")