Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions tests/test_sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ def mock_magic_link_profile(self):
def mock_connection(self):
return MockConnection("conn_01E4ZCR3C56J083X43JQXF3JK5").dict()

@pytest.fixture
def mock_connection_updated(self):
connection = MockConnection("conn_01FHT48Z8J8295GZNQ4ZP1J81T").dict()

connection["options"] = {
"signing_cert": "signing_cert",
}

return connection

@pytest.fixture
def mock_connections(self):
connection_list = [MockConnection(id=str(i)).dict() for i in range(10)]
Expand Down Expand Up @@ -339,6 +349,33 @@ def test_list_connections_with_connection_type(
"order": "desc",
}

def test_update_connection(
self, mock_connection_updated, capture_and_mock_http_client_request
):
request_kwargs = capture_and_mock_http_client_request(
self.http_client, mock_connection_updated, 200
)

updated_connection = syncify(
self.sso.update_connection(
connection_id="conn_01EHT88Z8J8795GZNQ4ZP1J81T",
saml_options_signing_key="signing_key",
saml_options_signing_cert="signing_cert",
)
)

assert request_kwargs["url"].endswith(
"/connections/conn_01EHT88Z8J8795GZNQ4ZP1J81T"
)

assert request_kwargs["method"] == "put"
assert request_kwargs["json"] == {
"options": {"signing_key": "signing_key", "signing_cert": "signing_cert"}
}
assert updated_connection.id == "conn_01FHT48Z8J8295GZNQ4ZP1J81T"
assert updated_connection.name == "Foo Corporation"
assert updated_connection.options.signing_cert == "signing_cert"

def test_delete_connection(self, capture_and_mock_http_client_request):
request_kwargs = capture_and_mock_http_client_request(
self.http_client,
Expand Down
7 changes: 6 additions & 1 deletion tests/utils/fixtures/mock_connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import datetime
from workos.types.sso import ConnectionDomain, ConnectionWithDomains
from workos.types.sso import (
ConnectionDomain,
ConnectionWithDomains,
SamlConnectionOptions,
)


class MockConnection(ConnectionWithDomains):
Expand All @@ -14,6 +18,7 @@ def __init__(self, id):
state="active",
created_at=now,
updated_at=now,
options=SamlConnectionOptions(signing_cert="signing_cert"),
domains=[
ConnectionDomain(
id="connection_domain_abc123",
Expand Down
65 changes: 64 additions & 1 deletion workos/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
REQUEST_METHOD_POST,
QueryParameters,
RequestHelper,
REQUEST_METHOD_PUT,
)
from workos.types.list_resource import (
ListArgs,
Expand Down Expand Up @@ -167,11 +168,29 @@ def list_connections(
"""
...

def update_connection(
self,
*,
connection_id: str,
saml_options_signing_key: Optional[str] = None,
saml_options_signing_cert: Optional[str] = None,
) -> SyncOrAsync[ConnectionWithDomains]:
"""Updates a single connection

Args:
connection_id (str): Connection unique identifier
saml_options_signing_key (str): Signing key for the connection (Optional)
saml_options_signing_cert (str): Signing certificate for the connection (Optional)
Returns:
None
"""
...

def delete_connection(self, connection_id: str) -> SyncOrAsync[None]:
"""Deletes a single Connection

Args:
connection (str): Connection unique identifier
connection_id (str): Connection unique identifier

Returns:
None
Expand Down Expand Up @@ -255,6 +274,28 @@ def list_connections(
**ListPage[ConnectionWithDomains](**response).model_dump(),
)

def update_connection(
self,
*,
connection_id: str,
saml_options_signing_key: Optional[str] = None,
saml_options_signing_cert: Optional[str] = None,
) -> ConnectionWithDomains:
json = {
"options": {
"signing_key": saml_options_signing_key,
"signing_cert": saml_options_signing_cert,
}
}

response = self._http_client.request(
f"connections/{connection_id}",
method=REQUEST_METHOD_PUT,
json=json,
)

return ConnectionWithDomains.model_validate(response)

def delete_connection(self, connection_id: str) -> None:
self._http_client.request(
f"connections/{connection_id}", method=REQUEST_METHOD_DELETE
Expand Down Expand Up @@ -335,6 +376,28 @@ async def list_connections(
**ListPage[ConnectionWithDomains](**response).model_dump(),
)

async def update_connection(
self,
*,
connection_id: str,
saml_options_signing_key: Optional[str] = None,
saml_options_signing_cert: Optional[str] = None,
) -> ConnectionWithDomains:
json = {
"options": {
"signing_key": saml_options_signing_key,
"signing_cert": saml_options_signing_cert,
}
}

response = await self._http_client.request(
f"connections/{connection_id}",
method=REQUEST_METHOD_PUT,
json=json,
)

return ConnectionWithDomains.model_validate(response)

async def delete_connection(self, connection_id: str) -> None:
await self._http_client.request(
f"connections/{connection_id}",
Expand Down
9 changes: 8 additions & 1 deletion workos/types/sso/connection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal, Sequence
from typing import Literal, Sequence, Optional
from workos.types.sso.connection_domain import ConnectionDomain
from workos.types.workos_model import WorkOSModel
from workos.typing.literals import LiteralOrUntyped
Expand Down Expand Up @@ -45,6 +45,12 @@
]


class SamlConnectionOptions(WorkOSModel):
"""Representation of options payload of a Connection Response."""

signing_cert: Optional[str]


class Connection(WorkOSModel):
object: Literal["connection"]
id: str
Expand All @@ -54,6 +60,7 @@ class Connection(WorkOSModel):
state: LiteralOrUntyped[ConnectionState]
created_at: str
updated_at: str
options: Optional[SamlConnectionOptions] = None


class ConnectionWithDomains(Connection):
Expand Down