|
15 | 15 | get_auth_token, |
16 | 16 | process_connection_string, |
17 | 17 | extract_auth_type, |
| 18 | + _credential_cache, |
| 19 | + _credential_cache_lock, |
18 | 20 | ) |
19 | 21 | from mssql_python.constants import AuthType, ConstantsDDBC |
20 | 22 | import secrets |
@@ -71,6 +73,14 @@ class exceptions: |
71 | 73 | del sys.modules[module] |
72 | 74 |
|
73 | 75 |
|
| 76 | +@pytest.fixture(autouse=True) |
| 77 | +def clear_credential_cache(): |
| 78 | + """Clear the module-level credential cache between tests.""" |
| 79 | + _credential_cache.clear() |
| 80 | + yield |
| 81 | + _credential_cache.clear() |
| 82 | + |
| 83 | + |
74 | 84 | class TestAuthType: |
75 | 85 | def test_auth_type_constants(self): |
76 | 86 | assert AuthType.INTERACTIVE.value == "activedirectoryinteractive" |
@@ -403,6 +413,150 @@ def test_unsupported_auth(self): |
403 | 413 | assert extract_auth_type("Server=test;Authentication=SqlPassword;") is None |
404 | 414 |
|
405 | 415 |
|
| 416 | +class TestCredentialInstanceCache: |
| 417 | + """Tests for the credential instance caching behavior.""" |
| 418 | + |
| 419 | + def test_credential_reused_across_calls(self): |
| 420 | + """The same credential instance should be returned for repeated calls.""" |
| 421 | + AADAuth.get_token("default") |
| 422 | + assert "default" in _credential_cache |
| 423 | + first_instance = _credential_cache["default"] |
| 424 | + |
| 425 | + AADAuth.get_token("default") |
| 426 | + assert _credential_cache["default"] is first_instance |
| 427 | + |
| 428 | + def test_different_auth_types_get_separate_instances(self): |
| 429 | + """Each auth type should have its own cached credential.""" |
| 430 | + AADAuth.get_token("default") |
| 431 | + AADAuth.get_token("devicecode") |
| 432 | + |
| 433 | + assert "default" in _credential_cache |
| 434 | + assert "devicecode" in _credential_cache |
| 435 | + assert _credential_cache["default"] is not _credential_cache["devicecode"] |
| 436 | + |
| 437 | + def test_get_raw_token_uses_cached_credential(self): |
| 438 | + """get_raw_token should also use the cached credential instance.""" |
| 439 | + AADAuth.get_token("default") |
| 440 | + cached = _credential_cache["default"] |
| 441 | + |
| 442 | + AADAuth.get_raw_token("default") |
| 443 | + assert _credential_cache["default"] is cached |
| 444 | + |
| 445 | + def test_cache_starts_empty(self): |
| 446 | + """Cache should be empty at the start due to the clear_credential_cache fixture.""" |
| 447 | + assert len(_credential_cache) == 0 |
| 448 | + |
| 449 | + |
| 450 | +class TestAcquireTokenImportError: |
| 451 | + """Test the ImportError path when azure-identity is not installed.""" |
| 452 | + |
| 453 | + def test_import_error_raises_runtime_error(self): |
| 454 | + """_acquire_token raises RuntimeError when azure.identity is missing.""" |
| 455 | + import sys |
| 456 | + |
| 457 | + # Temporarily remove the mocked azure modules |
| 458 | + saved = {} |
| 459 | + for mod_name in list(sys.modules): |
| 460 | + if mod_name == "azure" or mod_name.startswith("azure."): |
| 461 | + saved[mod_name] = sys.modules.pop(mod_name) |
| 462 | + |
| 463 | + # Make the import fail |
| 464 | + import builtins |
| 465 | + |
| 466 | + real_import = builtins.__import__ |
| 467 | + |
| 468 | + def blocked_import(name, *args, **kwargs): |
| 469 | + if name.startswith("azure"): |
| 470 | + raise ImportError("No module named 'azure'") |
| 471 | + return real_import(name, *args, **kwargs) |
| 472 | + |
| 473 | + builtins.__import__ = blocked_import |
| 474 | + try: |
| 475 | + with pytest.raises( |
| 476 | + RuntimeError, match="Azure authentication libraries are not installed" |
| 477 | + ): |
| 478 | + AADAuth._acquire_token("default") |
| 479 | + finally: |
| 480 | + builtins.__import__ = real_import |
| 481 | + sys.modules.update(saved) |
| 482 | + |
| 483 | + |
| 484 | +class TestAcquireTokenClientAuthError: |
| 485 | + """Test the ClientAuthenticationError path inside _acquire_token.""" |
| 486 | + |
| 487 | + def test_client_auth_error_in_acquire_token(self): |
| 488 | + """ClientAuthenticationError during get_token is wrapped in RuntimeError.""" |
| 489 | + import sys |
| 490 | + |
| 491 | + azure_identity = sys.modules["azure.identity"] |
| 492 | + original = azure_identity.DefaultAzureCredential |
| 493 | + |
| 494 | + from azure.core.exceptions import ClientAuthenticationError |
| 495 | + |
| 496 | + class FailingCredential: |
| 497 | + def get_token(self, scope): |
| 498 | + raise ClientAuthenticationError("token request denied") |
| 499 | + |
| 500 | + try: |
| 501 | + azure_identity.DefaultAzureCredential = FailingCredential |
| 502 | + with pytest.raises(RuntimeError, match="Azure AD authentication failed"): |
| 503 | + AADAuth._acquire_token("default") |
| 504 | + finally: |
| 505 | + azure_identity.DefaultAzureCredential = original |
| 506 | + |
| 507 | + |
| 508 | +class TestProcessAuthParametersEdgeCases: |
| 509 | + """Cover empty-param and no-equals-sign branches.""" |
| 510 | + |
| 511 | + def test_empty_and_whitespace_params_skipped(self): |
| 512 | + params = ["Server=test", "", " ", "Database=db"] |
| 513 | + modified, auth_type = process_auth_parameters(params) |
| 514 | + assert "Server=test" in modified |
| 515 | + assert "Database=db" in modified |
| 516 | + assert auth_type is None |
| 517 | + |
| 518 | + def test_param_without_equals_kept(self): |
| 519 | + params = ["Server=test", "SomeFlag", "Database=db"] |
| 520 | + modified, auth_type = process_auth_parameters(params) |
| 521 | + assert "SomeFlag" in modified |
| 522 | + assert "Server=test" in modified |
| 523 | + |
| 524 | + |
| 525 | +class TestGetAuthTokenEdgeCases: |
| 526 | + """Cover the Windows-interactive and token-failure branches.""" |
| 527 | + |
| 528 | + def test_no_auth_type_returns_none(self): |
| 529 | + result = get_auth_token(None) |
| 530 | + assert result is None |
| 531 | + |
| 532 | + def test_empty_auth_type_returns_none(self): |
| 533 | + result = get_auth_token("") |
| 534 | + assert result is None |
| 535 | + |
| 536 | + def test_windows_interactive_returns_none(self, monkeypatch): |
| 537 | + monkeypatch.setattr(platform, "system", lambda: "Windows") |
| 538 | + result = get_auth_token("interactive") |
| 539 | + assert result is None |
| 540 | + |
| 541 | + def test_token_acquisition_failure_returns_none(self): |
| 542 | + """When AADAuth.get_token raises, get_auth_token returns None.""" |
| 543 | + import sys |
| 544 | + |
| 545 | + azure_identity = sys.modules["azure.identity"] |
| 546 | + original = azure_identity.DefaultAzureCredential |
| 547 | + |
| 548 | + class FailingCredential: |
| 549 | + def __init__(self): |
| 550 | + raise RuntimeError("credential creation exploded") |
| 551 | + |
| 552 | + try: |
| 553 | + azure_identity.DefaultAzureCredential = FailingCredential |
| 554 | + result = get_auth_token("default") |
| 555 | + assert result is None |
| 556 | + finally: |
| 557 | + azure_identity.DefaultAzureCredential = original |
| 558 | + |
| 559 | + |
406 | 560 | def test_acquire_token_unsupported_auth_type(): |
407 | 561 | with pytest.raises(ValueError, match="Unsupported auth_type 'bogus'"): |
408 | 562 | AADAuth._acquire_token("bogus") |
|
0 commit comments