Skip to content

Commit

Permalink
[PENG-489] Prepare monorepo for CHIP (#12728)
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 2dc08b495e06c6046f108c4e629faaeb9253dcfc
  • Loading branch information
stephencpope authored and Descartes Labs Build committed Dec 4, 2024
1 parent 1e550c7 commit fa4e436
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 60 deletions.
17 changes: 9 additions & 8 deletions descarteslabs/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@
from ..common.http import Retry, Session


# This is only for the existing DL production tenant, and must remain in place
# until the tenant is completely replaced, if ever.
LEGACY_DELEGATION_CLIENT_IDS = ["ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c"]


# copied from descarteslabs/common/threading/local.py, but we need
# it standalone here to avoid any dependencies on our own packages
# for client configuration purposes
Expand Down Expand Up @@ -321,8 +326,8 @@ def __init__(
>>> auth.namespace # doctest: +SKIP
'a54d88e06612d820bc3be72877c74f257b561b19'
>>> auth = descarteslabs.auth.Auth(
... client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
... client_secret="b70B_ozH6CaV23WQ-toFQ8CaujGHDs-eC39QEJTRnZa9Z",
... client_id="some-client-id",
... client_secret="some-client-secret",
... )
>>> auth.namespace # doctest: +SKIP
'67f21eb1040f978fe1da32e5e33501d0f4a604ac'
Expand Down Expand Up @@ -426,7 +431,7 @@ def __init__(
and self.refresh_token == token_info.get(self.KEY_REFRESH_TOKEN)
):
self._token = token_info.get(self.KEY_JWT_TOKEN)
elif self.refresh_token and self.token_info_path is not None:
elif self.refresh_token and self.token_info_path:
# Make the saved JWT token file unique to the refresh token
token = self.refresh_token
token_sha1 = sha1(token.encode("utf-8")).hexdigest()
Expand Down Expand Up @@ -739,11 +744,7 @@ def _get_token(self, timeout=100):
self.AUTHORIZATION_ERROR.format(" (no client_secret or refresh_token)")
)

if self.client_id in [
# production tenant
"ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
]: # TODO(justin) remove legacy handling
# TODO (justin) insert deprecation warning
if self.client_id in LEGACY_DELEGATION_CLIENT_IDS:
if self.scope is None:
scope = ["openid", "name", "groups", "org", "email"]
else:
Expand Down
86 changes: 42 additions & 44 deletions descarteslabs/auth/tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from descarteslabs.exceptions import AuthError

from .. import auth as auth_module
from ..auth import Auth
from ..auth import Auth, LEGACY_DELEGATION_CLIENT_IDS


def token_response_callback(request):
Expand All @@ -49,16 +49,19 @@ def token_response_callback(request):
return (
200,
{"Content-Type": "application/json"},
json.dumps(dict(id_token="id_token")),
json.dumps(dict(id_token="legacy-id-token")),
)

if data["grant_type"] == "refresh_token" and all(
field not in data for field in legacy_required_fields
):
# note: this used to return both an access_token and an id_token
# but that isn't how IAM works anymore: it only returns an id_token.
# this isn't really OAuth2, but it is what it is.
return (
200,
{"Content-Type": "application/json"},
json.dumps(dict(access_token="access_token", id_token="id_token")),
json.dumps(dict(id_token="id-token")),
)
return 400, {"Content-Type": "application/json"}, json.dumps(data)

Expand All @@ -79,10 +82,6 @@ def setUpClass(cls):
cls.env = dict(os.environ)
os.environ.clear()

# Make sure we're not picking up credentials from anywhere
auth_module.DEFAULT_TOKEN_INFO_DIR = "/tmp"
auth_module.DEFAULT_TOKEN_INFO_PATH = "/dev/null"

@classmethod
def tearDownClass(cls):
os.environ.update(cls.env)
Expand All @@ -109,30 +108,30 @@ def test_get_token(self):
responses.add(
responses.POST,
f"{domain}/token",
json=dict(access_token="access_token"),
json=dict(access_token="access-token"),
status=200,
)
auth = Auth(client_secret="client_secret", client_id="client_id")
auth = Auth(client_secret="client-secret", client_id="client-id")
auth._get_token()

assert "access_token" == auth._token
assert "access-token" == auth._token

@responses.activate
def test_get_token_legacy(self):
responses.add(
responses.POST,
f"{domain}/token",
json=dict(id_token="id_token"),
json=dict(id_token="id-token"),
status=200,
)
auth = Auth(client_secret="client_secret", client_id="client_id")
auth = Auth(client_secret="client-secret", client_id="client-id")
auth._get_token()

assert "id_token" == auth._token
assert "id-token" == auth._token

@patch.object(Auth, "payload", new=dict(sub="asdf"))
def test_get_namespace(self):
auth = Auth(client_secret="client_secret", client_id="client_id")
auth = Auth(client_secret="client-secret", client_id="client-id")
assert auth.namespace == "3da541559918a808c2402bba5012f6c60b27661c"

def test_init_token_no_path(self):
Expand All @@ -152,16 +151,17 @@ def test_get_token_schema_internal_only(self):
f"{domain}/token",
callback=token_response_callback,
)
auth = Auth(refresh_token="refresh_token", client_id="client_id")
auth = Auth(refresh_token="refresh-token", client_id="client-id")
auth._get_token()

assert "access_token" == auth._token
assert "id-token" == auth._token

auth = Auth(client_secret="refresh_token", client_id="client_id")
auth = Auth(client_secret="refresh-token", client_id="client-id")
auth._get_token()

assert "access_token" == auth._token
assert "id-token" == auth._token

@unittest.skipUnless(len(LEGACY_DELEGATION_CLIENT_IDS) > 0, "No legacy client IDs")
@responses.activate
def test_get_token_schema_legacy_internal_only(self):
responses.add_callback(
Expand All @@ -170,26 +170,24 @@ def test_get_token_schema_legacy_internal_only(self):
callback=token_response_callback,
)
auth = Auth(
client_secret="client_secret",
client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
client_secret="client-secret",
client_id=LEGACY_DELEGATION_CLIENT_IDS[0],
)
auth._get_token()
assert "id_token" == auth._token
assert "legacy-id-token" == auth._token

@patch.object(Auth, "_get_token")
def test_token(self, _get_token):
auth = Auth(
client_secret="client_secret",
client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
client_secret="client-secret",
client_id="client-id",
)
token = b".".join(
(
base64.b64encode(to_bytes(p))
for p in [
"header",
json.dumps(
dict(exp=9999999999, aud="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c")
),
json.dumps(dict(exp=9999999999, aud="client-id")),
"sig",
]
)
Expand All @@ -202,8 +200,8 @@ def test_token(self, _get_token):
@patch.object(Auth, "_get_token")
def test_token_expired(self, _get_token):
auth = Auth(
client_secret="client_secret",
client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
client_secret="client-secret",
client_id="client-id",
)
token = b".".join(
(
Expand All @@ -219,8 +217,8 @@ def test_token_expired(self, _get_token):
@patch.object(Auth, "_get_token", side_effect=AuthError("error"))
def test_token_expired_autherror(self, _get_token):
auth = Auth(
client_secret="client_secret",
client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
client_secret="client-secret",
client_id="client-id",
)
token = b".".join(
(
Expand All @@ -237,8 +235,8 @@ def test_token_expired_autherror(self, _get_token):
@patch.object(Auth, "_get_token", side_effect=AuthError("error"))
def test_token_in_leeway_autherror(self, _get_token):
auth = Auth(
client_secret="client_secret",
client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
client_secret="client-secret",
client_id="client-id",
)
exp = (
datetime.datetime.now(datetime.timezone.utc)
Expand Down Expand Up @@ -269,12 +267,12 @@ def test_auth_init_env_vars(self):
# should work with direct var
with patch.object(auth_module.os, "environ", environ):
auth = Auth(
client_id="client_id",
client_secret="client_secret",
refresh_token="client_secret",
client_id="client-id",
client_secret="client-secret",
refresh_token="client-secret",
)
assert auth.client_secret == "client_secret"
assert auth.client_id == "client_id"
assert auth.client_secret == "client-secret"
assert auth.client_id == "client-id"

# should work with namespaced env vars
with patch.object(auth_module.os, "environ", environ):
Expand Down Expand Up @@ -383,7 +381,7 @@ def test_clear_cached_jwt_token_different_secret(self):

def test_no_valid_auth_info(self):
with warnings.catch_warnings(record=True) as caught_warnings:
Auth(client_id="client_id")
Auth(client_id="client-id")
assert len(caught_warnings) == 1
assert caught_warnings[0].category == UserWarning
assert "No valid authentication info found" in str(
Expand Down Expand Up @@ -492,8 +490,8 @@ def test_domain(self):

def test_all_acl_subjects(self):
auth = Auth(
client_secret="client_secret",
client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
client_secret="client-secret",
client_id="client-id",
)
token = b".".join(
(
Expand All @@ -506,7 +504,7 @@ def test_all_acl_subjects(self):
groups=["public"],
org="some-org",
exp=9999999999,
aud="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
aud="client-id",
)
),
"sig",
Expand All @@ -523,8 +521,8 @@ def test_all_acl_subjects(self):

def test_all_acl_subjects_ignores_bad_org_groups(self):
auth = Auth(
client_secret="client_secret",
client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
client_secret="client-secret",
client_id="client-id",
)
token = b".".join(
(
Expand All @@ -537,7 +535,7 @@ def test_all_acl_subjects_ignores_bad_org_groups(self):
groups=["public", "some-org:baz", "other:baz"],
org="some-org",
exp=9999999999,
aud="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
aud="client-id",
)
),
"sig",
Expand Down
6 changes: 3 additions & 3 deletions descarteslabs/config/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,9 @@ def test_remaining_keys(self):
for key in config.keys():
settings.pop(key)

settings.pop("DEFAULT_DOMAIN", None) # Added since 1.11.0
settings.pop("DEFAULT_HOST", None) # Added since 1.11.0
settings.pop("DOMAIN") # Added since 1.11.0
settings.pop("DEFAULT_DOMAIN", None)
settings.pop("DOMAIN", None)
settings.pop("TOKEN_INFO_PATH", None) # picked up from test environment

assert settings.pop("ENV") == config_name
assert len(settings) == 0, f"{config_name}: {settings}"
2 changes: 1 addition & 1 deletion descarteslabs/core/catalog/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def setUp(self):
base64.b64encode(
json.dumps(
{
"aud": "ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
"aud": "client-id",
"exp": time.time() + 3600,
"sub": "some|user",
"org": "some-org",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def setUp(self):
base64.b64encode(
json.dumps(
{
"aud": "ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
"aud": "client-id",
"exp": time.time() + 3600,
}
).encode()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def setUp(self):
base64.b64encode(
json.dumps(
{
"aud": "ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
"aud": "client-id",
"exp": time.time() + 3600,
}
).encode()
Expand Down
2 changes: 1 addition & 1 deletion descarteslabs/core/compute/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def setUp(self):
base64.b64encode(
json.dumps(
{
"aud": "ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
"aud": "client-id",
"exp": time.time() + 3600,
}
).encode()
Expand Down
2 changes: 1 addition & 1 deletion descarteslabs/core/vector/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def setUp(self):
base64.b64encode(
json.dumps(
{
"aud": "ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
"aud": "client-id",
"exp": time.time() + 3600,
}
).encode()
Expand Down

0 comments on commit fa4e436

Please sign in to comment.