diff --git a/cli/auth.py b/cli/auth.py index ddf0757..953a8bb 100644 --- a/cli/auth.py +++ b/cli/auth.py @@ -61,7 +61,7 @@ class TokenManager(AbstractTokenManager, AbstractSessionManager): def __init__( self, *args: Any, cache_file: Path | None = None, **kwargs: Any ) -> None: - """Initialise the token manager.""" + """Initialise the credentials manager (for access_token & session_id).""" super().__init__(*args, **kwargs) self._cache_file: Final = cache_file @@ -120,7 +120,7 @@ async def _write_cache_to_file(self, cache: CacheDataT) -> None: async with aiofiles.open(self.cache_file, "w") as fp: await fp.write(content) - async def load_cache(self) -> None: + async def load_from_cache(self) -> None: """Load the user entry from the cache.""" cache: CacheDataT = await self._read_cache_from_file() @@ -162,6 +162,12 @@ async def _load_session_id(self, cache: CacheDataT | None = None) -> None: if self._session_id_expires.isoformat() < session[SZ_SESSION_ID_EXPIRES]: self._import_session_id(session) + async def save_to_cache(self) -> None: + """Save the user entry to the cache.""" + + await self.save_access_token() + await self.save_session_id() + async def save_access_token(self) -> None: """Save the (serialized) access token to the cache. diff --git a/src/evohome/auth.py b/src/evohome/auth.py index 9165d48..9cb9f33 100644 --- a/src/evohome/auth.py +++ b/src/evohome/auth.py @@ -221,7 +221,17 @@ async def request( if method == HTTPMethod.PUT and "json" in kwargs: kwargs["json"] = convert_keys_to_camel_case(kwargs["json"]) - response = await self._make_request(method, url, **kwargs) + try: + response = await self._make_request(method, url, **kwargs) + except exc.ApiRequestFailedError as err: + if err.status != HTTPStatus.UNAUTHORIZED: # 401 + # leave it up to higher layers to handle the 401 as they can either be + # authentication errors: bad access_token, bad session_id + # authorization errors: bad URL (e.g. no access to that location) + self.logger.debug( + f"The access_token/session_id may be invalid (it shouldn't be): {err}" + ) + raise if self.logger.isEnabledFor(logging.DEBUG): self.logger.debug(f"{method} {url}: {obscure_secrets(response)}") @@ -247,12 +257,11 @@ async def _make_request( rsp: aiohttp.ClientResponse | None = None + url = f"{self.url_base}/{url}" headers = await self._headers(kwargs.pop("headers", {})) try: - rsp = await self.websession.request( - method, f"{self.url_base}/{url}", headers=headers, **kwargs - ) + rsp = await self.websession.request(method, url, headers=headers, **kwargs) assert rsp is not None # mypy rsp.raise_for_status() @@ -271,6 +280,11 @@ async def _make_request( f"{method} {url}: response is not JSON: {await _payload(rsp)}" ) from err + # an invalid access_token / session_id will cause a 401/Unauthorized + # so, we'd need to re-authenticate - how do we effect that? + # unfortunately, other scenarios cause a 401 (e.g. bad loc_id in URL) + # for now, leave it up to the consumer to handle the 401 + except aiohttp.ClientResponseError as err: # if hint := _ERR_MSG_LOOKUP_BASE.get(err.status): # raise exc.ApiRequestFailedError(hint, status=err.status) from err diff --git a/src/evohomeasync/auth.py b/src/evohomeasync/auth.py index c95145f..35d7575 100644 --- a/src/evohomeasync/auth.py +++ b/src/evohomeasync/auth.py @@ -120,13 +120,15 @@ async def get_session_id(self) -> str: """ if not self.is_session_id_valid(): # may be invalid for other reasons - self.logger.warning("Null/Expired/Invalid session_id, re-authenticating.") + self.logger.warning( + "Null/Expired/Invalid session_id, will re-authenticate..." + ) await self._update_session_id() return self.session_id async def _update_session_id(self) -> None: - self.logger.warning("Authenticating with client_id/secret...") + self.logger.warning("Authenticating with client_id/secret") credentials = { "applicationId": _APPLICATION_ID, diff --git a/src/evohomeasync/main.py b/src/evohomeasync/main.py index 446a0a8..c252166 100644 --- a/src/evohomeasync/main.py +++ b/src/evohomeasync/main.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Final from evohome.helpers import camel_to_snake @@ -82,7 +83,22 @@ async def update( if self._user_info is None: url = "accountInfo" - self._user_info = await self.auth.get(url, schema=SCH_GET_ACCOUNT_INFO) # type: ignore[assignment] + try: + self._user_info = await self.auth.get(url, schema=SCH_GET_ACCOUNT_INFO) # type: ignore[assignment] + except exc.ApiRequestFailedError as err: + if err.status != HTTPStatus.UNAUTHORIZED: # 401 + raise + + # in this case, the 401 must be due to a bad session_id as the + # URL is well-known (and there is no no usr_id, loc_id, etc.) + # that is, the accountInfo URL is open to all authenticated users + + self.logger.warning( + f"The session_id appears invalid (will re-authenticate): {err}" + ) + + self._session_manager._clear_session_id() + self._user_info = await self.auth.get(url, schema=SCH_GET_ACCOUNT_INFO) # type: ignore[assignment] assert self._user_info is not None # mypy (internal hint) diff --git a/src/evohomeasync/schemas.py b/src/evohomeasync/schemas.py index 6127f47..d6dede4 100644 --- a/src/evohomeasync/schemas.py +++ b/src/evohomeasync/schemas.py @@ -264,7 +264,7 @@ class TccSessionResponseT(TypedDict): class TccUserAccountInfoResponseT(TypedDict): # NOTE: is not TccUserAccountResponseT """GET api/accountInfo""" - userId: _UserIdT + userID: _UserIdT username: str # email address firstname: str lastname: str diff --git a/src/evohomeasync2/auth.py b/src/evohomeasync2/auth.py index d9d34c8..951821f 100644 --- a/src/evohomeasync2/auth.py +++ b/src/evohomeasync2/auth.py @@ -160,7 +160,9 @@ async def get_access_token(self) -> str: """ if not self.is_access_token_valid(): # may be invalid for other reasons - self.logger.warning("Null/Expired/Invalid access_token, re-authenticating.") + self.logger.warning( + "Null/Expired/Invalid access_token, will re-authenticate..." + ) await self._update_access_token() return self.access_token @@ -169,7 +171,7 @@ async def _update_access_token(self) -> None: """Update the access token and save it to the store/cache.""" if self._refresh_token: - self.logger.warning("Authenticating with the refresh_token...") + self.logger.warning("Authenticating with the refresh_token") credentials = {SZ_REFRESH_TOKEN: self._refresh_token} @@ -184,7 +186,7 @@ async def _update_access_token(self) -> None: self._refresh_token = "" if not self._refresh_token: - self.logger.warning("Authenticating with client_id/secret...") + self.logger.warning("Authenticating with client_id/secret") # NOTE: the keys are case-sensitive: 'Username' and 'Password' credentials = {"Username": self._client_id, "Password": self._secret} diff --git a/src/evohomeasync2/main.py b/src/evohomeasync2/main.py index a231266..fc496b3 100644 --- a/src/evohomeasync2/main.py +++ b/src/evohomeasync2/main.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Final from zoneinfo import ZoneInfo, ZoneInfoNotFoundError @@ -70,7 +71,7 @@ def __init__( def __str__(self) -> str: return f"{self.__class__.__name__}(auth='{self.auth}')" - async def update( + async def update( # noqa: C901 self, /, *, @@ -92,7 +93,22 @@ async def update( if self._user_info is None: url = "userAccount" - self._user_info = await self.auth.get(url, schema=SCH_USER_ACCOUNT) # type: ignore[assignment] + try: + self._user_info = await self.auth.get(url, schema=SCH_USER_ACCOUNT) # type: ignore[assignment] + except exc.ApiRequestFailedError as err: + if err.status != HTTPStatus.UNAUTHORIZED: # 401 + raise + + # in this case, the 401 must be due to a bad access_token as the + # URL is well-known (and there is no no usr_id, loc_id, etc.) + # that is, the userAccount URL is open to all authenticated users + + self.logger.warning( + f"The access_token appears invalid (will re-authenticate): {err}" + ) + + self._token_manager._clear_access_token() + self._user_info = await self.auth.get(url, schema=SCH_USER_ACCOUNT) # type: ignore[assignment] assert self._user_info is not None # mypy (internal hint) diff --git a/tests/conftest.py b/tests/conftest.py index 7f88395..cd9946a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -132,16 +132,16 @@ async def cache_manager( credentials: tuple[str, str], cache_file: Path, ) -> AsyncGenerator[TokenManager]: - """Yield a token manager for the v2 API.""" + """Yield a credentials manager for access_token & session_id caching.""" manager = TokenManager(*credentials, client_session, cache_file=cache_file) - # await cache_manager.load_cache() + # await manager.load_from_cache() try: yield manager finally: - await manager.save_access_token() # for next run of tests + await manager.save_to_cache() # for next run of tests @pytest.fixture @@ -150,7 +150,7 @@ async def evohome_v0( ) -> AsyncGenerator[EvohomeClientv0]: """Yield an instance of a v0 EvohomeClient.""" - await cache_manager.load_cache() + await cache_manager.load_from_cache() evo = EvohomeClientv0(cache_manager) @@ -168,7 +168,7 @@ async def evohome_v2( ) -> AsyncGenerator[EvohomeClientv2]: """Yield an instance of a v2 EvohomeClient.""" - await cache_manager.load_cache() + await cache_manager.load_from_cache() evo = EvohomeClientv2(cache_manager) diff --git a/tests/tests/test_v0_auth.py b/tests/tests/test_v0_auth.py index d5cb7ee..7dac323 100644 --- a/tests/tests/test_v0_auth.py +++ b/tests/tests/test_v0_auth.py @@ -147,7 +147,7 @@ async def test_session_manager( # have not yet called get_access_token (so not loaded cache either) assert cache_manager.is_session_id_valid() is False - await cache_manager.load_cache() + await cache_manager.load_from_cache() assert cache_manager.is_session_id_valid() is False # @@ -157,7 +157,7 @@ async def test_session_manager( cache_manager = TokenManager(*credentials, client_session, cache_file=cache_file) - await cache_manager.load_cache() + await cache_manager.load_from_cache() assert cache_manager.is_session_id_valid() is True session_id = await cache_manager.get_session_id() diff --git a/tests/tests/test_v2_auth.py b/tests/tests/test_v2_auth.py index f836658..ee8ee10 100644 --- a/tests/tests/test_v2_auth.py +++ b/tests/tests/test_v2_auth.py @@ -154,7 +154,7 @@ async def test_token_manager( # have not yet called get_access_token (so not loaded cache either) assert cache_manager.is_session_id_valid() is False - await cache_manager.load_cache() + await cache_manager.load_from_cache() assert cache_manager.is_access_token_valid() is False # @@ -164,7 +164,7 @@ async def test_token_manager( cache_manager = TokenManager(*credentials, client_session, cache_file=cache_file) - await cache_manager.load_cache() + await cache_manager.load_from_cache() assert cache_manager.is_access_token_valid() is True access_token = await cache_manager.get_access_token() diff --git a/tests/tests_rf/test_v0_cred.py b/tests/tests_rf/test_v0_cred.py index 325d768..d2996ac 100644 --- a/tests/tests_rf/test_v0_cred.py +++ b/tests/tests_rf/test_v0_cred.py @@ -29,6 +29,8 @@ from tests.const import ( _DBG_TEST_CRED_URLS, _DBG_USE_REAL_AIOHTTP, + HEADERS_BASE, + HEADERS_CRED_V0 as HEADERS_CRED, URL_BASE_V0 as URL_BASE, URL_CRED_V0 as URL_CRED, ) @@ -41,21 +43,6 @@ ) -HEADERS_CRED = { - "Accept": "application/json", - "Content-Type": "application/x-www-form-urlencoded; charset=utf-8", # data= - "Cache-Control": "no-cache, no-store", - "Pragma": "no-cache", - "Connection": "Keep-Alive", - # _APPLICATION_ID is in the data -} -HEADERS_BASE = { - "Accept": "application/json", - "Content-Type": "application/json", # json= - "SessionId": "", # "e163b069-1234-..." -} - - async def handle_too_many_requests(rsp: aiohttp.ClientResponse) -> None: assert rsp.status == HTTPStatus.TOO_MANY_REQUESTS diff --git a/tests/tests_rf/test_v2_apis.py b/tests/tests_rf/test_v2_apis.py index 395722e..dc89ebb 100644 --- a/tests/tests_rf/test_v2_apis.py +++ b/tests/tests_rf/test_v2_apis.py @@ -26,7 +26,7 @@ async def _test_basics_apis(evo: EvohomeClientv2) -> None: """Test authentication, `user_account()` and `installation()`.""" # STEP 1: retrieve base data - await evo.update(_dont_update_status=False) + await evo.update(_dont_update_status=True) assert evo2.main.SCH_USER_ACCOUNT(evo.user_account) assert evo2.main.SCH_USER_LOCATIONS(evo.user_installation) diff --git a/tests/tests_rf/test_v2_cred.py b/tests/tests_rf/test_v2_cred.py index d659977..71cad1f 100644 --- a/tests/tests_rf/test_v2_cred.py +++ b/tests/tests_rf/test_v2_cred.py @@ -20,7 +20,6 @@ import aiohttp import pytest -from evohomeasync2.auth import _APPLICATION_ID from evohomeasync2.schemas import ( TCC_ERROR_RESPONSE, TCC_GET_USR_ACCOUNT, @@ -30,6 +29,8 @@ from tests.const import ( _DBG_TEST_CRED_URLS, _DBG_USE_REAL_AIOHTTP, + HEADERS_BASE, + HEADERS_CRED_V2 as HEADERS_CRED, URL_BASE_V2 as URL_BASE, URL_CRED_V2 as URL_CRED, ) @@ -42,20 +43,6 @@ TccUsrAccountResponseT, ) -HEADERS_CRED = { - "Accept": "application/json", - "Content-Type": "application/x-www-form-urlencoded; charset=utf-8", # data= - "Cache-Control": "no-cache, no-store", - "Pragma": "no-cache", - "Connection": "Keep-Alive", - "Authorization": "Basic " + _APPLICATION_ID, -} -HEADERS_BASE = { - "Accept": "application/json", - "Content-Type": "application/json", # json= - "Authorization": "Bearer ...", # "Bearer " + access_token -} - async def handle_too_many_requests(rsp: aiohttp.ClientResponse) -> None: assert rsp.status == HTTPStatus.TOO_MANY_REQUESTS # 429 diff --git a/tests/tests_rf/test_v2_urlz.py b/tests/tests_rf/test_v2_urlz.py index 61297de..6176dfc 100644 --- a/tests/tests_rf/test_v2_urlz.py +++ b/tests/tests_rf/test_v2_urlz.py @@ -49,7 +49,7 @@ async def test_tcs_urls( ) -> None: """Test Location, Gateway and TCS URLs.""" - await cache_manager.load_cache() + await cache_manager.load_from_cache() # # STEP 0: Create the Auth client @@ -139,7 +139,7 @@ async def test_zon_urls( ) -> None: """Test Zone URLs""" - await cache_manager.load_cache() + await cache_manager.load_from_cache() # # STEP 0: Create the Auth client, get the TCS config @@ -221,7 +221,7 @@ async def test_dhw_urls( ) -> None: """Test DHW URLs""" - await cache_manager.load_cache() + await cache_manager.load_from_cache() # # STEP 0: Create the Auth client, get the TCS config