|
| 1 | +import time |
| 2 | +from typing import Any, Dict, Sequence, Optional, TYPE_CHECKING |
| 3 | + |
| 4 | +from .auth_exceptions import AuthFetchTokenException, AuthDiscoveryException, AuthInvalidCredentialsException |
| 5 | + |
| 6 | +if TYPE_CHECKING: # pragma: no cover |
| 7 | + from .scope import Scope |
| 8 | + from .session import Session |
| 9 | + from requests.models import PreparedRequest |
| 10 | + |
| 11 | + |
| 12 | +class _AuthClient: |
| 13 | + |
| 14 | + def __init__( |
| 15 | + self, username: str, password: str, scope: Sequence["Scope"], authorization_url: str, session: "Session" |
| 16 | + ) -> None: |
| 17 | + self._username = username |
| 18 | + self._password = password |
| 19 | + self.scope = " ".join((str(x) for x in scope)) |
| 20 | + self.authorization_url = authorization_url |
| 21 | + self.token_endpoint: Optional[str] = None |
| 22 | + self.session = session |
| 23 | + self.requests_auth = self._requests_auth |
| 24 | + self.access_token = "" |
| 25 | + self.token_response: Optional[Dict[str, Any]] = None |
| 26 | + self.leeway = 60 |
| 27 | + self.expires_at: Optional[int] = None |
| 28 | + self.fetch_token_get_time = time.time |
| 29 | + self.is_expired_get_time = time.time |
| 30 | + |
| 31 | + def fetch_token_if_necessary(self) -> bool: |
| 32 | + if self._is_expired(): |
| 33 | + self.fetch_token() |
| 34 | + return True |
| 35 | + return False |
| 36 | + |
| 37 | + def fetch_token(self) -> None: |
| 38 | + if self.token_endpoint is None: |
| 39 | + self.token_endpoint = self._discovery(self.authorization_url) |
| 40 | + |
| 41 | + self._fetch_token(self.token_endpoint) |
| 42 | + |
| 43 | + def _fetch_token(self, token_endpoint: str) -> None: |
| 44 | + payload = { |
| 45 | + "grant_type": "client_credentials", |
| 46 | + "client_id": self._username, |
| 47 | + "client_secret": self._password, |
| 48 | + "scope": self.scope, |
| 49 | + } |
| 50 | + response = self.session.requests_session.post( |
| 51 | + token_endpoint, data=payload, headers={"Content-Type": "application/x-www-form-urlencoded"} |
| 52 | + ) |
| 53 | + |
| 54 | + if response.status_code not in [200, 400]: |
| 55 | + raise AuthFetchTokenException("status code is not 200 or 400") |
| 56 | + |
| 57 | + try: |
| 58 | + json = self.token_response = response.json() |
| 59 | + except Exception as ex: |
| 60 | + raise AuthFetchTokenException("not valid json") from ex |
| 61 | + |
| 62 | + if not isinstance(json, dict): |
| 63 | + raise AuthFetchTokenException("no root obj in json") |
| 64 | + |
| 65 | + if response.status_code == 400: |
| 66 | + error = json.get("error") |
| 67 | + if error == "invalid_client": |
| 68 | + raise AuthInvalidCredentialsException("invalid client credentials") |
| 69 | + if isinstance(error, str): |
| 70 | + raise AuthFetchTokenException("error: " + error) |
| 71 | + raise AuthFetchTokenException("no error in response") |
| 72 | + |
| 73 | + if json.get("token_type") and json["token_type"] != "Bearer": |
| 74 | + raise AuthFetchTokenException("token_type is not Bearer") |
| 75 | + |
| 76 | + if json.get("expires_at"): |
| 77 | + self.expires_at = int(json["expires_at"]) |
| 78 | + elif json.get("expires_in"): |
| 79 | + self.expires_at = int(self.fetch_token_get_time()) + int(json["expires_in"]) |
| 80 | + else: |
| 81 | + raise AuthFetchTokenException("no expires_at or expires_in") |
| 82 | + |
| 83 | + if not json.get("access_token"): |
| 84 | + raise AuthFetchTokenException("No access_token") |
| 85 | + |
| 86 | + self.access_token = json["access_token"] |
| 87 | + |
| 88 | + def _discovery(self, url: str) -> str: |
| 89 | + |
| 90 | + response = self.session.requests_session.get(url + ".well-known/openid-configuration") |
| 91 | + if response.status_code != 200: |
| 92 | + raise AuthDiscoveryException("status code is not 200") |
| 93 | + |
| 94 | + try: |
| 95 | + json: Optional[Dict[str, Any]] = response.json() |
| 96 | + except Exception as ex: |
| 97 | + raise AuthDiscoveryException("not valid json.") from ex |
| 98 | + |
| 99 | + if not isinstance(json, dict): |
| 100 | + raise AuthDiscoveryException("no root obj in json.") |
| 101 | + |
| 102 | + token_endpoint: Optional[str] = json.get("token_endpoint") |
| 103 | + if token_endpoint is None: |
| 104 | + raise AuthDiscoveryException("token_endpoint in root obj.") |
| 105 | + |
| 106 | + return token_endpoint |
| 107 | + |
| 108 | + def _is_expired(self) -> bool: |
| 109 | + if not self.expires_at: |
| 110 | + return True |
| 111 | + expiration_threshold = self.expires_at - self.leeway |
| 112 | + return expiration_threshold < self.is_expired_get_time() |
| 113 | + |
| 114 | + def _requests_auth(self, r: "PreparedRequest") -> "PreparedRequest": |
| 115 | + r.headers["Authorization"] = "Bearer " + self.access_token |
| 116 | + return r |
0 commit comments