diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 448b3a95..35c2fb85 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -20,10 +20,10 @@ import mssql_python from mssql_python.cursor import Cursor from mssql_python.helpers import ( - sanitize_connection_string, sanitize_user_input, validate_attribute_value, ) +from mssql_python.connection_string_parser import sanitize_connection_string from mssql_python.logging import logger from mssql_python import ddbc_bindings from mssql_python.pooling import PoolingManager diff --git a/mssql_python/connection_string_parser.py b/mssql_python/connection_string_parser.py index cdf17620..a9b002bd 100644 --- a/mssql_python/connection_string_parser.py +++ b/mssql_python/connection_string_parser.py @@ -21,6 +21,8 @@ from mssql_python.helpers import sanitize_user_input from mssql_python.logging import logger +_SENSITIVE_KEYS = frozenset({"pwd", "password"}) + class _ConnectionStringParser: """ @@ -375,3 +377,48 @@ def _parse_braced_value(self, connection_str: str, start_pos: int) -> Tuple[str, # Reached end without finding closing '}' raise ValueError(f"Unclosed braced value starting at position {brace_start_pos}") + + +def sanitize_connection_string(conn_str: str) -> str: + """ + Sanitize a connection string by masking sensitive values (PWD, Password). + + Uses _ConnectionStringParser to correctly handle ODBC braced values + (e.g. PWD={Top;Secret}) rather than a simple regex, which would truncate + at the first semicolon and leak the tail of the password. + + If parsing fails (malformed input), the entire string is redacted to + prevent any partial password leakage. + + Args: + conn_str (str): The connection string to sanitize. + Returns: + str: The sanitized connection string. + """ + from mssql_python.connection_string_builder import _ConnectionStringBuilder + + logger.debug( + "sanitize_connection_string: Sanitizing connection string (length=%d)", len(conn_str) + ) + + try: + parser = _ConnectionStringParser(validate_keywords=False) + params = parser._parse(conn_str) + + sanitized_params = {} + for key, value in params.items(): + canonical = _ConnectionStringParser.normalize_key(key) + display_key = canonical if canonical else key + if key in _SENSITIVE_KEYS: + sanitized_params[display_key] = "***" + else: + sanitized_params[display_key] = value + + builder = _ConnectionStringBuilder(sanitized_params) + sanitized = builder.build() + except Exception: + logger.debug("sanitize_connection_string: Failed to parse, redacting entire string") + sanitized = "" + + logger.debug("sanitize_connection_string: Password fields masked") + return sanitized diff --git a/mssql_python/helpers.py b/mssql_python/helpers.py index 5ce0617a..65fab886 100644 --- a/mssql_python/helpers.py +++ b/mssql_python/helpers.py @@ -41,19 +41,20 @@ def check_error(handle_type: int, handle: Any, ret: int) -> None: def sanitize_connection_string(conn_str: str) -> str: """ Sanitize the connection string by removing sensitive information. + + Delegates to the parser-based implementation in connection_string_parser + which correctly handles ODBC braced values (e.g. PWD={Top;Secret}). + Args: conn_str (str): The connection string to sanitize. Returns: str: The sanitized connection string. """ - logger.debug( - "sanitize_connection_string: Sanitizing connection string (length=%d)", len(conn_str) + from mssql_python.connection_string_parser import ( + sanitize_connection_string as _sanitize, ) - # Remove sensitive information from the connection string, Pwd section - # Replace Pwd=...; or Pwd=... (end of string) with Pwd=***; - sanitized = re.sub(r"(Pwd\s*=\s*)[^;]*", r"\1***", conn_str, flags=re.IGNORECASE) - logger.debug("sanitize_connection_string: Password fields masked") - return sanitized + + return _sanitize(conn_str) def sanitize_user_input(user_input: str, max_length: int = 50) -> str: diff --git a/tests/test_007_logging.py b/tests/test_007_logging.py index daafc6fc..0e5e3d89 100644 --- a/tests/test_007_logging.py +++ b/tests/test_007_logging.py @@ -314,20 +314,84 @@ def test_pwd_sanitization(self, cleanup_logger): assert "secret123" not in sanitized def test_pwd_case_insensitive(self, cleanup_logger): - """PWD/Pwd/pwd should all be sanitized (case-insensitive)""" + """PWD/Pwd/pwd should all be sanitized to canonical PWD=***""" from mssql_python.helpers import sanitize_connection_string test_cases = [ - ("Server=localhost;PWD=secret;Database=test", "PWD=***"), - ("Server=localhost;Pwd=secret;Database=test", "Pwd=***"), - ("Server=localhost;pwd=secret;Database=test", "pwd=***"), + "Server=localhost;PWD=secret;Database=test", + "Server=localhost;Pwd=secret;Database=test", + "Server=localhost;pwd=secret;Database=test", ] - for conn_str, expected in test_cases: + for conn_str in test_cases: sanitized = sanitize_connection_string(conn_str) - assert expected in sanitized + assert "PWD=***" in sanitized assert "secret" not in sanitized + def test_pwd_braced_value_with_semicolon(self, cleanup_logger): + """PWD with braced value containing semicolons must be fully masked.""" + from mssql_python.helpers import sanitize_connection_string + + conn_str = "Server=localhost;PWD={Top;Secret};Database=test" + sanitized = sanitize_connection_string(conn_str) + + assert "PWD=***" in sanitized + assert "Top" not in sanitized + assert "Secret" not in sanitized + + def test_pwd_braced_value_with_escaped_braces(self, cleanup_logger): + """PWD with escaped closing braces (}}) must be fully masked.""" + from mssql_python.helpers import sanitize_connection_string + + conn_str = "Server=localhost;PWD={p}}w{{d};Database=test" + sanitized = sanitize_connection_string(conn_str) + + assert "PWD=***" in sanitized + assert "p}w{d" not in sanitized + + def test_pwd_braced_value_multiple_semicolons(self, cleanup_logger): + """PWD with multiple semicolons inside braces must be fully masked.""" + from mssql_python.helpers import sanitize_connection_string + + conn_str = "Server=localhost;PWD={a;b;c;d};Database=test" + sanitized = sanitize_connection_string(conn_str) + + assert "PWD=***" in sanitized + for fragment in ("a;b;c;d", "{a;", "b;c", "c;d}"): + assert fragment not in sanitized + + def test_pwd_at_end_of_string(self, cleanup_logger): + """PWD at end of connection string (no trailing semicolon) must be masked.""" + from mssql_python.helpers import sanitize_connection_string + + conn_str = "Server=localhost;Database=test;PWD=secret" + sanitized = sanitize_connection_string(conn_str) + + assert "PWD=***" in sanitized + assert "secret" not in sanitized + + def test_no_pwd_unchanged(self, cleanup_logger): + """Connection string without PWD should be returned intact.""" + from mssql_python.helpers import sanitize_connection_string + + conn_str = "Server=localhost;Database=test;UID=user" + sanitized = sanitize_connection_string(conn_str) + + assert "Server=" in sanitized + assert "Database=" in sanitized + assert "UID=" in sanitized + + def test_malformed_string_fully_redacted(self, cleanup_logger): + """Malformed connection string should be fully redacted, not partially leaked.""" + from mssql_python.helpers import sanitize_connection_string + + conn_str = "PWD={unclosed" + sanitized = sanitize_connection_string(conn_str) + + assert "unclosed" not in sanitized + assert "PWD" not in sanitized + assert "redacted" in sanitized.lower() + def test_explicit_sanitization_in_logging(self, cleanup_logger): """Verify that explicit sanitization works when logging""" from mssql_python.helpers import sanitize_connection_string