diff --git a/.gitignore b/.gitignore index 1b40c8dc0..e81f363fa 100644 --- a/.gitignore +++ b/.gitignore @@ -17,7 +17,7 @@ Makefile dr-logs /.venv* *.swp -CLAUDE.md +#CLAUDE.md AGENTS.md keeper_db.sqlite __pycache__/ diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..f8b777f2e --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,97 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## ⚠️ Do not edit vendored directories + +**Never edit files under `keepercommander/discovery_common/` or `keepercommander/keeper_dag/`.** These are copied in from external "golden" repositories (`keeper-dag`, and the shared Gateway/KDNRM `discovery-common` code). Any edits made here will be **overwritten** the next time the directories are synced from upstream. If a change is needed in this code, make it in the upstream repo — not here. The same applies to generated protobuf files in `keepercommander/proto/` (`*_pb2.py` / `*.pyi`): regenerate from the `.proto` source, never hand-edit. + +## What this is + +Keeper Commander is a command-line and terminal-UI client for Keeper Password Manager and KeeperPAM. It is a Python package (`keepercommander`) that ships as the `keeper` console script. Beyond vault access it does enterprise administration, PAM (privileged access: rotation, tunnels, discovery), data import/export from other password managers, and can run as a REST service. + +## Setup & running + +```bash +python3 -m venv venv +source venv/bin/activate +pip install -r requirements.txt +pip install -e . +pip install -e '.[email]' # optional email-sending extras +``` + +Run the CLI: +- `keeper help` — list all commands +- `keeper shell` — interactive command shell +- `keeper supershell` — full terminal vault UI (Textual) +- `keeper [args]` — run a single command and exit +- `python keeper.py ...` — equivalent entry point (calls `keepercommander.__main__:main`) + +Config defaults to `config.json` (cwd or platform data dir); override with `KEEPER_CONFIG_FILE`. Set `KEEPER_COMMANDER_DEBUG` for debug logging. + +## Testing + +CI runs only `unit-tests/` on Python 3.8 and 3.14 (`.github/workflows/test-with-pytest.yml`). Supported range is 3.8–3.14. + +```bash +pip install '.[test]' +pytest unit-tests/ # what CI runs +pytest unit-tests/test_sync_down.py # single file +pytest unit-tests/test_sync_down.py::TestClass::test_name # single test +pytest -m keeper_imports # smoke test: every module imports cleanly +``` + +Test marker semantics (see `pytest.ini`): +- `unit` — mocked, no live server (credential-provision) +- `e2e` — end-to-end, run manually (`pytest -m e2e`) +- `integration` — hits internal `dev.keepersecurity.com` accounts via a `config.json` (`tests/data_config.py`); not for general use +- `cross_enterprise` — excluded by default in `addopts` + +Note: `pytest.ini` excludes `venv/`, ignores `unit-tests/test_command_utils.py` (circular import) and `unit-tests/test_login.py` (connection errors), and disables warnings. The `tests/` directory holds heavier integration/e2e suites that are *not* run by CI. + +Lint config is `pylintrc`; `desired-pylint-warnings` documents which warning categories the team cares about. + +## Architecture + +**Entry & dispatch.** `__main__.py` loads config into a `KeeperParams`, then `cli.py` drives the REPL/one-shot execution. `cli.do_command()` is the central dispatcher: it parses a command line, resolves aliases, picks the right registry (`commands`, `enterprise_commands`, or `msp_commands`), and calls the command. Control characters in command input are rejected at this boundary. + +**Commands.** Every command subclasses `Command` (or `ArgparseCommand`) in `keepercommander/commands/base.py`. The contract: +- `get_parser()` returns an `argparse.ArgumentParser`; `execute_args()` parses the raw string and dispatches to `execute(params, **kwargs)`. +- `is_authorised()` gates whether login is required. +- `GroupCommand` / `GroupCommandNew` compose sub-verbs (e.g. `pam `); they own their own sub-registries and aliases. +- Mixins `RecordMixin` / `FolderMixin` provide shared record/folder resolution helpers. + +Commands are registered through `register_commands(commands, aliases, command_info)` in `commands/base.py`, which imports each command module and calls its `register_commands` / `register_command_info`. To add a command, create the module, implement the `Command`, and wire it into the appropriate `register_commands` function. The `commands/` subdirectories group large feature areas (`pam`, `pam_cloud`, `pam_import`, `discover`, `enterprise*`, `domain_management`, `remote_management`, `keeper_drive`, `pedm`, `scim`). + +**Session & state.** `KeeperParams` (`params.py`) is the single mutable object threaded through everything: session tokens, the in-memory vault cache (records, folders, shared folders, teams), enterprise data, and `RestApiContext` (server, transmission/encryption keys, QRC/EC key negotiation). Commands read and mutate this object rather than passing data around. + +**Network & data sync.** `api.py` is the transport layer: `login()`, `communicate()` / `communicate_rest()` (protobuf request/response with throttle retry), and `query_enterprise()`. `rest_api.py` / `loginv3.py` handle the low-level REST and login-v3 flows; `auth/` holds login-step and console-UI logic. `sync_down.py` pulls and decrypts the vault into `params`, then `prepare_folder_tree()` builds the folder hierarchy. Wire formats live in `proto/` (generated `*_pb2.py` — do not hand-edit). Crypto primitives are in `crypto.py`. + +**Vault data model.** `vault.py` defines the record types: `KeeperRecord` (abstract) with `PasswordRecord` (v2), `TypedRecord` (v3, field-based with `TypedField`), `FileRecord`, `ApplicationRecord`. `record_facades.py` / `vault_extensions.py` provide typed views; `subfolder.py` models the folder tree. + +**Local storage.** `storage/` (SQLite + in-memory DAOs) and `config_storage/` persist cache and config; secure config storage can be encrypted (`loader.SecureStorageException` path in `__main__.py`). + +**PAM / discovery / graph.** `keeper_dag/` and `discovery_common/` implement the directed-acyclic graph (DAG) backing PAM discovery, record-linking, and rotation. `commands/pam/`, `commands/pam_cloud/`, and `commands/discover/` build on top of them. These two directories are **vendored copies** of external golden repos — see the warning at the top of this file before touching them. + +**Importers.** `importer/` has per-product subpackages (1password, bitwarden, lastpass, keepass, dashlane, proton, thycotic, cyberark, etc.) plus generic csv/json. `imp_exp.py` orchestrates import/export. + +**Service mode.** `service/` is a Flask-based REST API server exposing Commander commands over HTTP with API-key auth, rate limiting, and optional response encryption. See `keepercommander/service/README.md`. Managed via `service-create` / `service-start` / etc. commands. + +**Plugins.** `plugins/` are rotation plugins loaded dynamically and registered like other commands. + +## Style + +Follow [PEP 8](https://peps.python.org/pep-0008/), with the project-specific settings enforced by `pylintrc`: +- **Line length: 100** (not PEP 8's default 79). +- `snake_case` for functions, methods, arguments, variables, and attributes. +- `PascalCase` for classes. +- `UPPER_CASE` for module-level and class constants. +- 4-space indentation, no tabs. + +Run `pylint keepercommander/.py` to check; `desired-pylint-warnings` documents which warning categories the team treats as meaningful. + +## Conventions + +- Match the surrounding file's style; most modules carry the Keeper ASCII-art header. +- Never edit `keeper_dag/`, `discovery_common/`, or generated `proto/` files (see the warning at the top of this file). +- Version lives in `keepercommander/__init__.py` (`__version__`); `setup.cfg` reads it via `attr:`. \ No newline at end of file diff --git a/keepercommander/__init__.py b/keepercommander/__init__.py index f16f8c4d4..845f8bb94 100644 --- a/keepercommander/__init__.py +++ b/keepercommander/__init__.py @@ -10,4 +10,4 @@ # Contact: commander@keepersecurity.com # -__version__ = '18.0.6' +__version__ = '18.0.7' diff --git a/keepercommander/auth/console_ui.py b/keepercommander/auth/console_ui.py index 0a60b077b..eca6bf3b4 100644 --- a/keepercommander/auth/console_ui.py +++ b/keepercommander/auth/console_ui.py @@ -18,6 +18,31 @@ def _stderr(msg=''): print(msg, file=sys.stderr) +_HEADLESS_AUTH_MSG_SHOWN = False + + +def _is_interactive(): + try: + return bool(sys.stdin) and sys.stdin.isatty() + except Exception: + return False + + +def _fail_headless_auth(step): + """In headless/service mode, persistent login often needs a follow-up prompt + (password, SSO, 2FA, device approval) that cannot be answered. Log once and + cancel so the caller exits cleanly instead of looping or spamming getpass.""" + global _HEADLESS_AUTH_MSG_SHOWN + if not _HEADLESS_AUTH_MSG_SHOWN: + _HEADLESS_AUTH_MSG_SHOWN = True + logging.error( + 'Persistent login is not working in this non-interactive environment ' + '(possibly due to an IP/location change). ' + 'Re-run Commander/Docker setup from this network, then restart the service.' + ) + step.cancel() + + class ConsoleLoginUi(login_steps.LoginUi): def __init__(self): self._show_device_approval_help = True @@ -28,6 +53,9 @@ def __init__(self): self._failed_password_attempt = 0 def on_device_approval(self, step): + if not _is_interactive(): + _fail_headless_auth(step) + return if self._show_device_approval_help: _stderr(f"\n{Fore.YELLOW}Device Approval Required{Fore.RESET}\n") _stderr(f"{Fore.CYAN}Select an approval method:{Fore.RESET}") @@ -123,6 +151,9 @@ def two_factor_channel_to_desc(channel): # type: (login_steps.TwoFactorChannel return 'Backup Codes' def on_two_factor(self, step): + if not _is_interactive(): + _fail_headless_auth(step) + return channels = step.get_channels() if self._show_two_factor_help: @@ -273,6 +304,9 @@ def on_two_factor(self, step): logging.warning(f'{Fore.YELLOW}Invalid 2FA code. Please try again.{Fore.RESET}') def on_password(self, step): + if not _is_interactive(): + _fail_headless_auth(step) + return if self._show_password_help: _stderr(f'{Fore.CYAN}Enter master password for {Fore.WHITE}{step.username}{Fore.RESET}') @@ -293,6 +327,9 @@ def on_password(self, step): step.cancel() def on_sso_redirect(self, step): + if not _is_interactive(): + _fail_headless_auth(step) + return try: wb = webbrowser.get() wrappers = set('xdg-open|gvfs-open|gnome-open|x-www-browser|www-browser'.split('|')) @@ -360,6 +397,9 @@ def on_sso_redirect(self, step): break def on_sso_data_key(self, step): + if not _is_interactive(): + _fail_headless_auth(step) + return if self._show_sso_data_key_help: _stderr(f'\n{Fore.YELLOW}Device Approval Required for SSO{Fore.RESET}\n') _stderr(f'{Fore.CYAN}Select an approval method:{Fore.RESET}') diff --git a/keepercommander/commands/discoveryrotation.py b/keepercommander/commands/discoveryrotation.py index 7671ef65d..8ee25fc1a 100644 --- a/keepercommander/commands/discoveryrotation.py +++ b/keepercommander/commands/discoveryrotation.py @@ -17,6 +17,7 @@ import re import time from datetime import datetime +from typing import Dict, Optional, Any, Set, List from urllib.parse import urlparse, urlunparse import requests @@ -89,10 +90,12 @@ from .pam_saas.config import PAMActionSaasConfigCommand from .pam_saas.update import PAMActionSaasUpdateCommand from .tunnel_and_connections import PAMTunnelCommand, PAMConnectionCommand, PAMRbiCommand, PAMSplitCommand +from .pam.cnapp_commands import PAMCnappCommand from .universalsecretsync import ( PAMUniversalSyncConfigCommand, PAMUniversalSyncRunCommand ) +from .pam.cnapp_commands import PAMCnappCommand # These characters are based on the Vault PAM_DEFAULT_SPECIAL_CHAR = '''!@#$%^?();',.=+[]<>{}-_/\\*&:"`~|''' @@ -200,8 +203,12 @@ def __init__(self): self.register_command('workflow', PAMWorkflowCommand(), 'Manage PAM Workflows', 'w') self.register_command('access', PAMPrivilegedAccessCommand(), 'Manage privileged cloud access operations', 'ac') + self.register_command('cnapp', PAMCnappCommand(), + 'Manage Cloud-Native Application Protection Platform integration', 'cn') self.register_command('universal-sync-config', PAMUniversalSyncConfigCommand(), 'Manage Universal Sync Configurations', 'usc') self.register_command('universal-sync-run', PAMUniversalSyncRunCommand(), 'Run Universal Sync', 'usr') + self.register_command('cnapp', PAMCnappCommand(), + 'Manage Cloud-Native Application Protection Platform integration', 'cn') class PAMGatewayCommand(GroupCommand): diff --git a/keepercommander/commands/helpers/record.py b/keepercommander/commands/helpers/record.py index 7769f5f7e..a2f97976f 100644 --- a/keepercommander/commands/helpers/record.py +++ b/keepercommander/commands/helpers/record.py @@ -1,9 +1,26 @@ +import re from typing import Set, Optional from ... import api +from ...error import CommandError from ...params import KeeperParams from ...subfolder import try_resolve_path +# Block shell chaining markers in `get` lookup tokens. +_GET_LOOKUP_CONTROL_CHARS_RE = re.compile(r'[\r\n\x00]') +_GET_LOOKUP_SHELL_METACHAR_RE = re.compile(r'[;|]') +_GET_LOOKUP_CHAIN_RE = re.compile(r'&&') + + +def raise_if_unsafe_get_lookup_token(token, command='get'): + # type: (str, str) -> None + if not token: + raise CommandError(command, 'Invalid record identifier: forbidden characters') + if (_GET_LOOKUP_CONTROL_CHARS_RE.search(token) + or _GET_LOOKUP_SHELL_METACHAR_RE.search(token) + or _GET_LOOKUP_CHAIN_RE.search(token)): + raise CommandError(command, 'Invalid record identifier: forbidden characters') + # Get record UID(s) given one of its identifiers: name (if current folder contains the record), path, or UID def get_record_uids(params, name): # type: (KeeperParams, str) -> Set[Optional[str]] diff --git a/keepercommander/commands/pam/cnapp_commands.py b/keepercommander/commands/pam/cnapp_commands.py new file mode 100644 index 000000000..92178ac6b --- /dev/null +++ b/keepercommander/commands/pam/cnapp_commands.py @@ -0,0 +1,647 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' bytes|None + """Encrypter keys are typically `openssl rand -base64 32`. Try standard base64 first, + then base64url (legacy notes). 16 or 32 bytes only — anything else is rejected so we + don't pass garbage to AES-GCM.""" + if not raw or not isinstance(raw, str): + return None + candidate = raw.strip() + for decoder in (base64.b64decode, base64.urlsafe_b64decode): + try: + padding = '=' * (-len(candidate) % 4) + data = decoder(candidate + padding) + except (binascii.Error, ValueError): + continue + if len(data) == 32: + return data + return None + + +def _load_encrypter_key(params, config_record_uid): + """Resolve the AES key from the CNAPP encrypter vault record. Returns None when the + record can't be loaded or doesn't carry a recognizable key — callers should fall back + to showing the encrypted payload as-is.""" + if not config_record_uid: + return None + try: + record = vault.KeeperRecord.load(params, config_record_uid) + except Exception as e: + logging.debug('CNAPP: failed to load encrypter record %s: %s', config_record_uid, e) + return None + if not isinstance(record, vault.TypedRecord): + return None + # Match vault/cloudSecurityUtils.ts: prefer `secret` then `note` labeled "Encryption Key", + # then the first unlabeled `note` field only when no labeled key field exists. + labeled_raws = [] + secret_field = record.get_typed_field('secret', CNAPP_ENCRYPTION_KEY_LABEL) + if secret_field and secret_field.value: + labeled_raws.append(secret_field.value[0]) + note_labeled = record.get_typed_field('note', CNAPP_ENCRYPTION_KEY_LABEL) + if note_labeled and note_labeled.value: + labeled_raws.append(note_labeled.value[0]) + for raw in labeled_raws: + key = _decode_aes_key(raw) + if key: + return key + logging.warning( + 'CNAPP: "%s" field is present on encrypter record %s but is not a valid AES-256 key; ' + 'not using other note fields.', + CNAPP_ENCRYPTION_KEY_LABEL, config_record_uid, + ) + return None + first_note = record.get_typed_field('note') + if first_note and first_note.value: + key = _decode_aes_key(first_note.value[0]) + if key: + return key + return None + + +def _decrypt_cnapp_payload(payload_bytes, key): + """Decrypt a CNAPP queue payload using the Encrypter's AES-256-GCM key. + + Wire format (matches vault's `decryptCnappQueueItem` in cloudSecurityUtils.ts): + payload_bytes (proto field, base64url-decoded by us) is UTF-8 base64url text + of a JSON envelope `{"encrypted_payload":"","alg":"AES-256-GCM","version":"1"}`. + encrypted_payload base64url-decodes to `nonce(12) || ciphertext || tag(16)` — + the standard layout AESGCM.decrypt expects. + + Returns a dict on success; raises Exception on bad envelope / wrong key / bad alg + so the caller can surface a meaningful warning.""" + from cryptography.hazmat.primitives.ciphers.aead import AESGCM + envelope_b64 = payload_bytes.decode('utf-8') + envelope_json = base64.urlsafe_b64decode(envelope_b64 + '=' * (-len(envelope_b64) % 4)) + envelope = json.loads(envelope_json) + alg = envelope.get('alg') + if alg != 'AES-256-GCM': + raise ValueError(f"Unsupported or missing CNAPP payload algorithm: {alg!r}") + ciphertext_b64 = envelope.get('encrypted_payload') or '' + ciphertext = base64.urlsafe_b64decode(ciphertext_b64 + '=' * (-len(ciphertext_b64) % 4)) + if len(ciphertext) < 12 + 16: + raise ValueError('CNAPP ciphertext shorter than nonce+tag — corrupt payload') + nonce, body = ciphertext[:12], ciphertext[12:] + plaintext = AESGCM(key).decrypt(nonce, body, None) + return json.loads(plaintext.decode('utf-8')) + + +def _resolve_status(value, allow_all=True): # type: (str|int|None, bool) -> int + """Accept either the numeric status id or its case-insensitive name.""" + if value is None or value == '': + status_id = 0 + elif isinstance(value, int): + status_id = value + else: + s = str(value).strip().lower() + if s.lstrip('-').isdigit(): + status_id = int(s) + elif s in QUEUE_STATUS_BY_NAME: + status_id = QUEUE_STATUS_BY_NAME[s] + else: + raise CommandError( + 'pam cnapp', + f"Unknown status '{value}'. Valid: {', '.join(QUEUE_STATUS_BY_NAME)} or 0 for ALL.", + ) + if status_id == 0: + if allow_all: + return 0 + raise CommandError('pam cnapp', 'A specific status is required (cannot be 0/ALL).') + if status_id not in QUEUE_STATUS_BY_ID: + raise CommandError( + 'pam cnapp', + f"Unknown status id {status_id}. Valid ids: {', '.join(str(i) for i in sorted(QUEUE_STATUS_BY_ID))}.", + ) + return status_id + + +def _format_timestamp(epoch_ms): + """krouter emits epoch-millis for received/resolved timestamps; render as UTC ISO.""" + if not epoch_ms: + return '' + try: + return datetime.fromtimestamp(int(epoch_ms) / 1000, tz=timezone.utc).isoformat() + except (ValueError, TypeError, OSError): + return f'' + + +class PAMCnappCommand(GroupCommand): + """Root for the `pam cnapp ...` command tree.""" + + def __init__(self): + super(PAMCnappCommand, self).__init__() + self.register_command('config', PAMCnappConfigCommand(), + 'Manage CNAPP provider configuration', 'c') + self.register_command('queue', PAMCnappQueueCommand(), + 'Manage CNAPP issue queue', 'q') + self.default_verb = 'queue' + + +# --------------------------------------------------------------------------- +# Configuration sub-tree +# --------------------------------------------------------------------------- + +class PAMCnappConfigCommand(GroupCommand): + + def __init__(self): + super(PAMCnappConfigCommand, self).__init__() + self.register_command('set', PAMCnappConfigSetCommand(), + 'Create or update CNAPP provider configuration') + self.register_command('test', PAMCnappConfigTestCommand(), + 'Validate CNAPP provider credentials without saving') + self.register_command('test-encrypter', PAMCnappConfigTestEncrypterCommand(), + 'Health-check the customer Encrypter at /health') + self.register_command('read', PAMCnappConfigReadCommand(), + 'Read the persisted CNAPP configuration for a network') + self.register_command('delete', PAMCnappConfigDeleteCommand(), + 'Delete the CNAPP configuration on a network') + self.default_verb = '' + + +def _add_configuration_args(parser, require_secret=True, optional_secret_on_set=False): + parser.add_argument('--network-uid', '-n', required=True, dest='network_uid', + help='Network record UID (base64url).') + parser.add_argument('--provider', '-p', required=True, dest='provider', + help='CNAPP provider keyword: wiz (case-insensitive).') + parser.add_argument('--client-id', required=True, dest='client_id', + help='Provider API client ID / app ID.') + if optional_secret_on_set: + parser.add_argument('--client-secret', required=False, default=None, dest='client_secret', + help='Provider API client secret. Omit on `config set` to keep the existing secret.') + else: + parser.add_argument('--client-secret', required=require_secret, dest='client_secret', + help='Provider API client secret.') + parser.add_argument('--api-endpoint', required=True, dest='api_endpoint_url', + help='Provider API endpoint URL (e.g. https://api.us1.app.wiz.io/graphql).') + parser.add_argument('--auth-endpoint', required=True, dest='auth_endpoint_url', + help='Provider OAuth2 token endpoint URL (e.g. https://auth.app.wiz.io/oauth/token).') + + +class PAMCnappConfigSetCommand(Command): + parser = argparse.ArgumentParser(prog='pam cnapp config set') + _add_configuration_args(parser, optional_secret_on_set=True) + parser.add_argument('--config-record', required=True, dest='cnapp_config_record_uid', + help='UID of the vault record holding the Encrypter URL + key.') + + def get_parser(self): + return PAMCnappConfigSetCommand.parser + + def execute(self, params, **kwargs): + provider = cnapp_helper.provider_from_name(kwargs.get('provider')) + response = cnapp_helper.set_cnapp_configuration( + params, + network_uid=kwargs.get('network_uid'), + provider=provider, + client_id=kwargs.get('client_id'), + client_secret='' if kwargs.get('client_secret') is None else kwargs.get('client_secret'), + api_endpoint_url=kwargs.get('api_endpoint_url'), + cnapp_config_record_uid=kwargs.get('cnapp_config_record_uid'), + auth_endpoint_url=kwargs.get('auth_endpoint_url'), + ) + print(f"{bcolors.OKGREEN}CNAPP configuration saved.{bcolors.ENDC}") + if response is not None: + _print_configuration(response) + return None + + +class PAMCnappConfigTestCommand(Command): + parser = argparse.ArgumentParser(prog='pam cnapp config test') + _add_configuration_args(parser, require_secret=True) + + def get_parser(self): + return PAMCnappConfigTestCommand.parser + + def execute(self, params, **kwargs): + provider = cnapp_helper.provider_from_name(kwargs.get('provider')) + cnapp_helper.test_cnapp_configuration( + params, + network_uid=kwargs.get('network_uid'), + provider=provider, + client_id=kwargs.get('client_id'), + client_secret=kwargs.get('client_secret'), + api_endpoint_url=kwargs.get('api_endpoint_url'), + auth_endpoint_url=kwargs.get('auth_endpoint_url'), + ) + print(f"{bcolors.OKGREEN}CNAPP credentials validated successfully.{bcolors.ENDC}") + + +class PAMCnappConfigTestEncrypterCommand(Command): + parser = argparse.ArgumentParser(prog='pam cnapp config test-encrypter') + parser.add_argument('--url', '-u', required=True, dest='url', + help='Base URL of the Encrypter. krouter probes /health.') + + def get_parser(self): + return PAMCnappConfigTestEncrypterCommand.parser + + def execute(self, params, **kwargs): + cnapp_helper.test_cnapp_encrypter(params, url_base_encrypter=kwargs.get('url')) + print(f"{bcolors.OKGREEN}Encrypter is reachable.{bcolors.ENDC}") + + +class PAMCnappConfigReadCommand(Command): + parser = argparse.ArgumentParser(prog='pam cnapp config read') + parser.add_argument('--network-uid', '-n', required=True, dest='network_uid', + help='Network record UID (base64url).') + parser.add_argument('--provider', '-p', required=True, dest='provider', + help='CNAPP provider keyword: wiz.') + parser.add_argument('--format', dest='format', choices=['table', 'json'], default='table', + help='Output format.') + + def get_parser(self): + return PAMCnappConfigReadCommand.parser + + def execute(self, params, **kwargs): + provider = cnapp_helper.provider_from_name(kwargs.get('provider')) + response = cnapp_helper.read_cnapp_configuration( + params, + network_uid=kwargs.get('network_uid'), + provider=provider, + ) + if response is None: + logging.warning('No CNAPP configuration returned.') + return None + if kwargs.get('format') == 'json': + print(json.dumps(_configuration_to_dict(response), indent=2)) + return None + _print_configuration(response) + return None + + +class PAMCnappConfigDeleteCommand(Command): + parser = argparse.ArgumentParser(prog='pam cnapp config delete') + parser.add_argument('--network-uid', '-n', required=True, dest='network_uid', + help='Network record UID (base64url).') + + def get_parser(self): + return PAMCnappConfigDeleteCommand.parser + + def execute(self, params, **kwargs): + cnapp_helper.delete_cnapp_configuration(params, network_uid=kwargs.get('network_uid')) + print(f"{bcolors.OKGREEN}CNAPP configuration deleted.{bcolors.ENDC}") + + +# --------------------------------------------------------------------------- +# Queue sub-tree +# --------------------------------------------------------------------------- + +class PAMCnappQueueCommand(GroupCommand): + + def __init__(self): + super(PAMCnappQueueCommand, self).__init__() + self.register_command('list', PAMCnappQueueListCommand(), 'List CNAPP queue items', 'l') + self.register_command('associate', PAMCnappQueueAssociateCommand(), + 'Attach a vault record to a queue item', 'a') + self.register_command('remediate', PAMCnappQueueRemediateCommand(), + 'Trigger a remediation action against the gateway', 'r') + self.register_command('set-status', PAMCnappQueueSetStatusCommand(), + 'Update local queue item status (notifies provider best-effort)', 's') + self.register_command('delete', PAMCnappQueueDeleteCommand(), 'Delete a queue item', 'd') + self.default_verb = 'list' + + +class PAMCnappQueueListCommand(Command): + parser = argparse.ArgumentParser(prog='pam cnapp queue list') + parser.add_argument('--network-uid', '-n', required=True, dest='network_uid', + help='Network record UID (base64url).') + parser.add_argument('--status', '-s', required=False, dest='status', default=0, + help='Filter by status name or id (pending/in_progress/resolved/failed/cancelled). Default: all.') + parser.add_argument('--provider', '-p', required=False, dest='provider', default='wiz', + help='CNAPP provider keyword for the config lookup (default: wiz).') + parser.add_argument('--config-record', required=False, dest='config_record_uid', + help='Explicit encrypter vault record UID. Overrides the lookup via `config read`.') + parser.add_argument('--no-decrypt', dest='no_decrypt', action='store_true', + help="Skip payload decryption — show the raw encrypted envelope's metadata only.") + parser.add_argument('--format', dest='format', choices=['table', 'json'], default='table', + help='Output format. Table and JSON are mutually exclusive.') + + def get_parser(self): + return PAMCnappQueueListCommand.parser + + def _resolve_encrypter_key(self, params, kwargs): + """Resolve the AES key: --config-record wins; otherwise fetch `config read` to get + the cnappConfigRecordUid and load the encrypter record from the local vault.""" + if kwargs.get('no_decrypt'): + return None, None + config_record_uid = kwargs.get('config_record_uid') + if not config_record_uid: + try: + provider = cnapp_helper.provider_from_name(kwargs.get('provider') or 'wiz') + config = cnapp_helper.read_cnapp_configuration( + params, network_uid=kwargs.get('network_uid'), provider=provider) + except Exception as e: + logging.debug('CNAPP: could not read configuration for decryption: %s', e) + return None, None + if config is None or not config.cnappConfigRecordUid: + return None, None + config_record_uid = bytes_to_base64(config.cnappConfigRecordUid) + key = _load_encrypter_key(params, config_record_uid) + return key, config_record_uid + + @staticmethod + def _decrypted_summary(decrypted): + """Compact human-readable summary for the table column. Mirrors the columns the + vault Cloud Security view shows: severity, title, resource.""" + if not isinstance(decrypted, dict): + return '' + issue = decrypted.get('issue') or {} + resource = decrypted.get('resource') or {} + control = decrypted.get('control') or {} + bits = [] + sev = issue.get('severity') + if sev: + bits.append(str(sev).upper()) + title = control.get('name') or issue.get('id') + if title: + bits.append(str(title)) + resource_name = resource.get('name') or resource.get('id') + if resource_name: + bits.append(f"on {resource_name}") + return ' · '.join(bits) + + def execute(self, params, **kwargs): + status_filter = _resolve_status(kwargs.get('status')) + response = cnapp_helper.list_cnapp_queue( + params, + network_uid=kwargs.get('network_uid'), + status_filter=status_filter, + ) + items = list(response.items) if response is not None else [] + has_more = bool(response.hasMore) if response is not None else False + + encrypter_key, encrypter_uid = self._resolve_encrypter_key(params, kwargs) + decrypted_by_id = {} + decrypt_errors = {} # type: dict[int, str] + if encrypter_key: + for item in items: + if not item.payload: + continue + try: + decrypted_by_id[item.cnappQueueId] = _decrypt_cnapp_payload(item.payload, encrypter_key) + except Exception as e: + decrypt_errors[item.cnappQueueId] = str(e) + + if kwargs.get('format') == 'json': + json_items = [] + for item in items: + d = _queue_item_to_dict(item) + d.pop('payload', None) + if item.cnappQueueId in decrypted_by_id: + d['decryptedPayload'] = decrypted_by_id[item.cnappQueueId] + elif item.cnappQueueId in decrypt_errors: + d['decryptError'] = decrypt_errors[item.cnappQueueId] + json_items.append(d) + payload = {'items': json_items, 'hasMore': has_more} + print(json.dumps(payload, indent=2, default=str)) + return None + + if not items: + print('No CNAPP queue items.') + return None + + if encrypter_key is None and not kwargs.get('no_decrypt'): + print(f"{bcolors.WARNING}No encrypter key resolved — payloads will be shown as 'encrypted'. " + f"Pass --config-record or run after `pam cnapp config read` succeeds.{bcolors.ENDC}") + + headers = ['Queue ID', 'Provider', 'Status', 'Received (UTC)', 'Resolved (UTC)', 'Record UID', 'Issue'] + rows = [] + for item in items: + if item.cnappQueueId in decrypted_by_id: + issue_cell = self._decrypted_summary(decrypted_by_id[item.cnappQueueId]) + elif not item.payload: + issue_cell = '' + elif kwargs.get('no_decrypt'): + issue_cell = '' + else: + issue_cell = f"{bcolors.WARNING}{bcolors.ENDC}" + rows.append([ + item.cnappQueueId, + cnapp_helper.CnappProvider.Name(item.cnappProviderId), + QUEUE_STATUS_BY_ID.get(item.cnappQueueStatusId, str(item.cnappQueueStatusId)), + _format_timestamp(item.receivedAt), + _format_timestamp(item.resolvedAt), + bytes_to_base64(item.recordUid) if item.recordUid else '', + issue_cell, + ]) + dump_report_data(rows, headers, fmt='table', filename='', row_number=False) + for queue_id, msg in decrypt_errors.items(): + print(f"{bcolors.WARNING}Queue item {queue_id}: failed to decrypt payload ({msg}).{bcolors.ENDC}") + if has_more: + print(f"{bcolors.WARNING}More queue items exist (hasMore=true). " + f"CLI paging is not available yet — resolve or delete returned items to see more.{bcolors.ENDC}") + return None + + +class PAMCnappQueueAssociateCommand(Command): + parser = argparse.ArgumentParser(prog='pam cnapp queue associate') + parser.add_argument('--queue-id', '-q', required=True, type=int, dest='cnapp_queue_id', + help='Queue item ID (from `pam cnapp queue list`).') + parser.add_argument('--record-uid', '-r', required=True, dest='record_uid', + help='Vault record UID to associate (base64url).') + + def get_parser(self): + return PAMCnappQueueAssociateCommand.parser + + def execute(self, params, **kwargs): + cnapp_helper.associate_cnapp_record( + params, + cnapp_queue_id=kwargs.get('cnapp_queue_id'), + record_uid=kwargs.get('record_uid'), + ) + print(f"{bcolors.OKGREEN}Record associated with queue item {kwargs.get('cnapp_queue_id')}.{bcolors.ENDC}") + + +class PAMCnappQueueRemediateCommand(Command): + parser = argparse.ArgumentParser(prog='pam cnapp queue remediate') + parser.add_argument('--queue-id', '-q', required=True, type=int, dest='cnapp_queue_id', + help='Queue item ID.') + parser.add_argument('--action', '-a', required=True, dest='action_type', + help='Remediation action: rotate_credentials, manage_access, jit_access, remove_standing_privilege.') + parser.add_argument('--provider', '-p', required=False, dest='provider', + help='Provider keyword (wiz). Optional — krouter resolves from queue item if omitted.') + parser.add_argument('--config-record', required=False, dest='cnapp_config_record_uid', + help='Configuration record UID (only required for some action types).') + parser.add_argument('--resource-ref', required=False, dest='resource_ref', + help='Resource reference UID for the action.') + parser.add_argument('--pwd-complexity', required=False, dest='pwd_complexity', + help='Password complexity JSON (rotate_credentials).') + parser.add_argument('--controller-uid', required=False, dest='controller_uid', + help='Override gateway UID.') + parser.add_argument('--message-uid', required=False, dest='message_uid', + help='Client-generated conversation UID for streaming responses.') + parser.add_argument('--group-name', required=False, dest='group_name', + help='Group name (remove_standing_privilege only).') + + def get_parser(self): + return PAMCnappQueueRemediateCommand.parser + + def execute(self, params, **kwargs): + action = cnapp_helper.action_from_name(kwargs.get('action_type')) + provider = None + if kwargs.get('provider'): + provider = cnapp_helper.provider_from_name(kwargs.get('provider')) + response = cnapp_helper.remediate_cnapp_queue_item( + params, + cnapp_queue_id=kwargs.get('cnapp_queue_id'), + action_type=action, + provider=provider, + cnapp_config_record_uid=kwargs.get('cnapp_config_record_uid'), + resource_ref=kwargs.get('resource_ref'), + pwd_complexity=kwargs.get('pwd_complexity'), + controller_uid=kwargs.get('controller_uid'), + message_uid=kwargs.get('message_uid'), + group_name=kwargs.get('group_name'), + ) + if response is None: + print(f"{bcolors.OKGREEN}Remediation dispatched.{bcolors.ENDC}") + return None + action_name = cnapp_helper.CnappRemediationAction.Name(response.actionType) + status_name = QUEUE_STATUS_BY_ID.get(response.cnappQueueStatusId, str(response.cnappQueueStatusId)) + print(f"{bcolors.OKGREEN}Remediation dispatched.{bcolors.ENDC}") + print(f" Action: {action_name}") + print(f" Status: {status_name}") + if response.result: + print(f" Result: {response.result}") + return None + + +class PAMCnappQueueSetStatusCommand(Command): + parser = argparse.ArgumentParser(prog='pam cnapp queue set-status') + parser.add_argument('--queue-id', '-q', required=True, type=int, dest='cnapp_queue_id', + help='Queue item ID.') + parser.add_argument('--status', '-s', required=True, dest='status', + help='New status: pending/in_progress/resolved/failed/cancelled, or its numeric id.') + parser.add_argument('--reason', required=False, dest='reason', + help='Free-form reason (forwarded to provider notification).') + + def get_parser(self): + return PAMCnappQueueSetStatusCommand.parser + + def execute(self, params, **kwargs): + status_id = _resolve_status(kwargs.get('status'), allow_all=False) + response = cnapp_helper.set_cnapp_queue_status( + params, + cnapp_queue_id=kwargs.get('cnapp_queue_id'), + cnapp_queue_status_id=status_id, + reason=kwargs.get('reason'), + ) + applied = response.cnappQueueStatusId if response is not None else status_id + print(f"{bcolors.OKGREEN}Status applied: {QUEUE_STATUS_BY_ID.get(applied, applied)}.{bcolors.ENDC}") + return None + + +class PAMCnappQueueDeleteCommand(Command): + parser = argparse.ArgumentParser(prog='pam cnapp queue delete') + parser.add_argument('--queue-id', '-q', required=True, type=int, dest='cnapp_queue_id', + help='Queue item ID to delete.') + + def get_parser(self): + return PAMCnappQueueDeleteCommand.parser + + def execute(self, params, **kwargs): + cnapp_helper.delete_cnapp_queue_item(params, cnapp_queue_id=kwargs.get('cnapp_queue_id')) + print(f"{bcolors.OKGREEN}Queue item {kwargs.get('cnapp_queue_id')} deleted.{bcolors.ENDC}") + + +# --------------------------------------------------------------------------- +# Formatting helpers +# --------------------------------------------------------------------------- + +def _configuration_to_dict(config): + return { + 'networkUid': bytes_to_base64(config.networkUid) if config.networkUid else '', + 'provider': cnapp_helper.CnappProvider.Name(config.provider), + 'clientId': config.clientId, + 'apiEndpointUrl': config.apiEndpointUrl, + 'authEndpointUrl': config.authEndpointUrl, + 'cnappConfigRecordUid': bytes_to_base64(config.cnappConfigRecordUid) if config.cnappConfigRecordUid else '', + } + + +def _queue_item_to_dict(item): + return { + 'cnappQueueId': item.cnappQueueId, + 'cnappProviderId': cnapp_helper.CnappProvider.Name(item.cnappProviderId), + 'cnappQueueStatusId': item.cnappQueueStatusId, + 'cnappQueueStatusName': QUEUE_STATUS_BY_ID.get(item.cnappQueueStatusId, str(item.cnappQueueStatusId)), + 'receivedAt': item.receivedAt, + 'resolvedAt': item.resolvedAt, + 'networkId': bytes_to_base64(item.networkId) if item.networkId else '', + 'recordUid': bytes_to_base64(item.recordUid) if item.recordUid else '', + } + + +def _uid_display(uid_bytes): + return bytes_to_base64(uid_bytes) if uid_bytes else '(none)' + + +def _print_configuration(config): + print(f"{bcolors.OKBLUE}CNAPP Configuration{bcolors.ENDC}") + print(f" Network UID : {_uid_display(config.networkUid)}") + print(f" Provider : {cnapp_helper.CnappProvider.Name(config.provider)}") + print(f" Client ID : {config.clientId or '(none)'}") + print(f" API Endpoint : {config.apiEndpointUrl or '(none)'}") + print(f" Auth Endpoint : {config.authEndpointUrl or '(none)'}") + print(f" Config Record : {_uid_display(config.cnappConfigRecordUid)}") diff --git a/keepercommander/commands/pam/cnapp_helper.py b/keepercommander/commands/pam/cnapp_helper.py new file mode 100644 index 000000000..c1f0302d6 --- /dev/null +++ b/keepercommander/commands/pam/cnapp_helper.py @@ -0,0 +1,265 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' set_cnapp_configuration + configuration/test -> test_cnapp_configuration + configuration/test-encrypter -> test_cnapp_encrypter + configuration/read -> read_cnapp_configuration + configuration/delete -> delete_cnapp_configuration + + Queue: + queue -> list_cnapp_queue + queue/associate -> associate_cnapp_record + queue/remediate -> remediate_cnapp_queue_item + queue/set-status -> set_cnapp_queue_status + queue/delete -> delete_cnapp_queue_item + +Failures from the helper layer bubble up as Python exceptions raised by the underlying +HTTP/proto plumbing; callers convert them to user-readable output. +""" + +from typing import Optional + +from keeper_secrets_manager_core.utils import url_safe_str_to_bytes + +from ...params import KeeperParams +from ...proto import cnapp_pb2 + + +# NOTE: `router_helper` is imported lazily inside `_post_request_to_router` below. +# Importing it at module top creates this import-time chain: +# cnapp_helper -> router_helper -> gateway_helper +# -> keepercommander.commands.utils -> commands.ksm +# -> commands.record -> commands.ksm (ksm still partially loaded — crash) +# That `record <-> ksm` cycle is pre-existing and only works because production +# code paths load `record` first. Tests that import `cnapp_helper` cold hit the +# cycle directly. TODO(KC-1290): break the record↔ksm cycle so this wrapper can be removed. +def _post_request_to_router(params, endpoint, **kwargs): + """Lazy proxy to `router_helper._post_request_to_router`. + + Defined as a module-level function so callers (and `unittest.mock.patch.object`) + can keep referring to `cnapp_helper._post_request_to_router` as if it were the + original symbol.""" + from .router_helper import _post_request_to_router as _real_post + return _real_post(params, endpoint, **kwargs) + + +# Public re-exports — let commands/tests reach proto types via the helper module so they +# don't need to know the on-disk proto path. +CnappProvider = cnapp_pb2.CnappProvider +CnappRemediationAction = cnapp_pb2.CnappRemediationAction + + +# --------------------------------------------------------------------------- +# Conversion utilities +# --------------------------------------------------------------------------- + +def _to_uid_bytes(uid): # type: (Optional[str]) -> bytes + """Convert a base64url-encoded UID string to bytes; empty/None -> empty bytes.""" + if not uid: + return b'' + if isinstance(uid, bytes): + return uid + return url_safe_str_to_bytes(uid) + + +def provider_from_name(name): # type: (str) -> int + """Resolve a human-typed provider name (e.g. "wiz") to a CnappProvider enum value. + + Accepts the bare provider keyword ("wiz") or the full proto symbol + ("CNAPP_PROVIDER_WIZ"); case-insensitive. Raises ValueError on unknown input.""" + if not name: + return cnapp_pb2.CNAPP_PROVIDER_UNSPECIFIED + normalized = name.strip().upper() + if not normalized.startswith('CNAPP_PROVIDER_'): + normalized = 'CNAPP_PROVIDER_' + normalized + try: + return cnapp_pb2.CnappProvider.Value(normalized) + except ValueError as e: + valid = [n for n in cnapp_pb2.CnappProvider.keys() if n != 'CNAPP_PROVIDER_UNSPECIFIED'] + raise ValueError(f"Unknown CNAPP provider '{name}'. Valid options: {', '.join(valid)}") from e + + +def action_from_name(name): # type: (str) -> int + """Resolve a remediation action name to its enum int. Case-insensitive; accepts the + short keyword (e.g. "rotate_credentials") or the full proto symbol.""" + if not name: + return cnapp_pb2.UNSPECIFIED + normalized = name.strip().upper().replace('-', '_') + try: + return cnapp_pb2.CnappRemediationAction.Value(normalized) + except ValueError as e: + valid = [n for n in cnapp_pb2.CnappRemediationAction.keys() if n != 'UNSPECIFIED'] + raise ValueError(f"Unknown remediation action '{name}'. Valid options: {', '.join(valid)}") from e + + +# --------------------------------------------------------------------------- +# Configuration endpoints +# --------------------------------------------------------------------------- + +def _build_configuration(network_uid, provider, client_id=None, client_secret=None, + api_endpoint_url=None, cnapp_config_record_uid=None, + auth_endpoint_url=None): + # type: (str, int, Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]) -> cnapp_pb2.CnappConfiguration + rq = cnapp_pb2.CnappConfiguration() + rq.networkUid = _to_uid_bytes(network_uid) + rq.provider = provider + if client_id: + rq.clientId = client_id + if client_secret: + rq.clientSecret = client_secret + if api_endpoint_url: + rq.apiEndpointUrl = api_endpoint_url + if cnapp_config_record_uid: + rq.cnappConfigRecordUid = _to_uid_bytes(cnapp_config_record_uid) + if auth_endpoint_url: + rq.authEndpointUrl = auth_endpoint_url + return rq + + +def set_cnapp_configuration(params, network_uid, provider, client_id, client_secret, + api_endpoint_url, cnapp_config_record_uid, auth_endpoint_url=None): + # type: (KeeperParams, str, int, str, str, str, str, Optional[str]) -> cnapp_pb2.CnappConfiguration + """Create or update the CNAPP provider configuration on a network. + + krouter validates the credentials against the provider before persisting; an empty + `client_secret` tells krouter to keep the previously stored value (useful for edits + that only change the endpoint or record UID). + + `auth_endpoint_url` is the provider's OAuth2 token endpoint, letting customers point + at their own tenant/region (e.g. EU vs US Wiz auth host) without a code change.""" + rq = _build_configuration(network_uid, provider, client_id, client_secret, + api_endpoint_url, cnapp_config_record_uid, auth_endpoint_url) + return _post_request_to_router(params, 'cnapp/configuration/set', rq_proto=rq, + rs_type=cnapp_pb2.CnappConfiguration) + + +def test_cnapp_configuration(params, network_uid, provider, client_id, client_secret, + api_endpoint_url, auth_endpoint_url=None): + # type: (KeeperParams, str, int, str, str, str, Optional[str]) -> None + """Probe the provider with the supplied credentials without persisting anything. + + Returns None on success; raises on validation failure (RRC_BAD_REQUEST with the + provider's reason in the message).""" + rq = _build_configuration(network_uid, provider, client_id, client_secret, + api_endpoint_url, cnapp_config_record_uid=None, + auth_endpoint_url=auth_endpoint_url) + return _post_request_to_router(params, 'cnapp/configuration/test', rq_proto=rq) + + +def test_cnapp_encrypter(params, url_base_encrypter): + # type: (KeeperParams, str) -> None + """Issue a `GET /health` against the customer-deployed Encrypter via krouter. + + Used by the UI/CLI to check that the Encrypter URL is reachable before saving a + configuration that references it. Raises on non-200 or transport error.""" + rq = cnapp_pb2.CnappTestEncrypterRequest() + rq.urlBaseEncrypter = url_base_encrypter + return _post_request_to_router(params, 'cnapp/configuration/test-encrypter', rq_proto=rq) + + +def read_cnapp_configuration(params, network_uid, provider): + # type: (KeeperParams, str, int) -> cnapp_pb2.CnappConfiguration + """Read the persisted CNAPP configuration for a network. Note: krouter never returns + the `clientSecret` field — only the endpoint, client id and config record UID.""" + rq = _build_configuration(network_uid, provider) + return _post_request_to_router(params, 'cnapp/configuration/read', rq_proto=rq, + rs_type=cnapp_pb2.CnappConfiguration) + + +def delete_cnapp_configuration(params, network_uid): + # type: (KeeperParams, str) -> None + """Remove the CNAPP configuration on a network. Raises RRC_BAD_STATE if none exists.""" + rq = cnapp_pb2.CnappDeleteConfigurationRequest() + rq.networkUid = _to_uid_bytes(network_uid) + return _post_request_to_router(params, 'cnapp/configuration/delete', rq_proto=rq) + + +# --------------------------------------------------------------------------- +# Queue endpoints +# --------------------------------------------------------------------------- + +def list_cnapp_queue(params, network_uid, status_filter=0): + # type: (KeeperParams, str, int) -> cnapp_pb2.CnappQueueListResponse + """List queued CNAPP issues for a network. `status_filter=0` returns all statuses.""" + rq = cnapp_pb2.CnappQueueListRequest() + rq.networkUid = _to_uid_bytes(network_uid) + rq.statusFilter = int(status_filter) if status_filter is not None else 0 + return _post_request_to_router(params, 'cnapp/queue', rq_proto=rq, + rs_type=cnapp_pb2.CnappQueueListResponse) + + +def associate_cnapp_record(params, cnapp_queue_id, record_uid): + # type: (KeeperParams, int, str) -> None + """Attach a vault record to a queue item — required before remediation.""" + rq = cnapp_pb2.CnappAssociateRequest() + rq.cnappQueueId = int(cnapp_queue_id) + rq.recordUid = _to_uid_bytes(record_uid) + return _post_request_to_router(params, 'cnapp/queue/associate', rq_proto=rq) + + +def remediate_cnapp_queue_item(params, cnapp_queue_id, action_type, provider=None, + cnapp_config_record_uid=None, resource_ref=None, + pwd_complexity=None, controller_uid=None, + message_uid=None, group_name=None): + # type: (KeeperParams, int, int, Optional[int], Optional[str], Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]) -> cnapp_pb2.CnappRemediateResponse + """Trigger a remediation action against the gateway for a queued issue. + + Currently krouter only dispatches `ROTATE_CREDENTIALS`; other actions return + RRC_BAD_REQUEST. The optional fields are forwarded as-is so this helper stays + forward-compatible with new action types.""" + rq = cnapp_pb2.CnappRemediateRequest() + rq.cnappQueueId = int(cnapp_queue_id) + rq.actionType = int(action_type) + if provider is not None: + rq.provider = int(provider) + if cnapp_config_record_uid: + rq.cnappConfigurationRecordUid = _to_uid_bytes(cnapp_config_record_uid) + if resource_ref: + rq.resourceRef = _to_uid_bytes(resource_ref) + if pwd_complexity: + rq.pwdComplexity = pwd_complexity + if controller_uid: + rq.controllerUid = controller_uid + if message_uid: + rq.messageUid = _to_uid_bytes(message_uid) + if group_name: + rq.groupName = group_name + return _post_request_to_router(params, 'cnapp/queue/remediate', rq_proto=rq, + rs_type=cnapp_pb2.CnappRemediateResponse) + + +def set_cnapp_queue_status(params, cnapp_queue_id, cnapp_queue_status_id, reason=None): + # type: (KeeperParams, int, int, Optional[str]) -> cnapp_pb2.CnappSetStatusResponse + """Set the local status on a queue item; krouter best-effort notifies the provider.""" + rq = cnapp_pb2.CnappSetStatusRequest() + rq.cnappQueueId = int(cnapp_queue_id) + rq.cnappQueueStatusId = int(cnapp_queue_status_id) + if reason: + rq.reason = reason + return _post_request_to_router(params, 'cnapp/queue/set-status', rq_proto=rq, + rs_type=cnapp_pb2.CnappSetStatusResponse) + + +def delete_cnapp_queue_item(params, cnapp_queue_id): + # type: (KeeperParams, int) -> None + """Remove a queue item entirely. Raises RRC_BAD_REQUEST if the queue id is unknown.""" + rq = cnapp_pb2.CnappDeleteQueueItemRequest() + rq.cnappQueueId = int(cnapp_queue_id) + return _post_request_to_router(params, 'cnapp/queue/delete', rq_proto=rq) diff --git a/keepercommander/commands/pam_import/kcm_import.py b/keepercommander/commands/pam_import/kcm_import.py index 5bc4e5f6e..bb2439003 100644 --- a/keepercommander/commands/pam_import/kcm_import.py +++ b/keepercommander/commands/pam_import/kcm_import.py @@ -1367,6 +1367,7 @@ def execute(self, params, **kwargs): fd = os.open(output_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) with os.fdopen(fd, 'w') as f: json.dump(out_data, f, indent=2) + utils.set_file_permissions(output_file) redact_note = '' if include_creds else ' (credentials redacted)' logging.warning('JSON written to %s (%d resources, %d users)%s', output_file, num_resources, num_users, redact_note) diff --git a/keepercommander/commands/record.py b/keepercommander/commands/record.py index 0c8d152de..99355974c 100644 --- a/keepercommander/commands/record.py +++ b/keepercommander/commands/record.py @@ -20,6 +20,7 @@ import re from functools import reduce from typing import Dict, Any, List, Optional, Iterable, Tuple, Set +from .helpers.record import raise_if_unsafe_get_lookup_token from colorama import Fore, Back, Style @@ -294,6 +295,8 @@ def execute(self, params, **kwargs): if not uid: raise CommandError('get', 'UID parameter is required') + raise_if_unsafe_get_lookup_token(uid) + fmt = kwargs.get('format') or 'detail' # First try to interpret as UID diff --git a/keepercommander/discovery_common/__version__.py b/keepercommander/discovery_common/__version__.py index 87bb711d7..7025f5029 100644 --- a/keepercommander/discovery_common/__version__.py +++ b/keepercommander/discovery_common/__version__.py @@ -1 +1 @@ -__version__ = '1.1.14' +__version__ = '1.1.15' diff --git a/keepercommander/discovery_common/infrastructure.py b/keepercommander/discovery_common/infrastructure.py index 744c7951b..ee2675955 100644 --- a/keepercommander/discovery_common/infrastructure.py +++ b/keepercommander/discovery_common/infrastructure.py @@ -394,7 +394,9 @@ def to_dot(self, graph_format: str = "svg", show_hex_uid: bool = False, head_uids.append(edge.head_uid) def _render_edge(e): - + # _render_edge is invoked immediately within the loop below, so capturing the + # loop variables v/content is safe. + # pylint: disable=cell-var-from-loop edge_color = "grey" style = "solid" @@ -439,7 +441,7 @@ def _render_edge(e): tooltip=edge_tip) for head_uid in head_uids: - version, edge = v.get_highest_edge_version(head_uid) + _, edge = v.get_highest_edge_version(head_uid) _render_edge(edge) data_edge = v.get_data() diff --git a/keepercommander/discovery_common/process.py b/keepercommander/discovery_common/process.py index 836243c41..a99ee2b06 100644 --- a/keepercommander/discovery_common/process.py +++ b/keepercommander/discovery_common/process.py @@ -377,7 +377,7 @@ def _directory_exists(self, domain: str, directory_info_func: Callable, context: for provider_vertex in provider_vertices: content = DiscoveryObject.get_discovery_object(provider_vertex) found = False - for domain in domains: + for domain in domains: # pylint: disable=redefined-argument-from-local for provider_domain in content.item.info.get("domains", []): if domain.lower() in provider_domain.lower(): found = True @@ -453,7 +453,7 @@ def _find_directory_user(self, found_vertex = None if find_user is not None: - user, domain = split_user_and_domain(find_user) + user, _ = split_user_and_domain(find_user) if user_content.item.user.lower() == user.lower(): found_vertex = user_vertex elif user_content.item.user.lower() == find_user.lower(): @@ -1185,7 +1185,7 @@ def _process_admin_user(self, # We need to populate the id and uid of the content, now that we have data in the content. self.populate_admin_content_ids(admin_content, resource_vertex) - ad_user, ad_domain = split_user_and_domain(admin_content.item.user) + _, ad_domain = split_user_and_domain(admin_content.item.user) if ad_domain is not None and admin_content.item.source == LOCAL_USER: self.logger.debug("The admin is an directory user, but the source is set to a local user") diff --git a/keepercommander/discovery_common/rm_types.py b/keepercommander/discovery_common/rm_types.py index a647ca933..1aa513d88 100644 --- a/keepercommander/discovery_common/rm_types.py +++ b/keepercommander/discovery_common/rm_types.py @@ -465,7 +465,7 @@ class RmOracleUserAddMeta(RmMetaBase): class RmOracleRoleAddMeta(RmMetaBase): - not_identified: bool = False, + not_identified: bool = False identified_by_password: Optional[str] = None identified_using: Optional[str] = None identified_externally: bool = False diff --git a/keepercommander/discovery_common/types.py b/keepercommander/discovery_common/types.py index 5a79a9115..6942e87eb 100644 --- a/keepercommander/discovery_common/types.py +++ b/keepercommander/discovery_common/types.py @@ -736,7 +736,7 @@ def has_dn(self, user) -> bool: return False - + class PromptResult(BaseModel): # "add" and "ignore" are the only action diff --git a/keepercommander/discovery_common/verify.py b/keepercommander/discovery_common/verify.py index 5e90ab5a7..dd6e6a3fd 100644 --- a/keepercommander/discovery_common/verify.py +++ b/keepercommander/discovery_common/verify.py @@ -403,7 +403,7 @@ def _check(vertex: DAGVertex, indent: int = 0): # Get all the child vertices, allow self ref, so we can delete it if not already deleted. for next_vertex in vertex.has_vertices(allow_self_ref=True): if next_vertex.uid == vertex.uid: - version, edge = next_vertex.get_highest_edge_version(vertex.uid) + _, edge = next_vertex.get_highest_edge_version(vertex.uid) if edge.edge_type == EdgeType.DELETION: continue else: diff --git a/keepercommander/keeper_dag/__version__.py b/keepercommander/keeper_dag/__version__.py index 874042f3b..ed7133b36 100644 --- a/keepercommander/keeper_dag/__version__.py +++ b/keepercommander/keeper_dag/__version__.py @@ -1 +1 @@ -__version__ = '1.1.10' # pragma: no cover +__version__ = '1.1.11' # pragma: no cover diff --git a/keepercommander/keeper_dag/connection/__init__.py b/keepercommander/keeper_dag/connection/__init__.py index 762acde53..24ef3ba66 100644 --- a/keepercommander/keeper_dag/connection/__init__.py +++ b/keepercommander/keeper_dag/connection/__init__.py @@ -31,6 +31,8 @@ class ConnectionBase: ADD_DATA = "/add_data" SYNC = "/sync" + MULTI_SYNC = "/multi_sync" + GET_LEAFS = "/get_leafs" TIMEOUT = 30 @@ -59,7 +61,7 @@ def __init__(self, if self.log_transactions_dir is None: self.log_transactions_dir = "." - if self.log_transactions is True: + if self.log_transactions: self.logger.info("keeper-dag transaction logging is ENABLED; " f"write directory at {self.log_transactions_dir}") @@ -99,8 +101,11 @@ def get_encrypted_payload_data(encrypted_payload_data: bytes) -> bytes: @staticmethod def get_router_host(server_hostname: str): - if server_hostname and '://' in server_hostname: # accept URL-formatted inputs + # Defensive: accept URL-formatted inputs (e.g. "https://keepersecurity.com") + # and extract the bare hostname before the GovCloud subdomain check. + if server_hostname and '://' in server_hostname: server_hostname = server_hostname.split('://', 1)[1].split('/', 1)[0] + # Only PROD GovCloud strips the subdomain (workaround for prod infrastructure). # DEV/QA GOV (govcloud.dev.keepersecurity.us, govcloud.qa.keepersecurity.us) keep govcloud. if server_hostname == 'govcloud.keepersecurity.us': @@ -222,10 +227,10 @@ def sync(self, sync_query: Union[SyncQuery, gs_pb2.GraphSyncQuery], graph_id: Optional[int] = None, endpoint: Optional[str] = None, - agent: Optional[str] = None) -> bytes: + agent: Optional[str] = None) -> Optional[bytes]: if agent is None: - f"keeper-dag/{__version__}" + agent = f"keeper-dag/{__version__}" endpoint = self._endpoint(ConnectionBase.SYNC, endpoint) self.logger.debug(f"endpoint {endpoint}") @@ -238,13 +243,13 @@ def sync(self, headers=headers, payload=sync_query) - if self.use_read_protobuf: + if payload is not None and self.use_read_protobuf: try: self.logger.debug(f"decrypt payload with transmission key {kotlin_bytes(self.transmission_key)}") payload = self.get_encrypted_payload_data(payload) payload = decrypt_aes(payload, self.transmission_key) except Exception as err: - self.logger.error(f"Could not decrypt protobuf graph sync response: {type(err)}, {err}") + self.logger.error(f"Could not decrypt protobuf graph sync response: {err}") self.write_transaction_log( graph_id=graph_id, @@ -290,7 +295,7 @@ def add_data(self, agent: Optional[str] = None): if agent is None: - f"keeper-dag/{__version__}" + agent = f"keeper-dag/{__version__}" endpoint = self._endpoint(ConnectionBase.ADD_DATA, endpoint) self.logger.debug(f"endpoint {endpoint}") @@ -331,3 +336,135 @@ def add_data(self, error=str(err) ) raise DAGException(f"Could not create a new DAG structure: {err}") + + def multi_sync(self, + multi_query: Union[BaseModel, gs_pb2.GraphSyncMultiQuery], + graph_id: Optional[int] = None, + endpoint: Optional[str] = None, + agent: Optional[str] = None) -> bytes: + """POST a GraphSyncMultiQuery to /multi_sync. + + Used by per-graph reads: after `get_leafs` discovers the stream refs + rooted at the graph's origin, `multi_sync` fetches sync data for all + those streams in one round-trip. Mirrors `sync()` in transport shape + (encrypt/headers, decrypt-on-read, transaction log, error handling). + """ + if agent is None: + agent = f"keeper-dag/{__version__}" + + endpoint = self._endpoint(ConnectionBase.MULTI_SYNC, endpoint) + self.logger.debug(f"endpoint {endpoint}") + + try: + multi_query, headers = self.payload_and_headers(multi_query) + payload = self.rest_call_to_router(http_method="POST", + endpoint=endpoint, + agent=agent, + headers=headers, + payload=multi_query) + + if self.use_read_protobuf: + try: + self.logger.debug(f"decrypt payload with transmission key {kotlin_bytes(self.transmission_key)}") + payload = self.get_encrypted_payload_data(payload) + payload = decrypt_aes(payload, self.transmission_key) + except Exception as err: + self.logger.error(f"Could not decrypt protobuf graph multi-sync response: {type(err)}, {err}") + + self.write_transaction_log( + graph_id=graph_id, + request=multi_query, + response=payload, + agent=agent, + endpoint=endpoint, + error=None + ) + + return payload + + except DAGConnectionException as err: + self.write_transaction_log( + graph_id=graph_id, + request=multi_query, + response=None, + agent=agent, + endpoint=endpoint, + error=str(err) + ) + raise err + except Exception as err: + self.write_transaction_log( + graph_id=graph_id, + request=multi_query, + response=None, + agent=agent, + endpoint=endpoint, + error=str(err) + ) + raise DAGException(f"Could not load the DAG structure (multi_sync): {err}") + + def get_leafs(self, + leafs_query: Union[BaseModel, gs_pb2.GraphSyncLeafsQuery], + graph_id: Optional[int] = None, + endpoint: Optional[str] = None, + agent: Optional[str] = None) -> bytes: + """POST a GraphSyncLeafsQuery to /get_leafs. + + Returns the serialized GraphSyncRefsResult — the list of stream refs + rooted at the queried vertices. Used as the discovery step before a + `multi_sync` call (per the per-graph read pattern that Web Vault + already uses). + """ + if agent is None: + agent = f"keeper-dag/{__version__}" + + endpoint = self._endpoint(ConnectionBase.GET_LEAFS, endpoint) + self.logger.debug(f"endpoint {endpoint}") + + try: + leafs_query, headers = self.payload_and_headers(leafs_query) + payload = self.rest_call_to_router(http_method="POST", + endpoint=endpoint, + agent=agent, + headers=headers, + payload=leafs_query) + + if self.use_read_protobuf: + try: + self.logger.debug(f"decrypt payload with transmission key {kotlin_bytes(self.transmission_key)}") + payload = self.get_encrypted_payload_data(payload) + payload = decrypt_aes(payload, self.transmission_key) + except Exception as err: + self.logger.error(f"Could not decrypt protobuf get_leafs response: {type(err)}, {err}") + + self.write_transaction_log( + graph_id=graph_id, + request=leafs_query, + response=payload, + agent=agent, + endpoint=endpoint, + error=None + ) + + return payload + + except DAGConnectionException as err: + self.write_transaction_log( + graph_id=graph_id, + request=leafs_query, + response=None, + agent=agent, + endpoint=endpoint, + error=str(err) + ) + raise err + except Exception as err: + self.write_transaction_log( + graph_id=graph_id, + request=leafs_query, + response=None, + agent=agent, + endpoint=endpoint, + error=str(err) + ) + raise DAGException(f"Could not get leafs: {err}") diff --git a/keepercommander/keeper_dag/connection/ksm.py b/keepercommander/keeper_dag/connection/ksm.py index 58fa4fac0..e24b85592 100644 --- a/keepercommander/keeper_dag/connection/ksm.py +++ b/keepercommander/keeper_dag/connection/ksm.py @@ -165,6 +165,7 @@ def authenticate(self, attempt = 0 while True: + err_msg = "no error message" try: attempt += 1 response = requests.get(url, diff --git a/keepercommander/keeper_dag/connection/local.py b/keepercommander/keeper_dag/connection/local.py index 6a0dac42b..0567860d5 100644 --- a/keepercommander/keeper_dag/connection/local.py +++ b/keepercommander/keeper_dag/connection/local.py @@ -582,6 +582,41 @@ def sync(self, hasMore=has_more ).model_dump_json().encode() + def multi_sync(self, + multi_query: Union[gs_pb2.GraphSyncMultiQuery, Any], + graph_id: Optional[int] = None, + endpoint: Optional[str] = None, + agent: Optional[str] = None) -> bytes: + """Local mirror of the network per-graph ``multi_sync``. + + The local SQLite store has no per-graph URL routing, so each sub-query + in the ``GraphSyncMultiQuery`` is run through this connection's own + ``sync()`` — identical stream / sync-point / graph-id semantics as the + single-stream read and the save path — and the per-stream results are + assembled into the same multi-stream envelope the network endpoint + returns: ``GraphSyncMultiResult`` (protobuf) or ``{"results": [...]}`` + (JSON). + """ + is_protobuf = isinstance(multi_query, gs_pb2.GraphSyncMultiQuery) + queries = list(multi_query.queries) + + if is_protobuf: + multi = gs_pb2.GraphSyncMultiResult() + for sub_query in queries: + single = self.sync(sub_query, graph_id=graph_id, + endpoint=endpoint, agent=agent) + result = gs_pb2.GraphSyncResult() + result.ParseFromString(single) + multi.results.add().CopyFrom(result) + return multi.SerializeToString() + + results = [] + for sub_query in queries: + single = self.sync(sub_query, graph_id=graph_id, + endpoint=endpoint, agent=agent) + results.append(json.loads(single)) + return json.dumps({"results": results}).encode() + def debug_dump(self) -> str: ret = "" diff --git a/keepercommander/keeper_dag/dag.py b/keepercommander/keeper_dag/dag.py index 85a2d1296..a9066528e 100644 --- a/keepercommander/keeper_dag/dag.py +++ b/keepercommander/keeper_dag/dag.py @@ -15,7 +15,7 @@ import importlib import traceback import sys -from typing import Optional, Union, List, Any, Tuple, TYPE_CHECKING +from typing import Optional, Union, List, Any, Tuple, Dict, TYPE_CHECKING if TYPE_CHECKING: from .connection import ConnectionBase @@ -225,7 +225,7 @@ def close(self): try: # Safely get the root vertex without creating a new one if hasattr(self, '_vertices') and hasattr(self, 'uid') and hasattr(self, '_uid_lookup'): - if len(self._vertices) > 0 and self.uid in self._uid_lookup: + if len(self._vertices) > 0 and self.uid is not None and self.uid in self._uid_lookup: idx = self._uid_lookup[self.uid] if idx < len(self._vertices): root = self._vertices[idx] @@ -298,7 +298,7 @@ def debug_stacktrace(self): trc = 'Traceback (most recent call last):\n' msg = trc + ''.join(traceback.format_list(stack)) if exc is not None: - msg += ' ' + traceback.format_exc().lstrip(trc) + msg += ' ' + traceback.format_exc().removeprefix(trc) self.debug(msg) def __str__(self): @@ -310,6 +310,8 @@ def __str__(self): for v in self.all_vertices: ret += f" * {v.uid}, Keys: {v.keychain}, Active: {v.active}\n" for e in v.edges: + if e is None: + continue if e.edge_type == EdgeType.DATA: ret += " + has a DATA edge" if e.content is not None: @@ -504,11 +506,28 @@ def get_vertices_by_path_value(self, path: str, inc_deleted: bool = False) -> Li for vertex in vertices: for edge in vertex.edges: - if edge.path == path: + if edge is not None and edge.path == path: results.append(vertex) return results def _sync(self, sync_point: int = 0) -> Tuple[List[DAGData], int]: + """Dispatch to legacy single-stream sync or per-graph multi-stream sync. + + When `read_endpoint` is set, the server uses the per-graph URL pattern + (`/api/user/graph-sync//...`). That model splits the graph across + multiple streams, so a single-stream `sync` returns only a fragment. + Web Vault uses `get_leafs` -> `multi_sync` to read the full graph; + this client follows the same pattern. + + When only `graph_id` is set (legacy single-endpoint transport), the + single-stream sync remains correct. + """ + if self.read_endpoint is not None: + return self._sync_per_graph(sync_point) + return self._sync_legacy(sync_point) + + def _sync_legacy(self, sync_point: int = 0) -> Tuple[List[DAGData], int]: + """Single-stream sync against the legacy `/sync` endpoint.""" # The web service will send 500 items, if there is more the 'has_more' flag is set to True. has_more = True @@ -543,6 +562,61 @@ def _sync(self, sync_point: int = 0) -> Tuple[List[DAGData], int]: return all_data, sync_point + def _sync_per_graph(self, sync_point: int = 0) -> Tuple[List[DAGData], int]: + """Multi-stream read against the per-graph endpoints. + + The graph's data lives in a single stream keyed by the graph's origin + (e.g. the PAM Configuration record's UID for TunnelDAG). We multi_sync + that stream directly — no `get_leafs` discovery step needed for this + caller pattern. (`Connection.get_leafs` remains available for callers + that start from leaf vertices and need to discover stream roots.) + + Returns aggregated (data, max_sync_point) just like `_sync_legacy`. + """ + + origin_bytes = urlsafe_str_to_bytes(self.uid) + + # Stream keyed by the graph's origin (e.g. config_uid for PAM linking). + per_stream_sync_point: Dict[bytes, int] = {origin_bytes: sync_point} + all_data: List[DAGData] = [] + max_sync_point = sync_point + + while per_stream_sync_point: + stream_ids = list(per_stream_sync_point.keys()) + multi_query = self.read_struct_obj.multi_sync_query( + stream_ids=stream_ids, + origin=origin_bytes, + sync_point=sync_point, + ) + # Per-stream syncPoint adjustment so each stream advances + # independently across pagination rounds (proto variant only; + # JSON variant builds via SyncQuery which already carries syncPoint). + try: + for inner, sid in zip(multi_query.queries, stream_ids): + inner.syncPoint = per_stream_sync_point[sid] + except Exception: # pragma: no cover - JSON variant has no .queries + pass + + multi_response = self.conn.multi_sync( + multi_query=multi_query, + graph_id=self.graph_id, + endpoint=self.read_endpoint, + agent=self.agent, + ) + multi_results = self.read_struct_obj.get_multi_sync_result(multi_response) + + next_per_stream: Dict[bytes, int] = {} + for result in multi_results: + all_data += result.data + if result.syncPoint and result.syncPoint > max_sync_point: + max_sync_point = result.syncPoint + if result.hasMore and result.streamId is not None: + next_per_stream[bytes(result.streamId)] = result.syncPoint + + per_stream_sync_point = next_per_stream + + return all_data, max_sync_point + def _load(self, sync_point: int = 0): """ @@ -619,11 +693,12 @@ def _load(self, sync_point: int = 0): name=data.parentRef.name, vertex_type=RefType.GENERAL ) + # Get the head vertex, which will exist now. - head = self.get_vertex(head_uid) + head = self.get_vertex_by_uid(head_uid) if head is None or head == "": head = tail - head = self.get_vertex_by_uid(head_uid) + self.debug(f" * tail {tail_uid} belongs to {head_uid}, " f"edge type {edge_type}", level=3) diff --git a/keepercommander/keeper_dag/struct/__init__.py b/keepercommander/keeper_dag/struct/__init__.py index 53aa771da..ac87baa7b 100644 --- a/keepercommander/keeper_dag/struct/__init__.py +++ b/keepercommander/keeper_dag/struct/__init__.py @@ -54,3 +54,34 @@ def payload(origin_ref: Union[Ref, gs_pb2.GraphSyncRef], graph_id: Optional[int] = None) -> Union[DataPayload, gs_pb2.GraphSyncAddDataRequest]: pass + + # --- Per-graph multi-stream read transport --------------------------- + # Used by DAG._sync_per_graph when read_endpoint is set. Two-step pattern: + # 1. leafs_query(...) -> get_leafs_result(...) discovers stream refs. + # 2. multi_sync_query(...) -> get_multi_sync_result(...) fetches data. + + def leafs_query(self, + vertices: List[str]) -> Union[BaseModel, gs_pb2.GraphSyncLeafsQuery]: + """Build a GraphSyncLeafsQuery from a list of vertex UIDs (URL-safe str).""" + pass + + @staticmethod + def get_leafs_result(results: bytes) -> List[Ref]: + """Parse GraphSyncRefsResult bytes into a list of Ref objects. + Each Ref's `value` is the stream UID rooted under the queried vertex. + """ + pass + + def multi_sync_query(self, + stream_ids: List[bytes], + origin: bytes, + sync_point: int = 0) -> Union[BaseModel, gs_pb2.GraphSyncMultiQuery]: + """Build a GraphSyncMultiQuery wrapping one GraphSyncQuery per stream.""" + pass + + @staticmethod + def get_multi_sync_result(results: bytes): # -> List[SyncData] + """Parse GraphSyncMultiResult bytes into a list of SyncData, one per + inner GraphSyncResult (each carrying its own streamId/syncPoint/hasMore). + """ + pass diff --git a/keepercommander/keeper_dag/struct/default.py b/keepercommander/keeper_dag/struct/default.py index a58558bff..dd23a6256 100644 --- a/keepercommander/keeper_dag/struct/default.py +++ b/keepercommander/keeper_dag/struct/default.py @@ -1,8 +1,10 @@ from __future__ import annotations +import json from . import DataStructBase from ..types import SyncQuery, Ref, RefType, DAGData, DataPayload, EdgeType, SyncData -from ..crypto import generate_random_bytes, generate_uid_str, bytes_to_str +from ..crypto import generate_random_bytes, generate_uid_str, bytes_to_str, bytes_to_urlsafe_str import base64 +from pydantic import BaseModel from typing import Optional, List @@ -79,3 +81,62 @@ def payload(origin_ref: Ref, dataList=data_list, graphId=graph_id ) + + # --- Per-graph multi-stream read transport --------------------------- + + class _LeafsQuery(BaseModel): + vertices: List[str] + + class _MultiSyncQuery(BaseModel): + queries: List[SyncQuery] + + def leafs_query(self, vertices: List[str]) -> 'DataStruct._LeafsQuery': + return DataStruct._LeafsQuery(vertices=list(vertices)) + + @staticmethod + def get_leafs_result(results: bytes) -> List[Ref]: + try: + obj = json.loads(results) + except Exception as err: + raise Exception(f"Could not parse the leafs JSON result: {err}") + refs_list = obj.get("refs", []) if isinstance(obj, dict) else obj + out: List[Ref] = [] + for r in refs_list: + # Server may return either {type, value, name} or just a value str. + if isinstance(r, dict): + value = r.get("value") + if isinstance(value, bytes): + value = bytes_to_urlsafe_str(value) + out.append(Ref( + type=RefType(r["type"]) if r.get("type") is not None else RefType.GENERAL, + value=value, + name=r.get("name") or None, + )) + return out + + def multi_sync_query(self, + stream_ids: List[bytes], + origin: bytes, + sync_point: int = 0) -> 'DataStruct._MultiSyncQuery': + queries = [ + SyncQuery( + streamId=bytes_to_urlsafe_str(sid), + deviceId=bytes_to_urlsafe_str(origin), + syncPoint=sync_point, + graphId=None, + ) + for sid in stream_ids + ] + return DataStruct._MultiSyncQuery(queries=queries) + + @staticmethod + def get_multi_sync_result(results: bytes) -> List[SyncData]: + try: + obj = json.loads(results) + except Exception as err: + raise Exception(f"Could not parse the multi_sync JSON result: {err}") + items = obj.get("results", []) if isinstance(obj, dict) else obj + out: List[SyncData] = [] + for item in items: + out.append(SyncData.model_validate(item)) + return out diff --git a/keepercommander/keeper_dag/struct/protobuf.py b/keepercommander/keeper_dag/struct/protobuf.py index fcd123d5d..7a449f918 100644 --- a/keepercommander/keeper_dag/struct/protobuf.py +++ b/keepercommander/keeper_dag/struct/protobuf.py @@ -58,23 +58,16 @@ def sync_query(self, ) @staticmethod - def get_sync_result(results: bytes) -> SyncData: - - try: - result = gs_pb2.GraphSyncResult() - result.ParseFromString(results) - except Exception as err: - raise Exception(f"Could not parse the GraphSyncResult message: {err}") - - message = gs_pb2.GraphSyncResult() - message.ParseFromString(results) - + def _sync_data_from_result(message: gs_pb2.GraphSyncResult) -> SyncData: + """Convert a single GraphSyncResult protobuf into a SyncData pydantic + model. Extracted so both single-`sync` and multi_sync code paths share + identical per-result decoding. + """ data_list: List[SyncDataItem] = [] for item in message.data: data_list.append( SyncDataItem( type=DataStruct.PB_TO_DATA_MAP.get(item.data.type), - # content=bytes_to_str(item.data.content), content=item.data.content, content_is_base64=False, ref=Ref( @@ -92,9 +85,21 @@ def get_sync_result(results: bytes) -> SyncData: return SyncData( syncPoint=message.syncPoint, data=data_list, - hasMore=message.hasMore + hasMore=message.hasMore, + streamId=bytes(message.streamId) if message.streamId else None, ) + @staticmethod + def get_sync_result(results: bytes) -> SyncData: + + try: + message = gs_pb2.GraphSyncResult() + message.ParseFromString(results) + except Exception as err: + raise Exception(f"Could not parse the GraphSyncResult message: {err}") + + return DataStruct._sync_data_from_result(message) + @staticmethod def origin_ref(origin_ref_value: bytes, name: str) -> gs_pb2.GraphSyncRef: @@ -149,3 +154,49 @@ def payload(origin_ref: gs_pb2.GraphSyncRef, return gs_pb2.GraphSyncAddDataRequest( origin=origin_ref, data=data_list) + + # --- Per-graph multi-stream read transport --------------------------- + + def leafs_query(self, vertices: List[str]) -> gs_pb2.GraphSyncLeafsQuery: + return gs_pb2.GraphSyncLeafsQuery( + vertices=[urlsafe_str_to_bytes(v) for v in vertices] + ) + + @staticmethod + def get_leafs_result(results: bytes) -> List[Ref]: + msg = gs_pb2.GraphSyncRefsResult() + try: + msg.ParseFromString(results) + except Exception as err: + raise Exception(f"Could not parse the GraphSyncRefsResult message: {err}") + return [ + Ref( + type=DataStruct.PB_TO_REF_MAP.get(r.type), + value=bytes_to_urlsafe_str(r.value), + name=r.name or None, + ) + for r in msg.refs + ] + + def multi_sync_query(self, + stream_ids: List[bytes], + origin: bytes, + sync_point: int = 0) -> gs_pb2.GraphSyncMultiQuery: + return gs_pb2.GraphSyncMultiQuery(queries=[ + gs_pb2.GraphSyncQuery( + streamId=sid, + origin=origin, + syncPoint=sync_point, + maxCount=0, # let krouter default (currently 500) + ) + for sid in stream_ids + ]) + + @staticmethod + def get_multi_sync_result(results: bytes) -> List[SyncData]: + msg = gs_pb2.GraphSyncMultiResult() + try: + msg.ParseFromString(results) + except Exception as err: + raise Exception(f"Could not parse the GraphSyncMultiResult message: {err}") + return [DataStruct._sync_data_from_result(r) for r in msg.results] diff --git a/keepercommander/keeper_dag/types.py b/keepercommander/keeper_dag/types.py index 9ab242c88..4b614fbfe 100644 --- a/keepercommander/keeper_dag/types.py +++ b/keepercommander/keeper_dag/types.py @@ -124,6 +124,17 @@ class PamEndpoints(BaseEnum): PamGraphId.SERVICE_LINKS.value: PamEndpoints.SERVICE_LINKS, } +# Inverse map for callers that have a graph_id int and need the PamEndpoints enum +# to address the new /api/user/graph-sync// routes. +GRAPH_ID_TO_ENDPOINT = { + PamGraphId.PAM.value: PamEndpoints.PAM, + PamGraphId.DISCOVERY_RULES.value: PamEndpoints.DISCOVERY_RULES, + PamGraphId.DISCOVERY_JOBS.value: PamEndpoints.DISCOVERY_JOBS, + PamGraphId.INFRASTRUCTURE.value: PamEndpoints.INFRASTRUCTURE, + PamGraphId.SERVICE_LINKS.value: PamEndpoints.SERVICE_LINKS, +} + + class SyncQuery(BaseModel): streamId: Optional[str] = None # base64 of a user's ID who is syncing. deviceId: Optional[str] = None @@ -134,7 +145,10 @@ class SyncQuery(BaseModel): class SyncDataItem(BaseModel): ref: Ref parentRef: Optional[Ref] = None - content: Optional[str] = None + # Either a base64-encoded string (JSON wire format) or raw bytes + # (protobuf wire format). `content_is_base64` distinguishes them so the + # consumer can decode appropriately. + content: Optional[Union[str, bytes]] = None content_is_base64: bool = True type: Optional[str] = None path: Optional[str] = None @@ -145,6 +159,9 @@ class SyncData(BaseModel): syncPoint: int data: List[SyncDataItem] hasMore: bool + # Per-graph multi_sync: identifies which stream this result came from. + # None for single-stream `sync` results (backward compatible). + streamId: Optional[bytes] = None class Ref(BaseModel): diff --git a/keepercommander/keeper_dag/utils.py b/keepercommander/keeper_dag/utils.py index 43ad5a76e..51e33c921 100644 --- a/keepercommander/keeper_dag/utils.py +++ b/keepercommander/keeper_dag/utils.py @@ -55,5 +55,5 @@ def set_file_permissions(file_path): # type: (str) -> None check=False, capture_output=True) subprocess.run(["icacls", file_path, "/grant", f"{username}:M"], check=True, capture_output=True) logging.debug(f'Set secure permissions (owner Modify only) for Windows file: {file_path}') - except Exception: - logging.warning(f'Failed to set file permissions for {file_path}') + except (OSError, subprocess.SubprocessError) as err: + logging.warning(f'Failed to set file permissions for {file_path}: {err}') diff --git a/keepercommander/proto/cnapp_pb2.py b/keepercommander/proto/cnapp_pb2.py new file mode 100644 index 000000000..a146a3e0e --- /dev/null +++ b/keepercommander/proto/cnapp_pb2.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: cnapp.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0b\x63napp.proto\x12\x05\x43NAPP\"A\n\x15\x43nappQueueListRequest\x12\x12\n\nnetworkUid\x18\x01 \x01(\x0c\x12\x14\n\x0cstatusFilter\x18\x02 \x01(\x05\"O\n\x16\x43nappQueueListResponse\x12$\n\x05items\x18\x01 \x03(\x0b\x32\x15.CNAPP.CnappQueueItem\x12\x0f\n\x07hasMore\x18\x02 \x01(\x08\"\xd0\x01\n\x0e\x43nappQueueItem\x12\x14\n\x0c\x63nappQueueId\x18\x01 \x01(\x05\x12-\n\x0f\x63nappProviderId\x18\x02 \x01(\x0e\x32\x14.CNAPP.CnappProvider\x12\x1a\n\x12\x63nappQueueStatusId\x18\x03 \x01(\x05\x12\x12\n\nreceivedAt\x18\x04 \x01(\x03\x12\x12\n\nresolvedAt\x18\x05 \x01(\x03\x12\x11\n\tnetworkId\x18\x06 \x01(\x0c\x12\x0f\n\x07payload\x18\x07 \x01(\x0c\x12\x11\n\trecordUid\x18\x08 \x01(\x0c\"@\n\x15\x43nappAssociateRequest\x12\x11\n\trecordUid\x18\x01 \x01(\x0c\x12\x14\n\x0c\x63nappQueueId\x18\x02 \x01(\x05\"4\n\x16\x43nappAssociateResponse\x12\x1a\n\x12\x63nappQueueStatusId\x18\x01 \x01(\x05\"\x97\x02\n\x15\x43nappRemediateRequest\x12\x14\n\x0c\x63nappQueueId\x18\x01 \x01(\x05\x12\x31\n\nactionType\x18\x02 \x01(\x0e\x32\x1d.CNAPP.CnappRemediationAction\x12#\n\x1b\x63nappConfigurationRecordUid\x18\x03 \x01(\x0c\x12\x15\n\rpwdComplexity\x18\x04 \x01(\t\x12\x13\n\x0bresourceRef\x18\x05 \x01(\x0c\x12&\n\x08provider\x18\x06 \x01(\x0e\x32\x14.CNAPP.CnappProvider\x12\x15\n\rcontrollerUid\x18\x07 \x01(\t\x12\x12\n\nmessageUid\x18\x08 \x01(\x0c\x12\x11\n\tgroupName\x18\t \x01(\t\"w\n\x16\x43nappRemediateResponse\x12\x31\n\nactionType\x18\x01 \x01(\x0e\x32\x1d.CNAPP.CnappRemediationAction\x12\x0e\n\x06result\x18\x02 \x01(\t\x12\x1a\n\x12\x63nappQueueStatusId\x18\x03 \x01(\x05\"Y\n\x15\x43nappSetStatusRequest\x12\x14\n\x0c\x63nappQueueId\x18\x01 \x01(\x05\x12\x1a\n\x12\x63nappQueueStatusId\x18\x02 \x01(\x05\x12\x0e\n\x06reason\x18\x03 \x01(\t\"4\n\x16\x43nappSetStatusResponse\x12\x1a\n\x12\x63nappQueueStatusId\x18\x01 \x01(\x05\"3\n\x1b\x43nappDeleteQueueItemRequest\x12\x14\n\x0c\x63nappQueueId\x18\x01 \x01(\x05\"\x1e\n\x1c\x43nappDeleteQueueItemResponse\"\xc7\x01\n\x12\x43nappConfiguration\x12\x12\n\nnetworkUid\x18\x01 \x01(\x0c\x12&\n\x08provider\x18\x02 \x01(\x0e\x32\x14.CNAPP.CnappProvider\x12\x10\n\x08\x63lientId\x18\x03 \x01(\t\x12\x14\n\x0c\x63lientSecret\x18\x04 \x01(\t\x12\x16\n\x0e\x61piEndpointUrl\x18\x05 \x01(\t\x12\x1c\n\x14\x63nappConfigRecordUid\x18\x06 \x01(\x0c\x12\x17\n\x0f\x61uthEndpointUrl\x18\x07 \x01(\t\"5\n\x1f\x43nappDeleteConfigurationRequest\x12\x12\n\nnetworkUid\x18\x01 \x01(\x0c\"5\n\x19\x43nappTestEncrypterRequest\x12\x18\n\x10urlBaseEncrypter\x18\x01 \x01(\t*G\n\rCnappProvider\x12\x1e\n\x1a\x43NAPP_PROVIDER_UNSPECIFIED\x10\x00\x12\x16\n\x12\x43NAPP_PROVIDER_WIZ\x10\x01*\x83\x01\n\x16\x43nappRemediationAction\x12\x0f\n\x0bUNSPECIFIED\x10\x00\x12\x16\n\x12ROTATE_CREDENTIALS\x10\x01\x12\x11\n\rMANAGE_ACCESS\x10\x02\x12\x0e\n\nJIT_ACCESS\x10\x03\x12\x1d\n\x19REMOVE_STANDING_PRIVILEGE\x10\x04\x42!\n\x18\x63om.keepersecurity.protoB\x05\x43nappb\x06proto3') + +_CNAPPPROVIDER = DESCRIPTOR.enum_types_by_name['CnappProvider'] +CnappProvider = enum_type_wrapper.EnumTypeWrapper(_CNAPPPROVIDER) +_CNAPPREMEDIATIONACTION = DESCRIPTOR.enum_types_by_name['CnappRemediationAction'] +CnappRemediationAction = enum_type_wrapper.EnumTypeWrapper(_CNAPPREMEDIATIONACTION) +CNAPP_PROVIDER_UNSPECIFIED = 0 +CNAPP_PROVIDER_WIZ = 1 +UNSPECIFIED = 0 +ROTATE_CREDENTIALS = 1 +MANAGE_ACCESS = 2 +JIT_ACCESS = 3 +REMOVE_STANDING_PRIVILEGE = 4 + + +_CNAPPQUEUELISTREQUEST = DESCRIPTOR.message_types_by_name['CnappQueueListRequest'] +_CNAPPQUEUELISTRESPONSE = DESCRIPTOR.message_types_by_name['CnappQueueListResponse'] +_CNAPPQUEUEITEM = DESCRIPTOR.message_types_by_name['CnappQueueItem'] +_CNAPPASSOCIATEREQUEST = DESCRIPTOR.message_types_by_name['CnappAssociateRequest'] +_CNAPPASSOCIATERESPONSE = DESCRIPTOR.message_types_by_name['CnappAssociateResponse'] +_CNAPPREMEDIATEREQUEST = DESCRIPTOR.message_types_by_name['CnappRemediateRequest'] +_CNAPPREMEDIATERESPONSE = DESCRIPTOR.message_types_by_name['CnappRemediateResponse'] +_CNAPPSETSTATUSREQUEST = DESCRIPTOR.message_types_by_name['CnappSetStatusRequest'] +_CNAPPSETSTATUSRESPONSE = DESCRIPTOR.message_types_by_name['CnappSetStatusResponse'] +_CNAPPDELETEQUEUEITEMREQUEST = DESCRIPTOR.message_types_by_name['CnappDeleteQueueItemRequest'] +_CNAPPDELETEQUEUEITEMRESPONSE = DESCRIPTOR.message_types_by_name['CnappDeleteQueueItemResponse'] +_CNAPPCONFIGURATION = DESCRIPTOR.message_types_by_name['CnappConfiguration'] +_CNAPPDELETECONFIGURATIONREQUEST = DESCRIPTOR.message_types_by_name['CnappDeleteConfigurationRequest'] +_CNAPPTESTENCRYPTERREQUEST = DESCRIPTOR.message_types_by_name['CnappTestEncrypterRequest'] +CnappQueueListRequest = _reflection.GeneratedProtocolMessageType('CnappQueueListRequest', (_message.Message,), { + 'DESCRIPTOR' : _CNAPPQUEUELISTREQUEST, + '__module__' : 'cnapp_pb2' + # @@protoc_insertion_point(class_scope:CNAPP.CnappQueueListRequest) + }) +_sym_db.RegisterMessage(CnappQueueListRequest) + +CnappQueueListResponse = _reflection.GeneratedProtocolMessageType('CnappQueueListResponse', (_message.Message,), { + 'DESCRIPTOR' : _CNAPPQUEUELISTRESPONSE, + '__module__' : 'cnapp_pb2' + # @@protoc_insertion_point(class_scope:CNAPP.CnappQueueListResponse) + }) +_sym_db.RegisterMessage(CnappQueueListResponse) + +CnappQueueItem = _reflection.GeneratedProtocolMessageType('CnappQueueItem', (_message.Message,), { + 'DESCRIPTOR' : _CNAPPQUEUEITEM, + '__module__' : 'cnapp_pb2' + # @@protoc_insertion_point(class_scope:CNAPP.CnappQueueItem) + }) +_sym_db.RegisterMessage(CnappQueueItem) + +CnappAssociateRequest = _reflection.GeneratedProtocolMessageType('CnappAssociateRequest', (_message.Message,), { + 'DESCRIPTOR' : _CNAPPASSOCIATEREQUEST, + '__module__' : 'cnapp_pb2' + # @@protoc_insertion_point(class_scope:CNAPP.CnappAssociateRequest) + }) +_sym_db.RegisterMessage(CnappAssociateRequest) + +CnappAssociateResponse = _reflection.GeneratedProtocolMessageType('CnappAssociateResponse', (_message.Message,), { + 'DESCRIPTOR' : _CNAPPASSOCIATERESPONSE, + '__module__' : 'cnapp_pb2' + # @@protoc_insertion_point(class_scope:CNAPP.CnappAssociateResponse) + }) +_sym_db.RegisterMessage(CnappAssociateResponse) + +CnappRemediateRequest = _reflection.GeneratedProtocolMessageType('CnappRemediateRequest', (_message.Message,), { + 'DESCRIPTOR' : _CNAPPREMEDIATEREQUEST, + '__module__' : 'cnapp_pb2' + # @@protoc_insertion_point(class_scope:CNAPP.CnappRemediateRequest) + }) +_sym_db.RegisterMessage(CnappRemediateRequest) + +CnappRemediateResponse = _reflection.GeneratedProtocolMessageType('CnappRemediateResponse', (_message.Message,), { + 'DESCRIPTOR' : _CNAPPREMEDIATERESPONSE, + '__module__' : 'cnapp_pb2' + # @@protoc_insertion_point(class_scope:CNAPP.CnappRemediateResponse) + }) +_sym_db.RegisterMessage(CnappRemediateResponse) + +CnappSetStatusRequest = _reflection.GeneratedProtocolMessageType('CnappSetStatusRequest', (_message.Message,), { + 'DESCRIPTOR' : _CNAPPSETSTATUSREQUEST, + '__module__' : 'cnapp_pb2' + # @@protoc_insertion_point(class_scope:CNAPP.CnappSetStatusRequest) + }) +_sym_db.RegisterMessage(CnappSetStatusRequest) + +CnappSetStatusResponse = _reflection.GeneratedProtocolMessageType('CnappSetStatusResponse', (_message.Message,), { + 'DESCRIPTOR' : _CNAPPSETSTATUSRESPONSE, + '__module__' : 'cnapp_pb2' + # @@protoc_insertion_point(class_scope:CNAPP.CnappSetStatusResponse) + }) +_sym_db.RegisterMessage(CnappSetStatusResponse) + +CnappDeleteQueueItemRequest = _reflection.GeneratedProtocolMessageType('CnappDeleteQueueItemRequest', (_message.Message,), { + 'DESCRIPTOR' : _CNAPPDELETEQUEUEITEMREQUEST, + '__module__' : 'cnapp_pb2' + # @@protoc_insertion_point(class_scope:CNAPP.CnappDeleteQueueItemRequest) + }) +_sym_db.RegisterMessage(CnappDeleteQueueItemRequest) + +CnappDeleteQueueItemResponse = _reflection.GeneratedProtocolMessageType('CnappDeleteQueueItemResponse', (_message.Message,), { + 'DESCRIPTOR' : _CNAPPDELETEQUEUEITEMRESPONSE, + '__module__' : 'cnapp_pb2' + # @@protoc_insertion_point(class_scope:CNAPP.CnappDeleteQueueItemResponse) + }) +_sym_db.RegisterMessage(CnappDeleteQueueItemResponse) + +CnappConfiguration = _reflection.GeneratedProtocolMessageType('CnappConfiguration', (_message.Message,), { + 'DESCRIPTOR' : _CNAPPCONFIGURATION, + '__module__' : 'cnapp_pb2' + # @@protoc_insertion_point(class_scope:CNAPP.CnappConfiguration) + }) +_sym_db.RegisterMessage(CnappConfiguration) + +CnappDeleteConfigurationRequest = _reflection.GeneratedProtocolMessageType('CnappDeleteConfigurationRequest', (_message.Message,), { + 'DESCRIPTOR' : _CNAPPDELETECONFIGURATIONREQUEST, + '__module__' : 'cnapp_pb2' + # @@protoc_insertion_point(class_scope:CNAPP.CnappDeleteConfigurationRequest) + }) +_sym_db.RegisterMessage(CnappDeleteConfigurationRequest) + +CnappTestEncrypterRequest = _reflection.GeneratedProtocolMessageType('CnappTestEncrypterRequest', (_message.Message,), { + 'DESCRIPTOR' : _CNAPPTESTENCRYPTERREQUEST, + '__module__' : 'cnapp_pb2' + # @@protoc_insertion_point(class_scope:CNAPP.CnappTestEncrypterRequest) + }) +_sym_db.RegisterMessage(CnappTestEncrypterRequest) + +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n\030com.keepersecurity.protoB\005Cnapp' + _CNAPPPROVIDER._serialized_start=1446 + _CNAPPPROVIDER._serialized_end=1517 + _CNAPPREMEDIATIONACTION._serialized_start=1520 + _CNAPPREMEDIATIONACTION._serialized_end=1651 + _CNAPPQUEUELISTREQUEST._serialized_start=22 + _CNAPPQUEUELISTREQUEST._serialized_end=87 + _CNAPPQUEUELISTRESPONSE._serialized_start=89 + _CNAPPQUEUELISTRESPONSE._serialized_end=168 + _CNAPPQUEUEITEM._serialized_start=171 + _CNAPPQUEUEITEM._serialized_end=379 + _CNAPPASSOCIATEREQUEST._serialized_start=381 + _CNAPPASSOCIATEREQUEST._serialized_end=445 + _CNAPPASSOCIATERESPONSE._serialized_start=447 + _CNAPPASSOCIATERESPONSE._serialized_end=499 + _CNAPPREMEDIATEREQUEST._serialized_start=502 + _CNAPPREMEDIATEREQUEST._serialized_end=781 + _CNAPPREMEDIATERESPONSE._serialized_start=783 + _CNAPPREMEDIATERESPONSE._serialized_end=902 + _CNAPPSETSTATUSREQUEST._serialized_start=904 + _CNAPPSETSTATUSREQUEST._serialized_end=993 + _CNAPPSETSTATUSRESPONSE._serialized_start=995 + _CNAPPSETSTATUSRESPONSE._serialized_end=1047 + _CNAPPDELETEQUEUEITEMREQUEST._serialized_start=1049 + _CNAPPDELETEQUEUEITEMREQUEST._serialized_end=1100 + _CNAPPDELETEQUEUEITEMRESPONSE._serialized_start=1102 + _CNAPPDELETEQUEUEITEMRESPONSE._serialized_end=1132 + _CNAPPCONFIGURATION._serialized_start=1135 + _CNAPPCONFIGURATION._serialized_end=1334 + _CNAPPDELETECONFIGURATIONREQUEST._serialized_start=1336 + _CNAPPDELETECONFIGURATIONREQUEST._serialized_end=1389 + _CNAPPTESTENCRYPTERREQUEST._serialized_start=1391 + _CNAPPTESTENCRYPTERREQUEST._serialized_end=1444 +# @@protoc_insertion_point(module_scope) diff --git a/unit-tests/conftest.py b/unit-tests/conftest.py new file mode 100644 index 000000000..5650ef913 --- /dev/null +++ b/unit-tests/conftest.py @@ -0,0 +1,12 @@ +"""Pytest session hooks for unit-tests. + +CNAPP tests import `cnapp_helper` before `keepercommander.commands.record` is loaded, +which triggers a pre-existing record <-> ksm circular import. Loading `record` first +resolves the cycle (same as production startup order). +""" +import pytest + + +@pytest.fixture(scope='session', autouse=True) +def _preload_commands_record_module(): + import keepercommander.commands.record # noqa: F401 diff --git a/unit-tests/pam/test_cnapp.py b/unit-tests/pam/test_cnapp.py new file mode 100644 index 000000000..6b9b39756 --- /dev/null +++ b/unit-tests/pam/test_cnapp.py @@ -0,0 +1,913 @@ +"""Unit tests for the Commander CNAPP helper and command surface. + +Strategy: every test patches `_post_request_to_router` so we can assert on what the +helper sends to krouter and feed deterministic responses back into the commands. We +deliberately stay one layer below the network — no socket calls, no real protobuf +encryption — but we exercise the real proto serializers so wire-format breakage +surfaces here. +""" +import base64 +import io +import json +import os +import unittest +from contextlib import redirect_stdout +from unittest.mock import MagicMock, patch + +# isort: off +# Pre-load `record` before cnapp modules (record↔ksm cycle). Pytest also loads it via +# unit-tests/conftest.py; keep this guard for `python unit-tests/pam/test_cnapp.py`. +import keepercommander.commands.record # noqa: F401 +# isort: on + +from cryptography.hazmat.primitives.ciphers.aead import AESGCM # noqa: E402 +from keeper_secrets_manager_core.utils import bytes_to_base64 # noqa: E402 + +from keepercommander.commands.pam import cnapp_helper # noqa: E402 +from keepercommander.commands.pam import cnapp_commands # noqa: E402 +from keepercommander.error import CommandError # noqa: E402 +from keepercommander.proto import cnapp_pb2 # noqa: E402 + + +# Sample 16-byte UIDs as base64url (the format Commander callers pass in). +NETWORK_UID = 'AAAAAAAAAAAAAAAAAAAAAA' # 16 zero bytes +RECORD_UID = 'AQEBAQEBAQEBAQEBAQEBAQ' # 16 0x01 bytes +CONFIG_RECORD_UID = 'AgICAgICAgICAgICAgICAg' + + +def _mock_params(): + """Minimal KeeperParams stand-in — the helpers only use it to drive the router_helper + transport, which is mocked here, so a MagicMock is enough.""" + return MagicMock() + + +# --------------------------------------------------------------------------- +# cnapp_helper: enum parsing +# --------------------------------------------------------------------------- + +class TestEnumParsing(unittest.TestCase): + """provider_from_name and action_from_name must accept short or full names and + reject unknown values with a helpful error listing valid options.""" + + def test_provider_short_name(self): + self.assertEqual(cnapp_helper.provider_from_name('wiz'), cnapp_pb2.CNAPP_PROVIDER_WIZ) + + def test_provider_full_name_case_insensitive(self): + self.assertEqual( + cnapp_helper.provider_from_name('cnapp_provider_wiz'), + cnapp_pb2.CNAPP_PROVIDER_WIZ, + ) + + def test_provider_empty_returns_unspecified(self): + self.assertEqual(cnapp_helper.provider_from_name(''), cnapp_pb2.CNAPP_PROVIDER_UNSPECIFIED) + + def test_provider_unknown_raises_with_valid_options(self): + with self.assertRaises(ValueError) as ctx: + cnapp_helper.provider_from_name('aws') + self.assertIn('WIZ', str(ctx.exception).upper()) + + def test_action_short_name(self): + self.assertEqual( + cnapp_helper.action_from_name('rotate_credentials'), + cnapp_pb2.ROTATE_CREDENTIALS, + ) + + def test_action_hyphenated(self): + # The CLI accepts hyphens (`--action remove-standing-privilege`) for ergonomics; + # helper must normalize before resolving the enum. + self.assertEqual( + cnapp_helper.action_from_name('remove-standing-privilege'), + cnapp_pb2.REMOVE_STANDING_PRIVILEGE, + ) + + def test_action_unknown_raises(self): + with self.assertRaises(ValueError): + cnapp_helper.action_from_name('teleport') + + def test_action_empty_returns_unspecified(self): + self.assertEqual(cnapp_helper.action_from_name(''), cnapp_pb2.UNSPECIFIED) + + +# --------------------------------------------------------------------------- +# cnapp_helper: configuration endpoints +# --------------------------------------------------------------------------- + +class TestConfigurationHelpers(unittest.TestCase): + """Each helper must dispatch to the right krouter path with a correctly populated + protobuf request and return the typed response.""" + + def setUp(self): + self.params = _mock_params() + + def _patch_post(self, return_value=None): + return patch.object(cnapp_helper, '_post_request_to_router', return_value=return_value) + + def test_set_configuration_dispatches_with_full_payload(self): + expected_response = cnapp_pb2.CnappConfiguration( + clientId='abc', apiEndpointUrl='https://api.wiz.io') + with self._patch_post(return_value=expected_response) as post: + result = cnapp_helper.set_cnapp_configuration( + self.params, + network_uid=NETWORK_UID, + provider=cnapp_pb2.CNAPP_PROVIDER_WIZ, + client_id='abc', + client_secret='secret', + api_endpoint_url='https://api.wiz.io', + cnapp_config_record_uid=CONFIG_RECORD_UID, + auth_endpoint_url='https://auth.wiz.io/oauth/token', + ) + self.assertIs(result, expected_response) + args, kwargs = post.call_args + self.assertEqual(args[1], 'cnapp/configuration/set') + rq = kwargs['rq_proto'] + self.assertEqual(rq.provider, cnapp_pb2.CNAPP_PROVIDER_WIZ) + self.assertEqual(rq.clientId, 'abc') + self.assertEqual(rq.clientSecret, 'secret') + self.assertEqual(rq.apiEndpointUrl, 'https://api.wiz.io') + self.assertEqual(rq.authEndpointUrl, 'https://auth.wiz.io/oauth/token') + self.assertEqual(len(rq.networkUid), 16) + self.assertEqual(len(rq.cnappConfigRecordUid), 16) + self.assertIs(kwargs['rs_type'], cnapp_pb2.CnappConfiguration) + + def test_set_configuration_omits_empty_secret_to_keep_existing(self): + """Edge case: passing '' for client_secret on set must leave the field blank in + the request so krouter can splice in the previously stored secret.""" + with self._patch_post() as post: + cnapp_helper.set_cnapp_configuration( + self.params, + network_uid=NETWORK_UID, + provider=cnapp_pb2.CNAPP_PROVIDER_WIZ, + client_id='abc', + client_secret='', + api_endpoint_url='https://api.wiz.io', + cnapp_config_record_uid=CONFIG_RECORD_UID, + ) + rq = post.call_args.kwargs['rq_proto'] + self.assertEqual(rq.clientSecret, '') + + def test_test_configuration_dispatches_to_test_endpoint(self): + with self._patch_post() as post: + cnapp_helper.test_cnapp_configuration( + self.params, + network_uid=NETWORK_UID, + provider=cnapp_pb2.CNAPP_PROVIDER_WIZ, + client_id='abc', + client_secret='secret', + api_endpoint_url='https://api.wiz.io', + auth_endpoint_url='https://auth.wiz.io/oauth/token', + ) + self.assertEqual(post.call_args.args[1], 'cnapp/configuration/test') + self.assertEqual(post.call_args.kwargs['rq_proto'].authEndpointUrl, 'https://auth.wiz.io/oauth/token') + # test endpoint never persists, so it must not require / send config record UID + self.assertEqual(post.call_args.kwargs['rq_proto'].cnappConfigRecordUid, b'') + + def test_test_encrypter_sets_url(self): + with self._patch_post() as post: + cnapp_helper.test_cnapp_encrypter(self.params, url_base_encrypter='https://encr.local') + rq = post.call_args.kwargs['rq_proto'] + self.assertEqual(post.call_args.args[1], 'cnapp/configuration/test-encrypter') + self.assertEqual(rq.urlBaseEncrypter, 'https://encr.local') + + def test_read_configuration_uses_read_endpoint(self): + with self._patch_post(return_value=cnapp_pb2.CnappConfiguration()) as post: + cnapp_helper.read_cnapp_configuration( + self.params, network_uid=NETWORK_UID, provider=cnapp_pb2.CNAPP_PROVIDER_WIZ) + self.assertEqual(post.call_args.args[1], 'cnapp/configuration/read') + self.assertIs(post.call_args.kwargs['rs_type'], cnapp_pb2.CnappConfiguration) + + def test_delete_configuration_uses_delete_endpoint(self): + with self._patch_post() as post: + cnapp_helper.delete_cnapp_configuration(self.params, network_uid=NETWORK_UID) + self.assertEqual(post.call_args.args[1], 'cnapp/configuration/delete') + self.assertEqual(len(post.call_args.kwargs['rq_proto'].networkUid), 16) + + +# --------------------------------------------------------------------------- +# cnapp_helper: queue endpoints +# --------------------------------------------------------------------------- + +class TestQueueHelpers(unittest.TestCase): + def setUp(self): + self.params = _mock_params() + + def _patch_post(self, return_value=None): + return patch.object(cnapp_helper, '_post_request_to_router', return_value=return_value) + + def test_list_queue_with_status_filter(self): + items = cnapp_pb2.CnappQueueListResponse( + items=[cnapp_pb2.CnappQueueItem(cnappQueueId=42)]) + with self._patch_post(return_value=items) as post: + response = cnapp_helper.list_cnapp_queue( + self.params, network_uid=NETWORK_UID, status_filter=1) + self.assertEqual(post.call_args.args[1], 'cnapp/queue') + self.assertEqual(post.call_args.kwargs['rq_proto'].statusFilter, 1) + self.assertEqual(response.items[0].cnappQueueId, 42) + + def test_list_queue_defaults_to_all_status(self): + with self._patch_post(return_value=cnapp_pb2.CnappQueueListResponse()) as post: + cnapp_helper.list_cnapp_queue(self.params, network_uid=NETWORK_UID) + self.assertEqual(post.call_args.kwargs['rq_proto'].statusFilter, 0) + + def test_associate_record_dispatches(self): + with self._patch_post() as post: + cnapp_helper.associate_cnapp_record( + self.params, cnapp_queue_id=7, record_uid=RECORD_UID) + rq = post.call_args.kwargs['rq_proto'] + self.assertEqual(post.call_args.args[1], 'cnapp/queue/associate') + self.assertEqual(rq.cnappQueueId, 7) + self.assertEqual(len(rq.recordUid), 16) + + def test_remediate_forwards_optional_fields(self): + with self._patch_post(return_value=cnapp_pb2.CnappRemediateResponse()) as post: + cnapp_helper.remediate_cnapp_queue_item( + self.params, + cnapp_queue_id=3, + action_type=cnapp_pb2.ROTATE_CREDENTIALS, + provider=cnapp_pb2.CNAPP_PROVIDER_WIZ, + cnapp_config_record_uid=CONFIG_RECORD_UID, + resource_ref=RECORD_UID, + pwd_complexity='{"len":24}', + controller_uid='gateway-1', + message_uid=RECORD_UID, + group_name='Admins', + ) + rq = post.call_args.kwargs['rq_proto'] + self.assertEqual(post.call_args.args[1], 'cnapp/queue/remediate') + self.assertEqual(rq.cnappQueueId, 3) + self.assertEqual(rq.actionType, cnapp_pb2.ROTATE_CREDENTIALS) + self.assertEqual(rq.provider, cnapp_pb2.CNAPP_PROVIDER_WIZ) + self.assertEqual(rq.pwdComplexity, '{"len":24}') + self.assertEqual(rq.controllerUid, 'gateway-1') + self.assertEqual(rq.groupName, 'Admins') + + def test_remediate_minimal_fields(self): + """No optional fields — only queueId and actionType must be set on the wire.""" + with self._patch_post(return_value=cnapp_pb2.CnappRemediateResponse()) as post: + cnapp_helper.remediate_cnapp_queue_item( + self.params, cnapp_queue_id=9, action_type=cnapp_pb2.ROTATE_CREDENTIALS) + rq = post.call_args.kwargs['rq_proto'] + self.assertEqual(rq.cnappQueueId, 9) + self.assertEqual(rq.provider, 0) + self.assertEqual(rq.pwdComplexity, '') + self.assertEqual(rq.controllerUid, '') + self.assertEqual(rq.groupName, '') + + def test_set_status_with_reason(self): + with self._patch_post(return_value=cnapp_pb2.CnappSetStatusResponse(cnappQueueStatusId=3)) as post: + response = cnapp_helper.set_cnapp_queue_status( + self.params, cnapp_queue_id=11, cnapp_queue_status_id=3, reason='Manually resolved') + self.assertEqual(post.call_args.args[1], 'cnapp/queue/set-status') + self.assertEqual(post.call_args.kwargs['rq_proto'].reason, 'Manually resolved') + self.assertEqual(response.cnappQueueStatusId, 3) + + def test_delete_queue_item_dispatches(self): + with self._patch_post() as post: + cnapp_helper.delete_cnapp_queue_item(self.params, cnapp_queue_id=11) + self.assertEqual(post.call_args.args[1], 'cnapp/queue/delete') + self.assertEqual(post.call_args.kwargs['rq_proto'].cnappQueueId, 11) + + +# --------------------------------------------------------------------------- +# cnapp_helper: error propagation +# --------------------------------------------------------------------------- + +class TestHelperErrorPropagation(unittest.TestCase): + """The router layer raises on RRC_!=OK; helpers must NOT swallow those errors.""" + + def test_set_configuration_propagates_router_error(self): + params = _mock_params() + with patch.object(cnapp_helper, '_post_request_to_router', + side_effect=Exception('Credential validation failed: Unauthorized')): + with self.assertRaises(Exception) as ctx: + cnapp_helper.set_cnapp_configuration( + params, + network_uid=NETWORK_UID, + provider=cnapp_pb2.CNAPP_PROVIDER_WIZ, + client_id='abc', + client_secret='bad', + api_endpoint_url='https://api.wiz.io', + cnapp_config_record_uid=CONFIG_RECORD_UID, + ) + self.assertIn('Credential validation failed', str(ctx.exception)) + + +# --------------------------------------------------------------------------- +# cnapp_commands: status resolver +# --------------------------------------------------------------------------- + +class TestStatusResolver(unittest.TestCase): + + def test_numeric_passes_through(self): + self.assertEqual(cnapp_commands._resolve_status('1'), 1) + self.assertEqual(cnapp_commands._resolve_status(2), 2) + + def test_unknown_numeric_id_raises(self): + with self.assertRaises(CommandError): + cnapp_commands._resolve_status(99) + + def test_zero_is_all(self): + self.assertEqual(cnapp_commands._resolve_status('0'), 0) + self.assertEqual(cnapp_commands._resolve_status(None), 0) + self.assertEqual(cnapp_commands._resolve_status(''), 0) + + def test_named_status_case_insensitive(self): + self.assertEqual(cnapp_commands._resolve_status('PENDING'), 1) + self.assertEqual(cnapp_commands._resolve_status('in_progress'), 2) + self.assertEqual(cnapp_commands._resolve_status('Resolved'), 3) + + def test_unknown_status_raises_command_error(self): + with self.assertRaises(CommandError): + cnapp_commands._resolve_status('flapping') + + +# --------------------------------------------------------------------------- +# cnapp_commands: end-to-end (helpers patched) +# --------------------------------------------------------------------------- + +class TestConfigCommands(unittest.TestCase): + def setUp(self): + self.params = _mock_params() + + def _capture_stdout(self): + buf = io.StringIO() + return buf, redirect_stdout(buf) + + def test_config_set_calls_helper_with_resolved_provider(self): + with patch.object(cnapp_commands.cnapp_helper, 'set_cnapp_configuration', + return_value=cnapp_pb2.CnappConfiguration(clientId='abc', + apiEndpointUrl='https://api.wiz.io', + provider=cnapp_pb2.CNAPP_PROVIDER_WIZ)) as helper: + buf, ctx = self._capture_stdout() + with ctx: + cnapp_commands.PAMCnappConfigSetCommand().execute( + self.params, + network_uid=NETWORK_UID, + provider='wiz', + client_id='abc', + client_secret='secret', + api_endpoint_url='https://api.wiz.io', + cnapp_config_record_uid=CONFIG_RECORD_UID, + auth_endpoint_url='https://auth.wiz.io/oauth/token', + ) + helper.assert_called_once() + kwargs = helper.call_args.kwargs + self.assertEqual(kwargs['provider'], cnapp_pb2.CNAPP_PROVIDER_WIZ) + self.assertEqual(kwargs['client_secret'], 'secret') + self.assertEqual(kwargs['auth_endpoint_url'], 'https://auth.wiz.io/oauth/token') + self.assertIn('saved', buf.getvalue().lower()) + + def test_config_set_blank_secret_passes_through(self): + """Edge case: the CLI must forward an empty secret unchanged so krouter can + keep the existing value.""" + with patch.object(cnapp_commands.cnapp_helper, 'set_cnapp_configuration', + return_value=cnapp_pb2.CnappConfiguration()) as helper: + with redirect_stdout(io.StringIO()): + cnapp_commands.PAMCnappConfigSetCommand().execute( + self.params, + network_uid=NETWORK_UID, + provider='wiz', + client_id='abc', + client_secret='', # explicit + api_endpoint_url='https://api.wiz.io', + cnapp_config_record_uid=CONFIG_RECORD_UID, + ) + self.assertEqual(helper.call_args.kwargs['client_secret'], '') + + def test_config_set_omitted_secret_keeps_existing(self): + with patch.object(cnapp_commands.cnapp_helper, 'set_cnapp_configuration', + return_value=cnapp_pb2.CnappConfiguration()) as helper: + with redirect_stdout(io.StringIO()): + cnapp_commands.PAMCnappConfigSetCommand().execute( + self.params, + network_uid=NETWORK_UID, + provider='wiz', + client_id='abc', + client_secret=None, + api_endpoint_url='https://api.wiz.io', + cnapp_config_record_uid=CONFIG_RECORD_UID, + ) + self.assertEqual(helper.call_args.kwargs['client_secret'], '') + + def test_config_set_invalid_provider_raises(self): + with self.assertRaises(ValueError): + cnapp_commands.PAMCnappConfigSetCommand().execute( + self.params, + network_uid=NETWORK_UID, + provider='bogus', + client_id='abc', + client_secret='secret', + api_endpoint_url='https://api.wiz.io', + cnapp_config_record_uid=CONFIG_RECORD_UID, + ) + + def test_config_test_prints_success(self): + with patch.object(cnapp_commands.cnapp_helper, 'test_cnapp_configuration', return_value=None) as helper: + buf = io.StringIO() + with redirect_stdout(buf): + cnapp_commands.PAMCnappConfigTestCommand().execute( + self.params, + network_uid=NETWORK_UID, + provider='wiz', + client_id='abc', + client_secret='secret', + api_endpoint_url='https://api.wiz.io', + auth_endpoint_url='https://auth.wiz.io/oauth/token', + ) + self.assertEqual(helper.call_args.kwargs['auth_endpoint_url'], 'https://auth.wiz.io/oauth/token') + self.assertIn('validated', buf.getvalue().lower()) + + def test_config_test_propagates_helper_error(self): + with patch.object(cnapp_commands.cnapp_helper, 'test_cnapp_configuration', + side_effect=Exception('Credential validation failed: bad')): + with self.assertRaises(Exception): + cnapp_commands.PAMCnappConfigTestCommand().execute( + self.params, + network_uid=NETWORK_UID, + provider='wiz', + client_id='abc', + client_secret='bad', + api_endpoint_url='https://api.wiz.io', + ) + + def test_config_test_encrypter_success(self): + with patch.object(cnapp_commands.cnapp_helper, 'test_cnapp_encrypter', return_value=None) as helper: + buf = io.StringIO() + with redirect_stdout(buf): + cnapp_commands.PAMCnappConfigTestEncrypterCommand().execute( + self.params, url='https://encr.local') + helper.assert_called_once_with(self.params, url_base_encrypter='https://encr.local') + self.assertIn('reachable', buf.getvalue().lower()) + + def test_config_read_table_format(self): + config = cnapp_pb2.CnappConfiguration( + clientId='abc', + apiEndpointUrl='https://api.wiz.io', + authEndpointUrl='https://auth.wiz.io/oauth/token', + provider=cnapp_pb2.CNAPP_PROVIDER_WIZ, + ) + with patch.object(cnapp_commands.cnapp_helper, 'read_cnapp_configuration', return_value=config): + buf = io.StringIO() + with redirect_stdout(buf): + cnapp_commands.PAMCnappConfigReadCommand().execute( + self.params, network_uid=NETWORK_UID, provider='wiz', format='table') + output = buf.getvalue() + self.assertIn('CNAPP Configuration', output) + self.assertIn('https://api.wiz.io', output) + self.assertIn('https://auth.wiz.io/oauth/token', output) + + def test_config_read_json_format(self): + config = cnapp_pb2.CnappConfiguration( + clientId='abc', + apiEndpointUrl='https://api.wiz.io', + authEndpointUrl='https://auth.wiz.io/oauth/token', + provider=cnapp_pb2.CNAPP_PROVIDER_WIZ, + ) + with patch.object(cnapp_commands.cnapp_helper, 'read_cnapp_configuration', return_value=config): + buf = io.StringIO() + with redirect_stdout(buf): + result = cnapp_commands.PAMCnappConfigReadCommand().execute( + self.params, network_uid=NETWORK_UID, provider='wiz', format='json') + payload = json.loads(buf.getvalue()) + self.assertEqual(payload['clientId'], 'abc') + self.assertEqual(payload['provider'], 'CNAPP_PROVIDER_WIZ') + self.assertEqual(payload['apiEndpointUrl'], 'https://api.wiz.io') + self.assertEqual(payload['authEndpointUrl'], 'https://auth.wiz.io/oauth/token') + self.assertIsNone(result, 'JSON output is the channel — no value returned to the REPL') + + def test_config_read_handles_none_response(self): + with patch.object(cnapp_commands.cnapp_helper, 'read_cnapp_configuration', return_value=None): + self.assertIsNone(cnapp_commands.PAMCnappConfigReadCommand().execute( + self.params, network_uid=NETWORK_UID, provider='wiz', format='table')) + + def test_config_delete_success(self): + with patch.object(cnapp_commands.cnapp_helper, 'delete_cnapp_configuration', return_value=None) as helper: + buf = io.StringIO() + with redirect_stdout(buf): + cnapp_commands.PAMCnappConfigDeleteCommand().execute(self.params, network_uid=NETWORK_UID) + helper.assert_called_once_with(self.params, network_uid=NETWORK_UID) + self.assertIn('deleted', buf.getvalue().lower()) + + +class TestQueueCommands(unittest.TestCase): + def setUp(self): + self.params = _mock_params() + + def _queue_response(self, items=None, has_more=False): + return cnapp_pb2.CnappQueueListResponse(items=items or [], hasMore=has_more) + + def _queue_item(self, queue_id=1, status_id=1, record_uid=b''): + return cnapp_pb2.CnappQueueItem( + cnappQueueId=queue_id, + cnappProviderId=cnapp_pb2.CNAPP_PROVIDER_WIZ, + cnappQueueStatusId=status_id, + receivedAt=1700000000000, + networkId=b'\x00' * 16, + recordUid=record_uid, + ) + + def test_queue_list_empty(self): + with patch.object(cnapp_commands.cnapp_helper, 'list_cnapp_queue', + return_value=self._queue_response()): + buf = io.StringIO() + with redirect_stdout(buf): + result = cnapp_commands.PAMCnappQueueListCommand().execute( + self.params, network_uid=NETWORK_UID, status=0, format='table', + no_decrypt=True) + self.assertIn('No CNAPP queue items', buf.getvalue()) + self.assertIsNone(result, 'queue list must not return the proto so the REPL does not dump bytes') + + def test_queue_list_with_items_table(self): + item = self._queue_item(queue_id=99, status_id=2, record_uid=b'\x01' * 16) + with patch.object(cnapp_commands.cnapp_helper, 'list_cnapp_queue', + return_value=self._queue_response([item])): + buf = io.StringIO() + with redirect_stdout(buf): + result = cnapp_commands.PAMCnappQueueListCommand().execute( + self.params, network_uid=NETWORK_UID, status=0, format='table', + no_decrypt=True) + output = buf.getvalue() + self.assertIn('99', output) + self.assertIn('IN_PROGRESS', output) + self.assertIn('CNAPP_PROVIDER_WIZ', output) + self.assertIsNone(result) + + def test_queue_list_filter_resolves_named_status(self): + with patch.object(cnapp_commands.cnapp_helper, 'list_cnapp_queue', + return_value=self._queue_response()) as helper: + with redirect_stdout(io.StringIO()): + cnapp_commands.PAMCnappQueueListCommand().execute( + self.params, network_uid=NETWORK_UID, status='pending', format='table', + no_decrypt=True) + self.assertEqual(helper.call_args.kwargs['status_filter'], 1) + + def test_queue_list_json_format(self): + item = self._queue_item(queue_id=5, status_id=3, record_uid=b'\x02' * 16) + with patch.object(cnapp_commands.cnapp_helper, 'list_cnapp_queue', + return_value=self._queue_response([item], has_more=True)): + buf = io.StringIO() + with redirect_stdout(buf): + result = cnapp_commands.PAMCnappQueueListCommand().execute( + self.params, network_uid=NETWORK_UID, status=0, format='json', + no_decrypt=True) + payload = json.loads(buf.getvalue()) + self.assertEqual(payload['items'][0]['cnappQueueId'], 5) + self.assertEqual(payload['items'][0]['cnappQueueStatusName'], 'RESOLVED') + self.assertTrue(payload['hasMore']) + self.assertEqual(payload['items'][0]['recordUid'], + bytes_to_base64(b'\x02' * 16)) + self.assertNotIn('payload', payload['items'][0], + 'raw encrypted payload bytes must not leak into JSON output') + self.assertIsNone(result, 'JSON output stream must not also return a value') + + def test_queue_list_warns_when_has_more(self): + item = self._queue_item() + with patch.object(cnapp_commands.cnapp_helper, 'list_cnapp_queue', + return_value=self._queue_response([item], has_more=True)): + buf = io.StringIO() + with redirect_stdout(buf): + cnapp_commands.PAMCnappQueueListCommand().execute( + self.params, network_uid=NETWORK_UID, status=0, format='table', + no_decrypt=True) + self.assertIn('hasMore=true', buf.getvalue()) + + def test_queue_associate_success(self): + with patch.object(cnapp_commands.cnapp_helper, 'associate_cnapp_record', return_value=None) as helper: + buf = io.StringIO() + with redirect_stdout(buf): + cnapp_commands.PAMCnappQueueAssociateCommand().execute( + self.params, cnapp_queue_id=12, record_uid=RECORD_UID) + helper.assert_called_once_with(self.params, cnapp_queue_id=12, record_uid=RECORD_UID) + self.assertIn('12', buf.getvalue()) + + def test_queue_remediate_prints_response(self): + response = cnapp_pb2.CnappRemediateResponse( + actionType=cnapp_pb2.ROTATE_CREDENTIALS, + result='Scheduled', + cnappQueueStatusId=2, + ) + with patch.object(cnapp_commands.cnapp_helper, 'remediate_cnapp_queue_item', + return_value=response): + buf = io.StringIO() + with redirect_stdout(buf): + cnapp_commands.PAMCnappQueueRemediateCommand().execute( + self.params, + cnapp_queue_id=4, + action_type='rotate_credentials', + provider='wiz', + ) + output = buf.getvalue() + self.assertIn('ROTATE_CREDENTIALS', output) + self.assertIn('IN_PROGRESS', output) + self.assertIn('Scheduled', output) + + def test_queue_remediate_unsupported_action_propagates(self): + with patch.object(cnapp_commands.cnapp_helper, 'remediate_cnapp_queue_item', + side_effect=Exception('Unsupported action type response code: RRC_BAD_REQUEST')): + with self.assertRaises(Exception) as ctx: + cnapp_commands.PAMCnappQueueRemediateCommand().execute( + self.params, + cnapp_queue_id=4, + action_type='jit_access', + ) + self.assertIn('Unsupported', str(ctx.exception)) + + def test_queue_remediate_invalid_action_name(self): + with self.assertRaises(ValueError): + cnapp_commands.PAMCnappQueueRemediateCommand().execute( + self.params, cnapp_queue_id=1, action_type='nuke_everything') + + def test_queue_set_status_normalizes_named(self): + response = cnapp_pb2.CnappSetStatusResponse(cnappQueueStatusId=3) + with patch.object(cnapp_commands.cnapp_helper, 'set_cnapp_queue_status', + return_value=response) as helper: + with redirect_stdout(io.StringIO()): + cnapp_commands.PAMCnappQueueSetStatusCommand().execute( + self.params, cnapp_queue_id=8, status='resolved', reason='manual') + kwargs = helper.call_args.kwargs + self.assertEqual(kwargs['cnapp_queue_status_id'], 3) + self.assertEqual(kwargs['reason'], 'manual') + + def test_queue_set_status_rejects_zero(self): + with self.assertRaises(CommandError): + cnapp_commands.PAMCnappQueueSetStatusCommand().execute( + self.params, cnapp_queue_id=8, status=0) + + def test_queue_set_status_rejects_unknown_name(self): + with self.assertRaises(CommandError): + cnapp_commands.PAMCnappQueueSetStatusCommand().execute( + self.params, cnapp_queue_id=8, status='snoozed') + + def test_queue_delete_success(self): + with patch.object(cnapp_commands.cnapp_helper, 'delete_cnapp_queue_item', + return_value=None) as helper: + buf = io.StringIO() + with redirect_stdout(buf): + cnapp_commands.PAMCnappQueueDeleteCommand().execute(self.params, cnapp_queue_id=22) + helper.assert_called_once_with(self.params, cnapp_queue_id=22) + self.assertIn('22', buf.getvalue()) + + def test_queue_delete_unknown_id_propagates_error(self): + with patch.object(cnapp_commands.cnapp_helper, 'delete_cnapp_queue_item', + side_effect=Exception('Queue item not found: 99 Response code: RRC_BAD_REQUEST')): + with self.assertRaises(Exception): + cnapp_commands.PAMCnappQueueDeleteCommand().execute(self.params, cnapp_queue_id=99) + + +# --------------------------------------------------------------------------- +# Command tree wiring +# --------------------------------------------------------------------------- + +class TestCommandTree(unittest.TestCase): + """Sanity check that the cnapp commands are reachable via `pam cnapp ...`.""" + + def test_pam_cnapp_subcommands(self): + from keepercommander.commands.discoveryrotation import PAMControllerCommand + pam = PAMControllerCommand() + self.assertIn('cnapp', pam.subcommands) + config = pam.subcommands['cnapp'].subcommands['config'] + queue = pam.subcommands['cnapp'].subcommands['queue'] + self.assertEqual( + sorted(config.subcommands), + ['delete', 'read', 'set', 'test', 'test-encrypter'], + ) + self.assertEqual( + sorted(queue.subcommands), + ['associate', 'delete', 'list', 'remediate', 'set-status'], + ) + + +# --------------------------------------------------------------------------- +# Payload decryption — round-trip an AES-256-GCM envelope and decrypt it back +# --------------------------------------------------------------------------- + +def _encrypt_cnapp_payload_for_test(plaintext_json, key): + """Produce a CNAPP queue payload byte string the way the Encrypter would so we can + exercise `_decrypt_cnapp_payload` end-to-end without mocking AES-GCM.""" + nonce = os.urandom(12) + ciphertext = AESGCM(key).encrypt(nonce, plaintext_json.encode('utf-8'), None) + enc_b64url = base64.urlsafe_b64encode(nonce + ciphertext).rstrip(b'=').decode('ascii') + envelope = json.dumps({ + 'encrypted_payload': enc_b64url, + 'alg': 'AES-256-GCM', + 'version': '1', + }).encode('utf-8') + envelope_b64url = base64.urlsafe_b64encode(envelope).rstrip(b'=').decode('ascii') + return envelope_b64url.encode('utf-8') + + +class TestPayloadDecryption(unittest.TestCase): + """`_decrypt_cnapp_payload` must round-trip the envelope produced by the customer + Encrypter (UTF-8 base64url envelope wrapping nonce||ciphertext||tag).""" + + def setUp(self): + self.key = os.urandom(32) + self.plaintext = { + 'issue': {'id': 'wiz-001', 'severity': 'HIGH', 'created': '2026-05-01T00:00:00Z'}, + 'resource': {'name': 'i-abc', 'type': 'EC2', 'cloudPlatform': 'AWS'}, + 'control': {'name': 'Public S3', 'risks': ['data-exposure']}, + 'tags': ['team:platform'], + } + + def test_roundtrip(self): + payload = _encrypt_cnapp_payload_for_test(json.dumps(self.plaintext), self.key) + decrypted = cnapp_commands._decrypt_cnapp_payload(payload, self.key) + self.assertEqual(decrypted['issue']['id'], 'wiz-001') + self.assertEqual(decrypted['resource']['name'], 'i-abc') + + def test_wrong_key_raises(self): + payload = _encrypt_cnapp_payload_for_test(json.dumps(self.plaintext), self.key) + with self.assertRaises(Exception): + cnapp_commands._decrypt_cnapp_payload(payload, os.urandom(32)) + + def test_unsupported_alg_raises(self): + envelope = json.dumps({'encrypted_payload': '', 'alg': 'ChaCha20', 'version': '1'}).encode('utf-8') + payload = base64.urlsafe_b64encode(envelope).rstrip(b'=') + with self.assertRaises(ValueError): + cnapp_commands._decrypt_cnapp_payload(payload, self.key) + + def test_missing_alg_raises(self): + envelope = json.dumps({'encrypted_payload': '', 'version': '1'}).encode('utf-8') + payload = base64.urlsafe_b64encode(envelope).rstrip(b'=') + with self.assertRaises(ValueError) as ctx: + cnapp_commands._decrypt_cnapp_payload(payload, self.key) + self.assertIn('missing', str(ctx.exception).lower()) + + def test_short_ciphertext_raises(self): + envelope = json.dumps({ + 'encrypted_payload': base64.urlsafe_b64encode(b'abc').rstrip(b'=').decode('ascii'), + 'alg': 'AES-256-GCM', + }).encode('utf-8') + payload = base64.urlsafe_b64encode(envelope).rstrip(b'=') + with self.assertRaises(ValueError): + cnapp_commands._decrypt_cnapp_payload(payload, self.key) + + +class TestKeyDecode(unittest.TestCase): + """`_decode_aes_key` must accept both standard and url-safe base64, only when the + decoded length is 16 or 32 bytes.""" + + def test_standard_base64_32(self): + raw = base64.b64encode(b'\x11' * 32).decode('ascii') + self.assertEqual(cnapp_commands._decode_aes_key(raw), b'\x11' * 32) + + def test_urlsafe_base64_32(self): + raw = base64.urlsafe_b64encode(b'\x22' * 32).decode('ascii') + self.assertEqual(cnapp_commands._decode_aes_key(raw), b'\x22' * 32) + + def test_wrong_length_returns_none(self): + raw = base64.b64encode(b'\x33' * 24).decode('ascii') + self.assertIsNone(cnapp_commands._decode_aes_key(raw)) + + def test_garbage_returns_none(self): + self.assertIsNone(cnapp_commands._decode_aes_key('not base64 at all!!!')) + self.assertIsNone(cnapp_commands._decode_aes_key('')) + self.assertIsNone(cnapp_commands._decode_aes_key(None)) + + +class TestQueueListDecryptionIntegration(unittest.TestCase): + """End-to-end: `queue list` resolves the encrypter key via the vault record, decrypts + each payload, and writes the human summary into the table cell.""" + + def setUp(self): + self.params = _mock_params() + self.key = os.urandom(32) + + def _make_item(self, queue_id, plaintext): + return cnapp_pb2.CnappQueueItem( + cnappQueueId=queue_id, + cnappProviderId=cnapp_pb2.CNAPP_PROVIDER_WIZ, + cnappQueueStatusId=1, + receivedAt=1700000000000, + networkId=b'\x00' * 16, + payload=_encrypt_cnapp_payload_for_test(json.dumps(plaintext), self.key), + ) + + def test_table_shows_decrypted_summary_when_key_resolves(self): + items = [self._make_item(101, { + 'issue': {'id': 'wiz-999', 'severity': 'CRITICAL'}, + 'control': {'name': 'Open SSH'}, + 'resource': {'name': 'prod-db-1'}, + })] + response = cnapp_pb2.CnappQueueListResponse(items=items) + config = cnapp_pb2.CnappConfiguration( + cnappConfigRecordUid=b'\xab' * 16, + provider=cnapp_pb2.CNAPP_PROVIDER_WIZ, + ) + with patch.object(cnapp_commands.cnapp_helper, 'list_cnapp_queue', return_value=response), \ + patch.object(cnapp_commands.cnapp_helper, 'read_cnapp_configuration', return_value=config), \ + patch.object(cnapp_commands, '_load_encrypter_key', return_value=self.key): + buf = io.StringIO() + with redirect_stdout(buf): + result = cnapp_commands.PAMCnappQueueListCommand().execute( + self.params, network_uid=NETWORK_UID, status=0, format='table', + provider='wiz', config_record_uid=None, no_decrypt=False) + output = buf.getvalue() + self.assertIn('CRITICAL', output) + self.assertIn('Open SSH', output) + self.assertIn('prod-db-1', output) + self.assertNotIn('', output, 'payload should have been decrypted') + self.assertIsNone(result) + + def test_table_marks_encrypted_when_key_unavailable(self): + items = [self._make_item(7, {'issue': {'id': 'x'}})] + response = cnapp_pb2.CnappQueueListResponse(items=items) + with patch.object(cnapp_commands.cnapp_helper, 'list_cnapp_queue', return_value=response), \ + patch.object(cnapp_commands.cnapp_helper, 'read_cnapp_configuration', + return_value=cnapp_pb2.CnappConfiguration()): + buf = io.StringIO() + with redirect_stdout(buf): + cnapp_commands.PAMCnappQueueListCommand().execute( + self.params, network_uid=NETWORK_UID, status=0, format='table', + provider='wiz', no_decrypt=False) + output = buf.getvalue() + self.assertIn('', output) + self.assertIn('No encrypter key', output) + + def test_json_includes_decrypted_payload_and_no_raw_payload(self): + plaintext = {'issue': {'id': 'wiz-42'}, 'resource': {'name': 'i-xyz'}} + response = cnapp_pb2.CnappQueueListResponse(items=[self._make_item(42, plaintext)]) + config = cnapp_pb2.CnappConfiguration(cnappConfigRecordUid=b'\xcd' * 16, + provider=cnapp_pb2.CNAPP_PROVIDER_WIZ) + with patch.object(cnapp_commands.cnapp_helper, 'list_cnapp_queue', return_value=response), \ + patch.object(cnapp_commands.cnapp_helper, 'read_cnapp_configuration', return_value=config), \ + patch.object(cnapp_commands, '_load_encrypter_key', return_value=self.key): + buf = io.StringIO() + with redirect_stdout(buf): + cnapp_commands.PAMCnappQueueListCommand().execute( + self.params, network_uid=NETWORK_UID, status=0, format='json', + provider='wiz', no_decrypt=False) + payload = json.loads(buf.getvalue()) + self.assertEqual(payload['items'][0]['decryptedPayload']['issue']['id'], 'wiz-42') + self.assertNotIn('payload', payload['items'][0]) + + def test_decrypt_failure_keeps_other_rows_and_reports(self): + good = self._make_item(1, { + 'issue': {'id': 'wiz-good-should-not-show'}, + 'control': {'name': 'Open SSH'}, + 'resource': {'name': 'good-resource'}, + }) + bad = cnapp_pb2.CnappQueueItem( + cnappQueueId=2, + cnappProviderId=cnapp_pb2.CNAPP_PROVIDER_WIZ, + cnappQueueStatusId=1, + payload=b'this-is-not-a-valid-envelope', + ) + response = cnapp_pb2.CnappQueueListResponse(items=[good, bad]) + config = cnapp_pb2.CnappConfiguration(cnappConfigRecordUid=b'\xef' * 16, + provider=cnapp_pb2.CNAPP_PROVIDER_WIZ) + with patch.object(cnapp_commands.cnapp_helper, 'list_cnapp_queue', return_value=response), \ + patch.object(cnapp_commands.cnapp_helper, 'read_cnapp_configuration', return_value=config), \ + patch.object(cnapp_commands, '_load_encrypter_key', return_value=self.key): + buf = io.StringIO() + with redirect_stdout(buf): + cnapp_commands.PAMCnappQueueListCommand().execute( + self.params, network_uid=NETWORK_UID, status=0, format='table', + provider='wiz', no_decrypt=False) + output = buf.getvalue() + self.assertIn('Open SSH', output) + self.assertNotIn('wiz-good-should-not-show', output) + self.assertIn('good-resource', output) + self.assertIn('', output) + self.assertIn('failed to decrypt payload', output) + + def test_json_reports_decrypt_error(self): + good = self._make_item(1, {'issue': {'id': 'wiz-1'}}) + bad = cnapp_pb2.CnappQueueItem( + cnappQueueId=2, + cnappProviderId=cnapp_pb2.CNAPP_PROVIDER_WIZ, + cnappQueueStatusId=1, + payload=b'not-valid', + ) + response = cnapp_pb2.CnappQueueListResponse(items=[good, bad]) + config = cnapp_pb2.CnappConfiguration(cnappConfigRecordUid=b'\xef' * 16, + provider=cnapp_pb2.CNAPP_PROVIDER_WIZ) + with patch.object(cnapp_commands.cnapp_helper, 'list_cnapp_queue', return_value=response), \ + patch.object(cnapp_commands.cnapp_helper, 'read_cnapp_configuration', return_value=config), \ + patch.object(cnapp_commands, '_load_encrypter_key', return_value=self.key): + buf = io.StringIO() + with redirect_stdout(buf): + cnapp_commands.PAMCnappQueueListCommand().execute( + self.params, network_uid=NETWORK_UID, status=0, format='json', + provider='wiz', no_decrypt=False) + items = json.loads(buf.getvalue())['items'] + self.assertIn('decryptedPayload', items[0]) + self.assertIn('decryptError', items[1]) + self.assertNotIn('decryptedPayload', items[1]) + + def test_no_decrypt_flag_skips_key_lookup(self): + items = [self._make_item(11, {'issue': {'id': 'x'}})] + response = cnapp_pb2.CnappQueueListResponse(items=items) + with patch.object(cnapp_commands.cnapp_helper, 'list_cnapp_queue', return_value=response), \ + patch.object(cnapp_commands, '_load_encrypter_key') as key_loader: + buf = io.StringIO() + with redirect_stdout(buf): + cnapp_commands.PAMCnappQueueListCommand().execute( + self.params, network_uid=NETWORK_UID, status=0, format='table', + no_decrypt=True) + key_loader.assert_not_called() + self.assertNotIn('No encrypter key', buf.getvalue()) + self.assertIn('', buf.getvalue()) + + +if __name__ == '__main__': + unittest.main() diff --git a/unit-tests/pam/test_dag_layer_b_set_record_rotation.py b/unit-tests/pam/test_dag_layer_b_set_record_rotation.py index 19495e8b0..334cc0906 100644 --- a/unit-tests/pam/test_dag_layer_b_set_record_rotation.py +++ b/unit-tests/pam/test_dag_layer_b_set_record_rotation.py @@ -83,18 +83,30 @@ def test_set_record_rotation_hits_correct_url(): def test_set_record_rotation_sends_protobuf_body(): from keepercommander.commands.pam.router_helper import router_set_record_rotation_information + from keepercommander import crypto, utils rq = router_pb2.RouterRecordRotationRequest( recordUid=RECORD_UID, configurationUid=CONFIG_UID, resourceUid=RESOURCE_UID, schedule='0 0 * * *', ) + transmission_key = utils.generate_aes_key() with patch(REQUESTS_TARGET, return_value=_ok_router_response()) as mock_req: - router_set_record_rotation_information(_mock_params(), rq) + router_set_record_rotation_information(_mock_params(), rq, transmission_key=transmission_key) body = mock_req.call_args.kwargs.get('data') assert isinstance(body, (bytes, bytearray)) - # Encrypted blob, not JSON - assert not body.startswith(b'{') + # Body is the AES-GCM-encrypted protobuf — not the plaintext proto, and not + # JSON. Decrypt with the known transmission key and confirm it round-trips. + # (A first-byte heuristic like `not body.startswith(b'{')` is flaky: ~1/256 + # of ciphertexts legitimately start with 0x7b.) + assert body != rq.SerializeToString() + decrypted = crypto.decrypt_aes_v2(body, transmission_key) + parsed = router_pb2.RouterRecordRotationRequest() + parsed.ParseFromString(decrypted) + assert parsed.recordUid == RECORD_UID + assert parsed.configurationUid == CONFIG_UID + assert parsed.resourceUid == RESOURCE_UID + assert parsed.schedule == '0 0 * * *' # --------------------------------------------------------------------------- # diff --git a/unit-tests/pam/test_kcm_import.py b/unit-tests/pam/test_kcm_import.py index e8b5db5ad..147c7227f 100644 --- a/unit-tests/pam/test_kcm_import.py +++ b/unit-tests/pam/test_kcm_import.py @@ -11,6 +11,8 @@ import json import os +import platform +import subprocess import sys import tempfile import unittest @@ -516,9 +518,8 @@ def test_output_file_pipeline(self, mock_getpass, MockConnector): cmd = PAMProjectKCMImportCommand() params = MagicMock() - import tempfile - with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as f: - output_path = f.name + tmp_dir = tempfile.mkdtemp() + output_path = os.path.join(tmp_dir, 'output.json') try: # With --include-credentials, passwords are preserved @@ -547,6 +548,7 @@ def test_output_file_pipeline(self, mock_getpass, MockConnector): finally: if os.path.exists(output_path): os.unlink(output_path) + os.rmdir(tmp_dir) @patch('keepercommander.commands.pam_import.kcm_import.KCMDatabaseConnector') @patch('keepercommander.commands.pam_import.kcm_import.getpass.getpass', @@ -562,9 +564,8 @@ def test_output_file_redacted_by_default(self, mock_getpass, MockConnector): cmd = PAMProjectKCMImportCommand() params = MagicMock() - import tempfile - with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as f: - output_path = f.name + tmp_dir = tempfile.mkdtemp() + output_path = os.path.join(tmp_dir, 'output.json') try: cmd.execute(params, @@ -588,6 +589,7 @@ def test_output_file_redacted_by_default(self, mock_getpass, MockConnector): finally: if os.path.exists(output_path): os.unlink(output_path) + os.rmdir(tmp_dir) @patch('keepercommander.commands.pam_import.kcm_import.KCMDatabaseConnector') @patch('keepercommander.commands.pam_import.kcm_import.getpass.getpass', @@ -1290,7 +1292,6 @@ def test_output_file_owner_only(self, mock_getpass, MockConnector): cmd = PAMProjectKCMImportCommand() params = MagicMock() - import stat tmp_dir = tempfile.mkdtemp() output_path = os.path.join(tmp_dir, 'test_output.json') try: @@ -1298,9 +1299,21 @@ def test_output_file_owner_only(self, mock_getpass, MockConnector): db_host='127.0.0.1', output=output_path) - file_mode = os.stat(output_path).st_mode & 0o777 - self.assertEqual(file_mode, 0o600, - f'Expected 0o600, got {oct(file_mode)}') + # POSIX: verifies mode == 0o600 via os.stat. + # Windows: os.stat().st_mode reads DOS attributes, not the NTFS + # DACL, so verify the icacls-applied ACL has the principals + # stripped by utils.set_file_permissions (SYSTEM, Administrators). + if platform.system() == 'Windows': + acl = subprocess.run(['icacls', output_path], + capture_output=True, text=True, check=True).stdout + self.assertNotIn('NT AUTHORITY\\SYSTEM', acl, + f'SYSTEM should have been removed from ACL:\n{acl}') + self.assertNotIn('BUILTIN\\Administrators', acl, + f'Administrators should have been removed from ACL:\n{acl}') + else: + file_mode = os.stat(output_path).st_mode & 0o777 + self.assertEqual(file_mode, 0o600, + f'Expected 0o600, got {oct(file_mode)}') finally: if os.path.exists(output_path): os.unlink(output_path) diff --git a/unit-tests/test_command_record.py b/unit-tests/test_command_record.py index fc61833c9..ebfeb2f8f 100644 --- a/unit-tests/test_command_record.py +++ b/unit-tests/test_command_record.py @@ -248,6 +248,14 @@ def test_get_invalid_uid(self): with self.assertRaises(CommandError): cmd.execute(params, uid='invalid') + def test_get_rejects_shell_metacharacters_in_lookup_token(self): + params = get_synced_params() + cmd = record.RecordGetUidCommand() + + with self.assertRaises(CommandError) as context: + cmd.execute(params, uid='x;cd $HOME && id > pwned_keeper_rce.txt;#"unclosed') + self.assertIn('forbidden characters', context.exception.message) + def test_append_notes_command(self): params = get_synced_params() cmd = record_edit.RecordAppendNotesCommand()