Skip to content

Commit

Permalink
refactor auth - handle bad creds cache
Browse files Browse the repository at this point in the history
  • Loading branch information
zxdavb committed Dec 24, 2024
1 parent 14b9ed8 commit b3bbbd2
Show file tree
Hide file tree
Showing 14 changed files with 88 additions and 58 deletions.
10 changes: 8 additions & 2 deletions cli/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class TokenManager(AbstractTokenManager, AbstractSessionManager):
def __init__(
self, *args: Any, cache_file: Path | None = None, **kwargs: Any
) -> None:
"""Initialise the token manager."""
"""Initialise the credentials manager (for access_token & session_id)."""
super().__init__(*args, **kwargs)

self._cache_file: Final = cache_file
Expand Down Expand Up @@ -120,7 +120,7 @@ async def _write_cache_to_file(self, cache: CacheDataT) -> None:
async with aiofiles.open(self.cache_file, "w") as fp:
await fp.write(content)

async def load_cache(self) -> None:
async def load_from_cache(self) -> None:
"""Load the user entry from the cache."""

cache: CacheDataT = await self._read_cache_from_file()
Expand Down Expand Up @@ -162,6 +162,12 @@ async def _load_session_id(self, cache: CacheDataT | None = None) -> None:
if self._session_id_expires.isoformat() < session[SZ_SESSION_ID_EXPIRES]:
self._import_session_id(session)

async def save_to_cache(self) -> None:
"""Save the user entry to the cache."""

await self.save_access_token()
await self.save_session_id()

async def save_access_token(self) -> None:
"""Save the (serialized) access token to the cache.
Expand Down
22 changes: 18 additions & 4 deletions src/evohome/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,17 @@ async def request(
if method == HTTPMethod.PUT and "json" in kwargs:
kwargs["json"] = convert_keys_to_camel_case(kwargs["json"])

response = await self._make_request(method, url, **kwargs)
try:
response = await self._make_request(method, url, **kwargs)
except exc.ApiRequestFailedError as err:
if err.status != HTTPStatus.UNAUTHORIZED: # 401
# leave it up to higher layers to handle the 401 as they can either be
# authentication errors: bad access_token, bad session_id
# authorization errors: bad URL (e.g. no access to that location)
self.logger.debug(
f"The access_token/session_id may be invalid (it shouldn't be): {err}"
)
raise

if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(f"{method} {url}: {obscure_secrets(response)}")
Expand All @@ -247,12 +257,11 @@ async def _make_request(

rsp: aiohttp.ClientResponse | None = None

url = f"{self.url_base}/{url}"
headers = await self._headers(kwargs.pop("headers", {}))

try:
rsp = await self.websession.request(
method, f"{self.url_base}/{url}", headers=headers, **kwargs
)
rsp = await self.websession.request(method, url, headers=headers, **kwargs)
assert rsp is not None # mypy

rsp.raise_for_status()
Expand All @@ -271,6 +280,11 @@ async def _make_request(
f"{method} {url}: response is not JSON: {await _payload(rsp)}"
) from err

# an invalid access_token / session_id will cause a 401/Unauthorized
# so, we'd need to re-authenticate - how do we effect that?
# unfortunately, other scenarios cause a 401 (e.g. bad loc_id in URL)
# for now, leave it up to the consumer to handle the 401

except aiohttp.ClientResponseError as err:
# if hint := _ERR_MSG_LOOKUP_BASE.get(err.status):
# raise exc.ApiRequestFailedError(hint, status=err.status) from err
Expand Down
6 changes: 4 additions & 2 deletions src/evohomeasync/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,15 @@ async def get_session_id(self) -> str:
"""

if not self.is_session_id_valid(): # may be invalid for other reasons
self.logger.warning("Null/Expired/Invalid session_id, re-authenticating.")
self.logger.warning(
"Null/Expired/Invalid session_id, will re-authenticate..."
)
await self._update_session_id()

return self.session_id

async def _update_session_id(self) -> None:
self.logger.warning("Authenticating with client_id/secret...")
self.logger.warning("Authenticating with client_id/secret")

credentials = {
"applicationId": _APPLICATION_ID,
Expand Down
18 changes: 17 additions & 1 deletion src/evohomeasync/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Final

from evohome.helpers import camel_to_snake
Expand Down Expand Up @@ -82,7 +83,22 @@ async def update(

if self._user_info is None:
url = "accountInfo"
self._user_info = await self.auth.get(url, schema=SCH_GET_ACCOUNT_INFO) # type: ignore[assignment]
try:
self._user_info = await self.auth.get(url, schema=SCH_GET_ACCOUNT_INFO) # type: ignore[assignment]
except exc.ApiRequestFailedError as err:
if err.status != HTTPStatus.UNAUTHORIZED: # 401
raise

# in this case, the 401 must be due to a bad session_id as the
# URL is well-known (and there is no no usr_id, loc_id, etc.)
# that is, the accountInfo URL is open to all authenticated users

self.logger.warning(
f"The session_id appears invalid (will re-authenticate): {err}"
)

self._session_manager._clear_session_id()
self._user_info = await self.auth.get(url, schema=SCH_GET_ACCOUNT_INFO) # type: ignore[assignment]

assert self._user_info is not None # mypy (internal hint)

Expand Down
2 changes: 1 addition & 1 deletion src/evohomeasync/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ class TccSessionResponseT(TypedDict):
class TccUserAccountInfoResponseT(TypedDict): # NOTE: is not TccUserAccountResponseT
"""GET api/accountInfo"""

userId: _UserIdT
userID: _UserIdT
username: str # email address
firstname: str
lastname: str
Expand Down
8 changes: 5 additions & 3 deletions src/evohomeasync2/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ async def get_access_token(self) -> str:
"""

if not self.is_access_token_valid(): # may be invalid for other reasons
self.logger.warning("Null/Expired/Invalid access_token, re-authenticating.")
self.logger.warning(
"Null/Expired/Invalid access_token, will re-authenticate..."
)
await self._update_access_token()

return self.access_token
Expand All @@ -169,7 +171,7 @@ async def _update_access_token(self) -> None:
"""Update the access token and save it to the store/cache."""

if self._refresh_token:
self.logger.warning("Authenticating with the refresh_token...")
self.logger.warning("Authenticating with the refresh_token")

credentials = {SZ_REFRESH_TOKEN: self._refresh_token}

Expand All @@ -184,7 +186,7 @@ async def _update_access_token(self) -> None:
self._refresh_token = ""

if not self._refresh_token:
self.logger.warning("Authenticating with client_id/secret...")
self.logger.warning("Authenticating with client_id/secret")

# NOTE: the keys are case-sensitive: 'Username' and 'Password'
credentials = {"Username": self._client_id, "Password": self._secret}
Expand Down
20 changes: 18 additions & 2 deletions src/evohomeasync2/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Final
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError

Expand Down Expand Up @@ -70,7 +71,7 @@ def __init__(
def __str__(self) -> str:
return f"{self.__class__.__name__}(auth='{self.auth}')"

async def update(
async def update( # noqa: C901
self,
/,
*,
Expand All @@ -92,7 +93,22 @@ async def update(

if self._user_info is None:
url = "userAccount"
self._user_info = await self.auth.get(url, schema=SCH_USER_ACCOUNT) # type: ignore[assignment]
try:
self._user_info = await self.auth.get(url, schema=SCH_USER_ACCOUNT) # type: ignore[assignment]
except exc.ApiRequestFailedError as err:
if err.status != HTTPStatus.UNAUTHORIZED: # 401
raise

# in this case, the 401 must be due to a bad access_token as the
# URL is well-known (and there is no no usr_id, loc_id, etc.)
# that is, the userAccount URL is open to all authenticated users

self.logger.warning(
f"The access_token appears invalid (will re-authenticate): {err}"
)

self._token_manager._clear_access_token()
self._user_info = await self.auth.get(url, schema=SCH_USER_ACCOUNT) # type: ignore[assignment]

assert self._user_info is not None # mypy (internal hint)

Expand Down
10 changes: 5 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,16 @@ async def cache_manager(
credentials: tuple[str, str],
cache_file: Path,
) -> AsyncGenerator[TokenManager]:
"""Yield a token manager for the v2 API."""
"""Yield a credentials manager for access_token & session_id caching."""

manager = TokenManager(*credentials, client_session, cache_file=cache_file)

# await cache_manager.load_cache()
# await manager.load_from_cache()

try:
yield manager
finally:
await manager.save_access_token() # for next run of tests
await manager.save_to_cache() # for next run of tests


@pytest.fixture
Expand All @@ -150,7 +150,7 @@ async def evohome_v0(
) -> AsyncGenerator[EvohomeClientv0]:
"""Yield an instance of a v0 EvohomeClient."""

await cache_manager.load_cache()
await cache_manager.load_from_cache()

evo = EvohomeClientv0(cache_manager)

Expand All @@ -168,7 +168,7 @@ async def evohome_v2(
) -> AsyncGenerator[EvohomeClientv2]:
"""Yield an instance of a v2 EvohomeClient."""

await cache_manager.load_cache()
await cache_manager.load_from_cache()

evo = EvohomeClientv2(cache_manager)

Expand Down
4 changes: 2 additions & 2 deletions tests/tests/test_v0_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ async def test_session_manager(
# have not yet called get_access_token (so not loaded cache either)
assert cache_manager.is_session_id_valid() is False

await cache_manager.load_cache()
await cache_manager.load_from_cache()
assert cache_manager.is_session_id_valid() is False

#
Expand All @@ -157,7 +157,7 @@ async def test_session_manager(

cache_manager = TokenManager(*credentials, client_session, cache_file=cache_file)

await cache_manager.load_cache()
await cache_manager.load_from_cache()
assert cache_manager.is_session_id_valid() is True

session_id = await cache_manager.get_session_id()
Expand Down
4 changes: 2 additions & 2 deletions tests/tests/test_v2_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ async def test_token_manager(
# have not yet called get_access_token (so not loaded cache either)
assert cache_manager.is_session_id_valid() is False

await cache_manager.load_cache()
await cache_manager.load_from_cache()
assert cache_manager.is_access_token_valid() is False

#
Expand All @@ -164,7 +164,7 @@ async def test_token_manager(

cache_manager = TokenManager(*credentials, client_session, cache_file=cache_file)

await cache_manager.load_cache()
await cache_manager.load_from_cache()
assert cache_manager.is_access_token_valid() is True

access_token = await cache_manager.get_access_token()
Expand Down
17 changes: 2 additions & 15 deletions tests/tests_rf/test_v0_cred.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from tests.const import (
_DBG_TEST_CRED_URLS,
_DBG_USE_REAL_AIOHTTP,
HEADERS_BASE,
HEADERS_CRED_V0 as HEADERS_CRED,
URL_BASE_V0 as URL_BASE,
URL_CRED_V0 as URL_CRED,
)
Expand All @@ -41,21 +43,6 @@
)


HEADERS_CRED = {
"Accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded; charset=utf-8", # data=
"Cache-Control": "no-cache, no-store",
"Pragma": "no-cache",
"Connection": "Keep-Alive",
# _APPLICATION_ID is in the data
}
HEADERS_BASE = {
"Accept": "application/json",
"Content-Type": "application/json", # json=
"SessionId": "", # "e163b069-1234-..."
}


async def handle_too_many_requests(rsp: aiohttp.ClientResponse) -> None:
assert rsp.status == HTTPStatus.TOO_MANY_REQUESTS

Expand Down
2 changes: 1 addition & 1 deletion tests/tests_rf/test_v2_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async def _test_basics_apis(evo: EvohomeClientv2) -> None:
"""Test authentication, `user_account()` and `installation()`."""

# STEP 1: retrieve base data
await evo.update(_dont_update_status=False)
await evo.update(_dont_update_status=True)

assert evo2.main.SCH_USER_ACCOUNT(evo.user_account)
assert evo2.main.SCH_USER_LOCATIONS(evo.user_installation)
Expand Down
17 changes: 2 additions & 15 deletions tests/tests_rf/test_v2_cred.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import aiohttp
import pytest

from evohomeasync2.auth import _APPLICATION_ID
from evohomeasync2.schemas import (
TCC_ERROR_RESPONSE,
TCC_GET_USR_ACCOUNT,
Expand All @@ -30,6 +29,8 @@
from tests.const import (
_DBG_TEST_CRED_URLS,
_DBG_USE_REAL_AIOHTTP,
HEADERS_BASE,
HEADERS_CRED_V2 as HEADERS_CRED,
URL_BASE_V2 as URL_BASE,
URL_CRED_V2 as URL_CRED,
)
Expand All @@ -42,20 +43,6 @@
TccUsrAccountResponseT,
)

HEADERS_CRED = {
"Accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded; charset=utf-8", # data=
"Cache-Control": "no-cache, no-store",
"Pragma": "no-cache",
"Connection": "Keep-Alive",
"Authorization": "Basic " + _APPLICATION_ID,
}
HEADERS_BASE = {
"Accept": "application/json",
"Content-Type": "application/json", # json=
"Authorization": "Bearer ...", # "Bearer " + access_token
}


async def handle_too_many_requests(rsp: aiohttp.ClientResponse) -> None:
assert rsp.status == HTTPStatus.TOO_MANY_REQUESTS # 429
Expand Down
6 changes: 3 additions & 3 deletions tests/tests_rf/test_v2_urlz.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def test_tcs_urls(
) -> None:
"""Test Location, Gateway and TCS URLs."""

await cache_manager.load_cache()
await cache_manager.load_from_cache()

#
# STEP 0: Create the Auth client
Expand Down Expand Up @@ -139,7 +139,7 @@ async def test_zon_urls(
) -> None:
"""Test Zone URLs"""

await cache_manager.load_cache()
await cache_manager.load_from_cache()

#
# STEP 0: Create the Auth client, get the TCS config
Expand Down Expand Up @@ -221,7 +221,7 @@ async def test_dhw_urls(
) -> None:
"""Test DHW URLs"""

await cache_manager.load_cache()
await cache_manager.load_from_cache()

#
# STEP 0: Create the Auth client, get the TCS config
Expand Down

0 comments on commit b3bbbd2

Please sign in to comment.