diff --git a/keepercommander/__main__.py b/keepercommander/__main__.py index eaf7cba03..cb69ef2c4 100644 --- a/keepercommander/__main__.py +++ b/keepercommander/__main__.py @@ -12,15 +12,12 @@ import argparse -import certifi import json import logging import os import re import shlex import sys -import ssl -import platform from pathlib import Path from typing import Optional @@ -178,69 +175,6 @@ def handle_exceptions(exc_type, exc_value, exc_traceback): sys.exit(-1) -def get_ssl_cert_file(): - """Get SSL certificate file path, preferring system CA store for corporate environments like Zscaler""" - - # Allow user to override via environment variable - user_cert_file = os.getenv('KEEPER_SSL_CERT_FILE') - if user_cert_file: - if user_cert_file.lower() == 'system': - # User explicitly wants system certs - pass # Continue with system detection below - elif user_cert_file.lower() == 'certifi': - # User explicitly wants certifi - return certifi.where() - elif user_cert_file.lower() == 'none' or user_cert_file.lower() == 'false': - # User wants to disable SSL verification (not recommended) - return None - elif os.path.exists(user_cert_file): - # User provided specific cert file - return user_cert_file - else: - logging.warning(f"SSL cert file specified in KEEPER_SSL_CERT_FILE not found: {user_cert_file}") - - # Try to use system CA store first for corporate environments - try: - # On macOS, try Homebrew certificates first (better for corporate environments like Zscaler) - if platform.system() == 'Darwin': - system_ca_paths = [ - '/opt/homebrew/etc/ca-certificates/cert.pem', # Homebrew CA bundle (best for Zscaler) - '/usr/local/etc/ssl/cert.pem', # Homebrew SSL (older location) - '/etc/ssl/cert.pem', # macOS system CA bundle - ] - for ca_path in system_ca_paths: - if os.path.exists(ca_path): - return ca_path - - # On Linux/Unix systems - elif platform.system() == 'Linux': - system_ca_paths = [ - '/etc/ssl/certs/ca-certificates.crt', # Debian/Ubuntu - '/etc/pki/tls/certs/ca-bundle.crt', # RHEL/CentOS - '/etc/ssl/ca-bundle.pem', # OpenSUSE - '/etc/ssl/cert.pem', # Generic - ] - for ca_path in system_ca_paths: - if os.path.exists(ca_path): - return ca_path - - # Try to get default SSL context locations - try: - default_locations = ssl.get_default_verify_paths() - if default_locations.cafile and os.path.exists(default_locations.cafile): - return default_locations.cafile - if default_locations.capath and os.path.exists(default_locations.capath): - return default_locations.capath - except: - pass - - except Exception: - pass - - # Fall back to certifi if system CA not available - return certifi.where() - - def main(from_package=False): if sys.platform == 'win32': try: @@ -253,15 +187,12 @@ def main(from_package=False): if logger: logger.name = 'keepercommander' - # Use system CA certificates when available (supports Zscaler), fallback to certifi - ssl_cert_file = get_ssl_cert_file() + ssl_cert_file = utils.get_ssl_cert_file() if ssl_cert_file: os.environ['SSL_CERT_FILE'] = ssl_cert_file else: - # User explicitly disabled SSL verification logging.warning("Warning: SSL certificate verification has been disabled. This is not recommended for production use.") - if 'SSL_CERT_FILE' in os.environ: - del os.environ['SSL_CERT_FILE'] + os.environ.pop('SSL_CERT_FILE', None) errno = 0 diff --git a/keepercommander/commands/_cloud_import_base.py b/keepercommander/commands/_cloud_import_base.py index 8bd18ed89..aaadaf331 100644 --- a/keepercommander/commands/_cloud_import_base.py +++ b/keepercommander/commands/_cloud_import_base.py @@ -13,6 +13,7 @@ import json import logging +import os import re from typing import Callable, Dict, List, Optional, Tuple @@ -272,7 +273,8 @@ def _validate_folder(params, folder_uid, command_name): 'Use "list-sf" to find the correct shared folder UID, ' 'or run "sync-down" if the folder was recently shared.' ) - if not isinstance(folder, (SharedFolderNode, SharedFolderFolderNode)): + if not (isinstance(folder, (SharedFolderNode, SharedFolderFolderNode)) or + (hasattr(folder, 'type') and folder.type == 'nested_share_folder')): raise CommandError( command_name, f'"{folder_uid}" is a personal folder. ' @@ -290,8 +292,12 @@ def _run_import(self, params, secrets, folder_uid, record_type, filter_name, # type: (KeeperParams, list, str, str, Optional[str], Optional[str], Optional[str], Optional[str], List[Tuple[str, str]], bool, str, Optional[Callable]) -> None """ Iterate *secrets* (list of dicts with at minimum 'name' and 'tags'), - apply all filters, then create Keeper records via batched vault/records_add - calls (up to BATCH_SIZE records per request). + apply all filters, then create Keeper records via batched API calls + (up to BATCH_SIZE records per request). + + When *folder_uid* belongs to a nested_share_folder the records are + created via ``vault/records/v3/add``; otherwise the legacy + ``vault/records_add`` endpoint is used. *value_fetcher*, when provided, is called as ``value_fetcher(name) -> str`` for each secret that passes filters and is not a dry-run. This enables @@ -320,10 +326,9 @@ def _run_import(self, params, secrets, folder_uid, record_type, filter_name, if dry_run: return - # Phase 2 – fetch values (if needed) and build TypedRecord + protobuf - # objects without touching the Keeper API yet. + # Phase 2 – fetch values (if needed) and build TypedRecord objects. skipped = 0 - pending = [] # type: List[Tuple[vault.TypedRecord, record_pb2.RecordAdd]] + typed_records = [] # type: List[vault.TypedRecord] for item in matched: name = item['name'] @@ -339,17 +344,118 @@ def _run_import(self, params, secrets, folder_uid, record_type, filter_name, value = item.get('value', '') fields = self._parse_secret_string(value) - record = self._build_keeper_record(name, fields, record_type) - pb = record_management.add_record_to_folder(params, record, folder_uid, pb_only=True) - if pb is not None: - pending.append((record, pb)) + typed_records.append(self._build_keeper_record(name, fields, record_type)) - if not pending: + if not typed_records: print(f'{command_name}: 0 record(s) created, {skipped} skipped.') return - # Phase 3 – send in batches of up to BATCH_SIZE to vault/records_add. + # Phase 3 – send in batches, routing to the appropriate API endpoint. + is_nsf = folder_uid in getattr(params, 'nested_share_folders', {}) + if is_nsf: + created, nsf_skipped = self._send_nsf_batches( + params, typed_records, folder_uid, command_name) + skipped += nsf_skipped + else: + created, legacy_skipped = self._send_legacy_batches( + params, typed_records, folder_uid, command_name) + skipped += legacy_skipped + + if created: + params.sync_data = True + print(f'{command_name}: {created} record(s) created, {skipped} skipped.') + + # ------------------------------------------------------------------ + # Phase-3 dispatch helpers + # ------------------------------------------------------------------ + + @staticmethod + def _typed_record_to_data(record): + # type: (vault.TypedRecord) -> dict + """Serialise a TypedRecord to the dict format expected by the v3 record API.""" + data = { + 'type': record.type_name, + 'title': record.title, + 'fields': [ + {'type': f.type, 'label': f.label or '', 'value': list(f.value)} + for f in record.fields + ], + 'custom': [ + {'type': f.type, 'label': f.label or '', 'value': list(f.value)} + for f in record.custom + ], + } + if record.notes: + data['notes'] = record.notes + return data + + @staticmethod + def _send_nsf_batches(params, typed_records, folder_uid, command_name): + # type: (KeeperParams, List[vault.TypedRecord], str, str) -> Tuple[int, int] + """Send *typed_records* to a nested_share_folder via ``vault/records/v3/add``.""" + from ..nested_share_folder.common import get_folder_key + from ..nested_share_folder.record_api import create_record_data_v3, record_add_v3 + + folder_key = get_folder_key(params, folder_uid, raise_on_missing=True) + + created = 0 + skipped = 0 + + for batch_start in range(0, len(typed_records), BATCH_SIZE): + batch = typed_records[batch_start:batch_start + BATCH_SIZE] + batch_num = batch_start // BATCH_SIZE + 1 + logging.info('%s: sending batch %d (%d record(s))', command_name, batch_num, len(batch)) + + uid_to_title = {} + adds = [] + for record in batch: + uid = utils.generate_uid() + rk = os.urandom(32) + data = CloudImportMixin._typed_record_to_data(record) + ra = create_record_data_v3( + record_uid=uid, record_key=rk, data=data, + folder_uid=folder_uid, folder_key=folder_key, + data_key=params.data_key, + client_modified_time=utils.current_milli_time(), + ) + adds.append(ra) + uid_to_title[uid] = record.title + + try: + rs = record_add_v3(params, adds) + except Exception as exc: + logging.warning('%s: batch %d failed: %s', command_name, batch_num, exc) + skipped += len(batch) + continue + + for r in rs.records: + uid = utils.base64_url_encode(r.record_uid) + title = uid_to_title.get(uid, uid) + if r.status == record_pb2.RS_SUCCESS: + logging.debug('%s: created record "%s"', command_name, title) + created += 1 + else: + logging.warning('%s: failed to create record "%s": status=%s', + command_name, title, r.status) + skipped += 1 + + return created, skipped + + @staticmethod + def _send_legacy_batches(params, typed_records, folder_uid, command_name): + # type: (KeeperParams, List[vault.TypedRecord], str, str) -> Tuple[int, int] + """Send *typed_records* to a regular shared folder via ``vault/records_add``.""" created = 0 + skipped = 0 + + pending = [] # type: List[Tuple[vault.TypedRecord, record_pb2.RecordAdd]] + for record in typed_records: + pb = record_management.add_record_to_folder(params, record, folder_uid, pb_only=True) + if pb is not None: + pending.append((record, pb)) + + if not pending: + return 0, 0 for batch_start in range(0, len(pending), BATCH_SIZE): batch = pending[batch_start:batch_start + BATCH_SIZE] @@ -385,6 +491,4 @@ def _run_import(self, params, secrets, folder_uid, record_type, filter_name, command_name, record.title, rs_rec.status) skipped += 1 - if created: - params.sync_data = True - print(f'{command_name}: {created} record(s) created, {skipped} skipped.') + return created, skipped diff --git a/keepercommander/commands/enterprise.py b/keepercommander/commands/enterprise.py index c6213d2b3..110e0d237 100644 --- a/keepercommander/commands/enterprise.py +++ b/keepercommander/commands/enterprise.py @@ -40,6 +40,7 @@ from .base import user_choice, suppress_exit, raise_parse_exception, dump_report_data, Command, field_to_title, \ report_output_parser from .enterprise_common import EnterpriseCommand +from .helpers.enterprise import is_valid_name_length, simplify_batch_responses from .automator import AutomatorListCommand from .enterprise_push import EnterprisePushCommand, enterprise_push_parser from .transfer_account import EnterpriseTransferUserCommand, transfer_user_parser @@ -811,10 +812,6 @@ def tree_node(node): role_ids.update(team_roles[team_uid]) if column == 'role_count': row.append(len(role_ids)) - elif kwargs.get('format') == 'json': - role_info = [{'role_id': rid, 'role_name': roles[rid]['name']} - for rid in role_ids if rid in roles] - row.append(role_info) else: role_names = [roles[role_id]['name'] for role_id in role_ids if role_id in roles] row.append(role_names) @@ -1170,7 +1167,10 @@ def execute(self, params, **kwargs): if not node_name or not node_name.strip(): logging.warning('Empty node name provided. Skipping.') continue - + + if not is_valid_name_length(node_name, 'Node name', 'enterprise-node'): + continue + n = node_lookup.get(node_name) if not n: n = node_lookup.get(node_name.lower()) @@ -1497,6 +1497,8 @@ def traverse_to_root(node_id, depth): return elif parent_id or kwargs.get('displayname'): display_name = kwargs.get('displayname') + if display_name and not is_valid_name_length(display_name, 'Node display name', 'enterprise-node'): + display_name = None def is_in_chain(node_id, parent_id): if node_id == parent_id: return True @@ -1506,8 +1508,10 @@ def is_in_chain(node_id, parent_id): return is_in_chain(nn['parent_id'], parent_id) if display_name and len(matched_nodes) > 1: - logging.warning('Cannot assign the same name to % nodes', len(matched_nodes)) + logging.warning('Cannot assign the same name to %s nodes', len(matched_nodes)) display_name = None + if not parent_id and not display_name: + return if not parent_id or not display_name: for node in matched_nodes: encrypted_data = node['encrypted_data'] @@ -1531,6 +1535,7 @@ def is_in_chain(node_id, parent_id): if request_batch: rss = api.execute_batch(params, request_batch) + simplify_batch_responses(rss) for rq, rs in zip(request_batch, rss): command = rq.get('command') if command == 'node_add': @@ -2000,6 +2005,7 @@ def execute(self, params, **kwargs): results = None if request_batch: results = api.execute_batch(params, request_batch) + simplify_batch_responses(results) for rq, rs in zip(request_batch, results): command = rq.get('command') if command == 'enterprise_user_add': @@ -2306,6 +2312,8 @@ def execute(self, params, **kwargs): # Collect role_ids for newly created roles new_role_ids = [] for role_name in role_names: + if not is_valid_name_length(role_name, 'Role name', 'enterprise-role'): + continue data = json.dumps({ "displayname": role_name }).encode('utf-8') role_id = self.get_enterprise_id(params) new_role_ids.append(role_id) @@ -2821,6 +2829,8 @@ def execute(self, params, **kwargs): role = matched_roles[0] if not role_name: role_name = role['data'].get('displayname') + if not is_valid_name_length(role_name, 'Role name', 'enterprise-role'): + return if not node_id: node_id = role['node_id'] dt = json.dumps({ "displayname": role_name }) @@ -2832,7 +2842,8 @@ def execute(self, params, **kwargs): "encrypted_data": utils.base64_url_encode( crypto.encrypt_aes_v1(dt.encode('utf-8'), params.enterprise['unencrypted_tree_key'])), "visible_below": role.get('visible_below') or False, - "new_user_inherit": role.get('new_user_inherit') or False + "new_user_inherit": role.get('new_user_inherit') or False, + "role_name": role_name, } request_batch.append(rq) if 'role_enforcements' in params.enterprise: @@ -2889,6 +2900,12 @@ def execute(self, params, **kwargs): logging.warning('Cannot assign the same name to %s roles', len(matched_roles)) kwargs['name'] = None + if kwargs.get('name') and not is_valid_name_length(kwargs.get('name'), 'Role name', 'enterprise-role'): + kwargs['name'] = None + + if not (node_id or kwargs.get('visible_below') or kwargs.get('new_user') or kwargs.get('name')): + return + for role in matched_roles: encrypted_data = role['encrypted_data'] if kwargs.get('name'): @@ -2910,11 +2927,12 @@ def execute(self, params, **kwargs): if request_batch: rss = api.execute_batch(params, request_batch) + simplify_batch_responses(rss) for rq, rs in zip(request_batch, rss): command = rq.get('command') if command == 'role_add': if rs['result'] == 'success': - logging.info('%s Role created with Role ID : %s', rq['role_name'], rq['role_id']) + logging.info('%s Role created with Role ID : %s', rq.get('role_name') or rq.get('role_id'), rq['role_id']) else: logging.warning('Failed to create role: %s', rs['message']) else: @@ -3403,6 +3421,8 @@ def execute(self, params, **kwargs): for item in queue: is_new_team = type(item) == str team_name = item if is_new_team else item['name'] + if is_new_team and not is_valid_name_length(team_name, 'Team name', 'enterprise-team'): + continue team_node_id = node_id if is_new_team else item['node_id'] team_uid = api.generate_record_uid() if is_new_team else item['team_uid'] team_key = api.generate_aes_key() @@ -3608,6 +3628,12 @@ def execute(self, params, **kwargs): logging.warning('Cannot set same name to %s teams', len(matched_teams)) kwargs['name'] = None + if kwargs.get('name') and not is_valid_name_length(kwargs.get('name'), 'Team name', 'enterprise-team'): + kwargs['name'] = None + + if not (node_id or kwargs.get('name') or kwargs.get('restrict_edit') or kwargs.get('restrict_share') or kwargs.get('restrict_view')): + return + for team in matched_teams: rq = { 'command': 'team_update', @@ -3622,6 +3648,7 @@ def execute(self, params, **kwargs): if request_batch: rss = api.execute_batch(params, request_batch) + simplify_batch_responses(rss) for rq, rs in zip(request_batch, rss): command = rq.get('command') team_name = None @@ -4014,6 +4041,7 @@ def execute(self, params, **kwargs): if request_batch: if not kwargs.get('dry_run'): rs = api.execute_batch(params, request_batch) + simplify_batch_responses(rs) if rs: team_add_success = 0 team_add_failure = 0 diff --git a/keepercommander/commands/helpers/enterprise.py b/keepercommander/commands/helpers/enterprise.py index cd0976876..97179f5f2 100644 --- a/keepercommander/commands/helpers/enterprise.py +++ b/keepercommander/commands/helpers/enterprise.py @@ -1,8 +1,57 @@ import logging +import re from ... import utils, crypto from ...params import KeeperParams +MAX_ENTERPRISE_NAME_LENGTH = 255 +_NAME_PREVIEW_LENGTH = 40 + +# Backend length-violation responses look like: ``max=185, length=255, value=`` +_BACKEND_LENGTH_ERROR_RE = re.compile( + r'max\s*=\s*(\d+)\s*,\s*length\s*=\s*(\d+)(?:\s*,\s*value\s*=.*)?', + re.IGNORECASE | re.DOTALL, +) + + +def is_valid_name_length(name, field_label, command_label): + """Return True if name fits within the enterprise name length limit; otherwise warn and return False.""" + if name is None: + return True + name = str(name) + if len(name) <= MAX_ENTERPRISE_NAME_LENGTH: + return True + preview = name[:_NAME_PREVIEW_LENGTH] + if len(name) > _NAME_PREVIEW_LENGTH: + preview += '...' + logging.warning( + '%s: %s \'%s\' is %d characters long. Maximum allowed is %d. Skipping.', + command_label, field_label, preview, len(name), MAX_ENTERPRISE_NAME_LENGTH, + ) + return False + + +def simplify_backend_message(message): + """Rewrite the backend's ``max=N, length=N, value=...`` length error into a friendlier sentence.""" + if not message or not isinstance(message, str): + return message + match = _BACKEND_LENGTH_ERROR_RE.search(message) + if not match: + return message + actual_len = int(match.group(2)) + max_len = int(match.group(1)) + return 'value is {0} characters but the maximum allowed is {1}'.format(actual_len, max_len) + + +def simplify_batch_responses(responses): + """Rewrite known noisy server validation messages in place on each response dict.""" + if not responses: + return + for rs in responses: + if isinstance(rs, dict) and rs.get('message'): + rs['message'] = simplify_backend_message(rs['message']) + + def is_addon_enabled(params, addon_name): # type: (KeeperParams, Dict[str, ]) -> Boolean def is_enabled(addon): return addon.get('enabled') or addon.get('included_in_product') diff --git a/keepercommander/commands/nested_share_folder/folder_commands.py b/keepercommander/commands/nested_share_folder/folder_commands.py index fa1218a84..57c04a7af 100644 --- a/keepercommander/commands/nested_share_folder/folder_commands.py +++ b/keepercommander/commands/nested_share_folder/folder_commands.py @@ -61,52 +61,102 @@ def execute(self, params, **kwargs): if current and current in getattr(params, 'nested_share_folders', {}): base_folder_uid = current - folder_name = self._parse_path(folder_path) + segments = self._parse_path(folder_path) + + parent_uid = base_folder_uid + last_idx = len(segments) - 1 + created_uid = None + + for idx, segment in enumerate(segments): + is_leaf = (idx == last_idx) + existing_uid = self._find_existing_child(params, segment, parent_uid) + if existing_uid: + if is_leaf: + logging.warning('nsf-mkdir: Folder "%s" already exists', segment) + return existing_uid + parent_uid = existing_uid + continue - existing_uid = self._find_existing_child(params, folder_name, base_folder_uid) - if existing_uid: - logging.warning('nsf-mkdir: Folder "%s" already exists', folder_name) - return existing_uid + seg_color = color if is_leaf else None + seg_inherit = inherit_permissions if is_leaf else True - with command_error_handler('nsf-mkdir'): - result = _nsf.create_folder_v3( - params=params, folder_name=folder_name, - parent_uid=base_folder_uid, - color=color, - inherit_permissions=inherit_permissions, - ) - check_result(result, 'nsf-mkdir') + with command_error_handler('nsf-mkdir'): + result = _nsf.create_folder_v3( + params=params, folder_name=segment, + parent_uid=parent_uid, + color=seg_color, + inherit_permissions=seg_inherit, + ) + check_result(result, 'nsf-mkdir') + + created_uid = result['folder_uid'] + self._cache_new_folder( + params, created_uid, segment, parent_uid, + folder_key=result.get('folder_key_unencrypted')) + + if not is_leaf: + logging.debug('nsf-mkdir: Created intermediate folder "%s"', segment) + parent_uid = created_uid params.sync_data = True - return result['folder_uid'] + return created_uid @staticmethod def _parse_path(folder_path): - # Collapse escaped slashes (//) to a sentinel so we can detect any - # stray path separator and refuse it — nsf-mkdir creates a single - # folder, not a nested hierarchy. - collapsed = folder_path.replace('//', '\x00') - if '/' in collapsed: - raise CommandError('nsf-mkdir', - 'Character "/" is reserved. Use "//" inside folder name') - name = collapsed.replace('\x00', '/').strip() - if not name: + """Split *folder_path* into a list of segment names. + """ + sentinel = '\x00' + collapsed = folder_path.replace('//', sentinel) + raw_segments = collapsed.split('/') + segments = [] + for raw in raw_segments: + name = raw.replace(sentinel, '/').strip() + if name: + segments.append(name) + if not segments: raise CommandError('nsf-mkdir', 'Invalid folder name') - return name + return segments + + @staticmethod + def _cache_new_folder(params, folder_uid, name, parent_uid, folder_key=None): + """Insert a just-created folder into the local NSF cache so that + subsequent segments in the same path can discover it as a parent + without requiring a full sync round-trip. + """ + nsf = getattr(params, 'nested_share_folders', None) + if nsf is None: + return + entry = { + 'name': name, + 'parent_uid': parent_uid or '', + } + if folder_key: + entry['folder_key_unencrypted'] = folder_key + nsf[folder_uid] = entry @staticmethod def _find_existing_child(params, folder_name, parent_uid): + """Find an existing NSF folder named *folder_name* whose parent matches + *parent_uid*. ``parent_uid=None`` means "root level". + """ nsf_folders = getattr(params, 'nested_share_folders', {}) name_lower = folder_name.lower() - expected_parent = parent_uid or '' + looking_for_root = not parent_uid for fuid, fobj in nsf_folders.items(): if fobj.get('name', '').lower() != name_lower: continue - existing_parent = normalize_parent_uid(fobj.get('parent_uid', '')) - if existing_parent == 'root': - existing_parent = '' - if existing_parent == expected_parent: - return fuid + raw_parent = fobj.get('parent_uid') or '' + normalized = normalize_parent_uid(raw_parent) + is_root_child = ( + normalized in ('', 'root') + or (raw_parent and raw_parent not in nsf_folders) + ) + if looking_for_root: + if is_root_child: + return fuid + else: + if raw_parent == parent_uid: + return fuid return None @@ -459,7 +509,7 @@ def _preview_and_confirm(self, params, removals, operation, force, dry_run, quie self._impact_summary(pr['folder_uid'], name, operation, pr.get('impact'), quiet) ) - if summary_lines: + if summary_lines and (dry_run or not force): for line in summary_lines: print(line) diff --git a/keepercommander/commands/nested_share_folder/parsers.py b/keepercommander/commands/nested_share_folder/parsers.py index 1539371ea..0d0e821b0 100644 --- a/keepercommander/commands/nested_share_folder/parsers.py +++ b/keepercommander/commands/nested_share_folder/parsers.py @@ -36,7 +36,9 @@ def _make_parser(prog, description): 'nsf-mkdir', 'Create a new Nested Share Folder using v3 API') nested_share_folder_mkdir_parser.add_argument( 'folder', type=str, - help='Folder name to create (use "//" to embed a literal "/" in the name)') + help='Folder name or path (e.g. "Parent/Child/Grand") to create. ' + 'Intermediate folders are created automatically. ' + 'Use "//" to embed a literal "/" in a segment name.') nested_share_folder_mkdir_parser.add_argument( '--color', type=str, choices=['none', 'red', 'orange', 'yellow', 'green', 'blue', 'gray'], diff --git a/keepercommander/commands/nested_share_folder/record_commands.py b/keepercommander/commands/nested_share_folder/record_commands.py index 164acab01..bbcdfe005 100644 --- a/keepercommander/commands/nested_share_folder/record_commands.py +++ b/keepercommander/commands/nested_share_folder/record_commands.py @@ -179,6 +179,26 @@ def __init__(self): def get_parser(self): return nested_share_record_update_parser + def _resolve_field_value(self, parsed): + raw = parsed.value + if not raw: + return raw + + action_params = [] + if self.is_json_value(raw, action_params): + return action_params[0] if action_params else None + action_params.clear() + if self.is_generate_value(raw, action_params): + if parsed.type == 'password': + return self.generate_password(action_params) + if parsed.type in ('oneTimeCode', 'otp'): + return self.generate_totp_url() + return raw + action_params.clear() + if self.is_base64_value(raw, action_params): + return action_params[0] if action_params else None + return raw + def execute(self, params, **kwargs): if kwargs.get('syntax_help'): print(record_fields_description) @@ -198,15 +218,24 @@ def execute(self, params, **kwargs): for spec in [f.strip() for f in kwargs.get('fields', []) if f.strip()]: try: parsed = RecordEditMixin.parse_field(spec) + value = self._resolve_field_value(parsed) + if value is None: + continue if parsed.type in fields: existing = fields[parsed.type] fields[parsed.type] = ([existing] if not isinstance(existing, list) - else existing) + [parsed.value] + else existing) + [value] else: - fields[parsed.type] = parsed.value + fields[parsed.type] = value except ValueError as e: raise CommandError('nsf-record-update', f'Invalid field specification: {e}') + if self.warnings: + for w in self.warnings: + logging.warning(w) + if not kwargs.get('force'): + return + with command_error_handler('nsf-record-update'): for identifier in record_uids: record_uid = _nsf.resolve_nested_share_record_uid(params, identifier) @@ -524,28 +553,35 @@ def _build_removals(self, params, record_args, folder_uid, operation): def _preview_and_confirm(self, params, removals, operation, force, dry_run): result = _nsf.remove_record_v3(params, removals, dry_run=True) any_error = False - summary_lines = [] + error_lines = [] + info_lines = [] for pr in result['preview_results']: title = self._record_title(params, pr['record_uid']) if pr.get('error'): any_error = True err = pr['error'] - summary_lines.append( + error_lines.append( f" {title} [{pr['record_uid']}]: " f"{err.get('code', '')} — {err.get('message', '')}" ) else: - summary_lines.extend( + info_lines.extend( self._impact_summary(pr['record_uid'], title, operation, pr.get('impact')) ) - for line in summary_lines: - print(line) - + # Errors must always surface, even in --force mode, so the caller (or + # Service Mode HTTP layer) can see why the operation aborted. if any_error: + for line in error_lines: + print(line) print('\nOne or more records could not be previewed. Aborting.') return + + if dry_run or not force: + for line in info_lines: + print(line) + if dry_run: print('\n[Dry-run] No records were deleted.') return diff --git a/keepercommander/commands/pam_launch/launch.py b/keepercommander/commands/pam_launch/launch.py index 1435ab112..503478a75 100644 --- a/keepercommander/commands/pam_launch/launch.py +++ b/keepercommander/commands/pam_launch/launch.py @@ -62,7 +62,6 @@ unregister_tunnel_session, unregister_conversation_key, get_keeper_tokens, - escalate_close, CloseConnectionReasons, ) from ..tunnel.port_forward.TunnelGraph import TunnelDAG @@ -1570,17 +1569,13 @@ def signal_handler_fn(signum, frame): original_handler = signal.signal(signal.SIGINT, signal_handler_fn) - # Workflow lease expiry: schedule a hard kill at expiresOn matching + # Workflow lease expiry: schedule teardown at expiresOn matching # the web vault (immediate teardown, no grace period, no reconnect). # The "Access expired" line is printed AFTER terminal reset in finally # so the message survives reset_local_terminal_after_pam_session(). - # On expiry we soft-close the tube and escalate to force_close_tube - # after FORCE_CLOSE_DELAY_SECONDS so any in-flight forwarded streams - # (SSH bytes etc.) are severed instead of lingering until the user - # disconnects manually. Escalation is gated on local hasattr + - # remote SDP version (FORCE_CLOSE_MIN_VERSION). + # On expiry we close the tube; the connection-closed cleanup path then + # stops the websocket and unregisters the tunnel session. lease_timer = None - force_close_timer_holder = {} # mutable holder so cleanup can cancel if workflow_expires_on_ms and workflow_expires_on_ms > 0: seconds_until_expiry = (workflow_expires_on_ms / 1000.0) - time.time() _lease_tube_id = tunnel_result['tunnel'].get('tube_id') @@ -1591,27 +1586,20 @@ def _on_lease_expired(): lease_expired = True shutdown_requested = True if _lease_tube_id and _lease_tube_registry is not None: - # Fetch remote version lazily: the SDP answer arrives - # asynchronously; capturing eagerly at schedule time - # would race for short leases scheduled before SDP. - remote_ver = tunnel_result['tunnel'].get('remote_webrtc_version') - if not remote_ver: - sess = get_tunnel_session(_lease_tube_id) - remote_ver = ( - getattr(sess, 'remote_webrtc_version', None) - if sess else None + try: + _lease_tube_registry.close_tube( + _lease_tube_id, + reason=CloseConnectionReasons.AdminClosed, + ) + except Exception as e: + logging.debug( + f"[lease-expiry launch tube={_lease_tube_id[:8]}] " + f"close_tube failed: {e}" ) - force_close_timer_holder['t'] = escalate_close( - _lease_tube_registry, - _lease_tube_id, - remote_webrtc_version=remote_ver, - reason=CloseConnectionReasons.AdminClosed, - log_prefix=f"[lease-expiry launch tube={_lease_tube_id[:8]}] ", - ) if seconds_until_expiry <= 0: - # Already expired at session start: run the close-and-escalate - # path immediately so cleanup goes through the same flow as a + # Already expired at session start: run the close path + # immediately so cleanup goes through the same flow as a # mid-session expiry. _on_lease_expired() else: diff --git a/keepercommander/commands/pam_launch/rust_log_filter.py b/keepercommander/commands/pam_launch/rust_log_filter.py index b1ee55e96..54f75d8fa 100644 --- a/keepercommander/commands/pam_launch/rust_log_filter.py +++ b/keepercommander/commands/pam_launch/rust_log_filter.py @@ -191,23 +191,9 @@ def enter_pam_launch_terminal_rust_logging(): # the channel is torn down, or TURN ``fail to refresh permissions`` warnings # from the relay-conn task as it observes the deallocated allocation. # -# The window must outlive both: -# 1. The soft→hard close escalation in ``escalate_close`` -# (``FORCE_CLOSE_DELAY_SECONDS`` = 3 s) -# 2. A brief TURN refresh-task latency after the PeerConnection drop cascade -# -# Imported lazily below to avoid a top-level cycle (this module is imported -# during pam_launch init, before the tunnel helpers are loaded for some -# callers). -def _force_close_delay_seconds(): - try: - from ..tunnel.port_forward.tunnel_helpers import FORCE_CLOSE_DELAY_SECONDS - return FORCE_CLOSE_DELAY_SECONDS - except Exception: - return 3.0 - - -_DEFAULT_RUST_LOG_FILTER_GRACE_SEC = _force_close_delay_seconds() + 1.5 +# The window must outlive both the tube close + teardown cascade (~3 s) and a +# brief TURN refresh-task latency after the PeerConnection drop cascade. +_DEFAULT_RUST_LOG_FILTER_GRACE_SEC = 4 # Refcount of active pam-launch sessions that have rust-log filtering installed. # Incremented in enter_*, decremented at the END of the grace timer in @@ -282,7 +268,8 @@ def _do_exit_rust_logging(token): def exit_pam_launch_terminal_rust_logging(token, grace_sec=_DEFAULT_RUST_LOG_FILTER_GRACE_SEC): """Restore Rust/webrtc logger state after pam launch terminal session. - The filter is removed after ``grace_sec`` seconds (default 2.5s) so that + The filter is removed after ``grace_sec`` seconds (default + ``_DEFAULT_RUST_LOG_FILTER_GRACE_SEC``) so that late records from the Rust runtime (e.g. ``webrtc-sctp`` stream teardown messages that arrive just after session exit) are still caught by the filter and do not leak to the console in front of the subsequent diff --git a/keepercommander/commands/record.py b/keepercommander/commands/record.py index a507c0bd6..422243d21 100644 --- a/keepercommander/commands/record.py +++ b/keepercommander/commands/record.py @@ -143,9 +143,11 @@ def register_command_info(aliases, command_info): search_parser = argparse.ArgumentParser(prog='search', description='Search the vault. Words can be in any order.') search_parser.add_argument('pattern', nargs='*', type=str, action='store', help='search terms (space-separated, order independent)') search_parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', help='verbose output') -search_parser.add_argument('-c', '--categories', dest='categories', action='store', - help='One or more of these letters for categories to search: "r" = records, ' - '"s" = shared folders, "t" = teams, "d" = Nested Share Folders') +search_parser.add_argument('-c', '--categories', dest='categories', action='append', + help='Category to search — repeatable: "r" = records, "s" = shared folders, ' + '"t" = teams, "d" = Nested Share folders. ' + 'Pass multiple times (e.g. -c s -c d) or combine letters (e.g. -c sd). ' + 'Default when omitted: all categories (rstd).') search_parser.add_argument('--regex', dest='regex', action='store_true', help='treat pattern as a regular expression instead of space-separated search terms') search_parser.add_argument('--device', dest='device', action='store_true', @@ -1455,7 +1457,7 @@ def execute(self, params, **kwargs): else: pattern = '' # Empty pattern matches all in token mode - categories = (kwargs.get('categories') or 'rstd').lower() + categories = ''.join(kwargs.get('categories') or ['rstd']).lower() skip_details = not verbose nsf_records_map = getattr(params, 'nested_share_records', {}) or {} @@ -1575,13 +1577,14 @@ def execute(self, params, **kwargs): f"Type: {item['record_type']}, Description: {item['description']}, Record Category: {item.get('record_category', 'Classic')}"] elif item['type'] == 'shared_folder': row = [item['type'], item['shared_folder_uid'], item['name'], - f"Can Edit: {item['can_edit']}, Can Share: {item['can_share']}"] + f"Folder Category: Classic, Can Edit: {item['can_edit']}, Can Share: {item['can_share']}"] elif item['type'] == 'team': row = [item['type'], item['team_uid'], item['name'], f"Restrict Edit: {item['restrict_edit']}, Restrict View: {item['restrict_view']}, Restrict Share: {item['restrict_share']}"] elif item['type'] == 'nested_share_folder': - details = (f"Parent UID: {item['parent_uid']}" - if item.get('parent_uid') else '') + details = 'Folder Category: NestedShare' + if item.get('parent_uid'): + details += f", Parent UID: {item['parent_uid']}" row = [item['type'], item['folder_uid'], item['name'], details] table.append(row) diff --git a/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py b/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py index 82a494ff9..1349f71d2 100644 --- a/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py +++ b/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py @@ -129,20 +129,6 @@ def parse(v): return False -# Minimum keeper-pam-webrtc-rs version that exposes force_close_tube. Both the -# local Rust crate AND the remote peer must satisfy this gate before Commander -# escalates a soft close to a force close. Local check uses hasattr (the binding -# attribute is missing on older crates), remote check uses the SDP-advertised -# version string. -FORCE_CLOSE_MIN_VERSION = "2.1.18" - -# Default delay between the soft close and the force-close escalation. Matches -# the consumer-side budget agreed with the gateway (gateway-side -# KEEPER_GATEWAY_FORCE_CLOSE_TIMEOUT is 6s; we run faster on the consumer because -# at lease expiry there is no reason to wait long). -FORCE_CLOSE_DELAY_SECONDS = 3.0 - - def print_above_keeper_prompt(msg): """Print ``msg`` so the keeper-shell prompt redraws itself underneath it. @@ -181,66 +167,6 @@ def print_above_keeper_prompt(msg): pass -def escalate_close( - tube_registry, - tube_id, - *, - remote_webrtc_version=None, - reason=None, - hard_after_seconds=FORCE_CLOSE_DELAY_SECONDS, - log_prefix="", -): - """ - Soft-close a tube now, then escalate to force_close_tube after - `hard_after_seconds` if both endpoints support it. - - The soft close stops new channel creation and emits CloseConnection control - frames; the force close (when available) drops the local TCP listener, - severs in-flight forwarded TCP streams (SSH, MySQL, etc.) and tears down - the peer connection on a short bounded budget. - - Returns the scheduled `threading.Timer` (or None if escalation is not - available) so callers can cancel it on a clean exit. - """ - if reason is None: - reason = CloseConnectionReasons.AdminClosed - - try: - tube_registry.close_tube(tube_id, reason=reason) - except Exception as e: - logging.debug(f"{log_prefix}soft close_tube failed: {e}") - - has_local = hasattr(tube_registry, "force_close_tube") - has_remote = _version_at_least(remote_webrtc_version, FORCE_CLOSE_MIN_VERSION) - if not has_local: - logging.debug( - f"{log_prefix}force_close_tube unavailable in local keeper_pam_webrtc_rs - " - f"soft close only" - ) - return None - if not has_remote: - logging.debug( - f"{log_prefix}remote keeper-pam-webrtc {remote_webrtc_version!r} < " - f"{FORCE_CLOSE_MIN_VERSION} - soft close only" - ) - return None - - def _do_force_close(): - try: - logging.debug( - f"{log_prefix}escalating to force_close_tube({tube_id}) after " - f"{hard_after_seconds}s" - ) - tube_registry.force_close_tube(tube_id, reason=reason) - except Exception as e: - logging.debug(f"{log_prefix}force_close_tube failed: {e}") - - timer = threading.Timer(hard_after_seconds, _do_force_close) - timer.daemon = True - timer.start() - return timer - - # Constants NONCE_LENGTH = 12 MAIN_NONCE_LENGTH = 16 @@ -834,6 +760,32 @@ def get_config_uid(params, encrypted_session_token, encrypted_transmission_key, return None +def get_config_uid_via_pam_link(params, record_uid): + """Resolve a resource record's PAM Config UID via KRouter's PAM_LINK graph. + + Returns the config_uid that owns the resource. Used as a fallback when + ``get_config_uid`` (legacy ``/api/user/get_leafs`` with graphId=0) returns + nothing for resources whose link lives only in the new PAM_LINK stream. + + Returns the config_uid as a base64-url-safe string, or empty string on + failure / no link. + """ + try: + from ....keeper_dag.proto import GraphSync_pb2 as gs_pb2 + from ...pam.router_helper import _post_request_to_router + record_uid_bytes = url_safe_str_to_bytes(record_uid) + rq = gs_pb2.GraphSyncLeafsQuery(vertices=[record_uid_bytes]) + rs = _post_request_to_router(params, 'graph-sync/pam/get_leafs', + rq_proto=rq, rs_type=gs_pb2.GraphSyncRefsResult) + if rs and rs.refs: + for ref in rs.refs: + if ref.value: + return utils.base64_url_encode(ref.value) + except Exception as e: + logging.debug('get_config_uid_via_pam_link: lookup failed for %s: %s', record_uid, e) + return '' + + def get_keeper_tokens(params): transmission_key = generate_random_bytes(32) server_public_key = rest_api.SERVER_PUBLIC_KEYS[params.rest_context.server_key_id] diff --git a/keepercommander/commands/tunnel_and_connections.py b/keepercommander/commands/tunnel_and_connections.py index 77ef48fff..c5f165c35 100644 --- a/keepercommander/commands/tunnel_and_connections.py +++ b/keepercommander/commands/tunnel_and_connections.py @@ -29,10 +29,11 @@ from .base import Command, GroupCommand, dump_report_data, RecordMixin from .tunnel.port_forward.TunnelGraph import TunnelDAG -from .tunnel.port_forward.tunnel_helpers import find_open_port, get_config_uid, get_keeper_tokens, \ +from .tunnel.port_forward.tunnel_helpers import find_open_port, get_config_uid, get_config_uid_via_pam_link, \ + get_keeper_tokens, \ get_or_create_tube_registry, get_gateway_uid_from_record, resolve_record, resolve_pam_config, resolve_folder, \ remove_field, start_rust_tunnel, get_tunnel_session, unregister_tunnel_session, CloseConnectionReasons, \ - wait_for_tunnel_connection, create_rust_webrtc_settings, escalate_close, \ + wait_for_tunnel_connection, create_rust_webrtc_settings, \ print_above_keeper_prompt from .pam.router_helper import get_dag_leafs from .tunnel_registry import ( @@ -1014,13 +1015,9 @@ def execute(self, params, **kwargs): self._print_keeperdb_proxy_banner(host, port, db_type_for_banner) # Workflow lease expiry handling. # - # At expiresOn we soft-close the tube (stops new channels, sends - # CloseConnection control frames) and, after a short delay, escalate - # to force_close_tube which drops the local TCP listener and severs - # any active forwarded streams (SSH, MySQL, etc.). The escalation - # only fires when both the local Rust crate and the remote peer - # advertise FORCE_CLOSE_MIN_VERSION; older peers get the soft close - # only and the in-flight session lingers until natural disconnect. + # At expiresOn we close the tube (stops new channels, sends + # CloseConnection control frames); the connection-closed cleanup + # path then stops the websocket and unregisters the tunnel session. if workflow_expires_on_ms and workflow_expires_on_ms > 0: seconds_until_expiry = (workflow_expires_on_ms / 1000.0) - time.time() tube_id = result.get('tube_id') @@ -1045,15 +1042,15 @@ def _close_on_lease_expiry(_tube_id=tube_id, _record_uid=record_uid): f"forwarded connections will be terminated." f"{bcolors.ENDC}" ) - sess = get_tunnel_session(_tube_id) - remote_ver = getattr(sess, 'remote_webrtc_version', None) if sess else None - escalate_close( - tube_registry, - _tube_id, - remote_webrtc_version=remote_ver, - reason=CloseConnectionReasons.AdminClosed, - log_prefix=f"[lease-expiry tunnel record={_record_uid}] ", - ) + try: + tube_registry.close_tube( + _tube_id, reason=CloseConnectionReasons.AdminClosed, + ) + except Exception as e: + logging.debug( + f"[lease-expiry tunnel record={_record_uid}] " + f"close_tube failed: {e}" + ) # Wake any --foreground / --run blocking wait so the # process self-terminates. Default interactive mode # does not register an event here. @@ -2624,6 +2621,10 @@ class PAMConnectionEditCommand(Command): 'the port from the record will be used.') parser.add_argument('--key-events', '-k', dest='key_events', choices=choices, help='Toggle Key Events settings') + parser.add_argument('--scrollback', '-sb', required=False, dest='scrollback', action='store', + help='Maximum Scrollback Size (terminal history). Integer to set, ' + 'empty string to remove. Supported only for pamDatabase (any DB protocol) and ' + 'pamMachine/pamDirectory (ssh/telnet/kubernetes).') parser.add_argument('--rotate-on-termination', required=False, dest='rotate_on_termination', choices=['on', 'off'], help='Rotate launch credentials when the PAM session ends (DAG resource meta)') @@ -2672,6 +2673,51 @@ def execute(self, params, **kwargs): f"pamRemoteBrowser, pamNetworkConfiguration pamAwsConfiguration, and " f"pamAzureConfiguration records{bcolors.ENDC}") + # --scrollback: validate record type + effective protocol before any mutation + scrollback_arg = kwargs.get('scrollback', None) + scrollback_clear = False + scrollback_value = None # parsed int, or None to skip apply + if scrollback_arg is not None: + db_scrollback_protocols = {'mysql', 'postgresql', 'sql-server', 'mariadb', 'oracle', + 'mongodb', 'redis', 'elasticsearch', 'clickhouse', 'dynamodb'} + terminal_scrollback_protocols = {'ssh', 'telnet', 'kubernetes'} + if record_type == 'pamDatabase': + allowed_protocols = db_scrollback_protocols + elif record_type in ('pamMachine', 'pamDirectory'): + allowed_protocols = terminal_scrollback_protocols + else: + raise CommandError('pam connection edit', + f'{bcolors.FAIL}--scrollback is only supported for pamDatabase, pamMachine, and pamDirectory ' + f'records. Record "{record_uid}" is of type "{record_type}".{bcolors.ENDC}') + + existing_ps = record.get_typed_field('pamSettings') + existing_protocol = '' + if existing_ps and existing_ps.value and isinstance(existing_ps.value[0], dict): + existing_protocol = existing_ps.value[0].get('connection', {}).get('protocol') or '' + new_protocol_arg = kwargs.get('protocol', None) + if kwargs.get('connections') == 'on' and new_protocol_arg is not None: + effective_protocol = new_protocol_arg # may be '' to clear + else: + effective_protocol = existing_protocol + if effective_protocol not in allowed_protocols: + raise CommandError('pam connection edit', + f'{bcolors.FAIL}--scrollback is not supported for protocol "{effective_protocol or "(unset)"}" ' + f'on {record_type} records. Allowed protocols: {", ".join(sorted(allowed_protocols))}.{bcolors.ENDC}') + + if scrollback_arg == '': + scrollback_clear = True + else: + try: + scrollback_value = int(scrollback_arg) + except (ValueError, TypeError): + raise CommandError('pam connection edit', + f'{bcolors.FAIL}--scrollback must be a non-negative integer or empty string. ' + f'Got: "{scrollback_arg}".{bcolors.ENDC}') + if scrollback_value < 0: + raise CommandError('pam connection edit', + f'{bcolors.FAIL}--scrollback must be a non-negative integer or empty string. ' + f'Got: "{scrollback_arg}".{bcolors.ENDC}') + encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) if record_type in "pamNetworkConfiguration pamAwsConfiguration pamAzureConfiguration".split(): tdag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, record_uid, is_config=True, @@ -2758,6 +2804,24 @@ def execute(self, params, **kwargs): else: logging.debug(f'Unexpected value for --key-events {key_events} (ignored)') + # --scrollback: apply (validated above; record_type + effective protocol already checked) + if scrollback_clear or scrollback_value is not None: + psv = pam_settings.value[0] if pam_settings and pam_settings.value else {} + vcon = psv.get('connection', {}) if isinstance(psv, dict) else {} + current_sb = vcon.get('scrollback') if isinstance(vcon, dict) else None + if scrollback_clear: + if current_sb is not None: + pam_settings.value[0]["connection"].pop('scrollback', None) + dirty = True + else: + logging.debug(f'scrollback is already unset on record={record_uid}') + else: + if current_sb != scrollback_value: + pam_settings.value[0]["connection"]["scrollback"] = scrollback_value + dirty = True + else: + logging.debug(f'scrollback is already {scrollback_value} on record={record_uid}') + if dirty: record_management.update_record(params, record) api.sync_down(params) @@ -2768,8 +2832,36 @@ def execute(self, params, **kwargs): f"Please make sure you have edit rights to record {record_uid} {bcolors.ENDC}") dirty = False + # If only record-level args were passed (e.g. --scrollback, --key-events, --protocol + # alone), the record update above is complete — skip the DAG/config lookup, which + # would otherwise raise a misleading "No PAM Configuration UID set" error when the + # resource isn't linked to a config yet. Mirrors the PAMRbiEditCommand pattern. + dag_affecting = (kwargs.get('config') or kwargs.get('admin') or kwargs.get('launch_user') + or kwargs.get('clear_launch_user') or kwargs.get('connections') + or kwargs.get('recording') or kwargs.get('typescriptrecording') + or kwargs.get('rotate_on_termination')) + if not dag_affecting: + return + existing_config_uid = get_config_uid(params, encrypted_session_token, encrypted_transmission_key, record_uid) + # When the user did not pass --configuration, fall back to + # first the legacy /api/user/get_leafs if that also returned nothing, + # then the new /api/user/graph-sync/pam/get_leafs (PAM_LINK stream) + # Without this, `tdag` would be built with config_uid=None, has_graph=False, + # and the misleading "No PAM Configuration UID set" error would fire + # even though the link is resolvable. + if not config_uid: + if existing_config_uid: + config_uid = existing_config_uid + logging.debug('pam connection edit: using DAG-resolved config_uid: %s', config_uid) + else: + found = get_config_uid_via_pam_link(params, record_uid) + if found: + logging.debug('pam connection edit: resolved config_uid via graph-sync/pam fallback: %s', found) + existing_config_uid = found + config_uid = found + tdag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, config_uid, transmission_key=transmission_key) old_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, existing_config_uid, @@ -3149,6 +3241,10 @@ class PAMRbiEditCommand(Command): help='Allow navigation via direct URL manipulation (on/off/default)') parser.add_argument('--ignore-server-cert', '-isc', dest='ignore_server_cert', choices=choices, help='Ignore server certificate errors (on/off/default)') + parser.add_argument('--allow-file-uploads', '-fu', dest='allow_file_uploads', choices=choices, + help='Allow file uploads in RBI sessions (on/off/default)') + parser.add_argument('--allow-file-downloads', '-fd', dest='allow_file_downloads', choices=choices, + help='Allow file downloads in RBI sessions (on/off/default)') # URL Filtering parser.add_argument('--allowed-urls', '-au', dest='allowed_urls', action='append', @@ -3197,6 +3293,8 @@ def execute(self, params, **kwargs): # New RBI settings (Phase 1 - KC-1034) allow_url_navigation = kwargs.get('allow_url_navigation') # on/off/default/None ignore_server_cert = kwargs.get('ignore_server_cert') # on/off/default/None + allow_file_uploads = kwargs.get('allow_file_uploads') # on/off/default/None + allow_file_downloads = kwargs.get('allow_file_downloads') # on/off/default/None allowed_urls = kwargs.get('allowed_urls') # list or None allowed_resource_urls = kwargs.get('allowed_resource_urls') # list or None autofill_targets = kwargs.get('autofill_targets') # list or None @@ -3214,6 +3312,8 @@ def execute(self, params, **kwargs): has_new_settings = any([ allow_url_navigation is not None, ignore_server_cert is not None, + allow_file_uploads is not None, + allow_file_downloads is not None, allowed_urls is not None, allowed_resource_urls is not None, autofill_targets is not None, @@ -3389,6 +3489,14 @@ def update_connection_int(field_name, value): if ignore_server_cert: update_connection_toggle('ignoreInitialSslCert', ignore_server_cert) + # Browser Settings - allowFileUploads (on/off/default) + if allow_file_uploads: + update_connection_toggle('allowFileUploads', allow_file_uploads) + + # Browser Settings - allowFileDownloads (on/off/default) + if allow_file_downloads: + update_connection_toggle('allowFileDownloads', allow_file_downloads) + # URL Filtering - allowedUrlPatterns (multi-value, joined with newlines) if allowed_urls is not None: update_connection_string('allowedUrlPatterns', allowed_urls) diff --git a/keepercommander/config_storage/loader.py b/keepercommander/config_storage/loader.py index 978a37d61..219e5951e 100644 --- a/keepercommander/config_storage/loader.py +++ b/keepercommander/config_storage/loader.py @@ -110,7 +110,7 @@ def _get_plugin(url): # type: (str) -> SecureStorageBase if not par.scheme: raise SecureStorageException(f'Configuration file error: "{CONFIG_STORAGE_URL}" is not URL') - plugin_name = par.scheme + plugin_name = par.scheme.replace('-', '_') if not any(x for x in pkgutil.iter_modules(config_storage.__path__) if x.name == plugin_name): raise SecureStorageException(f'Protected storage "{plugin_name}" is not supported') diff --git a/keepercommander/config_storage/os-keychain/__init__.py b/keepercommander/config_storage/os_keychain/__init__.py similarity index 100% rename from keepercommander/config_storage/os-keychain/__init__.py rename to keepercommander/config_storage/os_keychain/__init__.py diff --git a/keepercommander/config_storage/os-keychain/os_keychain_storage.py b/keepercommander/config_storage/os_keychain/os_keychain_storage.py similarity index 100% rename from keepercommander/config_storage/os-keychain/os_keychain_storage.py rename to keepercommander/config_storage/os_keychain/os_keychain_storage.py diff --git a/keepercommander/nested_share_folder/folder_api.py b/keepercommander/nested_share_folder/folder_api.py index 94f90179e..e6d8f2307 100644 --- a/keepercommander/nested_share_folder/folder_api.py +++ b/keepercommander/nested_share_folder/folder_api.py @@ -93,7 +93,7 @@ def _prepare_folder_for_creation(params, folder_uid, folder_name, parent_uid, else SetBooleanValue.BOOLEAN_FALSE), color=color) fd.folderKey = encrypted_fk - return fd + return fd, folder_key # ══════════════════════════════════════════════════════════════════════════ @@ -173,13 +173,14 @@ def resolve_folder_identifier(params, folder_identifier): def create_folder_v3(params, folder_name, parent_uid=None, color=None, inherit_permissions=True): uid = utils.generate_uid() - fd = _prepare_folder_for_creation(params, uid, folder_name, parent_uid, - color, inherit_permissions) + fd, folder_key = _prepare_folder_for_creation( + params, uid, folder_name, parent_uid, color, inherit_permissions) response = folder_add_v3(params, [fd]) if response.folderAddResults: r = response.folderAddResults[0] return { 'folder_uid': uid, + 'folder_key_unencrypted': folder_key, 'status': folder_pb2.FolderModifyStatus.Name(r.status), 'message': r.message, 'success': r.status == folder_pb2.SUCCESS, @@ -190,20 +191,22 @@ def create_folder_v3(params, folder_name, parent_uid=None, color=None, def create_folders_batch_v3(params, folder_specs): if len(folder_specs) > 100: raise ValueError("Maximum 100 folders at a time") - fd_list, uid_map = [], {} + fd_list, uid_map, key_map = [], {}, {} for idx, spec in enumerate(folder_specs): uid = utils.generate_uid() uid_map[idx] = uid name = spec.get('name') if not name: raise ValueError(f"Spec at index {idx} missing 'name'") - fd = _prepare_folder_for_creation( + fd, folder_key = _prepare_folder_for_creation( params, uid, name, spec.get('parent_uid'), spec.get('color'), spec.get('inherit_permissions', True)) fd_list.append(fd) + key_map[idx] = folder_key response = folder_add_v3(params, fd_list) return [{ 'folder_uid': uid_map.get(i, utils.base64_url_encode(r.folderUid)), + 'folder_key_unencrypted': key_map.get(i), 'name': folder_specs[i].get('name'), 'status': folder_pb2.FolderModifyStatus.Name(r.status), 'message': r.message, diff --git a/keepercommander/nested_share_folder/record_api.py b/keepercommander/nested_share_folder/record_api.py index 7eb432567..db886a78b 100644 --- a/keepercommander/nested_share_folder/record_api.py +++ b/keepercommander/nested_share_folder/record_api.py @@ -151,7 +151,9 @@ def update_record_v3(params, record_uid, data=None, title=None, if 'data_unencrypted' in rec: raw = rec['data_unencrypted'] if isinstance(raw, bytes): - existing = json.loads(raw.decode()) + existing = json.loads(raw.decode('utf-8')) + elif isinstance(raw, str): + existing = json.loads(raw) data = existing.copy() if existing else {'fields': []} if title is not None: data['title'] = title diff --git a/keepercommander/utils.py b/keepercommander/utils.py index fd0ebfdde..d9c8c7cef 100644 --- a/keepercommander/utils.py +++ b/keepercommander/utils.py @@ -531,66 +531,54 @@ def value_to_boolean(value): return None def get_ssl_cert_file(): - """Get SSL certificate file path, preferring system CA store for corporate environments like Zscaler""" - import ssl - import platform + """Resolve the SSL CA bundle path. + + KEEPER_SSL_CERT_FILE accepts: 'certifi', 'system', 'none'/'false', or a PEM path. + Defaults to the bundled certifi store; does not consult SSL_CERT_FILE from the + environment to avoid silent trust-store hijacking by unrelated tools. + """ import certifi - import os - - # Allow user to override via environment variable + user_cert_file = os.getenv('KEEPER_SSL_CERT_FILE') if user_cert_file: - if user_cert_file.lower() == 'system': - pass # Continue with system detection below - elif user_cert_file.lower() == 'certifi': + choice = user_cert_file.lower() + if choice == 'certifi': return certifi.where() - elif user_cert_file.lower() == 'none' or user_cert_file.lower() == 'false': - return False # Disable SSL verification - elif os.path.exists(user_cert_file): - return user_cert_file - else: - # Don't use logging here as it can interfere with main logging config - print(f"Warning: SSL cert file specified in KEEPER_SSL_CERT_FILE not found: {user_cert_file}", file=sys.stderr) - - # Try to use system CA store first for corporate environments + if choice in ('none', 'false'): + return False + if choice != 'system': + if os.path.exists(user_cert_file): + return user_cert_file + print( + f"Warning: KEEPER_SSL_CERT_FILE points to a non-existent file: " + f"{user_cert_file}; falling back to the bundled certifi store.", + file=sys.stderr, + ) + return certifi.where() + try: - # On macOS, try Homebrew certificates first (better for corporate environments like Zscaler) - if platform.system() == 'Darwin': - system_ca_paths = [ - '/opt/homebrew/etc/ca-certificates/cert.pem', # Homebrew CA bundle (best for Zscaler) - '/usr/local/etc/ssl/cert.pem', # Homebrew SSL (older location) - '/etc/ssl/cert.pem', # macOS system CA bundle - ] - for ca_path in system_ca_paths: - if os.path.exists(ca_path): - return ca_path - - # On Linux/Unix systems - elif platform.system() == 'Linux': - system_ca_paths = [ - '/etc/ssl/certs/ca-certificates.crt', # Debian/Ubuntu - '/etc/pki/tls/certs/ca-bundle.crt', # RHEL/CentOS - '/etc/ssl/ca-bundle.pem', # OpenSUSE - '/etc/ssl/cert.pem', # Generic - ] - for ca_path in system_ca_paths: - if os.path.exists(ca_path): - return ca_path - - # Try to get default SSL context locations - try: - default_locations = ssl.get_default_verify_paths() - if default_locations.cafile and os.path.exists(default_locations.cafile): - return default_locations.cafile - if default_locations.capath and os.path.exists(default_locations.capath): - return default_locations.capath - except: - pass - + system = platform.system() + if system == 'Darwin': + candidates = ( + '/opt/homebrew/etc/ca-certificates/cert.pem', + '/usr/local/etc/ssl/cert.pem', + '/etc/ssl/cert.pem', + ) + elif system == 'Linux': + candidates = ( + '/etc/ssl/certs/ca-certificates.crt', + '/etc/pki/tls/certs/ca-bundle.crt', + '/etc/ssl/ca-bundle.pem', + '/etc/ssl/cert.pem', + ) + else: + candidates = () + for ca_path in candidates: + if os.path.exists(ca_path): + return ca_path except Exception: pass - - # Fall back to certifi if system CA not available + return certifi.where() diff --git a/requirements.txt b/requirements.txt index 5264fa503..7af34666c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ requests>=2.31.0 cryptography>=46.0.6 protobuf>=5.29.6 keeper-secrets-manager-core>=16.6.0 -keeper_pam_webrtc_rs>=2.1.17 +keeper_pam_webrtc_rs>=2.1.18 pydantic>=2.6.4 flask pyngrok>=7.5.0 diff --git a/setup.cfg b/setup.cfg index dd9a215b7..6c52380c5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,7 +53,7 @@ install_requires = requests>=2.31.0 tabulate websockets - keeper_pam_webrtc_rs>=2.1.17 + keeper_pam_webrtc_rs>=2.1.18 pydantic>=2.6.4 fpdf2>=2.8.3 tzlocal>=5.0 diff --git a/unit-tests/pam/test_get_config_uid_via_pam_link.py b/unit-tests/pam/test_get_config_uid_via_pam_link.py new file mode 100644 index 000000000..54735fa23 --- /dev/null +++ b/unit-tests/pam/test_get_config_uid_via_pam_link.py @@ -0,0 +1,70 @@ +""" +Unit tests for tunnel_helpers.get_config_uid_via_pam_link — the KRouter +graph-sync/pam/get_leafs fallback used by `pam connection edit` when the +legacy get_leafs lookup misses. Mirrors the web vault's +`getMultiLeafsPamLinkDag` call (see vault/js/lib/api/pam/api-dag-pam-link.ts). +""" + +import unittest +from unittest import mock + +skip_tests = False +skip_reason = "" +try: + from keepercommander.commands.tunnel.port_forward.tunnel_helpers import get_config_uid_via_pam_link + from keepercommander.keeper_dag.proto import GraphSync_pb2 as gs_pb2 + from keepercommander import utils +except ImportError as e: + skip_tests = True + skip_reason = f"Cannot import tunnel_helpers/GraphSync_pb2: {e}" + + +@unittest.skipIf(skip_tests, skip_reason) +class TestGetConfigUidViaPamLink(unittest.TestCase): + def setUp(self): + self.params = mock.MagicMock() + self.record_uid = 'AAAAAAAAAAAAAAAAAAAAAA' # roundtrip-safe base64url for 16 bytes + self.config_uid = 'AQEBAQEBAQEBAQEBAQEBAQ' # roundtrip-safe base64url for 16 bytes + + @mock.patch('keepercommander.commands.pam.router_helper._post_request_to_router') + def test_returns_config_uid_on_single_ref(self, mock_post): + cfg_bytes = utils.base64_url_decode(self.config_uid) + mock_post.return_value = gs_pb2.GraphSyncRefsResult( + refs=[gs_pb2.GraphSyncRef(type=gs_pb2.RFT_PAM_NETWORK, value=cfg_bytes, name='')] + ) + result = get_config_uid_via_pam_link(self.params, self.record_uid) + self.assertEqual(result, self.config_uid) + # Verify endpoint + query shape + args, kwargs = mock_post.call_args + self.assertEqual(args[1], 'graph-sync/pam/get_leafs') + rq = kwargs.get('rq_proto') or args[2] + self.assertEqual(len(rq.vertices), 1) + self.assertEqual(rq.vertices[0], utils.base64_url_decode(self.record_uid)) + + @mock.patch('keepercommander.commands.pam.router_helper._post_request_to_router') + def test_returns_empty_string_when_no_refs(self, mock_post): + mock_post.return_value = gs_pb2.GraphSyncRefsResult(refs=[]) + self.assertEqual(get_config_uid_via_pam_link(self.params, self.record_uid), '') + + @mock.patch('keepercommander.commands.pam.router_helper._post_request_to_router') + def test_returns_empty_string_when_response_is_none(self, mock_post): + mock_post.return_value = None + self.assertEqual(get_config_uid_via_pam_link(self.params, self.record_uid), '') + + @mock.patch('keepercommander.commands.pam.router_helper._post_request_to_router') + def test_skips_refs_with_empty_value(self, mock_post): + cfg_bytes = utils.base64_url_decode(self.config_uid) + mock_post.return_value = gs_pb2.GraphSyncRefsResult(refs=[ + gs_pb2.GraphSyncRef(type=gs_pb2.RFT_PAM_NETWORK, value=b'', name=''), + gs_pb2.GraphSyncRef(type=gs_pb2.RFT_PAM_NETWORK, value=cfg_bytes, name=''), + ]) + self.assertEqual(get_config_uid_via_pam_link(self.params, self.record_uid), self.config_uid) + + @mock.patch('keepercommander.commands.pam.router_helper._post_request_to_router') + def test_swallows_exceptions_and_returns_empty(self, mock_post): + mock_post.side_effect = RuntimeError('krouter unreachable') + self.assertEqual(get_config_uid_via_pam_link(self.params, self.record_uid), '') + + +if __name__ == '__main__': + unittest.main() diff --git a/unit-tests/pam/test_pam_connection_edit_scrollback.py b/unit-tests/pam/test_pam_connection_edit_scrollback.py new file mode 100644 index 000000000..7f758652d --- /dev/null +++ b/unit-tests/pam/test_pam_connection_edit_scrollback.py @@ -0,0 +1,309 @@ +""" +Unit tests for PAM Connection Edit `--scrollback` flag. + +Covers argument parsing and the up-front validation that runs before any DAG / +record mutation: allowed record types (pamDatabase, pamMachine, pamDirectory), +allowed protocols per type, and value parsing (int / empty string / invalid). +""" + +import unittest +from unittest import mock + +skip_tests = False +skip_reason = "" +try: + from keepercommander.commands.tunnel_and_connections import PAMConnectionEditCommand + from keepercommander.error import CommandError + from keepercommander import vault +except ImportError as e: + skip_tests = True + skip_reason = f"Cannot import tunnel_and_connections: {e}" + + +@unittest.skipIf(skip_tests, skip_reason) +class TestPamConnectionEditScrollbackArgs(unittest.TestCase): + def setUp(self): + self.parser = PAMConnectionEditCommand.parser + + def test_scrollback_int(self): + args = self.parser.parse_args(['rec', '--scrollback', '1234']) + self.assertEqual(args.scrollback, '1234') + + def test_scrollback_empty_string(self): + args = self.parser.parse_args(['rec', '--scrollback', '']) + self.assertEqual(args.scrollback, '') + + def test_scrollback_short_alias(self): + args = self.parser.parse_args(['rec', '-sb', '5000']) + self.assertEqual(args.scrollback, '5000') + + def test_scrollback_not_provided(self): + args = self.parser.parse_args(['rec']) + self.assertIsNone(args.scrollback) + + def test_help_includes_scrollback(self): + help_text = self.parser.format_help() + self.assertIn('--scrollback', help_text) + self.assertIn('-sb', help_text) + + +@unittest.skipIf(skip_tests, skip_reason) +class TestPamConnectionEditScrollbackValidation(unittest.TestCase): + """Validation runs before DAG / token operations, so we can drive execute() + with mocks that only need to satisfy resolve_single_record + the typed-field + accessor for pamSettings.""" + + def _mock_record(self, record_type, protocol): + rec = mock.MagicMock(spec=vault.TypedRecord) + rec.record_uid = 'rec-uid' + rec.record_type = record_type + rec.version = 3 + ps_field = mock.MagicMock() + if protocol is None: + ps_field.value = [] + else: + ps_field.value = [{'connection': {'protocol': protocol}}] + rec.get_typed_field.side_effect = lambda name: ps_field if name == 'pamSettings' else None + return rec + + def _execute(self, record, **kwargs): + cmd = PAMConnectionEditCommand() + params = mock.MagicMock() + with mock.patch( + 'keepercommander.commands.tunnel_and_connections.RecordMixin.resolve_single_record', + return_value=record, + ): + cmd.execute(params, record='rec', **kwargs) + + def test_pam_remote_browser_rejected(self): + rec = self._mock_record('pamRemoteBrowser', 'http') + with self.assertRaises(CommandError) as ctx: + self._execute(rec, scrollback='100') + self.assertIn('--scrollback is only supported for pamDatabase, pamMachine, and pamDirectory', + str(ctx.exception)) + + def test_pam_user_rejected(self): + rec = self._mock_record('pamUser', None) + with self.assertRaises(CommandError) as ctx: + self._execute(rec, scrollback='100') + # pamUser fails the outer record-type check before scrollback validation runs + self.assertIn("type is not supported for connections", str(ctx.exception)) + + def test_pam_network_configuration_rejected(self): + rec = self._mock_record('pamNetworkConfiguration', None) + with self.assertRaises(CommandError) as ctx: + self._execute(rec, scrollback='100') + self.assertIn('--scrollback is only supported for pamDatabase, pamMachine, and pamDirectory', + str(ctx.exception)) + + def test_pam_machine_rdp_rejected(self): + rec = self._mock_record('pamMachine', 'rdp') + with self.assertRaises(CommandError) as ctx: + self._execute(rec, scrollback='100') + msg = str(ctx.exception) + self.assertIn('not supported for protocol "rdp"', msg) + self.assertIn('pamMachine', msg) + + def test_pam_machine_vnc_rejected(self): + rec = self._mock_record('pamMachine', 'vnc') + with self.assertRaises(CommandError) as ctx: + self._execute(rec, scrollback='100') + self.assertIn('not supported for protocol "vnc"', str(ctx.exception)) + + def test_pam_machine_no_protocol_rejected(self): + rec = self._mock_record('pamMachine', None) + with self.assertRaises(CommandError) as ctx: + self._execute(rec, scrollback='100') + self.assertIn('not supported for protocol "(unset)"', str(ctx.exception)) + + def test_pam_directory_http_rejected(self): + rec = self._mock_record('pamDirectory', 'http') + with self.assertRaises(CommandError) as ctx: + self._execute(rec, scrollback='100') + self.assertIn('not supported for protocol "http"', str(ctx.exception)) + + def test_non_numeric_rejected(self): + rec = self._mock_record('pamMachine', 'ssh') + with self.assertRaises(CommandError) as ctx: + self._execute(rec, scrollback='not-a-number') + self.assertIn('--scrollback must be a non-negative integer', str(ctx.exception)) + + def test_float_rejected(self): + rec = self._mock_record('pamMachine', 'ssh') + with self.assertRaises(CommandError) as ctx: + self._execute(rec, scrollback='1.5') + self.assertIn('--scrollback must be a non-negative integer', str(ctx.exception)) + + def test_negative_integer_rejected(self): + rec = self._mock_record('pamMachine', 'ssh') + with self.assertRaises(CommandError) as ctx: + self._execute(rec, scrollback='-100') + self.assertIn('--scrollback must be a non-negative integer', str(ctx.exception)) + + def test_negative_one_rejected(self): + rec = self._mock_record('pamMachine', 'ssh') + with self.assertRaises(CommandError) as ctx: + self._execute(rec, scrollback='-1') + self.assertIn('--scrollback must be a non-negative integer', str(ctx.exception)) + + def test_zero_accepted(self): + """Zero is a non-negative integer and should pass validation.""" + rec = self._mock_record('pamMachine', 'ssh') + try: + self._execute(rec, scrollback='0') + except CommandError as e: + self.assertNotIn('--scrollback must be', str(e)) + except Exception: + pass # downstream DAG failure expected; only validation is under test + + def test_protocol_change_in_same_command_validated_against_new(self): + """When --connections=on and --protocol=rdp are passed alongside --scrollback, + validation uses the new (post-mutation) protocol — rdp -> reject.""" + rec = self._mock_record('pamMachine', 'ssh') + with self.assertRaises(CommandError) as ctx: + self._execute(rec, scrollback='100', connections='on', protocol='rdp') + self.assertIn('not supported for protocol "rdp"', str(ctx.exception)) + + def test_protocol_change_without_connections_uses_existing(self): + """--protocol is only honored alongside --connections=on; without it, + validation uses the existing record protocol.""" + rec = self._mock_record('pamMachine', 'ssh') + # Without --connections=on, the bogus --protocol is ignored, existing + # 'ssh' wins. Validation should not raise (it gets past the protocol + # check); we expect it to proceed to the DAG layer and fail there. + # We just assert the error is NOT the scrollback-protocol error. + with self.assertRaises(Exception) as ctx: + self._execute(rec, scrollback='100', protocol='rdp') + self.assertNotIn('not supported for protocol', str(ctx.exception)) + + +@unittest.skipIf(skip_tests, skip_reason) +class TestPamConnectionEditScrollbackAllowedCombinations(unittest.TestCase): + """For each allowed (record_type, protocol) pair, validation must not raise + a scrollback-related error. We don't run the full execute path (which would + require mocking the entire DAG layer), only verify validation passes.""" + + DB_PROTOCOLS = ['mysql', 'postgresql', 'sql-server', 'mariadb', 'oracle', + 'mongodb', 'redis', 'elasticsearch', 'clickhouse', 'dynamodb'] + TERMINAL_PROTOCOLS = ['ssh', 'telnet', 'kubernetes'] + + def _assert_validation_passes(self, record_type, protocol): + rec = mock.MagicMock(spec=vault.TypedRecord) + rec.record_uid = 'rec-uid' + rec.record_type = record_type + rec.version = 3 + ps_field = mock.MagicMock() + ps_field.value = [{'connection': {'protocol': protocol}}] + rec.get_typed_field.side_effect = lambda name: ps_field if name == 'pamSettings' else None + + cmd = PAMConnectionEditCommand() + params = mock.MagicMock() + with mock.patch( + 'keepercommander.commands.tunnel_and_connections.RecordMixin.resolve_single_record', + return_value=rec, + ): + try: + cmd.execute(params, record='rec', scrollback='100') + except CommandError as e: + self.assertNotIn('--scrollback is only supported', str(e)) + self.assertNotIn('--scrollback is not supported for protocol', str(e)) + self.assertNotIn('--scrollback must be an integer', str(e)) + except Exception: + pass # downstream DAG/token failures are not what we're testing + + def test_pam_database_all_db_protocols(self): + for proto in self.DB_PROTOCOLS: + with self.subTest(protocol=proto): + self._assert_validation_passes('pamDatabase', proto) + + def test_pam_machine_terminal_protocols(self): + for proto in self.TERMINAL_PROTOCOLS: + with self.subTest(protocol=proto): + self._assert_validation_passes('pamMachine', proto) + + def test_pam_directory_terminal_protocols(self): + for proto in self.TERMINAL_PROTOCOLS: + with self.subTest(protocol=proto): + self._assert_validation_passes('pamDirectory', proto) + + +@unittest.skipIf(skip_tests, skip_reason) +class TestPamConnectionEditScrollbackEarlyReturn(unittest.TestCase): + """When only record-level args (scrollback, key-events, protocol alone) are passed, + the command should return after the record update without touching the DAG. Locks in + the fix for the misleading 'No PAM Configuration UID set' error when --scrollback is + used on a resource that isn't linked to a config.""" + + def _mock_record(self, record_type='pamMachine', protocol='ssh'): + rec = mock.MagicMock(spec=vault.TypedRecord) + rec.record_uid = 'rec-uid' + rec.record_type = record_type + rec.version = 3 + rec.fields = [] + rec.custom = [] + ps_field = mock.MagicMock() + ps_field.value = [{'connection': {'protocol': protocol}}] + rec.get_typed_field.side_effect = lambda name: ps_field if name == 'pamSettings' else ( + mock.MagicMock(value=['seed']) if name == 'trafficEncryptionSeed' else None + ) + return rec + + @mock.patch('keepercommander.commands.tunnel_and_connections.RecordMixin.resolve_single_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.record_management.update_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.api.sync_down') + @mock.patch('keepercommander.commands.tunnel_and_connections.get_keeper_tokens', + return_value=(b'st', b'tk', b'tr')) + @mock.patch('keepercommander.commands.tunnel_and_connections.get_config_uid') + @mock.patch('keepercommander.commands.tunnel_and_connections.TunnelDAG') + def test_scrollback_alone_skips_dag(self, mock_tdag, mock_get_config_uid, + mock_tokens, mock_sync, mock_update, mock_resolve): + """Running with only --scrollback should NOT invoke get_config_uid or TunnelDAG.""" + rec = self._mock_record() + mock_resolve.return_value = rec + cmd = PAMConnectionEditCommand() + cmd.execute(mock.MagicMock(), record='rec', scrollback='1234') + mock_get_config_uid.assert_not_called() + mock_tdag.assert_not_called() + # The record update IS expected to run (scrollback was written) + mock_update.assert_called_once() + + @mock.patch('keepercommander.commands.tunnel_and_connections.RecordMixin.resolve_single_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.record_management.update_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.api.sync_down') + @mock.patch('keepercommander.commands.tunnel_and_connections.get_keeper_tokens', + return_value=(b'st', b'tk', b'tr')) + @mock.patch('keepercommander.commands.tunnel_and_connections.get_config_uid') + @mock.patch('keepercommander.commands.tunnel_and_connections.TunnelDAG') + def test_key_events_alone_skips_dag(self, mock_tdag, mock_get_config_uid, + mock_tokens, mock_sync, mock_update, mock_resolve): + """Same early-return applies to --key-events alone.""" + rec = self._mock_record() + mock_resolve.return_value = rec + cmd = PAMConnectionEditCommand() + cmd.execute(mock.MagicMock(), record='rec', key_events='on') + mock_get_config_uid.assert_not_called() + mock_tdag.assert_not_called() + + @mock.patch('keepercommander.commands.tunnel_and_connections.RecordMixin.resolve_single_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.record_management.update_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.api.sync_down') + @mock.patch('keepercommander.commands.tunnel_and_connections.get_keeper_tokens', + return_value=(b'st', b'tk', b'tr')) + @mock.patch('keepercommander.commands.tunnel_and_connections.get_config_uid', + return_value=None) + def test_scrollback_with_connections_still_runs_dag(self, mock_get_config_uid, + mock_tokens, mock_sync, mock_update, mock_resolve): + """When --connections is passed alongside --scrollback, the DAG block must still run + (and is expected to surface its own errors). We just verify get_config_uid is reached.""" + rec = self._mock_record() + mock_resolve.return_value = rec + cmd = PAMConnectionEditCommand() + try: + cmd.execute(mock.MagicMock(), record='rec', scrollback='1234', connections='on') + except Exception: + pass # downstream TunnelDAG instantiation will fail; not under test here + mock_get_config_uid.assert_called_once() + + +if __name__ == '__main__': + unittest.main() diff --git a/unit-tests/pam/test_pam_rbi_edit.py b/unit-tests/pam/test_pam_rbi_edit.py index 94c91ee84..ce52be84d 100644 --- a/unit-tests/pam/test_pam_rbi_edit.py +++ b/unit-tests/pam/test_pam_rbi_edit.py @@ -2,7 +2,7 @@ Unit tests for PAM RBI Edit command - KC-1034 Feature Parity Tests the new CLI arguments added to expose RBI settings: -- Browser Settings: --allow-url-navigation, --ignore-server-cert (on/off/default) +- Browser Settings: --allow-url-navigation, --ignore-server-cert, --allow-file-uploads, --allow-file-downloads (on/off/default) - URL Filtering: --allowed-urls, --allowed-resource-urls (multi-value) - Autofill: --autofill-targets (multi-value) - Clipboard: --allow-copy, --allow-paste (on/off/default) @@ -67,6 +67,46 @@ def test_ignore_server_cert_default(self): args = self.parser.parse_args(['--record', 'test-record', '--ignore-server-cert', 'default']) self.assertEqual(args.ignore_server_cert, 'default') + def test_allow_file_uploads_on(self): + args = self.parser.parse_args(['--record', 'test-record', '--allow-file-uploads', 'on']) + self.assertEqual(args.allow_file_uploads, 'on') + + def test_allow_file_uploads_off(self): + args = self.parser.parse_args(['--record', 'test-record', '--allow-file-uploads', 'off']) + self.assertEqual(args.allow_file_uploads, 'off') + + def test_allow_file_uploads_default(self): + args = self.parser.parse_args(['--record', 'test-record', '--allow-file-uploads', 'default']) + self.assertEqual(args.allow_file_uploads, 'default') + + def test_allow_file_uploads_invalid(self): + with self.assertRaises(SystemExit): + self.parser.parse_args(['--record', 'test-record', '--allow-file-uploads', 'invalid']) + + def test_allow_file_uploads_not_provided(self): + args = self.parser.parse_args(['--record', 'test-record', '--key-events', 'on']) + self.assertIsNone(args.allow_file_uploads) + + def test_allow_file_downloads_on(self): + args = self.parser.parse_args(['--record', 'test-record', '--allow-file-downloads', 'on']) + self.assertEqual(args.allow_file_downloads, 'on') + + def test_allow_file_downloads_off(self): + args = self.parser.parse_args(['--record', 'test-record', '--allow-file-downloads', 'off']) + self.assertEqual(args.allow_file_downloads, 'off') + + def test_allow_file_downloads_default(self): + args = self.parser.parse_args(['--record', 'test-record', '--allow-file-downloads', 'default']) + self.assertEqual(args.allow_file_downloads, 'default') + + def test_allow_file_downloads_invalid(self): + with self.assertRaises(SystemExit): + self.parser.parse_args(['--record', 'test-record', '--allow-file-downloads', 'invalid']) + + def test_allow_file_downloads_not_provided(self): + args = self.parser.parse_args(['--record', 'test-record', '--key-events', 'on']) + self.assertIsNone(args.allow_file_downloads) + def test_allowed_urls_single(self): args = self.parser.parse_args(['--record', 'test-record', '--allowed-urls', '*.example.com']) self.assertEqual(args.allowed_urls, ['*.example.com']) @@ -214,6 +254,56 @@ def test_ignore_server_cert_on_sets_true(self, mock_sync, mock_update, mock_reso self.command.execute(self.mock_params, record='test-record', ignore_server_cert='on') self.assertEqual(self.pam_settings['connection'].get('ignoreInitialSslCert'), True) + @mock.patch('keepercommander.commands.tunnel_and_connections.RecordMixin.resolve_single_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.record_management.update_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.api.sync_down') + def test_allow_file_uploads_on_sets_true(self, mock_sync, mock_update, mock_resolve): + mock_resolve.return_value = self.mock_record + self.command.execute(self.mock_params, record='test-record', allow_file_uploads='on') + self.assertEqual(self.pam_settings['connection'].get('allowFileUploads'), True) + + @mock.patch('keepercommander.commands.tunnel_and_connections.RecordMixin.resolve_single_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.record_management.update_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.api.sync_down') + def test_allow_file_uploads_off_sets_false(self, mock_sync, mock_update, mock_resolve): + mock_resolve.return_value = self.mock_record + self.command.execute(self.mock_params, record='test-record', allow_file_uploads='off') + self.assertEqual(self.pam_settings['connection'].get('allowFileUploads'), False) + + @mock.patch('keepercommander.commands.tunnel_and_connections.RecordMixin.resolve_single_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.record_management.update_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.api.sync_down') + def test_allow_file_uploads_default_removes_field(self, mock_sync, mock_update, mock_resolve): + mock_resolve.return_value = self.mock_record + self.pam_settings['connection']['allowFileUploads'] = True + self.command.execute(self.mock_params, record='test-record', allow_file_uploads='default') + self.assertNotIn('allowFileUploads', self.pam_settings['connection']) + + @mock.patch('keepercommander.commands.tunnel_and_connections.RecordMixin.resolve_single_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.record_management.update_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.api.sync_down') + def test_allow_file_downloads_on_sets_true(self, mock_sync, mock_update, mock_resolve): + mock_resolve.return_value = self.mock_record + self.command.execute(self.mock_params, record='test-record', allow_file_downloads='on') + self.assertEqual(self.pam_settings['connection'].get('allowFileDownloads'), True) + + @mock.patch('keepercommander.commands.tunnel_and_connections.RecordMixin.resolve_single_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.record_management.update_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.api.sync_down') + def test_allow_file_downloads_off_sets_false(self, mock_sync, mock_update, mock_resolve): + mock_resolve.return_value = self.mock_record + self.command.execute(self.mock_params, record='test-record', allow_file_downloads='off') + self.assertEqual(self.pam_settings['connection'].get('allowFileDownloads'), False) + + @mock.patch('keepercommander.commands.tunnel_and_connections.RecordMixin.resolve_single_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.record_management.update_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.api.sync_down') + def test_allow_file_downloads_default_removes_field(self, mock_sync, mock_update, mock_resolve): + mock_resolve.return_value = self.mock_record + self.pam_settings['connection']['allowFileDownloads'] = True + self.command.execute(self.mock_params, record='test-record', allow_file_downloads='default') + self.assertNotIn('allowFileDownloads', self.pam_settings['connection']) + @mock.patch('keepercommander.commands.tunnel_and_connections.RecordMixin.resolve_single_record') @mock.patch('keepercommander.commands.tunnel_and_connections.record_management.update_record') @mock.patch('keepercommander.commands.tunnel_and_connections.api.sync_down') @@ -356,6 +446,8 @@ def test_help_includes_new_arguments(self): help_text = PAMRbiEditCommand.parser.format_help() self.assertIn('--allow-url-navigation', help_text) self.assertIn('--ignore-server-cert', help_text) + self.assertIn('--allow-file-uploads', help_text) + self.assertIn('--allow-file-downloads', help_text) self.assertIn('--allowed-urls', help_text) self.assertIn('--allowed-resource-urls', help_text) self.assertIn('--autofill-targets', help_text) @@ -386,6 +478,14 @@ def test_alias_isc(self): args = self.parser.parse_args(['--record', 'test-record', '-isc', 'on']) self.assertEqual(args.ignore_server_cert, 'on') + def test_alias_fu(self): + args = self.parser.parse_args(['--record', 'test-record', '-fu', 'on']) + self.assertEqual(args.allow_file_uploads, 'on') + + def test_alias_fd(self): + args = self.parser.parse_args(['--record', 'test-record', '-fd', 'on']) + self.assertEqual(args.allow_file_downloads, 'on') + def test_alias_au(self): args = self.parser.parse_args(['--record', 'test-record', '-au', '*.example.com']) self.assertEqual(args.allowed_urls, ['*.example.com'])