|
18 | 18 | import abc |
19 | 19 | import inspect |
20 | 20 |
|
| 21 | +from google.auth import _regional_access_boundary_utils |
21 | 22 | from google.auth import credentials |
22 | 23 |
|
23 | 24 |
|
@@ -64,8 +65,28 @@ async def before_request(self, request, method, url, headers): |
64 | 65 | await self.refresh(request) |
65 | 66 | else: |
66 | 67 | self.refresh(request) |
| 68 | + |
| 69 | + if inspect.iscoroutinefunction(self._after_refresh): |
| 70 | + await self._after_refresh(request, method, url, headers) |
| 71 | + else: |
| 72 | + self._after_refresh(request, method, url, headers) |
| 73 | + |
67 | 74 | self.apply(headers) |
68 | 75 |
|
| 76 | + def _after_refresh(self, request, method, url, headers): |
| 77 | + """Hook for subclasses to perform actions after refresh but before |
| 78 | + applying credentials to headers. |
| 79 | +
|
| 80 | + Args: |
| 81 | + request (google.auth.transport.Request): The object used to make |
| 82 | + HTTP requests. |
| 83 | + method (str): The request's HTTP method or the RPC method being |
| 84 | + invoked. |
| 85 | + url (str): The request's URI or the RPC service's URI. |
| 86 | + headers (Mapping[str, str]): The request's headers. |
| 87 | + """ |
| 88 | + pass |
| 89 | + |
69 | 90 |
|
70 | 91 | class CredentialsWithQuotaProject(credentials.CredentialsWithQuotaProject): |
71 | 92 | """Abstract base for credentials supporting ``with_quota_project`` factory""" |
@@ -169,3 +190,74 @@ def with_scopes_if_required(credentials, scopes): |
169 | 190 |
|
170 | 191 | class Signing(credentials.Signing, metaclass=abc.ABCMeta): |
171 | 192 | """Interface for credentials that can cryptographically sign messages.""" |
| 193 | + |
| 194 | + |
| 195 | +class CredentialsWithRegionalAccessBoundary( |
| 196 | + Credentials, credentials.CredentialsWithRegionalAccessBoundary |
| 197 | +): |
| 198 | + """Async base for credentials supporting regional access boundary configuration.""" |
| 199 | + |
| 200 | + def __init__(self): |
| 201 | + super().__init__() |
| 202 | + self._rab_manager.refresh_manager = ( |
| 203 | + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() |
| 204 | + ) |
| 205 | + |
| 206 | + def __setstate__(self, state): |
| 207 | + super().__setstate__(state) |
| 208 | + self._rab_manager.refresh_manager = ( |
| 209 | + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() |
| 210 | + ) |
| 211 | + |
| 212 | + async def _after_refresh(self, request, method, url, headers): |
| 213 | + """Triggers the Regional Access Boundary lookup asynchronously if necessary.""" |
| 214 | + await self._maybe_start_regional_access_boundary_refresh_async(request, url) |
| 215 | + |
| 216 | + async def _maybe_start_regional_access_boundary_refresh_async(self, request, url): |
| 217 | + """Starts a background refresh or performs a blocking refresh asynchronously. |
| 218 | +
|
| 219 | + Args: |
| 220 | + request (google.auth.aio.transport.Request): The object used to make |
| 221 | + HTTP requests. |
| 222 | + url (str): The URL of the request. |
| 223 | + """ |
| 224 | + # Do not perform a lookup if the request is for a regional endpoint. |
| 225 | + if self._is_regional_endpoint(url): |
| 226 | + return |
| 227 | + |
| 228 | + # A refresh is only needed if the feature is enabled. |
| 229 | + if not self._is_regional_access_boundary_lookup_required(): |
| 230 | + return |
| 231 | + |
| 232 | + # Trigger background or blocking refresh if needed. |
| 233 | + await self._rab_manager.maybe_start_refresh_async(self, request) |
| 234 | + |
| 235 | + async def _lookup_regional_access_boundary(self, request, fail_fast=False): |
| 236 | + """Calls the Regional Access Boundary lookup API asynchronously. |
| 237 | +
|
| 238 | + Args: |
| 239 | + request (google.auth.aio.transport.Request): The object used to make |
| 240 | + HTTP requests. |
| 241 | + fail_fast (bool): Whether the lookup should fail fast (short timeout, no retries). |
| 242 | +
|
| 243 | + Returns: |
| 244 | + Optional[Dict[str, str]]: The Regional Access Boundary information |
| 245 | + returned by the lookup API, or None if the lookup failed. |
| 246 | + """ |
| 247 | + url_builder = self._build_regional_access_boundary_lookup_url |
| 248 | + if inspect.iscoroutinefunction(url_builder): |
| 249 | + url = await url_builder(request=request) |
| 250 | + else: |
| 251 | + url = url_builder(request=request) |
| 252 | + |
| 253 | + if not url: |
| 254 | + return None |
| 255 | + |
| 256 | + headers = {} |
| 257 | + self._apply(headers) |
| 258 | + |
| 259 | + from google.oauth2 import _client_async |
| 260 | + |
| 261 | + return await _client_async._lookup_regional_access_boundary( |
| 262 | + request, url, headers=headers, fail_fast=fail_fast |
| 263 | + ) |
0 commit comments