diff --git a/include/arbiter/arbiter_model.h b/include/arbiter/arbiter_model.h index 36d2967..3c5111c 100644 --- a/include/arbiter/arbiter_model.h +++ b/include/arbiter/arbiter_model.h @@ -56,6 +56,7 @@ enum ARBITER_op { ARBITER_OP_CHANGED, ARBITER_OP_DELTA_GT, ARBITER_OP_DELTA_LT, + ARBITER_OP_HYSTERESIS = 13, }; /** Action types. */ @@ -76,6 +77,14 @@ enum ARBITER_cond_group { ARBITER_COND_NOT, }; +/** + * Maximum number of conditions that support per-condition state (hysteresis). + * Kept small to avoid dynamic allocation; sized for typical safety models. + */ +#ifndef CONFIG_ARBITER_MAX_HYSTERESIS_CONDITIONS +#define CONFIG_ARBITER_MAX_HYSTERESIS_CONDITIONS 32 +#endif + /** Expression operators for compute engine. */ enum ARBITER_expr_op { ARBITER_EXPR_ADD = 0, /**< target = left + right */ @@ -93,6 +102,7 @@ enum ARBITER_expr_op { ARBITER_EXPR_SCALE, /**< target = (left * right) / scale (fixed-point) */ ARBITER_EXPR_ASSIGN, /**< target = left (copy fact or literal) */ ARBITER_EXPR_ACCUMULATE, /**< target = target + (left * right) / scale */ + ARBITER_EXPR_LOOKUP = 15, /**< target = table_lookup(table[scale], left) */ }; /** Fact definition (compiled model table entry). */ @@ -125,6 +135,7 @@ struct ARBITER_condition_def { arbiter_index_t fact_id; enum ARBITER_op op; int32_t value; + int32_t aux_value; /**< Secondary threshold (e.g. falling edge for hysteresis). */ enum ARBITER_cond_group group; arbiter_index_t group_index; arbiter_index_t next; @@ -163,6 +174,13 @@ struct ARBITER_rule_def { #endif }; +/** Lookup table definition for interpolation. */ +struct ARBITER_table_def { + uint16_t count; /**< Number of entries in the table. */ + const int32_t *keys; /**< Sorted input key values. */ + const int32_t *values; /**< Output values (same count as keys). */ +}; + /** Complete compiled model. */ struct ARBITER_model { const char *name; @@ -181,6 +199,8 @@ struct ARBITER_model { const struct ARBITER_action_def *actions; const struct ARBITER_expr_def *expressions; const char **mode_names; + const struct ARBITER_table_def *tables; /**< Lookup tables. */ + uint16_t table_count; #if defined(CONFIG_ARBITER_FPGA_OFFLOAD) && CONFIG_ARBITER_FPGA_OFFLOAD const struct ARBITER_hw_offload_ops *offload_ops; #endif diff --git a/lib/arbiter_eval.c b/lib/arbiter_eval.c index 5b5b3cd..f46122e 100644 --- a/lib/arbiter_eval.c +++ b/lib/arbiter_eval.c @@ -69,10 +69,14 @@ ARBITER_ALWAYS_INLINE int32_t resolve_operand( * No pointer-to-pointer indirection -- values[] and timestamp are * passed directly so the compiler can keep them in registers. */ +/* Per-condition hysteresis state bitmask (static — survives across evals). */ +static uint32_t hyst_state[CONFIG_ARBITER_MAX_HYSTERESIS_CONDITIONS / 32 + 1]; + ARBITER_ALWAYS_INLINE bool eval_condition( const struct ARBITER_condition_def *__restrict cond, const struct ARBITER_fact_value *__restrict values, - arbiter_index_t vcount, uint32_t snap_ts) + arbiter_index_t vcount, uint32_t snap_ts, + arbiter_index_t cond_index) { if (unlikely(cond->fact_id >= vcount)) { return false; @@ -133,6 +137,43 @@ ARBITER_ALWAYS_INLINE bool eval_condition( case ARBITER_OP_NOT_IN: return val != cond->value; + /* Hysteresis: rising = value, falling = aux_value. + * State persists in a static bitmask across evaluations. + */ + case ARBITER_OP_HYSTERESIS: { + const int32_t rising = cond->value; + const int32_t falling = cond->aux_value; + bool prev_state = false; + + if (likely(cond_index < + CONFIG_ARBITER_MAX_HYSTERESIS_CONDITIONS)) { + prev_state = (hyst_state[cond_index / 32] >> + (cond_index & 31)) & 1u; + } + + bool result; + + if (val >= rising) { + result = true; + } else if (val <= falling) { + result = false; + } else { + result = prev_state; + } + + if (likely(cond_index < + CONFIG_ARBITER_MAX_HYSTERESIS_CONDITIONS)) { + if (result) { + hyst_state[cond_index / 32] |= + (1u << (cond_index & 31)); + } else { + hyst_state[cond_index / 32] &= + ~(1u << (cond_index & 31)); + } + } + return result; + } + default: return false; } @@ -157,7 +198,7 @@ ARBITER_ALWAYS_INLINE bool eval_condition_group( /* Fast path: single condition -- skip loop entirely */ if (likely(count == 1)) { bool r = eval_condition(&conds[start], values, - vcount, snap_ts); + vcount, snap_ts, start); return (group == ARBITER_COND_NOT) ? !r : r; } @@ -165,7 +206,8 @@ ARBITER_ALWAYS_INLINE bool eval_condition_group( if (likely(group == ARBITER_COND_ALL)) { for (arbiter_index_t i = 0; i < count; i++) { if (!eval_condition(&conds[start + i], values, - vcount, snap_ts)) { + vcount, snap_ts, + start + i)) { return false; } } @@ -175,7 +217,8 @@ ARBITER_ALWAYS_INLINE bool eval_condition_group( if (group == ARBITER_COND_ANY) { for (arbiter_index_t i = 0; i < count; i++) { if (eval_condition(&conds[start + i], values, - vcount, snap_ts)) { + vcount, snap_ts, + start + i)) { return true; } } @@ -183,7 +226,7 @@ ARBITER_ALWAYS_INLINE bool eval_condition_group( } /* ARBITER_COND_NOT: invert single child */ - return !eval_condition(&conds[start], values, vcount, snap_ts); + return !eval_condition(&conds[start], values, vcount, snap_ts, start); } /* ── Expression evaluator ─────────────────────────────────────── */ @@ -195,10 +238,56 @@ ARBITER_ALWAYS_INLINE bool eval_condition_group( * Switch cases ordered by frequency: ASSIGN and simple arithmetic * first (PID, Kalman models hit these 80%+ of the time). */ +/** + * Linear interpolation in a lookup table. + * Clamps to table endpoints when input is outside range. + */ +ARBITER_ALWAYS_INLINE int32_t table_lookup( + const struct ARBITER_table_def *__restrict tbl, + int32_t input) +{ + if (unlikely(tbl == NULL || tbl->count == 0)) { + return 0; + } + const uint16_t n = tbl->count; + const int32_t *__restrict keys = tbl->keys; + const int32_t *__restrict vals = tbl->values; + + /* Clamp below minimum */ + if (input <= keys[0]) { + return vals[0]; + } + /* Clamp above maximum */ + if (input >= keys[n - 1]) { + return vals[n - 1]; + } + /* Binary-ish scan for bracket (tables are small, linear is fine) */ + for (uint16_t i = 1; i < n; i++) { + if (input <= keys[i]) { + /* Linear interpolation between [i-1] and [i] */ + int32_t k0 = keys[i - 1]; + int32_t k1 = keys[i]; + int32_t v0 = vals[i - 1]; + int32_t v1 = vals[i]; + int32_t dk = k1 - k0; + + if (dk == 0) { + return v0; + } + /* lerp: v0 + (v1-v0)*(input-k0)/(k1-k0) */ + int64_t num = (int64_t)(v1 - v0) * + (int64_t)(input - k0); + return v0 + (int32_t)(num / dk); + } + } + return vals[n - 1]; +} + ARBITER_ALWAYS_INLINE void eval_expression( const struct ARBITER_expr_def *__restrict expr, struct ARBITER_fact_value *__restrict values, - arbiter_index_t vcount) + arbiter_index_t vcount, + const struct ARBITER_model *__restrict model) { const arbiter_index_t tid = expr->target_fact_id; @@ -281,6 +370,19 @@ ARBITER_ALWAYS_INLINE void eval_expression( case ARBITER_EXPR_SHIFT_L: result = left << (right & 31); break; + case ARBITER_EXPR_LOOKUP: { + /* scale field stores the table index */ + const uint16_t tbl_idx = (uint16_t)expr->scale; + + if (likely(model->tables != NULL && + tbl_idx < model->table_count)) { + result = table_lookup( + &model->tables[tbl_idx], left); + } else { + result = 0; + } + break; + } default: return; } @@ -423,11 +525,12 @@ int ARBITER_eval(const struct ARBITER_model *model, for (arbiter_index_t i = 0; i < ec; i++) { const arbiter_index_t ei = es + i; - if (likely(ei < expr_count)) { - eval_expression( - &exprs[ei], - values, vcount); - } + if (likely(ei < expr_count)) { + eval_expression( + &exprs[ei], + values, vcount, + model); + } } ops += ec; } diff --git a/python/arbiter/canonical.py b/python/arbiter/canonical.py index ac09c9a..70a1d1b 100644 --- a/python/arbiter/canonical.py +++ b/python/arbiter/canonical.py @@ -21,6 +21,8 @@ class CanonicalModel: actions: list[dict[str, Any]] modes: list[dict[str, Any]] expressions: list[dict[str, Any]] = field(default_factory=list) + tables: list[dict[str, Any]] = field(default_factory=list) + table_id_map: dict[str, int] = field(default_factory=dict) states: list[dict[str, Any]] = field(default_factory=list) transitions: list[dict[str, Any]] = field(default_factory=list) hazards: list[dict[str, Any]] = field(default_factory=list) @@ -105,6 +107,22 @@ def canonicalize(data: dict[str, Any]) -> CanonicalModel: annotated["_expr_count"] = len(rule_exprs) rules.append(annotated) + # Flatten tables + tables_raw = data.get("tables", []) + tables: list[dict[str, Any]] = [] + table_id_map: dict[str, int] = {} + if isinstance(tables_raw, list): + tables_sorted = sorted(tables_raw, key=lambda t: t.get("id", "") if isinstance(t, dict) else "") + for idx, tbl in enumerate(tables_sorted): + if isinstance(tbl, dict) and "id" in tbl: + table_id_map[tbl["id"]] = idx + tables.append({ + "id": tbl["id"], + "index": idx, + "keys": [int(k) for k in tbl.get("keys", [])], + "values": [int(v) for v in tbl.get("values", [])], + }) + # Flatten states and transitions (REQ-ARCH-039) states_flat, transitions_flat, state_id_map = _flatten_states( data.get("states", []), action_id_map, fact_id_map, conditions, @@ -119,6 +137,8 @@ def canonicalize(data: dict[str, Any]) -> CanonicalModel: actions=actions, modes=modes, expressions=expressions, + tables=tables, + table_id_map=table_id_map, states=states_flat, transitions=transitions_flat, hazards=data.get("hazards", []), @@ -147,6 +167,7 @@ def canonicalize(data: dict[str, Any]) -> CanonicalModel: "min": "min", "max": "max", "clamp": "clamp", "shift_r": "shift_r", "shift_l": "shift_l", "scale": "scale", "accumulate": "accumulate", + "lookup": "lookup", } @@ -194,7 +215,7 @@ def _flatten_expressions( op = _EXPR_OP_ALIASES.get(expr.get("op", "assign"), "assign") scale = int(expr.get("scale", 1)) - out.append({ + entry: dict[str, Any] = { "target_fact_id": target_id, "op": op, "left_fact_id": left_fact_id, @@ -202,7 +223,11 @@ def _flatten_expressions( "right_fact_id": right_fact_id, "right_literal": right_literal, "scale": scale, - }) + } + # Lookup: store table name for late binding (resolved by emitter) + if op == "lookup" and "table" in expr: + entry["table"] = expr["table"] + out.append(entry) return out @@ -221,13 +246,19 @@ def _flatten_conditions( for cond in group: if not isinstance(cond, dict): continue - flat = { + flat: dict[str, Any] = { "group": group_type, "fact": cond.get("fact", ""), "fact_id": fact_id_map.get(cond.get("fact", ""), 0), "op": cond.get("op", "=="), "value": cond.get("value", 0), } + # Hysteresis: map rising → value, falling → aux_value + if cond.get("op") == "hysteresis": + flat["value"] = int(cond.get("rising", 0)) + flat["aux_value"] = int(cond.get("falling", 0)) + flat["rising"] = flat["value"] + flat["falling"] = flat["aux_value"] conditions.append(flat) diff --git a/python/arbiter/emit_blob.py b/python/arbiter/emit_blob.py index aa5a5bd..dc3cd67 100644 --- a/python/arbiter/emit_blob.py +++ b/python/arbiter/emit_blob.py @@ -48,7 +48,7 @@ # Wire sizes for packed structs (all little-endian, uint16 indices) _FACT_ELEM_SIZE = 16 # id(2) + type(1) + pad(1) + range_min(4) + range_max(4) + default(4) + stale(2) + safety(1) + pad(1) => rearranged below _RULE_ELEM_SIZE = 20 -_COND_ELEM_SIZE = 12 +_COND_ELEM_SIZE = 16 # added aux_value (int32) _EXPR_ELEM_SIZE = 20 _ACTION_ELEM_SIZE = 12 @@ -63,6 +63,7 @@ "==": 0, "!=": 1, "<": 2, "<=": 3, ">": 4, ">=": 5, "in": 6, "not_in": 7, "stale": 8, "not_stale": 9, "changed": 10, "delta_gt": 11, "delta_lt": 12, + "hysteresis": 13, } _COND_GROUP_MAP = {"all": 0, "any": 1, "not": 2} _EXPR_OP_MAP = { @@ -70,6 +71,7 @@ "abs": 5, "negate": 6, "min": 7, "max": 8, "clamp": 9, "shift_r": 10, "shift_l": 11, "scale": 12, "assign": 13, "accumulate": 14, + "lookup": 15, } _ACTION_TYPE_MAP = { "callback": 0, "log": 1, "notify": 2, "set_fact": 3, @@ -191,11 +193,12 @@ def _pack_rules(model: CanonicalModel) -> bytes: def _pack_conditions(model: CanonicalModel) -> bytes: """Pack condition definitions. - Wire layout per condition (12 bytes): + Wire layout per condition (16 bytes): fact_id: uint16 LE op: uint8 group: uint8 value: int32 LE + aux_value: int32 LE group_index: uint16 LE next: uint16 LE """ @@ -207,7 +210,10 @@ def _pack_conditions(model: CanonicalModel) -> bytes: val = c.get("value", 0) if isinstance(val, bool): val = 1 if val else 0 - buf += struct.pack(" bytes: if model.rules: sections.append((SECTION_RULES, rules_data, len(model.rules), rule_elem)) - cond_elem = 12 + cond_elem = 16 if model.conditions: sections.append((SECTION_CONDITIONS, cond_data, len(model.conditions), cond_elem)) diff --git a/python/arbiter/emit_c.py b/python/arbiter/emit_c.py index c5f7638..2545beb 100644 --- a/python/arbiter/emit_c.py +++ b/python/arbiter/emit_c.py @@ -16,6 +16,7 @@ "stale": "ARBITER_OP_STALE", "not_stale": "ARBITER_OP_NOT_STALE", "changed": "ARBITER_OP_CHANGED", "delta_gt": "ARBITER_OP_DELTA_GT", "delta_lt": "ARBITER_OP_DELTA_LT", + "hysteresis": "ARBITER_OP_HYSTERESIS", } _TYPE_MAP = { @@ -45,6 +46,7 @@ "shift_r": "ARBITER_EXPR_SHIFT_R", "shift_l": "ARBITER_EXPR_SHIFT_L", "scale": "ARBITER_EXPR_SCALE", "accumulate": "ARBITER_EXPR_ACCUMULATE", "assign": "ARBITER_EXPR_ASSIGN", + "lookup": "ARBITER_EXPR_LOOKUP", } @@ -158,9 +160,12 @@ def emit_c_source(model: CanonicalModel, header_name: str = "arbiter_model.h", val = c.get("value", 0) if isinstance(val, bool): val = 1 if val else 0 + aux_val = c.get("aux_value", 0) + if isinstance(aux_val, bool): + aux_val = 1 if aux_val else 0 lines.append( f"\t{{ .fact_id = {c.get('fact_id', 0)}, .op = {op_enum}, " - f".value = {val}, .group = {grp}, " + f".value = {val}, .aux_value = {aux_val}, .group = {grp}, " f".group_index = 0, .next = UINT16_MAX }}," ) if not model.conditions: @@ -261,6 +266,33 @@ def emit_c_source(model: CanonicalModel, header_name: str = "arbiter_model.h", lines.append("") # If no expressions, we emit nothing here and use NULL directly below. + # Tables — emit key/value arrays and table_def array + tables = getattr(model, "tables", []) + table_id_map = getattr(model, "table_id_map", {}) + if tables: + for tbl in tables: + tid = tbl["id"].replace(".", "_") + keys = tbl["keys"] + vals = tbl["values"] + lines.append(f"static const int32_t table_{tid}_keys[] = {{ {', '.join(str(k) for k in keys)} }};") + lines.append(f"static const int32_t table_{tid}_values[] = {{ {', '.join(str(v) for v in vals)} }};") + lines.append("") + lines.append("static const struct ARBITER_table_def model_tables[] = {") + for tbl in tables: + tid = tbl["id"].replace(".", "_") + count = len(tbl["keys"]) + lines.append(f"\t{{ .count = {count}, .keys = table_{tid}_keys, .values = table_{tid}_values }},") + lines.append("};") + lines.append("") + + # Resolve table references in expressions + if tables and expressions: + for e in expressions: + if e.get("op") == "lookup" and "table" in e: + tbl_name = e["table"] + tbl_idx = table_id_map.get(tbl_name, 0) + e["scale"] = tbl_idx + # Mode names if model.modes and emit_trace_strings: lines.append("static const char *model_mode_names[] = {") @@ -280,6 +312,8 @@ def emit_c_source(model: CanonicalModel, header_name: str = "arbiter_model.h", expr_count_total = len(getattr(model, "expressions", [])) exprs_field = "model_expressions" if expressions else "NULL" + tables_field = "model_tables" if tables else "NULL" + table_count = len(tables) lines.extend([ "const struct ARBITER_model ARBITER_generated_model = {", f'\t.name = "{model.name}",', @@ -297,6 +331,8 @@ def emit_c_source(model: CanonicalModel, header_name: str = "arbiter_model.h", "\t.actions = model_actions,", f"\t.expressions = {exprs_field},", "\t.mode_names = model_mode_names,", + f"\t.tables = {tables_field},", + f"\t.table_count = {table_count},", "};", "", ]) diff --git a/python/arbiter/evaluator.py b/python/arbiter/evaluator.py index 81a6770..bf54779 100644 --- a/python/arbiter/evaluator.py +++ b/python/arbiter/evaluator.py @@ -97,6 +97,26 @@ def _saturate32(value: int) -> int: return value +def _table_lookup(keys: list[int], values: list[int], input_val: int) -> int: + """Linear interpolation lookup matching the C engine.""" + n = len(keys) + if n == 0: + return 0 + if input_val <= keys[0]: + return values[0] + if input_val >= keys[n - 1]: + return values[n - 1] + for i in range(1, n): + if input_val <= keys[i]: + k0, k1 = keys[i - 1], keys[i] + v0, v1 = values[i - 1], values[i] + dk = k1 - k0 + if dk == 0: + return v0 + return v0 + (v1 - v0) * (input_val - k0) // dk + return values[n - 1] + + # --------------------------------------------------------------------------- # Evaluator # --------------------------------------------------------------------------- @@ -135,6 +155,9 @@ def __init__(self, model_data: dict[str, Any]) -> None: # Snapshot timestamp (set by caller for staleness tests). self._snapshot_ts: int = 0 + # Hysteresis per-condition state (persistent across evals). + self._hyst_state: dict[str, bool] = {} + # Faults (persistent across evals until cleared). self._raised_faults: set[str] = set() @@ -363,6 +386,20 @@ def _eval_single_condition( delta = abs(fact_val - prev) return delta < int(threshold) + if op == "hysteresis": + rising = int(cond.get("rising", cond.get("value", 0))) + falling = int(cond.get("falling", cond.get("aux_value", 0))) + key = f"{fact_name}:{rising}:{falling}" + prev_state = self._hyst_state.get(key, False) + if fact_val >= rising: + result = True + elif fact_val <= falling: + result = False + else: + result = prev_state + self._hyst_state[key] = result + return result + return False # unknown operator @staticmethod @@ -401,6 +438,12 @@ def _exec_expression(self, expr: dict[str, Any]) -> None: ) scale = expr.get("scale", 1) + # Lookup: resolve table name to index if not yet resolved + if op == "lookup" and "table" in expr: + table_id_map = getattr(self._model, "table_id_map", {}) + tbl_idx = table_id_map.get(expr["table"], 0) + scale = tbl_idx + result = self._compute_op(op, target_name, left, right, scale) self._fact_values[target_name] = result @@ -480,6 +523,19 @@ def _compute_op( wide = left * right return _saturate32(current + int(wide / scale)) + if op == "lookup": + # scale field holds the table index + tbl_idx = scale + tables = getattr(self._model, "tables", []) + if 0 <= tbl_idx < len(tables): + tbl = tables[tbl_idx] + return _saturate32(_table_lookup( + tbl.get("keys", []), + tbl.get("values", []), + left, + )) + return 0 + return 0 # unknown op def _resolve_operand(self, fact_id: int, literal: int) -> int: diff --git a/schema/arb.schema.json b/schema/arb.schema.json index 9789c8c..3249851 100644 --- a/schema/arb.schema.json +++ b/schema/arb.schema.json @@ -92,6 +92,11 @@ "type": "array", "items": { "$ref": "#/$defs/state" } }, + "tables": { + "type": "array", + "items": { "$ref": "#/$defs/table" }, + "description": "Lookup tables for interpolation." + }, "runtime": { "type": "object" }, "tests": { "type": "array" }, "metadata": { "type": "object" } @@ -126,9 +131,11 @@ "fact": { "type": "string" }, "op": { "type": "string", - "enum": ["==", "!=", "<", "<=", ">", ">=", "in", "not_in", "stale", "not_stale", "changed", "delta_gt", "delta_lt"] + "enum": ["==", "!=", "<", "<=", ">", ">=", "in", "not_in", "stale", "not_stale", "changed", "delta_gt", "delta_lt", "hysteresis"] }, - "value": {} + "value": {}, + "rising": { "type": "integer", "description": "Rising threshold for hysteresis operator." }, + "falling": { "type": "integer", "description": "Falling threshold for hysteresis operator." } } }, "rule": { @@ -173,6 +180,25 @@ } } }, + "table": { + "type": "object", + "required": ["id", "keys", "values"], + "properties": { + "id": { "type": "string" }, + "keys": { + "type": "array", + "items": { "type": "integer" }, + "minItems": 2, + "description": "Sorted input key values for the lookup." + }, + "values": { + "type": "array", + "items": { "type": "integer" }, + "minItems": 2, + "description": "Output values corresponding to each key." + } + } + }, "expression": { "type": "object", "required": ["target", "op"], @@ -180,13 +206,14 @@ "target": { "type": "string", "description": "Fact id to write the result to." }, "op": { "type": "string", - "enum": ["add", "sub", "mul", "div", "mod", "abs", "negate", "min", "max", "clamp", "shift_r", "shift_l", "scale", "assign", "accumulate"] + "enum": ["add", "sub", "mul", "div", "mod", "abs", "negate", "min", "max", "clamp", "shift_r", "shift_l", "scale", "assign", "accumulate", "lookup"] }, "left": { "type": "string", "description": "Left operand fact id." }, "left_literal": { "type": "integer", "description": "Left operand literal value." }, "right": { "type": "string", "description": "Right operand fact id." }, "right_literal": { "type": "integer", "description": "Right operand literal value." }, - "scale": { "type": "integer", "description": "Divisor for scale/accumulate, or hi bound for clamp." } + "scale": { "type": "integer", "description": "Divisor for scale/accumulate, or hi bound for clamp." }, + "table": { "type": "string", "description": "Table id for lookup expressions." } } }, "action": { diff --git a/tests/python/test_hysteresis_lookup.py b/tests/python/test_hysteresis_lookup.py new file mode 100644 index 0000000..a770d25 --- /dev/null +++ b/tests/python/test_hysteresis_lookup.py @@ -0,0 +1,366 @@ +# SPDX-License-Identifier: MIT +"""Tests for hysteresis condition operator and lookup table support.""" + +from __future__ import annotations + +import pytest + +from arbiter.evaluator import ArbiterEvaluator, _table_lookup + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _model(facts=None, rules=None, actions=None, modes=None, tables=None, *, name="test_model"): + m = { + "arb_version": 0.1, + "model": name, + "target": {"rtos": "zephyr"}, + "facts": facts or [], + "rules": rules or [], + "actions": actions or [], + "modes": modes or [], + } + if tables: + m["tables"] = tables + return m + + +def _fact(fid, ftype="int32", **kwargs): + return {"id": fid, "type": ftype, **kwargs} + + +def _rule(rid, when=None, then=None, rclass="inference"): + r = {"id": rid, "class": rclass} + if when is not None: + r["when"] = when + if then is not None: + r["then"] = then + return r + + +# =================================================================== +# HYSTERESIS OPERATOR +# =================================================================== + + +class TestHysteresis: + """Test the hysteresis condition operator.""" + + def test_rising_edge_triggers(self): + """Value >= rising triggers condition to true.""" + m = _model( + facts=[_fact("temp")], + rules=[_rule("r", when={"all": [ + {"fact": "temp", "op": "hysteresis", "rising": 80, "falling": 60}, + ]})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("temp", 85) + r = ev.eval() + assert r.fired_rules == ["r"] + + def test_falling_edge_clears(self): + """Value <= falling clears condition to false.""" + m = _model( + facts=[_fact("temp")], + rules=[_rule("r", when={"all": [ + {"fact": "temp", "op": "hysteresis", "rising": 80, "falling": 60}, + ]})], + ) + ev = ArbiterEvaluator(m) + # First trigger high + ev.set_fact("temp", 85) + r1 = ev.eval() + assert r1.fired_rules == ["r"] + # Then drop below falling + ev.set_fact("temp", 55) + r2 = ev.eval() + assert r2.fired_rules == [] + + def test_deadband_holds_true(self): + """Value between falling and rising holds previous true state.""" + m = _model( + facts=[_fact("temp")], + rules=[_rule("r", when={"all": [ + {"fact": "temp", "op": "hysteresis", "rising": 80, "falling": 60}, + ]})], + ) + ev = ArbiterEvaluator(m) + # Trigger high + ev.set_fact("temp", 85) + ev.eval() + # Drop into deadband — should stay true + ev.set_fact("temp", 70) + r = ev.eval() + assert r.fired_rules == ["r"] + + def test_deadband_holds_false(self): + """Value between falling and rising holds previous false state.""" + m = _model( + facts=[_fact("temp")], + rules=[_rule("r", when={"all": [ + {"fact": "temp", "op": "hysteresis", "rising": 80, "falling": 60}, + ]})], + ) + ev = ArbiterEvaluator(m) + # Start in deadband — never triggered → false + ev.set_fact("temp", 70) + r = ev.eval() + assert r.fired_rules == [] + + def test_exact_rising_threshold(self): + """Value exactly at rising threshold triggers true.""" + m = _model( + facts=[_fact("temp")], + rules=[_rule("r", when={"all": [ + {"fact": "temp", "op": "hysteresis", "rising": 80, "falling": 60}, + ]})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("temp", 80) + r = ev.eval() + assert r.fired_rules == ["r"] + + def test_exact_falling_threshold(self): + """Value exactly at falling threshold clears to false.""" + m = _model( + facts=[_fact("temp")], + rules=[_rule("r", when={"all": [ + {"fact": "temp", "op": "hysteresis", "rising": 80, "falling": 60}, + ]})], + ) + ev = ArbiterEvaluator(m) + # Trigger + ev.set_fact("temp", 85) + ev.eval() + # Drop to exactly falling + ev.set_fact("temp", 60) + r = ev.eval() + assert r.fired_rules == [] + + def test_hysteresis_pid_enable_disable(self): + """PID-like model: hysteresis controls enable/disable.""" + m = _model( + facts=[_fact("speed"), _fact("enabled", "bool")], + rules=[ + _rule("enable_pid", when={"all": [ + {"fact": "speed", "op": "hysteresis", "rising": 1000, "falling": 500}, + ]}, then={"compute": [ + {"target": "enabled", "op": "assign", "left_literal": 1}, + ]}), + ], + ) + ev = ArbiterEvaluator(m) + # Below both thresholds + ev.set_fact("speed", 400) + ev.eval() + assert ev._fact_values["enabled"] == 0 + + # Above rising + ev.set_fact("speed", 1200) + ev.eval() + assert ev._fact_values["enabled"] == 1 + + def test_hysteresis_persists_across_evals(self): + """Hysteresis state persists across multiple eval() cycles.""" + m = _model( + facts=[_fact("temp")], + rules=[_rule("r", when={"all": [ + {"fact": "temp", "op": "hysteresis", "rising": 80, "falling": 60}, + ]})], + ) + ev = ArbiterEvaluator(m) + # Trigger + ev.set_fact("temp", 90) + assert ev.eval().fired_rules == ["r"] + # Deadband + ev.set_fact("temp", 75) + assert ev.eval().fired_rules == ["r"] + ev.set_fact("temp", 65) + assert ev.eval().fired_rules == ["r"] + # Drop below falling + ev.set_fact("temp", 59) + assert ev.eval().fired_rules == [] + # Deadband again — stays false + ev.set_fact("temp", 70) + assert ev.eval().fired_rules == [] + + +# =================================================================== +# LOOKUP TABLE — _table_lookup helper +# =================================================================== + + +class TestTableLookupHelper: + """Test the _table_lookup helper directly.""" + + def test_exact_key(self): + assert _table_lookup([0, 25, 50, 75, 100], [33000, 10000, 3300, 1200, 470], 25) == 10000 + + def test_exact_first_key(self): + assert _table_lookup([0, 50, 100], [0, 500, 1000], 0) == 0 + + def test_exact_last_key(self): + assert _table_lookup([0, 50, 100], [0, 500, 1000], 100) == 1000 + + def test_interpolation_midpoint(self): + # Between 0→0 and 100→1000, at key 50 → 500 + assert _table_lookup([0, 100], [0, 1000], 50) == 500 + + def test_interpolation_quarter(self): + # Between 0→0 and 100→1000, at key 25 → 250 + assert _table_lookup([0, 100], [0, 1000], 25) == 250 + + def test_below_min_clamps(self): + assert _table_lookup([10, 50, 100], [100, 500, 1000], -5) == 100 + + def test_above_max_clamps(self): + assert _table_lookup([10, 50, 100], [100, 500, 1000], 200) == 1000 + + def test_ntc_curve(self): + """NTC thermistor curve: interpolation between known points.""" + keys = [0, 25, 50, 75, 100] + values = [33000, 10000, 3300, 1200, 470] + # At 0 → 33000 + assert _table_lookup(keys, values, 0) == 33000 + # At 100 → 470 + assert _table_lookup(keys, values, 100) == 470 + # Midpoint between 0 and 25: (33000 + 10000) / 2 ≈ 21500 + # Actually: 33000 + (10000-33000)*(12-0)/(25-0) = 33000 + (-23000*12/25) = 33000 - 11040 = 21960 + result = _table_lookup(keys, values, 12) + assert 21000 < result < 22500 # approximate + + def test_empty_table(self): + assert _table_lookup([], [], 50) == 0 + + def test_descending_values(self): + """Tables with descending values (inverse relationship).""" + keys = [0, 50, 100] + values = [1000, 500, 0] + assert _table_lookup(keys, values, 25) == 750 # (1000+500)/2 = 750 + + +# =================================================================== +# LOOKUP TABLE — via evaluator +# =================================================================== + + +class TestLookupTableEvaluator: + """Test lookup table support through the evaluator.""" + + def test_lookup_at_exact_key(self): + m = _model( + facts=[_fact("temperature"), _fact("resistance")], + tables=[{ + "id": "ntc_curve", + "keys": [0, 25, 50, 75, 100], + "values": [33000, 10000, 3300, 1200, 470], + }], + rules=[_rule("r", then={"compute": [ + {"target": "resistance", "op": "lookup", "table": "ntc_curve", "left": "temperature"}, + ]})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("temperature", 25) + ev.eval() + assert ev._fact_values["resistance"] == 10000 + + def test_lookup_interpolation(self): + m = _model( + facts=[_fact("input"), _fact("output")], + tables=[{ + "id": "linear", + "keys": [0, 100], + "values": [0, 1000], + }], + rules=[_rule("r", then={"compute": [ + {"target": "output", "op": "lookup", "table": "linear", "left": "input"}, + ]})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("input", 50) + ev.eval() + assert ev._fact_values["output"] == 500 + + def test_lookup_below_min(self): + m = _model( + facts=[_fact("input"), _fact("output")], + tables=[{ + "id": "tbl", + "keys": [10, 100], + "values": [100, 1000], + }], + rules=[_rule("r", then={"compute": [ + {"target": "output", "op": "lookup", "table": "tbl", "left": "input"}, + ]})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("input", -10) + ev.eval() + assert ev._fact_values["output"] == 100 # clamped to first value + + def test_lookup_above_max(self): + m = _model( + facts=[_fact("input"), _fact("output")], + tables=[{ + "id": "tbl", + "keys": [10, 100], + "values": [100, 1000], + }], + rules=[_rule("r", then={"compute": [ + {"target": "output", "op": "lookup", "table": "tbl", "left": "input"}, + ]})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("input", 200) + ev.eval() + assert ev._fact_values["output"] == 1000 # clamped to last value + + def test_lookup_with_condition(self): + """Lookup only runs when rule fires.""" + m = _model( + facts=[_fact("enable", "bool"), _fact("input"), _fact("output")], + tables=[{ + "id": "tbl", + "keys": [0, 100], + "values": [0, 1000], + }], + rules=[_rule("r", + when={"all": [{"fact": "enable", "op": "==", "value": 1}]}, + then={"compute": [ + {"target": "output", "op": "lookup", "table": "tbl", "left": "input"}, + ]}, + )], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("enable", False) + ev.set_fact("input", 50) + ev.eval() + assert ev._fact_values["output"] == 0 # rule didn't fire + + ev.set_fact("enable", True) + ev.eval() + assert ev._fact_values["output"] == 500 # rule fired + + def test_multiple_tables(self): + """Multiple tables can coexist in one model.""" + m = _model( + facts=[_fact("temp"), _fact("pressure"), _fact("r_temp"), _fact("r_press")], + tables=[ + {"id": "temp_tbl", "keys": [0, 100], "values": [0, 1000]}, + {"id": "press_tbl", "keys": [0, 100], "values": [0, 5000]}, + ], + rules=[_rule("r", then={"compute": [ + {"target": "r_temp", "op": "lookup", "table": "temp_tbl", "left": "temp"}, + {"target": "r_press", "op": "lookup", "table": "press_tbl", "left": "pressure"}, + ]})], + ) + ev = ArbiterEvaluator(m) + ev.set_fact("temp", 50) + ev.set_fact("pressure", 50) + ev.eval() + assert ev._fact_values["r_temp"] == 500 + assert ev._fact_values["r_press"] == 2500