diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py index 21337cfa51..0d78a5759b 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py @@ -242,6 +242,12 @@ async def _get_existing_credential( existing_credential = await refresher.refresh( existing_credential, self.auth_scheme ) + # Persist the refreshed credential so the next invocation + # reads the new tokens instead of the stale pre-refresh ones. + # Without this, providers that rotate refresh_tokens on each + # refresh (e.g. Salesforce, many OIDC providers) will fail + # because the old refresh_token has already been invalidated. + self._store_credential(existing_credential) return existing_credential return None diff --git a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py index a6babce651..1c9ace5c18 100644 --- a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py +++ b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py @@ -292,3 +292,75 @@ async def test_openid_connect_existing_oauth2_token_refresh( assert result.state == 'done' # The result should contain the refreshed credential after exchange assert result.auth_credential is not None + + +@patch( + 'google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler.OAuth2CredentialRefresher' +) +@pytest.mark.asyncio +async def test_refreshed_credential_is_persisted_to_store( + mock_oauth2_refresher, openid_connect_scheme, openid_connect_credential +): + """Test that refreshed OAuth2 credentials are persisted back to the store. + + Regression test for https://github.com/google/adk-python/issues/5329. + Without persisting, the next invocation reads stale pre-refresh tokens from + state. Providers that rotate refresh_tokens on each refresh (e.g. + Salesforce, many OIDC providers) will then fail because the old + refresh_token has already been invalidated. + """ + # Create existing OAuth2 credential with an "old" refresh token. + existing_credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id='test_client_id', + client_secret='test_client_secret', + access_token='old_access_token', + refresh_token='old_refresh_token', + ), + ) + + # The refresher will return a credential with rotated tokens. + refreshed_credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id='test_client_id', + client_secret='test_client_secret', + access_token='new_access_token', + refresh_token='new_refresh_token', + ), + ) + + from unittest.mock import AsyncMock + + mock_refresher_instance = MagicMock() + mock_refresher_instance.is_refresh_needed = AsyncMock(return_value=True) + mock_refresher_instance.refresh = AsyncMock(return_value=refreshed_credential) + mock_oauth2_refresher.return_value = mock_refresher_instance + + tool_context = create_mock_tool_context() + credential_store = ToolContextCredentialStore(tool_context=tool_context) + + # Store the existing (stale) credential. + key = credential_store.get_credential_key( + openid_connect_scheme, openid_connect_credential + ) + credential_store.store_credential(key, existing_credential) + + handler = ToolAuthHandler( + tool_context, + openid_connect_scheme, + openid_connect_credential, + credential_store=credential_store, + ) + + await handler.prepare_auth_credentials() + + # The critical assertion: the *refreshed* credential must now be in the + # store so that the next invocation reads the new tokens, not the old ones. + persisted = credential_store.get_credential( + openid_connect_scheme, openid_connect_credential + ) + assert persisted is not None + assert persisted.oauth2.access_token == 'new_access_token' + assert persisted.oauth2.refresh_token == 'new_refresh_token'