Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,12 @@ async def _get_existing_credential(
existing_credential = await refresher.refresh(
existing_credential, self.auth_scheme
)
# Persist the refreshed credential so the next invocation
# reads the new tokens instead of the stale pre-refresh ones.
# Without this, providers that rotate refresh_tokens on each
# refresh (e.g. Salesforce, many OIDC providers) will fail
# because the old refresh_token has already been invalidated.
self._store_credential(existing_credential)
return existing_credential
return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,75 @@ async def test_openid_connect_existing_oauth2_token_refresh(
assert result.state == 'done'
# The result should contain the refreshed credential after exchange
assert result.auth_credential is not None


@patch(
'google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler.OAuth2CredentialRefresher'
)
@pytest.mark.asyncio
async def test_refreshed_credential_is_persisted_to_store(
mock_oauth2_refresher, openid_connect_scheme, openid_connect_credential
):
"""Test that refreshed OAuth2 credentials are persisted back to the store.

Regression test for https://github.com/google/adk-python/issues/5329.
Without persisting, the next invocation reads stale pre-refresh tokens from
state. Providers that rotate refresh_tokens on each refresh (e.g.
Salesforce, many OIDC providers) will then fail because the old
refresh_token has already been invalidated.
"""
# Create existing OAuth2 credential with an "old" refresh token.
existing_credential = AuthCredential(
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
oauth2=OAuth2Auth(
client_id='test_client_id',
client_secret='test_client_secret',
access_token='old_access_token',
refresh_token='old_refresh_token',
),
)

# The refresher will return a credential with rotated tokens.
refreshed_credential = AuthCredential(
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
oauth2=OAuth2Auth(
client_id='test_client_id',
client_secret='test_client_secret',
access_token='new_access_token',
refresh_token='new_refresh_token',
),
)

from unittest.mock import AsyncMock

mock_refresher_instance = MagicMock()
mock_refresher_instance.is_refresh_needed = AsyncMock(return_value=True)
mock_refresher_instance.refresh = AsyncMock(return_value=refreshed_credential)
mock_oauth2_refresher.return_value = mock_refresher_instance

tool_context = create_mock_tool_context()
credential_store = ToolContextCredentialStore(tool_context=tool_context)

# Store the existing (stale) credential.
key = credential_store.get_credential_key(
openid_connect_scheme, openid_connect_credential
)
credential_store.store_credential(key, existing_credential)

handler = ToolAuthHandler(
tool_context,
openid_connect_scheme,
openid_connect_credential,
credential_store=credential_store,
)

await handler.prepare_auth_credentials()

# The critical assertion: the *refreshed* credential must now be in the
# store so that the next invocation reads the new tokens, not the old ones.
persisted = credential_store.get_credential(
openid_connect_scheme, openid_connect_credential
)
assert persisted is not None
assert persisted.oauth2.access_token == 'new_access_token'
assert persisted.oauth2.refresh_token == 'new_refresh_token'