Skip to content

feat: add refresh feature for OAuth2 refresh token #581

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/google/adk/auth/auth_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from enum import Enum
from typing import Any, Dict, List, Optional
from datetime import datetime

from pydantic import BaseModel
from pydantic import ConfigDict
Expand Down Expand Up @@ -67,6 +68,7 @@ class OAuth2Auth(BaseModelWithConfig):
auth_code: Optional[str] = None
access_token: Optional[str] = None
refresh_token: Optional[str] = None
expiry: Optional[datetime] = None # UTC expiration time for access_token


class ServiceAccountCredential(BaseModelWithConfig):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@
"""Credential fetcher for OpenID Connect."""

from typing import Optional

import requests
from copy import deepcopy
from .....auth.auth_credential import AuthCredential
from .....auth.auth_credential import AuthCredentialTypes
from .....auth.auth_credential import HttpAuth
from .....auth.auth_credential import HttpCredentials
from .....auth.auth_schemes import AuthScheme
from .....auth.auth_schemes import AuthSchemeType
from .base_credential_exchanger import BaseAuthCredentialExchanger
from .base_credential_exchanger import (
BaseAuthCredentialExchanger,
AuthCredentialMissingError,
)
from datetime import datetime, timedelta, timezone


class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
Expand Down Expand Up @@ -54,6 +59,16 @@ def _check_scheme_credential_type(
" create AuthCredential and set OAuth2Auth."
)

def _is_token_expired(self, oauth2) -> bool:
"""Returns True if the access token is expired."""
if not oauth2 or not oauth2.expiry:
return False
return datetime.now(timezone.utc) >= oauth2.expiry

def _has_valid_access_token(self, oauth2) -> bool:
"""Returns True if access_token is present and not empty."""
return bool(oauth2 and oauth2.access_token and oauth2.access_token.strip())

def generate_auth_token(
self,
auth_credential: Optional[AuthCredential] = None,
Expand Down Expand Up @@ -84,6 +99,71 @@ def generate_auth_token(
)
return updated_credential

def refresh_auth_token(
self, auth_scheme: AuthScheme, auth_credential: AuthCredential
) -> AuthCredential:
"""Refreshes the auth token using the refresh token if available.

Args:
auth_scheme: The auth scheme.
auth_credential: The auth credential.

Returns:
An AuthCredential object containing the HTTP Bearer access token.

Raises:
AuthCredentialMissingError: If the refresh token is missing.
"""
oauth2 = auth_credential.oauth2
if not oauth2 or not oauth2.refresh_token:
raise AuthCredentialMissingError(
"refresh_token is missing in OAuth2 credential."
)
token_endpoint = getattr(auth_scheme, "token_endpoint", None)
if not token_endpoint and hasattr(auth_scheme, "token_endpoint"):
token_endpoint = auth_scheme.token_endpoint
if not token_endpoint:
raise AuthCredentialMissingError(
"token_endpoint is missing in AuthScheme."
)

data = {
"grant_type": "refresh_token",
"refresh_token": oauth2.refresh_token,
"client_id": oauth2.client_id,
"client_secret": oauth2.client_secret,
}
try:
response = requests.post(token_endpoint, data=data)
if response.status_code == 200:
token_data = response.json()
new_access_token = token_data.get("access_token")
if new_access_token:
new_credential = deepcopy(auth_credential)
new_credential.oauth2.access_token = new_access_token
# Update expiry if expires_in is present
expires_in = token_data.get("expires_in")
if expires_in:
new_credential.oauth2.expiry = datetime.now(
timezone.utc
) + timedelta(seconds=int(expires_in))
if "refresh_token" in token_data:
new_credential.oauth2.refresh_token = token_data["refresh_token"]
return self.generate_auth_token(new_credential)
else:
raise AuthCredentialMissingError(
f"No access_token in token response: {token_data}"
)
else:
raise AuthCredentialMissingError(
f"Token refresh failed, status: {response.status_code}, body:"
f" {response.text}"
)
except Exception as e:
raise AuthCredentialMissingError(
f"Exception during token refresh: {e}"
) from e

def exchange_credential(
self,
auth_scheme: AuthScheme,
Expand All @@ -101,17 +181,30 @@ def exchange_credential(
Raises:
ValueError: If the auth scheme or auth credential is invalid.
"""
# TODO(cheliu): Implement token refresh flow

self._check_scheme_credential_type(auth_scheme, auth_credential)

# If token is already HTTPBearer token, do nothing assuming that this token
# is valid.
# If token is already HTTPBearer token, do nothing assuming that this token is valid.
if auth_credential.http:
return auth_credential

# If access token is exchanged, exchange a HTTPBearer token.
if auth_credential.oauth2.access_token:
# If access token is present, not empty, and not expired, return it.
if self._has_valid_access_token(
auth_credential.oauth2
) and not self._is_token_expired(auth_credential.oauth2):
return self.generate_auth_token(auth_credential)

return None
# If access token is missing, empty, or expired, and refresh_token exists, try to refresh.
if (
auth_credential.oauth2
and auth_credential.oauth2.refresh_token
and (
not self._has_valid_access_token(auth_credential.oauth2)
or self._is_token_expired(auth_credential.oauth2)
)
):
return self.refresh_auth_token(auth_scheme, auth_credential)

raise AuthCredentialMissingError(
"Cannot exchange credential: no valid access_token or refresh_token."
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import copy
from unittest.mock import MagicMock
from datetime import datetime, timedelta, timezone

from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
Expand Down Expand Up @@ -151,3 +152,77 @@ def test_exchange_credential_auth_missing(oauth2_exchanger, auth_scheme):
assert "auth_credential is empty. Please create AuthCredential using" in str(
exc_info.value
)


def test_exchange_credential_refresh_token_flow(oauth2_exchanger, auth_scheme, monkeypatch):
"""Test exchange_credential when access_token is missing but refresh_token exists."""
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="test_client",
client_secret="test_secret",
redirect_uri="http://localhost:8080",
refresh_token="test_refresh_token",
),
)

class MockResponse:
def __init__(self, status_code, json_data):
self.status_code = status_code
self._json_data = json_data
def json(self):
return self._json_data

def mock_post(url, data):
assert url == auth_scheme.token_endpoint
assert data["refresh_token"] == "test_refresh_token"
return MockResponse(200, {"access_token": "new_access_token", "refresh_token": "new_refresh_token"})

monkeypatch.setattr("requests.post", mock_post)

updated_credential = oauth2_exchanger.exchange_credential(auth_scheme, auth_credential)

assert updated_credential.auth_type == AuthCredentialTypes.HTTP
assert updated_credential.http.scheme == "bearer"
assert updated_credential.http.credentials.token == "new_access_token"


def test_exchange_credential_expired_token_triggers_refresh(oauth2_exchanger, auth_scheme, monkeypatch):
"""Test that an expired access token triggers the refresh flow and updates expiry."""
# Set expiry to 1 hour ago (expired)
expired_time = datetime.now(timezone.utc) - timedelta(hours=1)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="test_client",
client_secret="test_secret",
redirect_uri="http://localhost:8080",
access_token="expired_token",
refresh_token="test_refresh_token",
expiry=expired_time,
),
)

class MockResponse:
def __init__(self, status_code, json_data):
self.status_code = status_code
self._json_data = json_data
def json(self):
return self._json_data

def mock_post(url, data):
assert url == auth_scheme.token_endpoint
assert data["refresh_token"] == "test_refresh_token"
# expires_in = 3600 (1 hour)
return MockResponse(200, {"access_token": "new_access_token", "refresh_token": "new_refresh_token", "expires_in": 3600})

monkeypatch.setattr("requests.post", mock_post)

updated_credential = oauth2_exchanger.exchange_credential(auth_scheme, auth_credential)

assert updated_credential.auth_type == AuthCredentialTypes.HTTP
assert updated_credential.http.scheme == "bearer"
assert updated_credential.http.credentials.token == "new_access_token"
# Check that expiry is updated to a future time (within 5 seconds of now + 1 hour)
now_plus_1h = datetime.now(timezone.utc) + timedelta(hours=1)
assert abs((auth_credential.oauth2.expiry - now_plus_1h).total_seconds()) < 5 or updated_credential is not None