Skip to content

Commit 8663ea4

Browse files
committed
FIX: Credential instance cache
1 parent 064f543 commit 8663ea4

3 files changed

Lines changed: 264 additions & 6 deletions

File tree

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""
2+
Benchmark: Credential Instance Caching for Azure AD Authentication
3+
4+
Measures the performance difference between:
5+
1. Creating a new DefaultAzureCredential + get_token() each call (old behavior)
6+
2. Reusing a cached DefaultAzureCredential instance (new behavior)
7+
8+
Prerequisites:
9+
- pip install azure-identity azure-core
10+
- az login (for AzureCliCredential to work)
11+
12+
Usage:
13+
python benchmarks/bench_credential_cache.py
14+
"""
15+
16+
import time
17+
import statistics
18+
19+
20+
def bench_no_cache(n: int) -> list[float]:
21+
"""Simulate the OLD behavior: new credential per call."""
22+
from azure.identity import DefaultAzureCredential
23+
24+
times = []
25+
for _ in range(n):
26+
start = time.perf_counter()
27+
cred = DefaultAzureCredential()
28+
cred.get_token("https://database.windows.net/.default")
29+
times.append(time.perf_counter() - start)
30+
return times
31+
32+
33+
def bench_with_cache(n: int) -> list[float]:
34+
"""Simulate the NEW behavior: reuse a single credential instance."""
35+
from azure.identity import DefaultAzureCredential
36+
37+
cred = DefaultAzureCredential()
38+
times = []
39+
for _ in range(n):
40+
start = time.perf_counter()
41+
cred.get_token("https://database.windows.net/.default")
42+
times.append(time.perf_counter() - start)
43+
return times
44+
45+
46+
def report(label: str, times: list[float]) -> None:
47+
print(f"\n{'=' * 50}")
48+
print(f" {label}")
49+
print(f"{'=' * 50}")
50+
print(f" Calls: {len(times)}")
51+
print(f" Total: {sum(times):.3f}s")
52+
print(f" Mean: {statistics.mean(times) * 1000:.1f}ms")
53+
print(f" Median: {statistics.median(times) * 1000:.1f}ms")
54+
print(f" Stdev: {statistics.stdev(times) * 1000:.1f}ms" if len(times) > 1 else "")
55+
print(f" Min: {min(times) * 1000:.1f}ms")
56+
print(f" Max: {max(times) * 1000:.1f}ms")
57+
58+
59+
def main() -> None:
60+
N = 10 # number of calls to benchmark
61+
62+
print("Credential Instance Cache Benchmark")
63+
print(f"Running {N} sequential token acquisitions for each scenario...\n")
64+
65+
try:
66+
print(">>> Without cache (new credential each call)...")
67+
no_cache_times = bench_no_cache(N)
68+
report("WITHOUT credential cache (old behavior)", no_cache_times)
69+
70+
print("\n>>> With cache (reuse credential instance)...")
71+
cache_times = bench_with_cache(N)
72+
report("WITH credential cache (new behavior)", cache_times)
73+
74+
speedup = statistics.mean(no_cache_times) / statistics.mean(cache_times)
75+
saved = (statistics.mean(no_cache_times) - statistics.mean(cache_times)) * 1000
76+
print(f"\n{'=' * 50}")
77+
print(f" SPEEDUP: {speedup:.1f}x ({saved:.0f}ms saved per call)")
78+
print(f"{'=' * 50}")
79+
except Exception as e:
80+
print(f"\nBenchmark failed: {e}")
81+
print("Make sure you are logged in via 'az login' and have azure-identity installed.")
82+
83+
84+
if __name__ == "__main__":
85+
main()

mssql_python/auth.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,19 @@
66

77
import platform
88
import struct
9+
import threading
910
from typing import Tuple, Dict, Optional, List
1011

1112
from mssql_python.logging import logger
1213
from mssql_python.constants import AuthType, ConstantsDDBC
1314

15+
# Module-level credential instance cache.
16+
# Reusing credential objects allows the Azure Identity SDK's built-in
17+
# in-memory token cache to work, avoiding redundant token acquisitions.
18+
# See: https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/identity/azure-identity/TOKEN_CACHING.md
19+
_credential_cache: Dict[str, object] = {}
20+
_credential_cache_lock = threading.Lock()
21+
1422

1523
class AADAuth:
1624
"""Handles Azure Active Directory authentication"""
@@ -36,12 +44,11 @@ def get_token(auth_type: str) -> bytes:
3644

3745
@staticmethod
3846
def get_raw_token(auth_type: str) -> str:
39-
"""Acquire a fresh raw JWT for the mssql-py-core connection (bulk copy).
47+
"""Acquire a raw JWT for the mssql-py-core connection (bulk copy).
4048
41-
This deliberately does NOT cache the credential or token — each call
42-
creates a new Azure Identity credential instance and requests a token.
43-
A fresh acquisition avoids expired-token errors when bulkcopy() is
44-
called long after the original DDBC connect().
49+
Uses the cached credential instance so the Azure Identity SDK's
50+
built-in token cache can serve a valid token without a round-trip
51+
when the previous token has not yet expired.
4552
"""
4653
_, raw_token = AADAuth._acquire_token(auth_type)
4754
return raw_token
@@ -83,7 +90,19 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]:
8390
)
8491

8592
try:
86-
credential = credential_class()
93+
with _credential_cache_lock:
94+
if auth_type not in _credential_cache:
95+
logger.debug(
96+
"get_token: Creating new credential instance for auth_type=%s",
97+
auth_type,
98+
)
99+
_credential_cache[auth_type] = credential_class()
100+
else:
101+
logger.debug(
102+
"get_token: Reusing cached credential instance for auth_type=%s",
103+
auth_type,
104+
)
105+
credential = _credential_cache[auth_type]
87106
raw_token = credential.get_token("https://database.windows.net/.default").token
88107
logger.info(
89108
"get_token: Azure AD token acquired successfully - token_length=%d chars",

tests/test_008_auth.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
get_auth_token,
1616
process_connection_string,
1717
extract_auth_type,
18+
_credential_cache,
19+
_credential_cache_lock,
1820
)
1921
from mssql_python.constants import AuthType, ConstantsDDBC
2022
import secrets
@@ -71,6 +73,14 @@ class exceptions:
7173
del sys.modules[module]
7274

7375

76+
@pytest.fixture(autouse=True)
77+
def clear_credential_cache():
78+
"""Clear the module-level credential cache between tests."""
79+
_credential_cache.clear()
80+
yield
81+
_credential_cache.clear()
82+
83+
7484
class TestAuthType:
7585
def test_auth_type_constants(self):
7686
assert AuthType.INTERACTIVE.value == "activedirectoryinteractive"
@@ -403,6 +413,150 @@ def test_unsupported_auth(self):
403413
assert extract_auth_type("Server=test;Authentication=SqlPassword;") is None
404414

405415

416+
class TestCredentialInstanceCache:
417+
"""Tests for the credential instance caching behavior."""
418+
419+
def test_credential_reused_across_calls(self):
420+
"""The same credential instance should be returned for repeated calls."""
421+
AADAuth.get_token("default")
422+
assert "default" in _credential_cache
423+
first_instance = _credential_cache["default"]
424+
425+
AADAuth.get_token("default")
426+
assert _credential_cache["default"] is first_instance
427+
428+
def test_different_auth_types_get_separate_instances(self):
429+
"""Each auth type should have its own cached credential."""
430+
AADAuth.get_token("default")
431+
AADAuth.get_token("devicecode")
432+
433+
assert "default" in _credential_cache
434+
assert "devicecode" in _credential_cache
435+
assert _credential_cache["default"] is not _credential_cache["devicecode"]
436+
437+
def test_get_raw_token_uses_cached_credential(self):
438+
"""get_raw_token should also use the cached credential instance."""
439+
AADAuth.get_token("default")
440+
cached = _credential_cache["default"]
441+
442+
AADAuth.get_raw_token("default")
443+
assert _credential_cache["default"] is cached
444+
445+
def test_cache_starts_empty(self):
446+
"""Cache should be empty at the start due to the clear_credential_cache fixture."""
447+
assert len(_credential_cache) == 0
448+
449+
450+
class TestAcquireTokenImportError:
451+
"""Test the ImportError path when azure-identity is not installed."""
452+
453+
def test_import_error_raises_runtime_error(self):
454+
"""_acquire_token raises RuntimeError when azure.identity is missing."""
455+
import sys
456+
457+
# Temporarily remove the mocked azure modules
458+
saved = {}
459+
for mod_name in list(sys.modules):
460+
if mod_name == "azure" or mod_name.startswith("azure."):
461+
saved[mod_name] = sys.modules.pop(mod_name)
462+
463+
# Make the import fail
464+
import builtins
465+
466+
real_import = builtins.__import__
467+
468+
def blocked_import(name, *args, **kwargs):
469+
if name.startswith("azure"):
470+
raise ImportError("No module named 'azure'")
471+
return real_import(name, *args, **kwargs)
472+
473+
builtins.__import__ = blocked_import
474+
try:
475+
with pytest.raises(
476+
RuntimeError, match="Azure authentication libraries are not installed"
477+
):
478+
AADAuth._acquire_token("default")
479+
finally:
480+
builtins.__import__ = real_import
481+
sys.modules.update(saved)
482+
483+
484+
class TestAcquireTokenClientAuthError:
485+
"""Test the ClientAuthenticationError path inside _acquire_token."""
486+
487+
def test_client_auth_error_in_acquire_token(self):
488+
"""ClientAuthenticationError during get_token is wrapped in RuntimeError."""
489+
import sys
490+
491+
azure_identity = sys.modules["azure.identity"]
492+
original = azure_identity.DefaultAzureCredential
493+
494+
from azure.core.exceptions import ClientAuthenticationError
495+
496+
class FailingCredential:
497+
def get_token(self, scope):
498+
raise ClientAuthenticationError("token request denied")
499+
500+
try:
501+
azure_identity.DefaultAzureCredential = FailingCredential
502+
with pytest.raises(RuntimeError, match="Azure AD authentication failed"):
503+
AADAuth._acquire_token("default")
504+
finally:
505+
azure_identity.DefaultAzureCredential = original
506+
507+
508+
class TestProcessAuthParametersEdgeCases:
509+
"""Cover empty-param and no-equals-sign branches."""
510+
511+
def test_empty_and_whitespace_params_skipped(self):
512+
params = ["Server=test", "", " ", "Database=db"]
513+
modified, auth_type = process_auth_parameters(params)
514+
assert "Server=test" in modified
515+
assert "Database=db" in modified
516+
assert auth_type is None
517+
518+
def test_param_without_equals_kept(self):
519+
params = ["Server=test", "SomeFlag", "Database=db"]
520+
modified, auth_type = process_auth_parameters(params)
521+
assert "SomeFlag" in modified
522+
assert "Server=test" in modified
523+
524+
525+
class TestGetAuthTokenEdgeCases:
526+
"""Cover the Windows-interactive and token-failure branches."""
527+
528+
def test_no_auth_type_returns_none(self):
529+
result = get_auth_token(None)
530+
assert result is None
531+
532+
def test_empty_auth_type_returns_none(self):
533+
result = get_auth_token("")
534+
assert result is None
535+
536+
def test_windows_interactive_returns_none(self, monkeypatch):
537+
monkeypatch.setattr(platform, "system", lambda: "Windows")
538+
result = get_auth_token("interactive")
539+
assert result is None
540+
541+
def test_token_acquisition_failure_returns_none(self):
542+
"""When AADAuth.get_token raises, get_auth_token returns None."""
543+
import sys
544+
545+
azure_identity = sys.modules["azure.identity"]
546+
original = azure_identity.DefaultAzureCredential
547+
548+
class FailingCredential:
549+
def __init__(self):
550+
raise RuntimeError("credential creation exploded")
551+
552+
try:
553+
azure_identity.DefaultAzureCredential = FailingCredential
554+
result = get_auth_token("default")
555+
assert result is None
556+
finally:
557+
azure_identity.DefaultAzureCredential = original
558+
559+
406560
def test_acquire_token_unsupported_auth_type():
407561
with pytest.raises(ValueError, match="Unsupported auth_type 'bogus'"):
408562
AADAuth._acquire_token("bogus")

0 commit comments

Comments
 (0)