Skip to content

Commit ec67173

Browse files
fix: address pr review comments
1 parent 229cdb3 commit ec67173

3 files changed

Lines changed: 170 additions & 4 deletions

File tree

docs/MultipleCustomDomain.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ claims = await api_client.verify_request(
110110
)
111111
```
112112

113+
> [!WARNING]
114+
> If using `Host` or `X-Forwarded-Host` headers in your resolver, do not trust them without validation — these headers can be spoofed by clients. Always validate against a known allowlist of your API hostnames, or use server-determined values from your reverse proxy / API gateway that are not client-controllable.
115+
113116
### Tenant Lookup
114117

115118
Resolve domains from a database or configuration service:
@@ -126,7 +129,13 @@ def tenant_resolver(context: DomainsResolverContext) -> list[str]:
126129
```
127130

128131
> [!NOTE]
129-
> The resolver runs synchronously. If your lookup requires async I/O (database queries, HTTP calls), wrap it with `asyncio.run()` or pre-load the domain mapping at startup.
132+
> The resolver can be synchronous or asynchronous. If your resolver is an `async def`, the SDK will automatically `await` the result.
133+
>
134+
> ```python
135+
> async def async_resolver(context: DomainsResolverContext) -> list[str]:
136+
> domains = await fetch_domains_from_db(context["unverified_iss"])
137+
> return domains
138+
> ```
130139
131140
---
132141

src/auth0_api_python/api_client.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,20 @@ def __init__(self, options: ApiClientOptions):
8888
"Use 'domain' for single-domain mode, 'domains' for multi-domain support."
8989
)
9090

91+
# Validate that domain is set when client_id is configured
92+
if options.client_id and not options.domain:
93+
raise ConfigurationError(
94+
"The 'domain' parameter is required when 'client_id' is configured."
95+
)
9196
self.options = options
9297

98+
# Validate cache configuration
99+
if not isinstance(options.cache_ttl_seconds, (int, float)) or options.cache_ttl_seconds < 0:
100+
raise ConfigurationError("cache_ttl_seconds must be a non-negative number")
101+
102+
if not isinstance(options.cache_max_entries, int) or options.cache_max_entries < 2:
103+
raise ConfigurationError("cache_max_entries must be an integer greater than 1")
104+
93105
if options.cache_adapter:
94106
self._discovery_cache = options.cache_adapter
95107
self._jwks_cache = options.cache_adapter
@@ -177,7 +189,12 @@ async def _resolve_allowed_domains(
177189
)
178190

179191
# Normalize domains from resolver
180-
allowed_domains = [normalize_domain(d) for d in result]
192+
try:
193+
allowed_domains = [normalize_domain(d) for d in result]
194+
except ValueError as e:
195+
raise DomainsResolverError(
196+
f"Domains resolver returned invalid domain: {str(e)}"
197+
) from e
181198
else:
182199
# Should never happen due to __init__ validation
183200
raise ConfigurationError("Invalid _allowed_domains type")
@@ -432,7 +449,10 @@ async def verify_access_token(
432449
raise VerifyAccessTokenError("Token missing 'iss' claim")
433450

434451
# Normalize issuer for validation
435-
normalized_iss = normalize_domain(unverified_iss)
452+
try:
453+
normalized_iss = normalize_domain(unverified_iss)
454+
except ValueError as e:
455+
raise VerifyAccessTokenError(f"Invalid token issuer format: {str(e)}") from e
436456

437457
# Validate issuer against allowed domains (MCD)
438458
if self._allowed_domains is not None:
@@ -461,7 +481,10 @@ async def verify_access_token(
461481
raise VerifyAccessTokenError("Discovery metadata missing 'issuer' field")
462482

463483
# Normalize discovery issuer for comparison
464-
normalized_discovery_issuer = normalize_domain(discovery_issuer)
484+
try:
485+
normalized_discovery_issuer = normalize_domain(discovery_issuer)
486+
except ValueError as e:
487+
raise VerifyAccessTokenError(f"Invalid discovery issuer format: {str(e)}") from e
465488

466489
if normalized_iss != normalized_discovery_issuer:
467490
raise VerifyAccessTokenError(

tests/test_api_client.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2904,6 +2904,76 @@ async def test_mcd_init_with_non_string_domains_list():
29042904
))
29052905

29062906

2907+
@pytest.mark.asyncio
2908+
async def test_mcd_init_client_id_requires_domain():
2909+
"""Test that client_id requires domain to be set (needed for token endpoint operations)."""
2910+
with pytest.raises(ConfigurationError, match="'domain' parameter is required when 'client_id'"):
2911+
ApiClient(ApiClientOptions(
2912+
domains=["tenant.auth0.com"],
2913+
client_id="my-client",
2914+
client_secret="my-secret",
2915+
audience="my-audience"
2916+
))
2917+
2918+
# Should work with both domain and domains
2919+
client = ApiClient(ApiClientOptions(
2920+
domain="tenant.auth0.com",
2921+
domains=["tenant.auth0.com", "custom.example.com"],
2922+
client_id="my-client",
2923+
client_secret="my-secret",
2924+
audience="my-audience"
2925+
))
2926+
assert client.options.client_id == "my-client"
2927+
2928+
# Should work with domains only (no client_id)
2929+
client = ApiClient(ApiClientOptions(
2930+
domains=["tenant.auth0.com"],
2931+
audience="my-audience"
2932+
))
2933+
assert client._allowed_domains is not None
2934+
2935+
2936+
@pytest.mark.asyncio
2937+
async def test_cache_config_validation():
2938+
"""Test that cache_max_entries and cache_ttl_seconds are validated at init."""
2939+
with pytest.raises(ConfigurationError, match="cache_ttl_seconds must be a non-negative number"):
2940+
ApiClient(ApiClientOptions(
2941+
domain="tenant.auth0.com",
2942+
audience="my-audience",
2943+
cache_ttl_seconds=-1
2944+
))
2945+
2946+
with pytest.raises(ConfigurationError, match="cache_max_entries must be an integer greater than 1"):
2947+
ApiClient(ApiClientOptions(
2948+
domain="tenant.auth0.com",
2949+
audience="my-audience",
2950+
cache_max_entries=0
2951+
))
2952+
2953+
with pytest.raises(ConfigurationError, match="cache_max_entries must be an integer greater than 1"):
2954+
ApiClient(ApiClientOptions(
2955+
domain="tenant.auth0.com",
2956+
audience="my-audience",
2957+
cache_max_entries=1
2958+
))
2959+
2960+
with pytest.raises(ConfigurationError, match="cache_max_entries must be an integer greater than 1"):
2961+
ApiClient(ApiClientOptions(
2962+
domain="tenant.auth0.com",
2963+
audience="my-audience",
2964+
cache_max_entries=-5
2965+
))
2966+
2967+
# cache_ttl_seconds=0 is valid (always refetch), cache_max_entries=2 is minimum
2968+
client = ApiClient(ApiClientOptions(
2969+
domain="tenant.auth0.com",
2970+
audience="my-audience",
2971+
cache_ttl_seconds=0,
2972+
cache_max_entries=2
2973+
))
2974+
assert client._cache_ttl == 0
2975+
2976+
29072977
@pytest.mark.asyncio
29082978
async def test_mcd_resolve_allowed_domains_static_list():
29092979
"""Test _resolve_allowed_domains with static list."""
@@ -3111,6 +3181,21 @@ def bad_resolver(context):
31113181
await api_client._resolve_allowed_domains("https://tenant1.auth0.com/")
31123182

31133183

3184+
@pytest.mark.asyncio
3185+
async def test_mcd_resolver_returns_invalid_domain_format():
3186+
"""Test that resolver returning domains with invalid format raises DomainsResolverError."""
3187+
def resolver_with_http(context):
3188+
return ["tenant1.auth0.com", "http://invalid.com"]
3189+
3190+
api_client = ApiClient(ApiClientOptions(
3191+
domains=resolver_with_http,
3192+
audience="my-audience"
3193+
))
3194+
3195+
with pytest.raises(DomainsResolverError, match="Domains resolver returned invalid domain"):
3196+
await api_client._resolve_allowed_domains("https://tenant1.auth0.com/")
3197+
3198+
31143199
@pytest.mark.asyncio
31153200
async def test_mcd_verify_rejects_symmetric_algorithm():
31163201
"""Test that verify_access_token rejects tokens with symmetric algorithms (HS256)."""
@@ -3319,6 +3404,55 @@ async def test_mcd_second_issuer_validation(httpx_mock):
33193404
assert "verified token issuer does not match the discovery issuer" in str(err.value).lower()
33203405

33213406

3407+
@pytest.mark.asyncio
3408+
async def test_mcd_malformed_token_issuer_format(httpx_mock):
3409+
"""Test that a token with malformed iss claim raises VerifyAccessTokenError."""
3410+
# Generate token with http:// issuer (rejected by normalize_domain)
3411+
token = await generate_token(
3412+
domain="evil.com",
3413+
user_id="user123",
3414+
audience="my-audience",
3415+
issuer="http://evil.com"
3416+
)
3417+
3418+
api_client = ApiClient(ApiClientOptions(
3419+
domain="tenant1.auth0.com",
3420+
audience="my-audience"
3421+
))
3422+
3423+
with pytest.raises(VerifyAccessTokenError, match="Invalid token issuer format"):
3424+
await api_client.verify_access_token(token)
3425+
3426+
3427+
@pytest.mark.asyncio
3428+
async def test_mcd_malformed_discovery_issuer_format(httpx_mock):
3429+
"""Test that malformed issuer in discovery metadata raises VerifyAccessTokenError."""
3430+
token = await generate_token(
3431+
domain="tenant1.auth0.com",
3432+
user_id="user123",
3433+
audience="my-audience",
3434+
issuer="https://tenant1.auth0.com/"
3435+
)
3436+
3437+
# Mock discovery returning malformed issuer with http://
3438+
httpx_mock.add_response(
3439+
method="GET",
3440+
url="https://tenant1.auth0.com/.well-known/openid-configuration",
3441+
json={
3442+
"issuer": "http://tenant1.auth0.com/",
3443+
"jwks_uri": "https://tenant1.auth0.com/.well-known/jwks.json"
3444+
}
3445+
)
3446+
3447+
api_client = ApiClient(ApiClientOptions(
3448+
domain="tenant1.auth0.com",
3449+
audience="my-audience"
3450+
))
3451+
3452+
with pytest.raises(VerifyAccessTokenError, match="Invalid discovery issuer format"):
3453+
await api_client.verify_access_token(token)
3454+
3455+
33223456
@pytest.mark.asyncio
33233457
async def test_mcd_discovery_missing_issuer_field(httpx_mock):
33243458
"""Test that missing issuer field in discovery causes clear error."""

0 commit comments

Comments
 (0)