|
7 | 7 | import pytest |
8 | 8 | import platform |
9 | 9 | import sys |
| 10 | +import threading |
10 | 11 | from unittest.mock import patch, MagicMock |
11 | 12 | from mssql_python.auth import ( |
12 | 13 | AADAuth, |
@@ -622,3 +623,164 @@ def test_auth_type_stored_on_connection(self, mock_ddbc_conn): |
622 | 623 | conn = connect("Server=test;Database=testdb;Authentication=ActiveDirectoryDefault") |
623 | 624 | assert conn._auth_type == "default" |
624 | 625 | conn.close() |
| 626 | + |
| 627 | + |
| 628 | +class TestCredentialCacheThreadSafety: |
| 629 | + """Verify thread-safe behavior of credential instance cache.""" |
| 630 | + |
| 631 | + def test_concurrent_access_creates_only_one_instance(self): |
| 632 | + """Multiple threads calling get_token concurrently should result in |
| 633 | + exactly one credential instance per auth type in the cache.""" |
| 634 | + import sys |
| 635 | + |
| 636 | + azure_identity = sys.modules["azure.identity"] |
| 637 | + original = azure_identity.DefaultAzureCredential |
| 638 | + |
| 639 | + instances_created = [] |
| 640 | + |
| 641 | + class TrackingCredential: |
| 642 | + def __init__(self): |
| 643 | + instances_created.append(self) |
| 644 | + |
| 645 | + def get_token(self, scope): |
| 646 | + class Token: |
| 647 | + token = SAMPLE_TOKEN |
| 648 | + |
| 649 | + return Token() |
| 650 | + |
| 651 | + try: |
| 652 | + azure_identity.DefaultAzureCredential = TrackingCredential |
| 653 | + |
| 654 | + errors = [] |
| 655 | + barrier = threading.Barrier(10) |
| 656 | + |
| 657 | + def worker(): |
| 658 | + try: |
| 659 | + barrier.wait(timeout=5) |
| 660 | + AADAuth.get_token("default") |
| 661 | + except Exception as e: |
| 662 | + errors.append(e) |
| 663 | + |
| 664 | + threads = [threading.Thread(target=worker) for _ in range(10)] |
| 665 | + for t in threads: |
| 666 | + t.start() |
| 667 | + for t in threads: |
| 668 | + t.join(timeout=10) |
| 669 | + |
| 670 | + assert not errors, f"Threads raised errors: {errors}" |
| 671 | + # Only one credential instance should exist in the cache |
| 672 | + assert "default" in _credential_cache |
| 673 | + # All threads should use the same cached instance |
| 674 | + cached = _credential_cache["default"] |
| 675 | + assert isinstance(cached, TrackingCredential) |
| 676 | + # Due to the lock, only one instance should have been created |
| 677 | + assert len(instances_created) == 1 |
| 678 | + finally: |
| 679 | + azure_identity.DefaultAzureCredential = original |
| 680 | + |
| 681 | + |
| 682 | +class TestCacheStateAfterErrors: |
| 683 | + """Verify credential cache state after various error scenarios.""" |
| 684 | + |
| 685 | + def test_client_auth_error_leaves_credential_in_cache(self): |
| 686 | + """When get_token raises ClientAuthenticationError, the credential |
| 687 | + instance should still remain in the cache since it was created |
| 688 | + successfully — only the token acquisition failed.""" |
| 689 | + import sys |
| 690 | + |
| 691 | + azure_identity = sys.modules["azure.identity"] |
| 692 | + original = azure_identity.DefaultAzureCredential |
| 693 | + from azure.core.exceptions import ClientAuthenticationError |
| 694 | + |
| 695 | + class CredentialThatFailsGetToken: |
| 696 | + def get_token(self, scope): |
| 697 | + raise ClientAuthenticationError("token denied") |
| 698 | + |
| 699 | + try: |
| 700 | + azure_identity.DefaultAzureCredential = CredentialThatFailsGetToken |
| 701 | + |
| 702 | + with pytest.raises(RuntimeError, match="Azure AD authentication failed"): |
| 703 | + AADAuth._acquire_token("default") |
| 704 | + |
| 705 | + # Credential was created and cached before get_token failed |
| 706 | + assert "default" in _credential_cache |
| 707 | + assert isinstance(_credential_cache["default"], CredentialThatFailsGetToken) |
| 708 | + finally: |
| 709 | + azure_identity.DefaultAzureCredential = original |
| 710 | + |
| 711 | + def test_init_error_does_not_leave_stale_entry_in_cache(self): |
| 712 | + """When credential_class() raises during __init__, no entry should |
| 713 | + be left in _credential_cache since the dict assignment never completes.""" |
| 714 | + import sys |
| 715 | + |
| 716 | + azure_identity = sys.modules["azure.identity"] |
| 717 | + original = azure_identity.DefaultAzureCredential |
| 718 | + |
| 719 | + class CredentialThatFailsInit: |
| 720 | + def __init__(self): |
| 721 | + raise ValueError("init exploded") |
| 722 | + |
| 723 | + try: |
| 724 | + azure_identity.DefaultAzureCredential = CredentialThatFailsInit |
| 725 | + |
| 726 | + with pytest.raises(RuntimeError, match="Failed to create"): |
| 727 | + AADAuth.get_token("default") |
| 728 | + |
| 729 | + # The cache should NOT contain a stale entry |
| 730 | + assert "default" not in _credential_cache |
| 731 | + finally: |
| 732 | + azure_identity.DefaultAzureCredential = original |
| 733 | + |
| 734 | + |
| 735 | +class TestCacheOutputCorrectness: |
| 736 | + """Verify the returned token bytes are correct on both cache-miss and cache-hit.""" |
| 737 | + |
| 738 | + def test_token_output_correct_on_cache_miss_and_hit(self): |
| 739 | + """get_token should return correct token bytes on both |
| 740 | + the initial (cache-miss) and subsequent (cache-hit) calls.""" |
| 741 | + # First call — cache miss |
| 742 | + token_1 = AADAuth.get_token("default") |
| 743 | + assert isinstance(token_1, bytes) |
| 744 | + assert len(token_1) > 4 |
| 745 | + expected = AADAuth.get_token_struct(SAMPLE_TOKEN) |
| 746 | + assert token_1 == expected |
| 747 | + |
| 748 | + # Second call — cache hit |
| 749 | + token_2 = AADAuth.get_token("default") |
| 750 | + assert isinstance(token_2, bytes) |
| 751 | + assert token_2 == expected |
| 752 | + |
| 753 | + # Same credential instance for both |
| 754 | + assert "default" in _credential_cache |
| 755 | + |
| 756 | + |
| 757 | +class TestProcessConnectionStringTokenFailureFallthrough: |
| 758 | + """Cover the path where get_auth_token returns None and |
| 759 | + process_connection_string falls through without attrs.""" |
| 760 | + |
| 761 | + def test_returns_none_attrs_when_token_acquisition_fails(self): |
| 762 | + """When auth type is detected but token acquisition fails, |
| 763 | + process_connection_string should return (conn_str, None, auth_type).""" |
| 764 | + import sys |
| 765 | + |
| 766 | + azure_identity = sys.modules["azure.identity"] |
| 767 | + original = azure_identity.DefaultAzureCredential |
| 768 | + |
| 769 | + class CredentialThatAlwaysFails: |
| 770 | + def __init__(self): |
| 771 | + raise RuntimeError("cannot create credential") |
| 772 | + |
| 773 | + try: |
| 774 | + azure_identity.DefaultAzureCredential = CredentialThatAlwaysFails |
| 775 | + conn_str = "Server=test;Authentication=ActiveDirectoryDefault;Database=testdb" |
| 776 | + result_str, attrs, auth_type = process_connection_string(conn_str) |
| 777 | + |
| 778 | + # Auth type was detected |
| 779 | + assert auth_type == "default" |
| 780 | + # But token acquisition failed, so attrs is None |
| 781 | + assert attrs is None |
| 782 | + # Connection string is still returned (sensitive params removed) |
| 783 | + assert "Server=test" in result_str |
| 784 | + assert "Database=testdb" in result_str |
| 785 | + finally: |
| 786 | + azure_identity.DefaultAzureCredential = original |
0 commit comments