|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT license. |
| 3 | + |
| 4 | +import unittest |
| 5 | +from unittest.mock import MagicMock, patch, call |
| 6 | + |
| 7 | +import requests |
| 8 | + |
| 9 | +from PowerPlatform.Dataverse.core._http import _HttpClient |
| 10 | + |
| 11 | + |
| 12 | +class TestHttpClientTimeout(unittest.TestCase): |
| 13 | + """Tests for automatic timeout selection in _HttpClient._request.""" |
| 14 | + |
| 15 | + def _make_response(self, status=200): |
| 16 | + resp = MagicMock(spec=requests.Response) |
| 17 | + resp.status_code = status |
| 18 | + return resp |
| 19 | + |
| 20 | + def test_get_uses_10s_default_timeout(self): |
| 21 | + """GET requests use 10s default when no timeout is specified.""" |
| 22 | + client = _HttpClient(retries=1) |
| 23 | + with patch("requests.request", return_value=self._make_response()) as mock_req: |
| 24 | + client._request("get", "https://example.com/data") |
| 25 | + _, kwargs = mock_req.call_args |
| 26 | + self.assertEqual(kwargs["timeout"], 10) |
| 27 | + |
| 28 | + def test_post_uses_120s_default_timeout(self): |
| 29 | + """POST requests use 120s default when no timeout is specified.""" |
| 30 | + client = _HttpClient(retries=1) |
| 31 | + with patch("requests.request", return_value=self._make_response()) as mock_req: |
| 32 | + client._request("post", "https://example.com/data") |
| 33 | + _, kwargs = mock_req.call_args |
| 34 | + self.assertEqual(kwargs["timeout"], 120) |
| 35 | + |
| 36 | + def test_delete_uses_120s_default_timeout(self): |
| 37 | + """DELETE requests use 120s default when no timeout is specified.""" |
| 38 | + client = _HttpClient(retries=1) |
| 39 | + with patch("requests.request", return_value=self._make_response()) as mock_req: |
| 40 | + client._request("delete", "https://example.com/data") |
| 41 | + _, kwargs = mock_req.call_args |
| 42 | + self.assertEqual(kwargs["timeout"], 120) |
| 43 | + |
| 44 | + def test_default_timeout_overrides_per_method_default(self): |
| 45 | + """Explicit default_timeout on the client overrides per-method defaults.""" |
| 46 | + client = _HttpClient(retries=1, timeout=30.0) |
| 47 | + with patch("requests.request", return_value=self._make_response()) as mock_req: |
| 48 | + client._request("get", "https://example.com/data") |
| 49 | + _, kwargs = mock_req.call_args |
| 50 | + self.assertEqual(kwargs["timeout"], 30.0) |
| 51 | + |
| 52 | + def test_explicit_timeout_kwarg_takes_precedence(self): |
| 53 | + """If timeout is already in kwargs it is passed through unchanged.""" |
| 54 | + client = _HttpClient(retries=1, timeout=30.0) |
| 55 | + with patch("requests.request", return_value=self._make_response()) as mock_req: |
| 56 | + client._request("get", "https://example.com/data", timeout=5) |
| 57 | + _, kwargs = mock_req.call_args |
| 58 | + self.assertEqual(kwargs["timeout"], 5) |
| 59 | + |
| 60 | + |
| 61 | +class TestHttpClientRequester(unittest.TestCase): |
| 62 | + """Tests for session vs direct requests.request routing.""" |
| 63 | + |
| 64 | + def _make_response(self): |
| 65 | + resp = MagicMock(spec=requests.Response) |
| 66 | + resp.status_code = 200 |
| 67 | + return resp |
| 68 | + |
| 69 | + def test_uses_direct_request_without_session(self): |
| 70 | + """Without a session, _request uses requests.request directly.""" |
| 71 | + client = _HttpClient(retries=1) |
| 72 | + with patch("requests.request", return_value=self._make_response()) as mock_req: |
| 73 | + client._request("get", "https://example.com/data") |
| 74 | + mock_req.assert_called_once() |
| 75 | + |
| 76 | + def test_uses_session_request_when_session_provided(self): |
| 77 | + """With a session, _request uses session.request instead of requests.request.""" |
| 78 | + mock_session = MagicMock(spec=requests.Session) |
| 79 | + mock_session.request.return_value = self._make_response() |
| 80 | + client = _HttpClient(retries=1, session=mock_session) |
| 81 | + with patch("requests.request") as mock_req: |
| 82 | + client._request("get", "https://example.com/data") |
| 83 | + mock_session.request.assert_called_once() |
| 84 | + mock_req.assert_not_called() |
| 85 | + |
| 86 | + |
| 87 | +class TestHttpClientRetry(unittest.TestCase): |
| 88 | + """Tests for retry behavior on RequestException.""" |
| 89 | + |
| 90 | + def test_retries_on_request_exception_and_succeeds(self): |
| 91 | + """Retries after a RequestException and returns response on second attempt.""" |
| 92 | + resp = MagicMock(spec=requests.Response) |
| 93 | + resp.status_code = 200 |
| 94 | + client = _HttpClient(retries=2, backoff=0) |
| 95 | + with patch("requests.request", side_effect=[requests.exceptions.ConnectionError(), resp]) as mock_req: |
| 96 | + with patch("time.sleep"): |
| 97 | + result = client._request("get", "https://example.com/data") |
| 98 | + self.assertEqual(mock_req.call_count, 2) |
| 99 | + self.assertIs(result, resp) |
| 100 | + |
| 101 | + def test_raises_after_all_retries_exhausted(self): |
| 102 | + """Raises RequestException after all retry attempts fail.""" |
| 103 | + client = _HttpClient(retries=3, backoff=0) |
| 104 | + with patch("requests.request", side_effect=requests.exceptions.ConnectionError("timeout")): |
| 105 | + with patch("time.sleep"): |
| 106 | + with self.assertRaises(requests.exceptions.RequestException): |
| 107 | + client._request("get", "https://example.com/data") |
| 108 | + |
| 109 | + def test_backoff_delay_between_retries(self): |
| 110 | + """Sleeps with exponential backoff between retry attempts.""" |
| 111 | + resp = MagicMock(spec=requests.Response) |
| 112 | + resp.status_code = 200 |
| 113 | + client = _HttpClient(retries=3, backoff=1.0) |
| 114 | + side_effects = [ |
| 115 | + requests.exceptions.ConnectionError(), |
| 116 | + requests.exceptions.ConnectionError(), |
| 117 | + resp, |
| 118 | + ] |
| 119 | + with patch("requests.request", side_effect=side_effects): |
| 120 | + with patch("time.sleep") as mock_sleep: |
| 121 | + client._request("get", "https://example.com/data") |
| 122 | + # First retry: delay = 1.0 * 2^0 = 1.0, second retry: 1.0 * 2^1 = 2.0 |
| 123 | + mock_sleep.assert_has_calls([call(1.0), call(2.0)]) |
0 commit comments