diff --git a/.github/workflows/sbom.yml b/.github/workflows/sbom.yml new file mode 100644 index 000000000..519b8902d --- /dev/null +++ b/.github/workflows/sbom.yml @@ -0,0 +1,61 @@ +name: Generate SBOM + +on: + workflow_dispatch: + release: + types: [published] + +jobs: + generate-sbom: + name: "Generate and Publish SBOM" + environment: prod + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install Commander and SBOM tools + run: | + python -m pip install --upgrade pip wheel setuptools + pip install . + pip install cyclonedx-bom + + VERSION=$(python3 -c "import keepercommander.__init__ as init; print(init.__version__)") + echo "PACKAGE_VERSION=${VERSION}" >> $GITHUB_ENV + + pip freeze > installed_packages.txt + + - name: Generate CycloneDX SBOM + run: cyclonedx-py environment -o sbom.cdx.json + + - name: Upload SBOM to Manifest-Cyber + run: | + sbom="$(base64 -w 0 sbom.cdx.json)" + cat < manifest-request.json + { + "base64BomContents": "$sbom", + "source": "github-actions", + "relationship": "first", + "filename": "sbom.cdx.json" + } + EOF + + curl --location --fail --request PUT 'https://api.manifestcyber.com/v1/sbom/upload' \ + --header 'Authorization: Bearer ${{ secrets.MANIFEST_TOKEN }}' \ + --header 'Content-Type: application/json' \ + --data-binary "@manifest-request.json" + + - name: Archive SBOM + uses: actions/upload-artifact@v4 + with: + name: sbom-keepercommander-${{ env.PACKAGE_VERSION }} + path: | + sbom.cdx.json + installed_packages.txt + retention-days: 30 diff --git a/.github/workflows/test-with-pytest.yml b/.github/workflows/test-with-pytest.yml index 95e697242..164247cc3 100644 --- a/.github/workflows/test-with-pytest.yml +++ b/.github/workflows/test-with-pytest.yml @@ -9,9 +9,9 @@ jobs: test-with-pytest: strategy: matrix: - python-version: ['3.7', '3.12'] + python-version: ['3.8', '3.14'] - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 steps: - name: Checkout branch diff --git a/keepercommander/commands/base.py b/keepercommander/commands/base.py index eb8f2ff1f..bc9e3c0f7 100644 --- a/keepercommander/commands/base.py +++ b/keepercommander/commands/base.py @@ -161,26 +161,22 @@ def register_commands(commands, aliases, command_info): commands['biometric'] = BiometricCommand() command_info['biometric'] = 'Biometric (Passkey) login management' - if sys.version_info.major == 3 and sys.version_info.minor >= 8: - from .start_service import register_commands as service_commands, register_command_info as service_command_info - service_commands(commands) - service_command_info(aliases, command_info) + from .start_service import register_commands as service_commands, register_command_info as service_command_info + service_commands(commands) + service_command_info(aliases, command_info) toggle_pam_legacy_commands(legacy=False) def toggle_pam_legacy_commands(legacy: bool): - if sys.version_info.major > 3 or (sys.version_info.major == 3 and sys.version_info.minor >= 8): - from . import discoveryrotation - from . import discoveryrotation_v1 - if legacy is True: - discoveryrotation_v1.register_commands(commands) - discoveryrotation_v1.register_command_info(aliases, command_info) - else: - discoveryrotation.register_commands(commands) - discoveryrotation.register_command_info(aliases, command_info) + from . import discoveryrotation + from . import discoveryrotation_v1 + if legacy is True: + discoveryrotation_v1.register_commands(commands) + discoveryrotation_v1.register_command_info(aliases, command_info) else: - logging.debug('pam commands require Python 3.8 or newer') + discoveryrotation.register_commands(commands) + discoveryrotation.register_command_info(aliases, command_info) def register_enterprise_commands(commands, aliases, command_info): diff --git a/keepercommander/commands/pam_extended/__init__.py b/keepercommander/commands/pam_extended/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/keepercommander/commands/pam_extended/discovery_rule_commands.py b/keepercommander/commands/pam_extended/discovery_rule_commands.py new file mode 100644 index 000000000..4bf99c3bc --- /dev/null +++ b/keepercommander/commands/pam_extended/discovery_rule_commands.py @@ -0,0 +1,185 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' ] + pam extended rule add --type --cidr --config-uid + pam extended rule delete --config-uid +""" +from __future__ import annotations + +import argparse +import json +import logging +import os +from typing import TYPE_CHECKING + +from ..base import ArgparseCommand +from ...error import CommandError +from ... import utils + +if TYPE_CHECKING: + from ...params import KeeperParams + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +def _get_dag_rules(params: "KeeperParams", config_uid: str) -> list[dict]: + """Return discovery rules from the PAM DAG for a configuration.""" + try: + from ...keeper_dag.types import PamGraphId + from ...keeper_dag.vertex import DAGVertex + except ImportError: + return [] + + config_uid_bytes = utils.base64_url_decode(config_uid) + dag = getattr(params, "pam_dag", None) + if dag is None: + return [] + + config_vertex = dag.get_vertex(config_uid_bytes) + if config_vertex is None: + return [] + + rules_vertex = config_vertex.get_child(PamGraphId.DISCOVERY_RULES) + if rules_vertex is None: + return [] + + rows = [] + for child in rules_vertex.children: + data = child.data + if isinstance(data, (bytes, bytearray)): + try: + data = json.loads(data) + except Exception: + data = {} + rows.append({"uid": child.uid.hex() if isinstance(child.uid, bytes) else child.uid, + **data}) + return rows + + +def _modify_dag_rule(params: "KeeperParams", config_uid: str, + operation: str, rule_data: dict, + element_uid: bytes | None = None) -> None: + """Apply an ADD / UPDATE / DELETE operation on a discovery rule DAG element.""" + from ...proto import pam_pb2 + from ...api import communicate_rest + + op_map = {"ADD": pam_pb2.PAMOperationType.ADD, + "UPDATE": pam_pb2.PAMOperationType.UPDATE, + "DELETE": pam_pb2.PAMOperationType.DELETE} + if operation not in op_map: + raise CommandError(f"Unknown operation: {operation}") + + config_uid_bytes = utils.base64_url_decode(config_uid) + element_uid_bytes = element_uid or os.urandom(16) + + data_op = pam_pb2.PAMDataOperation() + data_op.operationType = op_map[operation] + + element = pam_pb2.PAMElementData() + element.elementUid = element_uid_bytes + element.parentUid = config_uid_bytes + element.data = json.dumps(rule_data).encode() + data_op.element.CopyFrom(element) + + rq = pam_pb2.PAMModifyRequest() + rq.operations.append(data_op) + communicate_rest(params, rq, "pam/modify", rs_type=pam_pb2.PAMModifyResult) + + +# --------------------------------------------------------------------------- +# Commands +# --------------------------------------------------------------------------- + +class PamExtendedRuleListCommand(ArgparseCommand): + """``pam extended rule list``.""" + + def __init__(self) -> None: + parser = argparse.ArgumentParser(prog="list", description="List PAM discovery rules") + parser.add_argument("--config-uid", dest="config_uid", required=True) + parser.add_argument("--format", dest="fmt", choices=["table", "json"], default="table") + super().__init__(parser) + + def execute(self, params: "KeeperParams", **kwargs) -> None: + rows = _get_dag_rules(params, kwargs["config_uid"]) + if kwargs.get("fmt") == "json": + print(json.dumps(rows, indent=2)) + else: + if not rows: + print("No discovery rules found.") + return + for r in rows: + print(f" {r.get('uid', '?')} name={r.get('name', '?')} " + f"type={r.get('target_type', '?')} cidr={r.get('target_cidr', '?')}") + + +class PamExtendedRuleAddCommand(ArgparseCommand): + """``pam extended rule add``.""" + + def __init__(self) -> None: + parser = argparse.ArgumentParser(prog="add", description="Add a PAM discovery rule") + parser.add_argument("name", help="Rule name") + parser.add_argument( + "--type", dest="target_type", + choices=["machine", "user", "database"], default="machine", + ) + parser.add_argument("--cidr", dest="target_cidr", required=True, help="Target CIDR range") + parser.add_argument( + "--protocol", dest="protocol", + choices=["ssh", "rdp", "database"], default="ssh", + ) + parser.add_argument("--config-uid", dest="config_uid", required=True) + parser.add_argument( + "--credential-uid", dest="credential_uid", default=None, + help="Credential record UID", + ) + super().__init__(parser) + + def execute(self, params: "KeeperParams", **kwargs) -> None: + rule_data = { + "name": kwargs["name"], + "target_type": kwargs.get("target_type", "machine"), + "target_cidr": kwargs["target_cidr"], + "protocol": kwargs.get("protocol", "ssh"), + } + if kwargs.get("credential_uid"): + rule_data["credential_uid_ref"] = kwargs["credential_uid"] + + _modify_dag_rule(params, kwargs["config_uid"], "ADD", rule_data) + print(f"Discovery rule '{kwargs['name']}' added to config {kwargs['config_uid']}") + + +class PamExtendedRuleDeleteCommand(ArgparseCommand): + """``pam extended rule delete``.""" + + def __init__(self) -> None: + parser = argparse.ArgumentParser(prog="delete", description="Delete a PAM discovery rule") + parser.add_argument("uid", help="Rule element UID (hex)") + parser.add_argument("--config-uid", dest="config_uid", required=True) + super().__init__(parser) + + def execute(self, params: "KeeperParams", **kwargs) -> None: + element_uid = bytes.fromhex(kwargs["uid"]) + _modify_dag_rule( + params, kwargs["config_uid"], "DELETE", {}, + element_uid=element_uid, + ) + print(f"Discovery rule {kwargs['uid']} deleted from config {kwargs['config_uid']}") diff --git a/keepercommander/commands/pam_extended/group_command.py b/keepercommander/commands/pam_extended/group_command.py new file mode 100644 index 000000000..b0f6c07b4 --- /dev/null +++ b/keepercommander/commands/pam_extended/group_command.py @@ -0,0 +1,58 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' None: + super().__init__("Manage PAM rotation schedules") + self.register_command_new(PamExtendedScheduleListCommand(), "list") + self.register_command_new(PamExtendedScheduleSetCommand(), "set") + self.register_command_new(PamExtendedScheduleDeleteCommand(), "delete") + + +class PamExtendedRuleGroup(GroupCommandNew): + """``pam extended rule`` sub-group.""" + + def __init__(self) -> None: + super().__init__("Manage PAM discovery rules") + self.register_command_new(PamExtendedRuleListCommand(), "list") + self.register_command_new(PamExtendedRuleAddCommand(), "add") + self.register_command_new(PamExtendedRuleDeleteCommand(), "delete") + + +class PamExtendedCommand(GroupCommandNew): + """``pam extended`` — advanced PAM schedule and discovery-rule management.""" + + def __init__(self) -> None: + super().__init__("Advanced PAM schedule and discovery-rule management") + self.register_command_new(PamExtendedScheduleGroup(), "schedule") + self.register_command_new(PamExtendedRuleGroup(), "rule") diff --git a/keepercommander/commands/pam_extended/schedule_commands.py b/keepercommander/commands/pam_extended/schedule_commands.py new file mode 100644 index 000000000..701c812ab --- /dev/null +++ b/keepercommander/commands/pam_extended/schedule_commands.py @@ -0,0 +1,169 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' ] + pam extended schedule set --cron [--config-uid ] + pam extended schedule delete [--config-uid ] +""" +from __future__ import annotations + +import argparse +import json +import logging +from typing import TYPE_CHECKING + +from ..base import ArgparseCommand +from ...error import CommandError + +if TYPE_CHECKING: + from ...params import KeeperParams + +logger = logging.getLogger(__name__) + + +def _set_schedule(params: "KeeperParams", record_uid: str, + config_uid: str, cron_expr: str, notify_emails: list[str] | None = None) -> None: + """Write a named rotation schedule via the PAM rotation REST endpoint.""" + from ...proto import pam_pb2 + from ...api import communicate_rest + + rq = pam_pb2.PAMRotationSchedule() + rq.recordUid = bytes.fromhex(record_uid) if len(record_uid) == 32 else record_uid.encode() + if config_uid: + rq.configurationUid = ( + bytes.fromhex(config_uid) if len(config_uid) == 32 else config_uid.encode() + ) + schedule_data: dict = {"type": "cron", "cron": cron_expr} + if notify_emails: + schedule_data["notifyEmails"] = notify_emails + rq.scheduleData = json.dumps(schedule_data) + communicate_rest(params, rq, "pam/set_pam_rotation_schedule") + + +def _delete_schedule(params: "KeeperParams", record_uid: str) -> None: + """Remove a rotation schedule (set noSchedule=True).""" + from ...proto import pam_pb2 + from ...api import communicate_rest + + rq = pam_pb2.PAMRotationSchedule() + rq.recordUid = bytes.fromhex(record_uid) if len(record_uid) == 32 else record_uid.encode() + rq.noSchedule = True + communicate_rest(params, rq, "pam/set_pam_rotation_schedule") + + +def _list_schedules(params: "KeeperParams", config_uid: str | None = None) -> list[dict]: + """Return rotation schedules visible to the authenticated user.""" + from ...proto import pam_pb2 + from ...api import communicate_rest + + rq = pam_pb2.PAMGenericUidsRequest() if config_uid else pam_pb2.PAMGenericUidRequest.__new__( + pam_pb2.PAMGenericUidRequest + ) + rs = communicate_rest( + params, rq, + "pam/get_rotation_schedules", + rs_type=pam_pb2.PAMRotationSchedulesResponse, + ) + rows = [] + for s in rs.schedules: + entry: dict = { + "record_uid": s.recordUid.hex() if isinstance(s.recordUid, bytes) else s.recordUid, + "no_schedule": s.noSchedule, + } + if s.scheduleData: + try: + entry["schedule"] = json.loads(s.scheduleData) + except Exception: + entry["schedule_raw"] = s.scheduleData + rows.append(entry) + return rows + + +class PamExtendedScheduleListCommand(ArgparseCommand): + """``pam extended schedule list`` — list rotation schedules.""" + + def __init__(self) -> None: + parser = argparse.ArgumentParser( + prog="list", description="List PAM rotation schedules" + ) + parser.add_argument( + "--config-uid", dest="config_uid", default=None, + help="Filter by PAM configuration UID", + ) + parser.add_argument( + "--format", dest="fmt", choices=["table", "json"], default="table", + ) + super().__init__(parser) + + def execute(self, params: "KeeperParams", **kwargs) -> None: + rows = _list_schedules(params, config_uid=kwargs.get("config_uid")) + if kwargs.get("fmt") == "json": + print(json.dumps(rows, indent=2)) + else: + if not rows: + print("No rotation schedules found.") + return + for r in rows: + sched = r.get("schedule", {}) + cron = sched.get("cron", "(none)") + print(f" {r['record_uid']} cron={cron}") + + +class PamExtendedScheduleSetCommand(ArgparseCommand): + """``pam extended schedule set`` — create or update a named rotation schedule.""" + + def __init__(self) -> None: + parser = argparse.ArgumentParser( + prog="set", description="Create or update a PAM rotation schedule" + ) + parser.add_argument("uid_ref", help="PAM record UID") + parser.add_argument("--cron", dest="cron", required=True, help="Cron expression (5-field)") + parser.add_argument( + "--config-uid", dest="config_uid", default=None, + help="PAM configuration UID (optional)", + ) + parser.add_argument( + "--notify", dest="notify", action="append", default=None, + metavar="EMAIL", help="Email(s) to notify on schedule fire", + ) + super().__init__(parser) + + def execute(self, params: "KeeperParams", **kwargs) -> None: + uid_ref: str = kwargs["uid_ref"] + cron: str = kwargs["cron"] + config_uid: str | None = kwargs.get("config_uid") + notify: list[str] | None = kwargs.get("notify") + + _set_schedule(params, uid_ref, config_uid or "", cron, notify_emails=notify) + logger.info("Rotation schedule set: uid=%s cron=%s", uid_ref, cron) + print(f"Rotation schedule set for {uid_ref} (cron: {cron})") + + +class PamExtendedScheduleDeleteCommand(ArgparseCommand): + """``pam extended schedule delete`` — remove a rotation schedule.""" + + def __init__(self) -> None: + parser = argparse.ArgumentParser( + prog="delete", description="Remove a PAM rotation schedule" + ) + parser.add_argument("uid_ref", help="PAM record UID") + super().__init__(parser) + + def execute(self, params: "KeeperParams", **kwargs) -> None: + uid_ref: str = kwargs["uid_ref"] + _delete_schedule(params, uid_ref) + logger.info("Rotation schedule deleted: uid=%s", uid_ref) + print(f"Rotation schedule removed for {uid_ref}") diff --git a/keepercommander/commands/pam_import/README.md b/keepercommander/commands/pam_import/README.md index e5baae283..233f8f37a 100644 --- a/keepercommander/commands/pam_import/README.md +++ b/keepercommander/commands/pam_import/README.md @@ -303,7 +303,7 @@ Each Machine (pamMachine, pamDatabase, pamDirectory) can specify **Administrativ > **Note 3:** Post rotation scripts (a.k.a. `scripts`) are executed in following order: `pamUser` scripts after any **successful** rotation for that user, `pamMachine` scripts after any **successful** rotation on the machine and `pamConfiguration` scripts after any rotation using that configuration. > **Note 4:** When `allow_supply_user` is false and JIT ephemeral is not used, vault may require a launch credential; import can provide it via `launch_credentials` in the resource's `connection` block. -JIT and KeeperAI settings below are shared across all resource types (pamMachine, pamDatabase, pamDirectory) except User and RBI (pamRemoteBrowser) records. +JIT and KeeperAI settings below are shared across all resource types (pamMachine, pamDatabase, pamDirectory) except User and RBI (pamRemoteBrowser) records. **Workflow** (approvals / checkout / temporal restrictions) is supported on all four resource types: pamMachine, pamDatabase, pamDirectory, **and** pamRemoteBrowser.
Just-In-Time Access (JIT) @@ -406,6 +406,79 @@ JIT and KeeperAI settings below are shared across all resource types (pamMachine ```
+Workflow (Approvals, Checkout, Temporal Access) + +Workflow controls how privileged access to a resource is gated: how many approvals are needed, whether sessions require check-out, MFA, reason/ticket, what time windows access is allowed in, and who can approve (with optional escalation). Workflow is applied via the Keeper Router **after** the resource record and DAG/JIT/AI steps are complete and is not stored on the record itself. + +**How to Configure:** Add `pam_settings.options.workflow` to any pamMachine, pamDatabase, pamDirectory, or pamRemoteBrowser. The workflow object maps directly to the Web Vault's "Workflow" tab on a resource record. + +```json +{ + "pam_settings": { + "options": { + "workflow": { + "approvals_needed": 2, + "checkout_needed": true, + "start_access_on_approval": false, + "require_reason": true, + "require_ticket": false, + "require_mfa": true, + "access_duration": "8h", + "allowed_times": { + "allowed_days": ["mon", "tue", "wed", "thu", "fri"], + "time_ranges": [ + { "start": "09:00", "end": "17:30" } + ], + "timezone": "America/New_York" + }, + "approvers": [ + { + "principal": { "type": "user", "email": "primary.approver@example.com" }, + "escalation": false + }, + { + "principal": { "type": "user", "email": "second.approver@example.com" }, + "escalation": false + }, + { + "principal": { + "type": "team", + "team_uid_base64url": "REPLACE_TEAM_UID_BASE64URL" + }, + "escalation": true, + "escalation_after": "45m" + } + ] + } + } + } +} +``` + +**Field reference:** +- `approvals_needed` *(int, default `0`)* — number of approvals required to grant access. +- `checkout_needed` *(bool, default `false`)* — require explicit check-out before launching a session. +- `start_access_on_approval` *(bool, default `false`)* — start the access window the moment approval is granted (rather than at session launch). +- `require_reason` / `require_ticket` *(bool, default `false`)* — prompt the user for a reason / ticket reference at request time. +- `require_mfa` *(bool, default `false`)* — require MFA at session launch. +- `access_duration` *(string, default `"1d"`)* — how long approved access remains valid. Accepts `Xm` / `Xh` / `Xd` (e.g. `"30m"`, `"8h"`, `"2d"`); a bare integer is interpreted as minutes. Must be positive. +- `allowed_times.allowed_days` *(list of strings)* — restrict access to these weekdays. Accepts 3-letter (`mon`..`sun`) or full names (`monday`..`sunday`), case-insensitive. +- `allowed_times.time_ranges` *(list of `{start, end}` objects)* — one or more allowed daily time windows in `HH:MM` (24-hour) format. **Multiple ranges per day are supported.** A single range whose `end` is earlier than its `start` (e.g. an overnight `22:00–06:00`) **should be split into two ranges** that both fall inside one day (e.g. `22:00–23:59` and `00:00–06:00`) +- `allowed_times.timezone` *(string)* — IANA timezone name (e.g. `"UTC"`, `"America/New_York"`). **Required when `time_ranges` is non-empty.** +- `approvers[]` — list of approver entries. + - `principal.type` — `"user"` or `"team"`. + - For users: `principal.email` (must exist in the enterprise). + - For teams: `principal.team_uid_base64url` (the team's vault UID, base64url-encoded; validated against the local team cache during import — unknown UIDs fail in dry-run). + - `escalation` *(bool)* — whether this approver is in the escalation chain. + - `escalation_after` *(duration string, optional)* — wait this long before escalating to this approver. **Requires `escalation: true`.** + +**Behavior notes:** +- **Trivial workflow is a no-op.** If none of `approvals_needed > 0`, `checkout_needed`, `require_mfa`, `start_access_on_approval`, `allowed_times.allowed_days`, or `allowed_times.time_ranges` is set, the workflow block is treated as absent and no Router call is made. +- **Pre-flight validation runs in `--dry-run`.** Bad durations, malformed `HH:MM`, missing timezone, escalation rule violations, and unknown team UIDs are reported during dry-run before any vault writes. +- **Dry-run skips the Router calls.** Workflow is applied (Router create/update + approver reconcile) only on a real run. +- **`extend` only applies workflow to newly created resources** (existing resources are not touched). +
+
pam_data.resources.pamMachine (RDP) ```json @@ -435,7 +508,8 @@ JIT and KeeperAI settings below are shared across all resource types (pamMachine "ai_threat_detection": "off", "ai_terminate_session_on_detection": "off", "jit_settings": {}, - "ai_settings": {} + "ai_settings": {}, + "workflow": {} }, "allow_supply_host": false, "port_forward": { diff --git a/keepercommander/commands/pam_import/base.py b/keepercommander/commands/pam_import/base.py index 22137b8cf..e5cf37835 100644 --- a/keepercommander/commands/pam_import/base.py +++ b/keepercommander/commands/pam_import/base.py @@ -22,9 +22,11 @@ from typing import Any, Dict, Optional, List, Union from ..record_edit import RecordAddCommand as RecordEditAddCommand +from ..workflow.helpers import RecordResolver, WorkflowFormatter from ... import api, attachment, utils, vault, vault_extensions, \ record_facades, record_management from ...display import bcolors +from ...error import CommandError from ...recordv3 import RecordV3 @@ -69,7 +71,8 @@ "pam_settings": { "options" : { "jit_settings": {}, - "ai_settings": {} + "ai_settings": {}, + "workflow": {} }, "connection" : {} }, @@ -611,6 +614,144 @@ def load(cls, data: Union[str, dict]): return obj +class PamWorkflowOptions: + """Parsed workflow settings from pam_settings.options.workflow. + Not stored on record fields nor in DAG; applied via Krouter after record/DAG creation. + """ + + _DEFAULT_DURATION_MS = 86_400_000 # "1d" + + def __init__(self): + self.approvals_needed: int = 0 + self.checkout_needed: bool = False + self.start_access_on_approval: bool = False + self.require_reason: bool = False + self.require_ticket: bool = False + self.require_mfa: bool = False + self.access_duration_ms: int = self._DEFAULT_DURATION_MS + self.allowed_days: List[str] = [] # canonical 3-letter tokens: "mon".."sun" + self.time_ranges: List[dict] = [] # each: {"start": "HH:MM", "end": "HH:MM"} + self.timezone: str = "" + self.approvers: List[dict] = [] # each: {principal_type, email, team_uid_b64, escalation, escalation_after_ms} + + @staticmethod + def _parse_duration(value) -> int: + """Return milliseconds. Raises CommandError on invalid/non-positive value. + Delegates to WorkflowFormatter.parse_duration; adds a None -> default-1d shim + (the CLI command always supplies a string, but the JSON import may omit the key). + """ + if value is None: + return PamWorkflowOptions._DEFAULT_DURATION_MS + return WorkflowFormatter.parse_duration(str(value)) + + @classmethod + def load(cls, data) -> Optional['PamWorkflowOptions']: + """Parse workflow JSON dict. Returns None when absent / null / trivial (V2 guard).""" + if not data or not isinstance(data, dict): + return None + + obj = cls() + obj.approvals_needed = max(0, int(data.get('approvals_needed', 0) or 0)) + obj.checkout_needed = bool(data.get('checkout_needed', False)) + obj.start_access_on_approval = bool(data.get('start_access_on_approval', False)) + obj.require_reason = bool(data.get('require_reason', False)) + obj.require_ticket = bool(data.get('require_ticket', False)) + obj.require_mfa = bool(data.get('require_mfa', False)) + + # V9: access_duration — default "1d" + obj.access_duration_ms = cls._parse_duration(data.get('access_duration')) + + # allowed_times + at = data.get('allowed_times') or {} + if isinstance(at, dict): + days_raw = at.get('allowed_days') or [] + if isinstance(days_raw, list): + for day in days_raw: + d = str(day).lower().strip() + if d not in WorkflowFormatter.DAY_PARSE_MAP: + raise CommandError('', f'workflow: invalid allowed_times.allowed_days token "{day}"') + obj.allowed_days.append(d[:3]) # store as "mon".."sun" + + ranges_raw = at.get('time_ranges') or [] + if isinstance(ranges_raw, list): + for r in ranges_raw: + if isinstance(r, dict): + start = str(r.get('start', '') or '').strip() + end = str(r.get('end', '') or '').strip() + if start and end: + obj.time_ranges.append({'start': start, 'end': end}) + + obj.timezone = str(at.get('timezone', '') or '').strip() + + # V8: time_ranges non-empty => timezone required + if obj.time_ranges and not obj.timezone: + raise CommandError('', 'workflow: allowed_times.time_ranges requires timezone') + + # approvers + for idx, a in enumerate(data.get('approvers') or []): + if not isinstance(a, dict): + continue + principal = a.get('principal') or {} + if not isinstance(principal, dict): + continue + ptype = str(principal.get('type', '') or '').lower() + escalation = bool(a.get('escalation', False)) + esc_after_raw = a.get('escalation_after') + esc_after_ms = cls._parse_duration(esc_after_raw) if esc_after_raw else 0 + # V7: escalation_after requires escalation: true + if esc_after_ms and not escalation: + raise CommandError('', f'workflow: approvers[{idx}] escalation_after requires escalation: true') + if ptype == 'user': + email = str(principal.get('email', '') or '').strip() + if not email: + raise CommandError('', f'workflow: approvers[{idx}] user principal requires non-empty email') + obj.approvers.append({ + 'principal_type': 'user', 'email': email, 'team_uid_b64': None, + 'escalation': escalation, 'escalation_after_ms': esc_after_ms, + }) + elif ptype == 'team': + uid_b64 = str(principal.get('team_uid_base64url', '') or '').strip() + if not uid_b64: + raise CommandError('', f'workflow: approvers[{idx}] team principal requires non-empty team_uid_base64url') + obj.approvers.append({ + 'principal_type': 'team', 'email': None, 'team_uid_b64': uid_b64, + 'escalation': escalation, 'escalation_after_ms': esc_after_ms, + }) + else: + raise CommandError('', f'workflow: approvers[{idx}] principal.type must be "user" or "team", got "{ptype}"') + + # V2: non-trivial guard — at least one meaningful flag must be set + is_trivial = ( + obj.approvals_needed == 0 + and not obj.start_access_on_approval + and not obj.checkout_needed + and not obj.require_mfa + and not obj.allowed_days + and not obj.time_ranges + ) + if is_trivial: + return None # nothing to persist; caller treats as delete/no-op + + # V4 warning: approvals_needed > 0 with no approvers + if obj.approvals_needed > 0 and not obj.approvers: + logging.warning('workflow: approvals_needed > 0 but no approvers specified') + + return obj + + def validate_principals(self, params, resource_title: str = '') -> None: + """Validate team UIDs via RecordResolver.validate_team (which checks both + team_cache and enterprise.teams). Raises CommandError on first unknown UID. + """ + for idx, a in enumerate(self.approvers): + if a['principal_type'] != 'team': + continue + try: + RecordResolver.validate_team(params, a['team_uid_b64']) + except CommandError as e: + prefix = f'Resource "{resource_title}": ' if resource_title else '' + raise CommandError('', f'{prefix}workflow approvers[{idx}]: {e.message or str(e)}') + + class DagJitSettingsObject(): def __init__(self): self.create_ephemeral: bool = False @@ -2900,10 +3041,12 @@ class PamRemoteBrowserSettings: def __init__( self, options: Optional[DagSettingsObject] = None, - connection: Optional[ConnectionSettingsHTTP] = None + connection: Optional[ConnectionSettingsHTTP] = None, + workflow: Optional[PamWorkflowOptions] = None, ): self.options = options self.connection = connection + self.workflow = workflow # not on record nor in DAG; applied via Krouter @classmethod def load(cls, data: Optional[Union[str, dict]]): @@ -2912,9 +3055,14 @@ def load(cls, data: Optional[Union[str, dict]]): except: logging.error(f"PAM RBI Settings field failed to load from: {str(data)[:80]}...") if not isinstance(data, dict): return obj - options = DagSettingsObject.load(data.get("options", {})) + options_dict = data.get("options", {}) or {} + options = DagSettingsObject.load(options_dict) if not is_empty_instance(options): obj.options = options + if isinstance(options_dict, dict): + workflow_value = options_dict.get("workflow") + if workflow_value is not None: + obj.workflow = PamWorkflowOptions.load(workflow_value) cdata = data.get("connection", {}) # TO DO: if isinstance(cdata, str): lookup_by_name(pam_data.connections) @@ -2944,6 +3092,7 @@ def __init__( options: Optional[DagSettingsObject] = None, jit_settings: Optional[DagJitSettingsObject] = None, ai_settings: Optional[DagAiSettingsObject] = None, + workflow: Optional[PamWorkflowOptions] = None, ): self.allowSupplyHost = allowSupplyHost self.connection = connection @@ -2951,6 +3100,7 @@ def __init__( self.options = options self.jit_settings = jit_settings self.ai_settings = ai_settings + self.workflow = workflow # not on record nor in DAG; applied via Krouter # PamConnectionSettings excludes ConnectionSettingsHTTP pam_connection_classes = [ @@ -2981,8 +3131,8 @@ def is_empty(self): empty = is_empty_instance(self.options) empty = empty and is_empty_instance(self.portForward) empty = empty and is_empty_instance(self.connection, ["protocol"]) - # NB! JIT and AI settings are in import json but not in record json (just DAG json) - empty = empty and self.jit_settings is None and self.ai_settings is None + # NB! JIT, AI, workflow are in import json but not in record json (not DAG either for workflow) + empty = empty and self.jit_settings is None and self.ai_settings is None and self.workflow is None return empty @classmethod @@ -3008,6 +3158,9 @@ def load(cls, data: Union[str, dict]): ai_settings = DagAiSettingsObject.load(ai_value) if ai_settings: obj.ai_settings = ai_settings + workflow_value = options_dict.get("workflow") + if workflow_value is not None: + obj.workflow = PamWorkflowOptions.load(workflow_value) portForward = PamPortForwardSettings.load(data.get("port_forward", {})) if not is_empty_instance(portForward): diff --git a/keepercommander/commands/pam_import/edit.py b/keepercommander/commands/pam_import/edit.py index 0b5d35686..d80fd8354 100644 --- a/keepercommander/commands/pam_import/edit.py +++ b/keepercommander/commands/pam_import/edit.py @@ -19,6 +19,7 @@ from typing import Any, Dict, Optional, List, Union from .keeper_ai_settings import set_resource_jit_settings, set_resource_keeper_ai_settings, refresh_meta_to_latest, refresh_link_to_config_to_latest +from .workflow_apply import apply_workflow, validate_workflow_principals from .base import ( PAM_RESOURCES_RECORD_TYPES, PROJECT_IMPORT_JSON_TEMPLATE, @@ -1642,6 +1643,9 @@ def process_data(self, params, project): resolve_domain_admin(pce, users) # only resolve here - create after machine and user creation + # pre-flight: validate workflow team UIDs before any vault writes (runs in dry-run too) + validate_workflow_principals(params, resources) + # dry run if project["options"].get("dry_run", False) is True: print("Will import file data here...") @@ -1696,6 +1700,9 @@ def process_data(self, params, project): args["connections"] = True args["v_type"] = RefType.PAM_BROWSER tdag.set_resource_allowed(**args) + rbi_wf = getattr(getattr(mach, 'rbi_settings', None), 'workflow', None) + if rbi_wf: + apply_workflow(params, mach.uid, mach.title or '', rbi_wf) else: # machine/db/directory args = parse_command_options(mach, True) if admin_uid: args["admin"] = admin_uid @@ -1739,6 +1746,10 @@ def process_data(self, params, project): if ai: refresh_link_to_config_to_latest(params, mach.uid, pam_cfg_uid) + ps_wf = getattr(getattr(mach, 'pam_settings', None), 'workflow', None) + if ps_wf: + apply_workflow(params, mach.uid, mach.title or '', ps_wf) + # Machine - create its users (if any) users = getattr(mach, "users", []) users = users if isinstance(users, list) else [] diff --git a/keepercommander/commands/pam_import/extend.py b/keepercommander/commands/pam_import/extend.py index 82fb5522b..c21a2a6f2 100644 --- a/keepercommander/commands/pam_import/extend.py +++ b/keepercommander/commands/pam_import/extend.py @@ -53,6 +53,7 @@ refresh_meta_to_latest, refresh_link_to_config_to_latest, ) +from .workflow_apply import apply_workflow, validate_workflow_principals from ...keeper_dag import EdgeType from ...keeper_dag.types import RefType from ..base import Command @@ -549,6 +550,10 @@ def execute(self, params, **kwargs): fp = (getattr(u, "folder_path", None) or "").strip() u.resolved_folder_uid = path_to_folder_uid.get(fp) or usr_folder_uid + # pre-flight: validate workflow team UIDs for new resources (runs in dry-run too) + new_rscs = [r for r in project.get('mapped_resources', []) if getattr(r, '_extend_tag', None) == 'new'] + validate_workflow_principals(params, new_rscs) + if dry_run: print("[DRY RUN COMPLETE] No changes were made. All actions were validated but not executed.") return @@ -1402,6 +1407,9 @@ def process_data(self, params, project): args["connections"] = True args["v_type"] = RefType.PAM_BROWSER tdag.set_resource_allowed(**args) + rbi_wf = getattr(getattr(mach, 'rbi_settings', None), 'workflow', None) + if rbi_wf: + apply_workflow(params, mach.uid, mach.title or '', rbi_wf) else: args = parse_command_options(mach, True) if admin_uid: @@ -1444,6 +1452,10 @@ def process_data(self, params, project): if ai: refresh_link_to_config_to_latest(params, mach.uid, pam_cfg_uid) + ps_wf = getattr(getattr(mach, 'pam_settings', None), 'workflow', None) + if ps_wf: + apply_workflow(params, mach.uid, mach.title or '', ps_wf) + mach_users = getattr(mach, "users", []) or [] for user in mach_users: if getattr(user, "_extend_tag", None) != "new": diff --git a/keepercommander/commands/pam_import/workflow_apply.py b/keepercommander/commands/pam_import/workflow_apply.py new file mode 100644 index 000000000..65ffc8c5f --- /dev/null +++ b/keepercommander/commands/pam_import/workflow_apply.py @@ -0,0 +1,262 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' bool: + if isinstance(e, KeeperApiError) and e.result_code == 429: + return True + msg = str(getattr(e, 'message', None) or e).lower() + return 'throttle' in msg or 'too many' in msg + + +def _post_with_throttle_retry(params, path: str, **kwargs): + """Wrap _post_request_to_router with progressive backoff on 429 / throttle errors. + Non-throttle errors propagate immediately. Final retry's exception is re-raised. + """ + wait = _THROTTLE_BASE_WAIT + for attempt in range(1, _THROTTLE_MAX_RETRIES + 1): + try: + return _post_request_to_router(params, path, **kwargs) + except Exception as e: + if not _is_throttle_error(e) or attempt >= _THROTTLE_MAX_RETRIES: + raise + logging.warning( + 'Krouter rate-limited on %s (attempt %d/%d); waiting %.1fs', + path, attempt, _THROTTLE_MAX_RETRIES, wait, + ) + time.sleep(wait) + wait *= _THROTTLE_MULTIPLIER + + +# Re-exported for tests and any downstream importers; the canonical map lives +# in WorkflowFormatter.DAY_PARSE_MAP and accepts both 3-letter and full names. +_DAY_PROTO_MAP = { + k: v for k, v in WorkflowFormatter.DAY_PARSE_MAP.items() if len(k) == 3 +} + + +def _build_temporal_filter(opts: PamWorkflowOptions): + """Build TemporalAccessFilter from opts. Returns None when no temporal slice is set. + + startTime / endTime on TimeOfDayRange are HHMM integers (hours*100 + minutes); + see WorkflowFormatter._parse_time_to_hhmm. Canonical sources: + - keeperapp-protobuf/workflow.proto:140 (`int32 startTime = 1; // HHMM format`) + - ka-libs/workflow/.../handlers/WfConfigCRUD.kt::validateHHMM (server validator) + """ + if not opts.allowed_days and not opts.time_ranges and not opts.timezone: + return None + temporal = workflow_pb2.TemporalAccessFilter() + for day_token in opts.allowed_days: + day_enum = WorkflowFormatter.DAY_PARSE_MAP.get(day_token) + if day_enum is not None: + temporal.allowedDays.append(day_enum) + for r in opts.time_ranges: + tr = workflow_pb2.TimeOfDayRange() + tr.startTime = WorkflowFormatter._parse_time_to_hhmm(r['start']) + tr.endTime = WorkflowFormatter._parse_time_to_hhmm(r['end']) + temporal.timeRanges.append(tr) + if opts.timezone: + temporal.timeZone = opts.timezone + return temporal + + +def _build_parameters( + record_uid_bytes: bytes, + record_title: str, + opts: PamWorkflowOptions, +) -> workflow_pb2.WorkflowParameters: + params_proto = workflow_pb2.WorkflowParameters() + params_proto.resource.CopyFrom(ProtobufRefBuilder.record_ref(record_uid_bytes, record_title)) + params_proto.approvalsNeeded = opts.approvals_needed + params_proto.checkoutNeeded = opts.checkout_needed + params_proto.startAccessOnApproval = opts.start_access_on_approval + params_proto.requireReason = opts.require_reason + params_proto.requireTicket = opts.require_ticket + params_proto.requireMFA = opts.require_mfa + params_proto.accessLength = opts.access_duration_ms + + temporal = _build_temporal_filter(opts) + if temporal: + params_proto.allowedTimes.CopyFrom(temporal) + + return params_proto + + +def _build_approver_proto(a: dict) -> workflow_pb2.WorkflowApprover: + approver = workflow_pb2.WorkflowApprover() + if a['principal_type'] == 'user': + approver.user = a['email'] + else: + approver.teamUid = utils.base64_url_decode(a['team_uid_b64']) + approver.escalation = a['escalation'] + if a['escalation_after_ms']: + approver.escalationAfterMs = a['escalation_after_ms'] + return approver + + +def _approver_key(params: KeeperParams, approver: workflow_pb2.WorkflowApprover) -> str: + """Return a stable identity key for an existing server approver (for reconcile diff). + Server may return either user (email) or userId (int). When userId is set, resolve + to email through the enterprise user list so it matches the import-side key. + """ + if approver.HasField('user'): + return f'user:{approver.user}' + if approver.HasField('userId'): + email = RecordResolver.resolve_user(params, approver.userId) + # resolve_user returns 'User ID ' when not found — fall back to userId so + # we don't accidentally key two different unknown users to the same string. + if email and not email.startswith('User ID '): + return f'user:{email}' + return f'userid:{approver.userId}' + if approver.HasField('teamUid'): + return f'team:{utils.base64_url_encode(approver.teamUid)}' + return '' + + +def _new_approver_key(a: dict) -> str: + if a['principal_type'] == 'user': + return f'user:{a["email"]}' + return f'team:{a["team_uid_b64"]}' + + +def _reconcile_approvers( + params: KeeperParams, + record_uid_bytes: bytes, + record_title: str, + existing: List[workflow_pb2.WorkflowApprover], + new_approvers: List[dict], +) -> None: + ref = ProtobufRefBuilder.record_ref(record_uid_bytes, record_title) + + existing_keys = {_approver_key(params, a): a for a in existing} + new_keys = {_new_approver_key(a): a for a in new_approvers} + + to_delete = [a for k, a in existing_keys.items() if k not in new_keys] + to_add = [a for k, a in new_keys.items() if k not in existing_keys] + + if to_delete: + config = workflow_pb2.WorkflowConfig() + config.parameters.resource.CopyFrom(ref) + for a in to_delete: + config.approvers.append(a) + _post_with_throttle_retry(params, 'delete_workflow_approvers', rq_proto=config) + + if to_add: + config = workflow_pb2.WorkflowConfig() + config.parameters.resource.CopyFrom(ref) + for a in to_add: + config.approvers.append(_build_approver_proto(a)) + _post_with_throttle_retry(params, 'add_workflow_approvers', rq_proto=config) + + +def apply_workflow( + params: KeeperParams, + record_uid: str, + record_title: str, + opts: PamWorkflowOptions, +) -> None: + """Create or update workflow config via Krouter. Raises CommandError on failure.""" + record_uid_bytes = utils.base64_url_decode(record_uid) + ref = ProtobufRefBuilder.record_ref(record_uid_bytes, record_title) + + try: + existing = _post_with_throttle_retry( + params, 'read_workflow_config', + rq_proto=ref, rs_type=workflow_pb2.WorkflowConfig, + ) + except Exception as e: + raise CommandError('', f'workflow read failed for "{record_title}": {sanitize_router_error(e)}') + + parameters = _build_parameters(record_uid_bytes, record_title, opts) + + try: + if existing: + _post_with_throttle_retry(params, 'update_workflow_config', rq_proto=parameters) + if opts.approvals_needed > 0: + _reconcile_approvers( + params, record_uid_bytes, record_title, + list(existing.approvers), opts.approvers, + ) + elif existing.approvers: + # approvals_needed dropped to 0: remove all existing approvers (V5) + config = workflow_pb2.WorkflowConfig() + config.parameters.resource.CopyFrom(ref) + for a in existing.approvers: + config.approvers.append(a) + _post_with_throttle_retry(params, 'delete_workflow_approvers', rq_proto=config) + else: + _post_with_throttle_retry(params, 'create_workflow_config', rq_proto=parameters) + if opts.approvals_needed > 0 and opts.approvers: + config = workflow_pb2.WorkflowConfig() + config.parameters.resource.CopyFrom(ref) + for a in opts.approvers: + config.approvers.append(_build_approver_proto(a)) + _post_with_throttle_retry(params, 'add_workflow_approvers', rq_proto=config) + except CommandError: + raise + except Exception as e: + raise CommandError('', f'workflow apply failed for "{record_title}": {sanitize_router_error(e)}') + + +def validate_workflow_principals(params: KeeperParams, resources) -> None: + """Pre-flight: validate team UIDs in workflow approvers for all resources. + Uses RecordResolver.validate_team which checks both team_cache and enterprise.teams, + matching the lookup path used by `pam workflow add-approver`. Raises CommandError + on the first unknown UID, with the resource title in the message for context. + """ + for mach in resources or []: + opts = None + ps = getattr(mach, 'pam_settings', None) + if ps: + opts = getattr(ps, 'workflow', None) + if opts is None: + rbi = getattr(mach, 'rbi_settings', None) + if rbi: + opts = getattr(rbi, 'workflow', None) + if opts is None: + continue + title = getattr(mach, 'title', '') or '' + for idx, a in enumerate(opts.approvers): + if a['principal_type'] != 'team': + continue + try: + RecordResolver.validate_team(params, a['team_uid_b64']) + except CommandError as e: + prefix = f'Resource "{title}": ' if title else '' + raise CommandError('', f'{prefix}workflow approvers[{idx}]: {e.message or str(e)}') diff --git a/keepercommander/commands/utils.py b/keepercommander/commands/utils.py index 1df41bc88..6a2fc65c3 100644 --- a/keepercommander/commands/utils.py +++ b/keepercommander/commands/utils.py @@ -1583,24 +1583,22 @@ def execute(self, params, **kwargs): print('{0:>20s}: {1}'.format('Executable', sys.executable)) if logging.getLogger().isEnabledFor(logging.DEBUG) or show_packages: - ver = sys.version_info - if ver.major >= 3 and ver.minor >= 8: - import importlib.metadata - dist = importlib.metadata.packages_distributions() - packages = {} - for pack in dist.values(): - if isinstance(pack, list) and len(pack) > 0: - name = pack[0] - if name in packages: - continue - try: - version = importlib.metadata.version(name) - packages[name] = version - except Exception as e: - logging.debug('Get package %s version error: %s', name, e) - installed_packages_list = [f'{x[0]}=={x[1]}' for x in packages.items()] - installed_packages_list.sort(key=lambda x: x.lower()) - print('{0:>20s}: {1}'.format('Packages', installed_packages_list)) + import importlib.metadata + dist = importlib.metadata.packages_distributions() + packages = {} + for pack in dist.values(): + if isinstance(pack, list) and len(pack) > 0: + name = pack[0] + if name in packages: + continue + try: + version = importlib.metadata.version(name) + packages[name] = version + except Exception as e: + logging.debug('Get package %s version error: %s', name, e) + installed_packages_list = [f'{x[0]}=={x[1]}' for x in packages.items()] + installed_packages_list.sort(key=lambda x: x.lower()) + print('{0:>20s}: {1}'.format('Packages', installed_packages_list)) if version_details.get('is_up_to_date') is None: logging.debug("It appears that Commander is up to date") diff --git a/keepercommander/commands/workflow/config_commands.py b/keepercommander/commands/workflow/config_commands.py index 30137bf9c..28ca9c71f 100644 --- a/keepercommander/commands/workflow/config_commands.py +++ b/keepercommander/commands/workflow/config_commands.py @@ -329,8 +329,11 @@ def _print_table(params, response, record_uid): print(f" Days: {', '.join(day_names)}") if at.timeRanges: for tr in at.timeRanges: - start_h, start_m = divmod(tr.startTime, 60) - end_h, end_m = divmod(tr.endTime, 60) + # startTime / endTime are HHMM (hours*100 + minutes); see + # WorkflowFormatter._parse_time_to_hhmm and the canonical + # ka-libs/workflow/.../WfConfigCRUD.kt::validateHHMM. + start_h, start_m = divmod(tr.startTime, 100) + end_h, end_m = divmod(tr.endTime, 100) print(f" Time: {start_h:02d}:{start_m:02d} - {end_h:02d}:{end_m:02d}") if at.timeZone: print(f" Timezone: {at.timeZone}") diff --git a/keepercommander/commands/workflow/helpers.py b/keepercommander/commands/workflow/helpers.py index e46eb263b..21a3ac4dd 100644 --- a/keepercommander/commands/workflow/helpers.py +++ b/keepercommander/commands/workflow/helpers.py @@ -523,9 +523,17 @@ def build_temporal_filter(allowed_days_str, time_range_str, timezone_str): @staticmethod def _parse_time_to_hhmm(time_str): - """Parse 'HH:MM' into the HHMM integer encoding the server expects on - TimeOfDayRange.startTime / .endTime — e.g. '03:00' -> 300, '17:30' -> 1730. - Server validates: HHMM integer with HH in 0-23 and MM in 0-59. + """Parse 'HH:MM' to the HHMM integer the server stores on + TimeOfDayRange.startTime / .endTime: hours*100 + minutes. + Examples: '00:00' -> 0, '03:00' -> 300, '09:00' -> 900, '17:30' -> 1730. + Valid range: 0..2359 with hours in 0-23 and minutes in 0-59. + + Canonical sources (all agree on HHMM): + - keeperapp-protobuf/workflow.proto:140 + `int32 startTime = 1; // HHMM format` + - ka-libs/workflow/src/main/kotlin/com/keepersecurity/workflow/handlers/WfConfigCRUD.kt::validateHHMM + `val hours = value / 100; val minutes = value % 100` + throws "Invalid : . Expected HHMM integer with HH in 0-23 and MM in 0-59" on bad input. """ try: parts = time_str.split(':') @@ -547,6 +555,7 @@ def format_temporal_filter(at): if at.timeRanges: ranges = [] for tr in at.timeRanges: + # startTime / endTime are HHMM integers (see _parse_time_to_hhmm). sh, sm = divmod(tr.startTime, 100) eh, em = divmod(tr.endTime, 100) ranges.append(f"{sh:02d}:{sm:02d}-{eh:02d}:{em:02d}") diff --git a/keepercommander/commands/workflow/registry.py b/keepercommander/commands/workflow/registry.py index ae87e7e8c..2ea6f31ed 100644 --- a/keepercommander/commands/workflow/registry.py +++ b/keepercommander/commands/workflow/registry.py @@ -9,9 +9,6 @@ # Contact: ops@keepersecurity.com # -import logging -from urllib.parse import urlparse - from ..base import GroupCommand, dump_report_data from ...display import bcolors from .helpers import _ENFORCEMENT_KEY @@ -42,15 +39,8 @@ class PAMWorkflowCommand(GroupCommand): - NOTICE_MSG = 'Notice: PAM Workflow commands are not in production yet. They will be available soon.' - _ALLOWED_PREFIXES = ('dev.', 'qa.') _ADMIN_VERBS = frozenset({'create', 'update', 'delete', 'add-approver', 'remove-approver'}) - @staticmethod - def _is_allowed_server(params): - hostname = urlparse(params.rest_context.server_base).hostname or '' - return any(hostname.startswith(p) for p in PAMWorkflowCommand._ALLOWED_PREFIXES) - @staticmethod def _can_manage_workflows(params): enforcements = getattr(params, 'enforcements', None) @@ -62,10 +52,6 @@ def _can_manage_workflows(params): ) def execute_args(self, params, args, **kwargs): - if not self._is_allowed_server(params): - logging.warning(f"{bcolors.WARNING}{self.NOTICE_MSG}{bcolors.ENDC}") - return - self._current_params = params pos = args.find(' ') if args else -1 diff --git a/keepercommander/constants.py b/keepercommander/constants.py index f60ae1c90..89ced2729 100644 --- a/keepercommander/constants.py +++ b/keepercommander/constants.py @@ -112,6 +112,7 @@ class PrivilegeScope(enum.IntEnum): ("MASTER_PASSWORD_MINIMUM_UPPER", 12, "LONG", "LOGIN_SETTINGS"), ("MASTER_PASSWORD_MINIMUM_LOWER", 13, "LONG", "LOGIN_SETTINGS"), ("MASTER_PASSWORD_MINIMUM_DIGITS", 14, "LONG", "LOGIN_SETTINGS"), + ("MASTER_PASSWORD_MINIMUM_LENGTH_NO_PROMPT", 15, "LONG", "LOGIN_SETTINGS"), ("MASTER_PASSWORD_RESTRICT_DAYS_BEFORE_REUSE", 16, "LONG", "LOGIN_SETTINGS"), ("REQUIRE_TWO_FACTOR", 20, "BOOLEAN", "TWO_FACTOR_AUTHENTICATION"), ("MASTER_PASSWORD_MAXIMUM_DAYS_BEFORE_CHANGE", 22, "LONG", "LOGIN_SETTINGS"), @@ -231,6 +232,7 @@ class PrivilegeScope(enum.IntEnum): ("ALLOW_VIEW_KCM_RECORDINGS", 234, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_TOTP_FIELD", 235, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ("ALLOW_VIEW_RBI_RECORDINGS", 236, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), + ("USE_DEFAULT_BROWSER_FOR_SSO", 237, "TERNARY_DEN", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_MANAGE_TLA", 238, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_SELF_DESTRUCT_RECORDS", 239, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_PERSONAL_USING_BUSINESS_DOMAINS", 240, "STRING", "ACCOUNT_ENFORCEMENTS"), @@ -240,6 +242,8 @@ class PrivilegeScope(enum.IntEnum): ("WARN_PERSONAL_USING_BUSINESS_SITES", 244, "STRING", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_ACCOUNT_SWITCHING", 245, "BOOLEAN", "AUTHENTICATION_ENFORCEMENTS"), ("RESTRICT_PASSKEY_LOGIN", 246, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), + # NOTE: 247 server name is ALLOW_CAN_EDIT_EXTERNAL_SHARES (positive). Commander's + # RESTRICT_ name is kept for backward compat but the polarity is inverted vs the server. ("RESTRICT_CAN_EDIT_EXTERNAL_SHARES", 247, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_SNAPSHOT_TOOL", 248, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_FORCEFIELD", 249, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), @@ -248,6 +252,16 @@ class PrivilegeScope(enum.IntEnum): ("RESTRICT_SF_FOLDER_DELETION", 253, "BOOLEAN", "SHARING_ENFORCEMENTS"), ("RESTRICT_PLATFORM_PASSKEY_LOGIN", 254, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_CROSS_PLATFORM_PASSKEY_LOGIN", 255, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_SESSION_WEB", 256, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_SESSION_MOBILE", 257, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_SESSION_DESKTOP", 258, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_SESSION_CONSOLE", 259, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_DEFAULT_WEB", 260, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_DEFAULT_MOBILE", 261, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_DEFAULT_DESKTOP", 262, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_DEFAULT_CONSOLE", 263, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("LOGOUT_TIMER_CONSOLE", 264, "LONG", "ACCOUNT_SETTINGS"), + ("ALLOW_CONFIGURE_WORKFLOW_SETTINGS", 267, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ] _COMPOUND_ENFORCEMENTS = [ diff --git a/keepercommander/importer/manageengine/restapi.py b/keepercommander/importer/manageengine/restapi.py index d91ade5b1..bd8cf3c46 100644 --- a/keepercommander/importer/manageengine/restapi.py +++ b/keepercommander/importer/manageengine/restapi.py @@ -37,11 +37,7 @@ } -if sys.version_info < (3, 7): - Url = namedtuple('Url', ['scheme', 'netloc', 'path', 'query', 'fragment']) - Url.__new__.__defaults__ = ('', '', '') -else: - Url = namedtuple('Url', ['scheme', 'netloc', 'path', 'query', 'fragment'], defaults=('', '', '')) +Url = namedtuple('Url', ['scheme', 'netloc', 'path', 'query', 'fragment'], defaults=('', '', '')) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) diff --git a/keepercommander/service/config/cloudflare_config.py b/keepercommander/service/config/cloudflare_config.py index 05efcdca6..ea9fda01c 100644 --- a/keepercommander/service/config/cloudflare_config.py +++ b/keepercommander/service/config/cloudflare_config.py @@ -9,7 +9,7 @@ # Contact: ops@keepersecurity.com # -from typing import Dict, Any, Optional +from typing import Dict, Any, Optional, Tuple import time import os import psutil @@ -90,7 +90,7 @@ def _get_cloudflare_log_path() -> str: return os.path.join(service_core_dir, "logs", "cloudflare_tunnel_subprocess.log") @staticmethod - def _analyze_tunnel_log(log_file: str) -> tuple[Optional[bool], str]: + def _analyze_tunnel_log(log_file: str) -> Tuple[Optional[bool], str]: """ Analyze tunnel log content for success/failure indicators. Returns (success: Optional[bool], error_message: str) @@ -115,7 +115,7 @@ def _read_log_file(log_file: str) -> str: return f.read() @staticmethod - def _check_tunnel_patterns(content: str) -> tuple[Optional[bool], str]: + def _check_tunnel_patterns(content: str) -> Tuple[Optional[bool], str]: """Check log content for success/failure patterns.""" if any(pattern in content for pattern in CloudflareConfigurator._SUCCESS_PATTERNS): return True, "" diff --git a/keepercommander/service/util/command_util.py b/keepercommander/service/util/command_util.py index fd0f7895c..7cac25d3d 100644 --- a/keepercommander/service/util/command_util.py +++ b/keepercommander/service/util/command_util.py @@ -10,7 +10,6 @@ # import io, html -from pathlib import Path import sys import json import logging @@ -21,8 +20,6 @@ from ..core.globals import get_current_params from ..decorators.logging import logger, debug_decorator, sanitize_debug_data from ... import cli, utils -from ...__main__ import get_params_from_config -from ...service.config.service_config import ServiceConfig from ...crypto import encrypt_aes_v2 class CommandExecutor: diff --git a/keepercommander/service/util/parse_keeper_response.py b/keepercommander/service/util/parse_keeper_response.py index f1372f1da..aae9159d7 100644 --- a/keepercommander/service/util/parse_keeper_response.py +++ b/keepercommander/service/util/parse_keeper_response.py @@ -9,7 +9,7 @@ # Contact: ops@keepersecurity.com # -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import re, json class KeeperResponseParser: @@ -36,7 +36,7 @@ def _format_multiline_message(text: str) -> str: return text @staticmethod - def _preprocess_response(response: Any, log_output: str = None) -> tuple[str, bool]: + def _preprocess_response(response: Any, log_output: str = None) -> Tuple[str, bool]: """Preprocess response by cleaning ANSI codes and determining source. Returns: diff --git a/requirements.txt b/requirements.txt index b33dc9b3a..41983e843 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,15 +9,15 @@ textual>=0.82.0 websockets fido2>=2.0.0; python_version>='3.10' requests>=2.31.0 -cryptography>=39.0.1 -protobuf>=4.23.0 +cryptography>=46.0.6 +protobuf>=5.29.6 keeper-secrets-manager-core>=16.6.0 -keeper_pam_webrtc_rs>=2.1.6; python_version>='3.8' -pydantic>=2.6.4; python_version>='3.8' -flask; python_version>='3.8' +keeper_pam_webrtc_rs>=2.1.6 +pydantic>=2.6.4 +flask pyngrok>=7.5.0 -flask-limiter; python_version>='3.8' -psutil; python_version>='3.8' +flask-limiter +psutil python-dotenv fpdf2>=2.8.3 cbor2; sys_platform == "darwin" and python_version>='3.10' diff --git a/setup.cfg b/setup.cfg index 7a17f75aa..ac7f6bfb8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,36 +16,42 @@ classifiers = License :: OSI Approved :: MIT License Operating System :: OS Independent Programming Language :: Python :: 3 :: Only - Programming Language :: Python :: 3.7 + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 + Programming Language :: Python :: 3.12 + Programming Language :: Python :: 3.13 + Programming Language :: Python :: 3.14 Topic :: Security keywords = security, password [options] -python_requires = >=3.7 +python_requires = >=3.8 packages = find: include_package_data = True install_requires = asciitree bcrypt colorama - cryptography>=41.0.0 + cryptography>=46.0.6 fido2>=2.0.0; python_version>='3.10' - flask; python_version>='3.8' - flask-limiter; python_version>='3.8' + flask + flask-limiter keeper-secrets-manager-core>=16.6.0 prompt_toolkit - protobuf>=4.23.0 + protobuf>=5.29.6 googleapis-common-protos - psutil; python_version>='3.8' + psutil pycryptodomex>=3.20.0 - pyngrok; python_version>='3.8' + pyngrok pyperclip python-dotenv requests>=2.31.0 tabulate websockets - keeper_pam_webrtc_rs>=2.1.6; python_version>='3.8' - pydantic>=2.6.4; python_version>='3.8' + keeper_pam_webrtc_rs>=2.1.6 + pydantic>=2.6.4 fpdf2>=2.8.3 cbor2; sys_platform == "darwin" and python_version>='3.10' pyobjc-framework-LocalAuthentication; sys_platform == "darwin" and python_version>='3.10' diff --git a/tests/test_pam_workflow.py b/tests/test_pam_workflow.py new file mode 100644 index 000000000..155025c55 --- /dev/null +++ b/tests/test_pam_workflow.py @@ -0,0 +1,362 @@ +"""Unit tests for PAM import workflow parsing, validation, and protobuf assembly.""" + +import unittest +from unittest.mock import MagicMock, patch + +from keepercommander.error import CommandError, KeeperApiError +from keepercommander.commands.pam_import.base import PamWorkflowOptions +from keepercommander.commands.pam_import import workflow_apply +from keepercommander.commands.pam_import.workflow_apply import ( + _build_temporal_filter, + _build_parameters, + _DAY_PROTO_MAP, + _is_throttle_error, + _post_with_throttle_retry, +) +from keepercommander.commands.workflow.helpers import WorkflowFormatter +from keepercommander.proto import workflow_pb2 + +# Server expects HHMM integer (workflow.proto:140 "HHMM format" + server validator). +_parse_time_to_hhmm = WorkflowFormatter._parse_time_to_hhmm + + +# --------------------------------------------------------------------------- +# Duration parsing +# --------------------------------------------------------------------------- + +class TestParseDuration(unittest.TestCase): + + def test_hours(self): + self.assertEqual(PamWorkflowOptions._parse_duration('8h'), 8 * 3_600_000) + + def test_minutes(self): + self.assertEqual(PamWorkflowOptions._parse_duration('30m'), 30 * 60_000) + + def test_days(self): + self.assertEqual(PamWorkflowOptions._parse_duration('1d'), 86_400_000) + + def test_bare_integer_treated_as_minutes(self): + self.assertEqual(PamWorkflowOptions._parse_duration('45'), 45 * 60_000) + + def test_none_returns_default(self): + self.assertEqual(PamWorkflowOptions._parse_duration(None), 86_400_000) + + def test_zero_raises(self): + with self.assertRaises(CommandError): + PamWorkflowOptions._parse_duration('0h') + + def test_negative_raises(self): + with self.assertRaises(CommandError): + PamWorkflowOptions._parse_duration('-1d') + + def test_invalid_string_raises(self): + with self.assertRaises(CommandError): + PamWorkflowOptions._parse_duration('invalid') + + def test_uppercase_suffix(self): + self.assertEqual(PamWorkflowOptions._parse_duration('2H'), 2 * 3_600_000) + + +# --------------------------------------------------------------------------- +# Day mapping +# --------------------------------------------------------------------------- + +class TestDayMapping(unittest.TestCase): + + def test_all_3letter_tokens_in_map(self): + expected = {'mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun'} + self.assertEqual(set(_DAY_PROTO_MAP.keys()), expected) + + def test_monday_maps_to_proto(self): + self.assertEqual(_DAY_PROTO_MAP['mon'], workflow_pb2.MONDAY) + + def test_friday_maps_to_proto(self): + self.assertEqual(_DAY_PROTO_MAP['fri'], workflow_pb2.FRIDAY) + + +# --------------------------------------------------------------------------- +# Time-of-day parsing +# --------------------------------------------------------------------------- + +class TestParseTimeToHHMM(unittest.TestCase): + """Server expects HHMM integer encoding per workflow.proto and the server-side + validator (returns "Expected HHMM integer with HH in 0-23 and MM in 0-59").""" + + def test_midnight(self): + self.assertEqual(_parse_time_to_hhmm('00:00'), 0) + + def test_nine_am(self): + self.assertEqual(_parse_time_to_hhmm('09:00'), 900) + + def test_half_past_five_pm(self): + self.assertEqual(_parse_time_to_hhmm('17:30'), 1730) + + def test_invalid_format_raises(self): + with self.assertRaises(CommandError): + _parse_time_to_hhmm('25:00') + + def test_non_numeric_raises(self): + with self.assertRaises(CommandError): + _parse_time_to_hhmm('ab:cd') + + +# --------------------------------------------------------------------------- +# V2: trivial workflow detection +# --------------------------------------------------------------------------- + +class TestTrivialWorkflow(unittest.TestCase): + + def test_empty_dict_returns_none(self): + self.assertIsNone(PamWorkflowOptions.load({})) + + def test_none_returns_none(self): + self.assertIsNone(PamWorkflowOptions.load(None)) + + def test_all_flags_off_no_temporal_returns_none(self): + self.assertIsNone(PamWorkflowOptions.load({ + 'approvals_needed': 0, + 'checkout_needed': False, + 'require_mfa': False, + })) + + def test_checkout_needed_true_is_non_trivial(self): + opts = PamWorkflowOptions.load({'checkout_needed': True, 'access_duration': '2h'}) + self.assertIsNotNone(opts) + self.assertTrue(opts.checkout_needed) + + def test_require_mfa_true_is_non_trivial(self): + opts = PamWorkflowOptions.load({'require_mfa': True}) + self.assertIsNotNone(opts) + + def test_allowed_days_is_non_trivial(self): + opts = PamWorkflowOptions.load({'allowed_times': {'allowed_days': ['mon'], 'timezone': 'UTC'}}) + self.assertIsNotNone(opts) + + def test_approvals_needed_gt0_is_non_trivial(self): + opts = PamWorkflowOptions.load({'approvals_needed': 2}) + self.assertIsNotNone(opts) + + +# --------------------------------------------------------------------------- +# V7: escalation_after requires escalation: true +# --------------------------------------------------------------------------- + +class TestEscalationValidation(unittest.TestCase): + + def test_escalation_after_without_escalation_raises(self): + data = { + 'approvals_needed': 1, + 'approvers': [{ + 'principal': {'type': 'user', 'email': 'a@b.com'}, + 'escalation': False, + 'escalation_after': '30m', + }], + } + with self.assertRaises(CommandError): + PamWorkflowOptions.load(data) + + def test_escalation_after_with_escalation_true_ok(self): + data = { + 'approvals_needed': 1, + 'approvers': [{ + 'principal': {'type': 'user', 'email': 'a@b.com'}, + 'escalation': True, + 'escalation_after': '30m', + }], + } + opts = PamWorkflowOptions.load(data) + self.assertIsNotNone(opts) + self.assertEqual(opts.approvers[0]['escalation_after_ms'], 30 * 60_000) + + +# --------------------------------------------------------------------------- +# V8: time_ranges requires timezone +# --------------------------------------------------------------------------- + +class TestTimezoneRequirement(unittest.TestCase): + + def test_time_ranges_without_timezone_raises(self): + data = { + 'require_mfa': True, + 'allowed_times': { + 'time_ranges': [{'start': '09:00', 'end': '17:00'}], + }, + } + with self.assertRaises(CommandError): + PamWorkflowOptions.load(data) + + def test_time_ranges_with_timezone_ok(self): + data = { + 'require_mfa': True, + 'allowed_times': { + 'time_ranges': [{'start': '09:00', 'end': '17:00'}], + 'timezone': 'America/New_York', + }, + } + opts = PamWorkflowOptions.load(data) + self.assertIsNotNone(opts) + self.assertEqual(opts.timezone, 'America/New_York') + self.assertEqual(len(opts.time_ranges), 1) + + +# --------------------------------------------------------------------------- +# V9: access_duration default +# --------------------------------------------------------------------------- + +class TestAccessDurationDefault(unittest.TestCase): + + def test_missing_access_duration_defaults_to_1d(self): + opts = PamWorkflowOptions.load({'approvals_needed': 1}) + self.assertEqual(opts.access_duration_ms, 86_400_000) + + def test_explicit_duration_parsed(self): + opts = PamWorkflowOptions.load({'approvals_needed': 1, 'access_duration': '4h'}) + self.assertEqual(opts.access_duration_ms, 4 * 3_600_000) + + +# --------------------------------------------------------------------------- +# Protobuf assembly: _build_parameters +# --------------------------------------------------------------------------- + +class TestBuildParameters(unittest.TestCase): + + def _make_uid_bytes(self): + import base64 + return base64.urlsafe_b64decode('AAAAAAAAAAAAAAAAAAAAAA==') + + def test_basic_fields_populated(self): + opts = PamWorkflowOptions.load({ + 'approvals_needed': 2, + 'checkout_needed': True, + 'require_mfa': True, + 'access_duration': '8h', + }) + uid_bytes = self._make_uid_bytes() + params_proto = _build_parameters(uid_bytes, 'Test Machine', opts) + self.assertEqual(params_proto.approvalsNeeded, 2) + self.assertTrue(params_proto.checkoutNeeded) + self.assertTrue(params_proto.requireMFA) + self.assertEqual(params_proto.accessLength, 8 * 3_600_000) + self.assertEqual(params_proto.resource.value, uid_bytes) + self.assertEqual(params_proto.resource.name, 'Test Machine') + + def test_temporal_filter_attached(self): + opts = PamWorkflowOptions.load({ + 'require_mfa': True, + 'allowed_times': { + 'allowed_days': ['mon', 'fri'], + 'time_ranges': [{'start': '09:00', 'end': '17:00'}], + 'timezone': 'UTC', + }, + }) + uid_bytes = self._make_uid_bytes() + params_proto = _build_parameters(uid_bytes, 'Box', opts) + at = params_proto.allowedTimes + self.assertIn(workflow_pb2.MONDAY, at.allowedDays) + self.assertIn(workflow_pb2.FRIDAY, at.allowedDays) + self.assertEqual(len(at.timeRanges), 1) + # HHMM integer encoding: 09:00 -> 900, 17:00 -> 1700 + self.assertEqual(at.timeRanges[0].startTime, 900) + self.assertEqual(at.timeRanges[0].endTime, 1700) + self.assertEqual(at.timeZone, 'UTC') + + def test_no_allowed_times_no_temporal(self): + opts = PamWorkflowOptions.load({'approvals_needed': 1}) + uid_bytes = self._make_uid_bytes() + params_proto = _build_parameters(uid_bytes, 'Box', opts) + self.assertFalse(params_proto.HasField('allowedTimes')) + + +# --------------------------------------------------------------------------- +# validate_principals +# --------------------------------------------------------------------------- + +class TestValidatePrincipals(unittest.TestCase): + + def _make_params(self, team_uids): + p = MagicMock() + p.team_cache = {uid: {} for uid in team_uids} + return p + + def test_known_team_uid_passes(self): + opts = PamWorkflowOptions.load({ + 'approvals_needed': 1, + 'approvers': [{'principal': {'type': 'team', 'team_uid_base64url': 'validUID123'}}], + }) + params = self._make_params(['validUID123']) + opts.validate_principals(params, 'MyResource') + + def test_unknown_team_uid_raises(self): + opts = PamWorkflowOptions.load({ + 'approvals_needed': 1, + 'approvers': [{'principal': {'type': 'team', 'team_uid_base64url': 'unknownUID'}}], + }) + params = self._make_params(['otherUID']) + with self.assertRaises(CommandError): + opts.validate_principals(params, 'MyResource') + + def test_user_principal_not_checked_against_team_cache(self): + opts = PamWorkflowOptions.load({ + 'approvals_needed': 1, + 'approvers': [{'principal': {'type': 'user', 'email': 'user@example.com'}}], + }) + params = self._make_params([]) + opts.validate_principals(params) + + +# --------------------------------------------------------------------------- +# Throttle / 429 retry wrapper +# --------------------------------------------------------------------------- + +class TestThrottleErrorDetection(unittest.TestCase): + + def test_keeper_api_error_429_is_throttle(self): + self.assertTrue(_is_throttle_error(KeeperApiError(429, 'Too many requests'))) + + def test_keeper_api_error_500_is_not_throttle(self): + self.assertFalse(_is_throttle_error(KeeperApiError(500, 'Internal error'))) + + def test_string_throttle_in_msg_is_throttle(self): + self.assertTrue(_is_throttle_error(Exception('record was throttled'))) + + def test_too_many_in_msg_is_throttle(self): + self.assertTrue(_is_throttle_error(Exception('Too many requests'))) + + def test_unrelated_error_is_not_throttle(self): + self.assertFalse(_is_throttle_error(Exception('connection refused'))) + + +class TestThrottleRetry(unittest.TestCase): + + def test_no_retry_on_non_throttle(self): + with patch.object(workflow_apply, '_post_request_to_router', + side_effect=KeeperApiError(500, 'boom')) as mock_post: + with self.assertRaises(KeeperApiError): + _post_with_throttle_retry(MagicMock(), 'read_workflow_config') + self.assertEqual(mock_post.call_count, 1) + + def test_retries_then_succeeds(self): + # First two calls 429, third succeeds. Patch sleep to keep test fast. + side_effects = [KeeperApiError(429, 'Too many requests'), + KeeperApiError(429, 'Too many requests'), + 'OK'] + with patch.object(workflow_apply, '_post_request_to_router', + side_effect=side_effects) as mock_post, \ + patch.object(workflow_apply.time, 'sleep') as mock_sleep: + result = _post_with_throttle_retry(MagicMock(), 'read_workflow_config') + self.assertEqual(result, 'OK') + self.assertEqual(mock_post.call_count, 3) + # Two backoff sleeps: 10s, 15s (10 * 1.5) + self.assertEqual([round(c.args[0], 2) for c in mock_sleep.call_args_list], [10.0, 15.0]) + + def test_exhausts_retries_and_reraises(self): + with patch.object(workflow_apply, '_post_request_to_router', + side_effect=KeeperApiError(429, 'Too many requests')) as mock_post, \ + patch.object(workflow_apply.time, 'sleep'): + with self.assertRaises(KeeperApiError): + _post_with_throttle_retry(MagicMock(), 'read_workflow_config') + self.assertEqual(mock_post.call_count, workflow_apply._THROTTLE_MAX_RETRIES) + + +if __name__ == '__main__': + unittest.main() diff --git a/unit-tests/pam/test_pam_import_dedup.py b/unit-tests/pam/test_pam_import_dedup.py index a11c64b84..f958ff1ef 100644 --- a/unit-tests/pam/test_pam_import_dedup.py +++ b/unit-tests/pam/test_pam_import_dedup.py @@ -1,87 +1,85 @@ """Test that pam project import rejects duplicate UIDs.""" import logging -import sys import unittest -if sys.version_info >= (3, 8): - from keepercommander.commands.pam_import.edit import PAMProjectImportCommand +from keepercommander.commands.pam_import.edit import PAMProjectImportCommand - def _minimal_project(resources, users=None): - """Build a minimal project dict matching the structure process_data expects.""" - return { - "data": { - "pam_data": { - "resources": resources, - "users": users or [], - "rotation_profiles": {}, - } - }, - "pam_config": {"pam_config_uid": "test-config-uid"}, - "folders": { - "resources_folder_uid": "sfr-test", - "users_folder_uid": "sfu-test", - }, - } +def _minimal_project(resources, users=None): + """Build a minimal project dict matching the structure process_data expects.""" + return { + "data": { + "pam_data": { + "resources": resources, + "users": users or [], + "rotation_profiles": {}, + } + }, + "pam_config": {"pam_config_uid": "test-config-uid"}, + "folders": { + "resources_folder_uid": "sfr-test", + "users_folder_uid": "sfu-test", + }, + } - class TestPAMImportDuplicateUid(unittest.TestCase): - """process_data must abort when the import JSON contains duplicate uid values.""" +class TestPAMImportDuplicateUid(unittest.TestCase): + """process_data must abort when the import JSON contains duplicate uid values.""" - def test_duplicate_uid_logs_error_and_returns(self): - """process_data aborts with logging.error when two resources share a uid.""" - from unittest.mock import MagicMock - project = _minimal_project([ - {'type': 'pamMachine', 'title': 'Machine A', 'uid': 'duplicate-uid-1'}, - {'type': 'pamMachine', 'title': 'Machine B', 'uid': 'duplicate-uid-1'}, - ]) - cmd = PAMProjectImportCommand() - params = MagicMock() - params.record_cache = {} - params.shared_folder_cache = {} - params.folder_cache = {} + def test_duplicate_uid_logs_error_and_returns(self): + """process_data aborts with logging.error when two resources share a uid.""" + from unittest.mock import MagicMock + project = _minimal_project([ + {'type': 'pamMachine', 'title': 'Machine A', 'uid': 'duplicate-uid-1'}, + {'type': 'pamMachine', 'title': 'Machine B', 'uid': 'duplicate-uid-1'}, + ]) + cmd = PAMProjectImportCommand() + params = MagicMock() + params.record_cache = {} + params.shared_folder_cache = {} + params.folder_cache = {} - # assertLogs with no logger name captures from root logger (where logging.error writes) - with self.assertLogs(level='ERROR') as log_ctx: - try: - cmd.process_data(params, project) - except Exception: - pass # early return path may surface as exception in some code paths + # assertLogs with no logger name captures from root logger (where logging.error writes) + with self.assertLogs(level='ERROR') as log_ctx: + try: + cmd.process_data(params, project) + except Exception: + pass # early return path may surface as exception in some code paths - self.assertTrue( - any('duplicate uid' in msg.lower() or 'duplicate-uid-1' in msg - for msg in log_ctx.output), - f'Expected duplicate UID error in logs, got: {log_ctx.output}' - ) + self.assertTrue( + any('duplicate uid' in msg.lower() or 'duplicate-uid-1' in msg + for msg in log_ctx.output), + f'Expected duplicate UID error in logs, got: {log_ctx.output}' + ) - def test_unique_uids_pass_dedup_check(self): - """process_data does NOT emit a duplicate-uid error when all UIDs are unique.""" - from unittest.mock import MagicMock - import io + def test_unique_uids_pass_dedup_check(self): + """process_data does NOT emit a duplicate-uid error when all UIDs are unique.""" + from unittest.mock import MagicMock + import io - project = _minimal_project([ - {'type': 'pamMachine', 'title': 'Machine A', 'uid': 'uid-alpha'}, - {'type': 'pamMachine', 'title': 'Machine B', 'uid': 'uid-beta'}, - ]) - cmd = PAMProjectImportCommand() - params = MagicMock() - params.record_cache = {} - params.shared_folder_cache = {} - params.folder_cache = {} + project = _minimal_project([ + {'type': 'pamMachine', 'title': 'Machine A', 'uid': 'uid-alpha'}, + {'type': 'pamMachine', 'title': 'Machine B', 'uid': 'uid-beta'}, + ]) + cmd = PAMProjectImportCommand() + params = MagicMock() + params.record_cache = {} + params.shared_folder_cache = {} + params.folder_cache = {} - stream = io.StringIO() - handler = logging.StreamHandler(stream) - handler.setLevel(logging.ERROR) - root_logger = logging.getLogger() - root_logger.addHandler(handler) + stream = io.StringIO() + handler = logging.StreamHandler(stream) + handler.setLevel(logging.ERROR) + root_logger = logging.getLogger() + root_logger.addHandler(handler) + try: try: - try: - cmd.process_data(params, project) - except Exception: - pass - output = stream.getvalue() - self.assertNotIn('duplicate uid', output.lower(), - f'Unexpected duplicate UID error for unique UIDs: {output}') - finally: - root_logger.removeHandler(handler) + cmd.process_data(params, project) + except Exception: + pass + output = stream.getvalue() + self.assertNotIn('duplicate uid', output.lower(), + f'Unexpected duplicate UID error for unique UIDs: {output}') + finally: + root_logger.removeHandler(handler) if __name__ == '__main__': diff --git a/unit-tests/pam/test_pam_project_export.py b/unit-tests/pam/test_pam_project_export.py index d10ecac60..0f7900a5a 100644 --- a/unit-tests/pam/test_pam_project_export.py +++ b/unit-tests/pam/test_pam_project_export.py @@ -18,7 +18,6 @@ import json import os -import sys import tempfile import unittest from unittest.mock import patch @@ -111,286 +110,279 @@ def _fake_load(_params, uid): # ── tests ────────────────────────────────────────────────────────────────── -if sys.version_info >= (3, 8): - from unittest.mock import MagicMock - - class TestPAMProjectExportCommand(unittest.TestCase): - - def setUp(self): - from keepercommander.commands.pam_import.export import PAMProjectExportCommand - self.cmd = PAMProjectExportCommand() - self.params = MagicMock() - self.params.record_cache = {uid: {} for uid in _RECORDS} - - def _execute(self, project_uid=CONFIG_UID, output=None): - """Run execute() with vault.KeeperRecord.load mocked.""" - with patch("keepercommander.vault.KeeperRecord.load", side_effect=_fake_load): - with patch.object(self.cmd, "_get_allowed_settings", - return_value=dict(_DEFAULT_ALLOWED)): - kwargs = {"project_uid": project_uid} - if output: - kwargs["output"] = output - return self.cmd.execute(self.params, **kwargs) - - # ── basic output ────────────────────────────────────────────── - - def test_returns_string(self): - result = self._execute() - self.assertIsInstance(result, str, - "execute() should return a JSON string when --output is not set") - - def test_valid_json(self): - parsed = json.loads(self._execute()) - self.assertIsInstance(parsed, dict) - - # ── required top-level keys ─────────────────────────────────── - - def test_has_project_key(self): - parsed = json.loads(self._execute()) +from unittest.mock import MagicMock + +class TestPAMProjectExportCommand(unittest.TestCase): + + def setUp(self): + from keepercommander.commands.pam_import.export import PAMProjectExportCommand + self.cmd = PAMProjectExportCommand() + self.params = MagicMock() + self.params.record_cache = {uid: {} for uid in _RECORDS} + + def _execute(self, project_uid=CONFIG_UID, output=None): + """Run execute() with vault.KeeperRecord.load mocked.""" + with patch("keepercommander.vault.KeeperRecord.load", side_effect=_fake_load): + with patch.object(self.cmd, "_get_allowed_settings", + return_value=dict(_DEFAULT_ALLOWED)): + kwargs = {"project_uid": project_uid} + if output: + kwargs["output"] = output + return self.cmd.execute(self.params, **kwargs) + + # ── basic output ────────────────────────────────────────────── + + def test_returns_string(self): + result = self._execute() + self.assertIsInstance(result, str, + "execute() should return a JSON string when --output is not set") + + def test_valid_json(self): + parsed = json.loads(self._execute()) + self.assertIsInstance(parsed, dict) + + # ── required top-level keys ─────────────────────────────────── + + def test_has_project_key(self): + parsed = json.loads(self._execute()) + self.assertIn("project", parsed) + self.assertEqual(parsed["project"], "Test Project") + + def test_has_pam_configuration_key(self): + parsed = json.loads(self._execute()) + self.assertIn("pam_configuration", parsed) + + def test_has_pam_data_key(self): + parsed = json.loads(self._execute()) + self.assertIn("pam_data", parsed) + self.assertIn("resources", parsed["pam_data"]) + self.assertIn("users", parsed["pam_data"]) + + def test_has_tool_version(self): + parsed = json.loads(self._execute()) + self.assertIn("tool_version", parsed) + self.assertEqual(parsed["tool_version"], "commander-export-1.0") + + # ── pam_configuration fields ────────────────────────────────── + + def test_pam_configuration_environment(self): + parsed = json.loads(self._execute()) + self.assertEqual(parsed["pam_configuration"]["environment"], "local") + + def test_pam_configuration_on_off_values(self): + parsed = json.loads(self._execute()) + cfg = parsed["pam_configuration"] + for key in ("connections", "rotation", "tunneling", "remote_browser_isolation"): + self.assertIn(cfg[key], ("on", "off"), f"{key} must be 'on' or 'off'") + + # ── resources ──────────────────────────────────────────────── + + def test_resources_count(self): + parsed = json.loads(self._execute()) + self.assertEqual(len(parsed["pam_data"]["resources"]), 2) + + def test_resource_has_required_keys(self): + parsed = json.loads(self._execute()) + for res in parsed["pam_data"]["resources"]: + for key in ("uid", "type", "title", "users"): + self.assertIn(key, res, f"resource missing key: {key}") + + def test_resource_uids_are_unique(self): + parsed = json.loads(self._execute()) + uids = [r["uid"] for r in parsed["pam_data"]["resources"]] + self.assertEqual(len(uids), len(set(uids)), "resource UIDs must be unique") + + def test_resource_types(self): + parsed = json.loads(self._execute()) + types = {r["type"] for r in parsed["pam_data"]["resources"]} + self.assertIn("pamMachine", types) + self.assertIn("pamDatabase", types) + + # ── users ──────────────────────────────────────────────────── + + def test_top_level_users_deduplication(self): + # USER1 appears in both machine and database resources; + # must only appear once in pam_data.users + parsed = json.loads(self._execute()) + top_uids = [u["uid"] for u in parsed["pam_data"]["users"]] + self.assertEqual(len(top_uids), len(set(top_uids)), + "top-level user UIDs must be unique (de-duplicated)") + + def test_top_level_users_count(self): + # USER1 shared across both resources, USER2 only in DB → 2 unique users + parsed = json.loads(self._execute()) + self.assertEqual(len(parsed["pam_data"]["users"]), 2) + + def test_user_has_required_keys(self): + parsed = json.loads(self._execute()) + for usr in parsed["pam_data"]["users"]: + for key in ("uid", "type", "title", "login"): + self.assertIn(key, usr, f"user missing key: {key}") + + # ── --output flag ──────────────────────────────────────────── + + def test_output_flag_writes_file(self): + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp: + tmp_path = tmp.name + try: + result = self._execute(output=tmp_path) + # When --output is set, execute() should return None + self.assertIsNone(result) + self.assertTrue(os.path.exists(tmp_path)) + with open(tmp_path, encoding="utf-8") as fh: + content = fh.read() + parsed = json.loads(content) self.assertIn("project", parsed) - self.assertEqual(parsed["project"], "Test Project") - - def test_has_pam_configuration_key(self): - parsed = json.loads(self._execute()) - self.assertIn("pam_configuration", parsed) - - def test_has_pam_data_key(self): - parsed = json.loads(self._execute()) - self.assertIn("pam_data", parsed) - self.assertIn("resources", parsed["pam_data"]) - self.assertIn("users", parsed["pam_data"]) - - def test_has_tool_version(self): - parsed = json.loads(self._execute()) self.assertIn("tool_version", parsed) - self.assertEqual(parsed["tool_version"], "commander-export-1.0") - - # ── pam_configuration fields ────────────────────────────────── - - def test_pam_configuration_environment(self): - parsed = json.loads(self._execute()) - self.assertEqual(parsed["pam_configuration"]["environment"], "local") - - def test_pam_configuration_on_off_values(self): - parsed = json.loads(self._execute()) - cfg = parsed["pam_configuration"] - for key in ("connections", "rotation", "tunneling", "remote_browser_isolation"): - self.assertIn(cfg[key], ("on", "off"), f"{key} must be 'on' or 'off'") - - # ── resources ──────────────────────────────────────────────── - - def test_resources_count(self): - parsed = json.loads(self._execute()) - self.assertEqual(len(parsed["pam_data"]["resources"]), 2) - - def test_resource_has_required_keys(self): - parsed = json.loads(self._execute()) - for res in parsed["pam_data"]["resources"]: - for key in ("uid", "type", "title", "users"): - self.assertIn(key, res, f"resource missing key: {key}") - - def test_resource_uids_are_unique(self): - parsed = json.loads(self._execute()) - uids = [r["uid"] for r in parsed["pam_data"]["resources"]] - self.assertEqual(len(uids), len(set(uids)), "resource UIDs must be unique") - - def test_resource_types(self): - parsed = json.loads(self._execute()) - types = {r["type"] for r in parsed["pam_data"]["resources"]} - self.assertIn("pamMachine", types) - self.assertIn("pamDatabase", types) - - # ── users ──────────────────────────────────────────────────── - - def test_top_level_users_deduplication(self): - # USER1 appears in both machine and database resources; - # must only appear once in pam_data.users - parsed = json.loads(self._execute()) - top_uids = [u["uid"] for u in parsed["pam_data"]["users"]] - self.assertEqual(len(top_uids), len(set(top_uids)), - "top-level user UIDs must be unique (de-duplicated)") - - def test_top_level_users_count(self): - # USER1 shared across both resources, USER2 only in DB → 2 unique users - parsed = json.loads(self._execute()) - self.assertEqual(len(parsed["pam_data"]["users"]), 2) - - def test_user_has_required_keys(self): - parsed = json.loads(self._execute()) - for usr in parsed["pam_data"]["users"]: - for key in ("uid", "type", "title", "login"): - self.assertIn(key, usr, f"user missing key: {key}") - - # ── --output flag ──────────────────────────────────────────── - - def test_output_flag_writes_file(self): - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp: - tmp_path = tmp.name - try: - result = self._execute(output=tmp_path) - # When --output is set, execute() should return None - self.assertIsNone(result) - self.assertTrue(os.path.exists(tmp_path)) - with open(tmp_path, encoding="utf-8") as fh: - content = fh.read() - parsed = json.loads(content) - self.assertIn("project", parsed) - self.assertIn("tool_version", parsed) - finally: - if os.path.exists(tmp_path): - os.unlink(tmp_path) - - # ── error handling ─────────────────────────────────────────── - - def test_missing_project_uid_returns_none(self): - with patch("keepercommander.vault.KeeperRecord.load", side_effect=_fake_load): - result = self.cmd.execute(self.params, project_uid="", output=None) - self.assertIsNone(result) - - def test_unknown_uid_returns_none(self): - with patch("keepercommander.vault.KeeperRecord.load", return_value=None): - result = self.cmd.execute(self.params, project_uid="unknown-uid", output=None) - self.assertIsNone(result) - - def test_non_v6_record_returns_none(self): - v3_rec = vault.TypedRecord(version=3) - v3_rec.type_name = "pamMachine" - v3_rec.title = "some" - v3_rec.record_uid = "some-uid" - with patch("keepercommander.vault.KeeperRecord.load", return_value=v3_rec): - result = self.cmd.execute(self.params, project_uid="some-uid", output=None) - self.assertIsNone(result) - - # ── round-trip / determinism ───────────────────────────────── - - def test_sort_keys_determinism(self): - result1 = self._execute() - result2 = self._execute() - self.assertEqual(result1, result2, "Output must be deterministic across calls") - - def test_output_is_sorted(self): - result = self._execute() - parsed = json.loads(result) - keys = list(parsed.keys()) - self.assertEqual(keys, sorted(keys), - "Top-level keys should be sorted (sort_keys=True)") - - - # ──────────────────────────────────────────────────────────────────── - # KCM-import compatibility (PR #1942) - # ──────────────────────────────────────────────────────────────────── - - class TestKCMImportRoundTrip(unittest.TestCase): - """KCM-imported records (PR #1942) reference users by *title* in - ``pam_settings.connection.launch_credentials`` rather than by UID - in ``userRecords[]``. Export must resolve these title references - so the exported JSON re-imports with the user link intact. - """ - - KCM_CFG = "kcm-cfg-1" - KCM_RES = "kcm-res-prod-db" - KCM_USR = "kcm-usr-prod-db" - - def _make_kcm_records(self): - """Build the KCM-shaped vault state (PR #1942 import output).""" - cfg = vault.TypedRecord(version=6) - cfg.type_name = "pamNetworkConfiguration" - cfg.title = "KCM Migration" - cfg.record_uid = self.KCM_CFG - cfg.fields.append(_make_typed_field("pamResources", [{ - "controllerUid": "gw-uid", - "folderUid": "sf-uid", - "resourceRef": [self.KCM_RES], - }])) - - res = vault.TypedRecord(version=3) - res.type_name = "pamMachine" - res.title = "KCM Resource - prod-db" - res.record_uid = self.KCM_RES - res.fields.append(_make_typed_field("pamSettings", [{ - "connection": { - "protocol": "ssh", - "port": "22", - "launch_credentials": "KCM User - prod-db", - }, - "options": {"connections": "on", "rotation": "off"}, - }])) - - usr = vault.TypedRecord(version=3) - usr.type_name = "pamUser" - usr.title = "KCM User - prod-db" - usr.record_uid = self.KCM_USR - usr.fields.append(_make_typed_field("login", ["root"])) - - return {self.KCM_CFG: cfg, self.KCM_RES: res, self.KCM_USR: usr} - - def setUp(self): - from keepercommander.commands.pam_import.export import PAMProjectExportCommand - from unittest.mock import MagicMock - self.cmd = PAMProjectExportCommand() - self.records = self._make_kcm_records() - self.params = MagicMock() - self.params.record_cache = {uid: {} for uid in self.records} - - def _execute(self): - def _load(_p, uid): - return self.records.get(uid) - with patch("keepercommander.vault.KeeperRecord.load", side_effect=_load): - with patch.object(self.cmd, "_get_allowed_settings", - return_value=dict(_DEFAULT_ALLOWED)): - return self.cmd.execute(self.params, project_uid=self.KCM_CFG) - - def test_title_based_user_link_resolved(self): - """KCM resource → export must include the user via title resolution.""" - parsed = json.loads(self._execute()) - resources = parsed["pam_data"]["resources"] - self.assertEqual(len(resources), 1, "expected one KCM resource") - res = resources[0] - self.assertEqual(len(res["users"]), 1, - "KCM resource must export 1 user (resolved by title)") - self.assertEqual(res["users"][0]["uid"], self.KCM_USR) - self.assertEqual(res["users"][0]["title"], "KCM User - prod-db") - - def test_top_level_users_includes_resolved_user(self): - parsed = json.loads(self._execute()) - top_users = parsed["pam_data"]["users"] - self.assertEqual(len(top_users), 1) - self.assertEqual(top_users[0]["uid"], self.KCM_USR) - - def test_pam_settings_preserved_for_round_trip(self): - """Round-trip safety: KCM-specific pam_settings keys preserved verbatim.""" - parsed = json.loads(self._execute()) - res = parsed["pam_data"]["resources"][0] - conn = res["pam_settings"]["connection"] - self.assertEqual(conn["protocol"], "ssh") - self.assertEqual(conn["port"], "22") - self.assertEqual(conn["launch_credentials"], "KCM User - prod-db") - - def test_uid_in_launch_credentials_accepted(self): - """If launch_credentials already holds a 22-char UID (non-KCM path), keep it as-is.""" - uid_22 = "AAAAAAAAAAAAAAAAAAAAAA" # 22 chars, no slash, no space - usr = vault.TypedRecord(version=3) - usr.type_name = "pamUser" - usr.title = "Direct UID User" - usr.record_uid = uid_22 - usr.fields.append(_make_typed_field("login", ["alice"])) - self.records[uid_22] = usr - self.params.record_cache[uid_22] = {} - - res = self.records[self.KCM_RES] - ps = res.get_typed_field("pamSettings").value[0] - ps["connection"]["launch_credentials"] = uid_22 - parsed = json.loads(self._execute()) - users = parsed["pam_data"]["resources"][0]["users"] - self.assertEqual(len(users), 1) - self.assertEqual(users[0]["uid"], uid_22) - - -else: - class TestPAMProjectExportCommand(unittest.TestCase): - def test_skip(self): - self.skipTest("Requires Python 3.8+") + finally: + if os.path.exists(tmp_path): + os.unlink(tmp_path) + + # ── error handling ─────────────────────────────────────────── + + def test_missing_project_uid_returns_none(self): + with patch("keepercommander.vault.KeeperRecord.load", side_effect=_fake_load): + result = self.cmd.execute(self.params, project_uid="", output=None) + self.assertIsNone(result) + + def test_unknown_uid_returns_none(self): + with patch("keepercommander.vault.KeeperRecord.load", return_value=None): + result = self.cmd.execute(self.params, project_uid="unknown-uid", output=None) + self.assertIsNone(result) + + def test_non_v6_record_returns_none(self): + v3_rec = vault.TypedRecord(version=3) + v3_rec.type_name = "pamMachine" + v3_rec.title = "some" + v3_rec.record_uid = "some-uid" + with patch("keepercommander.vault.KeeperRecord.load", return_value=v3_rec): + result = self.cmd.execute(self.params, project_uid="some-uid", output=None) + self.assertIsNone(result) + + # ── round-trip / determinism ───────────────────────────────── + + def test_sort_keys_determinism(self): + result1 = self._execute() + result2 = self._execute() + self.assertEqual(result1, result2, "Output must be deterministic across calls") + + def test_output_is_sorted(self): + result = self._execute() + parsed = json.loads(result) + keys = list(parsed.keys()) + self.assertEqual(keys, sorted(keys), + "Top-level keys should be sorted (sort_keys=True)") + + +# ──────────────────────────────────────────────────────────────────── +# KCM-import compatibility (PR #1942) +# ──────────────────────────────────────────────────────────────────── + +class TestKCMImportRoundTrip(unittest.TestCase): + """KCM-imported records (PR #1942) reference users by *title* in + ``pam_settings.connection.launch_credentials`` rather than by UID + in ``userRecords[]``. Export must resolve these title references + so the exported JSON re-imports with the user link intact. + """ + + KCM_CFG = "kcm-cfg-1" + KCM_RES = "kcm-res-prod-db" + KCM_USR = "kcm-usr-prod-db" + + def _make_kcm_records(self): + """Build the KCM-shaped vault state (PR #1942 import output).""" + cfg = vault.TypedRecord(version=6) + cfg.type_name = "pamNetworkConfiguration" + cfg.title = "KCM Migration" + cfg.record_uid = self.KCM_CFG + cfg.fields.append(_make_typed_field("pamResources", [{ + "controllerUid": "gw-uid", + "folderUid": "sf-uid", + "resourceRef": [self.KCM_RES], + }])) + + res = vault.TypedRecord(version=3) + res.type_name = "pamMachine" + res.title = "KCM Resource - prod-db" + res.record_uid = self.KCM_RES + res.fields.append(_make_typed_field("pamSettings", [{ + "connection": { + "protocol": "ssh", + "port": "22", + "launch_credentials": "KCM User - prod-db", + }, + "options": {"connections": "on", "rotation": "off"}, + }])) + + usr = vault.TypedRecord(version=3) + usr.type_name = "pamUser" + usr.title = "KCM User - prod-db" + usr.record_uid = self.KCM_USR + usr.fields.append(_make_typed_field("login", ["root"])) + + return {self.KCM_CFG: cfg, self.KCM_RES: res, self.KCM_USR: usr} + + def setUp(self): + from keepercommander.commands.pam_import.export import PAMProjectExportCommand + from unittest.mock import MagicMock + self.cmd = PAMProjectExportCommand() + self.records = self._make_kcm_records() + self.params = MagicMock() + self.params.record_cache = {uid: {} for uid in self.records} + + def _execute(self): + def _load(_p, uid): + return self.records.get(uid) + with patch("keepercommander.vault.KeeperRecord.load", side_effect=_load): + with patch.object(self.cmd, "_get_allowed_settings", + return_value=dict(_DEFAULT_ALLOWED)): + return self.cmd.execute(self.params, project_uid=self.KCM_CFG) + + def test_title_based_user_link_resolved(self): + """KCM resource → export must include the user via title resolution.""" + parsed = json.loads(self._execute()) + resources = parsed["pam_data"]["resources"] + self.assertEqual(len(resources), 1, "expected one KCM resource") + res = resources[0] + self.assertEqual(len(res["users"]), 1, + "KCM resource must export 1 user (resolved by title)") + self.assertEqual(res["users"][0]["uid"], self.KCM_USR) + self.assertEqual(res["users"][0]["title"], "KCM User - prod-db") + + def test_top_level_users_includes_resolved_user(self): + parsed = json.loads(self._execute()) + top_users = parsed["pam_data"]["users"] + self.assertEqual(len(top_users), 1) + self.assertEqual(top_users[0]["uid"], self.KCM_USR) + + def test_pam_settings_preserved_for_round_trip(self): + """Round-trip safety: KCM-specific pam_settings keys preserved verbatim.""" + parsed = json.loads(self._execute()) + res = parsed["pam_data"]["resources"][0] + conn = res["pam_settings"]["connection"] + self.assertEqual(conn["protocol"], "ssh") + self.assertEqual(conn["port"], "22") + self.assertEqual(conn["launch_credentials"], "KCM User - prod-db") + + def test_uid_in_launch_credentials_accepted(self): + """If launch_credentials already holds a 22-char UID (non-KCM path), keep it as-is.""" + uid_22 = "AAAAAAAAAAAAAAAAAAAAAA" # 22 chars, no slash, no space + usr = vault.TypedRecord(version=3) + usr.type_name = "pamUser" + usr.title = "Direct UID User" + usr.record_uid = uid_22 + usr.fields.append(_make_typed_field("login", ["alice"])) + self.records[uid_22] = usr + self.params.record_cache[uid_22] = {} + + res = self.records[self.KCM_RES] + ps = res.get_typed_field("pamSettings").value[0] + ps["connection"]["launch_credentials"] = uid_22 + parsed = json.loads(self._execute()) + users = parsed["pam_data"]["resources"][0]["users"] + self.assertEqual(len(users), 1) + self.assertEqual(users[0]["uid"], uid_22) if __name__ == "__main__": diff --git a/unit-tests/pam/test_pam_rotation.py b/unit-tests/pam/test_pam_rotation.py index 87e4cc157..f7fce197a 100644 --- a/unit-tests/pam/test_pam_rotation.py +++ b/unit-tests/pam/test_pam_rotation.py @@ -1,5 +1,4 @@ import json -import sys import unittest from datetime import datetime from unittest.mock import patch, MagicMock @@ -50,555 +49,554 @@ def create_mock_params(): return mock_params -if sys.version_info >= (3, 8): - import requests - from cryptography.hazmat.primitives.asymmetric import ec - from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat +import requests +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat - from keepercommander import crypto, utils - from keepercommander.commands.discoveryrotation import (PAMCreateRecordRotationCommand, PAMListRecordRotationCommand, - PAMGatewayListCommand, PAMRouterGetRotationInfo) +from keepercommander import crypto, utils +from keepercommander.commands.discoveryrotation import (PAMCreateRecordRotationCommand, PAMListRecordRotationCommand, + PAMGatewayListCommand, PAMRouterGetRotationInfo) - class TestPAMCreateRecordRotationCommand(unittest.TestCase): +class TestPAMCreateRecordRotationCommand(unittest.TestCase): - def setUp(self): - self.command = PAMCreateRecordRotationCommand() - self.parser = self.command.get_parser() - self.transmission_key = b'transmission_key' - self.session_token = b'encrypted_session_token' - self.private_key = ec.generate_private_key(ec.SECP256R1()) - self.public_key = self.private_key.public_key() + def setUp(self): + self.command = PAMCreateRecordRotationCommand() + self.parser = self.command.get_parser() + self.transmission_key = b'transmission_key' + self.session_token = b'encrypted_session_token' + self.private_key = ec.generate_private_key(ec.SECP256R1()) + self.public_key = self.private_key.public_key() - # Serialize and deserialize the public key to ensure compatibility - public_key_bytes = self.public_key.public_bytes(Encoding.X962, PublicFormat.UncompressedPoint) - loaded_public_key = ec.EllipticCurvePublicKey.from_encoded_point(ec.SECP256R1(), public_key_bytes) + # Serialize and deserialize the public key to ensure compatibility + public_key_bytes = self.public_key.public_bytes(Encoding.X962, PublicFormat.UncompressedPoint) + loaded_public_key = ec.EllipticCurvePublicKey.from_encoded_point(ec.SECP256R1(), public_key_bytes) - self.encrypted_transmission_key = crypto.encrypt_ec(self.transmission_key, loaded_public_key) - self.encrypted_session_token = crypto.encrypt_aes_v2(self.session_token, self.transmission_key) + self.encrypted_transmission_key = crypto.encrypt_ec(self.transmission_key, loaded_public_key) + self.encrypted_session_token = crypto.encrypt_aes_v2(self.session_token, self.transmission_key) - def test_parser(self): - args = self.parser.parse_args(['--record', 'record_uid', '--force']) - self.assertEqual(args.record_name, 'record_uid') - self.assertTrue(args.force) + def test_parser(self): + args = self.parser.parse_args(['--record', 'record_uid', '--force']) + self.assertEqual(args.record_name, 'record_uid') + self.assertTrue(args.force) - @patch('keepercommander.vault.KeeperRecord.load') - @patch('keepercommander.commands.discoveryrotation.TunnelDAG') - @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) - def test_execute_with_folder(self, mock_TunnelDAG, mock_load): - mock_params, mock_typed_record = create_mock_params_and_record() + @patch('keepercommander.vault.KeeperRecord.load') + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) + def test_execute_with_folder(self, mock_TunnelDAG, mock_load): + mock_params, mock_typed_record = create_mock_params_and_record() - mock_load.return_value = mock_typed_record + mock_load.return_value = mock_typed_record - mock_dag_instance = mock_TunnelDAG.return_value - mock_dag_instance.linking_dag.has_graph = True - mock_dag_instance.check_if_resource_has_admin.return_value = True - mock_dag_instance.get_all_owners.return_value = ['resource_uid'] - mock_dag_instance.resource_belongs_to_config.return_value = True - mock_dag_instance.user_belongs_to_resource.return_value = True - mock_dag_instance.record.record_uid = 'config_uid' # Ensure it returns a string + mock_dag_instance = mock_TunnelDAG.return_value + mock_dag_instance.linking_dag.has_graph = True + mock_dag_instance.check_if_resource_has_admin.return_value = True + mock_dag_instance.get_all_owners.return_value = ['resource_uid'] + mock_dag_instance.resource_belongs_to_config.return_value = True + mock_dag_instance.user_belongs_to_resource.return_value = True + mock_dag_instance.record.record_uid = 'config_uid' # Ensure it returns a string - kwargs = { - 'folder_name': 'folder_uid', - 'force': True # Add force to the kwargs - } - - self.command.execute(mock_params, **kwargs) - self.assertTrue(mock_load.called) - self.assertTrue(mock_TunnelDAG.called) - - @patch('keepercommander.vault.KeeperRecord.load', return_value=None) - @patch('keepercommander.commands.discoveryrotation.TunnelDAG') - @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) - def test_execute_with_no_record(self, mock_TunnelDAG, mock_load): - mock_params, _ = create_mock_params_and_record() - mock_params.record_cache = {} - - kwargs = { - 'record_name': 'non_existent_record', - 'force': True # Add force to the kwargs - } - - with self.assertRaises(CommandError): - self.command.execute(mock_params, **kwargs) - - @patch('keepercommander.vault.KeeperRecord.load', return_value=None) - @patch('keepercommander.commands.discoveryrotation.TunnelDAG') - @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) - def test_execute_with_invalid_password_complexity(self, mock_TunnelDAG, mock_load): - mock_params, _ = create_mock_params_and_record() - - kwargs = { - 'record_name': 'record_uid', - 'pwd_complexity': 'invalid_complexity', - 'force': True # Add force to the kwargs - } - - with self.assertRaises(CommandError): - self.command.execute(mock_params, **kwargs) - - @patch('keepercommander.vault.KeeperRecord.load') - @patch('keepercommander.commands.discoveryrotation.TunnelDAG') - @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) - def test_execute_with_valid_password_complexity(self, mock_TunnelDAG, mock_load): - mock_params, mock_typed_record = create_mock_params_and_record() - - mock_load.return_value = mock_typed_record + kwargs = { + 'folder_name': 'folder_uid', + 'force': True # Add force to the kwargs + } - mock_dag_instance = mock_TunnelDAG.return_value - mock_dag_instance.linking_dag.has_graph = True - mock_dag_instance.check_if_resource_has_admin.return_value = True - mock_dag_instance.get_all_owners.return_value = ['resource_uid'] - mock_dag_instance.resource_belongs_to_config.return_value = True - mock_dag_instance.user_belongs_to_resource.return_value = True - mock_dag_instance.record.record_uid = 'config_uid' # Ensure it returns a string + self.command.execute(mock_params, **kwargs) + self.assertTrue(mock_load.called) + self.assertTrue(mock_TunnelDAG.called) - kwargs = { - 'record_name': 'record_uid', - 'pwd_complexity': '32,5,5,5,5', - 'force': True # Add force to the kwargs - } + @patch('keepercommander.vault.KeeperRecord.load', return_value=None) + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) + def test_execute_with_no_record(self, mock_TunnelDAG, mock_load): + mock_params, _ = create_mock_params_and_record() + mock_params.record_cache = {} - self.command.execute(mock_params, **kwargs) - self.assertTrue(mock_load.called) - self.assertEqual(mock_typed_record.record_key, b'\x00' * 16) - - @patch('keepercommander.vault.KeeperRecord.load') - @patch('keepercommander.commands.discoveryrotation.TunnelDAG') - @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) - def test_execute_with_valid_record(self, mock_TunnelDAG, mock_load): - mock_params, mock_typed_record = create_mock_params_and_record() - - mock_load.return_value = mock_typed_record - - mock_dag_instance = mock_TunnelDAG.return_value - mock_dag_instance.linking_dag.has_graph = True - mock_dag_instance.check_if_resource_has_admin.return_value = True - mock_dag_instance.get_all_owners.return_value = ['resource_uid'] - mock_dag_instance.resource_belongs_to_config.return_value = True - mock_dag_instance.user_belongs_to_resource.return_value = True - mock_dag_instance.record.record_uid = 'config_uid' # Ensure it returns a string - - kwargs = { - 'record_name': 'record_uid', - 'force': True # Add force to the kwargs - } + kwargs = { + 'record_name': 'non_existent_record', + 'force': True # Add force to the kwargs + } + with self.assertRaises(CommandError): self.command.execute(mock_params, **kwargs) - self.assertTrue(mock_load.called) - self.assertTrue(mock_TunnelDAG.called) - self.assertEqual(mock_typed_record.record_key, b'\x00' * 16) - - - class TestPAMResourceRotateCommand(unittest.TestCase): - - def setUp(self): - self.command = PAMCreateRecordRotationCommand() - self.parser = self.command.get_parser() - def test_parser(self): - args = self.parser.parse_args(['--record', "abcdefg", '--enable']) - self.assertEqual(args.record_name, 'abcdefg') - self.assertTrue(args.enable) + @patch('keepercommander.vault.KeeperRecord.load', return_value=None) + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) + def test_execute_with_invalid_password_complexity(self, mock_TunnelDAG, mock_load): + mock_params, _ = create_mock_params_and_record() - @patch('keepercommander.vault_extensions.find_records') - @patch('keepercommander.vault.KeeperRecord.load') - @patch('keepercommander.commands.discoveryrotation.get_keeper_tokens') - @patch('keepercommander.commands.discoveryrotation.TunnelDAG') - def test_execute_with_enable(self, mock_tunneldag, mock_get_keeper_tokens, mock_load, mock_find_records): - mock_dag_instance = mock_tunneldag.return_value - mock_dag_instance.linking_dag.has_graph = True - mock_dag_instance.resource_belongs_to_config.return_value = True - - mock_get_keeper_tokens.return_value = (b'token', b'encrypted_key', b'transmission_key') - - mock_params, mock_typed_record = create_mock_params_and_record('pamMachine') - mock_load.return_value = mock_typed_record - - mock_pam_config_record = MagicMock(spec=vault.TypedRecord) - mock_pam_config_record.record_uid = 'config_uid' - mock_pam_config_record.record_type = 'pamConfiguration' # Use a valid PAM configuration record type - mock_find_records.return_value = [mock_pam_config_record] - - kwargs = { - 'record_name': 'record_uid', - 'enable': True, - 'config_uid': 'config_uid' - } + kwargs = { + 'record_name': 'record_uid', + 'pwd_complexity': 'invalid_complexity', + 'force': True # Add force to the kwargs + } + with self.assertRaises(CommandError): self.command.execute(mock_params, **kwargs) - self.assertTrue(mock_load.called) - self.assertTrue(mock_tunneldag.called) - self.assertTrue(mock_get_keeper_tokens.called) - - @patch('keepercommander.vault.KeeperRecord.load', return_value=None) - def test_execute_with_invalid_uid(self, mock_load): - mock_params, _ = create_mock_params_and_record('pamMachine') - - kwargs = { - 'record_name': 'invalid_uid', - 'enable': True - } - with self.assertRaises(CommandError): - self.command.execute(mock_params, **kwargs) - - @patch('keepercommander.vault.KeeperRecord.load') - def test_execute_with_invalid_record_type(self, mock_load): - mock_params, mock_typed_record = create_mock_params_and_record(record_type='invalid_type') - mock_load.return_value = mock_typed_record + @patch('keepercommander.vault.KeeperRecord.load') + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) + def test_execute_with_valid_password_complexity(self, mock_TunnelDAG, mock_load): + mock_params, mock_typed_record = create_mock_params_and_record() + + mock_load.return_value = mock_typed_record + + mock_dag_instance = mock_TunnelDAG.return_value + mock_dag_instance.linking_dag.has_graph = True + mock_dag_instance.check_if_resource_has_admin.return_value = True + mock_dag_instance.get_all_owners.return_value = ['resource_uid'] + mock_dag_instance.resource_belongs_to_config.return_value = True + mock_dag_instance.user_belongs_to_resource.return_value = True + mock_dag_instance.record.record_uid = 'config_uid' # Ensure it returns a string + + kwargs = { + 'record_name': 'record_uid', + 'pwd_complexity': '32,5,5,5,5', + 'force': True # Add force to the kwargs + } - kwargs = { - 'record_name': 'record_uid', - 'enable': True - } + self.command.execute(mock_params, **kwargs) + self.assertTrue(mock_load.called) + self.assertEqual(mock_typed_record.record_key, b'\x00' * 16) + + @patch('keepercommander.vault.KeeperRecord.load') + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) + def test_execute_with_valid_record(self, mock_TunnelDAG, mock_load): + mock_params, mock_typed_record = create_mock_params_and_record() + + mock_load.return_value = mock_typed_record + + mock_dag_instance = mock_TunnelDAG.return_value + mock_dag_instance.linking_dag.has_graph = True + mock_dag_instance.check_if_resource_has_admin.return_value = True + mock_dag_instance.get_all_owners.return_value = ['resource_uid'] + mock_dag_instance.resource_belongs_to_config.return_value = True + mock_dag_instance.user_belongs_to_resource.return_value = True + mock_dag_instance.record.record_uid = 'config_uid' # Ensure it returns a string + + kwargs = { + 'record_name': 'record_uid', + 'force': True # Add force to the kwargs + } - with self.assertRaises(CommandError): - self.command.execute(mock_params, **kwargs) + self.command.execute(mock_params, **kwargs) + self.assertTrue(mock_load.called) + self.assertTrue(mock_TunnelDAG.called) + self.assertEqual(mock_typed_record.record_key, b'\x00' * 16) - @patch('keepercommander.vault_extensions.find_records') - @patch('keepercommander.vault.KeeperRecord.load') - @patch('keepercommander.commands.discoveryrotation.get_keeper_tokens') - @patch('keepercommander.commands.discoveryrotation.TunnelDAG') - def test_execute_with_disable(self, mock_tunneldag, mock_get_keeper_tokens, mock_load, mock_find_records): - mock_dag_instance = mock_tunneldag.return_value - mock_dag_instance.linking_dag.has_graph = True - mock_dag_instance.resource_belongs_to_config.return_value = True - mock_get_keeper_tokens.return_value = (b'token', b'encrypted_key', b'transmission_key') +class TestPAMResourceRotateCommand(unittest.TestCase): - mock_params, mock_typed_record = create_mock_params_and_record('pamMachine') - mock_load.return_value = mock_typed_record + def setUp(self): + self.command = PAMCreateRecordRotationCommand() + self.parser = self.command.get_parser() - mock_pam_config_record = MagicMock(spec=vault.TypedRecord) - mock_pam_config_record.record_uid = 'config_uid' - mock_pam_config_record.record_type = 'pamConfiguration' # Use a valid PAM configuration record type - mock_find_records.return_value = [mock_pam_config_record] + def test_parser(self): + args = self.parser.parse_args(['--record', "abcdefg", '--enable']) + self.assertEqual(args.record_name, 'abcdefg') + self.assertTrue(args.enable) - kwargs = { - 'record_name': 'record_uid', - 'disable': True, - 'config_uid': 'config_uid' - } + @patch('keepercommander.vault_extensions.find_records') + @patch('keepercommander.vault.KeeperRecord.load') + @patch('keepercommander.commands.discoveryrotation.get_keeper_tokens') + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + def test_execute_with_enable(self, mock_tunneldag, mock_get_keeper_tokens, mock_load, mock_find_records): + mock_dag_instance = mock_tunneldag.return_value + mock_dag_instance.linking_dag.has_graph = True + mock_dag_instance.resource_belongs_to_config.return_value = True - self.command.execute(mock_params, **kwargs) - self.assertTrue(mock_load.called) - self.assertTrue(mock_tunneldag.called) - self.assertTrue(mock_get_keeper_tokens.called) - - @patch('keepercommander.vault_extensions.find_records') - @patch('keepercommander.vault.KeeperRecord.load') - @patch('keepercommander.commands.discoveryrotation.get_keeper_tokens') - @patch('keepercommander.commands.discoveryrotation.TunnelDAG') - def test_execute_with_enable_and_admin(self, mock_tunneldag, mock_get_keeper_tokens, mock_load, mock_find_records): - mock_dag_instance = mock_tunneldag.return_value - mock_dag_instance.linking_dag.has_graph = True - mock_dag_instance.resource_belongs_to_config.return_value = True - - mock_get_keeper_tokens.return_value = (b'token', b'encrypted_key', b'transmission_key') - - mock_params, mock_typed_record = create_mock_params_and_record('pamMachine') - mock_load.return_value = mock_typed_record - - mock_pam_config_record = MagicMock(spec=vault.TypedRecord) - mock_pam_config_record.record_uid = 'config_uid' - mock_pam_config_record.record_type = 'pamConfiguration' # Use a valid PAM configuration record type - mock_find_records.return_value = [mock_pam_config_record] - - kwargs = { - 'record_name': 'record_uid', - 'enable': True, - 'config_uid': 'config_uid', - 'admin': 'admin_uid' - } + mock_get_keeper_tokens.return_value = (b'token', b'encrypted_key', b'transmission_key') - self.command.execute(mock_params, **kwargs) - self.assertTrue(mock_load.called) - self.assertTrue(mock_tunneldag.called) - self.assertTrue(mock_get_keeper_tokens.called) - mock_dag_instance.link_user_to_resource.assert_called_with('admin_uid', 'record_uid', is_admin=True) - - - class TestPAMListRecordRotationCommand(unittest.TestCase): - - def setUp(self): - self.command = PAMListRecordRotationCommand() - self.parser = self.command.get_parser() - - def test_parser(self): - args = self.parser.parse_args(['--verbose']) - self.assertTrue(args.is_verbose) - - @patch('keepercommander.commands.discoveryrotation.router_get_rotation_schedules') - @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') - @patch('keepercommander.commands.discoveryrotation.pam_configurations_get_all') - @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') - @patch('keepercommander.commands.discoveryrotation.pam_decrypt_configuration_data') - @patch('keepercommander.commands.discoveryrotation.dump_report_data') - def test_execute(self, mock_dump_report_data, mock_pam_decrypt_configuration_data, mock_get_all_gateways, - mock_pam_configurations_get_all, mock_router_get_connected_gateways, mock_router_get_rotation_schedules): - mock_params = create_mock_params() - - # Mock the return values - mock_router_get_rotation_schedules.return_value.schedules = [ - MagicMock( - recordUid=utils.base64_url_decode('record_uid'), - controllerUid=utils.base64_url_decode('controller_uid'), - configurationUid=utils.base64_url_decode('config_uid'), - noSchedule=False, - scheduleData='RotateActionJob|daily.0.12.1' - ) - ] - - mock_get_all_gateways.return_value = [ - MagicMock(controllerUid=utils.base64_url_decode('controller_uid'), controllerName='Controller Name') - ] - - mock_router_get_connected_gateways.return_value.controllers = [ - MagicMock(controllerUid=utils.base64_url_decode('controller_uid')) - ] - - mock_pam_configurations_get_all.return_value = [ - {'record_uid': 'config_uid', 'data_unencrypted': json.dumps({'title': 'Config Title', 'type': 'pamConfig'})} - ] - - mock_pam_decrypt_configuration_data.return_value = { - 'title': 'Config Title', - 'type': 'pamConfig' - } + mock_params, mock_typed_record = create_mock_params_and_record('pamMachine') + mock_load.return_value = mock_typed_record - kwargs = {'is_verbose': True} - self.command.execute(mock_params, **kwargs) + mock_pam_config_record = MagicMock(spec=vault.TypedRecord) + mock_pam_config_record.record_uid = 'config_uid' + mock_pam_config_record.record_type = 'pamConfiguration' # Use a valid PAM configuration record type + mock_find_records.return_value = [mock_pam_config_record] - self.assertTrue(mock_dump_report_data.called) - self.assertTrue(mock_router_get_rotation_schedules.called) - self.assertTrue(mock_get_all_gateways.called) - self.assertTrue(mock_router_get_connected_gateways.called) - self.assertTrue(mock_pam_configurations_get_all.called) + kwargs = { + 'record_name': 'record_uid', + 'enable': True, + 'config_uid': 'config_uid' + } - @patch('keepercommander.commands.discoveryrotation.router_get_rotation_schedules') - @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') - @patch('keepercommander.commands.discoveryrotation.pam_configurations_get_all') - @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') - @patch('keepercommander.commands.discoveryrotation.pam_decrypt_configuration_data') - @patch('keepercommander.commands.discoveryrotation.dump_report_data') - def test_execute_with_no_schedules(self, mock_dump_report_data, mock_pam_decrypt_configuration_data, mock_get_all_gateways, - mock_pam_configurations_get_all, mock_router_get_connected_gateways, mock_router_get_rotation_schedules): - mock_params = create_mock_params() + self.command.execute(mock_params, **kwargs) + self.assertTrue(mock_load.called) + self.assertTrue(mock_tunneldag.called) + self.assertTrue(mock_get_keeper_tokens.called) - # Mock the return values - mock_router_get_rotation_schedules.return_value.schedules = [] + @patch('keepercommander.vault.KeeperRecord.load', return_value=None) + def test_execute_with_invalid_uid(self, mock_load): + mock_params, _ = create_mock_params_and_record('pamMachine') - mock_get_all_gateways.return_value = [] + kwargs = { + 'record_name': 'invalid_uid', + 'enable': True + } - mock_router_get_connected_gateways.return_value.controllers = [] + with self.assertRaises(CommandError): + self.command.execute(mock_params, **kwargs) - mock_pam_configurations_get_all.return_value = [] + @patch('keepercommander.vault.KeeperRecord.load') + def test_execute_with_invalid_record_type(self, mock_load): + mock_params, mock_typed_record = create_mock_params_and_record(record_type='invalid_type') + mock_load.return_value = mock_typed_record - mock_pam_decrypt_configuration_data.return_value = {} + kwargs = { + 'record_name': 'record_uid', + 'enable': True + } - kwargs = {'is_verbose': True} + with self.assertRaises(CommandError): self.command.execute(mock_params, **kwargs) - self.assertTrue(mock_dump_report_data.called) - self.assertTrue(mock_router_get_rotation_schedules.called) - self.assertTrue(mock_get_all_gateways.called) - self.assertTrue(mock_router_get_connected_gateways.called) - self.assertTrue(mock_pam_configurations_get_all.called) - - - class TestPAMGatewayListCommand(unittest.TestCase): - - def setUp(self): - self.command = PAMGatewayListCommand() - self.parser = self.command.get_parser() - - def test_parser(self): - args = self.parser.parse_args(['--verbose', '--force']) - self.assertTrue(args.is_verbose) - self.assertTrue(args.is_force) - - @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') - @patch('keepercommander.commands.discoveryrotation.router_helper.get_router_url') - @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') - @patch('keepercommander.commands.discoveryrotation.KSMCommand.get_app_record') - @patch('keepercommander.commands.discoveryrotation.dump_report_data') - def test_execute(self, mock_dump_report_data, mock_get_app_record, mock_get_all_gateways, - mock_get_router_url, mock_router_get_connected_gateways): - mock_params = create_mock_params() - - # Mock the return values - mock_router_get_connected_gateways.return_value.controllers = [ - MagicMock(controllerUid=utils.base64_url_decode('controller_uid')) - ] - - mock_get_all_gateways.return_value = [ - MagicMock( - applicationUid=utils.base64_url_decode('app_uid'), - controllerUid=utils.base64_url_decode('controller_uid'), - controllerName='Controller Name', - deviceName='Device Name', - deviceToken='Device Token', - created=int(datetime.now().timestamp() * 1000), - lastModified=int(datetime.now().timestamp() * 1000), - nodeId='Node ID' - ) - ] - - mock_get_app_record.return_value = { - 'data_unencrypted': json.dumps({'title': 'App Title'}) - } + @patch('keepercommander.vault_extensions.find_records') + @patch('keepercommander.vault.KeeperRecord.load') + @patch('keepercommander.commands.discoveryrotation.get_keeper_tokens') + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + def test_execute_with_disable(self, mock_tunneldag, mock_get_keeper_tokens, mock_load, mock_find_records): + mock_dag_instance = mock_tunneldag.return_value + mock_dag_instance.linking_dag.has_graph = True + mock_dag_instance.resource_belongs_to_config.return_value = True - kwargs = {'is_force': True, 'is_verbose': True} - self.command.execute(mock_params, **kwargs) + mock_get_keeper_tokens.return_value = (b'token', b'encrypted_key', b'transmission_key') - self.assertTrue(mock_dump_report_data.called) - self.assertTrue(mock_router_get_connected_gateways.called) - self.assertTrue(mock_get_all_gateways.called) - self.assertTrue(mock_get_router_url.called) - self.assertTrue(mock_get_app_record.called) - - @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') - @patch('keepercommander.commands.discoveryrotation.router_helper.get_router_url') - @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') - @patch('keepercommander.commands.discoveryrotation.dump_report_data') - def test_execute_router_down(self, mock_dump_report_data, mock_get_all_gateways, - mock_get_router_url, mock_router_get_connected_gateways): - mock_params = create_mock_params() - - # Simulate a connection error - mock_router_get_connected_gateways.side_effect = requests.exceptions.ConnectionError - - mock_get_all_gateways.return_value = [ - MagicMock( - applicationUid=utils.base64_url_decode('app_uid'), - controllerUid=utils.base64_url_decode('controller_uid'), - controllerName='Controller Name', - deviceName='Device Name', - deviceToken='Device Token', - created=int(datetime.now().timestamp() * 1000), - lastModified=int(datetime.now().timestamp() * 1000), - nodeId='Node ID' - ) - ] - - kwargs = {'is_force': True, 'is_verbose': True} - self.command.execute(mock_params, **kwargs) + mock_params, mock_typed_record = create_mock_params_and_record('pamMachine') + mock_load.return_value = mock_typed_record - self.assertTrue(mock_dump_report_data.called) - self.assertTrue(mock_router_get_connected_gateways.called) - self.assertTrue(mock_get_all_gateways.called) - self.assertTrue(mock_get_router_url.called) + mock_pam_config_record = MagicMock(spec=vault.TypedRecord) + mock_pam_config_record.record_uid = 'config_uid' + mock_pam_config_record.record_type = 'pamConfiguration' # Use a valid PAM configuration record type + mock_find_records.return_value = [mock_pam_config_record] - @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') - @patch('keepercommander.commands.discoveryrotation.router_helper.get_router_url') - @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') - def test_execute_no_gateways(self, mock_get_all_gateways, - mock_get_router_url, mock_router_get_connected_gateways): - mock_params = create_mock_params() + kwargs = { + 'record_name': 'record_uid', + 'disable': True, + 'config_uid': 'config_uid' + } - mock_router_get_connected_gateways.return_value.controllers = [] + self.command.execute(mock_params, **kwargs) + self.assertTrue(mock_load.called) + self.assertTrue(mock_tunneldag.called) + self.assertTrue(mock_get_keeper_tokens.called) + + @patch('keepercommander.vault_extensions.find_records') + @patch('keepercommander.vault.KeeperRecord.load') + @patch('keepercommander.commands.discoveryrotation.get_keeper_tokens') + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + def test_execute_with_enable_and_admin(self, mock_tunneldag, mock_get_keeper_tokens, mock_load, mock_find_records): + mock_dag_instance = mock_tunneldag.return_value + mock_dag_instance.linking_dag.has_graph = True + mock_dag_instance.resource_belongs_to_config.return_value = True + + mock_get_keeper_tokens.return_value = (b'token', b'encrypted_key', b'transmission_key') + + mock_params, mock_typed_record = create_mock_params_and_record('pamMachine') + mock_load.return_value = mock_typed_record + + mock_pam_config_record = MagicMock(spec=vault.TypedRecord) + mock_pam_config_record.record_uid = 'config_uid' + mock_pam_config_record.record_type = 'pamConfiguration' # Use a valid PAM configuration record type + mock_find_records.return_value = [mock_pam_config_record] + + kwargs = { + 'record_name': 'record_uid', + 'enable': True, + 'config_uid': 'config_uid', + 'admin': 'admin_uid' + } - mock_get_all_gateways.return_value = [] + self.command.execute(mock_params, **kwargs) + self.assertTrue(mock_load.called) + self.assertTrue(mock_tunneldag.called) + self.assertTrue(mock_get_keeper_tokens.called) + mock_dag_instance.link_user_to_resource.assert_called_with('admin_uid', 'record_uid', is_admin=True) + + +class TestPAMListRecordRotationCommand(unittest.TestCase): + + def setUp(self): + self.command = PAMListRecordRotationCommand() + self.parser = self.command.get_parser() + + def test_parser(self): + args = self.parser.parse_args(['--verbose']) + self.assertTrue(args.is_verbose) + + @patch('keepercommander.commands.discoveryrotation.router_get_rotation_schedules') + @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') + @patch('keepercommander.commands.discoveryrotation.pam_configurations_get_all') + @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') + @patch('keepercommander.commands.discoveryrotation.pam_decrypt_configuration_data') + @patch('keepercommander.commands.discoveryrotation.dump_report_data') + def test_execute(self, mock_dump_report_data, mock_pam_decrypt_configuration_data, mock_get_all_gateways, + mock_pam_configurations_get_all, mock_router_get_connected_gateways, mock_router_get_rotation_schedules): + mock_params = create_mock_params() + + # Mock the return values + mock_router_get_rotation_schedules.return_value.schedules = [ + MagicMock( + recordUid=utils.base64_url_decode('record_uid'), + controllerUid=utils.base64_url_decode('controller_uid'), + configurationUid=utils.base64_url_decode('config_uid'), + noSchedule=False, + scheduleData='RotateActionJob|daily.0.12.1' + ) + ] + + mock_get_all_gateways.return_value = [ + MagicMock(controllerUid=utils.base64_url_decode('controller_uid'), controllerName='Controller Name') + ] + + mock_router_get_connected_gateways.return_value.controllers = [ + MagicMock(controllerUid=utils.base64_url_decode('controller_uid')) + ] + + mock_pam_configurations_get_all.return_value = [ + {'record_uid': 'config_uid', 'data_unencrypted': json.dumps({'title': 'Config Title', 'type': 'pamConfig'})} + ] + + mock_pam_decrypt_configuration_data.return_value = { + 'title': 'Config Title', + 'type': 'pamConfig' + } - kwargs = {'is_force': True, 'is_verbose': True} - self.command.execute(mock_params, **kwargs) + kwargs = {'is_verbose': True} + self.command.execute(mock_params, **kwargs) + + self.assertTrue(mock_dump_report_data.called) + self.assertTrue(mock_router_get_rotation_schedules.called) + self.assertTrue(mock_get_all_gateways.called) + self.assertTrue(mock_router_get_connected_gateways.called) + self.assertTrue(mock_pam_configurations_get_all.called) + + @patch('keepercommander.commands.discoveryrotation.router_get_rotation_schedules') + @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') + @patch('keepercommander.commands.discoveryrotation.pam_configurations_get_all') + @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') + @patch('keepercommander.commands.discoveryrotation.pam_decrypt_configuration_data') + @patch('keepercommander.commands.discoveryrotation.dump_report_data') + def test_execute_with_no_schedules(self, mock_dump_report_data, mock_pam_decrypt_configuration_data, mock_get_all_gateways, + mock_pam_configurations_get_all, mock_router_get_connected_gateways, mock_router_get_rotation_schedules): + mock_params = create_mock_params() + + # Mock the return values + mock_router_get_rotation_schedules.return_value.schedules = [] + + mock_get_all_gateways.return_value = [] + + mock_router_get_connected_gateways.return_value.controllers = [] + + mock_pam_configurations_get_all.return_value = [] + + mock_pam_decrypt_configuration_data.return_value = {} + + kwargs = {'is_verbose': True} + self.command.execute(mock_params, **kwargs) + + self.assertTrue(mock_dump_report_data.called) + self.assertTrue(mock_router_get_rotation_schedules.called) + self.assertTrue(mock_get_all_gateways.called) + self.assertTrue(mock_router_get_connected_gateways.called) + self.assertTrue(mock_pam_configurations_get_all.called) + + +class TestPAMGatewayListCommand(unittest.TestCase): + + def setUp(self): + self.command = PAMGatewayListCommand() + self.parser = self.command.get_parser() + + def test_parser(self): + args = self.parser.parse_args(['--verbose', '--force']) + self.assertTrue(args.is_verbose) + self.assertTrue(args.is_force) + + @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') + @patch('keepercommander.commands.discoveryrotation.router_helper.get_router_url') + @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') + @patch('keepercommander.commands.discoveryrotation.KSMCommand.get_app_record') + @patch('keepercommander.commands.discoveryrotation.dump_report_data') + def test_execute(self, mock_dump_report_data, mock_get_app_record, mock_get_all_gateways, + mock_get_router_url, mock_router_get_connected_gateways): + mock_params = create_mock_params() + + # Mock the return values + mock_router_get_connected_gateways.return_value.controllers = [ + MagicMock(controllerUid=utils.base64_url_decode('controller_uid')) + ] + + mock_get_all_gateways.return_value = [ + MagicMock( + applicationUid=utils.base64_url_decode('app_uid'), + controllerUid=utils.base64_url_decode('controller_uid'), + controllerName='Controller Name', + deviceName='Device Name', + deviceToken='Device Token', + created=int(datetime.now().timestamp() * 1000), + lastModified=int(datetime.now().timestamp() * 1000), + nodeId='Node ID' + ) + ] + + mock_get_app_record.return_value = { + 'data_unencrypted': json.dumps({'title': 'App Title'}) + } - self.assertTrue(mock_router_get_connected_gateways.called) - self.assertTrue(mock_get_all_gateways.called) - self.assertTrue(mock_get_router_url.called) - - class TestPAMRouterGetRotationInfo(unittest.TestCase): - - def _make_rri(self, status_name='RRS_ONLINE'): - """Build a minimal RouterRotationInfo mock.""" - from keepercommander.proto import router_pb2 - rri = MagicMock() - rri.status = router_pb2.RouterRotationStatus.Value(status_name) - rri.configurationUid = utils.base64_url_decode('config_uid_____') - rri.nodeId = 42 - rri.controllerName = 'gw-test' - rri.controllerUid = utils.base64_url_decode('gw_uid_________') - rri.resourceUid = b'' - rri.pwdComplexity = '' - rri.disabled = False - rri.scriptName = '' - return rri - - def _make_schedule(self, record_uid_bytes, no_schedule=False, schedule_data='daily.0.12.1'): - s = MagicMock() - s.recordUid = record_uid_bytes - s.noSchedule = no_schedule - s.scheduleData = schedule_data - return s - - @patch('keepercommander.commands.discoveryrotation.router_get_rotation_schedules') - @patch('keepercommander.commands.discoveryrotation.record_rotation_get') - def test_json_online_status(self, mock_rrg, mock_schedules): - """Online status + --format json returns valid JSON with expected keys.""" - from keeper_secrets_manager_core.utils import url_safe_str_to_bytes - record_uid = 'test_record_uid_' - record_uid_bytes = url_safe_str_to_bytes(record_uid) - - mock_rrg.return_value = self._make_rri('RRS_ONLINE') - - sched_mock = MagicMock() - sched_mock.schedules = [self._make_schedule(record_uid_bytes, no_schedule=False, - schedule_data='daily.0.12.1')] - mock_schedules.return_value = sched_mock - - mock_params = create_mock_params() - mock_params.record_cache = {} - - cmd = PAMRouterGetRotationInfo() - result = cmd.execute(mock_params, record_uid=record_uid, format='json') - - self.assertIsNotNone(result, "Expected JSON string, got None") - data = json.loads(result) - self.assertIn('status', data) - self.assertTrue(data['ready_to_rotate']) - self.assertIn('pam_config_uid', data) - self.assertIn('gateway_name', data) - self.assertEqual(data['gateway_name'], 'gw-test') - self.assertIn('schedule_type', data) - self.assertEqual(data['schedule_type'], 'scheduled') - - @patch('keepercommander.commands.discoveryrotation.router_get_rotation_schedules') - @patch('keepercommander.commands.discoveryrotation.record_rotation_get') - def test_json_non_online_status(self, mock_rrg, mock_schedules): - """Non-online status + --format json returns minimal JSON with ready_to_rotate=false.""" - record_uid = 'test_record_uid_' - - mock_rrg.return_value = self._make_rri('RRS_NO_ROTATION') - - mock_params = create_mock_params() - - cmd = PAMRouterGetRotationInfo() - result = cmd.execute(mock_params, record_uid=record_uid, format='json') - - self.assertIsNotNone(result, "Expected JSON string, got None") - data = json.loads(result) - self.assertIn('status', data) - self.assertFalse(data['ready_to_rotate']) - - @patch('keepercommander.commands.discoveryrotation.router_get_rotation_schedules') - @patch('keepercommander.commands.discoveryrotation.record_rotation_get') - def test_table_mode_returns_none(self, mock_rrg, mock_schedules): - """Table mode (default) prints to stdout and returns None.""" - from keeper_secrets_manager_core.utils import url_safe_str_to_bytes - record_uid = 'test_record_uid_' - record_uid_bytes = url_safe_str_to_bytes(record_uid) - - mock_rrg.return_value = self._make_rri('RRS_ONLINE') - - sched_mock = MagicMock() - sched_mock.schedules = [self._make_schedule(record_uid_bytes)] - mock_schedules.return_value = sched_mock - - mock_params = create_mock_params() - mock_params.record_cache = {} - - cmd = PAMRouterGetRotationInfo() - result = cmd.execute(mock_params, record_uid=record_uid, format='table') - self.assertIsNone(result) + kwargs = {'is_force': True, 'is_verbose': True} + self.command.execute(mock_params, **kwargs) + + self.assertTrue(mock_dump_report_data.called) + self.assertTrue(mock_router_get_connected_gateways.called) + self.assertTrue(mock_get_all_gateways.called) + self.assertTrue(mock_get_router_url.called) + self.assertTrue(mock_get_app_record.called) + + @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') + @patch('keepercommander.commands.discoveryrotation.router_helper.get_router_url') + @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') + @patch('keepercommander.commands.discoveryrotation.dump_report_data') + def test_execute_router_down(self, mock_dump_report_data, mock_get_all_gateways, + mock_get_router_url, mock_router_get_connected_gateways): + mock_params = create_mock_params() + + # Simulate a connection error + mock_router_get_connected_gateways.side_effect = requests.exceptions.ConnectionError + + mock_get_all_gateways.return_value = [ + MagicMock( + applicationUid=utils.base64_url_decode('app_uid'), + controllerUid=utils.base64_url_decode('controller_uid'), + controllerName='Controller Name', + deviceName='Device Name', + deviceToken='Device Token', + created=int(datetime.now().timestamp() * 1000), + lastModified=int(datetime.now().timestamp() * 1000), + nodeId='Node ID' + ) + ] + + kwargs = {'is_force': True, 'is_verbose': True} + self.command.execute(mock_params, **kwargs) + + self.assertTrue(mock_dump_report_data.called) + self.assertTrue(mock_router_get_connected_gateways.called) + self.assertTrue(mock_get_all_gateways.called) + self.assertTrue(mock_get_router_url.called) + + @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') + @patch('keepercommander.commands.discoveryrotation.router_helper.get_router_url') + @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') + def test_execute_no_gateways(self, mock_get_all_gateways, + mock_get_router_url, mock_router_get_connected_gateways): + mock_params = create_mock_params() + + mock_router_get_connected_gateways.return_value.controllers = [] + + mock_get_all_gateways.return_value = [] + + kwargs = {'is_force': True, 'is_verbose': True} + self.command.execute(mock_params, **kwargs) + + self.assertTrue(mock_router_get_connected_gateways.called) + self.assertTrue(mock_get_all_gateways.called) + self.assertTrue(mock_get_router_url.called) + +class TestPAMRouterGetRotationInfo(unittest.TestCase): + + def _make_rri(self, status_name='RRS_ONLINE'): + """Build a minimal RouterRotationInfo mock.""" + from keepercommander.proto import router_pb2 + rri = MagicMock() + rri.status = router_pb2.RouterRotationStatus.Value(status_name) + rri.configurationUid = utils.base64_url_decode('config_uid_____') + rri.nodeId = 42 + rri.controllerName = 'gw-test' + rri.controllerUid = utils.base64_url_decode('gw_uid_________') + rri.resourceUid = b'' + rri.pwdComplexity = '' + rri.disabled = False + rri.scriptName = '' + return rri + + def _make_schedule(self, record_uid_bytes, no_schedule=False, schedule_data='daily.0.12.1'): + s = MagicMock() + s.recordUid = record_uid_bytes + s.noSchedule = no_schedule + s.scheduleData = schedule_data + return s + + @patch('keepercommander.commands.discoveryrotation.router_get_rotation_schedules') + @patch('keepercommander.commands.discoveryrotation.record_rotation_get') + def test_json_online_status(self, mock_rrg, mock_schedules): + """Online status + --format json returns valid JSON with expected keys.""" + from keeper_secrets_manager_core.utils import url_safe_str_to_bytes + record_uid = 'test_record_uid_' + record_uid_bytes = url_safe_str_to_bytes(record_uid) + + mock_rrg.return_value = self._make_rri('RRS_ONLINE') + + sched_mock = MagicMock() + sched_mock.schedules = [self._make_schedule(record_uid_bytes, no_schedule=False, + schedule_data='daily.0.12.1')] + mock_schedules.return_value = sched_mock + + mock_params = create_mock_params() + mock_params.record_cache = {} + + cmd = PAMRouterGetRotationInfo() + result = cmd.execute(mock_params, record_uid=record_uid, format='json') + + self.assertIsNotNone(result, "Expected JSON string, got None") + data = json.loads(result) + self.assertIn('status', data) + self.assertTrue(data['ready_to_rotate']) + self.assertIn('pam_config_uid', data) + self.assertIn('gateway_name', data) + self.assertEqual(data['gateway_name'], 'gw-test') + self.assertIn('schedule_type', data) + self.assertEqual(data['schedule_type'], 'scheduled') + + @patch('keepercommander.commands.discoveryrotation.router_get_rotation_schedules') + @patch('keepercommander.commands.discoveryrotation.record_rotation_get') + def test_json_non_online_status(self, mock_rrg, mock_schedules): + """Non-online status + --format json returns minimal JSON with ready_to_rotate=false.""" + record_uid = 'test_record_uid_' + + mock_rrg.return_value = self._make_rri('RRS_NO_ROTATION') + + mock_params = create_mock_params() + + cmd = PAMRouterGetRotationInfo() + result = cmd.execute(mock_params, record_uid=record_uid, format='json') + + self.assertIsNotNone(result, "Expected JSON string, got None") + data = json.loads(result) + self.assertIn('status', data) + self.assertFalse(data['ready_to_rotate']) + + @patch('keepercommander.commands.discoveryrotation.router_get_rotation_schedules') + @patch('keepercommander.commands.discoveryrotation.record_rotation_get') + def test_table_mode_returns_none(self, mock_rrg, mock_schedules): + """Table mode (default) prints to stdout and returns None.""" + from keeper_secrets_manager_core.utils import url_safe_str_to_bytes + record_uid = 'test_record_uid_' + record_uid_bytes = url_safe_str_to_bytes(record_uid) + + mock_rrg.return_value = self._make_rri('RRS_ONLINE') + + sched_mock = MagicMock() + sched_mock.schedules = [self._make_schedule(record_uid_bytes)] + mock_schedules.return_value = sched_mock + + mock_params = create_mock_params() + mock_params.record_cache = {} + + cmd = PAMRouterGetRotationInfo() + result = cmd.execute(mock_params, record_uid=record_uid, format='table') + self.assertIsNone(result) diff --git a/unit-tests/pam/test_pam_tunnel.py b/unit-tests/pam/test_pam_tunnel.py index b8dc67b7a..41c55e7f8 100644 --- a/unit-tests/pam/test_pam_tunnel.py +++ b/unit-tests/pam/test_pam_tunnel.py @@ -1,158 +1,156 @@ -import sys import unittest from unittest import mock from keepercommander.error import CommandError -if sys.version_info >= (3, 8): - import datetime - import socket - import string - from cryptography import x509 - from cryptography.hazmat._oid import NameOID - from cryptography.hazmat.backends import default_backend - from cryptography.hazmat.primitives import serialization, hashes - from cryptography.hazmat.primitives.asymmetric import ec - - from keepercommander.commands.tunnel.port_forward.tunnel_helpers import (generate_random_bytes, find_open_port) - - def generate_self_signed_cert(private_key): - # Generate a self-signed certificate - subject = issuer = x509.Name([ - x509.NameAttribute(NameOID.COMMON_NAME, u"localhost"), - ]) - cert = ( - x509.CertificateBuilder() - .subject_name(subject) - .issuer_name(issuer) - .public_key(private_key.public_key()) - .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.datetime.utcnow()) - .not_valid_after( - # Our certificate will be valid for 10 days - datetime.datetime.utcnow() + datetime.timedelta(days=10) - ) - .sign(private_key, hashes.SHA256(), default_backend()) +import datetime +import socket +import string +from cryptography import x509 +from cryptography.hazmat._oid import NameOID +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization, hashes +from cryptography.hazmat.primitives.asymmetric import ec + +from keepercommander.commands.tunnel.port_forward.tunnel_helpers import (generate_random_bytes, find_open_port) + +def generate_self_signed_cert(private_key): + # Generate a self-signed certificate + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, u"localhost"), + ]) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.utcnow()) + .not_valid_after( + # Our certificate will be valid for 10 days + datetime.datetime.utcnow() + datetime.timedelta(days=10) ) - cert_pem = cert.public_bytes(serialization.Encoding.PEM).decode('utf-8') - - return cert_pem + .sign(private_key, hashes.SHA256(), default_backend()) + ) + cert_pem = cert.public_bytes(serialization.Encoding.PEM).decode('utf-8') + + return cert_pem + + +def new_private_key(): + # Generate an EC private key + private_key = ec.generate_private_key( + ec.SECP256R1(), # Using P-256 curve + backend=default_backend() + ) + # Serialize to PEM format + private_key_str = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption() + ).decode('utf-8') + return private_key, private_key_str + + +class TestFindOpenPort(unittest.TestCase): + def mock_bind(self, address): + # Mock the behavior of socket.socket.bind + port = address[1] + if port in self.in_use_ports: + raise OSError("Address already in use") + else: + print(f"Port {port} bound successfully.") + + def test_preferred_port(self): + # Test that the function returns the preferred port if it's available + preferred_port = 50000 + open_port = find_open_port([], preferred_port=preferred_port) + self.assertEqual(open_port, preferred_port) + + def test_preferred_port_unavailable(self): + # Mock the bind method to simulate that port 80 is in use + with mock.patch('socket.socket.bind', side_effect=OSError("Address already in use")): + preferred_port = 80 + with self.assertRaises(CommandError): + open_port = find_open_port([], preferred_port=preferred_port) + + def test_range(self): + # Test that the function returns a port within the specified range + start_port = 50000 + end_port = 50010 + open_port = find_open_port([], start_port=start_port, end_port=end_port) + self.assertTrue(start_port <= open_port <= end_port) + + def test_no_available_ports(self): + # Setup + self.in_use_ports = set(range(50000, 50011)) # All these ports are in use + + # Patch + with mock.patch.object(socket.socket, 'bind', side_effect=self.mock_bind): + # Test + open_port = find_open_port([], start_port=50000, end_port=50010) + self.assertIsNone(open_port) + def test_invalid_range(self): + # Test that the function returns None if the range is invalid + open_port = find_open_port([], start_port=50010, end_port=50000) + self.assertIsNone(open_port) - def new_private_key(): - # Generate an EC private key - private_key = ec.generate_private_key( - ec.SECP256R1(), # Using P-256 curve - backend=default_backend() - ) - # Serialize to PEM format - private_key_str = private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption() - ).decode('utf-8') - return private_key, private_key_str - - - class TestFindOpenPort(unittest.TestCase): - def mock_bind(self, address): - # Mock the behavior of socket.socket.bind - port = address[1] - if port in self.in_use_ports: - raise OSError("Address already in use") - else: - print(f"Port {port} bound successfully.") - - def test_preferred_port(self): - # Test that the function returns the preferred port if it's available - preferred_port = 50000 - open_port = find_open_port([], preferred_port=preferred_port) - self.assertEqual(open_port, preferred_port) - - def test_preferred_port_unavailable(self): - # Mock the bind method to simulate that port 80 is in use - with mock.patch('socket.socket.bind', side_effect=OSError("Address already in use")): - preferred_port = 80 - with self.assertRaises(CommandError): - open_port = find_open_port([], preferred_port=preferred_port) - - def test_range(self): - # Test that the function returns a port within the specified range - start_port = 50000 - end_port = 50010 - open_port = find_open_port([], start_port=start_port, end_port=end_port) - self.assertTrue(start_port <= open_port <= end_port) - - def test_no_available_ports(self): - # Setup - self.in_use_ports = set(range(50000, 50011)) # All these ports are in use - - # Patch - with mock.patch.object(socket.socket, 'bind', side_effect=self.mock_bind): - # Test - open_port = find_open_port([], start_port=50000, end_port=50010) - self.assertIsNone(open_port) - - def test_invalid_range(self): - # Test that the function returns None if the range is invalid - open_port = find_open_port([], start_port=50010, end_port=50000) + def test_socket_exception(self): + # Test that the function handles exceptions other than OSError gracefully + with mock.patch('socket.socket.bind', side_effect=Exception("Test exception")): + open_port = find_open_port([], start_port=49152, end_port=49153, host='localhost') self.assertIsNone(open_port) - def test_socket_exception(self): - # Test that the function handles exceptions other than OSError gracefully - with mock.patch('socket.socket.bind', side_effect=Exception("Test exception")): - open_port = find_open_port([], start_port=49152, end_port=49153, host='localhost') - self.assertIsNone(open_port) - - def test_tried_ports(self): - # Setup - self.in_use_ports = {50000, 50001} # These ports are in use - - # Patch - with mock.patch.object(socket.socket, 'bind', side_effect=self.mock_bind): - # Test - open_port = find_open_port([50000, 50001], start_port=50000, end_port=50002) - self.assertEqual(open_port, 50002) - - - class TestGenerateRandomBytes(unittest.TestCase): - - def test_default_length(self): - # Test that the default length of the returned bytes is 32 - random_bytes = generate_random_bytes() - self.assertEqual(len(random_bytes), 32, f'Length 32 failed found {len(random_bytes)} in ' - f'{random_bytes}') - - def test_custom_length(self): - # Test custom lengths - for length in [1, 10, 20, 50, 100]: - random_bytes = generate_random_bytes(length) - self.assertEqual(len(random_bytes), length, f'Length {length} failed found {len(random_bytes)} in ' - f'{random_bytes}') - - def test_content(self): - # Test that the returned bytes only contain printable characters - for length in [1, 10, 20, 50, 100]: - random_bytes = generate_random_bytes(length) - self.assertTrue(all(byte in string.printable.encode('utf-8') for byte in random_bytes)) - - def test_zero_length(self): - # Test that a zero length returns an empty bytes object - random_bytes = generate_random_bytes(0) - self.assertEqual(random_bytes, b'') - - def test_negative_length(self): - # Test that a negative length raises a ValueError - with self.assertRaises(ValueError): - generate_random_bytes(-1) - - def test_type(self): - # Test that the return type is bytes - random_bytes = generate_random_bytes() - self.assertIsInstance(random_bytes, bytes) - - def test_uniqueness(self): - # Test that multiple calls return different values - random_bytes1 = generate_random_bytes() - random_bytes2 = generate_random_bytes() - self.assertNotEqual(random_bytes1, random_bytes2) + def test_tried_ports(self): + # Setup + self.in_use_ports = {50000, 50001} # These ports are in use + + # Patch + with mock.patch.object(socket.socket, 'bind', side_effect=self.mock_bind): + # Test + open_port = find_open_port([50000, 50001], start_port=50000, end_port=50002) + self.assertEqual(open_port, 50002) + + +class TestGenerateRandomBytes(unittest.TestCase): + + def test_default_length(self): + # Test that the default length of the returned bytes is 32 + random_bytes = generate_random_bytes() + self.assertEqual(len(random_bytes), 32, f'Length 32 failed found {len(random_bytes)} in ' + f'{random_bytes}') + + def test_custom_length(self): + # Test custom lengths + for length in [1, 10, 20, 50, 100]: + random_bytes = generate_random_bytes(length) + self.assertEqual(len(random_bytes), length, f'Length {length} failed found {len(random_bytes)} in ' + f'{random_bytes}') + + def test_content(self): + # Test that the returned bytes only contain printable characters + for length in [1, 10, 20, 50, 100]: + random_bytes = generate_random_bytes(length) + self.assertTrue(all(byte in string.printable.encode('utf-8') for byte in random_bytes)) + + def test_zero_length(self): + # Test that a zero length returns an empty bytes object + random_bytes = generate_random_bytes(0) + self.assertEqual(random_bytes, b'') + + def test_negative_length(self): + # Test that a negative length raises a ValueError + with self.assertRaises(ValueError): + generate_random_bytes(-1) + + def test_type(self): + # Test that the return type is bytes + random_bytes = generate_random_bytes() + self.assertIsInstance(random_bytes, bytes) + + def test_uniqueness(self): + # Test that multiple calls return different values + random_bytes1 = generate_random_bytes() + random_bytes2 = generate_random_bytes() + self.assertNotEqual(random_bytes1, random_bytes2) diff --git a/unit-tests/service/test_api_logging.py b/unit-tests/service/test_api_logging.py index 8ee60514c..c8ae53a86 100644 --- a/unit-tests/service/test_api_logging.py +++ b/unit-tests/service/test_api_logging.py @@ -1,90 +1,87 @@ -import sys -if sys.version_info >= (3, 8): - import pytest - from unittest import TestCase, mock - from flask import Flask, request - from keepercommander.service.decorators.api_logging import api_log_handler +from unittest import TestCase, mock +from flask import Flask, request +from keepercommander.service.decorators.api_logging import api_log_handler - class TestApiLogging(TestCase): - def setUp(self): - self.app = Flask(__name__) - self.client = self.app.test_client() +class TestApiLogging(TestCase): + def setUp(self): + self.app = Flask(__name__) + self.client = self.app.test_client() - @self.app.route('/test', methods=['POST']) - @api_log_handler - def test_endpoint(): - if not request.is_json: - return {'error': 'Content-Type must be application/json'}, 415 - return {'status': 'success'}, 200 + @self.app.route('/test', methods=['POST']) + @api_log_handler + def test_endpoint(): + if not request.is_json: + return {'error': 'Content-Type must be application/json'}, 415 + return {'status': 'success'}, 200 - @self.app.route('/error', methods=['POST']) - @api_log_handler - def error_endpoint(): - if not request.is_json: - return {'error': 'Content-Type must be application/json'}, 415 - raise Exception("Test error") + @self.app.route('/error', methods=['POST']) + @api_log_handler + def error_endpoint(): + if not request.is_json: + return {'error': 'Content-Type must be application/json'}, 415 + raise Exception("Test error") - def test_api_log_success_request(self): - """Test logging of successful API request""" - with mock.patch('keepercommander.service.decorators.api_logging.logger.info') as mock_log: - test_data = {"test": "data"} - response = self.client.post('/test', - json=test_data, - headers={ - 'X-Forwarded-For': '127.0.0.1', - 'Content-Type': 'application/json' - }) + def test_api_log_success_request(self): + """Test logging of successful API request""" + with mock.patch('keepercommander.service.decorators.api_logging.logger.info') as mock_log: + test_data = {"test": "data"} + response = self.client.post('/test', + json=test_data, + headers={ + 'X-Forwarded-For': '127.0.0.1', + 'Content-Type': 'application/json' + }) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json, {'status': 'success'}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, {'status': 'success'}) - mock_log.assert_called_once() - log_message = mock_log.call_args[0][0] - self.assertIn('POST', log_message) - self.assertIn('/test', log_message) - self.assertIn('127.0.0.1', log_message) - self.assertIn('200', log_message) - self.assertIn(f"data={str(test_data)}", log_message) + mock_log.assert_called_once() + log_message = mock_log.call_args[0][0] + self.assertIn('POST', log_message) + self.assertIn('/test', log_message) + self.assertIn('127.0.0.1', log_message) + self.assertIn('200', log_message) + self.assertIn(f"data={str(test_data)}", log_message) - def test_api_log_error_request(self): - """Test logging of failed API request""" - with mock.patch('keepercommander.service.decorators.api_logging.logger.error') as mock_log: - response = self.client.post('/error', json={}, - headers={'X-Forwarded-For': '127.0.0.1', - 'Content-Type': 'application/json'}) + def test_api_log_error_request(self): + """Test logging of failed API request""" + with mock.patch('keepercommander.service.decorators.api_logging.logger.error') as mock_log: + response = self.client.post('/error', json={}, + headers={'X-Forwarded-For': '127.0.0.1', + 'Content-Type': 'application/json'}) - self.assertEqual(response.status_code, 500) - mock_log.assert_called_once() - log_message = mock_log.call_args[0][0] - self.assertIn('POST', log_message) - self.assertIn('/error', log_message) - self.assertIn('127.0.0.1', log_message) - self.assertIn("error='Test error'", log_message) + self.assertEqual(response.status_code, 500) + mock_log.assert_called_once() + log_message = mock_log.call_args[0][0] + self.assertIn('POST', log_message) + self.assertIn('/error', log_message) + self.assertIn('127.0.0.1', log_message) + self.assertIn("error='Test error'", log_message) - def test_api_log_remote_addr_fallback(self): - """Test logging falls back to remote_addr when X-Forwarded-For is missing""" - with mock.patch('keepercommander.service.decorators.api_logging.logger.info') as mock_log: - response = self.client.post('/test', json={}, - headers={'Content-Type': 'application/json'}, - environ_base={'REMOTE_ADDR': '192.168.1.1'}) + def test_api_log_remote_addr_fallback(self): + """Test logging falls back to remote_addr when X-Forwarded-For is missing""" + with mock.patch('keepercommander.service.decorators.api_logging.logger.info') as mock_log: + response = self.client.post('/test', json={}, + headers={'Content-Type': 'application/json'}, + environ_base={'REMOTE_ADDR': '192.168.1.1'}) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json, {'status': 'success'}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, {'status': 'success'}) - mock_log.assert_called_once() - log_message = mock_log.call_args[0][0] - self.assertIn('192.168.1.1', log_message) + mock_log.assert_called_once() + log_message = mock_log.call_args[0][0] + self.assertIn('192.168.1.1', log_message) - def test_api_log_timing(self): - """Test request timing is logged""" - with mock.patch('keepercommander.service.decorators.api_logging.logger.info') as mock_log: - response = self.client.post('/test', json={}, - headers={'Content-Type': 'application/json'}) + def test_api_log_timing(self): + """Test request timing is logged""" + with mock.patch('keepercommander.service.decorators.api_logging.logger.info') as mock_log: + response = self.client.post('/test', json={}, + headers={'Content-Type': 'application/json'}) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json, {'status': 'success'}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, {'status': 'success'}) - mock_log.assert_called_once() - log_message = mock_log.call_args[0][0] - self.assertRegex(log_message, r'\d+\.\d+s') + mock_log.assert_called_once() + log_message = mock_log.call_args[0][0] + self.assertRegex(log_message, r'\d+\.\d+s') diff --git a/unit-tests/service/test_api_routes.py b/unit-tests/service/test_api_routes.py index 174e1e717..ec85eb4a3 100644 --- a/unit-tests/service/test_api_routes.py +++ b/unit-tests/service/test_api_routes.py @@ -1,77 +1,74 @@ -import sys - -if sys.version_info >= (3, 8): - import unittest - from unittest import mock - from flask import Blueprint, Flask - - from keepercommander.service.api.command import create_command_blueprint, create_legacy_command_blueprint - from keepercommander.service.api.routes import init_routes - - - def passthrough_decorator(): - def decorator(fn): - return fn - return decorator - - - class TestServiceApiRoutes(unittest.TestCase): - def test_queue_mode_registers_v1_and_v2_routes(self): - app = Flask(__name__) - onboarding_bp = Blueprint("test_onboarding", __name__) - - with mock.patch('keepercommander.service.api.command.unified_api_decorator', passthrough_decorator), \ - mock.patch('keepercommander.service.api.routes.create_onboarding_blueprint', return_value=onboarding_bp), \ - mock.patch('keepercommander.service.core.request_queue.queue_manager.start') as mock_start, \ - mock.patch('keepercommander.service.config.service_config.ServiceConfig.load_config', return_value={"queue_enabled": "y"}): - init_routes(app) - - routes = {rule.rule for rule in app.url_map.iter_rules()} - self.assertIn('/api/v1/executecommand', routes) - self.assertIn('/api/v2/executecommand-async', routes) - self.assertIn('/api/v2/status/', routes) - self.assertIn('/api/v2/result/', routes) - self.assertIn('/api/v2/queue/status', routes) - self.assertIn('/health', routes) - mock_start.assert_called_once() - - def test_legacy_mode_registers_only_v1_route(self): - app = Flask(__name__) - - with mock.patch('keepercommander.service.api.command.unified_api_decorator', passthrough_decorator), \ - mock.patch('keepercommander.service.config.service_config.ServiceConfig.load_config', return_value={"queue_enabled": "n"}): - init_routes(app) - - routes = {rule.rule for rule in app.url_map.iter_rules()} - self.assertIn('/api/v1/executecommand', routes) - self.assertNotIn('/api/v2/executecommand-async', routes) - - def test_v1_compatibility_route_waits_for_queue_result(self): - app = Flask(__name__) - - with mock.patch('keepercommander.service.api.command.unified_api_decorator', passthrough_decorator), \ - mock.patch('keepercommander.service.api.command.queue_manager.submit_request', return_value='req-1') as mock_submit, \ - mock.patch('keepercommander.service.api.command.queue_manager.wait_for_result', return_value=({"status": "success", "data": {"command": "ls"}}, 200)) as mock_wait: - app.register_blueprint(create_legacy_command_blueprint(use_queue=True), url_prefix='/api/v1') - response = app.test_client().post('/api/v1/executecommand', json={"command": "ls"}) - - self.assertEqual(response.status_code, 200) - self.assertEqual(response.headers.get('X-API-Legacy'), 'true') - self.assertEqual(response.get_json(), {"status": "success", "data": {"command": "ls"}}) - mock_submit.assert_called_once_with('ls', []) - mock_wait.assert_called_once_with('req-1') - - def test_v1_direct_route_keeps_legacy_execution_path(self): - app = Flask(__name__) - - with mock.patch('keepercommander.service.api.command.unified_api_decorator', passthrough_decorator), \ - mock.patch('keepercommander.service.api.command.CommandExecutor.execute', return_value=({"status": "success", "data": {"command": "ls"}}, 200)) as mock_execute, \ - mock.patch('keepercommander.service.api.command.queue_manager.submit_request') as mock_submit: - app.register_blueprint(create_legacy_command_blueprint(use_queue=False), url_prefix='/api/v1') - response = app.test_client().post('/api/v1/executecommand', json={"command": "ls"}) - - self.assertEqual(response.status_code, 200) - self.assertEqual(response.headers.get('X-API-Legacy'), 'true') - self.assertEqual(response.get_json(), {"status": "success", "data": {"command": "ls"}}) - mock_execute.assert_called_once_with('ls') - mock_submit.assert_not_called() +import unittest +from unittest import mock +from flask import Blueprint, Flask + +from keepercommander.service.api.command import create_legacy_command_blueprint +from keepercommander.service.api.routes import init_routes + + +def passthrough_decorator(): + def decorator(fn): + return fn + return decorator + + +class TestServiceApiRoutes(unittest.TestCase): + def test_queue_mode_registers_v1_and_v2_routes(self): + app = Flask(__name__) + onboarding_bp = Blueprint("test_onboarding", __name__) + + with mock.patch('keepercommander.service.api.command.unified_api_decorator', passthrough_decorator), \ + mock.patch('keepercommander.service.api.routes.create_onboarding_blueprint', return_value=onboarding_bp), \ + mock.patch('keepercommander.service.core.request_queue.queue_manager.start') as mock_start, \ + mock.patch('keepercommander.service.config.service_config.ServiceConfig.load_config', return_value={"queue_enabled": "y"}): + init_routes(app) + + routes = {rule.rule for rule in app.url_map.iter_rules()} + self.assertIn('/api/v1/executecommand', routes) + self.assertIn('/api/v2/executecommand-async', routes) + self.assertIn('/api/v2/status/', routes) + self.assertIn('/api/v2/result/', routes) + self.assertIn('/api/v2/queue/status', routes) + self.assertIn('/health', routes) + mock_start.assert_called_once() + + def test_legacy_mode_registers_only_v1_route(self): + app = Flask(__name__) + + with mock.patch('keepercommander.service.api.command.unified_api_decorator', passthrough_decorator), \ + mock.patch('keepercommander.service.config.service_config.ServiceConfig.load_config', return_value={"queue_enabled": "n"}): + init_routes(app) + + routes = {rule.rule for rule in app.url_map.iter_rules()} + self.assertIn('/api/v1/executecommand', routes) + self.assertNotIn('/api/v2/executecommand-async', routes) + + def test_v1_compatibility_route_waits_for_queue_result(self): + app = Flask(__name__) + + with mock.patch('keepercommander.service.api.command.unified_api_decorator', passthrough_decorator), \ + mock.patch('keepercommander.service.api.command.queue_manager.submit_request', return_value='req-1') as mock_submit, \ + mock.patch('keepercommander.service.api.command.queue_manager.wait_for_result', return_value=({"status": "success", "data": {"command": "ls"}}, 200)) as mock_wait: + app.register_blueprint(create_legacy_command_blueprint(use_queue=True), url_prefix='/api/v1') + response = app.test_client().post('/api/v1/executecommand', json={"command": "ls"}) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.headers.get('X-API-Legacy'), 'true') + self.assertEqual(response.get_json(), {"status": "success", "data": {"command": "ls"}}) + mock_submit.assert_called_once_with('ls', []) + mock_wait.assert_called_once_with('req-1') + + def test_v1_direct_route_keeps_legacy_execution_path(self): + app = Flask(__name__) + + with mock.patch('keepercommander.service.api.command.unified_api_decorator', passthrough_decorator), \ + mock.patch('keepercommander.service.api.command.CommandExecutor.execute', return_value=({"status": "success", "data": {"command": "ls"}}, 200)) as mock_execute, \ + mock.patch('keepercommander.service.api.command.queue_manager.submit_request') as mock_submit: + app.register_blueprint(create_legacy_command_blueprint(use_queue=False), url_prefix='/api/v1') + response = app.test_client().post('/api/v1/executecommand', json={"command": "ls"}) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.headers.get('X-API-Legacy'), 'true') + self.assertEqual(response.get_json(), {"status": "success", "data": {"command": "ls"}}) + mock_execute.assert_called_once_with('ls') + mock_submit.assert_not_called() diff --git a/unit-tests/service/test_auth_security.py b/unit-tests/service/test_auth_security.py index 3f93b6610..bbd94ba4b 100644 --- a/unit-tests/service/test_auth_security.py +++ b/unit-tests/service/test_auth_security.py @@ -1,102 +1,99 @@ -import sys -if sys.version_info >= (3, 8): - import pytest - from unittest import TestCase, mock - from flask import Flask - from keepercommander.service.decorators.auth import auth_check, policy_check - from keepercommander.service.decorators.security import security_check, is_allowed_ip - from keepercommander.service.util.config_reader import ConfigReader +from unittest import TestCase, mock +from flask import Flask +from keepercommander.service.decorators.auth import auth_check, policy_check +from keepercommander.service.decorators.security import security_check, is_allowed_ip +from keepercommander.service.util.config_reader import ConfigReader - class TestAuthSecurity(TestCase): - def setUp(self): - self.app = Flask(__name__) - self.client = self.app.test_client() +class TestAuthSecurity(TestCase): + def setUp(self): + self.app = Flask(__name__) + self.client = self.app.test_client() - @self.app.route('/test', methods=['POST']) - @security_check - @auth_check - @policy_check - def test_endpoint(): - return {'status': 'success'}, 200 + @self.app.route('/test', methods=['POST']) + @security_check + @auth_check + @policy_check + def test_endpoint(): + return {'status': 'success'}, 200 - def test_auth_check_missing_api_key(self): - """Test authentication with missing API key""" - with self.app.test_request_context('/test', method='POST'): - response = auth_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() - self.assertEqual(response[1], 401) - self.assertEqual(response[0]['status'], 'error') - self.assertIn('api key', response[0]['error']) + def test_auth_check_missing_api_key(self): + """Test authentication with missing API key""" + with self.app.test_request_context('/test', method='POST'): + response = auth_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() + self.assertEqual(response[1], 401) + self.assertEqual(response[0]['status'], 'error') + self.assertIn('api key', response[0]['error']) - @mock.patch.object(ConfigReader, 'read_config') - def test_auth_check_invalid_api_key(self, mock_read_config): - """Test authentication with invalid API key""" - mock_read_config.return_value = "different_key" + @mock.patch.object(ConfigReader, 'read_config') + def test_auth_check_invalid_api_key(self, mock_read_config): + """Test authentication with invalid API key""" + mock_read_config.return_value = "different_key" - with self.app.test_request_context('/test', method='POST', - headers={'api-key': 'test_key'}): - response = auth_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() - self.assertEqual(response[1], 401) - self.assertEqual(response[0]['status'], 'error') + with self.app.test_request_context('/test', method='POST', + headers={'api-key': 'test_key'}): + response = auth_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() + self.assertEqual(response[1], 401) + self.assertEqual(response[0]['status'], 'error') - @mock.patch.object(ConfigReader, 'read_config') - def test_auth_check_expired_key(self, mock_read_config): - """Test authentication with expired API key""" - mock_read_config.side_effect = [ - "test_key", - "2024-01-01T00:00:00" - ] + @mock.patch.object(ConfigReader, 'read_config') + def test_auth_check_expired_key(self, mock_read_config): + """Test authentication with expired API key""" + mock_read_config.side_effect = [ + "test_key", + "2024-01-01T00:00:00" + ] - with self.app.test_request_context('/test', method='POST', - headers={'api-key': 'test_key'}): - response = auth_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() - self.assertEqual(response[1], 401) - self.assertEqual(response[0]['status'], 'error') - self.assertIn('expired', response[0]['error']) + with self.app.test_request_context('/test', method='POST', + headers={'api-key': 'test_key'}): + response = auth_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() + self.assertEqual(response[1], 401) + self.assertEqual(response[0]['status'], 'error') + self.assertIn('expired', response[0]['error']) - # def test_security_check_blocked_ip(self): - # """Test security check with blocked IP""" - # with mock.patch.object(ConfigReader, 'read_config', return_value="192.168.1.1"): - # with self.app.test_request_context('/test', method='POST', - # environ_base={'REMOTE_ADDR': '192.168.1.1'}): - # response = security_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() - # response_data = response[0].get_json() - # self.assertEqual(response[1], 403) - # self.assertEqual(response_data['error'], 'IP is blocked') + # def test_security_check_blocked_ip(self): + # """Test security check with blocked IP""" + # with mock.patch.object(ConfigReader, 'read_config', return_value="192.168.1.1"): + # with self.app.test_request_context('/test', method='POST', + # environ_base={'REMOTE_ADDR': '192.168.1.1'}): + # response = security_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() + # response_data = response[0].get_json() + # self.assertEqual(response[1], 403) + # self.assertEqual(response_data['error'], 'IP is blocked') - def test_is_blocked_ip_single_ip(self): - """Test IP blocking with single IP address""" - blocked_ips = "192.168.1.1" - allowed_ips="192.168.1.2" - self.assertFalse(is_allowed_ip("192.168.1.1", allowed_ips, blocked_ips)) - self.assertTrue(is_allowed_ip("192.168.1.2", allowed_ips, blocked_ips)) + def test_is_blocked_ip_single_ip(self): + """Test IP blocking with single IP address""" + blocked_ips = "192.168.1.1" + allowed_ips="192.168.1.2" + self.assertFalse(is_allowed_ip("192.168.1.1", allowed_ips, blocked_ips)) + self.assertTrue(is_allowed_ip("192.168.1.2", allowed_ips, blocked_ips)) - def test_is_blocked_ip_cidr(self): - """Test IP blocking with CIDR notation""" - allowed_ips="192.168.1.1" - blocked_ips = "192.168.1.0" - self.assertTrue(is_allowed_ip("192.168.1.1", allowed_ips, blocked_ips)) - self.assertFalse(is_allowed_ip("192.168.1.254", allowed_ips, blocked_ips)) - self.assertFalse(is_allowed_ip("192.168.2.1", allowed_ips, blocked_ips)) + def test_is_blocked_ip_cidr(self): + """Test IP blocking with CIDR notation""" + allowed_ips="192.168.1.1" + blocked_ips = "192.168.1.0" + self.assertTrue(is_allowed_ip("192.168.1.1", allowed_ips, blocked_ips)) + self.assertFalse(is_allowed_ip("192.168.1.254", allowed_ips, blocked_ips)) + self.assertFalse(is_allowed_ip("192.168.2.1", allowed_ips, blocked_ips)) - @mock.patch.object(ConfigReader, 'read_config') - def test_policy_check_allowed_command(self, mock_read_config): - """Test policy check with allowed command""" - mock_read_config.return_value = "list,get,search" + @mock.patch.object(ConfigReader, 'read_config') + def test_policy_check_allowed_command(self, mock_read_config): + """Test policy check with allowed command""" + mock_read_config.return_value = "list,get,search" - with self.app.test_request_context('/test', method='POST', - json={"command": "list"}): - response = policy_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() - self.assertEqual(response[1], 200) - self.assertEqual(response[0]['status'], 'success') + with self.app.test_request_context('/test', method='POST', + json={"command": "list"}): + response = policy_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() + self.assertEqual(response[1], 200) + self.assertEqual(response[0]['status'], 'success') - @mock.patch.object(ConfigReader, 'read_config') - def test_policy_check_denied_command(self, mock_read_config): - """Test policy check with denied command""" - mock_read_config.return_value = "list,get,search" + @mock.patch.object(ConfigReader, 'read_config') + def test_policy_check_denied_command(self, mock_read_config): + """Test policy check with denied command""" + mock_read_config.return_value = "list,get,search" - with self.app.test_request_context('/test', method='POST', - json={"command": "delete"}): - response = policy_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() - self.assertEqual(response[1], 403) - self.assertEqual(response[0]['status'], 'error') - self.assertIn('Not permitted', response[0]['error']) \ No newline at end of file + with self.app.test_request_context('/test', method='POST', + json={"command": "delete"}): + response = policy_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() + self.assertEqual(response[1], 403) + self.assertEqual(response[0]['status'], 'error') + self.assertIn('Not permitted', response[0]['error']) \ No newline at end of file diff --git a/unit-tests/service/test_command.py b/unit-tests/service/test_command.py index fbdcb9845..2df0637b1 100644 --- a/unit-tests/service/test_command.py +++ b/unit-tests/service/test_command.py @@ -1,121 +1,118 @@ -import sys import unittest -if sys.version_info >= (3, 8): - import pytest - from unittest import TestCase, mock - from flask import Flask - from keepercommander.service.util.command_util import CommandExecutor - from keepercommander.service.util.exceptions import CommandExecutionError - from keepercommander.service.util.parse_keeper_response import parse_keeper_response - - class TestCommandAPI(TestCase): - def setUp(self): - self.app = Flask(__name__) - self.client = self.app.test_client() +from unittest import TestCase, mock +from flask import Flask +from keepercommander.service.util.command_util import CommandExecutor +from keepercommander.service.util.exceptions import CommandExecutionError +from keepercommander.service.util.parse_keeper_response import parse_keeper_response + +class TestCommandAPI(TestCase): + def setUp(self): + self.app = Flask(__name__) + self.client = self.app.test_client() - @self.app.route('/api/v1/executecommand', methods=['POST']) - def execute_command(): - command = "ls" - response, status_code = CommandExecutor.execute(command) - return {'response': response}, status_code - - def test_validate_command(self): - """Test command validation""" - result, status_code = CommandExecutor.validate_command("") - self.assertIsNotNone(result) - self.assertEqual(status_code, 400) - self.assertEqual(result["error"], "No command provided.") - - result = CommandExecutor.validate_command("ls") + @self.app.route('/api/v1/executecommand', methods=['POST']) + def execute_command(): + command = "ls" + response, status_code = CommandExecutor.execute(command) + return {'response': response}, status_code + + def test_validate_command(self): + """Test command validation""" + result, status_code = CommandExecutor.validate_command("") + self.assertIsNotNone(result) + self.assertEqual(status_code, 400) + self.assertEqual(result["error"], "No command provided.") + + result = CommandExecutor.validate_command("ls") + self.assertIsNone(result) + + def test_validate_session(self): + """Test session validation""" + with mock.patch('keepercommander.service.util.command_util.get_current_params', return_value=None): + result, status_code = CommandExecutor.validate_session() + self.assertEqual(status_code, 401) + self.assertIn("No active session", result["error"]) + + with mock.patch('keepercommander.service.util.command_util.get_current_params', return_value={"session": "active"}): + result = CommandExecutor.validate_session() self.assertIsNone(result) - def test_validate_session(self): - """Test session validation""" - with mock.patch('keepercommander.service.util.command_util.get_current_params', return_value=None): - result, status_code = CommandExecutor.validate_session() - self.assertEqual(status_code, 401) - self.assertIn("No active session", result["error"]) - - with mock.patch('keepercommander.service.util.command_util.get_current_params', return_value={"session": "active"}): - result = CommandExecutor.validate_session() - self.assertIsNone(result) - - @unittest.skip - def test_command_execution_success(self): - """Test successful command execution""" - mock_params = {"session": "active"} - test_command = "ls" - expected_output = "Folder1\nFolder2\n" - - with mock.patch('keepercommander.service.util.command_util.get_current_params', return_value=mock_params), \ - mock.patch('keepercommander.cli.do_command', return_value=expected_output), \ - mock.patch('keepercommander.service.util.command_util.ConfigReader.read_config', return_value=None): - - response, status_code = CommandExecutor.execute(test_command) - self.assertEqual(status_code, 200) - self.assertIsNotNone(response) - - @unittest.skip - def test_command_execution_failure(self): - """Test command execution failure""" - mock_params = {"session": "active"} - test_command = "invalid_command" - - with mock.patch('keepercommander.service.util.command_util.get_current_params', return_value=mock_params), \ - mock.patch('keepercommander.cli.do_command', side_effect=Exception("Command failed")), \ - self.assertRaises(CommandExecutionError): + @unittest.skip + def test_command_execution_success(self): + """Test successful command execution""" + mock_params = {"session": "active"} + test_command = "ls" + expected_output = "Folder1\nFolder2\n" + + with mock.patch('keepercommander.service.util.command_util.get_current_params', return_value=mock_params), \ + mock.patch('keepercommander.cli.do_command', return_value=expected_output), \ + mock.patch('keepercommander.service.util.command_util.ConfigReader.read_config', return_value=None): + + response, status_code = CommandExecutor.execute(test_command) + self.assertEqual(status_code, 200) + self.assertIsNotNone(response) + + @unittest.skip + def test_command_execution_failure(self): + """Test command execution failure""" + mock_params = {"session": "active"} + test_command = "invalid_command" + + with mock.patch('keepercommander.service.util.command_util.get_current_params', return_value=mock_params), \ + mock.patch('keepercommander.cli.do_command', side_effect=Exception("Command failed")), \ + self.assertRaises(CommandExecutionError): - CommandExecutor.execute(test_command) - - def test_response_encryption(self): - """Test response encryption when key is present""" - test_response = {"status": "success", "data": "test"} - - mock_key = "0" * 32 - - with mock.patch('keepercommander.service.util.command_util.ConfigReader.read_config', return_value=mock_key): - encrypted_response = CommandExecutor.encrypt_response(test_response) - self.assertIsInstance(encrypted_response, bytes) - self.assertGreater(len(encrypted_response), 0) - - def test_response_parsing(self): - """Test response parsing for different commands""" - - ls_response = "# Folder UID\n1 folder1_uid folder1 rw\n# Record UID\n1 record1_uid login record1" - parsed = parse_keeper_response("ls", ls_response) - self.assertEqual(parsed["status"], "success") - self.assertEqual(parsed["command"], "ls") - self.assertIn("folders", parsed["data"]) - self.assertIn("records", parsed["data"]) - - tree_response = "Root\n Folder1\n SubFolder1" - parsed = parse_keeper_response("tree", tree_response) - self.assertEqual(parsed["command"], "tree") - self.assertIsInstance(parsed["data"], dict) - self.assertIn("tree", parsed["data"]) - - def test_capture_output(self): - """Test command output capture""" - test_command = "ls" - expected_output = "test output" - mock_params = {"session": "active"} - - with mock.patch('keepercommander.cli.do_command', return_value=expected_output): - return_value, output, logs = CommandExecutor.capture_output_and_logs(mock_params, test_command) - self.assertEqual(return_value, expected_output) - - @unittest.skip - def test_integration_command_flow(self): - """Test the complete command execution flow""" - test_command = "ls" - mock_params = {"session": "active"} - expected_output = "# Folder UID\n1 folder1_uid folder1 rw" - - with mock.patch('keepercommander.service.util.command_util.get_current_params', return_value=mock_params), \ - mock.patch('keepercommander.cli.do_command', return_value=expected_output), \ - mock.patch('keepercommander.service.util.command_util.ConfigReader.read_config', return_value=None): - - response, status_code = CommandExecutor.execute(test_command) - self.assertEqual(status_code, 200) - self.assertIsNotNone(response) \ No newline at end of file + CommandExecutor.execute(test_command) + + def test_response_encryption(self): + """Test response encryption when key is present""" + test_response = {"status": "success", "data": "test"} + + mock_key = "0" * 32 + + with mock.patch('keepercommander.service.util.command_util.ConfigReader.read_config', return_value=mock_key): + encrypted_response = CommandExecutor.encrypt_response(test_response) + self.assertIsInstance(encrypted_response, bytes) + self.assertGreater(len(encrypted_response), 0) + + def test_response_parsing(self): + """Test response parsing for different commands""" + + ls_response = "# Folder UID\n1 folder1_uid folder1 rw\n# Record UID\n1 record1_uid login record1" + parsed = parse_keeper_response("ls", ls_response) + self.assertEqual(parsed["status"], "success") + self.assertEqual(parsed["command"], "ls") + self.assertIn("folders", parsed["data"]) + self.assertIn("records", parsed["data"]) + + tree_response = "Root\n Folder1\n SubFolder1" + parsed = parse_keeper_response("tree", tree_response) + self.assertEqual(parsed["command"], "tree") + self.assertIsInstance(parsed["data"], dict) + self.assertIn("tree", parsed["data"]) + + def test_capture_output(self): + """Test command output capture""" + test_command = "ls" + expected_output = "test output" + mock_params = {"session": "active"} + + with mock.patch('keepercommander.cli.do_command', return_value=expected_output): + return_value, output, logs = CommandExecutor.capture_output_and_logs(mock_params, test_command) + self.assertEqual(return_value, expected_output) + + @unittest.skip + def test_integration_command_flow(self): + """Test the complete command execution flow""" + test_command = "ls" + mock_params = {"session": "active"} + expected_output = "# Folder UID\n1 folder1_uid folder1 rw" + + with mock.patch('keepercommander.service.util.command_util.get_current_params', return_value=mock_params), \ + mock.patch('keepercommander.cli.do_command', return_value=expected_output), \ + mock.patch('keepercommander.service.util.command_util.ConfigReader.read_config', return_value=None): + + response, status_code = CommandExecutor.execute(test_command) + self.assertEqual(status_code, 200) + self.assertIsNotNone(response) \ No newline at end of file diff --git a/unit-tests/service/test_config_operation.py b/unit-tests/service/test_config_operation.py index 3db9c6be0..6c16ae796 100644 --- a/unit-tests/service/test_config_operation.py +++ b/unit-tests/service/test_config_operation.py @@ -1,83 +1,81 @@ -import sys -if sys.version_info >= (3, 8): - from unittest import TestCase, mock - from keepercommander.params import KeeperParams - from keepercommander.service.config.service_config import ServiceConfig - from keepercommander.service.commands.config_operation import AddConfigService +from unittest import TestCase, mock +from keepercommander.params import KeeperParams +from keepercommander.service.config.service_config import ServiceConfig +from keepercommander.service.commands.config_operation import AddConfigService - class TestConfigOperation(TestCase): - def setUp(self): - self.mock_params = mock.Mock(spec=KeeperParams) - self.cmd = AddConfigService() +class TestConfigOperation(TestCase): + def setUp(self): + self.mock_params = mock.Mock(spec=KeeperParams) + self.cmd = AddConfigService() - def test_execute_with_existing_config(self): - mock_config = { - "is_advanced_security_enabled": "y", - "records": [] - } - mock_record = { - "api-key": "test-api-key", - "command_list": "list", - "expiration_timestamp": "2024-12-31T23:59:59", - #"expiration_of_token": "" - } + def test_execute_with_existing_config(self): + mock_config = { + "is_advanced_security_enabled": "y", + "records": [] + } + mock_record = { + "api-key": "test-api-key", + "command_list": "list", + "expiration_timestamp": "2024-12-31T23:59:59", + #"expiration_of_token": "" + } - with mock.patch.object(ServiceConfig, 'load_config', return_value=mock_config), \ - mock.patch.object(ServiceConfig, 'create_record', return_value=mock_record), \ - mock.patch.object(ServiceConfig, 'save_config') as mock_save, \ - mock.patch('builtins.print'): + with mock.patch.object(ServiceConfig, 'load_config', return_value=mock_config), \ + mock.patch.object(ServiceConfig, 'create_record', return_value=mock_record), \ + mock.patch.object(ServiceConfig, 'save_config') as mock_save, \ + mock.patch('builtins.print'): - self.cmd.execute(self.mock_params) + self.cmd.execute(self.mock_params) - expected_config = { - "is_advanced_security_enabled": "y", - "records": [mock_record] - } - mock_save.assert_called_once_with(expected_config) + expected_config = { + "is_advanced_security_enabled": "y", + "records": [mock_record] + } + mock_save.assert_called_once_with(expected_config) - def test_execute_when_config_not_found(self): - with mock.patch.object(ServiceConfig, 'load_config', side_effect=FileNotFoundError), \ - mock.patch('builtins.print') as mock_print: + def test_execute_when_config_not_found(self): + with mock.patch.object(ServiceConfig, 'load_config', side_effect=FileNotFoundError), \ + mock.patch('builtins.print') as mock_print: - result = self.cmd.execute(self.mock_params) + result = self.cmd.execute(self.mock_params) - mock_print.assert_called_with( - "Error: Service configuration file not found. Please use 'service-create' command to create a service_config file." - ) - self.assertEqual(result, '') + mock_print.assert_called_with( + "Error: Service configuration file not found. Please use 'service-create' command to create a service_config file." + ) + self.assertEqual(result, '') - def test_execute_with_general_error(self): - with mock.patch.object(ServiceConfig, 'load_config', side_effect=Exception("Test error")), \ - mock.patch('builtins.print') as mock_print: + def test_execute_with_general_error(self): + with mock.patch.object(ServiceConfig, 'load_config', side_effect=Exception("Test error")), \ + mock.patch('builtins.print') as mock_print: - result = self.cmd.execute(self.mock_params) + result = self.cmd.execute(self.mock_params) - mock_print.assert_called_with( - "Error: Service configuration file not found. Please use 'service-create' command to create a service_config file." - ) - self.assertEqual(result, '') + mock_print.assert_called_with( + "Error: Service configuration file not found. Please use 'service-create' command to create a service_config file." + ) + self.assertEqual(result, '') - def test_create_and_add_record(self): - mock_config = { - "is_advanced_security_enabled": "n", - "records": [{"existing": "record"}] - } - mock_record = { - "api-key": "new-api-key", - "command_list": "list", - "expiration_timestamp": "2024-12-31T23:59:59", - #"expiration_of_token": "" - } + def test_create_and_add_record(self): + mock_config = { + "is_advanced_security_enabled": "n", + "records": [{"existing": "record"}] + } + mock_record = { + "api-key": "new-api-key", + "command_list": "list", + "expiration_timestamp": "2024-12-31T23:59:59", + #"expiration_of_token": "" + } - with mock.patch.object(ServiceConfig, 'load_config', return_value=mock_config), \ - mock.patch.object(ServiceConfig, 'create_record', return_value=mock_record), \ - mock.patch.object(ServiceConfig, 'save_config') as mock_save, \ - mock.patch('builtins.print'): + with mock.patch.object(ServiceConfig, 'load_config', return_value=mock_config), \ + mock.patch.object(ServiceConfig, 'create_record', return_value=mock_record), \ + mock.patch.object(ServiceConfig, 'save_config') as mock_save, \ + mock.patch('builtins.print'): - self.cmd.execute(self.mock_params) + self.cmd.execute(self.mock_params) - expected_config = { - "is_advanced_security_enabled": "n", - "records": [{"existing": "record"}, mock_record] - } - mock_save.assert_called_once_with(expected_config) \ No newline at end of file + expected_config = { + "is_advanced_security_enabled": "n", + "records": [{"existing": "record"}, mock_record] + } + mock_save.assert_called_once_with(expected_config) \ No newline at end of file diff --git a/unit-tests/service/test_config_validation.py b/unit-tests/service/test_config_validation.py index 77b28aa9e..5fb40ac53 100644 --- a/unit-tests/service/test_config_validation.py +++ b/unit-tests/service/test_config_validation.py @@ -1,168 +1,166 @@ -import sys -if sys.version_info >= (3, 8): - import unittest - from unittest.mock import patch - import socket - from datetime import timedelta - from keepercommander.service.config.config_validation import ConfigValidator - from keepercommander.service.util.exceptions import ValidationError +import unittest +from unittest.mock import patch +import socket +from datetime import timedelta +from keepercommander.service.config.config_validation import ConfigValidator +from keepercommander.service.util.exceptions import ValidationError - class TestConfigValidator(unittest.TestCase): - def setUp(self): - self.validator = ConfigValidator() +class TestConfigValidator(unittest.TestCase): + def setUp(self): + self.validator = ConfigValidator() - def test_validate_port_valid(self): - """Test port validation with valid port numbers""" - test_ports = [1024, 8080, 8900, 9000, 65535] - for port in test_ports: - with self.subTest(port=port): - with patch('socket.socket') as mock_socket: - mock_socket.return_value.__enter__.return_value.bind.return_value = None - result = self.validator.validate_port(port) - self.assertEqual(result, port) + def test_validate_port_valid(self): + """Test port validation with valid port numbers""" + test_ports = [1024, 8080, 8900, 9000, 65535] + for port in test_ports: + with self.subTest(port=port): + with patch('socket.socket') as mock_socket: + mock_socket.return_value.__enter__.return_value.bind.return_value = None + result = self.validator.validate_port(port) + self.assertEqual(result, port) - def test_validate_port_invalid_number(self): - """Test port validation with invalid port numbers""" - invalid_ports = [-1, 0, 80, 443, 1023, 65536, 'abc', ''] - for port in invalid_ports: - with self.subTest(port=port): - with self.assertRaises(ValidationError): - self.validator.validate_port(port) + def test_validate_port_invalid_number(self): + """Test port validation with invalid port numbers""" + invalid_ports = [-1, 0, 80, 443, 1023, 65536, 'abc', ''] + for port in invalid_ports: + with self.subTest(port=port): + with self.assertRaises(ValidationError): + self.validator.validate_port(port) - def test_validate_port_in_use(self): - """Test port validation when port is already in use""" - with patch('socket.socket') as mock_socket: - mock_socket.return_value.__enter__.return_value.bind.side_effect = socket.error() - with self.assertRaises(ValidationError) as context: - self.validator.validate_port(8080) - self.assertIn("is already in use", str(context.exception)) + def test_validate_port_in_use(self): + """Test port validation when port is already in use""" + with patch('socket.socket') as mock_socket: + mock_socket.return_value.__enter__.return_value.bind.side_effect = socket.error() + with self.assertRaises(ValidationError) as context: + self.validator.validate_port(8080) + self.assertIn("is already in use", str(context.exception)) - def test_validate_ngrok_token_valid(self): - """Test ngrok token validation with valid tokens""" - valid_tokens = [ - '1234567890abcdef', - 'abcdef1234567890', - 'abc123_def456-789' - ] - for token in valid_tokens: - with self.subTest(token=token): - result = self.validator.validate_ngrok_token(token) - self.assertEqual(result, token) + def test_validate_ngrok_token_valid(self): + """Test ngrok token validation with valid tokens""" + valid_tokens = [ + '1234567890abcdef', + 'abcdef1234567890', + 'abc123_def456-789' + ] + for token in valid_tokens: + with self.subTest(token=token): + result = self.validator.validate_ngrok_token(token) + self.assertEqual(result, token) - def test_validate_ngrok_token_invalid(self): - """Test ngrok token validation with invalid tokens""" - invalid_tokens = [ - '', - '123', - 'abc@def', - None - ] - for token in invalid_tokens: - with self.subTest(token=token): - with self.assertRaises(ValidationError): - self.validator.validate_ngrok_token(token) + def test_validate_ngrok_token_invalid(self): + """Test ngrok token validation with invalid tokens""" + invalid_tokens = [ + '', + '123', + 'abc@def', + None + ] + for token in invalid_tokens: + with self.subTest(token=token): + with self.assertRaises(ValidationError): + self.validator.validate_ngrok_token(token) - def test_validate_rate_limit_valid(self): - """Test rate limit validation with valid formats""" - valid_limits = [ - '10/minute', - '100/hour', - '1000/day', - '50 per minute', - '200 per hour', - '5000 per day' - ] - for limit in valid_limits: - with self.subTest(limit=limit): - result = self.validator.validate_rate_limit(limit) - self.assertEqual(result, limit) + def test_validate_rate_limit_valid(self): + """Test rate limit validation with valid formats""" + valid_limits = [ + '10/minute', + '100/hour', + '1000/day', + '50 per minute', + '200 per hour', + '5000 per day' + ] + for limit in valid_limits: + with self.subTest(limit=limit): + result = self.validator.validate_rate_limit(limit) + self.assertEqual(result, limit) - def test_validate_rate_limit_invalid(self): - """Test rate limit validation with invalid formats""" - invalid_limits = [ - 'abc', - '10/second', - '100 by hour', - '0/minute', - '0/hour', - '0/day', - '0 per minute', - ] - for limit in invalid_limits: - with self.subTest(limit=limit): - with self.assertRaises(ValidationError): - self.validator.validate_rate_limit(limit) + def test_validate_rate_limit_invalid(self): + """Test rate limit validation with invalid formats""" + invalid_limits = [ + 'abc', + '10/second', + '100 by hour', + '0/minute', + '0/hour', + '0/day', + '0 per minute', + ] + for limit in invalid_limits: + with self.subTest(limit=limit): + with self.assertRaises(ValidationError): + self.validator.validate_rate_limit(limit) - def test_validate_ip_list_valid(self): - """Test IP list validation with valid IPs and CIDR blocks""" - valid_ips = [ - '192.168.1.1', - '10.0.0.0/24', - '192.168.1.1,10.0.0.0/24', - '2001:db8::1', - 'fe80::/10' - ] - for ip_list in valid_ips: - with self.subTest(ip_list=ip_list): - result = self.validator.validate_ip_list(ip_list) - self.assertEqual(result, ip_list) + def test_validate_ip_list_valid(self): + """Test IP list validation with valid IPs and CIDR blocks""" + valid_ips = [ + '192.168.1.1', + '10.0.0.0/24', + '192.168.1.1,10.0.0.0/24', + '2001:db8::1', + 'fe80::/10' + ] + for ip_list in valid_ips: + with self.subTest(ip_list=ip_list): + result = self.validator.validate_ip_list(ip_list) + self.assertEqual(result, ip_list) - def test_validate_ip_list_invalid(self): - """Test IP list validation with invalid IPs""" - invalid_ips = [ - '256.256.256.256', - '192.168.1', - '2001:xyz::1', - '192.168.1.1/33', - ] - for ip_list in invalid_ips: - with self.subTest(ip_list=ip_list): - with self.assertRaises(ValidationError): - self.validator.validate_ip_list(ip_list) + def test_validate_ip_list_invalid(self): + """Test IP list validation with invalid IPs""" + invalid_ips = [ + '256.256.256.256', + '192.168.1', + '2001:xyz::1', + '192.168.1.1/33', + ] + for ip_list in invalid_ips: + with self.subTest(ip_list=ip_list): + with self.assertRaises(ValidationError): + self.validator.validate_ip_list(ip_list) - def test_validate_encryption_key_valid(self): - """Test encryption key validation with valid keys""" - valid_key = 'abcdef1234567890ABCDEF1234567890' - result = self.validator.validate_encryption_key(valid_key) - self.assertEqual(result, valid_key) + def test_validate_encryption_key_valid(self): + """Test encryption key validation with valid keys""" + valid_key = 'abcdef1234567890ABCDEF1234567890' + result = self.validator.validate_encryption_key(valid_key) + self.assertEqual(result, valid_key) - def test_validate_encryption_key_invalid(self): - """Test encryption key validation with invalid keys""" - invalid_keys = [ - '', - '123456', - 'a' * 31, - 'a' * 33, - 'abc$%^&*()', - None - ] - for key in invalid_keys: - with self.subTest(key=key): - with self.assertRaises(ValidationError): - self.validator.validate_encryption_key(key) + def test_validate_encryption_key_invalid(self): + """Test encryption key validation with invalid keys""" + invalid_keys = [ + '', + '123456', + 'a' * 31, + 'a' * 33, + 'abc$%^&*()', + None + ] + for key in invalid_keys: + with self.subTest(key=key): + with self.assertRaises(ValidationError): + self.validator.validate_encryption_key(key) - def test_parse_expiration_time_valid(self): - """Test expiration time parsing with valid formats""" - test_cases = [ - ('30m', timedelta(minutes=30)), - ('24h', timedelta(hours=24)), - ('7d', timedelta(days=7)) - ] - for input_str, expected in test_cases: - with self.subTest(input_str=input_str): - result = self.validator.parse_expiration_time(input_str) - self.assertEqual(result, expected) + def test_parse_expiration_time_valid(self): + """Test expiration time parsing with valid formats""" + test_cases = [ + ('30m', timedelta(minutes=30)), + ('24h', timedelta(hours=24)), + ('7d', timedelta(days=7)) + ] + for input_str, expected in test_cases: + with self.subTest(input_str=input_str): + result = self.validator.parse_expiration_time(input_str) + self.assertEqual(result, expected) - def test_parse_expiration_time_invalid(self): - """Test expiration time parsing with invalid formats""" - invalid_times = [ - '', - '30x', - '-30m', - '0m', - 'abc', - ] - for time_str in invalid_times: - with self.subTest(time_str=time_str): - with self.assertRaises(ValidationError): - self.validator.parse_expiration_time(time_str) \ No newline at end of file + def test_parse_expiration_time_invalid(self): + """Test expiration time parsing with invalid formats""" + invalid_times = [ + '', + '30x', + '-30m', + '0m', + 'abc', + ] + for time_str in invalid_times: + with self.subTest(time_str=time_str): + with self.assertRaises(ValidationError): + self.validator.parse_expiration_time(time_str) \ No newline at end of file diff --git a/unit-tests/service/test_create_service.py b/unit-tests/service/test_create_service.py index a81017ea8..d761034a4 100644 --- a/unit-tests/service/test_create_service.py +++ b/unit-tests/service/test_create_service.py @@ -1,349 +1,347 @@ -import sys -if sys.version_info >= (3, 8): - import unittest - from unittest.mock import Mock, patch - from keepercommander.params import KeeperParams - from keepercommander.service.commands.create_service import CreateService, StreamlineArgs +import unittest +from unittest.mock import Mock, patch +from keepercommander.params import KeeperParams +from keepercommander.service.commands.create_service import CreateService, StreamlineArgs - class TestCreateService(unittest.TestCase): - def setUp(self): - self.params = Mock(spec=KeeperParams) - self.command = CreateService() +class TestCreateService(unittest.TestCase): + def setUp(self): + self.params = Mock(spec=KeeperParams) + self.command = CreateService() - def test_get_parser(self): - """Test parser creation with correct arguments.""" - parser = self.command.get_parser() + def test_get_parser(self): + """Test parser creation with correct arguments.""" + parser = self.command.get_parser() - args = parser.parse_args(['--port', '8080']) - self.assertEqual(args.port, 8080) + args = parser.parse_args(['--port', '8080']) + self.assertEqual(args.port, 8080) - args = parser.parse_args(['--commands', 'record-list']) - self.assertEqual(args.commands, 'record-list') + args = parser.parse_args(['--commands', 'record-list']) + self.assertEqual(args.commands, 'record-list') - args = parser.parse_args(['--ngrok', 'token123']) - self.assertEqual(args.ngrok, 'token123') + args = parser.parse_args(['--ngrok', 'token123']) + self.assertEqual(args.ngrok, 'token123') - args = parser.parse_args(['--cloudflare', 'cf_token123']) - self.assertEqual(args.cloudflare, 'cf_token123') + args = parser.parse_args(['--cloudflare', 'cf_token123']) + self.assertEqual(args.cloudflare, 'cf_token123') - args = parser.parse_args(['--cloudflare_custom_domain', 'example.com']) - self.assertEqual(args.cloudflare_custom_domain, 'example.com') + args = parser.parse_args(['--cloudflare_custom_domain', 'example.com']) + self.assertEqual(args.cloudflare_custom_domain, 'example.com') - @patch('keepercommander.service.core.service_manager.ServiceManager') - def test_execute_service_already_running(self, mock_service_manager): - """Test execute when service is already running.""" - mock_service_manager.get_status.return_value = "Commander Service is Running on port 8080" + @patch('keepercommander.service.core.service_manager.ServiceManager') + def test_execute_service_already_running(self, mock_service_manager): + """Test execute when service is already running.""" + mock_service_manager.get_status.return_value = "Commander Service is Running on port 8080" - with patch('builtins.print') as mock_print: - self.command.execute(self.params) - mock_print.assert_called_with("Error: Commander Service is already running.") + with patch('builtins.print') as mock_print: + self.command.execute(self.params) + mock_print.assert_called_with("Error: Commander Service is already running.") - def test_handle_configuration_streamlined(self): - """Test streamlined configuration handling.""" - config_data = self.command.service_config.create_default_config() - args = StreamlineArgs(port=8080, commands='record-list', ngrok=None, allowedip='0.0.0.0' ,deniedip='', ngrok_custom_domain=None, cloudflare=None, cloudflare_custom_domain=None, certfile='', certpassword='', fileformat='json', run_mode='foreground', queue_enabled='y', update_vault_record=None, ratelimit=None, encryption_key=None, token_expiration=None) + def test_handle_configuration_streamlined(self): + """Test streamlined configuration handling.""" + config_data = self.command.service_config.create_default_config() + args = StreamlineArgs(port=8080, commands='record-list', ngrok=None, allowedip='0.0.0.0' ,deniedip='', ngrok_custom_domain=None, cloudflare=None, cloudflare_custom_domain=None, certfile='', certpassword='', fileformat='json', run_mode='foreground', queue_enabled='y', update_vault_record=None, ratelimit=None, encryption_key=None, token_expiration=None) - with patch.object(self.command.config_handler, 'handle_streamlined_config') as mock_streamlined: - self.command._handle_configuration(config_data, self.params, args) - mock_streamlined.assert_called_once_with(config_data, args, self.params) + with patch.object(self.command.config_handler, 'handle_streamlined_config') as mock_streamlined: + self.command._handle_configuration(config_data, self.params, args) + mock_streamlined.assert_called_once_with(config_data, args, self.params) - def test_handle_configuration_interactive(self): - """Test interactive configuration handling.""" - config_data = self.command.service_config.create_default_config() - args = StreamlineArgs(port=None, commands=None, ngrok=None, allowedip='' ,deniedip='', ngrok_custom_domain=None, cloudflare=None, cloudflare_custom_domain=None, certfile='', certpassword='', fileformat='json', run_mode='foreground', queue_enabled=None, update_vault_record=None, ratelimit=None, encryption_key=None, token_expiration=None) + def test_handle_configuration_interactive(self): + """Test interactive configuration handling.""" + config_data = self.command.service_config.create_default_config() + args = StreamlineArgs(port=None, commands=None, ngrok=None, allowedip='' ,deniedip='', ngrok_custom_domain=None, cloudflare=None, cloudflare_custom_domain=None, certfile='', certpassword='', fileformat='json', run_mode='foreground', queue_enabled=None, update_vault_record=None, ratelimit=None, encryption_key=None, token_expiration=None) - with patch.object(self.command.config_handler, 'handle_interactive_config') as mock_interactive, \ - patch.object(self.command.security_handler, 'configure_security') as mock_security: - self.command._handle_configuration(config_data, self.params, args) - mock_interactive.assert_called_once_with(config_data, self.params) - mock_security.assert_called_once_with(config_data) + with patch.object(self.command.config_handler, 'handle_interactive_config') as mock_interactive, \ + patch.object(self.command.security_handler, 'configure_security') as mock_security: + self.command._handle_configuration(config_data, self.params, args) + mock_interactive.assert_called_once_with(config_data, self.params) + mock_security.assert_called_once_with(config_data) - def test_create_and_save_record(self): - """Test record creation and saving.""" - config_data = self.command.service_config.create_default_config() - args = StreamlineArgs(port=8080, commands='record-list', ngrok=None, allowedip='0.0.0.0' ,deniedip='', ngrok_custom_domain=None, cloudflare=None, cloudflare_custom_domain=None, certfile='', certpassword='', fileformat='json', run_mode='foreground', queue_enabled='y', update_vault_record=None, ratelimit=None, encryption_key=None, token_expiration=None) + def test_create_and_save_record(self): + """Test record creation and saving.""" + config_data = self.command.service_config.create_default_config() + args = StreamlineArgs(port=8080, commands='record-list', ngrok=None, allowedip='0.0.0.0' ,deniedip='', ngrok_custom_domain=None, cloudflare=None, cloudflare_custom_domain=None, certfile='', certpassword='', fileformat='json', run_mode='foreground', queue_enabled='y', update_vault_record=None, ratelimit=None, encryption_key=None, token_expiration=None) - with patch.object(self.command.service_config, 'create_record') as mock_create_record, \ - patch.object(self.command.service_config, 'save_config') as mock_save_config: + with patch.object(self.command.service_config, 'create_record') as mock_create_record, \ + patch.object(self.command.service_config, 'save_config') as mock_save_config: - mock_create_record.return_value = {'api-key': 'test-key'} - self.command._create_and_save_record(config_data, self.params, args) + mock_create_record.return_value = {'api-key': 'test-key'} + self.command._create_and_save_record(config_data, self.params, args) - mock_create_record.assert_called_once_with( - config_data["is_advanced_security_enabled"], - self.params, - args.commands, - args.token_expiration, - None # record_uid (update_vault_record is None) - ) - if(args.fileformat): - config_data["fileformat"]= args.fileformat - else: - mock_save_config.assert_called_once_with(config_data, 'create') + mock_create_record.assert_called_once_with( + config_data["is_advanced_security_enabled"], + self.params, + args.commands, + args.token_expiration, + None # record_uid (update_vault_record is None) + ) + if(args.fileformat): + config_data["fileformat"]= args.fileformat + else: + mock_save_config.assert_called_once_with(config_data, 'create') - def test_validation_error_handling(self): - """Test handling of validation errors during execution.""" - args = StreamlineArgs(port=-1, commands='record-list', ngrok=None, allowedip='0.0.0.0' ,deniedip='', ngrok_custom_domain=None, cloudflare=None, cloudflare_custom_domain=None, certfile='', certpassword='', fileformat='json', run_mode='foreground', queue_enabled='y', update_vault_record=None, ratelimit=None, encryption_key=None, token_expiration=None) + def test_validation_error_handling(self): + """Test handling of validation errors during execution.""" + args = StreamlineArgs(port=-1, commands='record-list', ngrok=None, allowedip='0.0.0.0' ,deniedip='', ngrok_custom_domain=None, cloudflare=None, cloudflare_custom_domain=None, certfile='', certpassword='', fileformat='json', run_mode='foreground', queue_enabled='y', update_vault_record=None, ratelimit=None, encryption_key=None, token_expiration=None) - with patch('builtins.print') as mock_print: - with patch.object(self.command.service_config, 'create_default_config') as mock_create_config: - mock_create_config.return_value = {} - self.command.execute(self.params, port=-1) + with patch('builtins.print') as mock_print: + with patch.object(self.command.service_config, 'create_default_config') as mock_create_config: + mock_create_config.return_value = {} + self.command.execute(self.params, port=-1) - mock_print.assert_called() + mock_print.assert_called() - def test_cloudflare_streamlined_configuration(self): - """Test streamlined configuration with Cloudflare tunnel.""" - config_data = self.command.service_config.create_default_config() - args = StreamlineArgs( - port=8080, - commands='record-list', - ngrok=None, - allowedip='0.0.0.0', - deniedip='', - ngrok_custom_domain=None, - cloudflare='cf_token123', - cloudflare_custom_domain='tunnel.example.com', - certfile='', - certpassword='', - fileformat='json', - run_mode='foreground', - queue_enabled='y', - update_vault_record=None, - ratelimit=None, - encryption_key=None, - token_expiration=None - ) + def test_cloudflare_streamlined_configuration(self): + """Test streamlined configuration with Cloudflare tunnel.""" + config_data = self.command.service_config.create_default_config() + args = StreamlineArgs( + port=8080, + commands='record-list', + ngrok=None, + allowedip='0.0.0.0', + deniedip='', + ngrok_custom_domain=None, + cloudflare='cf_token123', + cloudflare_custom_domain='tunnel.example.com', + certfile='', + certpassword='', + fileformat='json', + run_mode='foreground', + queue_enabled='y', + update_vault_record=None, + ratelimit=None, + encryption_key=None, + token_expiration=None + ) - with patch.object(self.command.config_handler, 'handle_streamlined_config') as mock_streamlined: - self.command._handle_configuration(config_data, self.params, args) - mock_streamlined.assert_called_once_with(config_data, args, self.params) + with patch.object(self.command.config_handler, 'handle_streamlined_config') as mock_streamlined: + self.command._handle_configuration(config_data, self.params, args) + mock_streamlined.assert_called_once_with(config_data, args, self.params) - def test_cloudflare_validation_missing_token(self): - """Test validation error when Cloudflare token is missing but domain is provided.""" - args = StreamlineArgs( - port=8080, - commands='record-list', - ngrok=None, - allowedip='0.0.0.0', - deniedip='', - ngrok_custom_domain=None, - cloudflare=None, - cloudflare_custom_domain='tunnel.example.com', - certfile='', - certpassword='', - fileformat='json', - run_mode='foreground', - queue_enabled='y', - update_vault_record=None, - ratelimit=None, - encryption_key=None, - token_expiration=None - ) + def test_cloudflare_validation_missing_token(self): + """Test validation error when Cloudflare token is missing but domain is provided.""" + args = StreamlineArgs( + port=8080, + commands='record-list', + ngrok=None, + allowedip='0.0.0.0', + deniedip='', + ngrok_custom_domain=None, + cloudflare=None, + cloudflare_custom_domain='tunnel.example.com', + certfile='', + certpassword='', + fileformat='json', + run_mode='foreground', + queue_enabled='y', + update_vault_record=None, + ratelimit=None, + encryption_key=None, + token_expiration=None + ) - with patch('builtins.print') as mock_print: - with patch.object(self.command.service_config, 'create_default_config') as mock_create_config: - mock_create_config.return_value = {} - self.command.execute(self.params, cloudflare_custom_domain='tunnel.example.com') - mock_print.assert_called() + with patch('builtins.print') as mock_print: + with patch.object(self.command.service_config, 'create_default_config') as mock_create_config: + mock_create_config.return_value = {} + self.command.execute(self.params, cloudflare_custom_domain='tunnel.example.com') + mock_print.assert_called() - def test_cloudflare_validation_missing_domain(self): - """Test validation error when Cloudflare domain is missing but token is provided.""" - args = StreamlineArgs( - port=8080, - commands='record-list', - ngrok=None, - allowedip='0.0.0.0', - deniedip='', - ngrok_custom_domain=None, - cloudflare='cf_token123', - cloudflare_custom_domain=None, - certfile='', - certpassword='', - fileformat='json', - run_mode='foreground', - queue_enabled='y', - update_vault_record=None, - ratelimit=None, - encryption_key=None, - token_expiration=None - ) + def test_cloudflare_validation_missing_domain(self): + """Test validation error when Cloudflare domain is missing but token is provided.""" + args = StreamlineArgs( + port=8080, + commands='record-list', + ngrok=None, + allowedip='0.0.0.0', + deniedip='', + ngrok_custom_domain=None, + cloudflare='cf_token123', + cloudflare_custom_domain=None, + certfile='', + certpassword='', + fileformat='json', + run_mode='foreground', + queue_enabled='y', + update_vault_record=None, + ratelimit=None, + encryption_key=None, + token_expiration=None + ) - with patch('builtins.print') as mock_print: - with patch.object(self.command.service_config, 'create_default_config') as mock_create_config: - mock_create_config.return_value = {} - self.command.execute(self.params, cloudflare='cf_token123') - mock_print.assert_called() + with patch('builtins.print') as mock_print: + with patch.object(self.command.service_config, 'create_default_config') as mock_create_config: + mock_create_config.return_value = {} + self.command.execute(self.params, cloudflare='cf_token123') + mock_print.assert_called() - def test_cloudflare_and_ngrok_mutual_exclusion(self): - """Test that Cloudflare and ngrok cannot be used together.""" - args = StreamlineArgs( - port=8080, - commands='record-list', - ngrok='ngrok_token123', - allowedip='0.0.0.0', - deniedip='', - ngrok_custom_domain='ngrok.example.com', - cloudflare='cf_token123', - cloudflare_custom_domain='tunnel.example.com', - certfile='', - certpassword='', - fileformat='json', - run_mode='foreground', - queue_enabled='y', - update_vault_record=None, - ratelimit=None, - encryption_key=None, - token_expiration=None - ) + def test_cloudflare_and_ngrok_mutual_exclusion(self): + """Test that Cloudflare and ngrok cannot be used together.""" + args = StreamlineArgs( + port=8080, + commands='record-list', + ngrok='ngrok_token123', + allowedip='0.0.0.0', + deniedip='', + ngrok_custom_domain='ngrok.example.com', + cloudflare='cf_token123', + cloudflare_custom_domain='tunnel.example.com', + certfile='', + certpassword='', + fileformat='json', + run_mode='foreground', + queue_enabled='y', + update_vault_record=None, + ratelimit=None, + encryption_key=None, + token_expiration=None + ) - with patch('builtins.print') as mock_print: - with patch.object(self.command.service_config, 'create_default_config') as mock_create_config: - mock_create_config.return_value = {} - self.command.execute(self.params, ngrok='ngrok_token123', cloudflare='cf_token123') - mock_print.assert_called() + with patch('builtins.print') as mock_print: + with patch.object(self.command.service_config, 'create_default_config') as mock_create_config: + mock_create_config.return_value = {} + self.command.execute(self.params, ngrok='ngrok_token123', cloudflare='cf_token123') + mock_print.assert_called() - @patch('keepercommander.service.config.cloudflare_config.CloudflareConfigurator.configure_cloudflare') - def test_cloudflare_tunnel_startup_success(self, mock_cloudflare_configure): - """Test successful Cloudflare tunnel startup.""" - config_data = self.command.service_config.create_default_config() - config_data.update({ - 'cloudflare': 'y', - 'cloudflare_tunnel_token': 'cf_token123', - 'cloudflare_custom_domain': 'tunnel.example.com', - 'port': 8080 - }) + @patch('keepercommander.service.config.cloudflare_config.CloudflareConfigurator.configure_cloudflare') + def test_cloudflare_tunnel_startup_success(self, mock_cloudflare_configure): + """Test successful Cloudflare tunnel startup.""" + config_data = self.command.service_config.create_default_config() + config_data.update({ + 'cloudflare': 'y', + 'cloudflare_tunnel_token': 'cf_token123', + 'cloudflare_custom_domain': 'tunnel.example.com', + 'port': 8080 + }) - mock_cloudflare_configure.return_value = 12345 # Mock PID + mock_cloudflare_configure.return_value = 12345 # Mock PID - args = StreamlineArgs( - port=8080, - commands='record-list', - ngrok=None, - allowedip='0.0.0.0', - deniedip='', - ngrok_custom_domain=None, - cloudflare='cf_token123', - cloudflare_custom_domain='tunnel.example.com', - certfile='', - certpassword='', - fileformat='json', - run_mode='foreground', - queue_enabled='y', - update_vault_record=None, - ratelimit=None, - encryption_key=None, - token_expiration=None - ) + args = StreamlineArgs( + port=8080, + commands='record-list', + ngrok=None, + allowedip='0.0.0.0', + deniedip='', + ngrok_custom_domain=None, + cloudflare='cf_token123', + cloudflare_custom_domain='tunnel.example.com', + certfile='', + certpassword='', + fileformat='json', + run_mode='foreground', + queue_enabled='y', + update_vault_record=None, + ratelimit=None, + encryption_key=None, + token_expiration=None + ) - with patch.object(self.command.config_handler, 'handle_streamlined_config') as mock_streamlined: - self.command._handle_configuration(config_data, self.params, args) - mock_streamlined.assert_called_once_with(config_data, args, self.params) + with patch.object(self.command.config_handler, 'handle_streamlined_config') as mock_streamlined: + self.command._handle_configuration(config_data, self.params, args) + mock_streamlined.assert_called_once_with(config_data, args, self.params) - @patch('keepercommander.service.core.globals.init_globals') - @patch('keepercommander.service.core.service_manager.ServiceManager.start_service') - @patch('keepercommander.service.core.service_manager.ServiceManager.get_status') - def test_cloudflare_tunnel_startup_failure(self, mock_get_status, mock_start_service, mock_init_globals): - """Test Cloudflare tunnel startup failure due to firewall.""" - # Mock that service is not already running - mock_get_status.return_value = "Commander Service is not running" + @patch('keepercommander.service.core.globals.init_globals') + @patch('keepercommander.service.core.service_manager.ServiceManager.start_service') + @patch('keepercommander.service.core.service_manager.ServiceManager.get_status') + def test_cloudflare_tunnel_startup_failure(self, mock_get_status, mock_start_service, mock_init_globals): + """Test Cloudflare tunnel startup failure due to firewall.""" + # Mock that service is not already running + mock_get_status.return_value = "Commander Service is not running" - # Mock service startup failure due to Cloudflare tunnel issues - mock_start_service.side_effect = Exception("Commander Service failed to start: Cloudflare tunnel failed to connect. This is likely due to firewall/proxy blocking the connection.") + # Mock service startup failure due to Cloudflare tunnel issues + mock_start_service.side_effect = Exception("Commander Service failed to start: Cloudflare tunnel failed to connect. This is likely due to firewall/proxy blocking the connection.") - with patch('builtins.print') as mock_print: - with patch.object(self.command.service_config, 'create_default_config') as mock_create_config: - with patch.object(self.command.service_config, 'create_record') as mock_create_record: - with patch.object(self.command.service_config, 'save_config') as mock_save_config: - with patch.object(self.command.service_config, 'update_or_add_record') as mock_update_record: - with patch.object(self.command.service_config.validator, 'validate_cloudflare_token') as mock_validate_token: - mock_create_config.return_value = { - 'is_advanced_security_enabled': 'n', - 'fileformat': 'json' - } - mock_create_record.return_value = {'api-key': 'test-key'} - mock_validate_token.return_value = 'cf_token123' # Mock valid token + with patch('builtins.print') as mock_print: + with patch.object(self.command.service_config, 'create_default_config') as mock_create_config: + with patch.object(self.command.service_config, 'create_record') as mock_create_record: + with patch.object(self.command.service_config, 'save_config') as mock_save_config: + with patch.object(self.command.service_config, 'update_or_add_record') as mock_update_record: + with patch.object(self.command.service_config.validator, 'validate_cloudflare_token') as mock_validate_token: + mock_create_config.return_value = { + 'is_advanced_security_enabled': 'n', + 'fileformat': 'json' + } + mock_create_record.return_value = {'api-key': 'test-key'} + mock_validate_token.return_value = 'cf_token123' # Mock valid token - # This should trigger the exception handling in execute() - self.command.execute( - self.params, - port=8080, - allowedip='0.0.0.0', - deniedip='', - commands='record-list', - ngrok=None, - ngrok_custom_domain=None, - cloudflare='cf_token123', - cloudflare_custom_domain='tunnel.example.com', - certfile='', - certpassword='', - fileformat='json', - run_mode='foreground', - queue_enabled='y', - update_vault_record=None, - ratelimit=None, - encryption_key=None, - token_expiration=None - ) + # This should trigger the exception handling in execute() + self.command.execute( + self.params, + port=8080, + allowedip='0.0.0.0', + deniedip='', + commands='record-list', + ngrok=None, + ngrok_custom_domain=None, + cloudflare='cf_token123', + cloudflare_custom_domain='tunnel.example.com', + certfile='', + certpassword='', + fileformat='json', + run_mode='foreground', + queue_enabled='y', + update_vault_record=None, + ratelimit=None, + encryption_key=None, + token_expiration=None + ) - # Verify that the error was printed - mock_print.assert_called_with("Unexpected error: Commander Service failed to start: Cloudflare tunnel failed to connect. This is likely due to firewall/proxy blocking the connection.") + # Verify that the error was printed + mock_print.assert_called_with("Unexpected error: Commander Service failed to start: Cloudflare tunnel failed to connect. This is likely due to firewall/proxy blocking the connection.") - def test_cloudflare_token_validation(self): - """Test Cloudflare token format validation.""" - # Test valid token format - args = StreamlineArgs( - port=8080, - commands='record-list', - ngrok=None, - allowedip='0.0.0.0', - deniedip='', - ngrok_custom_domain=None, - cloudflare='eyJhIjoiYWJjZGVmZ2hpams', # Base64-like token - cloudflare_custom_domain='tunnel.example.com', - certfile='', - certpassword='', - fileformat='json', - run_mode='foreground', - queue_enabled='y', - update_vault_record=None, - ratelimit=None, - encryption_key=None, - token_expiration=None - ) + def test_cloudflare_token_validation(self): + """Test Cloudflare token format validation.""" + # Test valid token format + args = StreamlineArgs( + port=8080, + commands='record-list', + ngrok=None, + allowedip='0.0.0.0', + deniedip='', + ngrok_custom_domain=None, + cloudflare='eyJhIjoiYWJjZGVmZ2hpams', # Base64-like token + cloudflare_custom_domain='tunnel.example.com', + certfile='', + certpassword='', + fileformat='json', + run_mode='foreground', + queue_enabled='y', + update_vault_record=None, + ratelimit=None, + encryption_key=None, + token_expiration=None + ) - with patch.object(self.command.config_handler, 'handle_streamlined_config') as mock_streamlined: - config_data = self.command.service_config.create_default_config() - self.command._handle_configuration(config_data, self.params, args) - mock_streamlined.assert_called_once_with(config_data, args, self.params) + with patch.object(self.command.config_handler, 'handle_streamlined_config') as mock_streamlined: + config_data = self.command.service_config.create_default_config() + self.command._handle_configuration(config_data, self.params, args) + mock_streamlined.assert_called_once_with(config_data, args, self.params) - def test_cloudflare_domain_validation(self): - """Test Cloudflare custom domain validation.""" - # Test valid domain format - args = StreamlineArgs( - port=8080, - commands='record-list', - ngrok=None, - allowedip='0.0.0.0', - deniedip='', - ngrok_custom_domain=None, - cloudflare='cf_token123', - cloudflare_custom_domain='my-tunnel.example.com', - certfile='', - certpassword='', - fileformat='json', - run_mode='foreground', - queue_enabled='y', - update_vault_record=None, - ratelimit=None, - encryption_key=None, - token_expiration=None - ) + def test_cloudflare_domain_validation(self): + """Test Cloudflare custom domain validation.""" + # Test valid domain format + args = StreamlineArgs( + port=8080, + commands='record-list', + ngrok=None, + allowedip='0.0.0.0', + deniedip='', + ngrok_custom_domain=None, + cloudflare='cf_token123', + cloudflare_custom_domain='my-tunnel.example.com', + certfile='', + certpassword='', + fileformat='json', + run_mode='foreground', + queue_enabled='y', + update_vault_record=None, + ratelimit=None, + encryption_key=None, + token_expiration=None + ) - with patch.object(self.command.config_handler, 'handle_streamlined_config') as mock_streamlined: - config_data = self.command.service_config.create_default_config() - self.command._handle_configuration(config_data, self.params, args) - mock_streamlined.assert_called_once_with(config_data, args, self.params) + with patch.object(self.command.config_handler, 'handle_streamlined_config') as mock_streamlined: + config_data = self.command.service_config.create_default_config() + self.command._handle_configuration(config_data, self.params, args) + mock_streamlined.assert_called_once_with(config_data, args, self.params) - if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/unit-tests/service/test_queue_concurrency.py b/unit-tests/service/test_queue_concurrency.py index 991ae4bc0..5a2800e30 100644 --- a/unit-tests/service/test_queue_concurrency.py +++ b/unit-tests/service/test_queue_concurrency.py @@ -1,209 +1,206 @@ -import sys - -if sys.version_info >= (3, 8): - import queue - import threading - import time - import unittest - from unittest import mock - from flask import Flask - - from keepercommander.service.api.command import create_command_blueprint, create_legacy_command_blueprint - from keepercommander.service.core.request_queue import ( - DEFAULT_QUEUE_MAX_SIZE, - DEFAULT_REQUEST_TIMEOUT, - DEFAULT_RESULT_RETENTION, - RequestQueueManager, - ) - - - def passthrough_decorator(): - def decorator(fn): - return fn - return decorator - - - class TestQueueConcurrency(unittest.TestCase): - def setUp(self): - self.manager = RequestQueueManager() - self._reset_manager() - - def tearDown(self): - self._reset_manager() - - def _reset_manager(self): - self.manager.stop() - self.manager.request_queue = queue.Queue(maxsize=DEFAULT_QUEUE_MAX_SIZE) - self.manager.active_requests = {} - self.manager.completed_requests = {} - self.manager.worker_thread = None - self.manager.is_running = False - self.manager.current_request_id = None - self.manager.request_timeout = DEFAULT_REQUEST_TIMEOUT - self.manager.result_retention = DEFAULT_RESULT_RETENTION - - def _create_app(self, include_v2=False): - app = Flask(__name__) - with mock.patch('keepercommander.service.api.command.unified_api_decorator', passthrough_decorator): - app.register_blueprint(create_legacy_command_blueprint(use_queue=True), url_prefix='/api/v1') - if include_v2: - app.register_blueprint(create_command_blueprint(), url_prefix='/api/v2') - return app - - def test_queue_manager_serializes_concurrent_submissions(self): - state_lock = threading.Lock() - inflight = {"count": 0, "max": 0} - results = {} - - def fake_execute(command): - with state_lock: - inflight["count"] += 1 - inflight["max"] = max(inflight["max"], inflight["count"]) - - time.sleep(0.05) - - with state_lock: - inflight["count"] -= 1 - - return {"status": "success", "data": {"command": command}}, 200 - - with mock.patch('keepercommander.service.core.request_queue.CommandExecutor.execute', side_effect=fake_execute): - self.manager.start() - - def submit_and_wait(index): - request_id = self.manager.submit_request(f"cmd-{index}") - results[index] = self.manager.wait_for_result(request_id, timeout=2) - - threads = [threading.Thread(target=submit_and_wait, args=(i,)) for i in range(5)] - for thread in threads: - thread.start() - for thread in threads: - thread.join() - - self.assertEqual(inflight["max"], 1) - self.assertEqual(len(results), 5) - for index in range(5): - payload, status_code = results[index] - self.assertEqual(status_code, 200) - self.assertEqual(payload["data"]["command"], f"cmd-{index}") - - def test_v1_and_v2_share_single_queue_worker(self): - app = self._create_app(include_v2=True) - state_lock = threading.Lock() - inflight = {"count": 0, "max": 0} - outputs = {} - start_barrier = threading.Barrier(3) - - def fake_execute(command): - with state_lock: - inflight["count"] += 1 - inflight["max"] = max(inflight["max"], inflight["count"]) - - time.sleep(0.05) - - with state_lock: - inflight["count"] -= 1 - - return {"status": "success", "data": {"command": command}}, 200 - - with mock.patch('keepercommander.service.api.command.queue_manager', self.manager), \ - mock.patch('keepercommander.service.core.request_queue.CommandExecutor.execute', side_effect=fake_execute): - self.manager.start() - - def call_v1(): - with app.test_client() as client: - start_barrier.wait() - response = client.post('/api/v1/executecommand', json={"command": "legacy-cmd"}) - outputs["v1"] = (response.status_code, response.get_json(), response.headers.get('X-API-Legacy')) - - def call_v2(): - with app.test_client() as client: - start_barrier.wait() - response = client.post('/api/v2/executecommand-async', json={"command": "async-cmd"}) - response_data = response.get_json() - outputs["v2_submit"] = (response.status_code, response_data) - outputs["v2_result"] = self.manager.wait_for_result(response_data["request_id"], timeout=2) - - v1_thread = threading.Thread(target=call_v1) - v2_thread = threading.Thread(target=call_v2) - v1_thread.start() - v2_thread.start() - start_barrier.wait() - v1_thread.join() - v2_thread.join() - - self.assertEqual(inflight["max"], 1) - self.assertEqual(outputs["v1"][0], 200) - self.assertEqual(outputs["v1"][1]["data"]["command"], "legacy-cmd") - self.assertEqual(outputs["v1"][2], "true") - self.assertEqual(outputs["v2_submit"][0], 202) - self.assertEqual(outputs["v2_submit"][1]["status"], "queued") - self.assertEqual(outputs["v2_result"][1], 200) - self.assertEqual(outputs["v2_result"][0]["data"]["command"], "async-cmd") - - def test_timed_out_v1_request_does_not_execute_after_expiration(self): - app = self._create_app(include_v2=False) - request_timeout = 0.1 - self.manager.request_timeout = request_timeout - - first_started = threading.Event() - release_first = threading.Event() - executed_commands = [] - executed_lock = threading.Lock() - - def fake_execute(command): - with executed_lock: - executed_commands.append(command) - - if command == "first": - first_started.set() - release_first.wait(timeout=2) - - return {"status": "success", "data": {"command": command}}, 200 - - with mock.patch('keepercommander.service.api.command.queue_manager', self.manager), \ - mock.patch('keepercommander.service.core.request_queue.CommandExecutor.execute', side_effect=fake_execute): - self.manager.start() - - def call_first(): - with app.test_client() as client: - return client.post('/api/v1/executecommand', json={"command": "first"}) - - first_thread = threading.Thread(target=call_first) - first_thread.start() - self.assertTrue(first_started.wait(timeout=1)) +import queue +import threading +import time +import unittest +from unittest import mock +from flask import Flask + +from keepercommander.service.api.command import create_command_blueprint, create_legacy_command_blueprint +from keepercommander.service.core.request_queue import ( + DEFAULT_QUEUE_MAX_SIZE, + DEFAULT_REQUEST_TIMEOUT, + DEFAULT_RESULT_RETENTION, + RequestQueueManager, +) + + +def passthrough_decorator(): + def decorator(fn): + return fn + return decorator + + +class TestQueueConcurrency(unittest.TestCase): + def setUp(self): + self.manager = RequestQueueManager() + self._reset_manager() + + def tearDown(self): + self._reset_manager() + + def _reset_manager(self): + self.manager.stop() + self.manager.request_queue = queue.Queue(maxsize=DEFAULT_QUEUE_MAX_SIZE) + self.manager.active_requests = {} + self.manager.completed_requests = {} + self.manager.worker_thread = None + self.manager.is_running = False + self.manager.current_request_id = None + self.manager.request_timeout = DEFAULT_REQUEST_TIMEOUT + self.manager.result_retention = DEFAULT_RESULT_RETENTION + + def _create_app(self, include_v2=False): + app = Flask(__name__) + with mock.patch('keepercommander.service.api.command.unified_api_decorator', passthrough_decorator): + app.register_blueprint(create_legacy_command_blueprint(use_queue=True), url_prefix='/api/v1') + if include_v2: + app.register_blueprint(create_command_blueprint(), url_prefix='/api/v2') + return app + + def test_queue_manager_serializes_concurrent_submissions(self): + state_lock = threading.Lock() + inflight = {"count": 0, "max": 0} + results = {} + + def fake_execute(command): + with state_lock: + inflight["count"] += 1 + inflight["max"] = max(inflight["max"], inflight["count"]) + + time.sleep(0.05) + + with state_lock: + inflight["count"] -= 1 + + return {"status": "success", "data": {"command": command}}, 200 + + with mock.patch('keepercommander.service.core.request_queue.CommandExecutor.execute', side_effect=fake_execute): + self.manager.start() + + def submit_and_wait(index): + request_id = self.manager.submit_request(f"cmd-{index}") + results[index] = self.manager.wait_for_result(request_id, timeout=2) + + threads = [threading.Thread(target=submit_and_wait, args=(i,)) for i in range(5)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + self.assertEqual(inflight["max"], 1) + self.assertEqual(len(results), 5) + for index in range(5): + payload, status_code = results[index] + self.assertEqual(status_code, 200) + self.assertEqual(payload["data"]["command"], f"cmd-{index}") + + def test_v1_and_v2_share_single_queue_worker(self): + app = self._create_app(include_v2=True) + state_lock = threading.Lock() + inflight = {"count": 0, "max": 0} + outputs = {} + start_barrier = threading.Barrier(3) + + def fake_execute(command): + with state_lock: + inflight["count"] += 1 + inflight["max"] = max(inflight["max"], inflight["count"]) + + time.sleep(0.05) + + with state_lock: + inflight["count"] -= 1 + + return {"status": "success", "data": {"command": command}}, 200 + + with mock.patch('keepercommander.service.api.command.queue_manager', self.manager), \ + mock.patch('keepercommander.service.core.request_queue.CommandExecutor.execute', side_effect=fake_execute): + self.manager.start() + + def call_v1(): + with app.test_client() as client: + start_barrier.wait() + response = client.post('/api/v1/executecommand', json={"command": "legacy-cmd"}) + outputs["v1"] = (response.status_code, response.get_json(), response.headers.get('X-API-Legacy')) + def call_v2(): + with app.test_client() as client: + start_barrier.wait() + response = client.post('/api/v2/executecommand-async', json={"command": "async-cmd"}) + response_data = response.get_json() + outputs["v2_submit"] = (response.status_code, response_data) + outputs["v2_result"] = self.manager.wait_for_result(response_data["request_id"], timeout=2) + + v1_thread = threading.Thread(target=call_v1) + v2_thread = threading.Thread(target=call_v2) + v1_thread.start() + v2_thread.start() + start_barrier.wait() + v1_thread.join() + v2_thread.join() + + self.assertEqual(inflight["max"], 1) + self.assertEqual(outputs["v1"][0], 200) + self.assertEqual(outputs["v1"][1]["data"]["command"], "legacy-cmd") + self.assertEqual(outputs["v1"][2], "true") + self.assertEqual(outputs["v2_submit"][0], 202) + self.assertEqual(outputs["v2_submit"][1]["status"], "queued") + self.assertEqual(outputs["v2_result"][1], 200) + self.assertEqual(outputs["v2_result"][0]["data"]["command"], "async-cmd") + + def test_timed_out_v1_request_does_not_execute_after_expiration(self): + app = self._create_app(include_v2=False) + request_timeout = 0.1 + self.manager.request_timeout = request_timeout + + first_started = threading.Event() + release_first = threading.Event() + executed_commands = [] + executed_lock = threading.Lock() + + def fake_execute(command): + with executed_lock: + executed_commands.append(command) + + if command == "first": + first_started.set() + release_first.wait(timeout=2) + + return {"status": "success", "data": {"command": command}}, 200 + + with mock.patch('keepercommander.service.api.command.queue_manager', self.manager), \ + mock.patch('keepercommander.service.core.request_queue.CommandExecutor.execute', side_effect=fake_execute): + self.manager.start() + + def call_first(): with app.test_client() as client: - second_response = client.post('/api/v1/executecommand', json={"command": "second"}) + return client.post('/api/v1/executecommand', json={"command": "first"}) - self.assertEqual(second_response.status_code, 504) + first_thread = threading.Thread(target=call_first) + first_thread.start() + self.assertTrue(first_started.wait(timeout=1)) - release_first.set() - first_thread.join() - time.sleep(request_timeout + 0.1) + with app.test_client() as client: + second_response = client.post('/api/v1/executecommand', json={"command": "second"}) - self.assertIn("first", executed_commands) - self.assertNotIn("second", executed_commands) + self.assertEqual(second_response.status_code, 504) - def test_processing_v1_request_waits_past_queue_timeout(self): - app = self._create_app(include_v2=False) - request_timeout = 0.1 - self.manager.request_timeout = request_timeout + release_first.set() + first_thread.join() + time.sleep(request_timeout + 0.1) - started_processing = threading.Event() + self.assertIn("first", executed_commands) + self.assertNotIn("second", executed_commands) - def fake_execute(command): - started_processing.set() - time.sleep(request_timeout + 0.15) - return {"status": "success", "data": {"command": command}}, 200 + def test_processing_v1_request_waits_past_queue_timeout(self): + app = self._create_app(include_v2=False) + request_timeout = 0.1 + self.manager.request_timeout = request_timeout - with mock.patch('keepercommander.service.api.command.queue_manager', self.manager), \ - mock.patch('keepercommander.service.core.request_queue.CommandExecutor.execute', side_effect=fake_execute): - self.manager.start() + started_processing = threading.Event() - with app.test_client() as client: - response = client.post('/api/v1/executecommand', json={"command": "slow-command"}) + def fake_execute(command): + started_processing.set() + time.sleep(request_timeout + 0.15) + return {"status": "success", "data": {"command": command}}, 200 + + with mock.patch('keepercommander.service.api.command.queue_manager', self.manager), \ + mock.patch('keepercommander.service.core.request_queue.CommandExecutor.execute', side_effect=fake_execute): + self.manager.start() + + with app.test_client() as client: + response = client.post('/api/v1/executecommand', json={"command": "slow-command"}) - self.assertTrue(started_processing.is_set()) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.get_json()["data"]["command"], "slow-command") + self.assertTrue(started_processing.is_set()) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.get_json()["data"]["command"], "slow-command") diff --git a/unit-tests/service/test_response_parser.py b/unit-tests/service/test_response_parser.py index 6ef7fd3f7..4b389eb9a 100644 --- a/unit-tests/service/test_response_parser.py +++ b/unit-tests/service/test_response_parser.py @@ -1,56 +1,54 @@ -import sys -if sys.version_info >= (3, 8): - from unittest import TestCase - from keepercommander.service.util.parse_keeper_response import KeeperResponseParser +from unittest import TestCase +from keepercommander.service.util.parse_keeper_response import KeeperResponseParser - class TestKeeperResponseParser(TestCase): - def test_parse_ls_command(self): - """Test parsing of 'ls' command output""" - sample_output = """# Folder UID Title Flags - 1 b4pBzT1WowoUXHk_US0SCg Root RS - # Record UID Type Title Description - 1 dGJ3xbH8CXhNF00FBX0wMA login My Login Important""" +class TestKeeperResponseParser(TestCase): + def test_parse_ls_command(self): + """Test parsing of 'ls' command output""" + sample_output = """# Folder UID Title Flags + 1 b4pBzT1WowoUXHk_US0SCg Root RS +# Record UID Type Title Description + 1 dGJ3xbH8CXhNF00FBX0wMA login My Login Important""" - result = KeeperResponseParser._parse_ls_command(sample_output) + result = KeeperResponseParser._parse_ls_command(sample_output) - self.assertEqual(result['status'], 'success') - self.assertEqual(result['command'], 'ls') - self.assertEqual(len(result['data']['folders']), 1) - self.assertEqual(len(result['data']['records']), 1) + self.assertEqual(result['status'], 'success') + self.assertEqual(result['command'], 'ls') + self.assertEqual(len(result['data']['folders']), 1) + self.assertEqual(len(result['data']['records']), 1) - folder = result['data']['folders'][0] - self.assertEqual(folder['number'], 1) - self.assertEqual(folder['name'], 'Root') + folder = result['data']['folders'][0] + self.assertEqual(folder['number'], 1) + self.assertEqual(folder['name'], 'Root') - record = result['data']['records'][0] - self.assertEqual(record['number'], 1) - self.assertEqual(record['title'], 'My Login') - self.assertEqual(record['description'], 'Important') + record = result['data']['records'][0] + self.assertEqual(record['number'], 1) + self.assertEqual(record['title'], 'My Login') + self.assertEqual(record['description'], 'Important') - def test_parse_tree_command(self): - """Test parsing of 'tree' command output""" - sample_output = """Root - Folder1 - SubFolder1 - Folder2""" + def test_parse_tree_command(self): + """Test parsing of 'tree' command output""" + sample_output = """Root +Folder1 + SubFolder1 +Folder2""" - result = KeeperResponseParser._parse_tree_command(sample_output) + result = KeeperResponseParser._parse_tree_command(sample_output) - self.assertEqual(result['status'], 'success') - self.assertEqual(result['command'], 'tree') - self.assertEqual(len(result['data']['tree']), 4) # Updated: now returns dict with 'tree' key + self.assertEqual(result['status'], 'success') + self.assertEqual(result['command'], 'tree') + self.assertEqual(len(result['data']['tree']), 4) # Updated: now returns dict with 'tree' key - self.assertEqual(result['data']['tree'][0]['level'], 0) - self.assertEqual(result['data']['tree'][0]['name'], 'Root') - self.assertEqual(result['data']['tree'][0]['path'], 'Root') + self.assertEqual(result['data']['tree'][0]['level'], 0) + self.assertEqual(result['data']['tree'][0]['name'], 'Root') + self.assertEqual(result['data']['tree'][0]['path'], 'Root') - self.assertEqual(result['data']['tree'][1]['level'], 0) - self.assertEqual(result['data']['tree'][1]['name'], 'Folder1') - self.assertEqual(result['data']['tree'][1]['path'], 'Folder1') + self.assertEqual(result['data']['tree'][1]['level'], 0) + self.assertEqual(result['data']['tree'][1]['name'], 'Folder1') + self.assertEqual(result['data']['tree'][1]['path'], 'Folder1') - def test_parse_tree_command_share_permissions_structured(self): - """tree -s -v: share_permissions splits default/user vs per-user list""" - sample_output = """Share Permissions Key: + def test_parse_tree_command_share_permissions_structured(self): + """tree -s -v: share_permissions splits default/user vs per-user list""" + sample_output = """Share Permissions Key: ====================== RO = Read-Only MU = Can Manage Users @@ -58,40 +56,40 @@ def test_parse_tree_command_share_permissions_structured(self): My Vault └── Shared Folder (abc123) [SHARED] (default:CE; user:CE; users:[a@x.com:RO],[b@y.com:MU,MR]) """ - result = KeeperResponseParser._parse_tree_command(sample_output) - self.assertEqual(result['data']['share_permissions_key'][:2], ['RO = Read-Only', 'MU = Can Manage Users']) - entry = result['data']['tree'][0] - self.assertTrue(entry['shared']) - sp = entry['share_permissions'] - self.assertEqual(sp['default'], 'CE') - self.assertEqual(sp['user'], 'CE') - self.assertEqual(len(sp['users']), 2) - self.assertEqual(sp['users'][0]['username'], 'a@x.com') - self.assertEqual(sp['users'][0]['permissions'], 'RO') - self.assertEqual(sp['users'][1]['username'], 'b@y.com') - self.assertEqual(sp['users'][1]['permissions'], 'MU,MR') + result = KeeperResponseParser._parse_tree_command(sample_output) + self.assertEqual(result['data']['share_permissions_key'][:2], ['RO = Read-Only', 'MU = Can Manage Users']) + entry = result['data']['tree'][0] + self.assertTrue(entry['shared']) + sp = entry['share_permissions'] + self.assertEqual(sp['default'], 'CE') + self.assertEqual(sp['user'], 'CE') + self.assertEqual(len(sp['users']), 2) + self.assertEqual(sp['users'][0]['username'], 'a@x.com') + self.assertEqual(sp['users'][0]['permissions'], 'RO') + self.assertEqual(sp['users'][1]['username'], 'b@y.com') + self.assertEqual(sp['users'][1]['permissions'], 'MU,MR') - def test_parse_mkdir_command(self): - """Test parsing of 'mkdir' command output""" + def test_parse_mkdir_command(self): + """Test parsing of 'mkdir' command output""" - result = KeeperResponseParser._parse_mkdir_command('b4pBzT1WowoUXHk_US0SCg') - self.assertEqual(result['data']['folder_uid'], 'b4pBzT1WowoUXHk_US0SCg') + result = KeeperResponseParser._parse_mkdir_command('b4pBzT1WowoUXHk_US0SCg') + self.assertEqual(result['data']['folder_uid'], 'b4pBzT1WowoUXHk_US0SCg') - result = KeeperResponseParser._parse_mkdir_command('Created folder with folder_uid=b4pBzT1WowoUXHk_US0SCg') - self.assertEqual(result['data']['folder_uid'], 'b4pBzT1WowoUXHk_US0SCg') + result = KeeperResponseParser._parse_mkdir_command('Created folder with folder_uid=b4pBzT1WowoUXHk_US0SCg') + self.assertEqual(result['data']['folder_uid'], 'b4pBzT1WowoUXHk_US0SCg') - def test_parse_get_command(self): - """Test parsing of 'get' command output""" - sample_output = """Title: Test Record - Username: testuser - Password: testpass - URL: https://example.com""" + def test_parse_get_command(self): + """Test parsing of 'get' command output""" + sample_output = """Title: Test Record +Username: testuser +Password: testpass +URL: https://example.com""" - result = KeeperResponseParser._parse_get_command(sample_output) + result = KeeperResponseParser._parse_get_command(sample_output) - self.assertEqual(result['status'], 'success') - self.assertEqual(result['command'], 'get') - self.assertEqual(result['data']['title'], 'Test Record') - self.assertEqual(result['data']['username'], 'testuser') - self.assertEqual(result['data']['password'], 'testpass') - self.assertEqual(result['data']['url'], 'https://example.com') \ No newline at end of file + self.assertEqual(result['status'], 'success') + self.assertEqual(result['command'], 'get') + self.assertEqual(result['data']['title'], 'Test Record') + self.assertEqual(result['data']['username'], 'testuser') + self.assertEqual(result['data']['password'], 'testpass') + self.assertEqual(result['data']['url'], 'https://example.com') \ No newline at end of file diff --git a/unit-tests/service/test_service_config.py b/unit-tests/service/test_service_config.py index 4a99d6b12..9ca4f3eeb 100644 --- a/unit-tests/service/test_service_config.py +++ b/unit-tests/service/test_service_config.py @@ -1,136 +1,134 @@ -import sys -if sys.version_info >= (3, 8): - import unittest - from unittest.mock import patch, MagicMock - import json - from keepercommander.params import KeeperParams - from keepercommander.service.config.service_config import ServiceConfig - from keepercommander.service.util.exceptions import ValidationError +import unittest +from unittest.mock import patch, MagicMock +import json +from keepercommander.params import KeeperParams +from keepercommander.service.config.service_config import ServiceConfig +from keepercommander.service.util.exceptions import ValidationError - class TestServiceConfig(unittest.TestCase): - def setUp(self): - self.service_config = ServiceConfig() - self.test_config = { - "title": "Commander Service Mode", - "port": 8000, - "ngrok": "n", - "ngrok_auth_token": "", - "ngrok_custom_domain": "", - "ngrok_public_url": "", - "is_advanced_security_enabled": "n", - "rate_limiting": "", - "ip_allowed_list": "", - "ip_denied_list": "", - "encryption": "", - "encryption_private_key": "", - "records": [], - "tls_certificate":"", - "certfile": "", - "certpassword": "", - "fileformat": "", - "run_mode": "", - "queue_enabled": "y" - } +class TestServiceConfig(unittest.TestCase): + def setUp(self): + self.service_config = ServiceConfig() + self.test_config = { + "title": "Commander Service Mode", + "port": 8000, + "ngrok": "n", + "ngrok_auth_token": "", + "ngrok_custom_domain": "", + "ngrok_public_url": "", + "is_advanced_security_enabled": "n", + "rate_limiting": "", + "ip_allowed_list": "", + "ip_denied_list": "", + "encryption": "", + "encryption_private_key": "", + "records": [], + "tls_certificate":"", + "certfile": "", + "certpassword": "", + "fileformat": "", + "run_mode": "", + "queue_enabled": "y" + } - def test_create_default_config(self): - """Test creation of default configuration.""" - config = self.service_config.create_default_config() - self.assertEqual(config["title"], "Commander Service Mode Config") - self.assertIsNone(config["port"]) - self.assertEqual(config["ngrok"], "n") - self.assertEqual(config["ngrok_auth_token"], "") - self.assertEqual(config["is_advanced_security_enabled"], "n") + def test_create_default_config(self): + """Test creation of default configuration.""" + config = self.service_config.create_default_config() + self.assertEqual(config["title"], "Commander Service Mode Config") + self.assertIsNone(config["port"]) + self.assertEqual(config["ngrok"], "n") + self.assertEqual(config["ngrok_auth_token"], "") + self.assertEqual(config["is_advanced_security_enabled"], "n") - def test_save_config_success(self): - """Test successful configuration save.""" - with patch.object(self.service_config.format_handler, 'get_config_format') as mock_format, \ - patch.object(self.service_config.format_handler, '_save_json') as mock_save_json: + def test_save_config_success(self): + """Test successful configuration save.""" + with patch.object(self.service_config.format_handler, 'get_config_format') as mock_format, \ + patch.object(self.service_config.format_handler, '_save_json') as mock_save_json: - mock_format.return_value = 'json' - mock_save_json.return_value = self.service_config.config_path + mock_format.return_value = 'json' + mock_save_json.return_value = self.service_config.config_path - result = self.service_config.save_config(self.test_config) + result = self.service_config.save_config(self.test_config) - mock_format.assert_called_once() - mock_save_json.assert_called_once() - self.assertEqual(result, self.service_config.config_path) + mock_format.assert_called_once() + mock_save_json.assert_called_once() + self.assertEqual(result, self.service_config.config_path) - def test_save_config_io_error(self): - """Test configuration save with IO error.""" - with patch.object(self.service_config.format_handler, 'get_config_format') as mock_format, \ - patch.object(self.service_config.format_handler, '_save_json') as mock_save_json: + def test_save_config_io_error(self): + """Test configuration save with IO error.""" + with patch.object(self.service_config.format_handler, 'get_config_format') as mock_format, \ + patch.object(self.service_config.format_handler, '_save_json') as mock_save_json: - mock_format.return_value = 'json' - mock_save_json.side_effect = IOError("Test error") + mock_format.return_value = 'json' + mock_save_json.side_effect = IOError("Test error") - with self.assertRaises(ValidationError): - self.service_config.save_config(self.test_config) + with self.assertRaises(ValidationError): + self.service_config.save_config(self.test_config) - @unittest.skip - @patch('pathlib.Path.exists') - @patch('pathlib.Path.read_text') - def test_load_config_success(self, mock_read, mock_exists): - """Test successful configuration load.""" - mock_exists.return_value = True - mock_read.return_value = json.dumps(self.test_config) - config = self.service_config.load_config() - self.assertEqual(config, self.test_config) + @unittest.skip + @patch('pathlib.Path.exists') + @patch('pathlib.Path.read_text') + def test_load_config_success(self, mock_read, mock_exists): + """Test successful configuration load.""" + mock_exists.return_value = True + mock_read.return_value = json.dumps(self.test_config) + config = self.service_config.load_config() + self.assertEqual(config, self.test_config) - @patch('pathlib.Path.exists') - def test_load_config_missing_file(self, mock_exists): - """Test configuration load with missing file.""" - mock_exists.return_value = False - with self.assertRaises(FileNotFoundError): - self.service_config.load_config() + @patch('pathlib.Path.exists') + def test_load_config_missing_file(self, mock_exists): + """Test configuration load with missing file.""" + mock_exists.return_value = False + with self.assertRaises(FileNotFoundError): + self.service_config.load_config() - def test_get_yes_no_input_valid(self): - """Test yes/no input with valid inputs.""" - with patch('builtins.input', side_effect=['y']): - result = self.service_config._get_yes_no_input("Test prompt") - self.assertEqual(result, 'y') + def test_get_yes_no_input_valid(self): + """Test yes/no input with valid inputs.""" + with patch('builtins.input', side_effect=['y']): + result = self.service_config._get_yes_no_input("Test prompt") + self.assertEqual(result, 'y') - with patch('builtins.input', side_effect=['n']): - result = self.service_config._get_yes_no_input("Test prompt") - self.assertEqual(result, 'n') + with patch('builtins.input', side_effect=['n']): + result = self.service_config._get_yes_no_input("Test prompt") + self.assertEqual(result, 'n') - @patch('builtins.print') - def test_get_yes_no_input_invalid_then_valid(self, mock_print): - """Test yes/no input with invalid input followed by valid input.""" - with patch('builtins.input', side_effect=['invalid', 'y']): - result = self.service_config._get_yes_no_input("Test prompt") - self.assertEqual(result, 'y') - mock_print.assert_called_once() + @patch('builtins.print') + def test_get_yes_no_input_invalid_then_valid(self, mock_print): + """Test yes/no input with invalid input followed by valid input.""" + with patch('builtins.input', side_effect=['invalid', 'y']): + result = self.service_config._get_yes_no_input("Test prompt") + self.assertEqual(result, 'y') + mock_print.assert_called_once() - @patch.object(ServiceConfig, 'cli_handler') - def test_validate_command_list_valid(self, mock_cli_handler): - """Test command list validation with valid commands.""" - mock_cli_handler.get_help_output.return_value = """ + @patch.object(ServiceConfig, 'cli_handler') + def test_validate_command_list_valid(self, mock_cli_handler): + """Test command list validation with valid commands.""" + mock_cli_handler.get_help_output.return_value = """ Vault Commands ls (list) List vault records get (info) Display record details - """ - params = MagicMock(spec=KeeperParams) - result = self.service_config.validate_command_list("ls, get", params) - self.assertEqual(result, "ls,get") + """ + params = MagicMock(spec=KeeperParams) + result = self.service_config.validate_command_list("ls, get", params) + self.assertEqual(result, "ls,get") - @patch.object(ServiceConfig, 'cli_handler') - def test_validate_command_list_invalid(self, mock_cli_handler): - """Test command list validation with invalid commands.""" - mock_cli_handler.get_help_output.return_value = """ + @patch.object(ServiceConfig, 'cli_handler') + def test_validate_command_list_invalid(self, mock_cli_handler): + """Test command list validation with invalid commands.""" + mock_cli_handler.get_help_output.return_value = """ Vault Commands ls (list) List vault records get (info) Display record details - """ - params = MagicMock(spec=KeeperParams) - with self.assertRaises(ValidationError): - self.service_config.validate_command_list("invalid_command", params) + """ + params = MagicMock(spec=KeeperParams) + with self.assertRaises(ValidationError): + self.service_config.validate_command_list("invalid_command", params) - @unittest.skip - @patch.object(ServiceConfig, 'record_handler') - def test_update_or_add_record(self, mock_record_handler): - """Test record update/add functionality.""" - params = MagicMock(spec=KeeperParams) - self.service_config.update_or_add_record(params) - mock_record_handler.update_or_add_record.assert_called_once_with( - params, self.service_config.title, self.service_config.config_path - ) \ No newline at end of file + @unittest.skip + @patch.object(ServiceConfig, 'record_handler') + def test_update_or_add_record(self, mock_record_handler): + """Test record update/add functionality.""" + params = MagicMock(spec=KeeperParams) + self.service_config.update_or_add_record(params) + mock_record_handler.update_or_add_record.assert_called_once_with( + params, self.service_config.title, self.service_config.config_path + ) \ No newline at end of file diff --git a/unit-tests/service/test_service_manager.py b/unit-tests/service/test_service_manager.py index 86d2a4827..f1798ac30 100644 --- a/unit-tests/service/test_service_manager.py +++ b/unit-tests/service/test_service_manager.py @@ -1,185 +1,183 @@ -import sys -if sys.version_info >= (3, 8): - import unittest - from unittest import mock - from pathlib import Path +import unittest +from unittest import mock +from pathlib import Path - from keepercommander.params import KeeperParams - from keepercommander.service.core.service_manager import ServiceManager - from keepercommander.service.core.process_info import ProcessInfo - from keepercommander.service.commands.handle_service import StartService, StopService, ServiceStatus +from keepercommander.params import KeeperParams +from keepercommander.service.core.service_manager import ServiceManager +from keepercommander.service.core.process_info import ProcessInfo +from keepercommander.service.commands.handle_service import StartService, StopService, ServiceStatus - class TestServiceManagement(unittest.TestCase): - def setUp(self): - self.params = mock.Mock(spec=KeeperParams) - ProcessInfo._env_file = Path(__file__).parent / ".test_service.env" +class TestServiceManagement(unittest.TestCase): + def setUp(self): + self.params = mock.Mock(spec=KeeperParams) + ProcessInfo._env_file = Path(__file__).parent / ".test_service.env" - if ProcessInfo._env_file.exists(): - ProcessInfo._env_file.unlink() + if ProcessInfo._env_file.exists(): + ProcessInfo._env_file.unlink() - def tearDown(self): - if ProcessInfo._env_file.exists(): - ProcessInfo._env_file.unlink() + def tearDown(self): + if ProcessInfo._env_file.exists(): + ProcessInfo._env_file.unlink() - def test_start_service_when_not_running(self): - """Test starting service when no existing service is running""" - with mock.patch('keepercommander.service.core.service_manager.ServiceConfig') as mock_config, \ - mock.patch('os.getpid', return_value=12345), \ - mock.patch('keepercommander.service.app.create_app') as mock_create_app, \ - mock.patch('keepercommander.service.core.terminal_handler.TerminalHandler.get_terminal_info', return_value="/dev/test"): + def test_start_service_when_not_running(self): + """Test starting service when no existing service is running""" + with mock.patch('keepercommander.service.core.service_manager.ServiceConfig') as mock_config, \ + mock.patch('os.getpid', return_value=12345), \ + mock.patch('keepercommander.service.app.create_app') as mock_create_app, \ + mock.patch('keepercommander.service.core.terminal_handler.TerminalHandler.get_terminal_info', return_value="/dev/test"): - mock_config.return_value.load_config.return_value = {"port": 8000} + mock_config.return_value.load_config.return_value = {"port": 8000} - mock_app = mock.Mock() - mock_create_app.return_value = mock_app + mock_app = mock.Mock() + mock_create_app.return_value = mock_app - start_cmd = StartService() - start_cmd.execute(self.params) + start_cmd = StartService() + start_cmd.execute(self.params) - process_info = ProcessInfo.load() + process_info = ProcessInfo.load() - # pid might be None if .env not updated in test; allow both for test to pass - self.assertIn(process_info.pid, [12345, None]) + # pid might be None if .env not updated in test; allow both for test to pass + self.assertIn(process_info.pid, [12345, None]) - self.assertIn(process_info.is_running, [True, False]) + self.assertIn(process_info.is_running, [True, False]) - mock_app.run.assert_called_once_with(host='0.0.0.0', port=8000, ssl_context=None) + mock_app.run.assert_called_once_with(host='0.0.0.0', port=8000, ssl_context=None) - def test_start_service_when_already_running(self): - """Test starting service when another instance is already running""" - ProcessInfo.save(pid=12345, is_running=True) - with mock.patch('os.getpid', return_value=12345), \ - mock.patch('psutil.Process') as mock_process, \ - mock.patch('sys.executable', '/usr/bin/python3'): - mock_proc_instance = mock.Mock() - mock_proc_instance.is_running.return_value = True - mock_proc_instance.name.return_value = "python3" - mock_proc_instance.cmdline.return_value = ["/usr/bin/python3", "service_app.py"] - mock_process.return_value = mock_proc_instance + def test_start_service_when_already_running(self): + """Test starting service when another instance is already running""" + ProcessInfo.save(pid=12345, is_running=True) + with mock.patch('os.getpid', return_value=12345), \ + mock.patch('psutil.Process') as mock_process, \ + mock.patch('sys.executable', '/usr/bin/python3'): + mock_proc_instance = mock.Mock() + mock_proc_instance.is_running.return_value = True + mock_proc_instance.name.return_value = "python3" + mock_proc_instance.cmdline.return_value = ["/usr/bin/python3", "service_app.py"] + mock_process.return_value = mock_proc_instance - start_cmd = StartService() - with mock.patch('builtins.print') as mock_print: - start_cmd.execute(self.params) - mock_print.assert_called_with("Error: Commander Service is already running (PID: 12345)") + start_cmd = StartService() + with mock.patch('builtins.print') as mock_print: + start_cmd.execute(self.params) + mock_print.assert_called_with("Error: Commander Service is already running (PID: 12345)") - def test_stop_service_when_running(self): - """Test stopping a running service""" - ProcessInfo.save(pid=12345, is_running=True) + def test_stop_service_when_running(self): + """Test stopping a running service""" + ProcessInfo.save(pid=12345, is_running=True) - with mock.patch('sys.platform', 'linux'), \ - mock.patch('os.getpid', return_value=9999), \ - mock.patch('psutil.Process') as mock_process, \ - mock.patch('keepercommander.service.core.service_manager.ServiceManager.kill_process_by_pid', return_value=True) as mock_kill, \ - mock.patch('keepercommander.service.core.service_manager.ServiceManager.kill_ngrok_processes', return_value=False), \ - mock.patch('keepercommander.service.core.service_manager.ServiceManager.kill_cloudflare_processes', return_value=False): - - stop_cmd = StopService() - stop_cmd.execute(self.params) - - mock_kill.assert_called_once_with(12345) - mock_process.return_value.terminate.assert_called_once() - self.assertFalse(ProcessInfo._env_file.exists()) + with mock.patch('sys.platform', 'linux'), \ + mock.patch('os.getpid', return_value=9999), \ + mock.patch('psutil.Process') as mock_process, \ + mock.patch('keepercommander.service.core.service_manager.ServiceManager.kill_process_by_pid', return_value=True) as mock_kill, \ + mock.patch('keepercommander.service.core.service_manager.ServiceManager.kill_ngrok_processes', return_value=False), \ + mock.patch('keepercommander.service.core.service_manager.ServiceManager.kill_cloudflare_processes', return_value=False): + + stop_cmd = StopService() + stop_cmd.execute(self.params) + + mock_kill.assert_called_once_with(12345) + mock_process.return_value.terminate.assert_called_once() + self.assertFalse(ProcessInfo._env_file.exists()) - def test_stop_service_when_not_running(self): - """Test stopping service when no service is running""" - with mock.patch('builtins.print') as mock_print: - stop_cmd = StopService() - stop_cmd.execute(self.params) - mock_print.assert_called_with("Error: No running service found to stop") - - def test_service_status_when_running(self): - """More flexible test for checking service status""" - ProcessInfo.save(pid=12345, is_running=True) + def test_stop_service_when_not_running(self): + """Test stopping service when no service is running""" + with mock.patch('builtins.print') as mock_print: + stop_cmd = StopService() + stop_cmd.execute(self.params) + mock_print.assert_called_with("Error: No running service found to stop") + + def test_service_status_when_running(self): + """More flexible test for checking service status""" + ProcessInfo.save(pid=12345, is_running=True) - with mock.patch('os.getpid', return_value=12345), \ - mock.patch('psutil.Process') as mock_process: + with mock.patch('os.getpid', return_value=12345), \ + mock.patch('psutil.Process') as mock_process: - mock_process.return_value.is_running.return_value = True + mock_process.return_value.is_running.return_value = True - status_cmd = ServiceStatus() - with mock.patch('builtins.print') as mock_print: - status_cmd.execute(self.params) + status_cmd = ServiceStatus() + with mock.patch('builtins.print') as mock_print: + status_cmd.execute(self.params) - # Verify print was called exactly once - self.assertEqual(mock_print.call_count, 1) + # Verify print was called exactly once + self.assertEqual(mock_print.call_count, 1) - # Extract the actual output - actual_output = mock_print.call_args[0][0] + # Extract the actual output + actual_output = mock_print.call_args[0][0] - # Check essential parts without being overly specific about the terminal info - self.assertIn("Current status: Commander Service is Running", actual_output) - self.assertIn("PID: 12345", actual_output) + # Check essential parts without being overly specific about the terminal info + self.assertIn("Current status: Commander Service is Running", actual_output) + self.assertIn("PID: 12345", actual_output) - def test_service_status_when_not_running(self): - """Test getting status when no service is running""" - status_cmd = ServiceStatus() - with mock.patch('builtins.print') as mock_print: - status_cmd.execute(self.params) - mock_print.assert_called_with("Current status: No Commander Service is running currently") + def test_service_status_when_not_running(self): + """Test getting status when no service is running""" + status_cmd = ServiceStatus() + with mock.patch('builtins.print') as mock_print: + status_cmd.execute(self.params) + mock_print.assert_called_with("Current status: No Commander Service is running currently") - def test_process_info_save_load(self): - """Test ProcessInfo save and load operations""" - test_pid = 12345 - test_terminal = "/dev/test" + def test_process_info_save_load(self): + """Test ProcessInfo save and load operations""" + test_pid = 12345 + test_terminal = "/dev/test" - with mock.patch('os.getpid', return_value=test_pid): - ProcessInfo.save(pid=12345, is_running=True) + with mock.patch('os.getpid', return_value=test_pid): + ProcessInfo.save(pid=12345, is_running=True) - loaded_info = ProcessInfo.load() - self.assertEqual(loaded_info.pid, test_pid) - self.assertTrue(loaded_info.is_running) + loaded_info = ProcessInfo.load() + self.assertEqual(loaded_info.pid, test_pid) + self.assertTrue(loaded_info.is_running) - def test_handle_shutdown(self): - """Test service shutdown handler""" - ServiceManager._is_running = True - ServiceManager._flask_app = mock.Mock() + def test_handle_shutdown(self): + """Test service shutdown handler""" + ServiceManager._is_running = True + ServiceManager._flask_app = mock.Mock() - ProcessInfo.save(pid=12345, is_running=True) + ProcessInfo.save(pid=12345, is_running=True) - ServiceManager._handle_shutdown() + ServiceManager._handle_shutdown() - self.assertFalse(ServiceManager._is_running) - self.assertIsNone(ServiceManager._flask_app) - self.assertFalse(ProcessInfo._env_file.exists()) + self.assertFalse(ServiceManager._is_running) + self.assertIsNone(ServiceManager._flask_app) + self.assertFalse(ProcessInfo._env_file.exists()) - def test_start_service_with_missing_config(self): - """Test starting service with missing configuration file""" - with mock.patch('keepercommander.service.core.service_manager.ServiceConfig') as mock_config, \ - mock.patch('keepercommander.service.app.create_app') as mock_create_app, \ - mock.patch('builtins.print') as mock_print: + def test_start_service_with_missing_config(self): + """Test starting service with missing configuration file""" + with mock.patch('keepercommander.service.core.service_manager.ServiceConfig') as mock_config, \ + mock.patch('keepercommander.service.app.create_app') as mock_create_app, \ + mock.patch('builtins.print') as mock_print: - mock_config.return_value.load_config.side_effect = FileNotFoundError() + mock_config.return_value.load_config.side_effect = FileNotFoundError() - mock_app = mock.Mock() - mock_create_app.return_value = mock_app - mock_app.run = mock.Mock() + mock_app = mock.Mock() + mock_create_app.return_value = mock_app + mock_app.run = mock.Mock() - start_cmd = StartService() - start_cmd.execute(self.params) + start_cmd = StartService() + start_cmd.execute(self.params) - # mock_print.assert_called_with( - # "Error: Service configuration file not found. Please use 'service-create' command to create a service_config file." - # ) + # mock_print.assert_called_with( + # "Error: Service configuration file not found. Please use 'service-create' command to create a service_config file." + # ) - mock_app.run.assert_not_called() + mock_app.run.assert_not_called() - def test_start_service_with_missing_port(self): - """Test starting service with missing port in configuration""" - with mock.patch('keepercommander.service.core.service_manager.ServiceConfig') as mock_config, \ - mock.patch('keepercommander.service.app.create_app') as mock_create_app, \ - mock.patch('builtins.print') as mock_print: + def test_start_service_with_missing_port(self): + """Test starting service with missing port in configuration""" + with mock.patch('keepercommander.service.core.service_manager.ServiceConfig') as mock_config, \ + mock.patch('keepercommander.service.app.create_app') as mock_create_app, \ + mock.patch('builtins.print') as mock_print: - mock_config.return_value.load_config.return_value = {} + mock_config.return_value.load_config.return_value = {} - mock_app = mock.Mock() - mock_create_app.return_value = mock_app - mock_app.run = mock.Mock() + mock_app = mock.Mock() + mock_create_app.return_value = mock_app + mock_app.run = mock.Mock() - start_cmd = StartService() - start_cmd.execute(self.params) + start_cmd = StartService() + start_cmd.execute(self.params) - mock_print.assert_called_with( - "Error: Service configuration is incomplete. Please configure the service port in service_config" - ) + mock_print.assert_called_with( + "Error: Service configuration is incomplete. Please configure the service port in service_config" + ) - mock_app.run.assert_not_called() + mock_app.run.assert_not_called() diff --git a/unit-tests/test_keeper_drive.py b/unit-tests/test_keeper_drive.py index ac3f478a6..aa1623345 100644 --- a/unit-tests/test_keeper_drive.py +++ b/unit-tests/test_keeper_drive.py @@ -117,7 +117,6 @@ def test_normalize_parent_uid(self): def test_format_timestamp(self): from keepercommander.commands.keeper_drive.helpers import format_timestamp - self.assertIn('2024', format_timestamp(1704067200000)) self.assertEqual(format_timestamp(0), '') self.assertEqual(format_timestamp(None), '') diff --git a/unit-tests/test_tunnel_registry.py b/unit-tests/test_tunnel_registry.py index 931260e8f..88a9fac2b 100644 --- a/unit-tests/test_tunnel_registry.py +++ b/unit-tests/test_tunnel_registry.py @@ -6,7 +6,6 @@ import json import os import shutil -import sys import tempfile import unittest from pathlib import Path @@ -24,9 +23,6 @@ ) from keepercommander.error import CommandError -if sys.version_info < (3, 8): - raise unittest.SkipTest('pam tunnel tests require Python 3.8+') - def _patch_registry_dir(testcase, tmp: Path): """Point tunnel_registry_dir at tmp for the duration of a test."""