diff --git a/customerio/client_base.py b/customerio/client_base.py index 0087faf..e333b08 100644 --- a/customerio/client_base.py +++ b/customerio/client_base.py @@ -3,14 +3,47 @@ """ import math +import socket from datetime import datetime, timezone from requests import Session -from requests.adapters import HTTPAdapter +from requests.adapters import DEFAULT_POOLBLOCK, HTTPAdapter +from urllib3.connection import HTTPConnection from urllib3.util.retry import Retry from .__version__ import __version__ as ClientVersion +TCP_KEEPALIVE_IDLE_TIMEOUT = 300 +TCP_KEEPALIVE_INTERVAL = 60 + + +def _tcp_keepalive_socket_options(): + tcp_protocol = getattr(socket, "SOL_TCP", socket.IPPROTO_TCP) + tcp_keepidle = getattr(socket, "TCP_KEEPIDLE", getattr(socket, "TCP_KEEPALIVE", None)) + + options = list(HTTPConnection.default_socket_options) + keepalive_options = [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)] + if tcp_keepidle is not None: + keepalive_options.append((tcp_protocol, tcp_keepidle, TCP_KEEPALIVE_IDLE_TIMEOUT)) + if hasattr(socket, "TCP_KEEPINTVL"): + keepalive_options.append((tcp_protocol, socket.TCP_KEEPINTVL, TCP_KEEPALIVE_INTERVAL)) + + for option in keepalive_options: + if option not in options: + options.append(option) + + return options + + +class TCPKeepAliveHTTPAdapter(HTTPAdapter): + def init_poolmanager(self, connections, maxsize, block=DEFAULT_POOLBLOCK, **pool_kwargs): + pool_kwargs.setdefault("socket_options", _tcp_keepalive_socket_options()) + super().init_poolmanager(connections, maxsize, block=block, **pool_kwargs) + + def proxy_manager_for(self, proxy, **proxy_kwargs): + proxy_kwargs.setdefault("socket_options", _tcp_keepalive_socket_options()) + return super().proxy_manager_for(proxy, **proxy_kwargs) + class CustomerIOException(Exception): pass @@ -113,6 +146,6 @@ def _build_session(self): allowed_methods=None, status_forcelist=[500, 502, 503, 504], ) - session.mount("https://", HTTPAdapter(max_retries=retry)) + session.mount("https://", TCPKeepAliveHTTPAdapter(max_retries=retry)) return session diff --git a/tests/test_customerio.py b/tests/test_customerio.py index 698802d..22a7bec 100644 --- a/tests/test_customerio.py +++ b/tests/test_customerio.py @@ -1,12 +1,15 @@ import json +import socket import unittest from datetime import datetime from functools import partial import urllib3 from requests.auth import _basic_auth_str +from urllib3.connection import HTTPConnection from customerio import CustomerIO, CustomerIOException, Regions +from customerio.client_base import TCP_KEEPALIVE_IDLE_TIMEOUT, TCP_KEEPALIVE_INTERVAL from customerio.constants import CIOID, EMAIL, ID from tests.server import HTTPSTestCase @@ -64,6 +67,26 @@ def test_client_setup(self): with self.assertRaises(CustomerIOException): CustomerIO(site_id="site_id", api_key="api_key", region="au") + def test_keepalive_socket_options_are_configured_on_adapter(self): + default_socket_options = list(HTTPConnection.default_socket_options) + client = CustomerIO(site_id="site_id", api_key="api_key") + socket_options = client.http.adapters["https://"].poolmanager.connection_pool_kw[ + "socket_options" + ] + tcp_protocol = getattr(socket, "SOL_TCP", socket.IPPROTO_TCP) + tcp_keepidle = getattr(socket, "TCP_KEEPIDLE", getattr(socket, "TCP_KEEPALIVE", None)) + + for option in default_socket_options: + self.assertIn(option, socket_options) + self.assertIn((socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1), socket_options) + if tcp_keepidle is not None: + self.assertIn((tcp_protocol, tcp_keepidle, TCP_KEEPALIVE_IDLE_TIMEOUT), socket_options) + if hasattr(socket, "TCP_KEEPINTVL"): + self.assertIn( + (tcp_protocol, socket.TCP_KEEPINTVL, TCP_KEEPALIVE_INTERVAL), socket_options + ) + self.assertEqual(HTTPConnection.default_socket_options, default_socket_options) + def test_client_connection_handling(self): retries = self.cio.retries # should not raise exception as i should be less than retries and