Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from datetime import datetime, timedelta, timezone
from unittest.mock import Mock

import pytest

from trakt.api import HttpClient
from trakt.api import TokenAuth
from trakt.config import AuthConfig
from trakt.core import api
from trakt.errors import OAuthException, OAuthRefreshException
from trakt.tv import TVShow


Expand All @@ -15,3 +23,31 @@ def test_tvshow_properties():
show = TVShow("Game of Thrones")
assert show.title == "Game of Thrones"
assert show.certification == "TV-MA"


def test_token_refresh_failure_raises_dedicated_exception():
config = AuthConfig('missing.json').update(
CLIENT_ID='client-id',
CLIENT_SECRET='client-secret',
OAUTH_TOKEN='stale-token',
OAUTH_REFRESH='refresh-token',
OAUTH_EXPIRES_AT=int((datetime.now(tz=timezone.utc) - timedelta(minutes=1)).timestamp()),
)
response = Mock()
response.json.return_value = {
'error': 'invalid_grant',
'error_description': 'refresh token is invalid',
}
response.text = 'refresh token is invalid'
client = Mock()
client.post.side_effect = OAuthException(response=response)

auth = TokenAuth(client=client, config=config)

with pytest.raises(OAuthRefreshException) as exc_info:
auth.get_token()

assert exc_info.value.error == 'invalid_grant'
assert exc_info.value.error_description == 'refresh token is invalid'
assert auth.TOKEN_UNDER_REFRESH is False
assert auth.OAUTH_TOKEN_VALID is False
22 changes: 20 additions & 2 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# -*- coding: utf-8 -*-
"""unit tests to define behavior of custom exception types"""
from unittest.mock import Mock

from trakt.errors import (BadRequestException, ConflictException,
ForbiddenException, NotFoundException,
OAuthException, ProcessException, RateLimitException,
TraktException, TraktInternalException,
OAuthException, OAuthRefreshException,
ProcessException, RateLimitException, TraktException,
TraktInternalException,
TraktUnavailable)


Expand All @@ -27,6 +30,21 @@ def test_401_exception():
assert str(texc) == texc.message


def test_oauth_refresh_exception_uses_api_error_details():
response = Mock()
response.json.return_value = {
'error': 'invalid_grant',
'error_description': 'refresh token is invalid',
}

texc = OAuthRefreshException(response=response)

assert texc.http_code == 401
assert texc.error == 'invalid_grant'
assert texc.error_description == 'refresh token is invalid'
assert str(texc) == 'Unauthorized - OAuth token refresh failed: invalid_grant - refresh token is invalid'


def test_403_exception():
texc = ForbiddenException()
assert texc.http_code == 403
Expand Down
115 changes: 80 additions & 35 deletions trakt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from trakt.config import AuthConfig
from trakt.core import TIMEOUT
from trakt.errors import (BadRequestException, BadResponseException,
OAuthException)
OAuthException, OAuthRefreshException)

__author__ = 'Elan Ruusamäe'

Expand Down Expand Up @@ -223,25 +223,29 @@ def validate_token(self):
critical operations while also maximizing the token's useful lifetime.
"""

current = datetime.now(tz=timezone.utc)
expires_at = datetime.fromtimestamp(self.config.OAUTH_EXPIRES_AT, tz=timezone.utc)
margin = expires_at - current
if margin > timedelta(**self.TOKEN_REFRESH_MARGIN):
self.OAUTH_TOKEN_VALID = True
else:
self.logger.debug("Token expires in %s, refreshing (margin: %s)", margin, self.TOKEN_REFRESH_MARGIN)
self.refresh_token()

self.TOKEN_UNDER_REFRESH = False
try:
expires_at_timestamp = self.config.OAUTH_EXPIRES_AT
if expires_at_timestamp is None:
self.OAUTH_TOKEN_VALID = False
raise OAuthRefreshException(
error='missing_token_expiry',
error_description='OAuth token expiry is missing from the current configuration.',
)

current = datetime.now(tz=timezone.utc)
expires_at = datetime.fromtimestamp(expires_at_timestamp, tz=timezone.utc)
margin = expires_at - current
if margin > timedelta(**self.TOKEN_REFRESH_MARGIN):
self.OAUTH_TOKEN_VALID = True
else:
self.logger.debug("Token expires in %s, refreshing (margin: %s)", margin, self.TOKEN_REFRESH_MARGIN)
self.refresh_token()
finally:
self.TOKEN_UNDER_REFRESH = False

def refresh_token(self):
"""Request Trakt API for a new valid OAuth token using refresh_token"""

if self.refresh_attempts >= self.MAX_RETRIES:
self.logger.error("Max token refresh attempts reached. Manual intervention required.")
return
self.refresh_attempts += 1

self.logger.info("OAuth token has expired, refreshing now...")
data = {
'client_id': self.config.CLIENT_ID,
Expand All @@ -251,26 +255,38 @@ def refresh_token(self):
'grant_type': 'refresh_token'
}

try:
response = self.client.post('oauth/token', data)
last_error = None
response = None
for attempt in range(1, self.MAX_RETRIES + 1):
self.refresh_attempts = attempt
try:
response = self.client.post('oauth/token', data)
self.refresh_attempts = 0
break
except (OAuthException, BadRequestException) as exc:
last_error = self._build_refresh_exception(exc)
self.logger.error(
"%s - Unable to refresh expired OAuth token (%s) %s",
exc.http_code,
last_error.error or 'unknown_error',
last_error.error_description or ''
)
else:
self.OAUTH_TOKEN_VALID = False
self.refresh_attempts = 0
except (OAuthException, BadRequestException) as e:
if e.response is not None:
try:
data = e.response.json()
error = data.get("error")
error_description = data.get("error_description")
except JSONDecodeError:
error = "Invalid JSON response"
error_description = e.response.text
else:
error = "No error description"
error_description = ""
self.logger.error(
"%s - Unable to refresh expired OAuth token (%s) %s",
e.http_code, error, error_description
if last_error is None:
raise OAuthRefreshException(
error='unknown_refresh_error',
error_description='OAuth token refresh failed without an explicit API error.',
)
raise last_error

if response is None:
self.OAUTH_TOKEN_VALID = False
raise OAuthRefreshException(
error='empty_refresh_response',
error_description='OAuth token refresh completed without a response payload.',
)
return

self.config.update(
OAUTH_TOKEN=response.get("access_token"),
Expand All @@ -279,9 +295,38 @@ def refresh_token(self):
)
self.OAUTH_TOKEN_VALID = True

expires_at_timestamp = self.config.OAUTH_EXPIRES_AT
if expires_at_timestamp is None:
self.OAUTH_TOKEN_VALID = False
raise OAuthRefreshException(
error='invalid_refresh_state',
error_description='OAuth token refresh did not persist an expiry timestamp.',
)

self.logger.info(
"OAuth token successfully refreshed, valid until {}".format(
datetime.fromtimestamp(self.config.OAUTH_EXPIRES_AT, tz=timezone.utc)
datetime.fromtimestamp(expires_at_timestamp, tz=timezone.utc)
)
)
self.config.store()

@staticmethod
def _build_refresh_exception(exc):
error = 'No error description'
error_description = ''

if exc.response is not None:
try:
data = exc.response.json()
error = data.get("error") or error
error_description = data.get("error_description") or error_description
except JSONDecodeError:
error = 'Invalid JSON response'
error_description = exc.response.text

return OAuthRefreshException(
response=exc.response,
error=error,
error_description=error_description,
cause=exc,
)
31 changes: 27 additions & 4 deletions trakt/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,40 @@ class OAuthException(TraktException):


class OAuthRefreshException(OAuthException):
def __init__(self, response=None):
"""Raised when an OAuth access token could not be refreshed."""

message = 'Unauthorized - OAuth token refresh failed'

def __init__(self, response=None, error=None, error_description=None, cause=None):
super().__init__(response)
self.data = self.response.json()
self.cause = cause
self.data = self._load_data()
self._error = error or self.data.get("error")
self._error_description = error_description or self.data.get("error_description")
Comment on lines +79 to +81
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Guard against non-object JSON payloads in OAuthRefreshException.

Line 80 and Line 81 assume self.data is a dict. If response.json() returns a JSON array/string/number, this path raises AttributeError and hides the actual refresh failure context.

Proposed fix
     def _load_data(self):
         if self.response is None:
             return {}

         try:
-            return self.response.json()
-        except (AttributeError, ValueError):
+            data = self.response.json()
+            return data if isinstance(data, dict) else {}
+        except (AttributeError, ValueError, TypeError):
             return {}

Also applies to: 83-90

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@trakt/errors.py` around lines 79 - 81, OAuthRefreshException assumes
self.data is a dict and calls .get, which will raise AttributeError for
non-object JSON payloads; change the initialization in the constructor that
calls self._load_data() to defensively coerce non-dict JSON into an empty dict
(e.g., data = self._load_data(); if not isinstance(data, dict): data = {}), then
use that safe `data` for assigning self._error, self._error_description and any
other attributes set via data.get in the block around OAuthRefreshException
(lines setting _error/_error_description and the subsequent .get uses between
~80-90), so all .get calls operate on a dict fallback rather than potentially
non-object JSON.


def _load_data(self):
if self.response is None:
return {}

try:
return self.response.json()
except (AttributeError, ValueError):
return {}

@property
def error(self):
return self.data["error"]
return self._error

@property
def error_description(self):
return self.data["error_description"]
return self._error_description

def __str__(self):
if self.error and self.error_description:
return f'{self.message}: {self.error} - {self.error_description}'
if self.error:
return f'{self.message}: {self.error}'
return self.message


class ForbiddenException(TraktException):
Expand Down