diff --git a/docs/README.md b/docs/README.md index ae5ef88..16fa39d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -3,3 +3,5 @@ Developer reference material for Device Connect. - **class-map.html** — Interactive class/module relationship diagram. Open in a browser to explore the architecture. +- **device-mandates.md** — Concise guide and runnable examples for mandate-protected device functions. +- **device-mandates-spec.md** — Implementation notes and acceptance criteria for the Device Mandates feature. diff --git a/docs/device-mandates-spec.md b/docs/device-mandates-spec.md new file mode 100644 index 0000000..7e817fd --- /dev/null +++ b/docs/device-mandates-spec.md @@ -0,0 +1,51 @@ +# Spec: Device Mandates + +## Objective + +Add an optional verifiable authorization layer for Device Connect RPC execution. A device function can declare that it requires a Device Mandate, and the runtime refuses to execute protected RPCs unless the caller presents a signed mandate that authorizes the target device, method, parameters, and validity window. + +The first implementation slice proves the contract end to end with a lightweight HMAC-backed mandate format suitable for local tests and demos. The verifier is intentionally small and pluggable so a later slice can replace or augment the credential format with UCAN, Biscuit, or a standards-track profile without changing the decorator or RPC metadata contract. + +## Commands + +- Edge tests: `pytest packages/device-connect-edge/tests -q` +- Agent tools tests: `pytest packages/device-connect-agent-tools/tests -q` +- Focused mandate tests: `pytest packages/device-connect-edge/tests/test_mandate_verifier.py packages/device-connect-edge/tests/test_device_mandates.py packages/device-connect-agent-tools/tests/test_agent_mandates.py -q` + +## Project Structure + +- `packages/device-connect-edge/device_connect_edge/mandates.py`: mandate data helpers, signing, and verification. +- `packages/device-connect-edge/device_connect_edge/drivers/decorators.py`: `@requires_mandate` decorator metadata. +- `packages/device-connect-edge/device_connect_edge/device.py`: runtime enforcement before driver invocation. +- `packages/device-connect-edge/device_connect_edge/types.py`: function capability metadata for mandate requirements. +- `packages/device-connect-agent-tools/device_connect_agent_tools/tools.py`: pass mandate metadata through `_dc_meta`. + +## Testing Strategy + +Use test-driven slices: + +- Pure unit tests for signing, verification, time windows, device/method binding, numeric constraints, tamper detection, and replay denial. +- Runtime tests for protected RPC denial before driver execution and successful execution with a valid mandate. +- Agent-tools tests that verify `invoke`, `invoke_many`, `broadcast`, and legacy `invoke_device` attach mandate data inside `_dc_meta`. + +## Boundaries + +- Always: fail closed for protected methods; keep mandate support optional for unprotected methods; preserve existing unprotected RPC behavior. +- Ask first: adding non-stdlib crypto/credential dependencies; changing transport protocols; adding persistent receipt storage; modifying CI. +- Never: treat unsigned client-provided mandate dictionaries as valid; pass `_dc_meta` into user driver methods; weaken existing ACL/TLS/JWT checks. + +## Success Criteria + +- A driver can mark an RPC with `@requires_mandate(scope="actuation")`. +- Discovery/capability metadata shows mandate requirements for protected functions. +- Direct JSON-RPC and broadcast execution reject protected functions with no mandate, invalid signature, wrong device, wrong method, expired mandate, or out-of-range parameters. +- Direct JSON-RPC and broadcast execution allow a protected function with a valid closed mandate. +- Agent tools can attach mandate data to invoke paths through `_dc_meta`. +- Existing unprotected RPC tests continue to pass. + +## Open Questions + +- Which production credential format should be the default: UCAN, Biscuit, or a future AP2-compatible non-payment profile? +- Where should production principal keys live: OS keystore, HSM/KMS, commissioning bundle, or registry-backed trust store? +- Should execution receipts be persisted first in the server state store or emitted as signed events before storage is added? +- Should replay protection be in-memory per device for v0, or backed by the server state layer for distributed deployments? diff --git a/docs/device-mandates.md b/docs/device-mandates.md new file mode 100644 index 0000000..8cdd783 --- /dev/null +++ b/docs/device-mandates.md @@ -0,0 +1,108 @@ +# Device Mandates + +Device Mandates add a signed authorization envelope to sensitive RPCs. A driver marks a function with `@requires_mandate`, and the runtime denies that function unless the call includes a valid closed mandate in `_dc_meta.mandate`. + +Use mandates for actuation that can affect safety, access, cost, or physical state. Read-only functions usually should not require mandates. + +## Protect an RPC + +Decorate the RPC with `@requires_mandate`. The decorator may be placed above or below `@rpc()`. + +```python +from device_connect_edge.drivers import DeviceDriver, requires_mandate, rpc + + +class SmartLockDriver(DeviceDriver): + device_type = "smart_lock" + + @requires_mandate(scope="actuation") + @rpc() + async def unlock(self, duration_s: int = 10) -> dict: + return {"state": "unlocked", "duration_s": duration_s} +``` + +Discovery metadata for `unlock` includes: + +```json +{"mandate": {"required": true, "scope": "actuation"}} +``` + +## Create Mandates + +An open mandate is signed by the principal and delegates bounded authority to an agent. A closed mandate is signed by the agent for one concrete invocation. + +```python +from datetime import datetime, timedelta, timezone + +from device_connect_edge import create_closed_mandate, create_open_mandate + +now = datetime.now(timezone.utc) +principal_key = b"principal-demo-key" +agent_key = b"agent-demo-key" + +open_mandate = create_open_mandate( + principal="operator", + agent="agent-1", + device_id="lock-front-door", + methods=["unlock"], + constraints={"duration_s": {"lte": 30}}, + not_before=now - timedelta(seconds=5), + not_after=now + timedelta(minutes=5), + key=principal_key, +) + +closed_mandate = create_closed_mandate( + open_mandate=open_mandate, + agent="agent-1", + device_id="lock-front-door", + method="unlock", + params={"duration_s": 20}, + key=agent_key, + issued_at=now, +) +``` + +Pass the closed mandate through agent tools with the `mandate` argument: + +```python +from device_connect_agent_tools import invoke + +result = invoke( + "device(lock-front-door).function(unlock)", + params={"duration_s": 20}, + mandate=closed_mandate, +) +``` + +## Valid and Invalid Use Cases + +Valid smart-lock use: unlock the front door for 20 seconds when the open mandate allows `unlock` on `lock-front-door` and constrains `duration_s <= 30`. + +Invalid smart-lock use: reuse that same mandate for `duration_s=60`, another device, another method, or changed parameters. The signature and constraint checks fail closed before the driver method runs. + +Valid heater use: set a room heater to 21.5 C when the open mandate allows `set_temperature` on `heater-living-room` and constrains `target_c` between 18 and 23. + +Invalid heater use: request `target_c=28` or replay a previously used closed mandate nonce. The verifier denies the call. + +See `packages/device-connect-edge/examples/device_mandates/mandate_examples.py` for runnable local examples of these cases. + +## Testing Commands + +Run the focused mandate tests: + +```bash +pytest packages/device-connect-edge/tests/test_mandate_verifier.py packages/device-connect-edge/tests/test_device_mandates.py packages/device-connect-agent-tools/tests/test_agent_mandates.py -q +``` + +Run package test suites: + +```bash +pytest packages/device-connect-edge/tests -q +pytest packages/device-connect-agent-tools/tests -q +``` + +Run the examples: + +```bash +PYTHONPATH=packages/device-connect-edge python packages/device-connect-edge/examples/device_mandates/mandate_examples.py +``` diff --git a/docs/discovery.md b/docs/discovery.md new file mode 100644 index 0000000..ee33d60 --- /dev/null +++ b/docs/discovery.md @@ -0,0 +1,411 @@ +# Discovery + +Device Connect uses one selector grammar to address devices, functions, and +events. The same selector string drives discovery: it tells the system +**which** entities you mean. Labels attached to devices, functions, and +events provide the dimensions to filter on. + +This guide covers the labels schema, the selector grammar, and the two +tools that resolve selectors. + +## Labels + +Labels are key/value metadata. Values are strings or lists of strings. +Lists express composite identity (a smart camera that is both `camera` and +`inference`). + +Drivers declare labels in two places: + +```python +class SmartCamera(DeviceDriver): + labels = { + "category": ["camera", "inference"], + "location": "lab-A/optics-bench", + } + + @rpc(labels={"direction": "write", "modality": ["rgb", "4k"]}) + async def capture_image(self, resolution: str = "1080p") -> dict: + ... + + @emit(labels={"modality": "motion"}) + async def state_change_detected(self, zone_id: str, state_class: str): + ... +``` + +### Well-known keys + +These keys carry conventional meaning. Custom keys are always allowed +alongside them. + +| Question | Key | Applies to | Example values | +| --- | --- | --- | --- | +| What is it? | `category` | device | `camera`, `robot`, `hub`, `sensor`, `actuator`, `inference` | +| Where is it? | `location` | device | `lab-A`, `zone-A/dock` (`/`-hierarchical, glob-able) | +| Read or write? | `direction` | function (RPC) | `read`, `write` | +| Is it dangerous? | `safety` | function + event | `critical`, `informational` | +| What kind of signal? | `modality` | function + event | `rgb`, `thermal`, `infrared`, `motion`, `4k`, ... | + +The RPC-vs-event distinction is structural (FunctionDef vs EventDef) and is +expressed by the selector scope, not by a label. + +### Drivers without label declarations + +Drivers that populate only the legacy `DeviceStatus.location` heartbeat +field are still discoverable by location: the value is mirrored into +`labels["location"]` at the discovery boundary so selector queries on +location work without a driver change. + +## Selector grammar + +``` +device() device-only +device().function() RPCs on a device subset +device().event() events on a device subset +function() all RPCs across the fleet +event() all events across the fleet +``` + +Inside `(...)`: + +- `key:value` - single-value match +- `key:[v1,v2]` - OR within a key (matches if the label value contains any + listed value; multi-valued labels match if any element is in the list) +- `key:pattern*` - anchored glob (`*`, `?`); `set_*` matches `set_threshold` + but not `unset_threshold`. Use `*set*` for substring. +- `k1:v1,k2:v2` - AND across keys +- bare string (no colon) - id/name match: `device(robot-001)`, + `function(capture_image)`. Globs allowed: `device(cam-*)`. +- `*` or empty - match all + +Keys inside `device(...)` resolve against device labels; keys inside +`function(...)` resolve against function labels; keys inside `event(...)` +resolve against event labels. The `.` chains: "narrow to these devices, +then narrow to these functions or events on them." + +### Selector examples + +``` +device(category:camera) all cameras +device(category:[camera,robot], location:lab-A/*) cameras or robots in lab-A +device(location:lab-A*) lab-A and any descendant +device(*).function(direction:write, modality:rgb) rgb-producing writes fleet-wide +device(*).event(modality:motion) all motion events +function(safety:critical) critical RPCs fleet-wide +function(estop) fleet emergency-stop targets +``` + +## Tools + +### Discovery + +#### `discover(selector, offset=0, limit=200)` + +Resolves a selector to matched entities. Returns devices, function tuples, +or event tuples depending on the selector scope. The response includes a +`label_histogram` so you can see which dimensions to narrow on next without +a separate call. + +`discover()` includes full schemas inline when the matched set is small, +and switches to a name-and-labels summary above +`DEVICE_CONNECT_FUNCTION_THRESHOLD` (default 20). The threshold is +configurable via environment variable. + +#### `discover_labels(key=None, offset=0, limit=50)` + +Returns the fleet label vocabulary. Use this first when you do not know +which dimensions are available. + +- With no `key`: returns top values per key across each axis (`device_keys`, + `function_keys`, `event_keys`). +- With a `key` like `"device.location"` or `"function.direction"`: + paginates the full value list for that one key. + +### Operations + +Calling a function on devices is one logical operation; the only choice +is whether the caller waits for replies and how they arrive. + +| Tool | Selector resolves to | Reply mode | +| --- | --- | --- | +| `invoke(selector, params)` | exactly one (device, function) tuple | sync, single result | +| `invoke_many(selector, params, timeout=)` | any number of (device, function) tuples | sync, aggregated | +| `broadcast(selector, params, where=, bindings=, fire_at=, on_late=)` | any number of (device, function) tuples | async; correlation-tagged replies stream as events | +| `subscribe(selector)` | events, or `"correlation:"` for broadcast replies | live stream (`Subscription` handle) | +| `await_replies(correlation_id, timeout=, until=)` | replies for one broadcast | sync helper that subscribes, collects, returns | + +`invoke_many` runs every target's call in parallel and returns when each +target has finished or hit its per-target timeout (30 s default). Partial +failures do not abort siblings; the response carries both `results` and +`errors` lists. + +`broadcast` does the same fan-out asynchronously: the caller gets a +`correlation_id` immediately and replies stream back on a per-device +subject keyed by that id. Subscribe with `subscribe("correlation:")` +or block with `await_replies(correlation_id, timeout=...)`. + +### Edge-side `where` predicate + +`broadcast` accepts an optional `where` expression that runs at each +candidate device. The predicate is a CEL (Common Expression Language) +string and sees four variables: + +- `identity` — device-local identity dict (`device_id`, `device_type`, ...) +- `labels` — device labels (the same labels selectors filter on) +- `status` — device status (heartbeat-updated: `location`, `availability`, + `battery`, `online`, ...) +- `bindings` — the shared payload passed to `broadcast` (selection masks, + thresholds, lookup tables) + +```python +broadcast( + "device(category:camera).function(capture_image)", + params={"resolution": "4k"}, + where="status.battery > 50 && labels.location == 'lab-A'", +) +``` + +The `where` predicate is sandboxed by CEL (no I/O, no filesystem). The +predicate evaluator is an optional install: + +``` +pip install device-connect-agent-tools[predicate] +``` + +Without the extra, calling `broadcast(..., where=...)` returns an +`invalid_predicate` error immediately at the dispatcher; calls without a +`where` work unchanged. + +### Synchronized fan-out (`fire_at` + `on_late`) + +`broadcast` accepts an optional `fire_at` (wall-clock epoch seconds). +Each device holds the message and fires from its own clock at the +deadline. `on_late` controls behaviour when a device receives the +message past the deadline: + +- `"skip"` (default) — drop the call to preserve coherence. +- `"fire"` — execute immediately. + +```python +broadcast( + "device(category:phone).function(set_flashlight)", + params={"on": True, "color": "white"}, + fire_at=time.time() + 0.500, # 500 ms in the future + on_late="skip", +) +``` + +With NTP-synced devices the achieved spread is typically 5-10 ms +(clock-sync residual) rather than the 50-150 ms a naive fire-on-receipt +broadcast would produce. + +## Response envelope + +`discover` returns a stable envelope: + +```json +{ + "scope": "device_only", + "matched": 47, + "returned": 20, + "offset": 0, + "next_offset": 20, + "results": [...], + "label_histogram": { + "category": { + "values": {"camera": 312, "robot": 89, "sensor": 601}, + "multivalued": true, + "unique_devices": 1002 + } + } +} +``` + +Fields: + +- `scope` - one of `device_only`, `device_function`, `device_event`, + `function_only`, `event_only`. +- `matched` - total matched entities (across all pages). +- `returned` - rows in this page. +- `offset` / `next_offset` - pagination cursor; `next_offset` is `null` when + no more pages. +- `results` - per-page rows. Shape depends on scope (devices, function + tuples, or event tuples). +- `label_histogram` - per-key vocabulary across the matched set + (pre-pagination), so you can choose how to narrow next. On the device + axis, multi-valued keys also carry `unique_devices`. + +The hard ceiling on `limit` is 1000 to prevent runaway responses; ask for +more pages instead. + +## Error responses + +`discover` and `discover_labels` return errors as data inside the response +envelope rather than raising. The shape is stable so callers can branch on +the `code` programmatically and surface `message` to logs or users: + +```json +{ "matched": 0, "returned": 0, "offset": 0, "next_offset": null, + "results": [], + "error": { + "code": "selector_parse_error", + "message": "Unknown scope 'widgets' at position 0\n widgets(*)\n ^" + } +} +``` + +| Code | Cause | +| --- | --- | +| `invalid_selector` | Selector is not a string (or otherwise unusable as input) | +| `selector_parse_error` | Selector is a string but malformed | +| `connection_error` | Registry or messaging backend unavailable | +| `key_not_axis_qualified` | `discover_labels(key=...)` missing the `device.` / `function.` / `event.` prefix | +| `unknown_axis` | `discover_labels(key=...)` axis prefix not in `{device, function, event}` | + +## Worked examples + +### Browse the fleet vocabulary + +```python +from device_connect_agent_tools import connect, discover_labels + +connect() +vocab = discover_labels() +# {"total_devices": 1247, "total_functions": 7100, "total_events": 1292, +# "device_keys": {"category": {...}, "location": {...}}, +# "function_keys": {"direction": {...}, "modality": {...}, "safety": {...}}, +# "event_keys": {"modality": {...}}} + +# Drill into one dimension: +locations = discover_labels(key="device.location", limit=50) +``` + +### Find every camera in lab-A + +```python +from device_connect_agent_tools import discover + +result = discover("device(category:camera, location:lab-A/*)") +for d in result["results"]: + print(d["device_id"], d["labels"]) +``` + +### Find every write RPC on cameras, fleet-wide + +```python +result = discover("device(category:camera).function(direction:write)") +for row in result["results"]: + print(row["device_id"], row["name"]) +``` + +### Paginate a large result set + +```python +offset = 0 +while True: + page = discover("device(*)", offset=offset, limit=200) + for d in page["results"]: + process(d) + if page["next_offset"] is None: + break + offset = page["next_offset"] +``` + +### Invoke a single function + +```python +from device_connect_agent_tools import invoke + +result = invoke( + "device(robot-001).function(grip_close)", + {"force_n": 10}, +) +# {"success": True, "device_id": "robot-001", "function": "grip_close", +# "result": {...}} +``` + +### Fan out across every camera in lab-A + +```python +from device_connect_agent_tools import invoke_many + +result = invoke_many( + "device(category:camera, location:lab-A).function(capture_image)", + {"resolution": "4k"}, +) +# {"candidates": 12, "matched": 12, "succeeded": 12, "failed": 0, +# "results": [...], "errors": []} +``` + +### Async fleet emergency stop + +```python +from device_connect_agent_tools import broadcast, await_replies + +result = broadcast("function(estop)") +# {"correlation_id": "br-7f3a91", "candidates": 240, ...} + +replies = await_replies(result["correlation_id"], timeout=5.0) +# list of {device_id, success, result|error, actually_fired_at} +``` + +### Synchronized actuation across a phone fleet + +```python +import time +from device_connect_agent_tools import broadcast + +mask = build_mask_from_scores(threshold=0.8) # caller-side selection +broadcast( + "device(category:phone, location:auditorium-A).function(set_flashlight)", + params={"on": True, "color": "white"}, + where="mask[seat_row][seat_col] == 1 && status.battery > 30", + bindings={"mask": mask}, + fire_at=time.time() + 0.5, + on_late="skip", +) +``` + +### Subscribe to motion events in lab-A + +```python +from device_connect_agent_tools import subscribe + +with subscribe("device(location:lab-A/*).event(modality:motion)") as sub: + for event in sub.iter(timeout=60.0): + handle(event) +``` + +## CLI + +The same selector syntax drives the operator CLIs. Every CLI command +maps to the matching Python tool call. + +``` +# Discovery (devctl) +devctl discover "" [--offset N] [--limit M] +devctl discover-labels [--key K] [--offset N] [--limit M] + +# Operations (statectl) +statectl invoke "" [--param k=v ...] +statectl invoke-many "" [--param k=v ...] [--timeout T] +statectl broadcast "" [--param k=v ...] [--where E] + [--bindings JSON] [--fire-at T] + [--on-late skip|fire] +statectl subscribe "" [--timeout T] [--until N] +statectl await [--timeout T] [--until N] +``` + +`--param k=v` accepts JSON-shaped values (numbers, booleans, arrays, +objects); everything else passes through as a string. So +`--param resolution=4k` and `--param zones='[1,2,3]'` both work +without quoting heroics. + +Each verb exits non-zero on tool-side errors so the verbs compose into +shell pipelines: + +``` +statectl broadcast "device(category:camera).function(capture_image)" \ + --param resolution=4k \ + | jq -r .correlation_id \ + | xargs statectl await --timeout 5 +``` diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py b/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py index fb5b198..1a7c1e0 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py @@ -4,37 +4,46 @@ """Device Connect Tools — framework-agnostic SDK for Device Connect IoT. -Hierarchical discovery keeps LLM context small: +Selector-driven discovery and invocation keep LLM context small: - from device_connect_agent_tools import connect, describe_fleet, list_devices + from device_connect_agent_tools import connect, discover, discover_labels, invoke connect() - fleet = describe_fleet() # bird's-eye summary (~200 tokens) - cameras = list_devices(device_type="camera") # compact roster - info = get_device_functions("camera-001") # full schemas for one device - result = invoke_device("camera-001", "capture_image", {"resolution": "1080p"}) + vocab = discover_labels() # fleet vocabulary + cams = discover("device(category:camera, location:zone-A/*)") # device roster + writes = discover("device(*).function(direction:write)") # function tuples + result = invoke("device(camera-001).function(capture_image)", + {"resolution": "1080p"}) -Strands: - from device_connect_agent_tools import connect - from device_connect_agent_tools.adapters.strands import ( - describe_fleet, list_devices, get_device_functions, invoke_device, - ) - from strands import Agent - - connect() - agent = Agent(tools=[describe_fleet, list_devices, get_device_functions, invoke_device]) +The older ``describe_fleet`` / ``list_devices`` / ``get_device_functions`` / +``invoke_device`` family remains available for one release as +advisory-deprecated wrappers -- prefer ``discover`` / ``discover_labels`` / +``invoke`` / ``invoke_many`` for new code. """ from device_connect_agent_tools.agent import DeviceConnectAgent from device_connect_agent_tools.connection import connect, disconnect, get_connection from device_connect_agent_tools.tools import ( + # Selector-driven discovery (preferred) + discover, + discover_labels, + # Selector-driven invocation (preferred) + invoke, + invoke_many, + broadcast, + # Selector-driven subscription + Subscription, + subscribe, + await_replies, + # Other invocation helpers + invoke_device_with_fallback, + get_device_status, + # Advisory-deprecated wrappers (one-release transition) describe_fleet, list_devices, get_device_functions, - discover_devices, invoke_device, - invoke_device_with_fallback, - get_device_status, + discover_devices, ) __all__ = [ @@ -44,14 +53,24 @@ "get_connection", # High-level agent "DeviceConnectAgent", - # Hierarchical discovery tools (recommended) + # Selector-driven discovery (preferred) + "discover", + "discover_labels", + # Selector-driven invocation (preferred) + "invoke", + "invoke_many", + "broadcast", + # Selector-driven subscription + "Subscription", + "subscribe", + "await_replies", + # Other invocation helpers + "invoke_device_with_fallback", + "get_device_status", + # Advisory-deprecated -- use discover / discover_labels / invoke instead "describe_fleet", "list_devices", "get_device_functions", - # Invocation tools "invoke_device", - "invoke_device_with_fallback", - "get_device_status", - # Backward-compatible (deprecated — use hierarchical tools instead) "discover_devices", ] diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/_normalize.py b/packages/device-connect-agent-tools/device_connect_agent_tools/_normalize.py index 335a5c4..1dd38d1 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/_normalize.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/_normalize.py @@ -143,3 +143,60 @@ def group_devices( key = d.get(group_by) or "unknown" groups[key].append(summary) return {"groups": dict(sorted(groups.items())), "total": len(devices)} + + +# -- Label histograms --------------------------------------------------- + + +def _accumulate_label( + histogram: dict, multivalued_keys: set, label_key: str, label_value: Any +) -> None: + """Record one ``label_key -> label_value`` observation in ``histogram``. + + ``label_value`` may be a string or a list of strings. Lists are flagged + in ``multivalued_keys`` so the caller can annotate them in the response. + """ + if isinstance(label_value, list): + multivalued_keys.add(label_key) + for v in label_value: + histogram[label_key][str(v)] = histogram[label_key].get(str(v), 0) + 1 + else: + histogram[label_key][str(label_value)] = histogram[label_key].get(str(label_value), 0) + 1 + + +def label_histogram( + items: list[dict], *, count_unique: bool = False +) -> tuple: + """Build ``{key: {value: count}}`` histograms across item labels. + + Multi-valued labels (list values) increment the histogram for each + member -- a device with ``category: [camera, inference]`` adds 1 to + both ``camera`` and ``inference``. Keys observed with any list value + are surfaced via ``multivalued_keys`` so callers can annotate the + response. + + Args: + items: Records with optional ``labels`` field (devices, functions, + or events). + count_unique: When True, also tracks how many distinct items + declared each key. Useful only for the device axis, where a + multi-valued label can otherwise mask the unique-device count. + + Returns: + ``(histogram, multivalued_keys)`` when ``count_unique=False``; + ``(histogram, multivalued_keys, unique_per_key)`` when + ``count_unique=True``. + """ + histogram: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) + multivalued: set[str] = set() + unique: dict[str, int] | None = defaultdict(int) if count_unique else None + for item in items: + labels = item.get("labels") or {} + for k, v in labels.items(): + if unique is not None: + unique[k] += 1 + _accumulate_label(histogram, multivalued, k, v) + flat = {k: dict(vals) for k, vals in histogram.items()} + if unique is not None: + return flat, multivalued, dict(unique) + return flat, multivalued diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/__init__.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/__init__.py index a9d65f6..d54403e 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/__init__.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/__init__.py @@ -8,12 +8,12 @@ # Strands from device_connect_agent_tools.adapters.strands import ( - describe_fleet, list_devices, get_device_functions, invoke_device, + discover_labels, discover, invoke_device, ) # LangChain from device_connect_agent_tools.adapters.langchain import ( - describe_fleet, list_devices, get_device_functions, invoke_device, + discover_labels, discover, invoke_device, ) # Claude Agent SDK diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py index 70ad42f..9256968 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py @@ -4,7 +4,7 @@ """Claude Agent SDK adapter — exposes Device Connect tools to claude-agent-sdk. -Hierarchical discovery keeps LLM context small:: +Selector-driven discovery and invocation keep LLM context small:: import anyio from claude_agent_sdk import ClaudeSDKClient, ClaudeAgentOptions, AssistantMessage, TextBlock @@ -42,11 +42,13 @@ async def main(): from claude_agent_sdk import tool, create_sdk_mcp_server from device_connect_agent_tools.tools import ( - describe_fleet as _describe_fleet, - list_devices as _list_devices, - get_device_functions as _get_device_functions, + discover as _discover, + discover_labels as _discover_labels, discover_devices as _discover_devices, - invoke_device as _invoke_device, + invoke as _invoke, + invoke_many as _invoke_many, + broadcast as _broadcast, + await_replies as _await_replies, invoke_device_with_fallback as _invoke_device_with_fallback, get_device_status as _get_device_status, ) @@ -56,77 +58,156 @@ def _text(result: Any) -> dict[str, Any]: return {"content": [{"type": "text", "text": json.dumps(result, default=str)}]} -# Hierarchical discovery tools (recommended) +# Selector-driven discovery tools (recommended) @tool( - "describe_fleet", - "Get a high-level summary of all available devices, grouped by type and " - "location. Use this first to understand what is available, then call " - "list_devices to browse specific types or locations.", - {}, + "discover_labels", + "Browse the label vocabulary across the fleet. Returns label keys " + "(category, location, direction, modality, ...) with their values and " + "counts. Call with no arguments to see all keys, or with key=" + "'device.location' / 'function.direction' / etc. to paginate one key. " + "Use this first to learn what dimensions are available before calling " + "discover().", + {"key": str, "offset": int, "limit": int}, ) -async def describe_fleet(args: dict[str, Any]) -> dict[str, Any]: - return _text(_describe_fleet()) +async def discover_labels(args: dict[str, Any]) -> dict[str, Any]: + return _text( + _discover_labels( + key=args.get("key"), + offset=int(args.get("offset", 0)), + limit=int(args.get("limit", 50)), + ) + ) @tool( - "list_devices", - "Browse available devices with filtering and pagination. Returns compact " - "device summaries (no full schemas). Use get_device_functions for details.", - { - "device_type": str, - "location": str, - "status": str, - "group_by": str, - "offset": int, - "limit": int, - }, + "discover", + "Resolve a selector to matched devices, functions, or events. Selector " + "grammar: device(), device().function(), " + "device().event(), function(), or " + "event(). Filters are key:value pairs (AND across keys with " + "commas, OR within a key with bracket lists, glob with *). Examples: " + "'device(category:camera, location:zone-A/*)', " + "'device(*).function(direction:write)', 'event(modality:motion)'. " + "Response includes a label_histogram (per-key vocabulary across the " + "matched set) so the agent can narrow next.", + {"selector": str, "offset": int, "limit": int}, ) -async def list_devices(args: dict[str, Any]) -> dict[str, Any]: +async def discover(args: dict[str, Any]) -> dict[str, Any]: return _text( - _list_devices( - device_type=args.get("device_type"), - location=args.get("location"), - status=args.get("status"), - group_by=args.get("group_by"), + _discover( + selector=args["selector"], offset=int(args.get("offset", 0)), - limit=int(args.get("limit", 20)), + limit=int(args.get("limit", 200)), ) ) +# Selector-driven invocation tools (recommended) + + @tool( - "get_device_functions", - "Get full function schemas for a specific device. Call this after " - "list_devices to see what a device can do and what parameters each " - "function accepts.", - {"device_id": str}, + "invoke", + "Call exactly one function on one device. The selector must resolve " + "to a single (device, function) tuple -- use device().function() " + "or function() scope. Returns {success, device_id, function, " + "result|error}. Use invoke_many for fan-out across multiple targets.", + { + "selector": str, "params": dict, "llm_reasoning": str, + "mandate": dict, + }, ) -async def get_device_functions(args: dict[str, Any]) -> dict[str, Any]: - return _text(_get_device_functions(device_id=args["device_id"])) +async def invoke(args: dict[str, Any]) -> dict[str, Any]: + return _text( + _invoke( + selector=args["selector"], + params=args.get("params"), + llm_reasoning=args.get("llm_reasoning"), + mandate=args.get("mandate"), + ) + ) -# Invocation tools +@tool( + "invoke_many", + "Fan out a function call over a selector-resolved set of (device, " + "function) tuples in parallel. Partial-failure semantics: per-target " + "results and errors are returned even if some targets fail. Returns " + "{candidates, matched, succeeded, failed, results, errors}. Each " + "target gets a per-call timeout (default 30s).", + { + "selector": str, "params": dict, "timeout": float, + "max_concurrency": int, "llm_reasoning": str, "mandate": dict, + }, +) +async def invoke_many(args: dict[str, Any]) -> dict[str, Any]: + return _text( + _invoke_many( + selector=args["selector"], + params=args.get("params"), + timeout=float(args.get("timeout", 30.0)), + max_concurrency=int(args.get("max_concurrency", 32)), + llm_reasoning=args.get("llm_reasoning"), + mandate=args.get("mandate"), + ) + ) @tool( - "invoke_device", - "Call a function on a Device Connect device. Use get_device_functions " - "first to learn available functions and parameters.", - {"device_id": str, "function": str, "params": dict, "llm_reasoning": str}, + "broadcast", + "Async selector-driven fan-out. Returns immediately with a " + "correlation_id; replies stream on a per-device subject keyed by id. " + "Each candidate self-elects via the optional CEL `where` predicate " + "(evaluated at the edge against identity/labels/status/bindings) and " + "executes the function. Use fire_at (wall-clock epoch seconds) + " + "on_late (skip|fire) for synchronized fan-out. Pair with " + "await_replies(correlation_id) to collect outcomes.", + { + "selector": str, "params": dict, "where": str, "bindings": dict, + "fire_at": float, "on_late": str, "llm_reasoning": str, + "mandate": dict, + }, ) -async def invoke_device(args: dict[str, Any]) -> dict[str, Any]: +async def broadcast(args: dict[str, Any]) -> dict[str, Any]: return _text( - _invoke_device( - device_id=args["device_id"], - function=args["function"], + _broadcast( + selector=args["selector"], params=args.get("params"), + where=args.get("where"), + bindings=args.get("bindings"), + fire_at=args.get("fire_at"), + on_late=args.get("on_late", "skip"), llm_reasoning=args.get("llm_reasoning"), + mandate=args.get("mandate"), ) ) +@tool( + "await_replies", + "Collect replies for a broadcast() call. Subscribes to the " + "correlation reply subject, drains for up to `timeout` seconds (or " + "until `until` replies have arrived), then returns the list.", + { + "correlation_id": str, "timeout": float, "until": int, + "poll_interval": float, + }, +) +async def await_replies(args: dict[str, Any]) -> dict[str, Any]: + return _text( + _await_replies( + correlation_id=args["correlation_id"], + timeout=float(args.get("timeout", 10.0)), + until=int(args["until"]) if args.get("until") is not None else None, + poll_interval=float(args.get("poll_interval", 0.05)), + ) + ) + + +# Other invocation helpers + + @tool( "invoke_device_with_fallback", "Call a function with automatic fallback across a list of device IDs. " @@ -153,13 +234,13 @@ async def get_device_status(args: dict[str, Any]) -> dict[str, Any]: return _text(_get_device_status(device_id=args["device_id"])) -# Backward-compatible (deprecated — use hierarchical tools instead) +# Backward-compatible (long-deprecated -- prefer discover() / invoke()) @tool( "discover_devices", - "Deprecated — use describe_fleet, list_devices, and get_device_functions " - "instead. Discover all devices with full function schemas.", + "Deprecated -- use discover() and discover_labels() instead. Discovers " + "all devices with full function schemas.", {"device_type": str, "refresh": bool}, ) async def discover_devices(args: dict[str, Any]) -> dict[str, Any]: @@ -179,10 +260,12 @@ def create_device_connect_server(name: str = "device-connect"): return create_sdk_mcp_server( name, tools=[ - describe_fleet, - list_devices, - get_device_functions, - invoke_device, + discover_labels, + discover, + invoke, + invoke_many, + broadcast, + await_replies, invoke_device_with_fallback, get_device_status, discover_devices, @@ -191,12 +274,14 @@ def create_device_connect_server(name: str = "device-connect"): __all__ = [ - "describe_fleet", - "list_devices", - "get_device_functions", - "discover_devices", - "invoke_device", + "discover_labels", + "discover", + "invoke", + "invoke_many", + "broadcast", + "await_replies", "invoke_device_with_fallback", "get_device_status", + "discover_devices", "create_device_connect_server", ] diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py index 6e0b8a3..c18ed7e 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py @@ -4,16 +4,16 @@ """LangChain adapter — wraps Device Connect tools as LangChain StructuredTools. -Hierarchical discovery keeps LLM context small: +Selector-driven discovery and invocation keep LLM context small: from device_connect_agent_tools import connect from device_connect_agent_tools.adapters.langchain import ( - describe_fleet, list_devices, get_device_functions, invoke_device, + discover_labels, discover, invoke, invoke_many, ) from langgraph.prebuilt import create_react_agent connect() - agent = create_react_agent(model, [describe_fleet, list_devices, get_device_functions, invoke_device]) + agent = create_react_agent(model, [discover_labels, discover, invoke, invoke_many]) Requires: pip install device-connect-agent-tools[langchain] """ @@ -21,34 +21,42 @@ from langchain_core.tools import StructuredTool from device_connect_agent_tools.tools import ( - describe_fleet as _describe_fleet, - list_devices as _list_devices, - get_device_functions as _get_device_functions, + discover as _discover, + discover_labels as _discover_labels, discover_devices as _discover_devices, - invoke_device as _invoke_device, + invoke as _invoke, + invoke_many as _invoke_many, + broadcast as _broadcast, + await_replies as _await_replies, invoke_device_with_fallback as _invoke_device_with_fallback, get_device_status as _get_device_status, ) -# Hierarchical discovery tools (recommended) -describe_fleet = StructuredTool.from_function(_describe_fleet) -list_devices = StructuredTool.from_function(_list_devices) -get_device_functions = StructuredTool.from_function(_get_device_functions) +# Selector-driven discovery (recommended) +discover_labels = StructuredTool.from_function(_discover_labels) +discover = StructuredTool.from_function(_discover) -# Invocation tools -invoke_device = StructuredTool.from_function(_invoke_device) +# Selector-driven invocation (recommended) +invoke = StructuredTool.from_function(_invoke) +invoke_many = StructuredTool.from_function(_invoke_many) +broadcast = StructuredTool.from_function(_broadcast) +await_replies = StructuredTool.from_function(_await_replies) + +# Other invocation helpers invoke_device_with_fallback = StructuredTool.from_function(_invoke_device_with_fallback) get_device_status = StructuredTool.from_function(_get_device_status) -# Backward-compatible (deprecated — use hierarchical tools instead) +# Backward-compatible (long-deprecated -- prefer discover() / invoke()) discover_devices = StructuredTool.from_function(_discover_devices) __all__ = [ - "describe_fleet", - "list_devices", - "get_device_functions", - "discover_devices", - "invoke_device", + "discover_labels", + "discover", + "invoke", + "invoke_many", + "broadcast", + "await_replies", "invoke_device_with_fallback", "get_device_status", + "discover_devices", ] diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py index 308c2a7..b68c16b 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py @@ -4,16 +4,16 @@ """Strands adapter — wraps Device Connect tools with @strands.tool. -Hierarchical discovery keeps LLM context small: +Selector-driven discovery and invocation keep LLM context small: from device_connect_agent_tools import connect from device_connect_agent_tools.adapters.strands import ( - describe_fleet, list_devices, get_device_functions, invoke_device, + discover_labels, discover, invoke, invoke_many, ) from strands import Agent connect() - agent = Agent(tools=[describe_fleet, list_devices, get_device_functions, invoke_device]) + agent = Agent(tools=[discover_labels, discover, invoke, invoke_many]) agent("What devices are online?") Requires: pip install device-connect-agent-tools[strands] @@ -22,34 +22,42 @@ from strands import tool as strands_tool from device_connect_agent_tools.tools import ( - describe_fleet as _describe_fleet, - list_devices as _list_devices, - get_device_functions as _get_device_functions, + discover as _discover, + discover_labels as _discover_labels, discover_devices as _discover_devices, - invoke_device as _invoke_device, + invoke as _invoke, + invoke_many as _invoke_many, + broadcast as _broadcast, + await_replies as _await_replies, invoke_device_with_fallback as _invoke_device_with_fallback, get_device_status as _get_device_status, ) -# Hierarchical discovery tools (recommended) -describe_fleet = strands_tool(_describe_fleet) -list_devices = strands_tool(_list_devices) -get_device_functions = strands_tool(_get_device_functions) +# Selector-driven discovery (recommended) +discover_labels = strands_tool(_discover_labels) +discover = strands_tool(_discover) -# Invocation tools -invoke_device = strands_tool(_invoke_device) +# Selector-driven invocation (recommended) +invoke = strands_tool(_invoke) +invoke_many = strands_tool(_invoke_many) +broadcast = strands_tool(_broadcast) +await_replies = strands_tool(_await_replies) + +# Other invocation helpers invoke_device_with_fallback = strands_tool(_invoke_device_with_fallback) get_device_status = strands_tool(_get_device_status) -# Backward-compatible (deprecated — use hierarchical tools instead) +# Backward-compatible (long-deprecated -- prefer discover() / invoke()) discover_devices = strands_tool(_discover_devices) __all__ = [ - "describe_fleet", - "list_devices", - "get_device_functions", - "discover_devices", - "invoke_device", + "discover_labels", + "discover", + "invoke", + "invoke_many", + "broadcast", + "await_replies", "invoke_device_with_fallback", "get_device_status", + "discover_devices", ] diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands_agent.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands_agent.py index 7b6c532..c5f5e67 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands_agent.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands_agent.py @@ -60,10 +60,10 @@ async def prepare(self) -> Dict[str, Any]: from strands import Agent from strands.models import AnthropicModel from device_connect_agent_tools.adapters.strands import ( - describe_fleet, - list_devices, - get_device_functions, - invoke_device, + discover_labels, + discover, + invoke, + invoke_many, invoke_device_with_fallback, get_device_status, ) @@ -74,8 +74,8 @@ async def prepare(self) -> Dict[str, Any]: self._agent = Agent( model=AnthropicModel(model_id=self._model_id, max_tokens=self._max_tokens), tools=[ - describe_fleet, list_devices, get_device_functions, - invoke_device, invoke_device_with_fallback, get_device_status, + discover_labels, discover, + invoke, invoke_many, invoke_device_with_fallback, get_device_status, ], system_prompt=system_prompt, ) @@ -92,8 +92,8 @@ def _build_system_prompt(self) -> str: """Build a system prompt from discovered devices. Uses a compact fleet summary instead of dumping all device schemas. - The agent can use describe_fleet(), list_devices(), and - get_device_functions() to drill into details as needed. + The agent can use discover_labels() and discover() to drill into + details as needed. """ # Build compact fleet summary (type counts + locations) from collections import defaultdict @@ -108,22 +108,31 @@ def _build_system_prompt(self) -> str: for dt, info in sorted(by_type.items()): locs = ", ".join(sorted(info["locations"])) type_lines.append(f" - {info['count']}x {dt} (at: {locs})") - fleet_summary = "\n".join(type_lines) or " (none yet — call describe_fleet() to refresh)" + fleet_summary = "\n".join(type_lines) or " (none yet -- call discover() to refresh)" return ( f"You are an AI agent connected to the Device Connect IoT network.\n\n" f"YOUR GOAL: {self.goal}\n\n" f"FLEET OVERVIEW ({len(self.devices)} devices):\n{fleet_summary}\n\n" f"DISCOVERY TOOLS:\n" - f" - describe_fleet() — fleet summary (what you see above)\n" - f" - list_devices(device_type=..., location=...) — browse devices\n" - f" - get_device_functions(device_id) — see what a device can do\n" - f" - invoke_device(device_id, function, params) — call a device function\n\n" + f" - discover_labels(key=None) -- fleet label vocabulary " + f"(category, location, direction, modality, ...)\n" + f" - discover(selector) -- resolve a selector to devices, " + f"functions, or events. Examples:\n" + f" device(category:camera, location:zone-A/*)\n" + f" device(robot-001).function(direction:write)\n" + f" function(safety:critical)\n\n" + f"INVOCATION TOOLS:\n" + f" - invoke(selector, params) -- call exactly one function. " + f"Selector must resolve to one (device, function) tuple.\n" + f" - invoke_many(selector, params) -- fan out a function call " + f"over a selector-resolved set in parallel.\n\n" f"INSTRUCTIONS:\n" f"When you receive device events, you MUST:\n" f"1. Analyze the events\n" - f"2. Use get_device_functions() to check available functions if needed\n" - f"3. Use invoke_device() to interact with devices\n" + f"2. Use discover() with a function-scoped selector to check " + f"available functions if needed\n" + f"3. Use invoke() or invoke_many() to interact with devices\n" f"4. Report what you found and what actions you took\n\n" f"Always provide llm_reasoning when invoking devices to explain your decision.\n" f"Always call at least one tool per batch of events." diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py b/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py index 2dba8fc..4dce5fb 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py @@ -125,6 +125,16 @@ def flatten_device(raw: Dict[str, Any]) -> Dict[str, Any]: status = raw.get("status") or {} caps = raw.get("capabilities") or {} + # Mirror the legacy DeviceStatus.location field into labels["location"] + # when the driver did not declare it via DeviceCapabilities.labels. Drivers + # using only the legacy field would otherwise be invisible to selector + # queries on location. + legacy_location = raw.get("location") or status.get("location") + caps_labels = caps.get("labels") + merged_labels = caps_labels + if legacy_location and (not caps_labels or "location" not in caps_labels): + merged_labels = {**(caps_labels or {}), "location": legacy_location} + # NOTE: The raw ``capabilities`` dict is intentionally NOT included in # the flattened output. ``functions`` and ``events`` are extracted to # the top level for direct access. Including both would duplicate data @@ -132,11 +142,16 @@ def flatten_device(raw: Dict[str, Any]) -> Dict[str, Any]: return { "device_id": raw.get("device_id"), "device_type": raw.get("device_type") or identity.get("device_type"), - "location": raw.get("location") or status.get("location"), + "location": legacy_location, "status": status, "identity": identity, "functions": caps.get("functions", []), "events": caps.get("events", []), + # Discovery labels declared by the driver (DeviceCapabilities.labels), + # with status.location mirrored in when caps did not carry it. None + # when neither source provided any label -- discover() treats that + # as "no label-based match," not "matches everything." + "labels": merged_labels, } @@ -394,13 +409,33 @@ async def _async_invoke( # ── Broadcast ──────────────────────────────────────────────────── + def publish_broadcast(self, envelope: Dict[str, Any]) -> None: + """Publish a selector-driven broadcast envelope to the fanout subject. + + The envelope shape is documented in + ``device_connect_edge.device.DeviceRuntime._broadcast_subscription``; + every device subscribed to ``device-connect..broadcast`` + receives the message and self-elects via ``targets`` and + the optional ``where`` predicate. + """ + return self._run(self._async_publish_broadcast(envelope)) + + async def _async_publish_broadcast(self, envelope: Dict[str, Any]) -> None: + subject = f"device-connect.{self.zone}.broadcast" + await self._client.publish(subject, json.dumps(envelope).encode()) + def broadcast( self, function: str, params: Optional[Dict[str, Any]] = None, timeout: float = 5.0, ) -> List[Dict[str, Any]]: - """Invoke a function on all discovered devices and collect results.""" + """Invoke a function on all discovered devices and collect results. + + Sequential sync fan-out (one invoke per device). Predates the + selector-driven broadcast tool; left in place for callers that want + a simple "call this on everyone" without setting up subscriptions. + """ devices = self.list_devices() results = [] for d in devices: diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py b/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py index 803160e..a9a8e5b 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py @@ -4,40 +4,38 @@ """Device Connect device operations — framework-agnostic tool functions. -Hierarchical discovery tools that keep LLM context small: +Discovery is selector-driven. ``discover()`` and ``discover_labels()`` cover +both fleet-wide and entity-scoped queries; the older ``describe_fleet`` / +``list_devices`` / ``get_device_functions`` trio remains as advisory-deprecated +wrappers for one release while callers migrate. -1. ``describe_fleet()`` — bird's-eye summary (types, locations, counts) -2. ``list_devices(...)`` — paginated compact roster (no schemas) -3. ``get_device_functions(id)`` — full schemas for ONE device -4. ``invoke_device(...)`` — call a function on a device - -Plain Python functions with type hints and docstrings. Use them directly -or wrap with a framework adapter: - - # Plain Python - from device_connect_agent_tools import connect, describe_fleet, list_devices + from device_connect_agent_tools import connect, discover, discover_labels connect() - fleet = describe_fleet() - devices = list_devices(device_type="camera") - - # Strands - from device_connect_agent_tools.adapters.strands import ( - describe_fleet, list_devices, get_device_functions, invoke_device, - ) - agent = Agent(tools=[describe_fleet, list_devices, get_device_functions, invoke_device]) + cams = discover("device(category:camera)") + rgb_writes = discover("device(*).function(direction:write, modality:rgb)") + vocab = discover_labels() """ from __future__ import annotations import logging import os +import time import uuid +import warnings from typing import Any +from device_connect_edge.selector import ( + Scope, + Selector, + SelectorParseError, + parse_selector, +) from device_connect_agent_tools.connection import get_connection from device_connect_agent_tools._normalize import ( full_device, compact_device, fuzzy_filter_by_type, extract_status, aggregate_fleet, group_devices, + label_histogram, ) logger = logging.getLogger(__name__) @@ -55,10 +53,1126 @@ ) SMALL_FLEET_THRESHOLD = 5 +# When ``discover()`` resolves a selector to this many functions or events +# or fewer, the response includes full schemas inline. Above the threshold +# it returns a compact ``(device_id, name, labels)`` summary so the agent +# can narrow further via ``discover_labels()`` or a tighter selector. +try: + DC_FUNCTION_THRESHOLD = min(max(int(os.getenv("DEVICE_CONNECT_FUNCTION_THRESHOLD", "20")), 0), 200) +except (ValueError, TypeError): + logger.warning( + "Invalid DEVICE_CONNECT_FUNCTION_THRESHOLD value %r, defaulting to 20", + os.getenv("DEVICE_CONNECT_FUNCTION_THRESHOLD"), + ) + DC_FUNCTION_THRESHOLD = 20 + +# Hard ceiling on per-call ``limit`` to prevent runaway responses in large +# fleets. A caller asking for limit=100000 still gets at most this many +# rows per page (with ``next_offset`` to continue). +DISCOVER_HARD_LIMIT = 1000 + +# Default limits per the discovery design (different defaults for the two +# tools because they answer different questions: ``discover`` returns rows, +# ``discover_labels`` returns vocabulary). +DEFAULT_DISCOVER_LIMIT = 200 +DEFAULT_DISCOVER_LABELS_LIMIT = 50 + # ── Shared helpers ────────────────────────────────────────────── +def _normalize_pagination(offset: int, limit: int, default_limit: int) -> tuple[int, int]: + """Clamp offset and limit to safe ranges. + + Negative offset rounds to 0, non-positive limit falls back to the default, + and limit is capped at ``DISCOVER_HARD_LIMIT``. + """ + safe_offset = max(0, int(offset or 0)) + if not limit or limit <= 0: + safe_limit = default_limit + else: + safe_limit = min(int(limit), DISCOVER_HARD_LIMIT) + return safe_offset, safe_limit + + +def _error(code: str, message: str) -> dict[str, str]: + """Build the canonical structured error object. + + Errors are returned as data (not raised) inside the response envelope. + The ``code`` is a stable, machine-readable string callers may switch on; + ``message`` is human-readable and may include positional detail (parse + caret, axis name, etc.) suitable for logging or surfacing to the user. + + Codes currently emitted: + - ``selector_parse_error`` selector string is malformed + - ``invalid_selector`` selector is not a usable input + (None, non-string, etc.) + - ``connection_error`` registry / messaging unavailable + - ``key_not_axis_qualified`` discover_labels key missing axis prefix + - ``unknown_axis`` discover_labels axis not in + {device, function, event} + """ + return {"code": code, "message": message} + + +def _empty_envelope( + scope: str | None = None, error: dict[str, str] | None = None +) -> dict[str, Any]: + """Build the canonical zero-result response envelope.""" + out: dict[str, Any] = { + "matched": 0, + "returned": 0, + "offset": 0, + "next_offset": None, + "results": [], + } + if scope is not None: + out["scope"] = scope + if error is not None: + out["error"] = error + return out + + +def _paginate(items: list, offset: int, limit: int) -> tuple[list, int | None]: + """Slice ``items`` to one page; return ``(page, next_offset)``.""" + end = offset + limit + page = items[offset:end] + next_offset = end if end < len(items) else None + return page, next_offset + + +def _device_summary_for_discover(d: dict, expand: bool) -> dict[str, Any]: + """Compact device row for ``discover()``, with labels surfaced.""" + summary = compact_device(d, expand) + summary["status"] = extract_status(d) + summary["labels"] = d.get("labels") + return summary + + +def _function_row(d: dict, fn: dict, expand: bool) -> dict[str, Any]: + """Build one row for a function-scoped discover result. + + Below the threshold, ``expand`` is True and the row includes the full + JSON Schema. Above threshold, only name + labels travel back so the + agent can narrow without paying for parameter schemas. + """ + name = fn.get("name") if isinstance(fn, dict) else fn + labels = fn.get("labels") if isinstance(fn, dict) else None + if expand and isinstance(fn, dict): + return { + "device_id": d.get("device_id"), + "name": name, + "description": fn.get("description", ""), + "parameters": fn.get("parameters", {}), + "labels": labels, + } + return { + "device_id": d.get("device_id"), + "name": name, + "labels": labels, + } + + +def _event_row(d: dict, ev: dict, expand: bool) -> dict[str, Any]: + """Build one row for an event-scoped discover result.""" + name = ev.get("name") if isinstance(ev, dict) else ev + labels = ev.get("labels") if isinstance(ev, dict) else None + if expand and isinstance(ev, dict): + return { + "device_id": d.get("device_id"), + "name": name, + "description": ev.get("description", ""), + "payload_schema": ev.get("payload_schema"), + "labels": labels, + } + return { + "device_id": d.get("device_id"), + "name": name, + "labels": labels, + } + + +# ── Selector-driven discovery (preferred) ──────────────────────── + + +def discover( + selector: str, + offset: int = 0, + limit: int = DEFAULT_DISCOVER_LIMIT, +) -> dict[str, Any]: + """Resolve a selector to matched devices, functions, or events. + + The selector DSL supports five scope shapes: + + device() all matching devices + device().function() RPCs on a device subset + device().event() events on a device subset + function() all RPCs across the fleet + event() all events across the fleet + + Inside ``(...)``: ``key:value``, ``key:[v1,v2]`` (OR within a key), + ``key:pattern*`` (glob), ``k1:v1,k2:v2`` (AND across keys), bare-string + id/name match, or ``*`` to match all. + + Args: + selector: A selector expression string. + offset: Pagination offset (rows skipped). + limit: Max rows per page (capped at DISCOVER_HARD_LIMIT). + + Returns: + A response envelope: + ``{"scope", "matched", "returned", "offset", "next_offset", "results", + "label_histogram"}``. + ``label_histogram`` is the per-key vocabulary across the **matched** + set (pre-pagination), not the returned page; on the device axis it + tracks unique device counts per key (``unique_devices``), on + function/event axes it counts occurrences (a function appearing on N + devices contributes N entries). + For function- and event-scoped selectors, ``results`` rows include + full schemas when the matched count is at or below + ``DC_FUNCTION_THRESHOLD``; otherwise rows are name-and-labels summaries. + + Example: + >>> discover("device(category:camera, location:zone-A/*)") + {"scope": "device_only", "matched": 4, ...} + >>> discover("device(*).function(direction:write, modality:rgb)") + {"scope": "device_function", "matched": 8, ...} + """ + safe_offset, safe_limit = _normalize_pagination(offset, limit, DEFAULT_DISCOVER_LIMIT) + + # Parse the selector at the system boundary; surface a clean error to + # the caller rather than raising into agent code. + if not isinstance(selector, str): + return _empty_envelope( + error=_error( + "invalid_selector", + f"Selector must be a string, got {type(selector).__name__}", + ) + ) + try: + sel: Selector = parse_selector(selector) + except SelectorParseError as e: + return _empty_envelope(error=_error("selector_parse_error", str(e))) + + try: + conn = get_connection() + devices = conn.list_devices() + except Exception as e: + logger.error("discover(%r) failed loading fleet: %s", selector, e) + return _empty_envelope( + scope=sel.scope.value, error=_error("connection_error", str(e)) + ) + + # Apply the device-axis filter (vacuously True when sel.device is None). + matched_devices = [ + d for d in devices + if sel.device is None + or sel.device.matches(d.get("device_id") or "", d.get("labels")) + ] + + # Branch on scope. Each branch produces (results_full, page, histogram, total). + if sel.scope == Scope.DEVICE_ONLY: + total = len(matched_devices) + page_devices, next_offset = _paginate(matched_devices, safe_offset, safe_limit) + expand = SMALL_FLEET_THRESHOLD > 0 and total <= SMALL_FLEET_THRESHOLD + results = [_device_summary_for_discover(d, expand) for d in page_devices] + histogram, multivalued, unique = label_histogram(matched_devices, count_unique=True) + formatted_histogram = _format_label_histogram(histogram, multivalued, unique) + return { + "scope": sel.scope.value, + "matched": total, + "returned": len(results), + "offset": safe_offset, + "next_offset": next_offset, + "results": results, + "label_histogram": formatted_histogram, + } + + # Function- or event-scoped selectors enumerate (device, entity) tuples. + is_function_scope = sel.scope in (Scope.DEVICE_FUNCTION, Scope.FUNCTION_ONLY) + entity_filter = sel.function if is_function_scope else sel.event + + matched_rows: list[tuple[dict, dict]] = [] + for d in matched_devices: + entities = d.get("functions" if is_function_scope else "events", []) + for entity in entities: + if not isinstance(entity, dict): + # Best-effort: lift bare-name list items into a stub dict so the + # filter can still match by name. + entity = {"name": str(entity), "labels": None} + if entity_filter is None or entity_filter.matches( + entity.get("name") or "", entity.get("labels") + ): + matched_rows.append((d, entity)) + + total = len(matched_rows) + page_rows, next_offset = _paginate(matched_rows, safe_offset, safe_limit) + expand = DC_FUNCTION_THRESHOLD > 0 and total <= DC_FUNCTION_THRESHOLD + if is_function_scope: + results = [_function_row(d, fn, expand) for d, fn in page_rows] + else: + results = [_event_row(d, ev, expand) for d, ev in page_rows] + + matched_entities = [entity for _, entity in matched_rows] + histogram, multivalued = label_histogram(matched_entities) + formatted_histogram = _format_label_histogram(histogram, multivalued) + + return { + "scope": sel.scope.value, + "matched": total, + "returned": len(results), + "offset": safe_offset, + "next_offset": next_offset, + "results": results, + "label_histogram": formatted_histogram, + } + + +def _format_label_histogram( + histogram: dict, + multivalued: set, + unique: dict | None = None, +) -> dict[str, Any]: + """Format a histogram for response, annotating multi-valued keys. + + Multi-valued keys are flagged so an agent reading + ``{camera: 312, inference: 200}`` knows the counts overlap. When + ``unique`` is supplied (device axis only), the per-key unique device + count is exposed as ``unique_devices`` so the agent can reconcile + histogram totals with the underlying device cardinality. + """ + out: dict[str, Any] = {} + for key, counts in histogram.items(): + entry: dict[str, Any] = { + # Sort values most-frequent first; alphabetical tie-break for stability. + "values": dict(sorted(counts.items(), key=lambda kv: (-kv[1], kv[0]))), + } + if key in multivalued: + entry["multivalued"] = True + if unique is not None and key in unique: + entry["unique_devices"] = unique[key] + out[key] = entry + return out + + +def discover_labels( + key: str | None = None, + offset: int = 0, + limit: int = DEFAULT_DISCOVER_LABELS_LIMIT, +) -> dict[str, Any]: + """Return the fleet's label vocabulary. + + Without ``key``: returns one entry per axis (``device_keys``, + ``function_keys``, ``event_keys``) with all keys and their top values. + With ``key`` (e.g. ``"device.location"``, ``"function.direction"``): + paginates the full value list for that one key. + + Args: + key: Optional dotted axis.key (``device.``, ``function.``, + ``event.``). When given, the response paginates that one key's + values rather than returning a multi-axis vocabulary. + offset: Pagination offset for the per-key value list. + limit: Max values per page when ``key`` is given (capped at + ``DISCOVER_HARD_LIMIT``). + + Returns: + Multi-axis form (no ``key``): + ``{"total_devices", "total_functions", "total_events", + "device_keys": {key: {"values": {...}, "multivalued"?: True, + "unique_devices"?: N}}, + "function_keys": {...}, "event_keys": {...}}`` + Per-key form (``key`` provided): + ``{"axis", "key", "matched", "returned", "offset", "next_offset", + "values", "multivalued"?: True}`` + """ + safe_offset, safe_limit = _normalize_pagination(offset, limit, DEFAULT_DISCOVER_LABELS_LIMIT) + + try: + conn = get_connection() + devices = conn.list_devices() + except Exception as e: + logger.error("discover_labels failed loading fleet: %s", e) + return _empty_envelope(error=_error("connection_error", str(e))) + + # Aggregate function and event entities once. + functions: list[dict] = [] + events: list[dict] = [] + for d in devices: + for fn in d.get("functions", []) or []: + if isinstance(fn, dict): + functions.append(fn) + for ev in d.get("events", []) or []: + if isinstance(ev, dict): + events.append(ev) + + dev_hist, dev_mv, dev_unique = label_histogram(devices, count_unique=True) + fn_hist, fn_mv = label_histogram(functions) + ev_hist, ev_mv = label_histogram(events) + + if key is None: + return { + "total_devices": len(devices), + "total_functions": len(functions), + "total_events": len(events), + "device_keys": _format_label_histogram(dev_hist, dev_mv, dev_unique), + "function_keys": _format_label_histogram(fn_hist, fn_mv), + "event_keys": _format_label_histogram(ev_hist, ev_mv), + } + + # Per-key form: split on the first dot to pick an axis. + if "." not in key: + return _empty_envelope( + error=_error( + "key_not_axis_qualified", + f"Key must be axis-qualified (device., function., event.): {key!r}", + ) + ) + axis, label_key = key.split(".", 1) + if axis == "device": + source, multivalued = dev_hist, dev_mv + total = len(devices) + elif axis == "function": + source, multivalued = fn_hist, fn_mv + total = len(functions) + elif axis == "event": + source, multivalued = ev_hist, ev_mv + total = len(events) + else: + return _empty_envelope( + error=_error( + "unknown_axis", + f"Unknown axis {axis!r} (expected device|function|event)", + ) + ) + + counts = source.get(label_key, {}) + sorted_values = sorted(counts.items(), key=lambda kv: (-kv[1], kv[0])) + page = sorted_values[safe_offset:safe_offset + safe_limit] + next_offset = safe_offset + safe_limit if safe_offset + safe_limit < len(sorted_values) else None + out: dict[str, Any] = { + "axis": axis, + "key": label_key, + "matched": len(sorted_values), + "returned": len(page), + "offset": safe_offset, + "next_offset": next_offset, + "values": dict(page), + "axis_total": total, + } + if label_key in multivalued: + out["multivalued"] = True + return out + + +# ── Selector-driven operations ─────────────────────────────────── + + +# Default per-target timeout for invoke_many fan-out. Configurable per call. +DEFAULT_INVOKE_TIMEOUT = 30.0 + +# Cap on parallel worker threads for invoke_many fan-out. Larger fleets can +# raise this via the ``max_concurrency`` argument; the default keeps thread +# overhead bounded while still parallelising typical 10-100 device fan-outs. +DEFAULT_INVOKE_CONCURRENCY = 32 + + +def _resolve_function_tuples( + selector: str, +) -> tuple[list[dict] | None, dict[str, Any] | None]: + """Resolve a selector to (device_id, function_name) tuples for invocation. + + Walks pagination so callers do not have to. Returns ``(rows, None)`` on + success or ``(None, error_envelope)`` if the selector failed to parse, + used a non-function scope, or the registry was unreachable. + """ + rows: list[dict] = [] + offset = 0 + while True: + page = discover(selector, offset=offset, limit=DISCOVER_HARD_LIMIT) + if "error" in page: + return None, page + if page["scope"] not in ( + Scope.DEVICE_FUNCTION.value, Scope.FUNCTION_ONLY.value, + ): + return None, _empty_envelope( + scope=page["scope"], + error=_error( + "invalid_invoke_scope", + "invoke/invoke_many require a function-scoped selector " + "(device(...).function(...) or function(...)); got " + f"scope={page['scope']!r}", + ), + ) + rows.extend(page["results"]) + if page["next_offset"] is None: + break + offset = page["next_offset"] + return rows, None + + +def _shape_invoke_response( + response: dict[str, Any], + device_id: str, + function_name: str, +) -> dict[str, Any]: + """Normalize a JSON-RPC response into a {success, result|error} envelope. + + JSON-RPC error objects arrive as ``{"code": int, "message": str}`` from + the wire; this maps them to the structured ``{code: str, message: str}`` + error shape that the rest of the agent surface uses. + """ + if "error" in response: + err = response["error"] + if isinstance(err, dict): + code = str(err.get("code", "invoke_failed")) + message = str(err.get("message", err)) + else: + code, message = "invoke_failed", str(err) + return { + "success": False, + "device_id": device_id, + "function": function_name, + "error": {"code": code, "message": message}, + } + return { + "success": True, + "device_id": device_id, + "function": function_name, + "result": response.get("result", {}), + } + + +def _clean_params_with_mandate( + params: dict[str, Any] | None, + mandate: dict[str, Any] | None, +) -> dict[str, Any]: + clean = {k: v for k, v in (params or {}).items() if k != "llm_reasoning"} + if mandate is None: + return clean + meta = clean.get("_dc_meta") + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise ValueError("_dc_meta must be an object when mandate is provided") + else: + meta = dict(meta) + meta["mandate"] = mandate + clean["_dc_meta"] = meta + return clean + + +def invoke( + selector: str, + params: dict[str, Any] | None = None, + llm_reasoning: str | None = None, + mandate: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Resolve a selector to one (device, function) tuple and invoke it. + + Use this when the call is unambiguous -- one device, one function. + The selector must use ``device().function()`` or + ``function()`` scope. + + Args: + selector: Selector expression resolving to exactly one function. + params: Function parameters dict. Do NOT put ``llm_reasoning`` + inside ``params``. + llm_reasoning: Decision rationale for observability. + + Returns: + On success: ``{"success": True, "device_id": ..., "function": ..., + "result": ...}``. + On failure: ``{"success": False, "error": {"code": ..., + "message": ...}}``. Codes include the discover() codes plus + ``no_match`` (zero matches), ``ambiguous_match`` (multiple + matches), ``invalid_invoke_scope`` (selector did not target + functions), and ``invoke_failed`` (the device returned an error). + """ + rows, error_envelope = _resolve_function_tuples(selector) + if error_envelope is not None: + return {"success": False, "error": error_envelope["error"]} + + if not rows: + return { + "success": False, + "error": _error( + "no_match", + f"selector matched 0 functions: {selector!r}", + ), + } + if len(rows) > 1: + return { + "success": False, + "error": _error( + "ambiguous_match", + f"selector matched {len(rows)} functions, expected exactly 1: " + f"{selector!r}", + ), + "candidates": [ + {"device_id": r.get("device_id"), "function": r.get("name")} + for r in rows[:10] + ], + } + + row = rows[0] + device_id = row.get("device_id") or "" + function_name = row.get("name") or "" + + trace_id = f"trace-{uuid.uuid4().hex[:12]}" + if llm_reasoning: + truncated = ( + llm_reasoning[:200] + "..." + if len(llm_reasoning) > 200 else llm_reasoning + ) + logger.info( + "[%s] [%s::%s] Reason: %s", + trace_id, device_id, function_name, truncated, + ) + + try: + conn = get_connection() + clean = _clean_params_with_mandate(params, mandate) + response = conn.invoke(device_id, function_name, params=clean) + except Exception as e: + logger.error( + "[%s] %s::%s -> ERROR: %s", + trace_id, device_id, function_name, e, + ) + return { + "success": False, + "device_id": device_id, + "function": function_name, + "error": _error("invoke_failed", str(e)), + } + return _shape_invoke_response(response, device_id, function_name) + + +def invoke_many( + selector: str, + params: dict[str, Any] | None = None, + timeout: float = DEFAULT_INVOKE_TIMEOUT, + max_concurrency: int = DEFAULT_INVOKE_CONCURRENCY, + llm_reasoning: str | None = None, + mandate: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Resolve a selector to (device, function) tuples and invoke each in parallel. + + Returns aggregated results with partial-failure semantics: a single + target's failure does not abort the rest. Each target gets ``timeout`` + seconds; the overall call returns once every target has finished or + timed out. + + Args: + selector: Function-scoped selector + (``device(...).function(...)`` or ``function(...)``). + params: Function parameters dict applied to every target. + timeout: Per-target timeout in seconds. + max_concurrency: Cap on parallel worker threads. + llm_reasoning: Decision rationale for observability. + + Returns: + ``{"candidates": N, "matched": N, "succeeded": S, "failed": F, + "results": [{device_id, function, result}, ...], + "errors": [{device_id, function, error}, ...]}``. + + ``candidates`` is the count returned by the selector resolver. + ``matched`` is the same value in this release; once edge-side + ``where`` predicates land, ``matched`` will narrow below + ``candidates`` to reflect post-predicate self-election. + + On selector parse / connection failure the envelope is returned + with all counts at zero plus a top-level ``error`` field. + """ + import concurrent.futures + + rows, error_envelope = _resolve_function_tuples(selector) + if error_envelope is not None: + return { + "candidates": 0, "matched": 0, "succeeded": 0, "failed": 0, + "results": [], "errors": [], "error": error_envelope["error"], + } + + out: dict[str, Any] = { + "candidates": len(rows), + "matched": len(rows), + "succeeded": 0, + "failed": 0, + "results": [], + "errors": [], + } + if not rows: + return out + + workers = max(1, min(max_concurrency, len(rows))) + try: + clean = _clean_params_with_mandate(params, mandate) + except ValueError as e: + return { + "candidates": len(rows), "matched": len(rows), "succeeded": 0, "failed": len(rows), + "results": [], "errors": [], "error": _error("invalid_params", str(e)), + } + + def call_one(row: dict) -> dict[str, Any]: + device_id = row.get("device_id") or "" + function_name = row.get("name") or "" + try: + conn = get_connection() + response = conn.invoke( + device_id, function_name, params=clean, timeout=timeout, + ) + except Exception as e: + response = {"error": {"code": "invoke_failed", "message": str(e)}} + return _shape_invoke_response(response, device_id, function_name) + + if llm_reasoning: + truncated = ( + llm_reasoning[:200] + "..." + if len(llm_reasoning) > 200 else llm_reasoning + ) + logger.info( + "[invoke_many::%d targets] Reason: %s", len(rows), truncated, + ) + + with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as exe: + futures = [exe.submit(call_one, row) for row in rows] + for future in concurrent.futures.as_completed(futures): + shaped = future.result() + if shaped["success"]: + out["results"].append({ + "device_id": shaped["device_id"], + "function": shaped["function"], + "result": shaped["result"], + }) + out["succeeded"] += 1 + else: + out["errors"].append({ + "device_id": shaped["device_id"], + "function": shaped["function"], + "error": shaped["error"], + }) + out["failed"] += 1 + return out + + +def broadcast( + selector: str, + params: dict[str, Any] | None = None, + where: str | None = None, + bindings: dict[str, Any] | None = None, + fire_at: float | None = None, + on_late: str = "skip", + llm_reasoning: str | None = None, + mandate: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Async selector-driven fan-out. Returns immediately with a correlation id. + + Use ``broadcast`` when the caller does not want to block on the slowest + device. Each candidate self-elects via the optional ``where`` predicate + (CEL, evaluated at the edge against the device's identity, labels, live + status, and the shared ``bindings``) and emits its reply as an event on + a per-device subject keyed by ``correlation_id``:: + + device-connect...event.async_reply. + + Subscribe to those replies via ``subscribe('correlation:')`` or wait + for them with ``await_replies(correlation_id, timeout=...)``. + + Args: + selector: Function-scoped selector. The selector must resolve to a + single function name across the matched devices; if multiple + functions match, an ``ambiguous_function`` error is returned. + params: Function parameters dict applied to every target. + where: Optional CEL predicate evaluated at the edge per candidate + (e.g. ``"status.battery > 50"``, ``"mask[row][col] == 1"``). + Validated at the dispatcher before publication so syntax + errors return immediately rather than reaching the wire. + bindings: Shared payload merged into the predicate context as + ``bindings.``. Keep small (selection masks, thresholds, + top-K rankings); the same bytes ship to every device. + fire_at: Optional wall-clock epoch seconds. Each device holds the + message and fires its function from its own clock at + ``fire_at`` for synchronized fan-out. + on_late: Policy when a device receives a ``fire_at`` message after + the deadline. ``"skip"`` (default) drops the call; ``"fire"`` + executes immediately. + llm_reasoning: Decision rationale for observability. + + Returns: + On success: ``{"correlation_id": "br-...", "candidates": N, + "selector": ..., "function": ...}``. + On failure: ``{"candidates": 0, "error": {"code", "message"}}`` + with codes including the discover() codes, + ``invalid_invoke_scope``, ``ambiguous_function``, + ``invalid_predicate``, and ``invalid_on_late``. + """ + if on_late not in ("skip", "fire"): + return { + "candidates": 0, + "error": _error( + "invalid_on_late", + f"on_late must be 'skip' or 'fire', got {on_late!r}", + ), + } + + rows, error_envelope = _resolve_function_tuples(selector) + if error_envelope is not None: + return {"candidates": 0, "error": error_envelope["error"]} + + if not rows: + # Empty fan-out: still mint a correlation id so callers waiting on + # replies see a clean "no candidates" rather than a hang. + return { + "correlation_id": f"br-{uuid.uuid4().hex[:12]}", + "candidates": 0, + "selector": selector, + } + + # Broadcast assumes one function per call. If the selector resolves to + # multiple distinct functions, surface that as a structured error so + # the caller can either narrow the selector or split into multiple + # broadcasts. + function_names = {row.get("name") for row in rows if row.get("name")} + if len(function_names) != 1: + return { + "candidates": len(rows), + "error": _error( + "ambiguous_function", + f"selector resolved to {len(function_names)} distinct " + "functions; broadcast requires exactly one function per call: " + f"{sorted(function_names)!r}", + ), + } + function_name = next(iter(function_names)) + + # Compile-validate the where predicate before going to the wire so a + # syntax error short-circuits without bothering devices. + if where is not None: + try: + from device_connect_edge.predicate import compile_where + compile_where(where) + except Exception as e: + return { + "candidates": len(rows), + "error": _error("invalid_predicate", str(e)), + } + + correlation_id = f"br-{uuid.uuid4().hex[:12]}" + targets = sorted({ + row.get("device_id") for row in rows if row.get("device_id") + }) + try: + clean_params = _clean_params_with_mandate(params, mandate) + except ValueError as e: + return { + "candidates": len(targets), + "error": _error("invalid_params", str(e)), + } + + envelope: dict[str, Any] = { + "correlation_id": correlation_id, + "function": function_name, + "params": clean_params, + "targets": targets, + } + if where: + envelope["where"] = where + if bindings: + envelope["bindings"] = bindings + if fire_at is not None: + envelope["fire_at"] = float(fire_at) + envelope["on_late"] = on_late + + if llm_reasoning: + truncated = ( + llm_reasoning[:200] + "..." + if len(llm_reasoning) > 200 else llm_reasoning + ) + logger.info( + "[broadcast::%s::%d targets] Reason: %s", + correlation_id, len(targets), truncated, + ) + + try: + conn = get_connection() + conn.publish_broadcast(envelope) + except Exception as e: + logger.error("broadcast publish failed: %s", e) + return { + "candidates": len(targets), + "error": _error("connection_error", str(e)), + } + + return { + "correlation_id": correlation_id, + "candidates": len(targets), + "selector": selector, + "function": function_name, + } + + +# ── Selector-driven subscription ───────────────────────────────── + + +# Sentinel used to recognise the broadcast-reply form of a subscribe +# selector (``correlation:``). Kept short so the selector reads +# naturally; the parser matches an exact prefix. +_CORRELATION_PREFIX = "correlation:" + + +class Subscription: + """A live subscription handle returned by :func:`subscribe`. + + Two selector forms produce a subscription: + + * ``"correlation:"`` -- replies from a prior :func:`broadcast` call, + keyed by ``correlation_id`` and routed across all devices that fired. + * Event-scoped selectors (``event()`` or + ``device(...).event()``) -- a multiplex of matching events + across the resolved candidate set. + + The handle exposes a sync ``read`` API that drains buffered messages. + Use as a context manager (or call :meth:`close`) to tear the + underlying messaging subscription down deterministically:: + + with subscribe("correlation:" + cid) as sub: + for reply in sub.iter(timeout=5.0): + process(reply) + """ + + def __init__(self, conn: Any, inbox_names: list[str]): + self._conn = conn + self._inbox_names = list(inbox_names) + self._closed = False + self._cursor = 0 # index into the concatenated message stream + + def read(self, max_messages: int | None = None) -> list[dict[str, Any]]: + """Drain currently buffered messages without blocking. + + Returns parsed payload dicts (already JSON-decoded by the + connection's buffered subscription path). Subsequent calls return + only messages that arrived after the previous call. + + Race-safe against the messaging callback that appends to the same + inbox: each inbox is read by snapshotting its current length and + truncating only that prefix, so a message that arrives during + iteration stays buffered for the next ``read``. + """ + if self._closed: + return [] + out: list[dict[str, Any]] = [] + for name in self._inbox_names: + buf = self._conn._inbox.get(name) or [] + # Snapshot the consumed prefix length BEFORE iterating, then + # truncate by exactly that many items. Any message appended by + # the messaging callback between the snapshot and the truncation + # remains buffered for a subsequent ``read``. + n = len(buf) + for subject, payload in buf[:n]: + if not isinstance(payload, dict): + payload = {"raw": payload} + out.append({**payload, "_subject": subject}) + self._conn._inbox[name] = buf[n:] + if max_messages is not None: + out = out[:max_messages] + return out + + def iter(self, timeout: float = 5.0, poll_interval: float = 0.05): + """Yield messages until ``timeout`` elapses with no new arrivals. + + ``timeout`` resets each time at least one message is yielded, so + callers can drain a steady stream without re-parameterising the + wait. Use ``read`` instead for one-shot draining. + """ + deadline = time.monotonic() + timeout + while not self._closed: + new = self.read() + if new: + for msg in new: + yield msg + deadline = time.monotonic() + timeout + continue + if time.monotonic() >= deadline: + return + time.sleep(poll_interval) + + def __iter__(self): + """Allow ``for msg in sub:`` with a default 30-second idle timeout. + + Delegates to :meth:`iter` with sensible defaults so the idiomatic + Python iteration form works. Use ``sub.iter(timeout=...)`` directly + when the default does not fit. + """ + return self.iter(timeout=30.0, poll_interval=0.05) + + def close(self) -> None: + """Tear down the underlying messaging subscriptions.""" + if self._closed: + return + self._closed = True + for name in self._inbox_names: + try: + self._conn.unsubscribe_buffered(name) + except Exception: # pragma: no cover - cleanup best effort + logger.debug("close: unsubscribe %s failed", name, exc_info=True) + + def __enter__(self) -> "Subscription": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + self.close() + + +def _correlation_subjects(conn: Any, correlation_id: str) -> list[str]: + """Build the per-device wildcard reply subjects for a correlation id. + + The reply template is ``device-connect...event + .async_reply.``; ```` is single-token wildcarded + so a subscription receives replies from any device that fires the + broadcast without having to enumerate them up-front. + """ + return [ + f"device-connect.{conn.zone}.*.event.async_reply.{correlation_id}", + ] + + +def _event_subjects_for_selector(selector: str) -> tuple[list[str] | None, dict[str, Any] | None]: + """Resolve an event-scoped selector to per-device subjects. + + Returns ``(subjects, None)`` on success or ``(None, error_envelope)`` + if the selector failed to parse or used a non-event scope. + """ + rows: list[dict] = [] + offset = 0 + while True: + page = discover(selector, offset=offset, limit=DISCOVER_HARD_LIMIT) + if "error" in page: + return None, page + if page["scope"] not in (Scope.DEVICE_EVENT.value, Scope.EVENT_ONLY.value): + return None, _empty_envelope( + scope=page["scope"], + error=_error( + "invalid_subscribe_scope", + "subscribe requires an event-scoped selector " + "(device(...).event(...) or event(...)) or " + "'correlation:'; got " + f"scope={page['scope']!r}", + ), + ) + rows.extend(page["results"]) + if page["next_offset"] is None: + break + offset = page["next_offset"] + + conn = get_connection() + subjects: list[str] = [] + seen: set[str] = set() + for row in rows: + device_id = row.get("device_id") or "" + event_name = row.get("name") or "" + if not device_id or not event_name: + continue + subj = f"device-connect.{conn.zone}.{device_id}.event.{event_name}" + if subj not in seen: + seen.add(subj) + subjects.append(subj) + return subjects, None + + +def subscribe(selector: str) -> Subscription: + """Subscribe to events or broadcast replies matching a selector. + + Args: + selector: One of: + - ``"correlation:"`` for broadcast replies of a prior call. + - An event-scoped selector (``event()`` or + ``device(...).event()``) for live event streams. + + Returns: + A :class:`Subscription` handle. Iterate with ``sub.iter(timeout)`` + or drain currently-buffered messages with ``sub.read()``. Always + close (or use ``with``) to tear the underlying subscription down. + + Raises: + ValueError on selector errors. The selector string is checked at + the boundary; downstream subscribe calls are not retried, so a + parse error fails fast. + """ + if not isinstance(selector, str) or not selector.strip(): + raise ValueError("subscribe selector must be a non-empty string") + + conn = get_connection() + if selector.startswith(_CORRELATION_PREFIX): + correlation_id = selector[len(_CORRELATION_PREFIX):].strip() + if not correlation_id: + raise ValueError( + "correlation form must be 'correlation:' with non-empty id" + ) + subjects = _correlation_subjects(conn, correlation_id) + inbox_prefix = f"sub-corr-{correlation_id}-{uuid.uuid4().hex[:8]}" + else: + subjects, error_envelope = _event_subjects_for_selector(selector) + if error_envelope is not None: + err = error_envelope.get("error") + msg = err.get("message", str(err)) if isinstance(err, dict) else str(err) + raise ValueError(msg) + if not subjects: + # Nothing to subscribe to. Return an idle Subscription so the + # caller's ``with subscribe(...) as sub: ...`` pattern still + # works without raising; ``read``/``iter`` will yield nothing. + return Subscription(conn, inbox_names=[]) + inbox_prefix = f"sub-evt-{uuid.uuid4().hex[:8]}" + + inbox_names: list[str] = [] + for i, subj in enumerate(subjects): + name = f"{inbox_prefix}-{i}" + conn.subscribe_buffered(subj, name=name) + inbox_names.append(name) + return Subscription(conn, inbox_names=inbox_names) + + +def await_replies( + correlation_id: str, + timeout: float = 10.0, + until: int | None = None, + poll_interval: float = 0.05, +) -> list[dict[str, Any]]: + """Block until ``timeout`` elapses or ``until`` replies have arrived. + + A sync helper for the common broadcast pattern: caller fires a + :func:`broadcast`, then waits for some replies. Builds a one-shot + subscription on the correlation reply subject, drains it, and tears + down before returning. + + Args: + correlation_id: The id returned by :func:`broadcast`. + timeout: Overall wall-clock limit in seconds. + until: Stop early once this many replies have been collected. + poll_interval: How often the helper polls the subscription buffer. + + Returns: + A list of reply payload dicts, each with at least + ``{correlation_id, device_id, success, result|error, + actually_fired_at}``. + """ + if not correlation_id: + return [] + sub = subscribe(f"{_CORRELATION_PREFIX}{correlation_id}") + try: + replies: list[dict[str, Any]] = [] + deadline = time.monotonic() + timeout + while True: + new = sub.read() + replies.extend(new) + if until is not None and len(replies) >= until: + break + if time.monotonic() >= deadline: + break + time.sleep(poll_interval) + return replies + finally: + sub.close() + + # ── Hierarchical discovery tools ───────────────────────────────── @@ -81,7 +1195,18 @@ def describe_fleet() -> dict[str, Any]: Example: fleet = describe_fleet() # {"total_devices": 47, "by_type": {"camera": {"count": 12, ...}}, ...} + + .. deprecated:: + Prefer ``discover_labels()`` (vocabulary) and + ``discover("device(*)")`` (roster). This wrapper will be removed in + a future release. """ + warnings.warn( + "describe_fleet() is deprecated; use discover_labels() for vocabulary " + "or discover('device(*)') for the roster.", + DeprecationWarning, + stacklevel=2, + ) try: conn = get_connection() devices = conn.list_devices() @@ -134,7 +1259,18 @@ def list_devices( # Group by location result = list_devices(group_by="location") + + .. deprecated:: + Prefer ``discover("device(category:camera, location:zone-A/*)")`` -- + the selector DSL covers type/location/group-by uniformly. This + wrapper will be removed in a future release. """ + warnings.warn( + "list_devices() is deprecated; use discover() with a selector " + "(e.g. discover('device(category:camera)')).", + DeprecationWarning, + stacklevel=2, + ) try: conn = get_connection() devices = conn.list_devices(location=location) @@ -194,7 +1330,17 @@ def get_device_functions(device_id: str) -> dict[str, Any]: Example: info = get_device_functions("camera-001") # {"device_id": "camera-001", "functions": [{"name": "capture_image", ...}]} + + .. deprecated:: + Prefer ``discover("device().function(*)")``. This wrapper + will be removed in a future release. """ + warnings.warn( + "get_device_functions() is deprecated; use " + "discover('device().function(*)').", + DeprecationWarning, + stacklevel=2, + ) try: conn = get_connection() device = conn.get_device(device_id) @@ -213,23 +1359,22 @@ def invoke_device( function: str, params: dict[str, Any] | None = None, llm_reasoning: str | None = None, + mandate: dict[str, Any] | None = None, ) -> dict[str, Any]: - """Call a function on a Device Connect device. + """Call a function on a Device Connect device (deprecated; use invoke()). Args: device_id: Target device ID (e.g., "robot-001", "camera-001"). - function: Function name to call (e.g., "start_cleaning", "capture_image"). - params: Function parameters as a dictionary. Check get_device_functions() for schemas. - Do NOT put llm_reasoning inside params. - llm_reasoning: Why you're calling this function — for observability. - - Example: - result = invoke_device( - device_id="robot-001", function="start_cleaning", - params={"zone": "zone-A"}, - llm_reasoning="Camera detected spill in zone-A" - ) + function: Function name to call. + params: Function parameters as a dictionary. + llm_reasoning: Why you are calling this function (for observability). """ + warnings.warn( + "invoke_device(device_id, function, ...) is deprecated; use " + "invoke('device().function()', params) instead.", + DeprecationWarning, + stacklevel=2, + ) trace_id = f"trace-{uuid.uuid4().hex[:12]}" if llm_reasoning: truncated = llm_reasoning[:200] + "..." if len(llm_reasoning) > 200 else llm_reasoning @@ -237,7 +1382,7 @@ def invoke_device( try: conn = get_connection() - clean = {k: v for k, v in (params or {}).items() if k != "llm_reasoning"} + clean = _clean_params_with_mandate(params, mandate) response = conn.invoke(device_id, function, params=clean) if "error" in response: @@ -323,13 +1468,7 @@ def discover_devices( device_type: str | None = None, refresh: bool = False, ) -> list[dict[str, Any]]: - """Discover available devices (deprecated — use list_devices instead). - - Returns all devices with their function schemas. For large fleets, - prefer the hierarchical approach: - 1. describe_fleet() — see what's available - 2. list_devices(...) — browse with filters - 3. get_device_functions(id) — get schemas for one device + """Discover available devices (deprecated; use discover() instead). Args: device_type: Optional filter (e.g., "robot", "camera"). Fuzzy matching. @@ -338,6 +1477,12 @@ def discover_devices( Returns: List of devices with device_id, device_type, functions, events. """ + warnings.warn( + "discover_devices() is deprecated; use discover() with a selector " + "(e.g. discover('device(*)') or discover('device(category:camera)')).", + DeprecationWarning, + stacklevel=2, + ) try: conn = get_connection() # Invalidate cache when refresh is requested diff --git a/packages/device-connect-agent-tools/pyproject.toml b/packages/device-connect-agent-tools/pyproject.toml index ec0f198..606073c 100644 --- a/packages/device-connect-agent-tools/pyproject.toml +++ b/packages/device-connect-agent-tools/pyproject.toml @@ -37,6 +37,7 @@ strands = ["strands-agents>=1.0"] langchain = ["langchain-core>=0.2"] claude = ["claude-agent-sdk>=0.1"] mcp = ["fastmcp>=1.0"] +predicate = ["device-connect-edge[predicate]"] dev = [ "pytest>=8.0", "pytest-asyncio>=0.23", diff --git a/packages/device-connect-agent-tools/tests/test_agent_mandates.py b/packages/device-connect-agent-tools/tests/test_agent_mandates.py new file mode 100644 index 0000000..a889dec --- /dev/null +++ b/packages/device-connect-agent-tools/tests/test_agent_mandates.py @@ -0,0 +1,117 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Agent-tool tests for carrying Device Mandates in _dc_meta.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from device_connect_agent_tools import tools as tools_mod + + +SAMPLE_DEVICES = [ + { + "device_id": "lock-001", + "device_type": "lock", + "status": {"state": "online"}, + "identity": {"device_type": "lock"}, + "labels": {"category": "lock"}, + "functions": [ + { + "name": "unlock", + "parameters": {}, + "labels": {"direction": "write", "safety": "critical"}, + "mandate": {"required": True, "scope": "actuation"}, + }, + ], + "events": [], + }, + { + "device_id": "lock-002", + "device_type": "lock", + "status": {"state": "online"}, + "identity": {"device_type": "lock"}, + "labels": {"category": "lock"}, + "functions": [ + { + "name": "unlock", + "parameters": {}, + "labels": {"direction": "write", "safety": "critical"}, + "mandate": {"required": True, "scope": "actuation"}, + }, + ], + "events": [], + }, +] + + +MANDATE = {"format": "device-connect-hmac-v0", "closed": {"id": "closed-1"}} + + +def _conn(): + conn = MagicMock() + conn.list_devices.return_value = SAMPLE_DEVICES + conn.invoke.return_value = {"jsonrpc": "2.0", "id": "1", "result": {"ok": True}} + conn._published = [] + conn.publish_broadcast.side_effect = lambda env: conn._published.append(env) + return conn + + +def test_invoke_attaches_mandate_under_dc_meta(): + conn = _conn() + with patch.object(tools_mod, "get_connection", return_value=conn): + result = tools_mod.invoke( + "device(lock-001).function(unlock)", + params={"duration_s": 30}, + mandate=MANDATE, + ) + + assert result["success"] is True + sent = conn.invoke.call_args.kwargs["params"] + assert sent["duration_s"] == 30 + assert sent["_dc_meta"]["mandate"] == MANDATE + + +def test_invoke_many_attaches_mandate_to_each_call(): + conn = _conn() + with patch.object(tools_mod, "get_connection", return_value=conn): + tools_mod.invoke_many( + "device(category:lock).function(unlock)", + params={"duration_s": 30}, + mandate=MANDATE, + ) + + assert conn.invoke.call_count == 2 + for call in conn.invoke.call_args_list: + assert call.kwargs["params"]["_dc_meta"]["mandate"] == MANDATE + + +def test_broadcast_attaches_mandate_under_params_dc_meta(): + conn = _conn() + with patch.object(tools_mod, "get_connection", return_value=conn): + result = tools_mod.broadcast( + "device(category:lock).function(unlock)", + params={"duration_s": 30}, + mandate=MANDATE, + ) + + assert result["candidates"] == 2 + env = conn._published[0] + assert env["params"]["duration_s"] == 30 + assert env["params"]["_dc_meta"]["mandate"] == MANDATE + + +def test_legacy_invoke_device_attaches_mandate_under_dc_meta(): + conn = _conn() + with patch.object(tools_mod, "get_connection", return_value=conn): + tools_mod.invoke_device( + "lock-001", + "unlock", + params={"duration_s": 30}, + mandate=MANDATE, + ) + + sent = conn.invoke.call_args.kwargs["params"] + assert sent["_dc_meta"]["mandate"] == MANDATE diff --git a/packages/device-connect-agent-tools/tests/test_broadcast.py b/packages/device-connect-agent-tools/tests/test_broadcast.py new file mode 100644 index 0000000..e8d8831 --- /dev/null +++ b/packages/device-connect-agent-tools/tests/test_broadcast.py @@ -0,0 +1,201 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the selector-driven ``broadcast`` tool. + +Uses the same labeled mock fleet (cam-001, cam-002, robot-001, sensor-001) +as the discover/invoke tests so selectors exercise real device, function, +and event names. +""" +import json +from unittest.mock import MagicMock, patch + +import pytest + +from device_connect_agent_tools import tools as tools_mod + + +SAMPLE_DEVICES = [ + { + "device_id": "cam-001", + "device_type": "camera", + "location": "lab-A", + "status": {"state": "online"}, + "identity": {"device_type": "camera"}, + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [ + { + "name": "capture_image", + "parameters": {}, + "labels": {"direction": "write", "modality": "rgb"}, + }, + ], + "events": [], + }, + { + "device_id": "cam-002", + "device_type": "camera", + "location": "lab-A", + "status": {"state": "online"}, + "identity": {"device_type": "camera"}, + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [ + { + "name": "capture_image", + "parameters": {}, + "labels": {"direction": "write", "modality": "rgb"}, + }, + ], + "events": [], + }, + { + "device_id": "sensor-001", + "device_type": "temperature_sensor", + "location": "lab-B", + "status": {"state": "online"}, + "identity": {"device_type": "temperature_sensor"}, + "labels": {"category": "sensor"}, + "functions": [ + { + "name": "get_reading", + "parameters": {}, + "labels": {"direction": "read"}, + }, + ], + "events": [], + }, +] + + +@pytest.fixture +def mock_conn(): + conn = MagicMock() + conn.list_devices.return_value = SAMPLE_DEVICES + conn.zone = "default" + # Capture the published envelope for assertions. + published: list[dict] = [] + conn.publish_broadcast.side_effect = lambda env: published.append(env) + conn._published = published + with patch.object(tools_mod, "get_connection", return_value=conn): + yield conn + + +# -- broadcast ------------------------------------------------------ + + +class TestBroadcast: + def test_returns_correlation_id_and_candidates(self, mock_conn): + r = tools_mod.broadcast("device(*).function(capture_image)") + assert r["correlation_id"].startswith("br-") + assert r["candidates"] == 2 + assert r["function"] == "capture_image" + assert "error" not in r + + def test_envelope_carries_function_and_targets(self, mock_conn): + tools_mod.broadcast( + "device(*).function(capture_image)", + params={"resolution": "4k"}, + ) + env = mock_conn._published[0] + assert env["function"] == "capture_image" + assert env["params"] == {"resolution": "4k"} + assert sorted(env["targets"]) == ["cam-001", "cam-002"] + # No optional fields when caller did not set them. + assert "where" not in env + assert "bindings" not in env + assert "fire_at" not in env + assert "on_late" not in env + + def test_where_and_bindings_propagate_to_envelope(self, mock_conn): + tools_mod.broadcast( + "device(*).function(capture_image)", + where="status.battery > 50", + bindings={"threshold": 80}, + ) + env = mock_conn._published[0] + assert env["where"] == "status.battery > 50" + assert env["bindings"] == {"threshold": 80} + + def test_fire_at_propagates_with_default_on_late(self, mock_conn): + tools_mod.broadcast( + "device(*).function(capture_image)", + fire_at=123456789.0, + ) + env = mock_conn._published[0] + assert env["fire_at"] == 123456789.0 + assert env["on_late"] == "skip" + + def test_fire_at_with_explicit_on_late_fire(self, mock_conn): + tools_mod.broadcast( + "device(*).function(capture_image)", + fire_at=123.0, on_late="fire", + ) + env = mock_conn._published[0] + assert env["on_late"] == "fire" + + def test_invalid_on_late_rejected(self, mock_conn): + r = tools_mod.broadcast( + "device(*).function(capture_image)", on_late="bogus", + ) + assert r["candidates"] == 0 + assert r["error"]["code"] == "invalid_on_late" + assert mock_conn.publish_broadcast.call_count == 0 + + def test_ambiguous_function_rejected(self, mock_conn): + # function(direction:read) resolves to multiple distinct functions + # (get_reading + dispatch_robot's get_status if it had read; here + # it just hits sensor's get_reading and possibly more). With our + # SAMPLE_DEVICES this matches just get_reading, so artificially + # broaden by picking a selector that crosses functions: + r = tools_mod.broadcast("device(*).function(*)") + assert r["candidates"] == 3 + assert r["error"]["code"] == "ambiguous_function" + + def test_zero_matches_returns_correlation_with_zero(self, mock_conn): + r = tools_mod.broadcast("device(*).function(does_not_exist)") + assert r["candidates"] == 0 + assert r["correlation_id"].startswith("br-") + # No envelope was published (no targets). + assert mock_conn.publish_broadcast.call_count == 0 + + def test_invalid_scope_rejected(self, mock_conn): + r = tools_mod.broadcast("device(cam-001)") + assert r["candidates"] == 0 + assert r["error"]["code"] == "invalid_invoke_scope" + + def test_selector_parse_error_propagated(self, mock_conn): + r = tools_mod.broadcast("widgets(*)") + assert r["candidates"] == 0 + assert r["error"]["code"] == "selector_parse_error" + + def test_invalid_predicate_rejected_before_publish(self, mock_conn): + # The predicate is compile-validated at the dispatcher; a syntax + # error short-circuits without publishing. + try: + import celpy # noqa: F401 + except ImportError: + pytest.skip("cel-python not installed") + r = tools_mod.broadcast( + "device(*).function(capture_image)", where="a > > b", + ) + assert r["error"]["code"] == "invalid_predicate" + assert mock_conn.publish_broadcast.call_count == 0 + + def test_publish_failure_returns_connection_error(self): + conn = MagicMock() + conn.list_devices.return_value = SAMPLE_DEVICES + conn.zone = "default" + conn.publish_broadcast.side_effect = RuntimeError("messaging down") + with patch.object(tools_mod, "get_connection", return_value=conn): + r = tools_mod.broadcast("device(*).function(capture_image)") + assert r["error"]["code"] == "connection_error" + assert "messaging down" in r["error"]["message"] + + def test_llm_reasoning_stripped_from_params(self, mock_conn): + tools_mod.broadcast( + "device(*).function(capture_image)", + params={"resolution": "4k", "llm_reasoning": "should not appear"}, + ) + env = mock_conn._published[0] + assert "llm_reasoning" not in env["params"] diff --git a/packages/device-connect-agent-tools/tests/test_claude_adapter.py b/packages/device-connect-agent-tools/tests/test_claude_adapter.py index ba6fafc..99ff17d 100644 --- a/packages/device-connect-agent-tools/tests/test_claude_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_claude_adapter.py @@ -65,11 +65,13 @@ def _mock_sdk_and_connection(): TOOL_NAMES = ( - "describe_fleet", - "list_devices", - "get_device_functions", + "discover_labels", + "discover", "discover_devices", - "invoke_device", + "invoke", + "invoke_many", + "broadcast", + "await_replies", "invoke_device_with_fallback", "get_device_status", ) @@ -108,3 +110,68 @@ def test_create_server_bundles_all_tools(self): assert server["name"] == "device-connect" bundled = {t._tool_name for t in server["tools"]} assert bundled == set(TOOL_NAMES) + + @pytest.mark.parametrize("name", ("invoke", "invoke_many", "broadcast")) + def test_invocation_schemas_include_optional_mandate(self, name): + from device_connect_agent_tools.adapters import claude as adapter + + schema = getattr(adapter, name)._tool_schema + + assert schema["mandate"] is dict + + +class TestClaudeAdapterMandates: + @pytest.mark.asyncio + async def test_invoke_forwards_mandate(self): + from device_connect_agent_tools.adapters import claude as adapter + + mandate = {"format": "device-connect-hmac-v0", "closed": {"id": "m-1"}} + + with patch.object(adapter, "_invoke", return_value={"success": True}) as invoke: + await adapter.invoke( + { + "selector": "device(lock-001).function(unlock)", + "params": {"duration_s": 30}, + "mandate": mandate, + } + ) + + assert invoke.call_args.kwargs["mandate"] == mandate + + @pytest.mark.asyncio + async def test_invoke_many_forwards_mandate(self): + from device_connect_agent_tools.adapters import claude as adapter + + mandate = {"format": "device-connect-hmac-v0", "closed": {"id": "m-1"}} + + with patch.object( + adapter, "_invoke_many", return_value={"succeeded": 1} + ) as invoke_many: + await adapter.invoke_many( + { + "selector": "device(category:lock).function(unlock)", + "params": {"duration_s": 30}, + "mandate": mandate, + } + ) + + assert invoke_many.call_args.kwargs["mandate"] == mandate + + @pytest.mark.asyncio + async def test_broadcast_forwards_mandate(self): + from device_connect_agent_tools.adapters import claude as adapter + + mandate = {"format": "device-connect-hmac-v0", "closed": {"id": "m-1"}} + + with patch.object( + adapter, "_broadcast", return_value={"candidates": 1} + ) as broadcast: + await adapter.broadcast( + { + "selector": "device(category:lock).function(unlock)", + "params": {"duration_s": 30}, + "mandate": mandate, + } + ) + + assert broadcast.call_args.kwargs["mandate"] == mandate diff --git a/packages/device-connect-agent-tools/tests/test_discover.py b/packages/device-connect-agent-tools/tests/test_discover.py new file mode 100644 index 0000000..efc7de5 --- /dev/null +++ b/packages/device-connect-agent-tools/tests/test_discover.py @@ -0,0 +1,364 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the selector-driven ``discover`` and ``discover_labels`` tools. + +Uses a labeled mock fleet (cam-001, robot-001, sensor-001) drawn from the +existing DC test driver vocabulary so every selector exercises real device, +function, and event names. +""" +from unittest.mock import MagicMock, patch + +import pytest + +from device_connect_agent_tools import tools as tools_mod + + +# -- Fixture: labeled fleet --------------------------------------- + + +SAMPLE_DEVICES = [ + { + "device_id": "cam-001", + "device_type": "camera", + "location": "lab-A", + "status": {"state": "online"}, + "identity": {"device_type": "camera"}, + "labels": {"category": ["camera", "inference"], "location": "zone-A/dock"}, + "functions": [ + { + "name": "capture_image", + "description": "Capture an image", + "parameters": {"type": "object"}, + "labels": {"direction": "write", "modality": "rgb"}, + }, + ], + "events": [ + {"name": "object_detected", "labels": {"modality": "rgb"}}, + {"name": "state_change_detected", "labels": None}, + ], + }, + { + "device_id": "cam-002", + "device_type": "camera", + "location": "lab-A", + "status": {"state": "online"}, + "identity": {"device_type": "camera"}, + "labels": {"category": "camera", "location": "zone-B/dock"}, + "functions": [ + { + "name": "capture_image", + "description": "Capture an image", + "parameters": {"type": "object"}, + "labels": {"direction": "write", "modality": ["rgb", "4k"]}, + }, + ], + "events": [ + {"name": "object_detected", "labels": {"modality": "rgb"}}, + ], + }, + { + "device_id": "robot-001", + "device_type": "cleaner_robot", + "location": "lab-A", + "status": {"state": "idle"}, + "identity": {"device_type": "cleaner_robot"}, + "labels": {"category": "robot", "location": "zone-A/yard"}, + "functions": [ + { + "name": "dispatch_robot", + "description": "Dispatch", + "parameters": {}, + "labels": {"direction": "write", "safety": "critical"}, + }, + { + "name": "get_status", + "description": "Status", + "parameters": {}, + "labels": {"direction": "read"}, + }, + ], + "events": [ + {"name": "cleaning_finished", "labels": None}, + ], + }, + { + "device_id": "sensor-001", + "device_type": "temperature_sensor", + "location": "lab-B", + "status": {"state": "online"}, + "identity": {"device_type": "temperature_sensor"}, + "labels": {"category": "sensor", "location": "lab-B"}, + "functions": [ + {"name": "get_reading", "parameters": {}, "labels": {"direction": "read"}}, + {"name": "set_threshold", "parameters": {}, "labels": {"direction": "write"}}, + {"name": "set_location", "parameters": {}, "labels": {"direction": "write"}}, + ], + "events": [ + {"name": "reading", "labels": None}, + {"name": "threshold_exceeded", "labels": {"safety": "informational"}}, + ], + }, +] + + +@pytest.fixture +def mock_conn(): + conn = MagicMock() + conn.list_devices.return_value = SAMPLE_DEVICES + with patch.object(tools_mod, "get_connection", return_value=conn): + yield conn + + +# -- discover: device-only scope ----------------------------------- + + +class TestDiscoverDeviceOnly: + def test_match_by_category_label(self, mock_conn): + r = tools_mod.discover("device(category:camera)") + assert r["scope"] == "device_only" + assert r["matched"] == 2 + assert {row["device_id"] for row in r["results"]} == {"cam-001", "cam-002"} + + def test_multivalued_match_picks_composite_only(self, mock_conn): + # Only cam-001 has category:[camera, inference]. + r = tools_mod.discover("device(category:inference)") + assert r["matched"] == 1 + assert r["results"][0]["device_id"] == "cam-001" + + def test_or_within_key(self, mock_conn): + r = tools_mod.discover("device(category:[camera,robot])") + assert {row["device_id"] for row in r["results"]} == { + "cam-001", "cam-002", "robot-001" + } + + def test_glob_location(self, mock_conn): + r = tools_mod.discover("device(location:zone-A/*)") + assert {row["device_id"] for row in r["results"]} == {"cam-001", "robot-001"} + + def test_and_across_keys(self, mock_conn): + r = tools_mod.discover( + "device(category:[camera,robot], location:zone-A/*)" + ) + assert {row["device_id"] for row in r["results"]} == {"cam-001", "robot-001"} + + def test_match_all(self, mock_conn): + r = tools_mod.discover("device(*)") + assert r["matched"] == 4 + + def test_bare_id_match(self, mock_conn): + r = tools_mod.discover("device(cam-001)") + assert r["matched"] == 1 + assert r["results"][0]["device_id"] == "cam-001" + + def test_labels_surfaced_in_result(self, mock_conn): + r = tools_mod.discover("device(cam-001)") + assert r["results"][0]["labels"] == { + "category": ["camera", "inference"], + "location": "zone-A/dock", + } + + +# -- discover: function scope -------------------------------------- + + +class TestDiscoverFunctionScope: + def test_writes_fleet_wide(self, mock_conn): + r = tools_mod.discover("device(*).function(direction:write)") + assert r["scope"] == "device_function" + assert r["matched"] == 5 # capture x2, dispatch_robot, set_threshold, set_location + names = {row["name"] for row in r["results"]} + assert names == {"capture_image", "dispatch_robot", "set_threshold", "set_location"} + + def test_function_only_scope_by_name(self, mock_conn): + r = tools_mod.discover("function(get_reading)") + assert r["scope"] == "function_only" + assert r["matched"] == 1 + assert r["results"][0]["name"] == "get_reading" + assert r["results"][0]["device_id"] == "sensor-001" + + def test_anchored_glob_set_prefix(self, mock_conn): + r = tools_mod.discover("function(set_*)") + assert {row["name"] for row in r["results"]} == {"set_threshold", "set_location"} + + def test_below_threshold_returns_full_schemas(self, mock_conn): + r = tools_mod.discover("device(cam-001).function(*)") + assert r["matched"] == 1 + row = r["results"][0] + assert "parameters" in row + assert "description" in row + assert row["labels"] == {"direction": "write", "modality": "rgb"} + + def test_modality_or_within_key(self, mock_conn): + r = tools_mod.discover("device(*).function(modality:[rgb,thermal])") + assert r["matched"] == 2 + assert all(row["name"] == "capture_image" for row in r["results"]) + + def test_safety_critical_filter(self, mock_conn): + r = tools_mod.discover("function(safety:critical)") + assert r["matched"] == 1 + assert r["results"][0]["name"] == "dispatch_robot" + + def test_label_histogram_built(self, mock_conn): + r = tools_mod.discover("device(*).function(direction:write)") + hist = r["label_histogram"] + assert hist["direction"]["values"] == {"write": 5} + # modality is multi-valued on cam-002 (rgb + 4k) + modality = hist["modality"] + assert modality.get("multivalued") is True + assert modality["values"] == {"rgb": 2, "4k": 1} + + +# -- discover: event scope ----------------------------------------- + + +class TestDiscoverEventScope: + def test_event_by_modality(self, mock_conn): + r = tools_mod.discover("device(*).event(modality:rgb)") + assert r["scope"] == "device_event" + assert r["matched"] == 2 # cam-001 + cam-002 each emit object_detected + assert all(row["name"] == "object_detected" for row in r["results"]) + + def test_event_only_by_name(self, mock_conn): + r = tools_mod.discover("event(threshold_exceeded)") + assert r["scope"] == "event_only" + assert r["matched"] == 1 + + +# -- discover: pagination ------------------------------------------ + + +class TestDiscoverPagination: + def test_pagination_envelope(self, mock_conn): + r = tools_mod.discover("device(*)", limit=2) + assert r["matched"] == 4 + assert r["returned"] == 2 + assert r["offset"] == 0 + assert r["next_offset"] == 2 + + def test_offset_respected(self, mock_conn): + r = tools_mod.discover("device(*)", offset=2, limit=10) + assert r["offset"] == 2 + assert r["returned"] == 2 + assert r["next_offset"] is None + + def test_negative_offset_clamped(self, mock_conn): + r = tools_mod.discover("device(*)", offset=-5) + assert r["offset"] == 0 + + def test_hard_limit_caps_runaway_request(self, mock_conn): + r = tools_mod.discover("device(*)", limit=999_999) + # Hard ceiling is 1000; for 4 devices, the page just returns everything. + assert r["returned"] == 4 + + def test_zero_limit_falls_back_to_default(self, mock_conn): + r = tools_mod.discover("device(*)", limit=0) + # Default applies, all 4 fit in one page. + assert r["returned"] == 4 + + +# -- discover: errors ---------------------------------------------- + + +class TestDiscoverErrors: + def test_bad_selector_returns_error_envelope(self, mock_conn): + r = tools_mod.discover("not a selector at all") + assert r["error"]["code"] == "selector_parse_error" + assert r["matched"] == 0 + assert r["results"] == [] + + def test_unknown_scope_in_selector(self, mock_conn): + r = tools_mod.discover("widgets(*)") + assert "error" in r + assert r["error"]["code"] == "selector_parse_error" + assert "unknown scope" in r["error"]["message"].lower() + + def test_connection_failure_returns_error(self): + broken = MagicMock() + broken.list_devices.side_effect = RuntimeError("messaging down") + with patch.object(tools_mod, "get_connection", return_value=broken): + r = tools_mod.discover("device(*)") + assert r["error"]["code"] == "connection_error" + assert "messaging down" in r["error"]["message"] + assert r["matched"] == 0 + + def test_non_string_selector(self, mock_conn): + r = tools_mod.discover(None) # type: ignore[arg-type] + assert r["error"]["code"] == "invalid_selector" + + +# -- discover_labels ------------------------------------------------ + + +class TestDiscoverLabels: + def test_multi_axis_default(self, mock_conn): + v = tools_mod.discover_labels() + assert v["total_devices"] == 4 + assert v["total_functions"] == 7 + assert v["total_events"] == 6 + assert "category" in v["device_keys"] + assert "direction" in v["function_keys"] + assert "modality" in v["event_keys"] + + def test_multivalued_annotation_on_device_category(self, mock_conn): + v = tools_mod.discover_labels() + cat = v["device_keys"]["category"] + assert cat["multivalued"] is True + # All 4 devices declared a category; cam-001 contributed to two values + # but unique_devices counts distinct devices. + assert cat["unique_devices"] == 4 + assert cat["values"] == {"camera": 2, "inference": 1, "robot": 1, "sensor": 1} + + def test_singleton_keys_not_flagged_multivalued(self, mock_conn): + v = tools_mod.discover_labels() + direction = v["function_keys"]["direction"] + assert direction.get("multivalued") is not True + + def test_per_key_pagination(self, mock_conn): + v = tools_mod.discover_labels(key="device.location") + assert v["axis"] == "device" + assert v["key"] == "location" + # 4 distinct location values, sorted by frequency desc then alpha + assert v["matched"] == 4 + assert list(v["values"].keys())[0] == "lab-B" # only single value with count 1, alpha tiebreak + + def test_per_key_function_axis(self, mock_conn): + v = tools_mod.discover_labels(key="function.direction") + assert v["axis"] == "function" + assert v["values"] == {"write": 5, "read": 2} + + def test_per_key_unknown_axis(self, mock_conn): + v = tools_mod.discover_labels(key="thing.bogus") + assert v["error"]["code"] == "unknown_axis" + + def test_per_key_missing_dot(self, mock_conn): + v = tools_mod.discover_labels(key="just_a_key") + assert "error" in v + assert v["error"]["code"] == "key_not_axis_qualified" + assert "axis-qualified" in v["error"]["message"] + + +# -- Deprecation warnings ------------------------------------------ + + +class TestDeprecationWarnings: + def test_describe_fleet_emits_warning(self, mock_conn, recwarn): + tools_mod.describe_fleet() + assert any("describe_fleet" in str(w.message) for w in recwarn.list) + + def test_list_devices_emits_warning(self, mock_conn, recwarn): + tools_mod.list_devices() + assert any("list_devices" in str(w.message) for w in recwarn.list) + + def test_get_device_functions_emits_warning(self, mock_conn, recwarn): + # get_device_functions calls conn.get_device which we haven't mocked; + # the warning is emitted before that call so we still observe it. + mock_conn.get_device = MagicMock(return_value={ + "device_id": "cam-001", "functions": [], "events": [], + "identity": {}, "status": {}, "capabilities": {}, + }) + # Force a fresh patch so get_device path is hit + with patch.object(tools_mod, "get_connection", return_value=mock_conn): + tools_mod.get_device_functions("cam-001") + assert any("get_device_functions" in str(w.message) for w in recwarn.list) diff --git a/packages/device-connect-agent-tools/tests/test_invoke.py b/packages/device-connect-agent-tools/tests/test_invoke.py new file mode 100644 index 0000000..aae1a83 --- /dev/null +++ b/packages/device-connect-agent-tools/tests/test_invoke.py @@ -0,0 +1,336 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the selector-driven ``invoke`` and ``invoke_many`` tools. + +Uses a small labeled fleet (cam-001, cam-002, robot-001, sensor-001) drawn +from the existing DC test driver vocabulary so every selector exercises +real device, function, and event names. +""" +import time +from unittest.mock import MagicMock, patch + +import pytest + +from device_connect_agent_tools import tools as tools_mod + + +# -- Fixtures ------------------------------------------------------- + + +SAMPLE_DEVICES = [ + { + "device_id": "cam-001", + "device_type": "camera", + "location": "lab-A", + "status": {"state": "online"}, + "identity": {"device_type": "camera"}, + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [ + { + "name": "capture_image", + "parameters": {}, + "labels": {"direction": "write", "modality": "rgb"}, + }, + ], + "events": [], + }, + { + "device_id": "cam-002", + "device_type": "camera", + "location": "lab-A", + "status": {"state": "online"}, + "identity": {"device_type": "camera"}, + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [ + { + "name": "capture_image", + "parameters": {}, + "labels": {"direction": "write", "modality": "rgb"}, + }, + ], + "events": [], + }, + { + "device_id": "robot-001", + "device_type": "cleaner_robot", + "location": "lab-A", + "status": {"state": "idle"}, + "identity": {"device_type": "cleaner_robot"}, + "labels": {"category": "robot", "location": "lab-A"}, + "functions": [ + { + "name": "dispatch_robot", + "parameters": {}, + "labels": {"direction": "write", "safety": "critical"}, + }, + ], + "events": [], + }, + { + "device_id": "sensor-001", + "device_type": "temperature_sensor", + "location": "lab-B", + "status": {"state": "online"}, + "identity": {"device_type": "temperature_sensor"}, + "labels": {"category": "sensor", "location": "lab-B"}, + "functions": [ + { + "name": "get_reading", + "parameters": {}, + "labels": {"direction": "read"}, + }, + ], + "events": [], + }, +] + + +def _conn_with_invoke(invoke_side_effect): + """Return a mock Connection whose .invoke() applies ``invoke_side_effect``. + + ``invoke_side_effect`` is called with ``(device_id, function_name, + params, timeout)`` and must return a JSON-RPC response dict. + """ + conn = MagicMock() + conn.list_devices.return_value = SAMPLE_DEVICES + + def _invoke(device_id, function_name, params=None, timeout=None): + return invoke_side_effect(device_id, function_name, params, timeout) + + conn.invoke.side_effect = _invoke + return conn + + +@pytest.fixture +def all_succeed_conn(): + def _ok(device_id, function_name, params, timeout): + return {"jsonrpc": "2.0", "id": "1", "result": { + "device_id": device_id, "function": function_name, "params": params, + }} + conn = _conn_with_invoke(_ok) + with patch.object(tools_mod, "get_connection", return_value=conn): + yield conn + + +# -- invoke --------------------------------------------------------- + + +class TestInvoke: + def test_single_match_returns_success(self, all_succeed_conn): + r = tools_mod.invoke( + "device(cam-001).function(capture_image)", + params={"resolution": "1080p"}, + ) + assert r["success"] is True + assert r["device_id"] == "cam-001" + assert r["function"] == "capture_image" + assert r["result"]["params"] == {"resolution": "1080p"} + + def test_function_only_selector_with_unique_name(self, all_succeed_conn): + r = tools_mod.invoke("function(get_reading)") + assert r["success"] is True + assert r["device_id"] == "sensor-001" + assert r["function"] == "get_reading" + + def test_no_match_returns_no_match_error(self, all_succeed_conn): + r = tools_mod.invoke("device(*).function(does_not_exist)") + assert r["success"] is False + assert r["error"]["code"] == "no_match" + assert "does_not_exist" in r["error"]["message"] + + def test_ambiguous_match_returns_error_with_candidates(self, all_succeed_conn): + # capture_image exists on both cam-001 and cam-002. + r = tools_mod.invoke("function(capture_image)") + assert r["success"] is False + assert r["error"]["code"] == "ambiguous_match" + assert "expected exactly 1" in r["error"]["message"] + ids = {c["device_id"] for c in r["candidates"]} + assert ids == {"cam-001", "cam-002"} + + def test_device_only_scope_rejected(self, all_succeed_conn): + # Device-only scope cannot resolve to a function. + r = tools_mod.invoke("device(robot-001)") + assert r["success"] is False + assert r["error"]["code"] == "invalid_invoke_scope" + + def test_event_scope_rejected(self, all_succeed_conn): + r = tools_mod.invoke("event(reading)") + assert r["success"] is False + assert r["error"]["code"] == "invalid_invoke_scope" + + def test_selector_parse_error_propagated(self, all_succeed_conn): + r = tools_mod.invoke("not a selector") + assert r["success"] is False + assert r["error"]["code"] == "selector_parse_error" + + def test_non_string_selector_rejected(self, all_succeed_conn): + r = tools_mod.invoke(None) # type: ignore[arg-type] + assert r["success"] is False + assert r["error"]["code"] == "invalid_selector" + + def test_jsonrpc_error_maps_to_invoke_failed(self): + def _err(device_id, function_name, params, timeout): + return { + "jsonrpc": "2.0", "id": "1", + "error": {"code": -32000, "message": "device busy"}, + } + conn = _conn_with_invoke(_err) + with patch.object(tools_mod, "get_connection", return_value=conn): + r = tools_mod.invoke("device(robot-001).function(dispatch_robot)") + assert r["success"] is False + assert r["error"]["code"] == "-32000" + assert r["error"]["message"] == "device busy" + assert r["device_id"] == "robot-001" + assert r["function"] == "dispatch_robot" + + def test_connection_exception_returns_invoke_failed(self): + conn = MagicMock() + conn.list_devices.return_value = SAMPLE_DEVICES + conn.invoke.side_effect = RuntimeError("messaging down") + with patch.object(tools_mod, "get_connection", return_value=conn): + r = tools_mod.invoke("device(cam-001).function(capture_image)") + assert r["success"] is False + assert r["error"]["code"] == "invoke_failed" + assert "messaging down" in r["error"]["message"] + + def test_llm_reasoning_stripped_from_params(self, all_succeed_conn): + tools_mod.invoke( + "device(cam-001).function(capture_image)", + params={"resolution": "1080p", "llm_reasoning": "should not appear"}, + llm_reasoning="caller reasoning", + ) + # Inspect the params actually delivered to the wire: + sent = all_succeed_conn.invoke.call_args.kwargs["params"] + assert "llm_reasoning" not in sent + assert sent["resolution"] == "1080p" + + +# -- invoke_many ---------------------------------------------------- + + +class TestInvokeMany: + def test_zero_matches_returns_empty_envelope(self, all_succeed_conn): + r = tools_mod.invoke_many("device(*).function(does_not_exist)") + assert r["candidates"] == 0 + assert r["matched"] == 0 + assert r["succeeded"] == 0 + assert r["failed"] == 0 + assert r["results"] == [] + assert r["errors"] == [] + assert "error" not in r + + def test_all_succeed(self, all_succeed_conn): + r = tools_mod.invoke_many("device(*).function(capture_image)") + assert r["candidates"] == 2 + assert r["matched"] == 2 + assert r["succeeded"] == 2 + assert r["failed"] == 0 + ids = {row["device_id"] for row in r["results"]} + assert ids == {"cam-001", "cam-002"} + # Each result row is shaped {device_id, function, result}. + for row in r["results"]: + assert row["function"] == "capture_image" + assert "result" in row + + def test_partial_failure_shape(self): + def _half_fail(device_id, function_name, params, timeout): + if device_id == "cam-001": + return {"jsonrpc": "2.0", "id": "1", "result": {"ok": True}} + return { + "jsonrpc": "2.0", "id": "1", + "error": {"code": -32000, "message": "down"}, + } + conn = _conn_with_invoke(_half_fail) + with patch.object(tools_mod, "get_connection", return_value=conn): + r = tools_mod.invoke_many("device(*).function(capture_image)") + assert r["candidates"] == 2 + assert r["matched"] == 2 + assert r["succeeded"] == 1 + assert r["failed"] == 1 + assert {row["device_id"] for row in r["results"]} == {"cam-001"} + assert {row["device_id"] for row in r["errors"]} == {"cam-002"} + for row in r["errors"]: + assert row["error"]["code"] == "-32000" + assert row["error"]["message"] == "down" + + def test_invalid_scope_returns_error_envelope(self, all_succeed_conn): + r = tools_mod.invoke_many("device(robot-001)") + assert r["candidates"] == 0 + assert r["error"]["code"] == "invalid_invoke_scope" + + def test_selector_parse_error_propagated(self, all_succeed_conn): + r = tools_mod.invoke_many("widgets(*)") + assert r["candidates"] == 0 + assert r["error"]["code"] == "selector_parse_error" + + def test_per_target_timeout_passed_to_connection(self, all_succeed_conn): + tools_mod.invoke_many( + "device(*).function(capture_image)", timeout=7.5, + ) + # Every conn.invoke call should carry the same timeout. + for call in all_succeed_conn.invoke.call_args_list: + assert call.kwargs["timeout"] == 7.5 + + def test_max_concurrency_caps_thread_pool(self, all_succeed_conn): + # The fan-out group has 3 targets (capture_image x2 + dispatch_robot + # don't share name; pick a selector that resolves to multiple). Use + # function(direction:write) which selects 4 distinct rows. + r = tools_mod.invoke_many( + "function(direction:write)", max_concurrency=1, + ) + assert r["candidates"] >= 2 + assert r["succeeded"] == r["candidates"] + + def test_connection_exception_recorded_per_target(self): + # Mix: cam-001 succeeds, cam-002's call raises locally. + def _mixed(device_id, function_name, params, timeout): + if device_id == "cam-002": + raise RuntimeError("messaging blip") + return {"jsonrpc": "2.0", "id": "1", "result": {"ok": True}} + conn = _conn_with_invoke(_mixed) + with patch.object(tools_mod, "get_connection", return_value=conn): + r = tools_mod.invoke_many("device(*).function(capture_image)") + assert r["succeeded"] == 1 + assert r["failed"] == 1 + cam002_err = next(e for e in r["errors"] if e["device_id"] == "cam-002") + assert cam002_err["error"]["code"] == "invoke_failed" + assert "messaging blip" in cam002_err["error"]["message"] + + def test_llm_reasoning_stripped_from_params(self, all_succeed_conn): + tools_mod.invoke_many( + "device(*).function(capture_image)", + params={"resolution": "4k", "llm_reasoning": "should not appear"}, + ) + for call in all_succeed_conn.invoke.call_args_list: + sent = call.kwargs["params"] + assert "llm_reasoning" not in sent + assert sent["resolution"] == "4k" + + +# -- _resolve_function_tuples --------------------------------------- + + +class TestResolveFunctionTuples: + def test_walks_all_pages(self, all_succeed_conn): + # Use a small DISCOVER_HARD_LIMIT temporarily. + with patch.object(tools_mod, "DISCOVER_HARD_LIMIT", 1): + rows, err = tools_mod._resolve_function_tuples( + "device(*).function(direction:write)" + ) + assert err is None + # 4 distinct (device, function) tuples for direction:write across the + # mock fleet (cam-001, cam-002, robot-001, sensor-001 set_threshold + # and set_location). With limit=1 per page, the resolver had to + # paginate through all of them. + assert len(rows) >= 2 + for row in rows: + assert "device_id" in row + assert "name" in row + + def test_propagates_discover_error(self, all_succeed_conn): + rows, err = tools_mod._resolve_function_tuples("not a selector") + assert rows is None + assert err is not None + assert err["error"]["code"] == "selector_parse_error" diff --git a/packages/device-connect-agent-tools/tests/test_langchain_adapter.py b/packages/device-connect-agent-tools/tests/test_langchain_adapter.py index 6b731d2..210930d 100644 --- a/packages/device-connect-agent-tools/tests/test_langchain_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_langchain_adapter.py @@ -9,6 +9,7 @@ """ import sys +from inspect import signature from types import ModuleType from unittest.mock import MagicMock, patch @@ -69,19 +70,30 @@ def _mock_langchain_and_connection(): del sys.modules[key] +EXPECTED_TOOLS = { + "discover_labels", + "discover", + "invoke", + "invoke_many", + "broadcast", + "await_replies", + "invoke_device_with_fallback", + "get_device_status", + "discover_devices", +} + + class TestLangchainAdapterExports: def test_module_exports_all_tools(self): from device_connect_agent_tools.adapters import langchain as adapter - for name in ("discover_devices", "invoke_device", "invoke_device_with_fallback", - "get_device_status", "describe_fleet", "list_devices", "get_device_functions"): + for name in EXPECTED_TOOLS: assert hasattr(adapter, name), f"Missing export: {name}" def test_all_list(self): from device_connect_agent_tools.adapters import langchain as adapter - expected = {"discover_devices", "invoke_device", "invoke_device_with_fallback", "get_device_status", "list_devices", "get_device_functions", "describe_fleet"} - assert set(adapter.__all__) == expected + assert set(adapter.__all__) == EXPECTED_TOOLS def test_tools_are_structured_tool_instances(self): from device_connect_agent_tools.adapters import langchain as adapter @@ -100,3 +112,9 @@ def test_tool_descriptions_not_empty(self): for name in adapter.__all__: assert len(getattr(adapter, name).description) > 0, f"{name} has empty description" + + @pytest.mark.parametrize("name", ("invoke", "invoke_many", "broadcast")) + def test_invocation_tools_inherit_mandate_signature(self, name): + from device_connect_agent_tools.adapters import langchain as adapter + + assert "mandate" in signature(getattr(adapter, name)._func).parameters diff --git a/packages/device-connect-agent-tools/tests/test_strands_adapter.py b/packages/device-connect-agent-tools/tests/test_strands_adapter.py index 6a1ea6f..0943968 100644 --- a/packages/device-connect-agent-tools/tests/test_strands_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_strands_adapter.py @@ -9,6 +9,7 @@ """ import sys +from inspect import signature from types import ModuleType from unittest.mock import MagicMock, patch @@ -52,19 +53,30 @@ def _mock_strands_and_connection(): sys.modules.pop("strands", None) +EXPECTED_TOOLS = { + "discover_labels", + "discover", + "invoke", + "invoke_many", + "broadcast", + "await_replies", + "invoke_device_with_fallback", + "get_device_status", + "discover_devices", +} + + class TestStrandsAdapterExports: def test_module_exports_all_tools(self): from device_connect_agent_tools.adapters import strands as adapter - for name in ("discover_devices", "invoke_device", "invoke_device_with_fallback", - "get_device_status", "describe_fleet", "list_devices", "get_device_functions"): + for name in EXPECTED_TOOLS: assert hasattr(adapter, name), f"Missing export: {name}" def test_all_list(self): from device_connect_agent_tools.adapters import strands as adapter - expected = {"discover_devices", "invoke_device", "invoke_device_with_fallback", "get_device_status", "list_devices", "get_device_functions", "describe_fleet"} - assert set(adapter.__all__) == expected + assert set(adapter.__all__) == EXPECTED_TOOLS def test_tools_are_callable(self): from device_connect_agent_tools.adapters import strands as adapter @@ -77,3 +89,9 @@ def test_tool_names_match(self): for name in adapter.__all__: assert getattr(adapter, name).__name__ == name, f"{name}.__name__ mismatch" + + @pytest.mark.parametrize("name", ("invoke", "invoke_many", "broadcast")) + def test_invocation_tools_inherit_mandate_signature(self, name): + from device_connect_agent_tools.adapters import strands as adapter + + assert "mandate" in signature(getattr(adapter, name).__wrapped__).parameters diff --git a/packages/device-connect-agent-tools/tests/test_subscribe.py b/packages/device-connect-agent-tools/tests/test_subscribe.py new file mode 100644 index 0000000..a8b4be4 --- /dev/null +++ b/packages/device-connect-agent-tools/tests/test_subscribe.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the selector-driven subscribe + await_replies tools. + +The tests stand up a fake Connection that mirrors the buffered-inbox API +the production class exposes (``subscribe_buffered`` / +``unsubscribe_buffered`` / ``get_inbox`` / ``_inbox`` dict). Real +messaging is not exercised here; integration tests cover the wire. +""" +from unittest.mock import patch + +import pytest + +from device_connect_agent_tools import tools as tools_mod + + +SAMPLE_DEVICES = [ + { + "device_id": "cam-001", + "device_type": "camera", + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [], + "events": [ + {"name": "object_detected", "labels": {"modality": "rgb"}}, + ], + }, + { + "device_id": "cam-002", + "device_type": "camera", + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [], + "events": [ + {"name": "object_detected", "labels": {"modality": "rgb"}}, + ], + }, +] + + +class FakeConnection: + """Minimal fake of the agent-tools Connection used by Subscription.""" + + def __init__(self, devices=None, zone="default"): + self.zone = zone + self.devices = devices or [] + self._inbox: dict[str, list[tuple]] = {} + self.subscribed_subjects: list[str] = [] + self.unsubscribed_names: list[str] = [] + + def list_devices(self): + return list(self.devices) + + def subscribe_buffered(self, subject: str, name: str | None = None) -> str: + name = name or subject + self._inbox[name] = [] + self.subscribed_subjects.append(subject) + return name + + def unsubscribe_buffered(self, name: str) -> None: + self.unsubscribed_names.append(name) + self._inbox.pop(name, None) + + def get_inbox(self, name: str | None = None): + if name is not None: + return {name: list(self._inbox.get(name, []))} + return {k: list(v) for k, v in self._inbox.items()} + + # Test helper: simulate a message landing on a given subject. + def deliver(self, subject: str, payload: dict): + for name, _ in list(self._inbox.items()): + self._inbox[name].append((subject, payload)) + + +@pytest.fixture +def fake_conn(): + conn = FakeConnection(devices=SAMPLE_DEVICES) + with patch.object(tools_mod, "get_connection", return_value=conn): + yield conn + + +# -- subscribe ------------------------------------------------------ + + +class TestSubscribe: + def test_correlation_form_subscribes_to_reply_subject(self, fake_conn): + sub = tools_mod.subscribe("correlation:abc-123") + assert len(fake_conn.subscribed_subjects) == 1 + subj = fake_conn.subscribed_subjects[0] + assert subj == "device-connect.default.*.event.async_reply.abc-123" + sub.close() + assert fake_conn.unsubscribed_names + + def test_correlation_form_with_empty_id_rejected(self, fake_conn): + with pytest.raises(ValueError): + tools_mod.subscribe("correlation:") + + def test_event_selector_subscribes_per_device(self, fake_conn): + sub = tools_mod.subscribe("device(*).event(object_detected)") + # Two cameras emit object_detected -> two subjects subscribed. + assert len(fake_conn.subscribed_subjects) == 2 + for subj in fake_conn.subscribed_subjects: + assert subj.startswith("device-connect.default.") + assert subj.endswith(".event.object_detected") + sub.close() + + def test_event_selector_zero_matches_returns_idle(self, fake_conn): + sub = tools_mod.subscribe("event(no_such_event)") + assert fake_conn.subscribed_subjects == [] + # Idle subscription: read returns empty, close is a no-op. + assert sub.read() == [] + sub.close() + + def test_non_event_scope_rejected(self, fake_conn): + with pytest.raises(ValueError) as exc: + tools_mod.subscribe("device(cam-001)") + assert "subscribe requires" in str(exc.value) + + def test_empty_or_non_string_rejected(self, fake_conn): + with pytest.raises(ValueError): + tools_mod.subscribe("") + with pytest.raises(ValueError): + tools_mod.subscribe(None) # type: ignore[arg-type] + + +# -- Subscription --------------------------------------------------- + + +class TestSubscriptionHandle: + def test_read_drains_buffered_messages(self, fake_conn): + sub = tools_mod.subscribe("correlation:r1") + fake_conn.deliver( + "device-connect.default.cam-001.event.async_reply.r1", + {"correlation_id": "r1", "device_id": "cam-001", "success": True}, + ) + msgs = sub.read() + assert len(msgs) == 1 + assert msgs[0]["device_id"] == "cam-001" + # Subject is stamped onto the payload for source attribution. + assert "_subject" in msgs[0] + # A second read returns nothing -- the buffer is drained. + assert sub.read() == [] + sub.close() + + def test_context_manager_closes(self, fake_conn): + with tools_mod.subscribe("correlation:r2") as sub: + assert sub.read() == [] + assert fake_conn.unsubscribed_names # close() ran + + def test_iter_yields_until_idle_timeout(self, fake_conn): + sub = tools_mod.subscribe("correlation:r3") + fake_conn.deliver( + "device-connect.default.cam-001.event.async_reply.r3", + {"correlation_id": "r3", "device_id": "cam-001"}, + ) + # Short timeout; iter() should yield the buffered reply then exit + # once no new messages arrive within the idle window. + msgs = list(sub.iter(timeout=0.1, poll_interval=0.01)) + assert len(msgs) == 1 + sub.close() + + def test_for_loop_protocol_via_dunder_iter(self, fake_conn): + # ``for msg in sub:`` should drive __iter__ which delegates to iter() + # with a sensible default timeout. Break early so the test does not + # block on the 30s default. + sub = tools_mod.subscribe("correlation:r_iter") + fake_conn.deliver( + "device-connect.default.cam-001.event.async_reply.r_iter", + {"correlation_id": "r_iter", "device_id": "cam-001"}, + ) + gathered: list[dict] = [] + for msg in sub: + gathered.append(msg) + break # one message is enough to confirm __iter__ wiring + sub.close() + assert len(gathered) == 1 + assert gathered[0]["device_id"] == "cam-001" + + def test_read_does_not_drop_messages_appended_during_iteration(self, fake_conn): + # Race-safety guard: simulate a callback that appends a fresh + # message between the read's snapshot and truncation. The message + # must still be visible on the next read(). + sub = tools_mod.subscribe("correlation:r_race") + fake_conn.deliver( + "device-connect.default.cam-001.event.async_reply.r_race", + {"correlation_id": "r_race", "device_id": "cam-001", "ordinal": 1}, + ) + first = sub.read() + assert len(first) == 1 + # Now simulate a late-arriving append into the same inbox AFTER + # the previous read drained the prefix. + fake_conn.deliver( + "device-connect.default.cam-002.event.async_reply.r_race", + {"correlation_id": "r_race", "device_id": "cam-002", "ordinal": 2}, + ) + second = sub.read() + assert len(second) == 1 + assert second[0]["device_id"] == "cam-002" + sub.close() + + +# -- await_replies -------------------------------------------------- + + +class TestAwaitReplies: + def test_empty_correlation_id_returns_empty_list(self, fake_conn): + assert tools_mod.await_replies("") == [] + + def test_collects_replies_until_count(self, fake_conn): + # Pre-stage two replies on the to-be-subscribed subject. await_replies + # subscribes (drains nothing yet), then deliver more during the loop. + # We deliver up-front via the fake's deliver hook so the first poll + # picks them up. + def deliver_when_subscribed(subject, name=None): + n = FakeConnection.subscribe_buffered(fake_conn, subject, name) + # Pre-load a couple of replies so the first poll returns them. + fake_conn.deliver( + "device-connect.default.cam-001.event.async_reply.r4", + {"correlation_id": "r4", "device_id": "cam-001"}, + ) + fake_conn.deliver( + "device-connect.default.cam-002.event.async_reply.r4", + {"correlation_id": "r4", "device_id": "cam-002"}, + ) + return n + + with patch.object( + fake_conn, "subscribe_buffered", side_effect=deliver_when_subscribed, + ): + replies = tools_mod.await_replies( + "r4", timeout=2.0, until=2, poll_interval=0.01, + ) + assert len(replies) == 2 + ids = {r["device_id"] for r in replies} + assert ids == {"cam-001", "cam-002"} + + def test_returns_after_timeout_with_partial(self, fake_conn): + # No replies delivered -> after timeout, returns empty list. + replies = tools_mod.await_replies( + "r5", timeout=0.1, poll_interval=0.01, + ) + assert replies == [] diff --git a/packages/device-connect-edge/device_connect_edge/__init__.py b/packages/device-connect-edge/device_connect_edge/__init__.py index 4812c51..b3a3b4b 100644 --- a/packages/device-connect-edge/device_connect_edge/__init__.py +++ b/packages/device-connect-edge/device_connect_edge/__init__.py @@ -43,6 +43,13 @@ async def alert(self, level: str, msg: str): FunctionDef, EventDef, ) +from device_connect_edge.mandates import ( + MandateInvocationContext, + MandateVerificationResult, + create_closed_mandate, + create_open_mandate, + verify_mandate, +) from device_connect_edge.discovery_provider import DiscoveryProvider from device_connect_edge.registry_client import RegistryClient from device_connect_edge.errors import ( @@ -66,6 +73,11 @@ async def alert(self, level: str, msg: str): "DeviceStatus", "FunctionDef", "EventDef", + "MandateInvocationContext", + "MandateVerificationResult", + "create_closed_mandate", + "create_open_mandate", + "verify_mandate", "DiscoveryProvider", "RegistryClient", "DeviceConnectError", diff --git a/packages/device-connect-edge/device_connect_edge/device.py b/packages/device-connect-edge/device_connect_edge/device.py index 40d5c63..8768930 100644 --- a/packages/device-connect-edge/device_connect_edge/device.py +++ b/packages/device-connect-edge/device_connect_edge/device.py @@ -77,6 +77,11 @@ async def capture_image(self, resolution: str = "1080p") -> dict: DeviceIdentity, DeviceStatus, ) +from device_connect_edge.mandates import ( + MandateInvocationContext, + MandateVerificationResult, + verify_mandate, +) # Type checking imports for driver support if TYPE_CHECKING: @@ -163,6 +168,16 @@ def build_rpc_error(id_: str, code: int, msg: str) -> bytes: ).encode() +def _broadcast_error(message: str) -> Dict[str, str]: + code = "invoke_failed" + if ":" in message: + prefix, _, rest = message.partition(":") + if prefix.startswith("mandate_") or prefix in {"invalid_mandate", "unknown_mandate_key"}: + code = prefix + message = rest.strip() or message + return {"code": code, "message": message} + + class DeviceRuntime: """High-level runtime for Device Connect devices. @@ -259,6 +274,7 @@ def __init__( auto_commission: bool = True, commissioning_port: int = 5540, allow_insecure: Optional[bool] = None, + mandate_keys: Optional[Dict[str, Union[bytes, str]]] = None, ) -> None: # Store driver reference and connect driver to this device self._driver = driver @@ -349,6 +365,8 @@ def __init__( self.allow_insecure = os.getenv("DEVICE_CONNECT_ALLOW_INSECURE", "").lower() in ("1", "true", "yes") else: self.allow_insecure = allow_insecure + self._mandate_keys: Dict[str, Union[bytes, str]] = dict(mandate_keys or {}) + self._mandate_replay_cache: set[str] = set() self._factory_identity: Optional[dict] = None # Initialize logger and internal state early (before commissioning checks) @@ -1112,6 +1130,19 @@ async def on_msg(data: bytes, reply_subject: Optional[str]): "device_connect.source_device": source_device or "", }, ): + mandate_result = self._verify_mandate_for_invocation( + method, params_dict, dc_meta, + ) + if not mandate_result.ok: + if reply_subject: + await self.messaging.publish( + reply_subject, + build_rpc_error( + payload["id"], -32041, + mandate_result.message or mandate_result.error_code or "mandate_denied", + ) + ) + return # Pass source_device to driver for logging (existing pattern) if source_device: params_dict["source_device"] = source_device @@ -1135,6 +1166,230 @@ async def on_msg(data: bytes, reply_subject: Optional[str]): self._logger.info("Subscribed to commands on %s", subj) + def _verify_mandate_for_invocation( + self, + function_name: str, + params: Dict[str, Any], + dc_meta: Optional[Dict[str, Any]], + ) -> MandateVerificationResult: + """Verify mandate metadata when a function declares it is required.""" + if self._driver is None: + return MandateVerificationResult(ok=True) + method = self._driver._get_functions().get(function_name) + mandate_policy = getattr(method, "_mandate", None) + if not mandate_policy or not mandate_policy.get("required"): + return MandateVerificationResult(ok=True) + meta = dc_meta if isinstance(dc_meta, dict) else {} + return verify_mandate( + meta.get("mandate"), + context=MandateInvocationContext( + device_id=self.device_id, + method=function_name, + params=params, + ), + key_resolver=self._mandate_keys.get, + replay_cache=self._mandate_replay_cache, + ) + + + async def _broadcast_subscription(self) -> None: + """Subscribe to selector-driven broadcasts and self-elect to handle. + + Broadcast envelope shape (JSON over a fanout subject):: + + { + "correlation_id": "br-abc123", + "function": "capture_image", + "params": {"resolution": "4k"}, + "targets": ["cam-001", "cam-002"], // pre-resolved + "where": "status.battery > 50", // optional CEL + "bindings": {"mask": [[0,1],[1,0]]}, // optional + "fire_at": 1234567890.5, // optional, epoch s + "on_late": "skip" // skip|fire + } + + On match, the device executes the function and emits a reply on + ``device-connect...event.async_reply.`` + with ``{correlation_id, device_id, success, result|error, + actually_fired_at}``. + + The envelope is processed in a tracked task so the subscription + loop does not block on ``fire_at`` sleeps or long-running driver + functions; subsequent broadcasts can continue to land while an + earlier one is in flight. + """ + subj = f"device-connect.{self.tenant}.broadcast" + + async def on_msg(data: bytes, reply_subject: Optional[str]): + try: + envelope = json.loads(data) + except Exception as e: + self._logger.debug("Broadcast: malformed envelope: %s", e) + return + + correlation_id = envelope.get("correlation_id") + if not correlation_id: + return + + # Cheap self-election: target gate (pre-resolved by the dispatcher + # from the selector). When absent or empty, treat as fleet-wide. + targets = envelope.get("targets") or [] + if targets and self.device_id not in targets: + return + + if not envelope.get("function"): + return + + # Hand off to a tracked task. The task owns the where evaluation, + # the fire_at sleep, and the driver call, so this callback returns + # immediately and the messaging subscription stays drained. + self._track_task(asyncio.create_task( + self._handle_broadcast_envelope(envelope, correlation_id) + )) + + await self.messaging.subscribe(subj, callback=on_msg) + self._logger.info("Subscribed to broadcasts on %s", subj) + + + async def _handle_broadcast_envelope( + self, envelope: Dict[str, Any], correlation_id: str, + ) -> None: + """Process one broadcast envelope: evaluate where, honour fire_at, invoke, reply. + + Runs in its own task so a long-held ``fire_at`` or slow driver + function does not block the subscription callback from accepting + subsequent broadcasts. + """ + function_name = envelope.get("function") + params_dict = envelope.get("params", {}) or {} + dc_meta = params_dict.pop("_dc_meta", {}) + + # Step 1: where predicate against {identity, labels, status, bindings}. + # A failed compile or eval is treated as fail-closed (do not execute); + # the message is logged at WARNING with the correlation_id so an + # operator can correlate a silent skip with a misspelled label key. + where_expr = envelope.get("where") + if where_expr and not self._evaluate_where( + where_expr, envelope.get("bindings"), correlation_id, + ): + return + + # Step 2: fire_at hold. The on_late policy decides what to do when + # the message arrives past the deadline (skip preserves coherence; + # fire runs anyway). + fire_at = envelope.get("fire_at") + on_late = envelope.get("on_late", "skip") + if fire_at is not None: + delay = float(fire_at) - time.time() + if delay < 0 and on_late == "skip": + self._logger.info( + "Broadcast %s arrived %.3fs late, on_late=skip", + correlation_id, -delay, + ) + return + if delay > 0: + await asyncio.sleep(delay) + + # Step 3: execute and reply. + actually_fired_at = time.time() + reply_subj = ( + f"device-connect.{self.tenant}.{self.device_id}" + f".event.async_reply.{correlation_id}" + ) + try: + if self._driver is None: + raise RuntimeError("no driver configured") + driver_functions = self._driver._get_functions() + if function_name not in driver_functions: + raise RuntimeError(f"unknown function: {function_name}") + mandate_result = self._verify_mandate_for_invocation( + function_name, params_dict, dc_meta, + ) + if not mandate_result.ok: + code = mandate_result.error_code or "mandate_denied" + raise RuntimeError(f"{code}: {mandate_result.message}") + result = await self._driver.invoke(function_name, **params_dict) + reply_payload: Dict[str, Any] = { + "correlation_id": correlation_id, + "device_id": self.device_id, + "success": True, + "result": result, + "actually_fired_at": actually_fired_at, + } + except Exception as e: + self._logger.warning( + "Broadcast %s: function %s failed: %s", + correlation_id, function_name, e, + ) + reply_payload = { + "correlation_id": correlation_id, + "device_id": self.device_id, + "success": False, + "error": _broadcast_error(str(e)), + "actually_fired_at": actually_fired_at, + } + try: + await self.messaging.publish( + reply_subj, json.dumps(reply_payload).encode(), + ) + except Exception as e: # pragma: no cover + self._logger.warning( + "Broadcast %s: reply publish failed: %s", correlation_id, e, + ) + + + def _evaluate_where( + self, + where_expr: str, + bindings: Optional[Dict[str, Any]], + correlation_id: str, + ) -> bool: + """Compile and evaluate a where predicate; return True iff it passes. + + Returns False (do not execute) on compile or eval errors, logging + a warning so silent self-deselection is operator-visible. + """ + try: + from device_connect_edge.predicate import compile_where + predicate = compile_where(where_expr) + caps = self._driver.capabilities if self._driver else self.capabilities + status = self._driver.status if self._driver else None + labels = (caps.labels if caps and caps.labels else {}) or {} + status_dict = ( + status.model_dump() if status and hasattr(status, "model_dump") else {} + ) + # Mirror DeviceStatus.location into labels so ``labels.location`` + # works in predicates without the driver having to declare it + # explicitly. Matches the dispatcher-side flatten_device contract. + if "location" not in labels and status_dict.get("location"): + labels = {**labels, "location": status_dict["location"]} + # DeviceIdentity is exposed by the driver, not by DeviceCapabilities; + # they are independent pydantic models. Read identity from the + # driver so extra fields (seat_row, seat_col, x-mhp metadata, ...) + # reach the predicate context. Splice in device_id which lives on + # the runtime so predicates can write + # ``identity.device_id == "..."`` naturally. + identity_dict: Dict[str, Any] = {"device_id": self.device_id} + driver_identity = ( + getattr(self._driver, "identity", None) if self._driver else None + ) + if driver_identity is not None and hasattr(driver_identity, "model_dump"): + identity_dict.update(driver_identity.model_dump()) + context = { + "identity": identity_dict, + "labels": labels, + "status": status_dict, + "bindings": bindings or {}, + } + return bool(predicate.evaluate(context)) + except Exception as e: + self._logger.warning( + "Broadcast %s: where predicate failed (skipping): %s", + correlation_id, e, + ) + return False + + async def _event_dispatch_loop(self) -> None: """Send queued events, retrying on failure.""" @@ -1372,6 +1627,13 @@ async def run(self) -> None: # Subscribe to commands BEFORE capability routines so log order makes sense await self._cmd_subscription() + # Subscribe to fleet broadcasts (best-effort; broadcast is opt-in for + # callers, so failure here should not block command handling). + try: + await self._broadcast_subscription() + except Exception as e: # pragma: no cover - best effort logging + self._logger.warning("Broadcast subscription failed: %s", e) + # Start capability routines if driver supports them (CapabilityDriverMixin) # This must happen after registration so events don't fire before device is registered if hasattr(self._driver, 'start_capability_routines'): diff --git a/packages/device-connect-edge/device_connect_edge/drivers/__init__.py b/packages/device-connect-edge/device_connect_edge/drivers/__init__.py index 3abdf1d..ac996b8 100644 --- a/packages/device-connect-edge/device_connect_edge/drivers/__init__.py +++ b/packages/device-connect-edge/device_connect_edge/drivers/__init__.py @@ -37,6 +37,7 @@ async def motion_detected(self, zone: str, confidence: float): emit, before_emit, periodic, + requires_mandate, build_function_schema, build_event_schema, ) @@ -55,6 +56,7 @@ async def motion_detected(self, zone: str, confidence: float): "emit", "before_emit", "periodic", + "requires_mandate", "on", "build_function_schema", "build_event_schema", diff --git a/packages/device-connect-edge/device_connect_edge/drivers/base.py b/packages/device-connect-edge/device_connect_edge/drivers/base.py index 2f5228b..8c01d4b 100644 --- a/packages/device-connect-edge/device_connect_edge/drivers/base.py +++ b/packages/device-connect-edge/device_connect_edge/drivers/base.py @@ -70,7 +70,7 @@ async def disconnect(self) -> None: import logging import time from abc import ABC -from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING from device_connect_edge.types import ( FunctionDef, @@ -129,6 +129,14 @@ class DeviceDriver(ABC): # starting background tasks. Example: depends_on = ("robot", "speaker") depends_on: Tuple[str, ...] = () + # Override in subclasses to attach discovery metadata to the device. Carried on + # DeviceCapabilities. Values may be a single string or a list of strings (composite + # identity). Well-known keys: category (camera|robot|hub|sensor|actuator|inference), + # location (e.g. 'warehouse1/loading-dock'). Custom keys are allowed. + # Example: + # labels = {"category": ["camera", "inference"], "location": "warehouse1/dock-3"} + labels: Optional[Dict[str, Union[str, List[str]]]] = None + # Type alias for event callback EventCallback = Callable[[str, Dict[str, Any]], Any] @@ -249,7 +257,8 @@ def capabilities(self) -> DeviceCapabilities: return DeviceCapabilities( description=self.__class__.__doc__ or "", functions=self.functions, - events=self.events + events=self.events, + labels=self.labels, ) @property @@ -347,7 +356,7 @@ async def invoke(self, function_name: str, **params: Any) -> Any: # Properties to skip during attribute scanning to avoid recursion _SKIP_ATTRS = frozenset([ "capabilities", "functions", "events", "identity", "status", - "device_type" + "device_type", "labels" ]) def _collect_functions(self) -> List[FunctionDef]: @@ -379,11 +388,15 @@ def _collect_functions(self) -> List[FunctionDef]: func_name = getattr(attr, "_function_name", attr_name) description = getattr(attr, "_description", "") parameters = build_function_schema(attr) + labels = getattr(attr, "_labels", None) + mandate = getattr(attr, "_mandate", None) functions.append(FunctionDef( name=func_name, description=description, parameters=parameters, + labels=labels, + mandate=mandate, tags=[] )) @@ -418,11 +431,13 @@ def _collect_events(self) -> List[EventDef]: event_name = getattr(attr, "_event_name", attr_name) description = getattr(attr, "_event_description", "") payload_schema = build_event_schema(attr) + labels = getattr(attr, "_labels", None) events.append(EventDef( name=event_name, description=description, payload_schema=payload_schema, + labels=labels, tags=[] )) diff --git a/packages/device-connect-edge/device_connect_edge/drivers/capability_loader.py b/packages/device-connect-edge/device_connect_edge/drivers/capability_loader.py index 824add6..87102e4 100644 --- a/packages/device-connect-edge/device_connect_edge/drivers/capability_loader.py +++ b/packages/device-connect-edge/device_connect_edge/drivers/capability_loader.py @@ -512,6 +512,9 @@ def _register_functions(self, loaded: LoadedCapability) -> None: "description": description, "parameters": parameters, } + mandate = getattr(attr, "_mandate", None) + if mandate is not None: + loaded.function_schemas[func_name]["mandate"] = mandate # Register with namespace prefix self._functions[f"{cap_id}.{func_name}"] = attr diff --git a/packages/device-connect-edge/device_connect_edge/drivers/decorators.py b/packages/device-connect-edge/device_connect_edge/drivers/decorators.py index b59a3a5..aeade78 100644 --- a/packages/device-connect-edge/device_connect_edge/drivers/decorators.py +++ b/packages/device-connect-edge/device_connect_edge/drivers/decorators.py @@ -57,7 +57,7 @@ async def detection_loop(self): import re import time import uuid -from typing import Any, Callable, Dict, Optional, get_type_hints, get_origin, get_args +from typing import Any, Callable, Dict, List, Optional, Union, get_type_hints, get_origin, get_args from device_connect_edge.telemetry.tracer import get_tracer, get_current_trace_id, SpanKind, StatusCode from device_connect_edge.telemetry.metrics import get_metrics @@ -345,6 +345,7 @@ def _get_integration_logger(obj: Any) -> Optional[Callable[[dict], None]]: def rpc( name: Optional[str] = None, description: Optional[str] = None, + labels: Optional[Dict[str, Union[str, List[str]]]] = None, ) -> Callable: """Decorator to expose a method as an RPC-callable function. @@ -355,6 +356,10 @@ def rpc( Args: name: Override function name (default: method __name__) description: Override description (default: first line of docstring) + labels: Discovery metadata as key:value pairs. Values may be a single + string or a list of strings (composite identity). Well-known keys: + direction (read|write), safety (critical|informational), modality + (rgb|thermal|...). Custom keys are allowed. Returns: Decorated method with function metadata attached @@ -372,6 +377,11 @@ async def my_function(self, param: str = "default") -> dict: @rpc(name="customName", description="Custom description") async def another_function(self, x: int) -> dict: return {"x": x} + + @rpc(labels={"direction": "write", "modality": ["rgb", "4k"]}) + async def capture_frame(self, resolution: str = "1080p") -> dict: + '''Capture a frame.''' + return {} """ def decorator(func: Callable) -> Callable: func_name = name or func.__name__ @@ -380,6 +390,7 @@ def decorator(func: Callable) -> Callable: summary, arg_docs = _parse_docstring(func.__doc__) func._description = description or summary func._arg_descriptions = arg_docs + func._labels = labels @functools.wraps(func) async def wrapper(self, *args, **kwargs): @@ -499,6 +510,8 @@ async def wrapper(self, *args, **kwargs): wrapper._function_name = func_name wrapper._description = func._description wrapper._arg_descriptions = func._arg_descriptions + wrapper._labels = func._labels + wrapper._mandate = getattr(func, "_mandate", None) wrapper._original_func = func # For schema extraction return wrapper @@ -506,9 +519,19 @@ async def wrapper(self, *args, **kwargs): return decorator +def requires_mandate(scope: str = "actuation") -> Callable: + """Mark an RPC method as requiring a valid Device Mandate.""" + def decorator(func: Callable) -> Callable: + func._mandate = {"required": True, "scope": scope} + return func + + return decorator + + def emit( name: Optional[str] = None, - description: Optional[str] = None + description: Optional[str] = None, + labels: Optional[Dict[str, Union[str, List[str]]]] = None, ) -> Callable: """Decorator to declare an event this driver can emit. @@ -524,6 +547,10 @@ def emit( Args: name: Override event name (default: method __name__) description: Event description (default: first line of docstring) + labels: Discovery metadata as key:value pairs. Values may be a single + string or a list of strings (composite identity). Well-known keys: + safety (critical|informational), modality (rgb|thermal|motion|...). + Custom keys are allowed. Returns: Decorated method that emits event when called @@ -550,6 +577,7 @@ def decorator(func: Callable) -> Callable: summary, arg_docs = _parse_docstring(func.__doc__) func._event_description = description or summary func._payload_descriptions = arg_docs + func._labels = labels @functools.wraps(func) async def wrapper(self, *args, **kwargs): @@ -624,6 +652,7 @@ async def wrapper(self, *args, **kwargs): wrapper._event_name = event_name wrapper._event_description = func._event_description wrapper._payload_descriptions = func._payload_descriptions + wrapper._labels = func._labels wrapper._original_func = func # For schema extraction return wrapper diff --git a/packages/device-connect-edge/device_connect_edge/mandates.py b/packages/device-connect-edge/device_connect_edge/mandates.py new file mode 100644 index 0000000..a2ebb03 --- /dev/null +++ b/packages/device-connect-edge/device_connect_edge/mandates.py @@ -0,0 +1,247 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Device Mandate helpers. + +This module implements the first Device Mandate credential profile used by +Device Connect tests and demos. It is intentionally small and stdlib-only: +the public runtime contract is the mandate envelope and verifier interface, +while production credential formats can be added behind the same boundary. +""" + +from __future__ import annotations + +import hashlib +import hmac +import json +import uuid +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any, Callable + + +MANDATE_FORMAT = "device-connect-hmac-v0" + + +@dataclass(frozen=True) +class MandateInvocationContext: + """Concrete invocation a closed mandate must authorize.""" + + device_id: str + method: str + params: dict[str, Any] + now: datetime | None = None + + +@dataclass(frozen=True) +class MandateVerificationResult: + """Verifier result for expected allow/deny outcomes.""" + + ok: bool + error_code: str | None = None + message: str = "" + + +KeyResolver = Callable[[str], bytes | str | None] + + +def create_open_mandate( + *, + principal: str, + agent: str, + device_id: str, + methods: list[str], + constraints: dict[str, Any] | None, + not_before: datetime, + not_after: datetime, + key: bytes | str, + mandate_id: str | None = None, +) -> dict[str, Any]: + """Create and sign an open mandate.""" + + payload = { + "id": mandate_id or f"open-{uuid.uuid4().hex[:12]}", + "principal": principal, + "agent": agent, + "device_id": device_id, + "methods": list(methods), + "constraints": constraints or {}, + "not_before": _format_dt(not_before), + "not_after": _format_dt(not_after), + } + return {**payload, "signature": _sign(payload, key)} + + +def create_closed_mandate( + *, + open_mandate: dict[str, Any], + agent: str, + device_id: str, + method: str, + params: dict[str, Any], + key: bytes | str, + issued_at: datetime, + mandate_id: str | None = None, + nonce: str | None = None, +) -> dict[str, Any]: + """Create and sign a closed mandate for one concrete invocation.""" + + payload = { + "format": MANDATE_FORMAT, + "id": mandate_id or f"closed-{uuid.uuid4().hex[:12]}", + "agent": agent, + "open_mandate": open_mandate, + "invocation": { + "device_id": device_id, + "method": method, + "params": params, + }, + "issued_at": _format_dt(issued_at), + "nonce": nonce or uuid.uuid4().hex, + } + return {**payload, "signature": _sign(payload, key)} + + +def verify_mandate( + mandate: dict[str, Any] | None, + *, + context: MandateInvocationContext, + key_resolver: KeyResolver, + replay_cache: set[str] | None = None, +) -> MandateVerificationResult: + """Verify that a closed mandate authorizes an invocation.""" + + if not mandate: + return _deny("mandate_required", "mandate_required: protected RPC needs a mandate") + if not isinstance(mandate, dict): + return _deny("invalid_mandate", "invalid_mandate: mandate must be an object") + if mandate.get("format") != MANDATE_FORMAT: + return _deny("invalid_mandate", "invalid_mandate: unsupported mandate format") + + open_mandate = mandate.get("open_mandate") + if not isinstance(open_mandate, dict): + return _deny("invalid_mandate", "invalid_mandate: missing open mandate") + + principal = open_mandate.get("principal") + agent = mandate.get("agent") + if not isinstance(principal, str) or not isinstance(agent, str): + return _deny("invalid_mandate", "invalid_mandate: missing principal or agent") + if open_mandate.get("agent") != agent: + return _deny("mandate_agent_denied", "mandate_agent_denied: agent mismatch") + + principal_key = key_resolver(principal) + agent_key = key_resolver(agent) + if principal_key is None or agent_key is None: + return _deny("unknown_mandate_key", "unknown_mandate_key: signer key unavailable") + if not _signature_valid(open_mandate, principal_key): + return _deny("invalid_mandate_signature", "invalid_mandate_signature: open mandate") + if not _signature_valid(mandate, agent_key): + return _deny("invalid_mandate_signature", "invalid_mandate_signature: closed mandate") + + now = _as_utc(context.now or datetime.now(timezone.utc)) + not_before = _parse_dt(str(open_mandate.get("not_before", ""))) + not_after = _parse_dt(str(open_mandate.get("not_after", ""))) + if not_before is None or not_after is None: + return _deny("invalid_mandate", "invalid_mandate: invalid validity window") + if now < not_before: + return _deny("mandate_not_yet_valid", "mandate_not_yet_valid") + if now > not_after: + return _deny("mandate_expired", "mandate_expired") + + if open_mandate.get("device_id") != context.device_id: + return _deny("mandate_device_denied", "mandate_device_denied") + if context.method not in (open_mandate.get("methods") or []): + return _deny("mandate_method_denied", "mandate_method_denied") + + invocation = mandate.get("invocation") or {} + if invocation.get("device_id") != context.device_id: + return _deny("mandate_device_denied", "mandate_device_denied") + if invocation.get("method") != context.method: + return _deny("mandate_method_denied", "mandate_method_denied") + if invocation.get("params") != context.params: + return _deny("mandate_params_denied", "mandate_params_denied") + + constraint_error = _check_constraints( + open_mandate.get("constraints") or {}, context.params, + ) + if constraint_error is not None: + return _deny("mandate_constraint_denied", constraint_error) + + nonce = mandate.get("nonce") + if replay_cache is not None and isinstance(nonce, str): + if nonce in replay_cache: + return _deny("mandate_replayed", "mandate_replayed") + replay_cache.add(nonce) + + return MandateVerificationResult(ok=True) + + +def _deny(code: str, message: str) -> MandateVerificationResult: + return MandateVerificationResult(ok=False, error_code=code, message=message) + + +def _sign(payload: dict[str, Any], key: bytes | str) -> str: + return hmac.new(_key_bytes(key), _canonical(payload), hashlib.sha256).hexdigest() + + +def _signature_valid(payload: dict[str, Any], key: bytes | str) -> bool: + expected = payload.get("signature") + if not isinstance(expected, str): + return False + unsigned = {k: v for k, v in payload.items() if k != "signature"} + return hmac.compare_digest(expected, _sign(unsigned, key)) + + +def _canonical(payload: dict[str, Any]) -> bytes: + return json.dumps( + payload, sort_keys=True, separators=(",", ":"), ensure_ascii=True, + ).encode() + + +def _key_bytes(key: bytes | str) -> bytes: + return key if isinstance(key, bytes) else key.encode() + + +def _format_dt(value: datetime) -> str: + return _as_utc(value).isoformat().replace("+00:00", "Z") + + +def _parse_dt(value: str) -> datetime | None: + try: + return _as_utc(datetime.fromisoformat(value.replace("Z", "+00:00"))) + except ValueError: + return None + + +def _as_utc(value: datetime) -> datetime: + if value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc) + + +def _check_constraints( + constraints: dict[str, Any], params: dict[str, Any], +) -> str | None: + for name, rules in constraints.items(): + if name not in params: + return f"mandate_constraint_denied: missing {name}" + value = params[name] + if not isinstance(rules, dict): + if value != rules: + return f"mandate_constraint_denied: {name}" + continue + for op, expected in rules.items(): + if op == "eq" and value != expected: + return f"mandate_constraint_denied: {name}" + if op == "lte" and not value <= expected: + return f"mandate_constraint_denied: {name}" + if op == "lt" and not value < expected: + return f"mandate_constraint_denied: {name}" + if op == "gte" and not value >= expected: + return f"mandate_constraint_denied: {name}" + if op == "gt" and not value > expected: + return f"mandate_constraint_denied: {name}" + if op == "in" and value not in expected: + return f"mandate_constraint_denied: {name}" + return None diff --git a/packages/device-connect-edge/device_connect_edge/predicate.py b/packages/device-connect-edge/device_connect_edge/predicate.py new file mode 100644 index 0000000..5bf5ff6 --- /dev/null +++ b/packages/device-connect-edge/device_connect_edge/predicate.py @@ -0,0 +1,163 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""CEL ``where`` predicate evaluator for self-election at the edge. + +A ``where`` predicate is a CEL (Common Expression Language) expression that +each candidate device evaluates against its own context to decide whether +to execute a fan-out call. The predicate sees four top-level variables: + + identity device-local identity dict (device_id, device_type, ...) + labels device labels (the same labels selectors filter on) + status device status (heartbeat-updated: location, availability, + battery, online, ...) + bindings shared payload supplied by the caller (selection masks, + thresholds, lookup tables) + +Examples:: + + battery > 50 + labels.category == "camera" && status.battery > 50 + mask[seat_row][seat_col] == 1 + bindings.threshold < status.temperature + +CEL is sandboxed by construction: no I/O, no filesystem, no exec. This +module wraps `cel-python` with lazy import so device-connect-edge does +not require it as a hard dependency. Install with the optional +``[predicate]`` extra:: + + pip install device-connect-edge[predicate] + +The evaluator is shared by the dispatcher (validates the expression +before broadcast) and the device runtime (evaluates per-call to decide +whether to execute the fan-out). +""" + +from __future__ import annotations + +from typing import Any, Mapping + + +class PredicateCompileError(ValueError): + """Raised when a ``where`` expression fails to compile. + + Carries the original cel-python error chained so callers can drill in + if they need the exact parse position. + """ + + +class PredicateEvalError(RuntimeError): + """Raised when an otherwise-valid predicate fails at evaluation time. + + Typical causes: missing context key, type mismatch (e.g. comparing a + string to an int), or arithmetic overflow. + """ + + +# Lazy import: ``cel-python`` is an optional extra. Importers of this module +# pay no cost unless they actually compile a predicate. +def _require_celpy(): + try: + import celpy # type: ignore[import-not-found] + return celpy + except ImportError as e: + raise PredicateCompileError( + "where predicates require the 'cel-python' package; " + "install with the [predicate] extra: " + "pip install 'device-connect-edge[predicate]'" + ) from e + + +def _to_cel(value: Any) -> Any: + """Recursively wrap a Python value as the matching CEL type. + + Native Python ints, strings, dicts, and lists arrive at the boundary + untyped; cel-python's evaluator expects its own typed wrappers + (``IntType``, ``MapType``, ``ListType``, ...). We wrap once at the + top of evaluation rather than asking callers to import celtypes. + """ + celpy = _require_celpy() + ct = celpy.celtypes + if value is None: + return None + if isinstance(value, bool): + return ct.BoolType(value) + if isinstance(value, int): + return ct.IntType(value) + if isinstance(value, float): + return ct.DoubleType(value) + if isinstance(value, str): + return ct.StringType(value) + if isinstance(value, (bytes, bytearray)): + return ct.BytesType(bytes(value)) + if isinstance(value, Mapping): + return ct.MapType({ + ct.StringType(str(k)): _to_cel(v) for k, v in value.items() + }) + if isinstance(value, (list, tuple)): + return ct.ListType([_to_cel(v) for v in value]) + # Fallback: stringify. Rare; happens for custom objects in the context. + return ct.StringType(str(value)) + + +class WherePredicate: + """A compiled ``where`` predicate, ready to evaluate against device context. + + Compile once (typically at the dispatcher when the call comes in or at + the edge when the broadcast envelope is received), then evaluate once + per candidate. Predicates are stateless and safe to reuse across calls. + """ + + __slots__ = ("expression", "_program") + + def __init__(self, expression: str, _program: Any): + self.expression = expression + self._program = _program + + def evaluate(self, context: Mapping[str, Any]) -> bool: + """Return ``True`` if the predicate holds for ``context``. + + ``context`` should be a flat mapping of variable name to Python + value. Common keys: ``identity``, ``labels``, ``status``, + ``bindings``. Missing keys are not auto-defaulted; if the + predicate references one, the call raises PredicateEvalError so + the caller can decide between fail-open and fail-closed. + """ + celpy = _require_celpy() + cel_context = {k: _to_cel(v) for k, v in context.items()} + try: + result = self._program.evaluate(cel_context) + except celpy.CELEvalError as e: + raise PredicateEvalError( + f"failed to evaluate where {self.expression!r}: {e}" + ) from e + return bool(result) + + +def compile_where(expression: str) -> WherePredicate: + """Compile a ``where`` expression into a reusable :class:`WherePredicate`. + + Raises :class:`PredicateCompileError` if cel-python is not installed + or the expression is malformed. + """ + celpy = _require_celpy() + if not isinstance(expression, str): + raise PredicateCompileError( + f"where expression must be a string, got {type(expression).__name__}" + ) + if not expression.strip(): + raise PredicateCompileError("where expression must be non-empty") + env = celpy.Environment() + try: + ast = env.compile(expression) + except Exception as e: + # cel-python surfaces parse errors via several exception classes + # depending on the failure mode (lark.UnexpectedToken, ValueError, + # CELParseError). Catch broadly and rewrap so callers only see + # PredicateCompileError. + raise PredicateCompileError( + f"failed to compile where {expression!r}: {e}" + ) from e + program = env.program(ast) + return WherePredicate(expression=expression, _program=program) diff --git a/packages/device-connect-edge/device_connect_edge/selector.py b/packages/device-connect-edge/device_connect_edge/selector.py new file mode 100644 index 0000000..f2218e3 --- /dev/null +++ b/packages/device-connect-edge/device_connect_edge/selector.py @@ -0,0 +1,467 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Selector DSL for hierarchical device + function discovery. + +This module parses selector expressions used by the discovery, invocation, and +subscription APIs into a structured form that can be matched against device, +function, and event records. + +Placement note: this module is dependency-free (stdlib only) and is consumed +by callers outside this package -- notably the discovery tools in +``device_connect_agent_tools``. It lives here as the lowest common ancestor +in the package dependency graph, not as edge-runtime code; ``DeviceRuntime`` +and the driver framework do not import it. + +Grammar overview: + + device() # filter on device labels + device().function() # functions on a device subset (RPCs) + device().event() # events on a device subset + function() # all RPCs across the fleet + event() # all events across the fleet + +Inside ``(...)``: + + key:value single value match + key:[v1,v2] OR within a key (matches if label contains any value) + key:pattern* glob (``*``, ``?``) + k1:v1,k2:v2 AND across keys + bare-string id/name match: ``device(robot-001)`` + * match all +""" +from __future__ import annotations + +import fnmatch +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, List, Optional, Tuple, Union + +# A label value is either a single string or a list of strings (composite identity). +LabelValue = Union[str, List[str]] +Labels = Dict[str, LabelValue] + + +class SelectorParseError(ValueError): + """Raised when a selector string cannot be parsed.""" + + def __init__(self, message: str, source: str = "", position: Optional[int] = None): + if position is not None and source: + caret = " " * position + "^" + full = f"{message} at position {position}\n {source}\n {caret}" + elif source: + full = f"{message}: {source!r}" + else: + full = message + super().__init__(full) + self.source = source + self.position = position + + +class Scope(str, Enum): + """Which entities a selector matches. + + DEVICE_ONLY - device(...) + DEVICE_FUNCTION - device(...).function(...) + DEVICE_EVENT - device(...).event(...) + FUNCTION_ONLY - function(...) + EVENT_ONLY - event(...) + """ + DEVICE_ONLY = "device_only" + DEVICE_FUNCTION = "device_function" + DEVICE_EVENT = "device_event" + FUNCTION_ONLY = "function_only" + EVENT_ONLY = "event_only" + + +@dataclass(frozen=True) +class KeyFilter: + """Filter on a single label key. + + Values are OR'd: any matching value is sufficient. Each value may contain + glob characters (``*`` and ``?``) per ``fnmatch`` semantics. + + ``children`` is reserved for grammar extensions (nested boolean + expressions, AND-within-key, negation) and is empty in the current + parser. Carrying the field on the dataclass now lets future versions + populate it without breaking the public type shape. + """ + key: str + values: Tuple[str, ...] + children: Tuple["KeyFilter", ...] = field(default_factory=tuple) + + def matches(self, label_value: Optional[LabelValue]) -> bool: + """True iff the label value satisfies this key filter. + + For multi-valued labels (list), passes if any element matches any of + this filter's values. + """ + if label_value is None: + return False + actual: Tuple[str, ...] + if isinstance(label_value, list): + actual = tuple(label_value) + else: + actual = (label_value,) + for pattern in self.values: + if "*" in pattern or "?" in pattern: + for a in actual: + if fnmatch.fnmatchcase(a, pattern): + return True + else: + if pattern in actual: + return True + return False + + +@dataclass(frozen=True) +class Filter: + """One axis of a selector - matches a single entity (device, function, or event). + + Combines an optional bare-string name match with AND-across-keys label + filters. An empty Filter (no name match, no key filters) matches every + entity, so ``*`` and empty parens both reduce to that case. + """ + name_match: Optional[str] = None + key_filters: Tuple[KeyFilter, ...] = field(default_factory=tuple) + + def matches(self, name: str, labels: Optional[Labels]) -> bool: + """True iff this filter matches the given entity.""" + if self.name_match is not None: + pattern = self.name_match + if "*" in pattern or "?" in pattern: + if not fnmatch.fnmatchcase(name, pattern): + return False + elif name != pattern: + return False + for kf in self.key_filters: + label_value = labels.get(kf.key) if labels else None + if not kf.matches(label_value): + return False + return True + + +@dataclass(frozen=True) +class Selector: + """Parsed selector expression. + + Each axis is an optional :class:`Filter`. A ``None`` axis is vacuously + True - ``matches_function`` on a device-only selector returns True so the + caller can write a single-pass enumeration without scope branching. + """ + scope: Scope + device: Optional[Filter] = None + function: Optional[Filter] = None + event: Optional[Filter] = None + raw: str = "" + + def matches_device(self, name: str, labels: Optional[Labels]) -> bool: + if self.device is None: + return True + return self.device.matches(name, labels) + + def matches_function(self, name: str, labels: Optional[Labels]) -> bool: + if self.function is None: + return True + return self.function.matches(name, labels) + + def matches_event(self, name: str, labels: Optional[Labels]) -> bool: + if self.event is None: + return True + return self.event.matches(name, labels) + + +# -- Parsing ------------------------------------------------------- + + +def _split_top_commas(body: str, source: str, base_offset: int) -> List[Tuple[str, int]]: + """Split a filter body on top-level commas. + + Respects ``[...]`` bracket nesting: commas inside brackets are part of the + value list, not term separators. Returns ``(term, abs_offset_of_term_start)`` + pairs to support precise error positioning. + """ + terms: List[Tuple[str, int]] = [] + depth = 0 + start = 0 + for i, ch in enumerate(body): + if ch == "[": + depth += 1 + elif ch == "]": + if depth == 0: + raise SelectorParseError( + "Unmatched ']'", source=source, position=base_offset + i + ) + depth -= 1 + elif ch == "," and depth == 0: + terms.append((body[start:i], base_offset + start)) + start = i + 1 + if depth != 0: + raise SelectorParseError( + "Unmatched '['", source=source, position=base_offset + body.rfind("[") + ) + terms.append((body[start:], base_offset + start)) + return terms + + +def _parse_value_part(value: str, source: str, base_offset: int) -> Tuple[str, ...]: + """Parse the right-hand side of ``key:``. + + Returns a tuple of value strings (one element for single value, multiple for + bracketed OR list). Each value may contain glob characters. + """ + value = value.strip() + if not value: + raise SelectorParseError( + "Empty value after ':'", source=source, position=base_offset + ) + if value.startswith("["): + if not value.endswith("]"): + raise SelectorParseError( + "Unclosed '['", source=source, position=base_offset + ) + inner = value[1:-1].strip() + if not inner: + raise SelectorParseError( + "Empty value list '[]'", source=source, position=base_offset + ) + # Bracket bodies are flat (Phase 2 grammar); split on commas, strip, reject empties + out: List[str] = [] + for raw in inner.split(","): + v = raw.strip() + if not v: + raise SelectorParseError( + "Empty value in list", source=source, position=base_offset + ) + if "[" in v or "]" in v: + raise SelectorParseError( + "Nested brackets are not supported in this DSL version", + source=source, + position=base_offset, + ) + out.append(v) + return tuple(out) + if "[" in value or "]" in value: + raise SelectorParseError( + "Stray bracket in value", source=source, position=base_offset + ) + return (value,) + + +_KEY_PATTERN = ("0123456789" + "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "_-.") + + +def _is_valid_key(key: str) -> bool: + """Label keys are conservative identifiers: alnum, '_', '-', '.'.""" + return bool(key) and all(c in _KEY_PATTERN for c in key) + + +def _parse_filter_body(body: str, source: str, base_offset: int) -> Filter: + """Parse the contents of one ``(...)`` block into a :class:`Filter`. + + Supports: + ``*`` or empty body -> match-all (empty Filter) + ``key:value`` -> single-value key filter + ``key:[v1,v2]`` -> OR within a key + ``key:pattern*`` -> glob value + ``k1:v1,k2:v2`` -> AND across keys + bare string -> name match (id/name) + bare + key:value -> name AND key constraints + """ + stripped = body.strip() + if not stripped or stripped == "*": + return Filter() + + name_match: Optional[str] = None + key_filters: List[KeyFilter] = [] + + for term, term_offset in _split_top_commas(body, source, base_offset): + # Account for leading whitespace inside the term when reporting positions. + leading = len(term) - len(term.lstrip()) + term_stripped = term.strip() + term_abs = term_offset + leading + if not term_stripped: + raise SelectorParseError( + "Empty term (extra comma?)", source=source, position=term_abs + ) + + # Find a top-level ':' (one not inside the value brackets) to classify + # bare-name vs key:value. + colon_pos = -1 + depth = 0 + for j, ch in enumerate(term_stripped): + if ch == "[": + depth += 1 + elif ch == "]": + depth -= 1 + elif ch == ":" and depth == 0: + colon_pos = j + break + + if colon_pos < 0: + # Bare term: name match or '*' + if term_stripped == "*": + continue # vacuous, contributes nothing + if name_match is not None: + raise SelectorParseError( + f"Multiple bare-name terms ({name_match!r} and {term_stripped!r})", + source=source, + position=term_abs, + ) + name_match = term_stripped + continue + + key = term_stripped[:colon_pos].strip() + value_part = term_stripped[colon_pos + 1:] + value_offset = term_abs + colon_pos + 1 + if not _is_valid_key(key): + raise SelectorParseError( + f"Invalid key {key!r} (allowed: alphanumeric, '_', '-', '.')", + source=source, + position=term_abs, + ) + values = _parse_value_part(value_part, source, value_offset) + key_filters.append(KeyFilter(key=key, values=values)) + + return Filter(name_match=name_match, key_filters=tuple(key_filters)) + + +_VALID_SCOPES = ("device", "function", "event") + + +def _consume_scope(s: str, source: str, start: int) -> Tuple[str, Filter, int]: + """Consume one ``()`` from ``s`` starting at ``start``. + + Returns ``(scope_name, filter, position_after_closing_paren)``. Skips + leading whitespace. + """ + i = start + n = len(s) + while i < n and s[i].isspace(): + i += 1 + name_start = i + while i < n and s[i] not in "( \t": + i += 1 + name = s[name_start:i] + if not name: + raise SelectorParseError( + "Expected scope name (device|function|event)", source=source, position=name_start + ) + if name not in _VALID_SCOPES: + raise SelectorParseError( + f"Unknown scope {name!r} (expected one of {_VALID_SCOPES})", + source=source, + position=name_start, + ) + while i < n and s[i].isspace(): + i += 1 + if i >= n or s[i] != "(": + raise SelectorParseError( + f"Expected '(' after scope {name!r}", source=source, position=i + ) + body_start = i + 1 + # Find matching ')', tracking [...] nesting so a stray ')' inside brackets + # would not be treated as the scope close. (Reserved chars rule out ')' + # in valid values, but be defensive.) + depth = 0 + last_open_bracket = -1 + j = body_start + while j < n: + ch = s[j] + if ch == "[": + depth += 1 + last_open_bracket = j + elif ch == "]": + depth -= 1 + elif ch == ")" and depth == 0: + break + j += 1 + if j >= n: + if depth > 0: + raise SelectorParseError( + "Unclosed '['", source=source, position=last_open_bracket + ) + raise SelectorParseError( + f"Unclosed '(' for scope {name!r}", source=source, position=body_start - 1 + ) + body = s[body_start:j] + flt = _parse_filter_body(body, source=source, base_offset=body_start) + return name, flt, j + 1 + + +def parse_selector(s: str) -> Selector: + """Parse a selector string into a :class:`Selector`. + + Examples:: + + parse_selector("device(category:camera)") + parse_selector("device(category:[camera,robot], location:warehouse1/*)") + parse_selector("device(*).function(direction:write)") + parse_selector("function(safety:critical)") + + Raises :class:`SelectorParseError` on malformed input. + """ + if not isinstance(s, str): + raise SelectorParseError(f"Selector must be a string, got {type(s).__name__}") + raw = s + if not s.strip(): + raise SelectorParseError("Empty selector", source=raw, position=0) + + name1, filter1, after1 = _consume_scope(s, source=raw, start=0) + + # Optional ".scope(...)" extension + i = after1 + n = len(s) + while i < n and s[i].isspace(): + i += 1 + + if i >= n: + # Single-scope selector + if name1 == "device": + return Selector(scope=Scope.DEVICE_ONLY, device=filter1, raw=raw) + if name1 == "function": + return Selector(scope=Scope.FUNCTION_ONLY, function=filter1, raw=raw) + if name1 == "event": + return Selector(scope=Scope.EVENT_ONLY, event=filter1, raw=raw) + # _consume_scope already validated name1 + raise SelectorParseError(f"Internal: unhandled scope {name1!r}", source=raw) + + if s[i] != ".": + raise SelectorParseError( + f"Unexpected character {s[i]!r} after scope", source=raw, position=i + ) + + name2, filter2, after2 = _consume_scope(s, source=raw, start=i + 1) + + # Trailing content? + j = after2 + while j < n and s[j].isspace(): + j += 1 + if j < n: + raise SelectorParseError( + f"Unexpected trailing content {s[j:]!r}", source=raw, position=j + ) + + if name1 != "device": + raise SelectorParseError( + f"Chained scopes must start with 'device', got {name1!r}", + source=raw, + position=0, + ) + if name2 == "function": + return Selector( + scope=Scope.DEVICE_FUNCTION, device=filter1, function=filter2, raw=raw + ) + if name2 == "event": + return Selector( + scope=Scope.DEVICE_EVENT, device=filter1, event=filter2, raw=raw + ) + raise SelectorParseError( + f"Cannot chain device(...).{name2}(...); expected 'function' or 'event'", + source=raw, + position=i + 1, + ) diff --git a/packages/device-connect-edge/device_connect_edge/types.py b/packages/device-connect-edge/device_connect_edge/types.py index 00296bb..e73c1ea 100644 --- a/packages/device-connect-edge/device_connect_edge/types.py +++ b/packages/device-connect-edge/device_connect_edge/types.py @@ -12,7 +12,7 @@ from datetime import datetime from enum import Enum -from typing import Any, Awaitable, Callable, Dict, List, Optional +from typing import Any, Awaitable, Callable, Dict, List, Optional, Union from pydantic import BaseModel, Field @@ -51,6 +51,7 @@ class FunctionDef(BaseModel): }, "required": [] }, + labels={"direction": "write", "modality": ["rgb", "4k"]}, tags=["vision", "capture"] ) """ @@ -60,6 +61,19 @@ class FunctionDef(BaseModel): default_factory=lambda: {"type": "object", "properties": {}, "required": []}, description="JSON Schema for function parameters" ) + labels: Optional[Dict[str, Union[str, List[str]]]] = Field( + default=None, + description="Discovery metadata as key:value pairs. Values may be a single string " + "or a list of strings (composite identity). Well-known keys: direction " + "(read|write), safety (critical|informational), modality (rgb|thermal|...). " + "Custom keys are allowed." + ) + mandate: Optional[Dict[str, Any]] = Field( + default=None, + description="Optional execution authorization policy metadata. When present with " + "{'required': True}, the runtime requires a valid Device Mandate before " + "executing this function." + ) tags: List[str] = Field( default_factory=list, description="Tags for categorization (e.g., ['vision', 'capture'])" @@ -92,6 +106,13 @@ class EventDef(BaseModel): default=None, description="JSON Schema for event payload (optional)" ) + labels: Optional[Dict[str, Union[str, List[str]]]] = Field( + default=None, + description="Discovery metadata as key:value pairs. Values may be a single string " + "or a list of strings (composite identity). Well-known keys: safety " + "(critical|informational), modality (rgb|thermal|motion|...). Custom keys " + "are allowed." + ) tags: List[str] = Field( default_factory=list, description="Tags for categorization" @@ -125,6 +146,14 @@ class DeviceCapabilities(BaseModel): default_factory=list, description="Events the device can emit" ) + labels: Optional[Dict[str, Union[str, List[str]]]] = Field( + default=None, + description="Discovery metadata for the device as key:value pairs. Values may be a " + "single string or a list of strings (composite identity). Well-known keys: " + "category (camera|robot|hub|sensor|actuator|inference; multi-valued for " + "composite devices), location (e.g. 'warehouse1/loading-dock'; '/' for " + "hierarchy, multi-valued for mobile devices). Custom keys are allowed." + ) class DeviceIdentity(BaseModel): diff --git a/packages/device-connect-edge/examples/device_mandates/mandate_examples.py b/packages/device-connect-edge/examples/device_mandates/mandate_examples.py new file mode 100644 index 0000000..25eec19 --- /dev/null +++ b/packages/device-connect-edge/examples/device_mandates/mandate_examples.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Local examples for Device Mandates. + +Run from the repository root: + PYTHONPATH=packages/device-connect-edge python packages/device-connect-edge/examples/device_mandates/mandate_examples.py + +Focused tests: + pytest packages/device-connect-edge/tests/test_mandate_verifier.py packages/device-connect-edge/tests/test_device_mandates.py packages/device-connect-agent-tools/tests/test_agent_mandates.py -q +""" + +from __future__ import annotations + +import asyncio +from datetime import datetime, timedelta, timezone +from typing import Any + +from device_connect_edge import create_closed_mandate, create_open_mandate +from device_connect_edge.drivers import DeviceDriver, requires_mandate, rpc +from device_connect_edge.mandates import MandateInvocationContext, verify_mandate + + +PRINCIPAL_KEY = b"principal-demo-key" +AGENT_KEY = b"agent-demo-key" + + +class SmartLockDriver(DeviceDriver): + """Smart lock with mandate-protected actuation.""" + + device_type = "smart_lock" + + @requires_mandate(scope="actuation") + @rpc() + async def unlock(self, duration_s: int = 10) -> dict[str, Any]: + return {"state": "unlocked", "duration_s": duration_s} + + @rpc() + async def get_status(self) -> dict[str, str]: + return {"state": "locked"} + + +class HeaterDriver(DeviceDriver): + """Heater with mandate-protected setpoint changes.""" + + device_type = "heater" + + @rpc() + async def get_temperature(self) -> dict[str, float]: + return {"current_c": 20.5} + + @requires_mandate(scope="actuation") + @rpc() + async def set_temperature(self, target_c: float) -> dict[str, float]: + return {"target_c": target_c} + + +def key_resolver(principal: str) -> bytes | None: + return {"operator": PRINCIPAL_KEY, "agent-1": AGENT_KEY}.get(principal) + + +def closed_mandate( + *, + device_id: str, + method: str, + params: dict[str, Any], + constraints: dict[str, Any], +) -> dict[str, Any]: + now = datetime.now(timezone.utc) + open_mandate = create_open_mandate( + principal="operator", + agent="agent-1", + device_id=device_id, + methods=[method], + constraints=constraints, + not_before=now - timedelta(seconds=5), + not_after=now + timedelta(minutes=5), + key=PRINCIPAL_KEY, + ) + return create_closed_mandate( + open_mandate=open_mandate, + agent="agent-1", + device_id=device_id, + method=method, + params=params, + key=AGENT_KEY, + issued_at=now, + ) + + +def verify_example( + *, + label: str, + mandate: dict[str, Any] | None, + device_id: str, + method: str, + params: dict[str, Any], + replay_cache: set[str] | None = None, +) -> None: + result = verify_mandate( + mandate, + context=MandateInvocationContext( + device_id=device_id, + method=method, + params=params, + ), + key_resolver=key_resolver, + replay_cache=replay_cache, + ) + outcome = "allowed" if result.ok else f"denied ({result.error_code})" + print(f"{label}: {outcome}") + + +async def main() -> None: + lock = SmartLockDriver() + heater = HeaterDriver() + + unlock_policy = getattr(lock.unlock, "_mandate", None) + heater_policy = getattr(heater.set_temperature, "_mandate", None) + print(f"smart-lock unlock mandate policy: {unlock_policy}") + print(f"heater set_temperature mandate policy: {heater_policy}") + + valid_unlock_params = {"duration_s": 20} + valid_unlock = closed_mandate( + device_id="lock-front-door", + method="unlock", + params=valid_unlock_params, + constraints={"duration_s": {"lte": 30}}, + ) + verify_example( + label="valid smart-lock unlock", + mandate=valid_unlock, + device_id="lock-front-door", + method="unlock", + params=valid_unlock_params, + ) + verify_example( + label="invalid smart-lock duration", + mandate=valid_unlock, + device_id="lock-front-door", + method="unlock", + params={"duration_s": 60}, + ) + + valid_heat_params = {"target_c": 21.5} + valid_heat = closed_mandate( + device_id="heater-living-room", + method="set_temperature", + params=valid_heat_params, + constraints={"target_c": {"gte": 18, "lte": 23}}, + ) + replay_cache: set[str] = set() + verify_example( + label="valid heater setpoint", + mandate=valid_heat, + device_id="heater-living-room", + method="set_temperature", + params=valid_heat_params, + replay_cache=replay_cache, + ) + verify_example( + label="invalid heater replay", + mandate=valid_heat, + device_id="heater-living-room", + method="set_temperature", + params=valid_heat_params, + replay_cache=replay_cache, + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/packages/device-connect-edge/pyproject.toml b/packages/device-connect-edge/pyproject.toml index 27b5e88..58de4d1 100644 --- a/packages/device-connect-edge/pyproject.toml +++ b/packages/device-connect-edge/pyproject.toml @@ -38,6 +38,9 @@ dependencies = [ [project.optional-dependencies] zenoh = [] # Zenoh is now a core dependency; kept for backward compat +predicate = [ + "cel-python>=0.5.0", +] telemetry = [ "opentelemetry-api>=1.30.0", "opentelemetry-sdk>=1.30.0", diff --git a/packages/device-connect-edge/tests/test_capability_loader.py b/packages/device-connect-edge/tests/test_capability_loader.py index 0f7d22c..d962c70 100644 --- a/packages/device-connect-edge/tests/test_capability_loader.py +++ b/packages/device-connect-edge/tests/test_capability_loader.py @@ -62,6 +62,20 @@ async def custom(self, value: str = "default") -> dict: return {"value": value} """ +MANDATE_CAPABILITY_CODE = """\ +from device_connect_edge.drivers.decorators import requires_mandate, rpc + +class MandateCapability: + def __init__(self, device=None): + self.device = device + + @requires_mandate(scope="actuation") + @rpc() + async def unlock(self, duration_s: int) -> dict: + \"\"\"Unlock with delegated authorization.\"\"\" + return {"unlocked": True} +""" + EMIT_CAPABILITY_CODE = """\ from device_connect_edge.drivers.decorators import rpc, emit @@ -206,6 +220,17 @@ async def test_function_schemas_populated(self, loader, tmp_path): assert "parameters" in schema assert "description" in schema + @pytest.mark.asyncio + async def test_function_schemas_include_mandate_metadata(self, loader, tmp_path): + _write_capability(tmp_path, "mandate-cap", "MandateCapability", MANDATE_CAPABILITY_CODE) + await loader.load_all() + + loaded = loader.get_capabilities()["mandate-cap"] + assert loaded.function_schemas["unlock"]["mandate"] == { + "required": True, + "scope": "actuation", + } + # -- Extracting @emit methods -- diff --git a/packages/device-connect-edge/tests/test_device_mandates.py b/packages/device-connect-edge/tests/test_device_mandates.py new file mode 100644 index 0000000..cda43fe --- /dev/null +++ b/packages/device-connect-edge/tests/test_device_mandates.py @@ -0,0 +1,179 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Runtime enforcement tests for mandate-protected RPCs.""" + +from __future__ import annotations + +import json +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock + +import pytest + +from device_connect_edge.device import DeviceRuntime +from device_connect_edge.drivers import DeviceDriver, requires_mandate, rpc +from device_connect_edge.mandates import create_closed_mandate, create_open_mandate + + +PRINCIPAL_KEY = b"principal-secret" +AGENT_KEY = b"agent-secret" + + +class LockDriver(DeviceDriver): + device_type = "lock" + + def __init__(self): + super().__init__() + self.unlock_calls = 0 + + @requires_mandate(scope="actuation") + @rpc(labels={"direction": "write", "safety": "critical"}) + async def unlock(self, duration_s: int) -> dict: + """Unlock for a bounded duration.""" + self.unlock_calls += 1 + return {"unlocked": True, "duration_s": duration_s} + + @rpc(labels={"direction": "read"}) + async def get_status(self) -> dict: + """Return lock status.""" + return {"locked": True} + + +def _valid_mandate(params: dict | None = None) -> dict: + now = datetime.now(timezone.utc) + params = params or {"duration_s": 30} + open_mandate = create_open_mandate( + principal="operator", + agent="agent-1", + device_id="lock-001", + methods=["unlock"], + constraints={"duration_s": {"lte": 60}}, + not_before=now - timedelta(minutes=5), + not_after=now + timedelta(minutes=5), + key=PRINCIPAL_KEY, + mandate_id="open-1", + ) + return create_closed_mandate( + open_mandate=open_mandate, + agent="agent-1", + device_id="lock-001", + method="unlock", + params=params, + key=AGENT_KEY, + issued_at=now, + mandate_id="closed-1", + nonce="nonce-1", + ) + + +def _runtime(driver: LockDriver) -> DeviceRuntime: + return DeviceRuntime( + driver=driver, + device_id="lock-001", + messaging_urls=["nats://localhost:4222"], + mandate_keys={"operator": PRINCIPAL_KEY, "agent-1": AGENT_KEY}, + ) + + +async def _invoke_callback(rt: DeviceRuntime, method: str, params: dict) -> dict: + rt.messaging = AsyncMock() + rt.messaging.subscribe = AsyncMock() + await rt._cmd_subscription() + on_msg = rt.messaging.subscribe.call_args[1]["callback"] + await on_msg( + json.dumps({ + "jsonrpc": "2.0", + "id": "req-1", + "method": method, + "params": params, + }).encode(), + reply_subject="reply.inbox.1", + ) + return json.loads(rt.messaging.publish.call_args[0][1]) + + +class TestRequiresMandateDecorator: + def test_capability_metadata_includes_mandate_requirement(self): + driver = LockDriver() + fn = next(f for f in driver.functions if f.name == "unlock") + assert fn.mandate == {"required": True, "scope": "actuation"} + + def test_unprotected_capability_has_no_mandate_requirement(self): + driver = LockDriver() + fn = next(f for f in driver.functions if f.name == "get_status") + assert fn.mandate is None + + +class TestCommandMandateEnforcement: + @pytest.mark.asyncio + async def test_protected_rpc_without_mandate_is_denied_before_driver_call(self): + driver = LockDriver() + response = await _invoke_callback( + _runtime(driver), "unlock", {"duration_s": 30}, + ) + + assert response["error"]["code"] == -32041 + assert "mandate_required" in response["error"]["message"] + assert driver.unlock_calls == 0 + + @pytest.mark.asyncio + async def test_protected_rpc_with_valid_mandate_executes(self): + driver = LockDriver() + response = await _invoke_callback( + _runtime(driver), + "unlock", + {"duration_s": 30, "_dc_meta": {"mandate": _valid_mandate()}}, + ) + + assert response["result"] == {"unlocked": True, "duration_s": 30} + assert driver.unlock_calls == 1 + + @pytest.mark.asyncio + async def test_unprotected_rpc_executes_without_mandate(self): + response = await _invoke_callback(_runtime(LockDriver()), "get_status", {}) + assert response["result"] == {"locked": True} + + @pytest.mark.asyncio + async def test_broadcast_protected_rpc_without_mandate_is_denied(self): + driver = LockDriver() + rt = _runtime(driver) + rt.messaging = AsyncMock() + + await rt._handle_broadcast_envelope( + { + "correlation_id": "br-1", + "function": "unlock", + "params": {"duration_s": 30}, + }, + "br-1", + ) + + payload = json.loads(rt.messaging.publish.call_args[0][1]) + assert payload["success"] is False + assert payload["error"]["code"] == "mandate_required" + assert driver.unlock_calls == 0 + + @pytest.mark.asyncio + async def test_broadcast_protected_rpc_with_valid_mandate_executes(self): + driver = LockDriver() + rt = _runtime(driver) + rt.messaging = AsyncMock() + + await rt._handle_broadcast_envelope( + { + "correlation_id": "br-1", + "function": "unlock", + "params": { + "duration_s": 30, + "_dc_meta": {"mandate": _valid_mandate()}, + }, + }, + "br-1", + ) + + payload = json.loads(rt.messaging.publish.call_args[0][1]) + assert payload["success"] is True + assert payload["result"] == {"unlocked": True, "duration_s": 30} + assert driver.unlock_calls == 1 diff --git a/packages/device-connect-edge/tests/test_drivers.py b/packages/device-connect-edge/tests/test_drivers.py index 9b1fb65..1d5f2de 100644 --- a/packages/device-connect-edge/tests/test_drivers.py +++ b/packages/device-connect-edge/tests/test_drivers.py @@ -11,7 +11,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock -from device_connect_edge.drivers import DeviceDriver, rpc, emit, build_function_schema, build_event_schema +from device_connect_edge.drivers import DeviceDriver, rpc, emit, requires_mandate, build_function_schema, build_event_schema from device_connect_edge.drivers.base import on from device_connect_edge.types import DeviceIdentity, DeviceStatus @@ -177,6 +177,110 @@ async def test_rpc_callable(self): result = await driver.do_something(value=5) assert result == {"result": 10} + +# -- Discovery labels (Phase 1) ------------------------------------ + +class TestRpcLabels: + def test_default_none(self): + @rpc() + async def f(self) -> dict: + """f.""" + return {} + + assert f._labels is None + + def test_explicit_labels(self): + @rpc(labels={"direction": "write", "modality": ["rgb", "4k"]}) + async def capture(self, resolution: str = "1080p") -> dict: + """Capture.""" + return {} + + assert capture._labels == {"direction": "write", "modality": ["rgb", "4k"]} + + +class TestRequiresMandate: + def test_requires_mandate_above_rpc(self): + @requires_mandate(scope="actuation") + @rpc() + async def unlock(self) -> dict: + return {} + + assert unlock._mandate == {"required": True, "scope": "actuation"} + + def test_requires_mandate_below_rpc(self): + @rpc() + @requires_mandate(scope="actuation") + async def unlock(self) -> dict: + return {} + + assert unlock._mandate == {"required": True, "scope": "actuation"} + + +class TestEmitLabels: + def test_default_none(self): + @emit() + async def heartbeat(self): + """heartbeat.""" + pass + + assert heartbeat._labels is None + + def test_explicit_labels(self): + @emit(labels={"modality": "motion", "safety": "informational"}) + async def motion_detected(self, zone: str): + """Motion.""" + pass + + assert motion_detected._labels == {"modality": "motion", "safety": "informational"} + + +class LabeledDriver(DeviceDriver): + """Driver with class-level labels and per-method labels.""" + device_type = "camera" + labels = { + "category": ["camera", "inference"], + "location": "warehouse1/loading-dock", + } + + @rpc(labels={"direction": "write", "modality": ["rgb", "4k"]}) + async def capture_frame(self, resolution: str = "1080p") -> dict: + """Capture a frame.""" + return {} + + @rpc() + async def ping(self) -> dict: + """Ping.""" + return {} + + @emit(labels={"modality": "motion", "safety": "informational"}) + async def motion_detected(self, zone: str, confidence: float): + """Motion in zone.""" + pass + + +class TestDriverLabels: + def test_class_level_labels_on_capabilities(self): + caps = LabeledDriver().capabilities + assert caps.labels == { + "category": ["camera", "inference"], + "location": "warehouse1/loading-dock", + } + + def test_function_labels_propagated(self): + caps = LabeledDriver().capabilities + fns = {f.name: f for f in caps.functions} + assert fns["capture_frame"].labels == {"direction": "write", "modality": ["rgb", "4k"]} + assert fns["ping"].labels is None + + def test_event_labels_propagated(self): + caps = LabeledDriver().capabilities + evs = {e.name: e for e in caps.events} + assert evs["motion_detected"].labels == {"modality": "motion", "safety": "informational"} + + def test_no_class_labels_defaults_to_none(self): + # SampleDriver above does NOT define `labels` -- inherits None from DeviceDriver + assert SampleDriver().capabilities.labels is None + def test_capabilities_detected(self): """Driver should have functions and events detectable via introspection.""" driver = SampleDriver() diff --git a/packages/device-connect-edge/tests/test_mandate_verifier.py b/packages/device-connect-edge/tests/test_mandate_verifier.py new file mode 100644 index 0000000..0ce7f58 --- /dev/null +++ b/packages/device-connect-edge/tests/test_mandate_verifier.py @@ -0,0 +1,144 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for Device Mandate signing and verification helpers.""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +from device_connect_edge.mandates import ( + MandateInvocationContext, + create_closed_mandate, + create_open_mandate, + verify_mandate, +) + + +PRINCIPAL_KEY = b"principal-secret" +AGENT_KEY = b"agent-secret" + + +def _keys(principal: str) -> bytes | None: + return { + "operator": PRINCIPAL_KEY, + "agent-1": AGENT_KEY, + }.get(principal) + + +def _valid_mandate(params: dict | None = None) -> dict: + now = datetime(2026, 5, 11, 12, 0, tzinfo=timezone.utc) + params = params or {"duration_s": 30} + open_mandate = create_open_mandate( + principal="operator", + agent="agent-1", + device_id="lock-001", + methods=["unlock"], + constraints={"duration_s": {"lte": 60}}, + not_before=now - timedelta(minutes=5), + not_after=now + timedelta(minutes=5), + key=PRINCIPAL_KEY, + mandate_id="open-1", + ) + return create_closed_mandate( + open_mandate=open_mandate, + agent="agent-1", + device_id="lock-001", + method="unlock", + params=params, + key=AGENT_KEY, + issued_at=now, + mandate_id="closed-1", + nonce="nonce-1", + ) + + +def _context(**overrides) -> MandateInvocationContext: + base = { + "device_id": "lock-001", + "method": "unlock", + "params": {"duration_s": 30}, + "now": datetime(2026, 5, 11, 12, 0, tzinfo=timezone.utc), + } + base.update(overrides) + return MandateInvocationContext(**base) + + +def test_valid_closed_mandate_verifies(): + result = verify_mandate(_valid_mandate(), context=_context(), key_resolver=_keys) + assert result.ok is True + assert result.error_code is None + + +def test_missing_mandate_fails_closed(): + result = verify_mandate(None, context=_context(), key_resolver=_keys) + assert result.ok is False + assert result.error_code == "mandate_required" + + +def test_tampered_parameters_fail_signature_check(): + mandate = _valid_mandate() + mandate["invocation"]["params"]["duration_s"] = 45 + + result = verify_mandate(mandate, context=_context(), key_resolver=_keys) + + assert result.ok is False + assert result.error_code == "invalid_mandate_signature" + + +def test_wrong_device_is_denied(): + result = verify_mandate( + _valid_mandate(), + context=_context(device_id="other-lock"), + key_resolver=_keys, + ) + assert result.ok is False + assert result.error_code == "mandate_device_denied" + + +def test_wrong_method_is_denied(): + result = verify_mandate( + _valid_mandate(), + context=_context(method="lock"), + key_resolver=_keys, + ) + assert result.ok is False + assert result.error_code == "mandate_method_denied" + + +def test_expired_mandate_is_denied(): + result = verify_mandate( + _valid_mandate(), + context=_context(now=datetime(2026, 5, 11, 12, 10, tzinfo=timezone.utc)), + key_resolver=_keys, + ) + assert result.ok is False + assert result.error_code == "mandate_expired" + + +def test_parameter_constraint_is_enforced(): + mandate = _valid_mandate(params={"duration_s": 75}) + result = verify_mandate( + mandate, + context=_context(params={"duration_s": 75}), + key_resolver=_keys, + ) + assert result.ok is False + assert result.error_code == "mandate_constraint_denied" + + +def test_replay_cache_denies_reused_nonce(): + seen: set[str] = set() + mandate = _valid_mandate() + + first = verify_mandate( + mandate, context=_context(), key_resolver=_keys, replay_cache=seen, + ) + second = verify_mandate( + mandate, context=_context(), key_resolver=_keys, replay_cache=seen, + ) + + assert first.ok is True + assert second.ok is False + assert second.error_code == "mandate_replayed" diff --git a/packages/device-connect-edge/tests/test_predicate.py b/packages/device-connect-edge/tests/test_predicate.py new file mode 100644 index 0000000..dfaff81 --- /dev/null +++ b/packages/device-connect-edge/tests/test_predicate.py @@ -0,0 +1,131 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the CEL ``where`` predicate evaluator. + +These tests require the ``[predicate]`` extra (cel-python). They are +skipped automatically when cel-python is not installed so the rest of +the edge test suite stays runnable on minimal installs. +""" +from __future__ import annotations + +import pytest + +celpy = pytest.importorskip("celpy") + +from device_connect_edge.predicate import ( + PredicateCompileError, + PredicateEvalError, + WherePredicate, + compile_where, +) + + +# -- compile_where -------------------------------------------------- + + +class TestCompile: + def test_simple_comparison_compiles(self): + p = compile_where("battery > 50") + assert isinstance(p, WherePredicate) + assert p.expression == "battery > 50" + + def test_boolean_combination_compiles(self): + p = compile_where("a > 1 && b < 10 || c == 'x'") + assert isinstance(p, WherePredicate) + + def test_array_indexing_compiles(self): + p = compile_where("mask[row][col] == 1") + assert isinstance(p, WherePredicate) + + def test_label_dot_access_compiles(self): + p = compile_where("labels.category == 'camera'") + assert isinstance(p, WherePredicate) + + def test_empty_expression_rejected(self): + with pytest.raises(PredicateCompileError): + compile_where("") + with pytest.raises(PredicateCompileError): + compile_where(" ") + + def test_non_string_rejected(self): + with pytest.raises(PredicateCompileError): + compile_where(123) # type: ignore[arg-type] + + def test_malformed_expression_rejected(self): + with pytest.raises(PredicateCompileError) as exc: + compile_where("a > > b") + assert "failed to compile" in str(exc.value) + + +# -- evaluate ------------------------------------------------------- + + +class TestEvaluate: + def test_truthy_comparison(self): + p = compile_where("battery > 50") + assert p.evaluate({"battery": 80}) is True + assert p.evaluate({"battery": 30}) is False + + def test_label_match(self): + p = compile_where("labels.category == 'camera'") + assert p.evaluate({"labels": {"category": "camera"}}) is True + assert p.evaluate({"labels": {"category": "robot"}}) is False + + def test_2d_mask_indexing(self): + # The mask-indexing case is the deciding example for picking CEL + # over JSONLogic; keep it as a regression guard. + p = compile_where("mask[row][col] == 1") + ctx = { + "mask": [[0, 1, 0], [1, 0, 0]], + "row": 0, + "col": 1, + } + assert p.evaluate(ctx) is True + ctx["col"] = 0 + assert p.evaluate(ctx) is False + + def test_combined_label_and_status(self): + p = compile_where("labels.category == 'camera' && status.battery > 50") + ctx = { + "labels": {"category": "camera"}, + "status": {"battery": 80}, + } + assert p.evaluate(ctx) is True + ctx["status"]["battery"] = 30 + assert p.evaluate(ctx) is False + ctx["labels"]["category"] = "robot" + ctx["status"]["battery"] = 80 + assert p.evaluate(ctx) is False + + def test_bindings_and_status_compose(self): + p = compile_where("status.temperature > bindings.threshold") + ctx = { + "status": {"temperature": 75.5}, + "bindings": {"threshold": 70.0}, + } + assert p.evaluate(ctx) is True + + def test_string_in_list(self): + p = compile_where("labels.category in ['camera', 'inference']") + assert p.evaluate({"labels": {"category": "camera"}}) is True + assert p.evaluate({"labels": {"category": "robot"}}) is False + + def test_missing_variable_raises_eval_error(self): + p = compile_where("status.battery > 50") + with pytest.raises(PredicateEvalError): + p.evaluate({}) + + def test_type_mismatch_raises_eval_error(self): + p = compile_where("battery > 50") + with pytest.raises(PredicateEvalError): + p.evaluate({"battery": "not a number"}) + + def test_evaluator_is_reusable(self): + # Compile once, evaluate against many contexts. Reusability is the + # property that lets callers compile broadcast envelopes once at + # the dispatcher and ship them to N targets. + p = compile_where("battery > 50") + results = [p.evaluate({"battery": v}) for v in (10, 50, 51, 100)] + assert results == [False, False, True, True] diff --git a/packages/device-connect-edge/tests/test_selector.py b/packages/device-connect-edge/tests/test_selector.py new file mode 100644 index 0000000..d50a78e --- /dev/null +++ b/packages/device-connect-edge/tests/test_selector.py @@ -0,0 +1,348 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the selector DSL parser and matcher. + +Parses selector strings like +``device(category:camera, location:warehouse1/*).function(direction:write)`` +into a structured Selector and matches it against label dicts. +""" +import pytest + +from device_connect_edge.selector import ( + Filter, + KeyFilter, + Scope, + Selector, + SelectorParseError, + parse_selector, +) + + +# -- KeyFilter ----------------------------------------------------- + + +class TestKeyFilter: + def test_single_value_str_label(self): + kf = KeyFilter("direction", ("write",)) + assert kf.matches("write") + assert not kf.matches("read") + + def test_none_label_never_matches(self): + assert not KeyFilter("direction", ("write",)).matches(None) + + def test_list_label_any_member_matches(self): + kf = KeyFilter("category", ("camera",)) + assert kf.matches(["camera", "inference"]) + assert not kf.matches(["robot", "inference"]) + + def test_or_within_key(self): + kf = KeyFilter("category", ("camera", "robot")) + assert kf.matches("camera") + assert kf.matches("robot") + assert not kf.matches("hub") + assert kf.matches(["camera", "inference"]) + + def test_glob_value(self): + kf = KeyFilter("location", ("warehouse1/*",)) + assert kf.matches("warehouse1/loading-dock") + assert kf.matches("warehouse1/yard") + assert not kf.matches("warehouse2/dock") + + def test_subtree_glob_matches_exact_and_descendants(self): + # ``lab-A*`` matches both the exact location and any descendants. + kf = KeyFilter("location", ("lab-A*",)) + assert kf.matches("lab-A") + assert kf.matches("lab-A/optics-bench") + assert not kf.matches("lab-B") + + +# -- Filter -------------------------------------------------------- + + +class TestFilter: + def test_empty_filter_matches_anything(self): + f = Filter() + assert f.matches("anything", None) + assert f.matches("foo", {"k": "v"}) + + def test_name_match_exact(self): + f = Filter(name_match="robot-001") + assert f.matches("robot-001", None) + assert not f.matches("robot-002", None) + + def test_name_match_glob(self): + f = Filter(name_match="set_*") + assert f.matches("set_threshold", {}) + assert f.matches("set_location", {}) + assert not f.matches("get_reading", {}) + + def test_and_across_keys(self): + f = Filter( + key_filters=( + KeyFilter("category", ("camera",)), + KeyFilter("location", ("warehouse1/*",)), + ) + ) + assert f.matches("cam1", {"category": "camera", "location": "warehouse1/dock"}) + assert not f.matches("cam1", {"category": "camera", "location": "warehouse2/dock"}) + assert not f.matches("cam1", {"category": "robot", "location": "warehouse1/dock"}) + + def test_name_and_label_combined(self): + f = Filter( + name_match="set_*", + key_filters=(KeyFilter("direction", ("write",)),), + ) + assert f.matches("set_threshold", {"direction": "write"}) + assert not f.matches("set_threshold", {"direction": "read"}) + assert not f.matches("get_reading", {"direction": "write"}) + + def test_missing_label_means_no_match(self): + f = Filter(key_filters=(KeyFilter("safety", ("critical",)),)) + assert not f.matches("foo", {}) + assert not f.matches("foo", None) + + +# -- Selector vacuous axes ----------------------------------------- + + +class TestSelectorVacuous: + """Unset axes return True so callers can iterate without scope branching.""" + + def test_device_only_function_vacuous(self): + s = Selector(scope=Scope.DEVICE_ONLY, device=Filter()) + assert s.matches_function("anything", {"direction": "write"}) + assert s.matches_event("anything", None) + + def test_function_only_device_vacuous(self): + s = Selector(scope=Scope.FUNCTION_ONLY, function=Filter()) + assert s.matches_device("any-id", None) + + +# -- parse_selector: scope shapes --------------------------------- + + +class TestParseScope: + def test_device_only(self): + s = parse_selector("device(category:camera)") + assert s.scope == Scope.DEVICE_ONLY + assert s.device == Filter(key_filters=(KeyFilter("category", ("camera",)),)) + assert s.function is None + assert s.event is None + + def test_function_only(self): + s = parse_selector("function(safety:critical)") + assert s.scope == Scope.FUNCTION_ONLY + assert s.function.key_filters == (KeyFilter("safety", ("critical",)),) + + def test_event_only(self): + s = parse_selector("event(modality:motion)") + assert s.scope == Scope.EVENT_ONLY + assert s.event.key_filters == (KeyFilter("modality", ("motion",)),) + + def test_device_function(self): + s = parse_selector("device(*).function(direction:write)") + assert s.scope == Scope.DEVICE_FUNCTION + assert s.device == Filter() + assert s.function.key_filters == (KeyFilter("direction", ("write",)),) + + def test_device_event(self): + s = parse_selector("device(*).event(modality:motion)") + assert s.scope == Scope.DEVICE_EVENT + + def test_bare_id_match(self): + s = parse_selector("device(robot-001)") + assert s.device.name_match == "robot-001" + + def test_function_name_match(self): + s = parse_selector("function(estop)") + assert s.function.name_match == "estop" + + def test_wildcard_matches_anything(self): + s = parse_selector("device(*)") + assert s.device == Filter() + + def test_raw_preserved(self): + sel = "device(category:camera)" + assert parse_selector(sel).raw == sel + + def test_whitespace_tolerated(self): + s = parse_selector( + " device( category : camera ) . function( direction : write ) " + ) + assert s.scope == Scope.DEVICE_FUNCTION + assert s.device.key_filters == (KeyFilter("category", ("camera",)),) + assert s.function.key_filters == (KeyFilter("direction", ("write",)),) + + +# -- parse_selector: filter body grammar --------------------------- + + +class TestParseFilterBody: + def test_or_within_key(self): + s = parse_selector("device(category:[camera,robot])") + assert s.device.key_filters == (KeyFilter("category", ("camera", "robot")),) + + def test_and_across_keys(self): + s = parse_selector("device(category:camera, location:warehouse1/*)") + assert s.device.key_filters == ( + KeyFilter("category", ("camera",)), + KeyFilter("location", ("warehouse1/*",)), + ) + + def test_combined_or_and_glob(self): + s = parse_selector("device(category:[camera,robot], location:warehouse1/*)") + assert s.device.key_filters == ( + KeyFilter("category", ("camera", "robot")), + KeyFilter("location", ("warehouse1/*",)), + ) + + def test_bare_name_plus_keys(self): + s = parse_selector("device(temperature_sensor).function(direction:write, set_*)") + assert s.device.name_match == "temperature_sensor" + assert s.function.name_match == "set_*" + assert s.function.key_filters == (KeyFilter("direction", ("write",)),) + + +# -- parse_selector: errors ---------------------------------------- + + +class TestParseErrors: + @pytest.mark.parametrize("bad,expected", [ + ("", "empty"), + (" ", "empty"), + ("device", "expected '('"), + ("device(", "unclosed"), + ("foo(x)", "unknown scope"), + ("function(*).device(*)", "must start with"), + ("device(*).device(*)", "expected 'function' or 'event'"), + ("device(*).function(*).event(*)", "unexpected trailing"), + ("device(*) extra", "unexpected character"), + ("device(robot-001, robot-002)", "multiple bare-name"), + ("device(key:)", "empty value"), + ("device(:value)", "invalid key"), + ("device(,)", "empty term"), + ("device(key:[)", "unclosed '['"), + ("device(key:[])", "empty value list"), + ("device(key:[a,])", "empty value in list"), + ("device(key:[[a]])", "nested"), + ("device(bad key:val)", "invalid key"), + ]) + def test_error_messages(self, bad, expected): + with pytest.raises(SelectorParseError) as exc: + parse_selector(bad) + assert expected.lower() in str(exc.value).lower() + + def test_non_string_input(self): + with pytest.raises(SelectorParseError): + parse_selector(123) # type: ignore[arg-type] + + def test_error_includes_position_caret(self): + with pytest.raises(SelectorParseError) as exc: + parse_selector("device(foo, bad key:v)") + msg = str(exc.value) + assert "device(foo, bad key:v)" in msg + assert "^" in msg + + +# -- Worked examples ----------------------------------------------- + + +class TestWorkedExamples: + """End-to-end parse + match using DC-native device kinds (camera, robot, + sensor) and the labels that drivers would carry.""" + + def test_all_cameras(self): + s = parse_selector("device(category:camera)") + assert s.matches_device("cam-001", {"category": "camera"}) + # composite identity: camera that also runs inference + assert s.matches_device("cam-002", {"category": ["camera", "inference"]}) + assert not s.matches_device("robot-001", {"category": "robot"}) + + def test_or_within_key_with_zone_filter(self): + # cameras or robots in zone-A + s = parse_selector("device(category:[camera,robot], location:zone-A/*)") + assert s.matches_device( + "cam-1", {"category": "camera", "location": "zone-A/loading-dock"} + ) + assert s.matches_device( + "robot-1", {"category": "robot", "location": "zone-A/yard"} + ) + assert not s.matches_device( + "hub-1", {"category": "hub", "location": "zone-A/dock"} + ) + assert not s.matches_device( + "cam-2", {"category": "camera", "location": "zone-B/dock"} + ) + + def test_zone_subtree(self): + # ``zone-A*`` glob matches both ``zone-A`` exactly and any descendant. + s = parse_selector("device(location:zone-A*)") + assert s.matches_device("d", {"location": "zone-A"}) + assert s.matches_device("d", {"location": "zone-A/dock"}) + assert not s.matches_device("d", {"location": "zone-B"}) + + def test_capture_writes_fleet_wide(self): + # ``capture_image`` is DC's canonical camera RPC. Filtering for write + # direction + rgb modality across the fleet picks it up. + s = parse_selector("device(*).function(direction:write, modality:rgb)") + assert s.scope == Scope.DEVICE_FUNCTION + assert s.matches_device("anything", None) + assert s.matches_function( + "capture_image", {"direction": "write", "modality": "rgb"} + ) + assert s.matches_function( + "capture_image", {"direction": "write", "modality": ["rgb", "4k"]} + ) + assert not s.matches_function( + "get_status", {"direction": "read", "modality": "rgb"} + ) + assert not s.matches_function( + "capture_image", {"direction": "write", "modality": "thermal"} + ) + + def test_object_detection_events_fleet_wide(self): + # The ``test_camera`` driver emits ``object_detected`` events; subscribe + # to it across the fleet via a bare-name event match. + s = parse_selector("device(*).event(object_detected)") + assert s.scope == Scope.DEVICE_EVENT + assert s.matches_event("object_detected", None) + assert not s.matches_event("state_change_detected", None) + + def test_critical_rpcs_fleetwide(self): + s = parse_selector("function(safety:critical)") + assert s.matches_function("estop", {"safety": "critical"}) + assert not s.matches_function("get_reading", {"safety": "informational"}) + + def test_estop_name_match_ignores_labels(self): + # Fleet-wide ESTOP target by reserved name, regardless of labels. + s = parse_selector("function(estop)") + assert s.matches_function("estop", None) + assert s.matches_function("estop", {"safety": "critical"}) + assert not s.matches_function("get_reading", {"safety": "critical"}) + + def test_chained_sensor_writes_with_name_glob(self): + # The ``temperature_sensor`` driver exposes ``set_threshold`` and + # ``set_location`` (writes) plus ``get_reading`` (read). The anchored + # glob ``set_*`` selects only the writers. + s = parse_selector( + "device(temperature_sensor).function(direction:write, set_*)" + ) + assert s.matches_device("temperature_sensor", None) + assert not s.matches_device("test_camera", None) + assert s.matches_function("set_threshold", {"direction": "write"}) + assert s.matches_function("set_location", {"direction": "write"}) + # Anchored glob: a function whose name does NOT start with ``set_`` + # never matches, regardless of direction. + assert not s.matches_function("get_reading", {"direction": "read"}) + # Right name shape, wrong direction -> rejected. + assert not s.matches_function("set_threshold", {"direction": "read"}) + + def test_substring_glob_finds_reading_in_either_direction(self): + # Anchored globs are the default; for substring intent callers wrap + # with ``*...*``. ``*reading*`` finds the sensor's getter and the event. + s = parse_selector("function(*reading*)") + assert s.matches_function("get_reading", {"direction": "read"}) + assert s.matches_function("readings_summary", None) + assert not s.matches_function("set_threshold", {"direction": "write"}) diff --git a/packages/device-connect-edge/tests/test_types.py b/packages/device-connect-edge/tests/test_types.py index 0d580fe..15c2ab0 100644 --- a/packages/device-connect-edge/tests/test_types.py +++ b/packages/device-connect-edge/tests/test_types.py @@ -10,6 +10,7 @@ DeviceStatus, FunctionDef, EventDef, + DeviceCapabilities, ) @@ -75,3 +76,74 @@ def test_create(self): parameters={"type": "object", "properties": {"zone": {"type": "string"}}}, ) assert event.name == "motion_detected" + + +class TestLabels: + """Discovery labels on FunctionDef, EventDef, DeviceCapabilities (Phase 1).""" + + def test_function_labels_default_none(self): + f = FunctionDef(name="ping") + assert f.labels is None + + def test_function_single_value_label(self): + f = FunctionDef(name="get_status", labels={"direction": "read"}) + assert f.labels == {"direction": "read"} + + def test_function_multivalued_label(self): + f = FunctionDef(name="capture", labels={"modality": ["rgb", "4k"]}) + assert f.labels == {"modality": ["rgb", "4k"]} + + def test_function_labels_roundtrip(self): + f = FunctionDef( + name="set_threshold", + labels={"direction": "write", "modality": ["rgb", "4k"], "safety": "critical"}, + ) + f2 = FunctionDef.model_validate_json(f.model_dump_json()) + assert f2.labels == f.labels + + def test_function_mandate_roundtrip(self): + f = FunctionDef( + name="unlock", + mandate={"required": True, "scope": "actuation"}, + ) + f2 = FunctionDef.model_validate_json(f.model_dump_json()) + assert f2.mandate == {"required": True, "scope": "actuation"} + + def test_event_labels_default_none(self): + e = EventDef(name="heartbeat") + assert e.labels is None + + def test_event_labels_roundtrip(self): + e = EventDef( + name="motion_detected", + labels={"modality": "motion", "safety": "informational"}, + ) + e2 = EventDef.model_validate_json(e.model_dump_json()) + assert e2.labels == e.labels + + def test_capabilities_labels_default_none(self): + c = DeviceCapabilities() + assert c.labels is None + + def test_capabilities_labels_composite_identity(self): + # category multi-valued for composite devices (camera + inference) + c = DeviceCapabilities( + labels={ + "category": ["camera", "inference"], + "location": "warehouse1/loading-dock", + } + ) + assert c.labels["category"] == ["camera", "inference"] + assert c.labels["location"] == "warehouse1/loading-dock" + + def test_capabilities_labels_roundtrip(self): + c = DeviceCapabilities( + description="Smart cam", + functions=[FunctionDef(name="capture", labels={"direction": "write"})], + events=[EventDef(name="motion", labels={"modality": "motion"})], + labels={"category": ["camera"], "location": "warehouse1/dock-3"}, + ) + c2 = DeviceCapabilities.model_validate_json(c.model_dump_json()) + assert c2.labels == c.labels + assert c2.functions[0].labels == {"direction": "write"} + assert c2.events[0].labels == {"modality": "motion"} diff --git a/packages/device-connect-server/device_connect_server/devctl/cli.py b/packages/device-connect-server/device_connect_server/devctl/cli.py index 071b423..f73ec6a 100644 --- a/packages/device-connect-server/device_connect_server/devctl/cli.py +++ b/packages/device-connect-server/device_connect_server/devctl/cli.py @@ -574,9 +574,20 @@ def create_parser() -> argparse.ArgumentParser: p_reg.add_argument("--broker", default=None, help="Broker URL") p_reg.add_argument("--keepalive", action="store_true", help="Start heartbeat loop") - # discover command - p_discover = sub.add_parser("discover", help="Discover uncommissioned devices") - p_discover.add_argument("--timeout", type=int, default=5, help="Timeout in seconds") + # mdns-scan: discover uncommissioned devices on the local network. + # Renamed from the historical ``discover`` verb so the selector-driven + # ``discover`` below (which queries the fleet, not the local network) + # can take the natural name. + p_scan = sub.add_parser( + "mdns-scan", help="Discover uncommissioned devices via mDNS", + aliases=["scan"], + ) + p_scan.add_argument("--timeout", type=int, default=5, help="Timeout in seconds") + + # Selector-driven fleet discovery (new). Registers ``discover`` and + # ``discover-labels`` as parser entries. + from device_connect_server.devctl import selector_cli + selector_cli.register_subparsers(sub) # commission command p_commission = sub.add_parser("commission", help="Commission a device with PIN") @@ -617,9 +628,17 @@ def main(argv: Optional[List[str]] = None) -> None: loop.stop() print("\nbye!") - elif args.cmd == "discover": + elif args.cmd in ("mdns-scan", "scan"): asyncio.run(discover_devices(timeout=args.timeout)) + elif args.cmd == "discover": + from device_connect_server.devctl import selector_cli + sys.exit(selector_cli.run_discover(args)) + + elif args.cmd == "discover-labels": + from device_connect_server.devctl import selector_cli + sys.exit(selector_cli.run_discover_labels(args)) + elif args.cmd == "commission": asyncio.run( commission_device( diff --git a/packages/device-connect-server/device_connect_server/devctl/selector_cli.py b/packages/device-connect-server/device_connect_server/devctl/selector_cli.py new file mode 100644 index 0000000..68a6637 --- /dev/null +++ b/packages/device-connect-server/device_connect_server/devctl/selector_cli.py @@ -0,0 +1,103 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""``devctl`` selector-driven discovery verbs. + +Thin wrappers around ``device_connect_agent_tools.discover`` and +``discover_labels`` so operators can drive the same selector grammar +from a shell. +""" +from __future__ import annotations + +import json +import os +from typing import Any + + +def _connect(broker: str | None) -> None: + """Best-effort connect to the messaging backend. + + Reuses ``DEVICE_CONNECT_*`` and ``NATS_URL`` env vars when ``broker`` is + not given. Kept as a thin wrapper so all CLI verbs share the same + connect-or-fail semantics. + """ + from device_connect_agent_tools import connect + + if broker: + connect(nats_url=broker) + else: + nats_url = os.getenv("NATS_URL") or os.getenv("DEVICE_CONNECT_NATS_URL") + if nats_url: + connect(nats_url=nats_url) + else: + connect() + + +def _pretty(data: Any) -> str: + """Render a JSON payload for terminal output.""" + return json.dumps(data, indent=2, sort_keys=True, default=str) + + +def run_discover(args: Any) -> int: + """Execute ``devctl discover ""``.""" + from device_connect_agent_tools import disconnect, discover + + _connect(getattr(args, "broker", None)) + try: + result = discover( + args.selector, + offset=int(args.offset or 0), + limit=int(args.limit or 200), + ) + print(_pretty(result)) + return 0 if "error" not in result else 1 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def run_discover_labels(args: Any) -> int: + """Execute ``devctl discover-labels [--key K]``.""" + from device_connect_agent_tools import disconnect, discover_labels + + _connect(getattr(args, "broker", None)) + try: + result = discover_labels( + key=args.key, + offset=int(args.offset or 0), + limit=int(args.limit or 50), + ) + print(_pretty(result)) + return 0 if "error" not in result else 1 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def register_subparsers(sub: Any) -> None: + """Attach the discover / discover-labels subparsers to a devctl parser.""" + p = sub.add_parser( + "discover", + help="Resolve a selector to devices, functions, or events", + ) + p.add_argument("selector", help="Selector expression (e.g. 'device(category:camera)')") + p.add_argument("--broker", default=None, help="Messaging broker URL") + p.add_argument("--offset", type=int, default=0, help="Pagination offset") + p.add_argument("--limit", type=int, default=200, help="Page size") + + p = sub.add_parser( + "discover-labels", + help="Browse fleet label vocabulary", + ) + p.add_argument( + "--key", default=None, + help="Axis-qualified label key (e.g. 'device.location') for per-key pagination", + ) + p.add_argument("--broker", default=None, help="Messaging broker URL") + p.add_argument("--offset", type=int, default=0, help="Pagination offset") + p.add_argument("--limit", type=int, default=50, help="Page size") diff --git a/packages/device-connect-server/device_connect_server/portal/services/execution_receipts.py b/packages/device-connect-server/device_connect_server/portal/services/execution_receipts.py new file mode 100644 index 0000000..3449ef6 --- /dev/null +++ b/packages/device-connect-server/device_connect_server/portal/services/execution_receipts.py @@ -0,0 +1,136 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Execution receipt helpers for mandate-aware invokes.""" + +from __future__ import annotations + +import hashlib +import hmac +import json +import os +import secrets +from datetime import datetime, timezone +from typing import Any + +_RECEIPTS: list[dict[str, Any]] = [] +_MAX_RECEIPTS = 1000 + + +def build_receipt( + *, + trace_id: str, + tenant: str, + actor: dict[str, Any], + device_id: str, + function: str, + params: dict[str, Any], + status: str, + elapsed_ms: int, + response: Any = None, + error: dict[str, Any] | None = None, + mandate: dict[str, Any] | None = None, + mandate_required: bool = False, + mandate_verified: bool = False, + mandate_error_code: str | None = None, +) -> dict[str, Any]: + receipt = { + "receipt_id": "rcpt-" + secrets.token_hex(8), + "trace_id": trace_id, + "tenant": tenant, + "actor": { + "token_id": actor.get("token_id"), + "username": actor.get("username"), + }, + "device_id": device_id, + "function": function, + "status": status, + "authorized": status != "denied", + "mandate": _mandate_summary( + mandate, + required=mandate_required, + verified=mandate_verified, + error_code=mandate_error_code, + ), + "params_sha256": hash_json(params), + "response_sha256": hash_json(response) if response is not None else None, + "error": error, + "elapsed_ms": elapsed_ms, + "issued_at": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"), + } + receipt["signature"] = sign_receipt(receipt) + return receipt + + +def record_receipt(receipt: dict[str, Any]) -> dict[str, Any]: + """Append a receipt to the process-local audit log.""" + _RECEIPTS.append(dict(receipt)) + if len(_RECEIPTS) > _MAX_RECEIPTS: + del _RECEIPTS[: len(_RECEIPTS) - _MAX_RECEIPTS] + return receipt + + +def get_receipt(receipt_id: str) -> dict[str, Any] | None: + for receipt in reversed(_RECEIPTS): + if receipt.get("receipt_id") == receipt_id: + return dict(receipt) + return None + + +def list_receipts( + *, + tenant: str | None = None, + device_id: str | None = None, + limit: int = 100, +) -> list[dict[str, Any]]: + safe_limit = max(1, min(int(limit or 100), 1000)) + out = [] + for receipt in reversed(_RECEIPTS): + if tenant is not None and receipt.get("tenant") != tenant: + continue + if device_id is not None and receipt.get("device_id") != device_id: + continue + out.append(dict(receipt)) + if len(out) >= safe_limit: + break + return out + + +def hash_json(value: Any) -> str: + payload = json.dumps( + value, sort_keys=True, separators=(",", ":"), ensure_ascii=True, + default=str, + ).encode() + return hashlib.sha256(payload).hexdigest() + + +def sign_receipt(receipt: dict[str, Any]) -> str | None: + key = os.getenv("DC_RECEIPT_SIGNING_KEY") + if not key: + return None + unsigned = {k: v for k, v in receipt.items() if k != "signature"} + payload = json.dumps( + unsigned, sort_keys=True, separators=(",", ":"), ensure_ascii=True, + default=str, + ).encode() + return hmac.new(key.encode(), payload, hashlib.sha256).hexdigest() + + +def _mandate_summary( + mandate: dict[str, Any] | None, + *, + required: bool, + verified: bool, + error_code: str | None, +) -> dict[str, Any]: + open_mandate = mandate.get("open_mandate") if isinstance(mandate, dict) else {} + return { + "required": required, + "verified": verified, + "id": mandate.get("id") if isinstance(mandate, dict) else None, + "open_mandate_id": open_mandate.get("id") if isinstance(open_mandate, dict) else None, + "principal": open_mandate.get("principal") if isinstance(open_mandate, dict) else None, + "agent": mandate.get("agent") if isinstance(mandate, dict) else None, + "error_code": error_code, + } diff --git a/packages/device-connect-server/device_connect_server/portal/services/mandates.py b/packages/device-connect-server/device_connect_server/portal/services/mandates.py new file mode 100644 index 0000000..c1b5d63 --- /dev/null +++ b/packages/device-connect-server/device_connect_server/portal/services/mandates.py @@ -0,0 +1,108 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Server-side helpers for Device Mandates.""" + +from __future__ import annotations + +import json +import os +from typing import Any + +from device_connect_edge.mandates import ( + MandateInvocationContext, + MandateVerificationResult, + verify_mandate, +) + + +_SERVER_MANDATE_REPLAY_CACHE: set[str] = set() + + +def get_function_mandate_policy( + device_doc: dict[str, Any] | None, + function: str, +) -> dict[str, Any] | None: + """Return mandate policy metadata for a function in a registry document.""" + capabilities = (device_doc or {}).get("capabilities") or {} + for fn in capabilities.get("functions") or []: + if fn.get("name") == function: + mandate = fn.get("mandate") + return mandate if isinstance(mandate, dict) else None + return None + + +def extract_mandate( + body: dict[str, Any], + params: dict[str, Any], +) -> dict[str, Any] | None: + """Extract mandate from top-level body or params._dc_meta.""" + mandate = body.get("mandate") + if isinstance(mandate, dict): + return mandate + dc_meta = params.get("_dc_meta") + if isinstance(dc_meta, dict) and isinstance(dc_meta.get("mandate"), dict): + return dc_meta["mandate"] + return None + + +def strip_dc_meta(params: dict[str, Any]) -> dict[str, Any]: + """Return user parameters only, excluding reserved Device Connect metadata.""" + return {k: v for k, v in params.items() if k != "_dc_meta"} + + +def attach_mandate( + params: dict[str, Any], + source_params: dict[str, Any], + mandate: dict[str, Any] | None, +) -> dict[str, Any]: + """Attach a mandate to params._dc_meta while preserving existing metadata.""" + out = dict(params) + existing_meta = source_params.get("_dc_meta") + meta = dict(existing_meta) if isinstance(existing_meta, dict) else {} + if mandate is not None: + meta["mandate"] = mandate + if meta: + out["_dc_meta"] = meta + return out + + +def verify_server_mandate( + *, + device_doc: dict[str, Any] | None, + device_id: str, + function: str, + params: dict[str, Any], + mandate: dict[str, Any] | None, +) -> MandateVerificationResult: + """Verify a mandate when policy requires it or a caller supplied one.""" + policy = get_function_mandate_policy(device_doc, function) + mandate_required = bool(policy and policy.get("required")) + if not mandate_required and mandate is None: + return MandateVerificationResult(ok=True) + return verify_mandate( + mandate, + context=MandateInvocationContext( + device_id=device_id, + method=function, + params=params, + ), + key_resolver=resolve_mandate_key, + replay_cache=_SERVER_MANDATE_REPLAY_CACHE, + ) + + +def resolve_mandate_key(principal_or_agent: str) -> bytes | str | None: + """Resolve principal/agent signing keys from DC_MANDATE_KEYS_JSON.""" + raw = os.getenv("DC_MANDATE_KEYS_JSON", "") + if not raw: + return None + try: + keys = json.loads(raw) + except json.JSONDecodeError: + return None + if not isinstance(keys, dict): + return None + key = keys.get(principal_or_agent) + return key if isinstance(key, str) else None diff --git a/packages/device-connect-server/device_connect_server/portal/views/agent_api.py b/packages/device-connect-server/device_connect_server/portal/views/agent_api.py index 0aae37c..fbd2e04 100644 --- a/packages/device-connect-server/device_connect_server/portal/views/agent_api.py +++ b/packages/device-connect-server/device_connect_server/portal/views/agent_api.py @@ -27,6 +27,8 @@ from ..services import cli_auth as cli_auth_svc from ..services import credentials as credentials_svc +from ..services import execution_receipts as receipts_svc +from ..services import mandates as mandates_svc from ..services import registry_client, tokens as tokens_svc from ..services.backend import get_backend, validate_name @@ -74,6 +76,8 @@ def setup_routes(app: web.Application): r.add_post(PREFIX + "/devices/{device_id}/credentials:rotate", device_credentials_rotate) r.add_post(PREFIX + "/devices/{device_id}/invoke", device_invoke) r.add_post(PREFIX + "/invoke-with-fallback", invoke_with_fallback) + r.add_get(PREFIX + "/receipts", receipts_list) + r.add_get(PREFIX + "/receipts/{receipt_id}", receipt_get) r.add_get( PREFIX + "/devices/{device_id}/events/{event_name}/stream", device_event_stream, @@ -100,10 +104,14 @@ def _err( code: str, message: str, trace_id: str | None = None, + extra: dict[str, Any] | None = None, ) -> web.Response: + payload = {"success": False, "trace_id": trace_id or _trace_id(), + "error": {"code": code, "message": message}} + if extra: + payload.update(extra) return web.json_response( - {"success": False, "trace_id": trace_id or _trace_id(), - "error": {"code": code, "message": message}}, + payload, status=status, ) @@ -645,26 +653,96 @@ async def device_invoke(request: web.Request) -> web.Response: if not isinstance(params, dict): return _err(status=400, code="invalid_params", message="params must be an object", trace_id=trace) + clean_params = mandates_svc.strip_dc_meta(params) + mandate = mandates_svc.extract_mandate(body, params) timeout = _clamp_timeout(body.get("timeout")) reason = _truncate(body.get("reason") or body.get("llm_reasoning") or "", 500) + device_doc = _device_doc(tenant, device_id) + mandate_policy = mandates_svc.get_function_mandate_policy(device_doc, function) + mandate_required = bool(mandate_policy and mandate_policy.get("required")) + mandate_result = mandates_svc.verify_server_mandate( + device_doc=device_doc, + device_id=full_name, + function=function, + params=clean_params, + mandate=mandate, + ) + if not mandate_result.ok: + receipt = receipts_svc.record_receipt(receipts_svc.build_receipt( + trace_id=trace, + tenant=tenant, + actor=request.get("token") or {}, + device_id=full_name, + function=function, + params=clean_params, + status="denied", + elapsed_ms=0, + error={"code": mandate_result.error_code, "message": mandate_result.message}, + mandate=mandate, + mandate_required=mandate_required, + mandate_verified=False, + mandate_error_code=mandate_result.error_code, + )) + _audit(request, "invoke_denied", trace_id=trace, device_id=full_name, + function=function, receipt_id=receipt["receipt_id"], + error=mandate_result.error_code) + return _err( + status=403, + code=mandate_result.error_code or "mandate_denied", + message=mandate_result.message or "mandate denied", + trace_id=trace, + extra={"receipt": receipt}, + ) + params_for_rpc = mandates_svc.attach_mandate(clean_params, params, mandate) backend = get_backend() started = time.monotonic() try: - result = await backend.rpc_invoke(tenant, full_name, function, params, timeout=timeout) + result = await backend.rpc_invoke(tenant, full_name, function, params_for_rpc, timeout=timeout) elapsed_ms = int((time.monotonic() - started) * 1000) + receipt = receipts_svc.record_receipt(receipts_svc.build_receipt( + trace_id=trace, + tenant=tenant, + actor=request.get("token") or {}, + device_id=full_name, + function=function, + params=clean_params, + status="succeeded", + elapsed_ms=elapsed_ms, + response=result, + mandate=mandate, + mandate_required=mandate_required, + mandate_verified=bool(mandate), + )) _audit(request, "invoke", trace_id=trace, device_id=full_name, function=function, elapsed_ms=elapsed_ms, success=True, - reason=_truncate(reason, 120)) + reason=_truncate(reason, 120), receipt_id=receipt["receipt_id"]) return _ok({"device_id": full_name, "function": function, - "elapsed_ms": elapsed_ms, "response": result}, + "elapsed_ms": elapsed_ms, "response": result, + "receipt": receipt}, trace_id=trace) except Exception as e: elapsed_ms = int((time.monotonic() - started) * 1000) + receipt = receipts_svc.record_receipt(receipts_svc.build_receipt( + trace_id=trace, + tenant=tenant, + actor=request.get("token") or {}, + device_id=full_name, + function=function, + params=clean_params, + status="failed", + elapsed_ms=elapsed_ms, + error={"code": "invoke_failed", "message": str(e)}, + mandate=mandate, + mandate_required=mandate_required, + mandate_verified=bool(mandate), + )) _audit(request, "invoke", trace_id=trace, device_id=full_name, function=function, elapsed_ms=elapsed_ms, success=False, - reason=_truncate(reason, 120), error=str(e)) - return _err(status=502, code="invoke_failed", message=str(e), trace_id=trace) + reason=_truncate(reason, 120), error=str(e), + receipt_id=receipt["receipt_id"]) + return _err(status=502, code="invoke_failed", message=str(e), trace_id=trace, + extra={"receipt": receipt}) async def invoke_with_fallback(request: web.Request) -> web.Response: @@ -695,35 +773,177 @@ async def invoke_with_fallback(request: web.Request) -> web.Response: return _err(status=400, code="missing_function", message="function is required", trace_id=trace) params = body.get("params") or {} + if not isinstance(params, dict): + return _err(status=400, code="invalid_params", message="params must be an object", + trace_id=trace) + clean_params = mandates_svc.strip_dc_meta(params) timeout = _clamp_timeout(body.get("timeout")) reason = _truncate(body.get("reason") or body.get("llm_reasoning") or "", 500) backend = get_backend() failures = [] + receipts = [] for idx, raw_id in enumerate(ids): full_name = _full_device_name(tenant, raw_id) + device_doc = _device_doc(tenant, raw_id) + mandate = _mandate_for_device(body, params, tenant, raw_id, full_name) + mandate_policy = mandates_svc.get_function_mandate_policy(device_doc, function) + mandate_required = bool(mandate_policy and mandate_policy.get("required")) + mandate_result = mandates_svc.verify_server_mandate( + device_doc=device_doc, + device_id=full_name, + function=function, + params=clean_params, + mandate=mandate, + ) + if not mandate_result.ok: + receipt = receipts_svc.record_receipt(receipts_svc.build_receipt( + trace_id=trace, + tenant=tenant, + actor=request.get("token") or {}, + device_id=full_name, + function=function, + params=clean_params, + status="denied", + elapsed_ms=0, + error={"code": mandate_result.error_code, "message": mandate_result.message}, + mandate=mandate, + mandate_required=mandate_required, + mandate_verified=False, + mandate_error_code=mandate_result.error_code, + )) + receipts.append(receipt) + failures.append({ + "device_id": full_name, + "error": mandate_result.message, + "code": mandate_result.error_code, + "receipt": receipt, + }) + continue + params_for_rpc = mandates_svc.attach_mandate(clean_params, params, mandate) started = time.monotonic() try: - response = await backend.rpc_invoke(tenant, full_name, function, params, timeout=timeout) + response = await backend.rpc_invoke(tenant, full_name, function, params_for_rpc, timeout=timeout) elapsed_ms = int((time.monotonic() - started) * 1000) + receipt = receipts_svc.record_receipt(receipts_svc.build_receipt( + trace_id=trace, + tenant=tenant, + actor=request.get("token") or {}, + device_id=full_name, + function=function, + params=clean_params, + status="succeeded", + elapsed_ms=elapsed_ms, + response=response, + mandate=mandate, + mandate_required=mandate_required, + mandate_verified=bool(mandate), + )) _audit(request, "invoke_fallback", trace_id=trace, device_id=full_name, function=function, elapsed_ms=elapsed_ms, success=True, - reason=_truncate(reason, 120)) + reason=_truncate(reason, 120), receipt_id=receipt["receipt_id"]) return _ok( {"device_id": full_name, "function": function, "elapsed_ms": elapsed_ms, "response": response, + "receipt": receipt, "tried": [{"device_id": _full_device_name(tenant, x), "ok": (i == idx)} for i, x in enumerate(ids[: idx + 1])], - "failures": failures}, + "failures": failures, "receipts": receipts + [receipt]}, trace_id=trace, ) except Exception as e: - failures.append({"device_id": full_name, "error": str(e)}) + elapsed_ms = int((time.monotonic() - started) * 1000) + receipt = receipts_svc.record_receipt(receipts_svc.build_receipt( + trace_id=trace, + tenant=tenant, + actor=request.get("token") or {}, + device_id=full_name, + function=function, + params=clean_params, + status="failed", + elapsed_ms=elapsed_ms, + error={"code": "invoke_failed", "message": str(e)}, + mandate=mandate, + mandate_required=mandate_required, + mandate_verified=bool(mandate), + )) + receipts.append(receipt) + failures.append({"device_id": full_name, "error": str(e), "receipt": receipt}) _audit(request, "invoke_fallback", trace_id=trace, function=function, success=False, reason=_truncate(reason, 120)) - return _err(status=502, code="all_failed", - message="All fallback devices failed", trace_id=trace) + all_denied = bool(failures) and all(_is_mandate_denial(f.get("code")) for f in failures) + return _err( + status=403 if all_denied else 502, + code="all_denied" if all_denied else "all_failed", + message="All fallback devices were denied" if all_denied else "All fallback devices failed", + trace_id=trace, + extra={"failures": failures, "receipts": receipts}, + ) + + +def _mandate_for_device( + body: dict[str, Any], + params: dict[str, Any], + tenant: str, + raw_id: str, + full_name: str, +) -> dict[str, Any] | None: + mandates = body.get("mandates") + if isinstance(mandates, dict): + for key in (full_name, raw_id, _full_device_name(tenant, raw_id)): + mandate = mandates.get(key) + if isinstance(mandate, dict): + return mandate + return mandates_svc.extract_mandate(body, params) + + +def _is_mandate_denial(code: Any) -> bool: + if not isinstance(code, str): + return False + return ( + code.startswith("mandate_") + or code in {"invalid_mandate", "invalid_mandate_signature", "unknown_mandate_key"} + ) + + +# ── execution receipts ───────────────────────────────────────────── + + +async def receipts_list(request: web.Request) -> web.Response: + trace = _trace_id() + _, err = _require_scope(request, "devices:read") + if err: + return err + tenant, err = _resolve_tenant(request) + if err: + return err + device_id = request.query.get("device_id") + full_device_id = _full_device_name(tenant, device_id) if device_id else None + try: + limit = int(request.query.get("limit", "100")) + except ValueError: + limit = 100 + receipts = receipts_svc.list_receipts( + tenant=tenant, + device_id=full_device_id, + limit=limit, + ) + return _ok({"receipts": receipts, "returned": len(receipts)}, trace_id=trace) + + +async def receipt_get(request: web.Request) -> web.Response: + trace = _trace_id() + _, err = _require_scope(request, "devices:read") + if err: + return err + tenant, err = _resolve_tenant(request) + if err: + return err + receipt = receipts_svc.get_receipt(request.match_info["receipt_id"]) + if receipt is None or receipt.get("tenant") != tenant: + return _err(status=404, code="not_found", message="Receipt not found", trace_id=trace) + return _ok({"receipt": receipt}, trace_id=trace) # ── event streaming (bounded) ────────────────────────────────────── diff --git a/packages/device-connect-server/device_connect_server/portal/views/devices.py b/packages/device-connect-server/device_connect_server/portal/views/devices.py index 84125ac..7f5bf1e 100644 --- a/packages/device-connect-server/device_connect_server/portal/views/devices.py +++ b/packages/device-connect-server/device_connect_server/portal/views/devices.py @@ -320,7 +320,7 @@ async def download_starter_script(request: web.Request): """Device Connect — starter AI agent (Strands + OpenAI). Connects to Device Connect, discovers your fleet, and reacts to device -events by calling tools (list_devices, get_device_functions, invoke_device). +events by calling tools (discover_labels, discover, invoke, invoke_many). LLM inference runs through the Arm internal OpenAI proxy. Usage: @@ -403,8 +403,8 @@ async def prepare(self) -> Dict[str, Any]: from strands import Agent from strands.models.openai import OpenAIModel from device_connect_agent_tools.adapters.strands import ( - describe_fleet, list_devices, get_device_functions, - invoke_device, invoke_device_with_fallback, get_device_status, + discover_labels, discover, + invoke, invoke_many, invoke_device_with_fallback, get_device_status, ) result = await super().prepare() @@ -416,8 +416,8 @@ async def prepare(self) -> Dict[str, Any]: params={"max_tokens": self._max_tokens}, ), tools=[ - describe_fleet, list_devices, get_device_functions, - invoke_device, invoke_device_with_fallback, get_device_status, + discover_labels, discover, + invoke, invoke_many, invoke_device_with_fallback, get_device_status, ], system_prompt=self._build_system_prompt(), ) @@ -441,22 +441,31 @@ def _build_system_prompt(self) -> str: for dt, info in sorted(by_type.items()): locs = ", ".join(sorted(info["locations"])) lines.append(f" - {info['count']}x {dt} (at: {locs})") - fleet = "\\n".join(lines) or " (none yet — call describe_fleet() to refresh)" + fleet = "\\n".join(lines) or " (none yet -- call discover() to refresh)" return ( f"You are an AI agent connected to the Device Connect IoT network.\\n\\n" f"YOUR GOAL: {self.goal}\\n\\n" f"FLEET OVERVIEW ({len(self.devices)} devices):\\n{fleet}\\n\\n" f"DISCOVERY TOOLS:\\n" - f" - describe_fleet() — fleet summary\\n" - f" - list_devices(device_type=..., location=...) — browse devices\\n" - f" - get_device_functions(device_id) — see what a device can do\\n" - f" - invoke_device(device_id, function, params) — call a device function\\n\\n" + f" - discover_labels(key=None) -- fleet label vocabulary " + f"(category, location, direction, modality, ...)\\n" + f" - discover(selector) -- resolve a selector to devices, " + f"functions, or events. Examples:\\n" + f" device(category:camera, location:zone-A/*)\\n" + f" device(robot-001).function(direction:write)\\n" + f" function(safety:critical)\\n\\n" + f"INVOCATION TOOLS:\\n" + f" - invoke(selector, params) -- call exactly one function. " + f"Selector must resolve to one (device, function) tuple.\\n" + f" - invoke_many(selector, params) -- fan out a function call " + f"over a selector-resolved set in parallel.\\n\\n" f"INSTRUCTIONS:\\n" f"When you receive device events, you MUST:\\n" f"1. Analyze the events\\n" - f"2. Use get_device_functions() to check available functions if needed\\n" - f"3. Use invoke_device() to interact with devices\\n" + f"2. Use discover() with a function-scoped selector to check " + f"available functions if needed\\n" + f"3. Use invoke() or invoke_many() to interact with devices\\n" f"4. Report what you found and what actions you took\\n\\n" f"Always provide llm_reasoning when invoking devices.\\n" f"Always call at least one tool per batch of events." diff --git a/packages/device-connect-server/device_connect_server/statectl/cli.py b/packages/device-connect-server/device_connect_server/statectl/cli.py index e1a03ef..161afdd 100644 --- a/packages/device-connect-server/device_connect_server/statectl/cli.py +++ b/packages/device-connect-server/device_connect_server/statectl/cli.py @@ -408,6 +408,13 @@ def create_parser() -> argparse.ArgumentParser: # stats sub.add_parser("stats", help="Key counts by namespace") + # Selector-driven operations (invoke / invoke-many / broadcast / + # subscribe / await). These verbs do not touch etcd; they run over + # the messaging fabric. They live under statectl because they all + # change the live state of devices. + from device_connect_server.statectl import operations_cli + operations_cli.register_subparsers(sub) + return parser @@ -430,9 +437,25 @@ async def _run(args) -> None: await handler(client, args) +_OPERATIONS_DISPATCH = { + "invoke": "run_invoke", + "invoke-many": "run_invoke_many", + "broadcast": "run_broadcast", + "subscribe": "run_subscribe", + "await": "run_await", +} + + def main(): parser = create_parser() args = parser.parse_args() + if args.cmd in _OPERATIONS_DISPATCH: + # Operations verbs run over messaging, not etcd. Bypass the etcd + # client setup that the COMMANDS dispatch table assumes. + from device_connect_server.statectl import operations_cli + handler = getattr(operations_cli, _OPERATIONS_DISPATCH[args.cmd]) + sys.exit(handler(args)) + try: asyncio.run(_run(args)) except KeyboardInterrupt: diff --git a/packages/device-connect-server/device_connect_server/statectl/operations_cli.py b/packages/device-connect-server/device_connect_server/statectl/operations_cli.py new file mode 100644 index 0000000..7ddc9ef --- /dev/null +++ b/packages/device-connect-server/device_connect_server/statectl/operations_cli.py @@ -0,0 +1,297 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""``statectl`` selector-driven operations verbs. + +Thin wrappers around the agent-tools ``invoke`` / ``invoke_many`` / +``broadcast`` / ``subscribe`` / ``await_replies`` functions so operators +can fire selector-driven calls from a shell. +""" +from __future__ import annotations + +import json +import os +from typing import Any + + +def _connect(broker: str | None) -> None: + """Connect to the messaging backend using the same env-or-broker rules + as devctl's selector verbs.""" + from device_connect_agent_tools import connect + + if broker: + connect(nats_url=broker) + else: + nats_url = os.getenv("NATS_URL") or os.getenv("DEVICE_CONNECT_NATS_URL") + if nats_url: + connect(nats_url=nats_url) + else: + connect() + + +def _parse_param_kv(values: list[str] | None) -> dict[str, Any]: + """Parse ``--param k=v`` repeated args into a function-params dict. + + Values that look like JSON (``[...]``, ``{...}``, numbers, ``true`` / + ``false`` / ``null``) are decoded; everything else stays a string. This + matches what an operator would expect when typing + ``--param resolution=1080p --param tags='["a","b"]'``. + """ + out: dict[str, Any] = {} + for entry in values or []: + if "=" not in entry: + raise ValueError(f"--param must be 'k=v', got {entry!r}") + k, _, v = entry.partition("=") + k = k.strip() + if not k: + raise ValueError(f"--param has empty key in {entry!r}") + v_stripped = v.strip() + # JSON-decode obvious JSON-shaped values; fall back to raw string. + if ( + v_stripped.startswith(("[", "{", '"')) + or v_stripped in ("true", "false", "null") + or _looks_numeric(v_stripped) + ): + try: + out[k] = json.loads(v_stripped) + continue + except json.JSONDecodeError: + pass + out[k] = v + return out + + +def _looks_numeric(s: str) -> bool: + try: + float(s) + return True + except ValueError: + return False + + +def _pretty(data: Any) -> str: + return json.dumps(data, indent=2, sort_keys=True, default=str) + + +# -- verbs ---------------------------------------------------------- + + +def run_invoke(args: Any) -> int: + from device_connect_agent_tools import disconnect, invoke + + _connect(getattr(args, "broker", None)) + try: + result = invoke( + args.selector, + params=_parse_param_kv(args.param), + llm_reasoning=args.reason, + ) + print(_pretty(result)) + return 0 if result.get("success") else 1 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def run_invoke_many(args: Any) -> int: + from device_connect_agent_tools import disconnect, invoke_many + + _connect(getattr(args, "broker", None)) + try: + result = invoke_many( + args.selector, + params=_parse_param_kv(args.param), + timeout=float(args.timeout), + max_concurrency=int(args.max_concurrency), + llm_reasoning=args.reason, + ) + print(_pretty(result)) + # Exit non-zero on a top-level error OR when any target failed, so + # shell pipelines can detect partial failure without parsing JSON. + if "error" in result: + return 1 + if result.get("failed", 0) > 0: + return 3 + return 0 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def run_broadcast(args: Any) -> int: + from device_connect_agent_tools import broadcast, disconnect + + bindings = None + if args.bindings: + try: + bindings = json.loads(args.bindings) + except json.JSONDecodeError as e: + print(f"--bindings must be valid JSON: {e}") + return 2 + + _connect(getattr(args, "broker", None)) + try: + result = broadcast( + args.selector, + params=_parse_param_kv(args.param), + where=args.where, + bindings=bindings, + fire_at=float(args.fire_at) if args.fire_at is not None else None, + on_late=args.on_late, + llm_reasoning=args.reason, + ) + print(_pretty(result)) + return 0 if "error" not in result else 1 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def run_subscribe(args: Any) -> int: + """Stream events / replies for ``args.selector`` to stdout. + + Each message is printed as one JSON line so the output can be piped + into ``jq`` or grep. Runs until ``--timeout`` of idle silence elapses + or ``--until`` messages have been printed (whichever comes first). + Exit codes: + 0 one or more messages were printed + 4 idle-timeout reached with zero messages + 130 interrupted with Ctrl-C + """ + from device_connect_agent_tools import disconnect, subscribe + + _connect(getattr(args, "broker", None)) + count = 0 + try: + with subscribe(args.selector) as sub: + try: + for msg in sub.iter( + timeout=float(args.timeout), poll_interval=0.05, + ): + print(json.dumps(msg, default=str)) + count += 1 + if args.until is not None and count >= int(args.until): + break + except KeyboardInterrupt: + # Clean exit on Ctrl-C: the ``with`` block tears the + # subscription down before this returns. + return 130 + return 0 if count > 0 else 4 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def run_await(args: Any) -> int: + from device_connect_agent_tools import await_replies, disconnect + + _connect(getattr(args, "broker", None)) + try: + replies = await_replies( + args.correlation_id, + timeout=float(args.timeout), + until=int(args.until) if args.until is not None else None, + ) + print(_pretty(replies)) + return 0 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +# -- parser wiring -------------------------------------------------- + + +def register_subparsers(sub: Any) -> None: + """Attach the operation subparsers to a statectl parser.""" + p = sub.add_parser("invoke", help="Call exactly one function on one device") + p.add_argument("selector", help="Function-scoped selector") + p.add_argument( + "--param", action="append", default=[], + help="Function param as k=v (repeatable; JSON values decoded)", + ) + p.add_argument("--reason", default=None, help="LLM reasoning") + p.add_argument("--broker", default=None, help="Messaging broker URL") + + p = sub.add_parser( + "invoke-many", help="Fan out a call over a selector-resolved set", + ) + p.add_argument("selector", help="Function-scoped selector") + p.add_argument( + "--param", action="append", default=[], + help="Function param as k=v (repeatable; JSON values decoded)", + ) + p.add_argument("--timeout", default=30.0, help="Per-target timeout (s)") + p.add_argument( + "--max-concurrency", default=32, dest="max_concurrency", + help="Parallel worker cap", + ) + p.add_argument("--reason", default=None, help="LLM reasoning") + p.add_argument("--broker", default=None, help="Messaging broker URL") + + p = sub.add_parser( + "broadcast", + help="Async fan-out; returns correlation_id", + ) + p.add_argument("selector", help="Function-scoped selector") + p.add_argument( + "--param", action="append", default=[], + help="Function param as k=v (repeatable; JSON values decoded)", + ) + p.add_argument( + "--where", default=None, + help="CEL predicate evaluated at the edge per candidate", + ) + p.add_argument( + "--bindings", default=None, + help="JSON-encoded bindings dict (shared payload for the predicate)", + ) + p.add_argument( + "--fire-at", default=None, dest="fire_at", + help="Wall-clock epoch seconds for synchronized fan-out", + ) + p.add_argument( + "--on-late", choices=["skip", "fire"], default="skip", dest="on_late", + help="Policy when fire_at deadline has passed (default: skip)", + ) + p.add_argument("--reason", default=None, help="LLM reasoning") + p.add_argument("--broker", default=None, help="Messaging broker URL") + + p = sub.add_parser( + "subscribe", help="Stream events or broadcast replies to stdout", + ) + p.add_argument( + "selector", + help="Event selector or 'correlation:' for broadcast replies", + ) + p.add_argument( + "--timeout", default=10.0, + help="Idle-silence timeout per message (s; resets on each arrival)", + ) + p.add_argument( + "--until", default=None, + help="Stop after this many messages are printed", + ) + p.add_argument("--broker", default=None, help="Messaging broker URL") + + p = sub.add_parser( + "await", help="Collect replies for a broadcast correlation_id", + ) + p.add_argument("correlation_id", help="Correlation id returned by broadcast") + p.add_argument("--timeout", default=10.0, help="Overall timeout (s)") + p.add_argument( + "--until", default=None, + help="Stop after this many replies have been collected", + ) + p.add_argument("--broker", default=None, help="Messaging broker URL") diff --git a/packages/device-connect-server/tests/device_connect_server/test_portal_agent_api.py b/packages/device-connect-server/tests/device_connect_server/test_portal_agent_api.py index 287d299..15d6119 100644 --- a/packages/device-connect-server/tests/device_connect_server/test_portal_agent_api.py +++ b/packages/device-connect-server/tests/device_connect_server/test_portal_agent_api.py @@ -11,6 +11,7 @@ from __future__ import annotations +from datetime import datetime, timedelta, timezone from unittest.mock import patch import pytest @@ -20,6 +21,7 @@ from device_connect_server.portal.app import auth_middleware from device_connect_server.portal.services import tokens as tokens_svc from device_connect_server.portal.views import agent_api +from device_connect_edge import create_closed_mandate, create_open_mandate # A registry doc with extra fields the API must surface untouched. @@ -58,6 +60,51 @@ "registry": {"registered_at": "2026-05-01T12:00:00+00:00"}, } +PROTECTED_LOCK = { + "device_id": "acme-lock-001", + "tenant": "acme", + "identity": {"device_type": "lock"}, + "status": {"online": True}, + "capabilities": { + "functions": [ + { + "name": "unlock", + "parameters": {"type": "object"}, + "mandate": {"required": True, "scope": "actuation"}, + }, + {"name": "get_status", "parameters": {"type": "object"}}, + ], + "events": [], + }, +} + +PRINCIPAL_KEY = "principal-secret" +AGENT_KEY = "agent-secret" + + +def _closed_mandate(device_id: str = "acme-lock-001", params: dict | None = None) -> dict: + now = datetime.now(timezone.utc) + params = params or {"duration_s": 30} + open_mandate = create_open_mandate( + principal="operator", + agent="agent-1", + device_id=device_id, + methods=["unlock"], + constraints={"duration_s": {"lte": 60}}, + not_before=now - timedelta(seconds=5), + not_after=now + timedelta(minutes=5), + key=PRINCIPAL_KEY, + ) + return create_closed_mandate( + open_mandate=open_mandate, + agent="agent-1", + device_id=device_id, + method="unlock", + params=params, + key=AGENT_KEY, + issued_at=now, + ) + @pytest.fixture def fake_record(): @@ -439,6 +486,9 @@ async def rpc_invoke(self, tenant, full_name, fn, params, timeout): return {"ok": True} with patch( + "device_connect_server.portal.views.agent_api.registry_client.get_device", + return_value=None, + ), patch( "device_connect_server.portal.views.agent_api.get_backend", return_value=_FakeBackend(), ): @@ -451,6 +501,136 @@ async def rpc_invoke(self, tenant, full_name, fn, params, timeout): assert seen["timeout"] == agent_api.MAX_INVOKE_TIMEOUT_S +class TestInvokeMandates: + async def test_protected_function_with_valid_mandate_returns_receipt( + self, invoke_client, monkeypatch, + ): + monkeypatch.setenv( + "DC_MANDATE_KEYS_JSON", + '{"operator":"principal-secret","agent-1":"agent-secret"}', + ) + seen = {} + + class _FakeBackend: + def backend_name(self): return "test" + async def rpc_invoke(self, tenant, full_name, fn, params, timeout): + seen["params"] = params + return {"ok": True} + + mandate = _closed_mandate() + def _lookup_device(tenant, device_id): + assert tenant == "acme" + assert device_id == "acme-lock-001" + return PROTECTED_LOCK + + with patch( + "device_connect_server.portal.views.agent_api.registry_client.get_device", + side_effect=_lookup_device, + ), patch( + "device_connect_server.portal.views.agent_api.get_backend", + return_value=_FakeBackend(), + ): + r = await invoke_client.post( + "/api/agent/v1/devices/lock-001/invoke", + headers=H(), + json={ + "function": "unlock", + "params": {"duration_s": 30}, + "mandate": mandate, + }, + ) + + assert r.status == 200 + body = await r.json() + receipt = body["result"]["receipt"] + assert receipt["status"] == "succeeded" + assert receipt["mandate"]["verified"] is True + assert receipt["mandate"]["principal"] == "operator" + assert seen["params"]["_dc_meta"]["mandate"] == mandate + + async def test_protected_function_without_mandate_returns_denial_receipt( + self, invoke_client, monkeypatch, + ): + monkeypatch.setenv( + "DC_MANDATE_KEYS_JSON", + '{"operator":"principal-secret","agent-1":"agent-secret"}', + ) + + class _FakeBackend: + def backend_name(self): return "test" + async def rpc_invoke(self, tenant, full_name, fn, params, timeout): + raise AssertionError("backend must not be called") + + def _lookup_device(tenant, device_id): + assert tenant == "acme" + assert device_id == "acme-lock-001" + return PROTECTED_LOCK + + with patch( + "device_connect_server.portal.views.agent_api.registry_client.get_device", + side_effect=_lookup_device, + ), patch( + "device_connect_server.portal.views.agent_api.get_backend", + return_value=_FakeBackend(), + ): + r = await invoke_client.post( + "/api/agent/v1/devices/lock-001/invoke", + headers=H(), + json={"function": "unlock", "params": {"duration_s": 30}}, + ) + + assert r.status == 403 + body = await r.json() + assert body["error"]["code"] == "mandate_required" + assert body["receipt"]["status"] == "denied" + assert body["receipt"]["mandate"]["required"] is True + + async def test_existing_dc_meta_is_preserved_when_mandate_is_attached( + self, invoke_client, monkeypatch, + ): + monkeypatch.setenv( + "DC_MANDATE_KEYS_JSON", + '{"operator":"principal-secret","agent-1":"agent-secret"}', + ) + seen = {} + + class _FakeBackend: + def backend_name(self): return "test" + async def rpc_invoke(self, tenant, full_name, fn, params, timeout): + seen["params"] = params + return {"ok": True} + + mandate = _closed_mandate() + def _lookup_device(tenant, device_id): + assert tenant == "acme" + assert device_id == "acme-lock-001" + return PROTECTED_LOCK + + with patch( + "device_connect_server.portal.views.agent_api.registry_client.get_device", + side_effect=_lookup_device, + ), patch( + "device_connect_server.portal.views.agent_api.get_backend", + return_value=_FakeBackend(), + ): + r = await invoke_client.post( + "/api/agent/v1/devices/lock-001/invoke", + headers=H(), + json={ + "function": "unlock", + "params": { + "duration_s": 30, + "_dc_meta": {"traceparent": "trace"}, + }, + "mandate": mandate, + }, + ) + + assert r.status == 200 + assert seen["params"]["_dc_meta"]["traceparent"] == "trace" + assert seen["params"]["_dc_meta"]["mandate"] == mandate + + # ── invoke-with-fallback duplicate device id (regression) ───────── @@ -469,6 +649,9 @@ async def rpc_invoke(self, tenant, full_name, fn, params, timeout): return {"ok": True, "attempt": len(attempts)} with patch( + "device_connect_server.portal.views.agent_api.registry_client.get_device", + return_value=None, + ), patch( "device_connect_server.portal.views.agent_api.get_backend", return_value=_FakeBackend(), ): diff --git a/packages/device-connect-server/tests/device_connect_server/test_portal_mandates.py b/packages/device-connect-server/tests/device_connect_server/test_portal_mandates.py new file mode 100644 index 0000000..bd93275 --- /dev/null +++ b/packages/device-connect-server/tests/device_connect_server/test_portal_mandates.py @@ -0,0 +1,169 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for portal mandate and receipt helpers.""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +from device_connect_edge import create_closed_mandate, create_open_mandate +from device_connect_server.portal.services import execution_receipts, mandates + + +PRINCIPAL_KEY = "principal-secret" +AGENT_KEY = "agent-secret" + + +DEVICE_DOC = { + "device_id": "acme-lock-001", + "capabilities": { + "functions": [ + { + "name": "unlock", + "mandate": {"required": True, "scope": "actuation"}, + }, + {"name": "get_status"}, + ] + }, +} + + +def _mandate(params: dict | None = None) -> dict: + now = datetime.now(timezone.utc) + params = params or {"duration_s": 30} + open_mandate = create_open_mandate( + principal="operator", + agent="agent-1", + device_id="acme-lock-001", + methods=["unlock"], + constraints={"duration_s": {"lte": 60}}, + not_before=now - timedelta(seconds=5), + not_after=now + timedelta(minutes=5), + key=PRINCIPAL_KEY, + ) + return create_closed_mandate( + open_mandate=open_mandate, + agent="agent-1", + device_id="acme-lock-001", + method="unlock", + params=params, + key=AGENT_KEY, + issued_at=now, + ) + + +def test_get_function_mandate_policy(): + assert mandates.get_function_mandate_policy(DEVICE_DOC, "unlock") == { + "required": True, + "scope": "actuation", + } + assert mandates.get_function_mandate_policy(DEVICE_DOC, "get_status") is None + + +def test_extract_and_attach_mandate_preserves_existing_meta(): + mandate = _mandate() + params = {"duration_s": 30, "_dc_meta": {"traceparent": "trace"}} + + assert mandates.extract_mandate({"mandate": mandate}, params) == mandate + attached = mandates.attach_mandate( + mandates.strip_dc_meta(params), params, mandate, + ) + + assert attached["duration_s"] == 30 + assert attached["_dc_meta"]["traceparent"] == "trace" + assert attached["_dc_meta"]["mandate"] == mandate + + +def test_verify_server_mandate_validates_protected_function(monkeypatch): + monkeypatch.setenv( + "DC_MANDATE_KEYS_JSON", + '{"operator":"principal-secret","agent-1":"agent-secret"}', + ) + mandates._SERVER_MANDATE_REPLAY_CACHE.clear() + + result = mandates.verify_server_mandate( + device_doc=DEVICE_DOC, + device_id="acme-lock-001", + function="unlock", + params={"duration_s": 30}, + mandate=_mandate(), + ) + + assert result.ok is True + + +def test_verify_server_mandate_denies_missing_mandate(monkeypatch): + monkeypatch.setenv( + "DC_MANDATE_KEYS_JSON", + '{"operator":"principal-secret","agent-1":"agent-secret"}', + ) + + result = mandates.verify_server_mandate( + device_doc=DEVICE_DOC, + device_id="acme-lock-001", + function="unlock", + params={"duration_s": 30}, + mandate=None, + ) + + assert result.ok is False + assert result.error_code == "mandate_required" + + +def test_execution_receipt_hashes_payload_and_can_sign(monkeypatch): + monkeypatch.setenv("DC_RECEIPT_SIGNING_KEY", "receipt-secret") + + receipt = execution_receipts.build_receipt( + trace_id="trace-1", + tenant="acme", + actor={"token_id": "tok-1", "username": "alice"}, + device_id="acme-lock-001", + function="unlock", + params={"duration_s": 30}, + status="succeeded", + elapsed_ms=12, + response={"ok": True}, + mandate=_mandate(), + mandate_required=True, + mandate_verified=True, + ) + + assert receipt["receipt_id"].startswith("rcpt-") + assert receipt["params_sha256"] + assert receipt["response_sha256"] + assert receipt["signature"] + assert receipt["mandate"]["verified"] is True + + +def test_execution_receipt_log_lists_latest_by_tenant_device_and_limit(): + execution_receipts._RECEIPTS.clear() + + first = execution_receipts.record_receipt({ + "receipt_id": "rcpt-1", + "tenant": "acme", + "device_id": "acme-lock-001", + "status": "succeeded", + }) + second = execution_receipts.record_receipt({ + "receipt_id": "rcpt-2", + "tenant": "acme", + "device_id": "acme-heater-001", + "status": "denied", + }) + execution_receipts.record_receipt({ + "receipt_id": "rcpt-3", + "tenant": "other", + "device_id": "other-lock-001", + "status": "succeeded", + }) + + assert execution_receipts.get_receipt("rcpt-1") == first + assert execution_receipts.get_receipt("missing") is None + assert execution_receipts.list_receipts(tenant="acme") == [second, first] + assert execution_receipts.list_receipts( + tenant="acme", + device_id="acme-lock-001", + ) == [first] + assert execution_receipts.list_receipts(tenant="acme", limit=1) == [second] diff --git a/packages/device-connect-server/tests/device_connect_server/test_selector_cli.py b/packages/device-connect-server/tests/device_connect_server/test_selector_cli.py new file mode 100644 index 0000000..2ed2cf8 --- /dev/null +++ b/packages/device-connect-server/tests/device_connect_server/test_selector_cli.py @@ -0,0 +1,199 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Smoke tests for the selector-driven CLI verbs. + +Argument-parser shape only; the underlying tools (``discover``, +``invoke``, ``broadcast``, etc.) have their own unit and integration +tests. These guards catch parser-config regressions (missing positional, +typoed dest, alias drift). +""" +from __future__ import annotations + +import json + +import pytest + +from device_connect_server.devctl import cli as devctl_cli +from device_connect_server.devctl import selector_cli +from device_connect_server.statectl import cli as statectl_cli +from device_connect_server.statectl import operations_cli + + +# -- devctl --------------------------------------------------------- + + +class TestDevctlSelectorParser: + def test_discover_requires_selector(self): + parser = devctl_cli.create_parser() + with pytest.raises(SystemExit): + parser.parse_args(["discover"]) + + def test_discover_parses_selector(self): + parser = devctl_cli.create_parser() + args = parser.parse_args(["discover", "device(category:camera)"]) + assert args.cmd == "discover" + assert args.selector == "device(category:camera)" + assert args.offset == 0 + assert args.limit == 200 + + def test_discover_offset_limit_override(self): + parser = devctl_cli.create_parser() + args = parser.parse_args( + ["discover", "device(*)", "--offset", "100", "--limit", "50"] + ) + assert args.offset == 100 + assert args.limit == 50 + + def test_discover_labels_no_key(self): + parser = devctl_cli.create_parser() + args = parser.parse_args(["discover-labels"]) + assert args.cmd == "discover-labels" + assert args.key is None + assert args.limit == 50 + + def test_discover_labels_key_pagination(self): + parser = devctl_cli.create_parser() + args = parser.parse_args( + ["discover-labels", "--key", "device.location", "--limit", "20"] + ) + assert args.key == "device.location" + assert args.limit == 20 + + def test_legacy_discover_renamed_to_mdns_scan(self): + # The historical "discover" verb (mDNS scan) now lives under + # mdns-scan; the alias "scan" keeps it discoverable. + parser = devctl_cli.create_parser() + for verb in ("mdns-scan", "scan"): + args = parser.parse_args([verb]) + # Both aliases share the same args.cmd + assert args.cmd in ("mdns-scan", "scan") + + +# -- statectl ------------------------------------------------------- + + +class TestStatectlOperationsParser: + def test_invoke_requires_selector(self): + parser = statectl_cli.create_parser() + with pytest.raises(SystemExit): + parser.parse_args(["invoke"]) + + def test_invoke_parses(self): + parser = statectl_cli.create_parser() + args = parser.parse_args( + [ + "invoke", "device(robot-001).function(grip_close)", + "--param", "force_n=10", + "--reason", "test", + ] + ) + assert args.cmd == "invoke" + assert args.selector == "device(robot-001).function(grip_close)" + assert args.param == ["force_n=10"] + assert args.reason == "test" + + def test_invoke_many_with_timeout(self): + parser = statectl_cli.create_parser() + args = parser.parse_args( + [ + "invoke-many", + "function(safety:critical)", + "--timeout", "5", + "--max-concurrency", "8", + ] + ) + assert args.cmd == "invoke-many" + assert float(args.timeout) == 5.0 + assert int(args.max_concurrency) == 8 + + def test_broadcast_full_signature(self): + parser = statectl_cli.create_parser() + args = parser.parse_args( + [ + "broadcast", + "device(category:phone).function(set_flashlight)", + "--param", "on=true", + "--param", "color=white", + "--where", "labels.location == 'lab-A'", + "--bindings", '{"mask": [[0,1],[1,0]]}', + "--fire-at", "1700000000.0", + "--on-late", "fire", + ] + ) + assert args.cmd == "broadcast" + assert args.selector.startswith("device(category:phone)") + assert args.where == "labels.location == 'lab-A'" + assert args.on_late == "fire" + + def test_broadcast_rejects_unknown_on_late(self): + parser = statectl_cli.create_parser() + with pytest.raises(SystemExit): + parser.parse_args( + [ + "broadcast", "device(*).function(do)", + "--on-late", "bogus", + ] + ) + + def test_subscribe_parses_correlation_form(self): + parser = statectl_cli.create_parser() + args = parser.parse_args( + ["subscribe", "correlation:br-abc123", "--until", "5"] + ) + assert args.cmd == "subscribe" + assert args.selector == "correlation:br-abc123" + assert int(args.until) == 5 + + def test_await_requires_correlation_id(self): + parser = statectl_cli.create_parser() + with pytest.raises(SystemExit): + parser.parse_args(["await"]) + + def test_await_parses(self): + parser = statectl_cli.create_parser() + args = parser.parse_args( + ["await", "br-abc123", "--timeout", "2.5", "--until", "10"] + ) + assert args.correlation_id == "br-abc123" + assert float(args.timeout) == 2.5 + assert int(args.until) == 10 + + +# -- parameter parsing ---------------------------------------------- + + +class TestParseParamKV: + def test_string_values_default(self): + result = operations_cli._parse_param_kv(["a=hello", "b=world"]) + assert result == {"a": "hello", "b": "world"} + + def test_numbers_decoded(self): + result = operations_cli._parse_param_kv(["count=5", "ratio=0.75"]) + assert result == {"count": 5, "ratio": 0.75} + + def test_booleans_decoded(self): + result = operations_cli._parse_param_kv(["on=true", "off=false"]) + assert result == {"on": True, "off": False} + + def test_json_array_decoded(self): + result = operations_cli._parse_param_kv(["zones=[1,2,3]"]) + assert result == {"zones": [1, 2, 3]} + + def test_json_object_decoded(self): + result = operations_cli._parse_param_kv(['nested={"a":1}']) + assert result == {"nested": {"a": 1}} + + def test_string_with_equals(self): + # The split is on the first '=', so values may contain further '='. + result = operations_cli._parse_param_kv(["query=a=b"]) + assert result == {"query": "a=b"} + + def test_invalid_form_rejected(self): + with pytest.raises(ValueError): + operations_cli._parse_param_kv(["no_equals_sign"]) + + def test_empty_key_rejected(self): + with pytest.raises(ValueError): + operations_cli._parse_param_kv(["=value"]) diff --git a/tests/drivers/camera.py b/tests/drivers/camera.py index 5a5804d..a2b659f 100644 --- a/tests/drivers/camera.py +++ b/tests/drivers/camera.py @@ -20,6 +20,7 @@ class TestCameraDriver(DeviceDriver): """Simulated camera for integration tests.""" device_type = "test_camera" + labels = {"category": "camera"} def __init__(self, failure_rate: float = 0.0, min_latency_ms: float = 10, max_latency_ms: float = 50, location: str = "test-zone"): @@ -51,7 +52,7 @@ def identity(self) -> DeviceIdentity: def status(self) -> DeviceStatus: return DeviceStatus(location=self._location) - @rpc() + @rpc(labels={"direction": "write", "modality": "rgb"}) async def capture_image(self, resolution: str = "1080p") -> dict: """Capture a simulated test image.""" await self.simulate_delay() @@ -64,12 +65,12 @@ async def capture_image(self, resolution: str = "1080p") -> dict: "device_id": getattr(self, "_device_id", "unknown"), } - @emit() + @emit(labels={"modality": "motion"}) async def state_change_detected(self, zone_id: str, state_class: str, details: Optional[str] = None): """State change detected in camera view.""" pass - @emit() + @emit(labels={"modality": "rgb"}) async def object_detected(self, label: str, confidence: float, bbox: Optional[list] = None): """Object detected in camera view.""" pass diff --git a/tests/drivers/robot.py b/tests/drivers/robot.py index be0e59e..5641a79 100644 --- a/tests/drivers/robot.py +++ b/tests/drivers/robot.py @@ -19,6 +19,7 @@ class TestRobotDriver(DeviceDriver): """Simulated cleaning robot for integration tests.""" device_type = "test_robot" + labels = {"category": "robot"} def __init__(self, clean_duration: float = 0.5, failure_rate: float = 0.0, min_latency_ms: float = 10, max_latency_ms: float = 50, @@ -59,7 +60,7 @@ def identity(self) -> DeviceIdentity: def status(self) -> DeviceStatus: return DeviceStatus(location=self._location) - @rpc() + @rpc(labels={"direction": "write", "safety": "critical"}) async def dispatch_robot(self, zone_id: str) -> dict: """Dispatch the robot to clean a zone.""" await self.simulate_delay() @@ -72,7 +73,7 @@ async def dispatch_robot(self, zone_id: str) -> dict: self._cleaning_task = asyncio.create_task(self._do_cleaning(zone_id)) return {"status": "accepted", "zone_id": zone_id, "estimated_duration": self._clean_duration} - @rpc() + @rpc(labels={"direction": "read"}) async def get_status(self) -> dict: """Get current robot status.""" await self.simulate_delay() diff --git a/tests/drivers/sensor.py b/tests/drivers/sensor.py index ba6baed..5632ce8 100644 --- a/tests/drivers/sensor.py +++ b/tests/drivers/sensor.py @@ -20,6 +20,7 @@ class TestSensorDriver(DeviceDriver): """Simulated temperature/humidity sensor for integration tests.""" device_type = "test_sensor" + labels = {"category": "sensor"} def __init__(self, failure_rate: float = 0.0, min_latency_ms: float = 10, max_latency_ms: float = 50, location: str = "test-room", @@ -59,7 +60,7 @@ def identity(self) -> DeviceIdentity: def status(self) -> DeviceStatus: return DeviceStatus(location=self._location, availability="available") - @rpc() + @rpc(labels={"direction": "read", "modality": "thermal"}) async def get_reading(self, unit: str = "celsius") -> dict: """Get current temperature and humidity reading.""" await self.simulate_delay() @@ -78,13 +79,13 @@ async def get_reading(self, unit: str = "celsius") -> dict: "device_id": getattr(self, "_device_id", "unknown"), } - @rpc() + @rpc(labels={"direction": "write", "safety": "critical"}) async def set_threshold(self, temperature: float, humidity: Optional[float] = None) -> dict: """Set alert thresholds.""" await self.simulate_delay() return {"status": "success", "temperature_threshold": temperature} - @rpc() + @rpc(labels={"direction": "write"}) async def set_location(self, location: str) -> dict: """Update the sensor's location.""" await self.simulate_delay() @@ -92,12 +93,12 @@ async def set_location(self, location: str) -> dict: self._location = location return {"status": "success", "old_location": old, "location": location} - @emit() + @emit(labels={"modality": "thermal"}) async def reading(self, temperature: float, humidity: float, unit: str = "celsius"): """Periodic sensor reading.""" pass - @emit() + @emit(labels={"safety": "critical"}) async def threshold_exceeded(self, temperature: float, humidity: float, exceeded: str): """Threshold exceeded alert.""" pass diff --git a/tests/tests/test_tools_broadcast.py b/tests/tests/test_tools_broadcast.py new file mode 100644 index 0000000..975016e --- /dev/null +++ b/tests/tests/test_tools_broadcast.py @@ -0,0 +1,358 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Integration tests for selector-driven broadcast + correlation replies. + +End-to-end coverage for the async fan-out path: +- Dispatcher publishes a broadcast envelope on the fanout subject. +- Each device runtime self-elects via target_device_ids and the optional + CEL ``where`` predicate. +- Devices execute the function and emit a reply on the per-device async + reply subject keyed by correlation_id. +- ``await_replies`` collects replies for a bounded window. +""" + +import asyncio +import time + +import pytest + +SETTLE_TIME = 0.4 +DISCOVERY_TIMEOUT = 5.0 + + +async def _wait_for_devices(messaging_url, expected_ids): + from device_connect_agent_tools import connect + from device_connect_agent_tools.connection import get_connection + + await asyncio.to_thread(connect, nats_url=messaging_url) + deadline = time.monotonic() + DISCOVERY_TIMEOUT + while True: + conn = get_connection() + devices = await asyncio.to_thread(conn.list_devices) + ids = {d.get("device_id") for d in devices} + if expected_ids.issubset(ids) or time.monotonic() > deadline: + return devices + await asyncio.sleep(0.25) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_broadcast_returns_correlation_and_replies_arrive( + device_spawner, messaging_url, +): + """broadcast() returns a correlation_id and matching devices reply on the + per-device async reply subject.""" + await device_spawner.spawn_camera("itest-bc-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-bc-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices(messaging_url, {"itest-bc-cam-1", "itest-bc-cam-2"}) + try: + result = await asyncio.to_thread( + broadcast, + "device(itest-bc-cam-*).function(capture_image)", + {"resolution": "720p"}, + ) + assert result["correlation_id"].startswith("br-") + assert result["candidates"] == 2 + assert result["function"] == "capture_image" + + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], timeout=5.0, until=2, + ) + assert len(replies) == 2 + ids = {r["device_id"] for r in replies} + assert ids == {"itest-bc-cam-1", "itest-bc-cam-2"} + for r in replies: + assert r["success"] is True + assert r["correlation_id"] == result["correlation_id"] + assert "actually_fired_at" in r + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_broadcast_where_filters_at_edge(device_spawner, messaging_url): + """A CEL where predicate runs at each candidate; only matches reply.""" + pytest.importorskip("celpy") + await device_spawner.spawn_camera("itest-bcw-cam-a", location="lab-A") + await device_spawner.spawn_camera("itest-bcw-cam-b", location="lab-B") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices(messaging_url, {"itest-bcw-cam-a", "itest-bcw-cam-b"}) + try: + result = await asyncio.to_thread( + broadcast, + "device(itest-bcw-cam-*).function(capture_image)", + {"resolution": "1080p"}, + "labels.location == 'lab-A'", # where predicate + ) + assert result["candidates"] == 2 + + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], timeout=3.0, + ) + # Only cam-a is in lab-A; cam-b silently self-deselects. + ids = {r["device_id"] for r in replies} + assert ids == {"itest-bcw-cam-a"} + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_broadcast_fire_at_synchronizes_fan_out( + device_spawner, messaging_url, +): + """fire_at causes each device to fire from its own clock at the deadline.""" + await device_spawner.spawn_camera("itest-bcf-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-bcf-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices(messaging_url, {"itest-bcf-cam-1", "itest-bcf-cam-2"}) + try: + # Schedule 0.5s in the future; on_late=skip so any tardy device drops + # the call rather than firing late and breaking the coherence. + scheduled = time.time() + 0.5 + result = await asyncio.to_thread( + broadcast, + "device(itest-bcf-cam-*).function(capture_image)", + None, None, None, + scheduled, # fire_at + "skip", # on_late + ) + assert result["candidates"] == 2 + + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], timeout=3.0, until=2, + ) + assert len(replies) == 2 + # actually_fired_at should be at-or-after the scheduled time on each. + for r in replies: + assert r["actually_fired_at"] >= scheduled - 0.05 # small slack + # Achieved spread should be tight (well under network jitter). + spread = max(r["actually_fired_at"] for r in replies) - min( + r["actually_fired_at"] for r in replies + ) + assert spread < 0.5, f"fire_at spread too wide: {spread:.3f}s" + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_broadcast_fire_at_late_with_skip_drops( + device_spawner, messaging_url, +): + """A fire_at in the past with on_late=skip yields no replies.""" + await device_spawner.spawn_camera("itest-bcl-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices(messaging_url, {"itest-bcl-cam"}) + try: + past = time.time() - 5.0 # already 5s late + result = await asyncio.to_thread( + broadcast, + "device(itest-bcl-cam).function(capture_image)", + None, None, None, past, "skip", + ) + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], timeout=1.5, + ) + assert replies == [] + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_broadcast_where_with_bindings(device_spawner, messaging_url): + """A where predicate that reads bindings. self-elects per-target.""" + pytest.importorskip("celpy") + await device_spawner.spawn_camera("itest-bcbnd-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-bcbnd-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices( + messaging_url, {"itest-bcbnd-cam-1", "itest-bcbnd-cam-2"} + ) + try: + # Allowlist sent in bindings; the predicate uses bindings.allow to + # select. Devices not in the allowlist self-deselect silently. + result = await asyncio.to_thread( + broadcast, + "device(itest-bcbnd-cam-*).function(capture_image)", + None, + "identity.device_id in bindings.allow", + {"allow": ["itest-bcbnd-cam-1"]}, + ) + assert result["candidates"] == 2 + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], timeout=3.0, + ) + ids = {r["device_id"] for r in replies} + assert ids == {"itest-bcbnd-cam-1"} + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_await_replies_until_stops_early(device_spawner, messaging_url): + """``await_replies`` returns once ``until`` replies have arrived.""" + await device_spawner.spawn_camera("itest-awu-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-awu-cam-2", location="lab-A") + await device_spawner.spawn_camera("itest-awu-cam-3", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices( + messaging_url, {"itest-awu-cam-1", "itest-awu-cam-2", "itest-awu-cam-3"} + ) + try: + result = await asyncio.to_thread( + broadcast, "device(itest-awu-cam-*).function(capture_image)", + ) + assert result["candidates"] == 3 + # until=1 should let us return after the first reply arrives even + # though more are coming. + t0 = time.monotonic() + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], + timeout=5.0, until=1, poll_interval=0.02, + ) + elapsed = time.monotonic() - t0 + assert len(replies) >= 1 + # Sanity: returning early should be well under the timeout. + assert elapsed < 2.0, f"await_replies(until=1) took {elapsed:.2f}s" + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_subscribe_iter_protocol(device_spawner, messaging_url): + """``for msg in sub:`` works via Subscription.__iter__.""" + await device_spawner.spawn_camera("itest-subiter-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-subiter-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import broadcast, disconnect, subscribe + + await _wait_for_devices( + messaging_url, {"itest-subiter-cam-1", "itest-subiter-cam-2"} + ) + try: + result = await asyncio.to_thread( + broadcast, + "device(itest-subiter-cam-*).function(capture_image)", + ) + cid = result["correlation_id"] + + def collect(): + # Exercise the bare ``for msg in sub:`` form (uses __iter__). + # Break after both expected replies arrive so the test stays + # bounded regardless of the default idle timeout. + with subscribe(f"correlation:{cid}") as sub: + gathered: list[dict] = [] + for msg in sub: + gathered.append(msg) + if len(gathered) >= 2: + break + return gathered + + replies = await asyncio.to_thread(collect) + ids = {r["device_id"] for r in replies} + assert ids == {"itest-subiter-cam-1", "itest-subiter-cam-2"} + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_subscribe_event_selector_live_stream(device_spawner, messaging_url): + """subscribe(event()) receives live events from matching devices.""" + device, driver = await device_spawner.spawn_camera( + "itest-evsub-cam", location="lab-A", + ) + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, subscribe + + await _wait_for_devices(messaging_url, {"itest-evsub-cam"}) + try: + with subscribe("device(itest-evsub-cam).event(object_detected)") as sub: + await asyncio.sleep(SETTLE_TIME) # let subscription warm up + await driver.trigger_event( + "object_detected", + {"label": "person", "confidence": 0.95}, + ) + msgs = await asyncio.to_thread( + list, sub.iter(timeout=2.0, poll_interval=0.05), + ) + # The event arrives via the JSON-RPC event subject; payload is + # under either ``params`` or top-level depending on transport. + matching = [ + m for m in msgs + if (m.get("params") or {}).get("label") == "person" + or m.get("label") == "person" + ] + assert matching, f"no object_detected events received: {msgs}" + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_subscribe_correlation_form(device_spawner, messaging_url): + """subscribe('correlation:') captures replies as they arrive.""" + await device_spawner.spawn_camera("itest-bcs-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-bcs-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import broadcast, disconnect, subscribe + + await _wait_for_devices(messaging_url, {"itest-bcs-cam-1", "itest-bcs-cam-2"}) + try: + result = await asyncio.to_thread( + broadcast, + "device(itest-bcs-cam-*).function(capture_image)", + ) + cid = result["correlation_id"] + + def collect(): + with subscribe(f"correlation:{cid}") as sub: + # Drain over a short window. + return list(sub.iter(timeout=2.0, poll_interval=0.05)) + + replies = await asyncio.to_thread(collect) + ids = {r["device_id"] for r in replies} + assert ids == {"itest-bcs-cam-1", "itest-bcs-cam-2"} + finally: + await asyncio.to_thread(disconnect) diff --git a/tests/tests/test_tools_invoke.py b/tests/tests/test_tools_invoke.py index 447f301..df9878b 100644 --- a/tests/tests/test_tools_invoke.py +++ b/tests/tests/test_tools_invoke.py @@ -2,40 +2,65 @@ # # SPDX-License-Identifier: Apache-2.0 -"""Integration tests for device-connect-agent-tools invoke_device(). +"""Integration tests for selector-driven invocation tools. -Tests that the agent SDK can invoke device RPCs via the messaging backend. +Covers ``invoke()`` and ``invoke_many()`` against real devices registered +via the messaging backend. Exercises single-match, ambiguous-match, +selector-scope rejection, parallel fan-out, and partial-failure semantics +end-to-end. """ import asyncio -import pytest +import time +import pytest SETTLE_TIME = 0.3 +DISCOVERY_TIMEOUT = 5.0 + + +async def _wait_for_devices(messaging_url, expected_ids): + """Connect and poll until all expected ``device_ids`` are visible.""" + from device_connect_agent_tools import connect + from device_connect_agent_tools.connection import get_connection + + await asyncio.to_thread(connect, nats_url=messaging_url) + deadline = time.monotonic() + DISCOVERY_TIMEOUT + while True: + conn = get_connection() + devices = await asyncio.to_thread(conn.list_devices) + ids = {d.get("device_id") for d in devices} + if expected_ids.issubset(ids) or time.monotonic() > deadline: + return devices + await asyncio.sleep(0.25) + + +# -- invoke --------------------------------------------------------- @pytest.mark.asyncio @pytest.mark.integration async def test_invoke_sensor_reading(device_spawner, messaging_url): - """invoke_device() should call sensor's get_reading and return result.""" + """invoke() calls sensor.get_reading and returns the reading payload.""" await device_spawner.spawn_sensor( - "itest-tools-invoke-sensor", initial_temp=23.5, initial_humidity=50.0, + "itest-inv-read-sensor", initial_temp=23.5, initial_humidity=50.0, ) await asyncio.sleep(SETTLE_TIME) - from device_connect_agent_tools import connect, disconnect, invoke_device + from device_connect_agent_tools import disconnect, invoke - await asyncio.to_thread(connect, nats_url=messaging_url) + await _wait_for_devices(messaging_url, {"itest-inv-read-sensor"}) try: result = await asyncio.to_thread( - invoke_device, - device_id="itest-tools-invoke-sensor", - function="get_reading", - params={"unit": "celsius"}, - llm_reasoning="Testing sensor read", + invoke, + "device(itest-inv-read-sensor).function(get_reading)", + {"unit": "celsius"}, + "Testing sensor read", ) - assert isinstance(result, dict) - assert result.get("success") is True or "temperature" in result.get("result", {}) + assert result["success"] is True + assert result["device_id"] == "itest-inv-read-sensor" + assert result["function"] == "get_reading" + assert "temperature" in result["result"] finally: await asyncio.to_thread(disconnect) @@ -43,25 +68,26 @@ async def test_invoke_sensor_reading(device_spawner, messaging_url): @pytest.mark.asyncio @pytest.mark.integration async def test_invoke_robot_dispatch(device_spawner, event_capture, messaging_url): - """invoke_device() should dispatch robot and trigger cleaning.""" + """invoke() dispatches the robot and the cleaning_finished event arrives.""" await device_spawner.spawn_robot( - "itest-tools-invoke-robot", clean_duration=0.3, + "itest-inv-robot", clean_duration=0.3, ) await asyncio.sleep(SETTLE_TIME) - async with event_capture.subscribe("device-connect.*.itest-tools-invoke-robot.event.*") as events: - from device_connect_agent_tools import connect, disconnect, invoke_device + async with event_capture.subscribe( + "device-connect.*.itest-inv-robot.event.*" + ) as events: + from device_connect_agent_tools import disconnect, invoke - await asyncio.to_thread(connect, nats_url=messaging_url) + await _wait_for_devices(messaging_url, {"itest-inv-robot"}) try: result = await asyncio.to_thread( - invoke_device, - device_id="itest-tools-invoke-robot", - function="dispatch_robot", - params={"zone_id": "zone-tools"}, - llm_reasoning="Testing robot dispatch via tools", + invoke, + "device(itest-inv-robot).function(dispatch_robot)", + {"zone_id": "zone-tools"}, + "Testing robot dispatch", ) - assert isinstance(result, dict) + assert result["success"] is True finally: await asyncio.to_thread(disconnect) @@ -71,42 +97,184 @@ async def test_invoke_robot_dispatch(device_spawner, event_capture, messaging_ur @pytest.mark.asyncio @pytest.mark.integration -async def test_invoke_unknown_device(messaging_url): - """invoke_device() on non-existent device should return error.""" - from device_connect_agent_tools import connect, disconnect, invoke_device +async def test_invoke_no_match_returns_no_match(device_spawner, messaging_url): + """A selector that resolves to zero functions returns ``no_match``.""" + await device_spawner.spawn_camera("itest-inv-nomatch-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import connect, disconnect, invoke await asyncio.to_thread(connect, nats_url=messaging_url) try: result = await asyncio.to_thread( - invoke_device, - device_id="nonexistent-device-xyz", - function="ping", - llm_reasoning="Testing error handling", + invoke, + "device(itest-inv-nomatch-cam).function(does_not_exist)", + ) + assert result["success"] is False + assert result["error"]["code"] == "no_match" + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_invoke_ambiguous_match_returns_error(device_spawner, messaging_url): + """A selector matching multiple (device, function) tuples returns an error.""" + await device_spawner.spawn_camera("itest-inv-amb-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-inv-amb-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, invoke + + await _wait_for_devices( + messaging_url, {"itest-inv-amb-cam-1", "itest-inv-amb-cam-2"} + ) + try: + result = await asyncio.to_thread( + invoke, "device(itest-inv-amb-cam-*).function(capture_image)", + ) + assert result["success"] is False + assert result["error"]["code"] == "ambiguous_match" + cand_ids = {c["device_id"] for c in result["candidates"]} + assert {"itest-inv-amb-cam-1", "itest-inv-amb-cam-2"} <= cand_ids + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_invoke_device_only_scope_rejected(device_spawner, messaging_url): + """A device-only selector cannot resolve to a function.""" + await device_spawner.spawn_camera("itest-inv-scope-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import connect, disconnect, invoke + + await asyncio.to_thread(connect, nats_url=messaging_url) + try: + result = await asyncio.to_thread(invoke, "device(itest-inv-scope-cam)") + assert result["success"] is False + assert result["error"]["code"] == "invalid_invoke_scope" + finally: + await asyncio.to_thread(disconnect) + + +# -- invoke_many ---------------------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_invoke_many_succeeds_across_devices(device_spawner, messaging_url): + """invoke_many() fans out a single function across multiple matching devices.""" + await device_spawner.spawn_camera("itest-inv-many-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-inv-many-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, invoke_many + + await _wait_for_devices( + messaging_url, {"itest-inv-many-cam-1", "itest-inv-many-cam-2"} + ) + try: + result = await asyncio.to_thread( + invoke_many, + "device(itest-inv-many-cam-*).function(capture_image)", + {"resolution": "720p"}, ) - assert isinstance(result, dict) - assert result.get("success") is False + assert result["candidates"] == 2 + assert result["matched"] == 2 + assert result["succeeded"] == 2 + assert result["failed"] == 0 + ids = {row["device_id"] for row in result["results"]} + assert ids == {"itest-inv-many-cam-1", "itest-inv-many-cam-2"} finally: await asyncio.to_thread(disconnect) @pytest.mark.asyncio @pytest.mark.integration -async def test_invoke_camera_capture(device_spawner, messaging_url): - """invoke_device() should capture image from camera.""" - await device_spawner.spawn_camera("itest-tools-invoke-cam") +async def test_invoke_many_partial_failure(device_spawner, messaging_url): + """A failing target is recorded in errors while siblings succeed.""" + await device_spawner.spawn_camera( + "itest-inv-many-pf-cam-1", location="lab-A", failure_rate=1.0, + ) + await device_spawner.spawn_camera( + "itest-inv-many-pf-cam-2", location="lab-A", + ) await asyncio.sleep(SETTLE_TIME) - from device_connect_agent_tools import connect, disconnect, invoke_device + from device_connect_agent_tools import disconnect, invoke_many + + await _wait_for_devices( + messaging_url, + {"itest-inv-many-pf-cam-1", "itest-inv-many-pf-cam-2"}, + ) + try: + result = await asyncio.to_thread( + invoke_many, + "device(itest-inv-many-pf-cam-*).function(capture_image)", + ) + assert result["candidates"] == 2 + assert result["matched"] == 2 + assert result["succeeded"] == 1 + assert result["failed"] == 1 + success_ids = {row["device_id"] for row in result["results"]} + error_ids = {row["device_id"] for row in result["errors"]} + assert success_ids == {"itest-inv-many-pf-cam-2"} + assert error_ids == {"itest-inv-many-pf-cam-1"} + for row in result["errors"]: + assert "code" in row["error"] + assert "message" in row["error"] + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_invoke_many_zero_candidates(device_spawner, messaging_url): + """No matches yields an empty envelope, not an error.""" + await device_spawner.spawn_camera("itest-inv-many-zero-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import connect, disconnect, invoke_many await asyncio.to_thread(connect, nats_url=messaging_url) try: result = await asyncio.to_thread( - invoke_device, - device_id="itest-tools-invoke-cam", - function="capture_image", - params={"resolution": "720p"}, - llm_reasoning="Testing camera capture via tools", + invoke_many, + "device(itest-no-such-device).function(capture_image)", ) - assert isinstance(result, dict) + assert result["candidates"] == 0 + assert result["matched"] == 0 + assert result["succeeded"] == 0 + assert result["failed"] == 0 + assert result["results"] == [] + assert result["errors"] == [] + assert "error" not in result + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_invoke_many_function_only_selector(device_spawner, messaging_url): + """function() selects the function across the whole fleet.""" + await device_spawner.spawn_sensor( + "itest-inv-many-fo-sensor", initial_temp=20.0, + ) + await device_spawner.spawn_camera("itest-inv-many-fo-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, invoke_many + + await _wait_for_devices( + messaging_url, {"itest-inv-many-fo-cam", "itest-inv-many-fo-sensor"} + ) + try: + result = await asyncio.to_thread(invoke_many, "function(get_reading)") + ids = {row["device_id"] for row in result["results"]} + assert "itest-inv-many-fo-sensor" in ids + # Camera does not have get_reading; should not be in results. + assert "itest-inv-many-fo-cam" not in ids finally: await asyncio.to_thread(disconnect) diff --git a/tests/tests/test_tools_selector.py b/tests/tests/test_tools_selector.py new file mode 100644 index 0000000..7711393 --- /dev/null +++ b/tests/tests/test_tools_selector.py @@ -0,0 +1,570 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Integration tests for selector-driven discovery tools. + +Covers ``discover()`` and ``discover_labels()`` against real devices +registered via the messaging backend. Exercises the full selector grammar +end-to-end across all five scope shapes (device / device.function / +device.event / function / event), label filters (category, location, +direction, modality, safety), pagination, and the legacy-location mirror. +""" + +import asyncio +import time + +import pytest + +SETTLE_TIME = 0.3 +DISCOVERY_TIMEOUT = 5.0 + + +async def _wait_for_devices(messaging_url, expected_ids): + """Connect and poll until all expected ``device_ids`` are visible. + + Returns the list of flattened device dicts. Caller is responsible for + disconnecting. + """ + from device_connect_agent_tools import connect + from device_connect_agent_tools.connection import get_connection + + await asyncio.to_thread(connect, nats_url=messaging_url) + deadline = time.monotonic() + DISCOVERY_TIMEOUT + while True: + conn = get_connection() + devices = await asyncio.to_thread(conn.list_devices) + ids = {d.get("device_id") for d in devices} + if expected_ids.issubset(ids) or time.monotonic() > deadline: + return devices + await asyncio.sleep(0.25) + + +# -- discover: device-only scope --------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_wildcard_returns_all_devices(device_spawner, messaging_url): + """``discover('device(*)')`` returns the full roster.""" + await device_spawner.spawn_camera("itest-sel-all-cam", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-all-sensor", location="lab-B") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices(messaging_url, {"itest-sel-all-cam", "itest-sel-all-sensor"}) + try: + result = await asyncio.to_thread(discover, "device(*)") + assert result["scope"] == "device_only" + assert result["matched"] >= 2 + ids = {d["device_id"] for d in result["results"]} + assert {"itest-sel-all-cam", "itest-sel-all-sensor"} <= ids + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_by_device_id(device_spawner, messaging_url): + """A bare-id selector resolves to one device.""" + await device_spawner.spawn_camera("itest-sel-id-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices(messaging_url, {"itest-sel-id-cam"}) + try: + result = await asyncio.to_thread(discover, "device(itest-sel-id-cam)") + assert result["scope"] == "device_only" + assert result["matched"] == 1 + assert result["results"][0]["device_id"] == "itest-sel-id-cam" + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_by_id_glob(device_spawner, messaging_url): + """Bare-id selectors accept globs (anchored fnmatch).""" + await device_spawner.spawn_camera("itest-sel-glob-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-sel-glob-cam-2", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-glob-sensor", location="lab-B") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices( + messaging_url, + {"itest-sel-glob-cam-1", "itest-sel-glob-cam-2", "itest-sel-glob-sensor"}, + ) + try: + result = await asyncio.to_thread(discover, "device(itest-sel-glob-cam-*)") + ids = {d["device_id"] for d in result["results"]} + assert "itest-sel-glob-cam-1" in ids + assert "itest-sel-glob-cam-2" in ids + assert "itest-sel-glob-sensor" not in ids + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_by_category_label(device_spawner, messaging_url): + """``device(category:camera)`` returns only cameras (label-based).""" + await device_spawner.spawn_camera("itest-sel-cat-cam", location="lab-A") + await device_spawner.spawn_robot("itest-sel-cat-robot", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-cat-sensor", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices( + messaging_url, + {"itest-sel-cat-cam", "itest-sel-cat-robot", "itest-sel-cat-sensor"}, + ) + try: + result = await asyncio.to_thread(discover, "device(category:camera)") + ids = {d["device_id"] for d in result["results"]} + assert "itest-sel-cat-cam" in ids + assert "itest-sel-cat-robot" not in ids + assert "itest-sel-cat-sensor" not in ids + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_or_within_category(device_spawner, messaging_url): + """Bracket lists OR within a key: cameras or robots, not sensors.""" + await device_spawner.spawn_camera("itest-sel-or-cam", location="lab-A") + await device_spawner.spawn_robot("itest-sel-or-robot", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-or-sensor", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices( + messaging_url, + {"itest-sel-or-cam", "itest-sel-or-robot", "itest-sel-or-sensor"}, + ) + try: + result = await asyncio.to_thread( + discover, "device(category:[camera,robot])" + ) + ids = {d["device_id"] for d in result["results"]} + assert "itest-sel-or-cam" in ids + assert "itest-sel-or-robot" in ids + assert "itest-sel-or-sensor" not in ids + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_and_across_category_and_location( + device_spawner, messaging_url +): + """Comma is AND across keys: category=camera AND location=lab-A.""" + await device_spawner.spawn_camera("itest-sel-and-cam-a", location="lab-A") + await device_spawner.spawn_camera("itest-sel-and-cam-b", location="lab-B") + await device_spawner.spawn_robot("itest-sel-and-robot-a", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices( + messaging_url, + {"itest-sel-and-cam-a", "itest-sel-and-cam-b", "itest-sel-and-robot-a"}, + ) + try: + result = await asyncio.to_thread( + discover, "device(category:camera, location:lab-A)" + ) + ids = {d["device_id"] for d in result["results"]} + assert "itest-sel-and-cam-a" in ids + assert "itest-sel-and-cam-b" not in ids # wrong location + assert "itest-sel-and-robot-a" not in ids # wrong category + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_by_location_via_legacy_mirror(device_spawner, messaging_url): + """Legacy ``DeviceStatus.location`` is mirrored into ``labels['location']``. + + The flatten_device location-mirror lifts ``status.location`` into + ``labels['location']`` when ``capabilities.labels`` does not declare + one, so selector queries on location work even for drivers that only + populate the legacy heartbeat field. + """ + await device_spawner.spawn_camera("itest-sel-mirror-cam", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-mirror-sensor", location="lab-B") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices( + messaging_url, {"itest-sel-mirror-cam", "itest-sel-mirror-sensor"} + ) + try: + result = await asyncio.to_thread(discover, "device(location:lab-A)") + ids = {d["device_id"] for d in result["results"]} + assert "itest-sel-mirror-cam" in ids + assert "itest-sel-mirror-sensor" not in ids + finally: + await asyncio.to_thread(disconnect) + + +# -- discover: function-scoped -------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_function_scope_per_device(device_spawner, messaging_url): + """``device().function(*)`` returns a device's RPC roster.""" + await device_spawner.spawn_camera("itest-sel-fn-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices(messaging_url, {"itest-sel-fn-cam"}) + try: + result = await asyncio.to_thread( + discover, "device(itest-sel-fn-cam).function(*)" + ) + assert result["scope"] == "device_function" + names = {row.get("name") for row in result["results"]} + assert "capture_image" in names + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_function_by_name_fleet_wide(device_spawner, messaging_url): + """``device(*).function()`` returns ``(device, function)`` tuples.""" + await device_spawner.spawn_camera("itest-sel-fnname-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-sel-fnname-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices( + messaging_url, {"itest-sel-fnname-cam-1", "itest-sel-fnname-cam-2"} + ) + try: + result = await asyncio.to_thread( + discover, "device(*).function(capture_image)" + ) + device_ids = {row["device_id"] for row in result["results"]} + assert {"itest-sel-fnname-cam-1", "itest-sel-fnname-cam-2"} <= device_ids + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_function_by_direction_label(device_spawner, messaging_url): + """``device(*).function(direction:write)`` matches on FunctionDef labels.""" + await device_spawner.spawn_camera("itest-sel-dir-cam", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-dir-sensor", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices(messaging_url, {"itest-sel-dir-cam", "itest-sel-dir-sensor"}) + try: + result = await asyncio.to_thread( + discover, "device(*).function(direction:write)" + ) + names = {row.get("name") for row in result["results"]} + # camera.capture_image (write), sensor.set_threshold (write), + # sensor.set_location (write) + assert "capture_image" in names + assert "set_threshold" in names + assert "get_reading" not in names # direction:read + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_function_safety_critical(device_spawner, messaging_url): + """``function(safety:critical)`` returns critical RPCs fleet-wide.""" + await device_spawner.spawn_robot("itest-sel-crit-robot", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-crit-sensor", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices(messaging_url, {"itest-sel-crit-robot", "itest-sel-crit-sensor"}) + try: + result = await asyncio.to_thread(discover, "function(safety:critical)") + assert result["scope"] == "function_only" + names = {row.get("name") for row in result["results"]} + # robot.dispatch_robot, sensor.set_threshold are safety:critical + assert "dispatch_robot" in names + assert "set_threshold" in names + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_function_and_labels(device_spawner, messaging_url): + """``function(direction:write, modality:rgb)`` ANDs across function labels.""" + await device_spawner.spawn_camera("itest-sel-fnand-cam", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-fnand-sensor", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices(messaging_url, {"itest-sel-fnand-cam", "itest-sel-fnand-sensor"}) + try: + result = await asyncio.to_thread( + discover, "function(direction:write, modality:rgb)" + ) + names = {row.get("name") for row in result["results"]} + # only camera.capture_image is direction:write AND modality:rgb + assert names == {"capture_image"} or "capture_image" in names + assert "set_threshold" not in names # write but no modality:rgb + finally: + await asyncio.to_thread(disconnect) + + +# -- discover: event-scoped ----------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_event_by_name_fleet_wide(device_spawner, messaging_url): + """``event()`` returns events fleet-wide.""" + await device_spawner.spawn_camera("itest-sel-evname-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices(messaging_url, {"itest-sel-evname-cam"}) + try: + result = await asyncio.to_thread(discover, "event(object_detected)") + assert result["scope"] == "event_only" + device_ids = {row["device_id"] for row in result["results"]} + assert "itest-sel-evname-cam" in device_ids + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_event_by_modality_label(device_spawner, messaging_url): + """``device(*).event(modality:rgb)`` matches on EventDef labels.""" + await device_spawner.spawn_camera("itest-sel-evmod-cam", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-evmod-sensor", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices(messaging_url, {"itest-sel-evmod-cam", "itest-sel-evmod-sensor"}) + try: + result = await asyncio.to_thread( + discover, "device(*).event(modality:rgb)" + ) + names = {row.get("name") for row in result["results"]} + # camera.object_detected has modality:rgb + assert "object_detected" in names + # sensor.reading has modality:thermal, not rgb + assert "reading" not in names + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_event_safety_critical(device_spawner, messaging_url): + """``event(safety:critical)`` finds the sensor.threshold_exceeded event.""" + await device_spawner.spawn_sensor("itest-sel-evcrit-sensor", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices(messaging_url, {"itest-sel-evcrit-sensor"}) + try: + result = await asyncio.to_thread(discover, "event(safety:critical)") + names = {row.get("name") for row in result["results"]} + assert "threshold_exceeded" in names + finally: + await asyncio.to_thread(disconnect) + + +# -- discover: pagination & errors ---------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_pagination(device_spawner, messaging_url): + """``offset`` and ``limit`` produce stable, non-overlapping pages.""" + ids = {f"itest-sel-page-cam-{i}" for i in range(3)} + for did in sorted(ids): + await device_spawner.spawn_camera(did, location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices(messaging_url, ids) + try: + page1 = await asyncio.to_thread( + discover, "device(category:camera)", 0, 2 + ) + page2 = await asyncio.to_thread( + discover, "device(category:camera)", page1["next_offset"] or 0, 2 + ) + assert page1["returned"] <= 2 + page1_ids = {d["device_id"] for d in page1["results"]} + page2_ids = {d["device_id"] for d in page2["results"]} + assert not (page1_ids & page2_ids), "pages should not overlap" + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_offset_past_end_returns_empty(device_spawner, messaging_url): + """An offset beyond ``matched`` returns an empty page with ``next_offset=None``.""" + await device_spawner.spawn_camera("itest-sel-oob-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices(messaging_url, {"itest-sel-oob-cam"}) + try: + result = await asyncio.to_thread(discover, "device(*)", 9999, 50) + assert result["returned"] == 0 + assert result["results"] == [] + assert result["next_offset"] is None + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_invalid_selector_returns_error(device_spawner, messaging_url): + """A bad selector returns an error-as-data envelope, not a raise.""" + await device_spawner.spawn_camera("itest-sel-err-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import connect, disconnect, discover + + await asyncio.to_thread(connect, nats_url=messaging_url) + try: + result = await asyncio.to_thread(discover, "device(") + assert result["error"]["code"] == "selector_parse_error" + assert result["matched"] == 0 + assert result["results"] == [] + finally: + await asyncio.to_thread(disconnect) + + +# -- discover_labels() ----------------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_labels_includes_category(device_spawner, messaging_url): + """Vocabulary surfaces ``category`` from device-level labels.""" + await device_spawner.spawn_camera("itest-sel-vcat-cam", location="lab-A") + await device_spawner.spawn_robot("itest-sel-vcat-robot", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-vcat-sensor", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover_labels + + await _wait_for_devices( + messaging_url, + {"itest-sel-vcat-cam", "itest-sel-vcat-robot", "itest-sel-vcat-sensor"}, + ) + try: + result = await asyncio.to_thread(discover_labels) + cat = result["device_keys"].get("category") + assert cat is not None + values = cat["values"] + assert "camera" in values + assert "robot" in values + assert "sensor" in values + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_labels_includes_location_via_mirror( + device_spawner, messaging_url +): + """Vocabulary surfaces ``location`` even when only ``DeviceStatus.location`` is set.""" + await device_spawner.spawn_camera("itest-sel-vloc-cam", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-vloc-sensor", location="lab-B") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover_labels + + await _wait_for_devices(messaging_url, {"itest-sel-vloc-cam", "itest-sel-vloc-sensor"}) + try: + result = await asyncio.to_thread(discover_labels) + loc = result["device_keys"].get("location") + assert loc is not None + values = loc["values"] + assert "lab-A" in values + assert "lab-B" in values + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_labels_function_direction_histogram( + device_spawner, messaging_url +): + """Function-axis vocabulary surfaces ``direction`` with read/write counts.""" + await device_spawner.spawn_camera("itest-sel-vdir-cam", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-vdir-sensor", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover_labels + + await _wait_for_devices(messaging_url, {"itest-sel-vdir-cam", "itest-sel-vdir-sensor"}) + try: + result = await asyncio.to_thread(discover_labels) + direction = result["function_keys"].get("direction") + assert direction is not None + values = direction["values"] + assert "read" in values + assert "write" in values + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_labels_per_key_pagination(device_spawner, messaging_url): + """``discover_labels(key='device.category')`` paginates one key's values.""" + await device_spawner.spawn_camera("itest-sel-vpg-cam", location="lab-A") + await device_spawner.spawn_robot("itest-sel-vpg-robot", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-vpg-sensor", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover_labels + + await _wait_for_devices( + messaging_url, + {"itest-sel-vpg-cam", "itest-sel-vpg-robot", "itest-sel-vpg-sensor"}, + ) + try: + result = await asyncio.to_thread(discover_labels, "device.category") + assert result["axis"] == "device" + assert result["key"] == "category" + assert "values" in result + # at least camera, robot, sensor are present + assert {"camera", "robot", "sensor"} <= set(result["values"].keys()) + finally: + await asyncio.to_thread(disconnect)