diff --git a/custom_components/integration_linear/api.py b/custom_components/integration_linear/api.py index 53514b0..a4fe0c6 100644 --- a/custom_components/integration_linear/api.py +++ b/custom_components/integration_linear/api.py @@ -2,8 +2,10 @@ from __future__ import annotations +import asyncio import socket -from typing import Any, Callable, Awaitable +from collections.abc import Awaitable, Callable +from typing import Any import aiohttp import async_timeout @@ -67,7 +69,7 @@ def __init__( self._api_token = api_token self._session = session self._token_refresh_callback = token_refresh_callback - self._refresh_in_progress = False + self._refresh_lock = asyncio.Lock() async def async_validate_token(self) -> None: """Validate the API token by making a simple query.""" @@ -589,22 +591,22 @@ async def _api_wrapper( error_messages.append(message) extensions = err.get("extensions", {}) status_code = extensions.get("statusCode") - if status_code in (401, 403) or "unauthorized" in message.lower(): + is_unauthorized = "unauthorized" in message.lower() + if status_code in (401, 403) or is_unauthorized: is_auth_error = True break # Try to refresh token if we have a callback and this is an auth error - if is_auth_error and retry_on_auth_error and self._token_refresh_callback and not self._refresh_in_progress: - LOGGER.info("Authentication error detected, attempting token refresh") - try: - self._refresh_in_progress = True - new_token = await self._token_refresh_callback() - self._api_token = new_token - # Create new headers dict with updated token + has_refresh = self._token_refresh_callback is not None + if is_auth_error and retry_on_auth_error and has_refresh: + LOGGER.info( + "Authentication error detected, attempting token refresh" + ) + async with self._refresh_lock: + # After acquiring the lock, check if another task + # already refreshed by retrying with the current token retry_headers = dict(headers) if headers else {} - retry_headers["Authorization"] = new_token - # Retry the request once - LOGGER.debug("Retrying request with refreshed token") + retry_headers["Authorization"] = self._api_token response = await self._session.request( method=method, url=url, @@ -612,12 +614,39 @@ async def _api_wrapper( json=data, ) result = await response.json() - LOGGER.debug("Response after retry: %r", result) - except Exception as refresh_exception: - LOGGER.error("Token refresh failed: %s", refresh_exception) - _raise_authentication_error() - finally: - self._refresh_in_progress = False + + # Check if the retry succeeded (token was already refreshed) + is_still_unauthorized = response.status in ( + HTTP_STATUS_UNAUTHORIZED, + HTTP_STATUS_FORBIDDEN, + ) + if not is_still_unauthorized: + # Another task already refreshed the token + LOGGER.debug( + "Token was already refreshed by another request" + ) + else: + # We need to refresh the token + try: + LOGGER.debug("Refreshing token") + new_token = await self._token_refresh_callback() + self._api_token = new_token + # Retry the request with the new token + retry_headers["Authorization"] = new_token + LOGGER.debug("Retrying request with refreshed token") + response = await self._session.request( + method=method, + url=url, + headers=retry_headers, + json=data, + ) + result = await response.json() + LOGGER.debug("Response after retry: %r", result) + except Exception as refresh_exception: # noqa: BLE001 + LOGGER.error( + "Token refresh failed: %s", refresh_exception + ) + _raise_authentication_error() # Check for HTTP errors if response.status in (HTTP_STATUS_UNAUTHORIZED, HTTP_STATUS_FORBIDDEN):