diff --git a/include/arbiter/arbiter.h b/include/arbiter/arbiter.h index 159324a..6adf97a 100644 --- a/include/arbiter/arbiter.h +++ b/include/arbiter/arbiter.h @@ -149,6 +149,35 @@ int ARBITER_get_requested_actions(const struct ARBITER_result *result, */ uint32_t ARBITER_get_last_eval_op_count(const struct ARBITER_ctx *ctx); +/** + * @brief Check whether a model meets a minimum version requirement. + * + * @param model Compiled model to check. + * @param min_major Minimum required major version. + * @param min_minor Minimum required minor version. + * @return ARBITER_OK if the model meets the requirement, + * ARBITER_EMODEL if it does not, + * ARBITER_EINVAL if model is NULL. + */ +int ARBITER_check_version(const struct ARBITER_model *model, + uint8_t min_major, uint8_t min_minor); + +#if defined(CONFIG_ARBITER_HOT_SWAP) && CONFIG_ARBITER_HOT_SWAP +/** + * @brief Hot-swap a running context to a new model. + * + * Validates the new model, preserves current fact values where the fact + * index and type match, and atomically swaps the model pointer. + * + * @param ctx Initialized context. + * @param new_model New compiled model. + * @return ARBITER_OK on success, ARBITER_EINVAL on bad pointers, + * ARBITER_EMODEL if the new model fails validation. + */ +int ARBITER_hot_swap(struct ARBITER_ctx *ctx, + const struct ARBITER_model *new_model); +#endif /* CONFIG_ARBITER_HOT_SWAP */ + /** * @brief Get the arbiter version string. */ diff --git a/include/arbiter/arbiter_model.h b/include/arbiter/arbiter_model.h index 8dcf9e1..36d2967 100644 --- a/include/arbiter/arbiter_model.h +++ b/include/arbiter/arbiter_model.h @@ -166,6 +166,7 @@ struct ARBITER_rule_def { /** Complete compiled model. */ struct ARBITER_model { const char *name; + uint8_t version[3]; /**< Model version: [major, minor, patch]. */ const uint8_t model_hash[32]; const uint8_t schema_hash[32]; arbiter_index_t fact_count; diff --git a/lib/arbiter_blob.c b/lib/arbiter_blob.c index bb78762..0503d7c 100644 --- a/lib/arbiter_blob.c +++ b/lib/arbiter_blob.c @@ -1,5 +1,14 @@ /* SPDX-License-Identifier: MIT */ +/** + * @file arbiter_blob.c + * @brief Binary blob (.zrmb) model loader for arbiter. + * + * Parses and validates a ZRMB binary blob produced by emit_blob.py, + * reconstructing an ARBITER_model from the packed section data. + * No dynamic allocation — all storage is static. + */ + #include #include #include @@ -7,110 +16,487 @@ LOG_MODULE_DECLARE(arbiter, CONFIG_ARBITER_LOG_LEVEL); -#define ZRMB_MAGIC 0x424D525A /* "ZRMB" little-endian */ -#define ZRMB_VERSION 1 -#define ZRMB_HEADER_LEN 84 /* magic(4)+ver(2)+flags(2)+hlen(4)+tlen(4)+mhash(32)+shash(32)+crc(4) */ - -struct zrmb_header { - uint32_t magic; - uint16_t version; - uint16_t flags; - uint32_t header_len; - uint32_t total_len; - uint8_t model_hash[32]; - uint8_t schema_hash[32]; - uint32_t crc32; +/* ── ZRMB header layout ─────────────────────────────────────────── */ + +#define ZRMB_MAGIC_0 'Z' +#define ZRMB_MAGIC_1 'R' +#define ZRMB_MAGIC_2 'M' +#define ZRMB_MAGIC_3 'B' + +#define ZRMB_VERSION 1 +#define ZRMB_HEADER_LEN 84 + +/* Section types (must match emit_blob.py) */ +#define SECTION_FACTS 1 +#define SECTION_RULES 2 +#define SECTION_CONDITIONS 3 +#define SECTION_EXPRESSIONS 4 +#define SECTION_ACTIONS 5 +#define SECTION_STRINGS 6 +#define SECTION_MODES 7 +#define SECTION_TYPE_MAX 7 + +/* Wire sizes produced by emit_blob.py */ +#define WIRE_FACT_SIZE 16 +#define WIRE_RULE_SIZE 20 +#define WIRE_COND_SIZE 12 +#define WIRE_EXPR_SIZE 20 +#define WIRE_ACTION_SIZE 12 +#define WIRE_MODE_SIZE 2 + +/* Section table entry (8 bytes, packed little-endian) */ +struct blob_section_entry { + uint8_t type; + uint8_t pad; + uint16_t offset; + uint16_t count; + uint16_t elem_size; }; -/** - * Simple CRC-32 (IEEE 802.3 polynomial). - */ -static uint32_t compute_crc32(const uint8_t *data, size_t len) +#define SECTION_ENTRY_SIZE 8 + +/* ── Static storage (no malloc) ──────────────────────────────────── */ + +#ifndef CONFIG_ARBITER_MAX_RULES +#define CONFIG_ARBITER_MAX_RULES 64 +#endif + +#ifndef CONFIG_ARBITER_MAX_CONDITIONS +#define CONFIG_ARBITER_MAX_CONDITIONS 256 +#endif + +#ifndef CONFIG_ARBITER_MAX_EXPRESSIONS +#define CONFIG_ARBITER_MAX_EXPRESSIONS 256 +#endif + +#ifndef CONFIG_ARBITER_MAX_ACTIONS_PER_EVAL +#define CONFIG_ARBITER_MAX_ACTIONS_PER_EVAL 16 +#endif + +#ifndef CONFIG_ARBITER_MAX_MODES +#define CONFIG_ARBITER_MAX_MODES 16 +#endif + +static struct ARBITER_fact_def blob_facts[CONFIG_ARBITER_MAX_FACTS]; +static struct ARBITER_rule_def blob_rules[CONFIG_ARBITER_MAX_RULES]; +static struct ARBITER_condition_def blob_conditions[CONFIG_ARBITER_MAX_CONDITIONS]; +static struct ARBITER_expr_def blob_expressions[CONFIG_ARBITER_MAX_EXPRESSIONS]; +static struct ARBITER_action_def blob_actions[CONFIG_ARBITER_MAX_ACTIONS_PER_EVAL]; +static const char *blob_mode_names[CONFIG_ARBITER_MAX_MODES]; + +/* ── CRC-32 (ISO 3309 / ITU-T V.42) ─────────────────────────────── */ + +static uint32_t blob_crc32(const uint8_t *__restrict data, size_t len) { - uint32_t crc = 0xFFFFFFFF; + uint32_t crc = 0xFFFFFFFFU; for (size_t i = 0; i < len; i++) { crc ^= data[i]; - for (int j = 0; j < 8; j++) { - if (crc & 1) { - crc = (crc >> 1) ^ 0xEDB88320; + for (int bit = 0; bit < 8; bit++) { + if (crc & 1U) { + crc = (crc >> 1) ^ 0xEDB88320U; } else { crc >>= 1; } } } - return ~crc; + return crc ^ 0xFFFFFFFFU; +} + +/* ── Little-endian helpers ───────────────────────────────────────── */ + +static inline uint16_t read_u16(const uint8_t *p) +{ + return (uint16_t)p[0] | ((uint16_t)p[1] << 8); } -int ARBITER_blob_load(const uint8_t *blob, size_t blob_len, - struct ARBITER_model *model_out) +static inline uint32_t read_u32(const uint8_t *p) { - if (blob == NULL || model_out == NULL) { - return ARBITER_EINVAL; + return (uint32_t)p[0] + | ((uint32_t)p[1] << 8) + | ((uint32_t)p[2] << 16) + | ((uint32_t)p[3] << 24); +} + +static inline int32_t read_i32(const uint8_t *p) +{ + uint32_t u = read_u32(p); + int32_t s; + + memcpy(&s, &u, sizeof(s)); + return s; +} + +/* ── Section parsers ─────────────────────────────────────────────── */ + +static int parse_facts(const uint8_t *__restrict data, uint16_t count, + struct ARBITER_model *__restrict m) +{ + if (unlikely(count > CONFIG_ARBITER_MAX_FACTS)) { + LOG_ERR("blob: %u facts exceeds max %d", count, + CONFIG_ARBITER_MAX_FACTS); + return ARBITER_EMODEL; + } + + for (uint16_t i = 0; i < count; i++) { + const uint8_t *p = data + (size_t)i * WIRE_FACT_SIZE; + + blob_facts[i].id = read_u16(p); + blob_facts[i].type = (enum ARBITER_fact_type)p[2]; + blob_facts[i].safety_relevant = p[3] ? true : false; + blob_facts[i].range_min = read_i32(p + 4); + blob_facts[i].range_max = read_i32(p + 8); + blob_facts[i].default_value = read_i32(p + 12); + blob_facts[i].stale_after_ms = 0; +#if !defined(CONFIG_ARBITER_STRINGS) || CONFIG_ARBITER_STRINGS + blob_facts[i].name = NULL; +#endif + } + + m->facts = blob_facts; + m->fact_count = count; + return ARBITER_OK; +} + +static int parse_rules(const uint8_t *__restrict data, uint16_t count, + struct ARBITER_model *__restrict m) +{ + if (unlikely(count > CONFIG_ARBITER_MAX_RULES)) { + LOG_ERR("blob: %u rules exceeds max %d", count, + CONFIG_ARBITER_MAX_RULES); + return ARBITER_EMODEL; + } + + for (uint16_t i = 0; i < count; i++) { + const uint8_t *p = data + (size_t)i * WIRE_RULE_SIZE; + + blob_rules[i].id = read_u16(p); + blob_rules[i].rule_class = (enum ARBITER_rule_class)p[2]; + blob_rules[i].safety_critical = p[3] ? true : false; + blob_rules[i].condition_start = read_u16(p + 4); + blob_rules[i].condition_count = read_u16(p + 6); + blob_rules[i].action_start = read_u16(p + 8); + blob_rules[i].action_count = read_u16(p + 10); + blob_rules[i].expr_start = read_u16(p + 12); + blob_rules[i].expr_count = read_u16(p + 14); + blob_rules[i].safety_goal_id = read_u16(p + 16); + blob_rules[i].set_mode = read_u16(p + 18); +#if !defined(CONFIG_ARBITER_STRINGS) || CONFIG_ARBITER_STRINGS + blob_rules[i].name = NULL; + blob_rules[i].explanation = NULL; +#endif + } + + m->rules = blob_rules; + m->rule_count = count; + return ARBITER_OK; +} + +static int parse_conditions(const uint8_t *__restrict data, uint16_t count, + struct ARBITER_model *__restrict m) +{ + if (unlikely(count > CONFIG_ARBITER_MAX_CONDITIONS)) { + LOG_ERR("blob: %u conditions exceeds max %d", count, + CONFIG_ARBITER_MAX_CONDITIONS); + return ARBITER_EMODEL; + } + + for (uint16_t i = 0; i < count; i++) { + const uint8_t *p = data + (size_t)i * WIRE_COND_SIZE; + + blob_conditions[i].fact_id = read_u16(p); + blob_conditions[i].op = (enum ARBITER_op)p[2]; + blob_conditions[i].group = (enum ARBITER_cond_group)p[3]; + blob_conditions[i].value = read_i32(p + 4); + blob_conditions[i].group_index = read_u16(p + 8); + blob_conditions[i].next = read_u16(p + 10); + } + + m->conditions = blob_conditions; + m->condition_count = count; + return ARBITER_OK; +} + +static int parse_expressions(const uint8_t *__restrict data, uint16_t count, + struct ARBITER_model *__restrict m) +{ + if (unlikely(count > CONFIG_ARBITER_MAX_EXPRESSIONS)) { + LOG_ERR("blob: %u expressions exceeds max %d", count, + CONFIG_ARBITER_MAX_EXPRESSIONS); + return ARBITER_EMODEL; + } + + for (uint16_t i = 0; i < count; i++) { + const uint8_t *p = data + (size_t)i * WIRE_EXPR_SIZE; + + blob_expressions[i].target_fact_id = read_u16(p); + blob_expressions[i].op = (enum ARBITER_expr_op)p[2]; + /* p[3] is padding */ + blob_expressions[i].left_fact_id = read_u16(p + 4); + blob_expressions[i].right_fact_id = read_u16(p + 6); + blob_expressions[i].left_literal = read_i32(p + 8); + blob_expressions[i].right_literal = read_i32(p + 12); + blob_expressions[i].scale = read_i32(p + 16); } - if (blob_len < ZRMB_HEADER_LEN) { - LOG_ERR("Blob too small: %zu < %d", blob_len, ZRMB_HEADER_LEN); + m->expressions = blob_expressions; + m->expr_count = count; + return ARBITER_OK; +} + +static int parse_actions(const uint8_t *__restrict data, uint16_t count, + struct ARBITER_model *__restrict m) +{ + if (unlikely(count > CONFIG_ARBITER_MAX_ACTIONS_PER_EVAL)) { + LOG_ERR("blob: %u actions exceeds max %d", count, + CONFIG_ARBITER_MAX_ACTIONS_PER_EVAL); return ARBITER_EMODEL; } - const struct zrmb_header *hdr = (const struct zrmb_header *)blob; + for (uint16_t i = 0; i < count; i++) { + const uint8_t *p = data + (size_t)i * WIRE_ACTION_SIZE; - /* Validate magic */ - if (hdr->magic != ZRMB_MAGIC) { - LOG_ERR("Invalid blob magic: 0x%08x", hdr->magic); + blob_actions[i].id = read_u16(p); + blob_actions[i].type = (enum ARBITER_action_type)p[2]; + blob_actions[i].safe_state_action = p[3] ? true : false; + /* p[4..5] target_fact_id placeholder */ + blob_actions[i].target_fact_id = read_u16(p + 4); + blob_actions[i].must_complete_within_ms = read_u16(p + 6); + blob_actions[i].target_value = read_i32(p + 8); + blob_actions[i].callback = NULL; +#if !defined(CONFIG_ARBITER_STRINGS) || CONFIG_ARBITER_STRINGS + blob_actions[i].name = NULL; +#endif + } + + m->actions = blob_actions; + m->action_count = count; + return ARBITER_OK; +} + +static int parse_modes(const uint8_t *__restrict data, uint16_t count, + struct ARBITER_model *__restrict m) +{ + if (unlikely(count > CONFIG_ARBITER_MAX_MODES)) { + LOG_ERR("blob: %u modes exceeds max %d", count, + CONFIG_ARBITER_MAX_MODES); return ARBITER_EMODEL; } - /* Validate version */ - if (hdr->version != ZRMB_VERSION) { - LOG_ERR("Unsupported blob version: %u", hdr->version); + /* Mode entries are just uint16 indices; names come from strings. */ + for (uint16_t i = 0; i < count; i++) { + blob_mode_names[i] = NULL; + } + + m->mode_names = blob_mode_names; + m->mode_count = count; + return ARBITER_OK; +} + +/* ── Main loader ─────────────────────────────────────────────────── */ + +int ARBITER_blob_load(const uint8_t *__restrict blob, size_t blob_len, + struct ARBITER_model *__restrict model_out) +{ + if (unlikely(blob == NULL || model_out == NULL)) { + return ARBITER_EINVAL; + } + + /* Minimum size: header only */ + if (unlikely(blob_len < ZRMB_HEADER_LEN)) { + LOG_ERR("blob: too short (%zu < %d)", blob_len, + ZRMB_HEADER_LEN); return ARBITER_EMODEL; } - /* Validate lengths */ - if (hdr->total_len > blob_len) { - LOG_ERR("Blob total_len %u exceeds buffer %zu", - hdr->total_len, blob_len); + /* ── Validate magic ──────────────────────────────────────── */ + if (blob[0] != ZRMB_MAGIC_0 || blob[1] != ZRMB_MAGIC_1 || + blob[2] != ZRMB_MAGIC_2 || blob[3] != ZRMB_MAGIC_3) { + LOG_ERR("blob: bad magic"); return ARBITER_EMODEL; } - if (hdr->header_len < ZRMB_HEADER_LEN) { - LOG_ERR("Invalid header_len: %u", hdr->header_len); + /* ── Validate version ────────────────────────────────────── */ + uint16_t version = read_u16(blob + 4); + + if (unlikely(version != ZRMB_VERSION)) { + LOG_ERR("blob: unsupported version %u", version); return ARBITER_EMODEL; } - /* Validate CRC over everything except the CRC field itself. - * CRC field is at offset 80 (4 bytes before end of header). - */ - uint32_t crc = compute_crc32(blob, offsetof(struct zrmb_header, crc32)); - uint32_t crc_rest = compute_crc32( - blob + offsetof(struct zrmb_header, crc32) + sizeof(uint32_t), - hdr->total_len - offsetof(struct zrmb_header, crc32) - - sizeof(uint32_t)); - - /* Combine CRC segments (simplified: recompute over full blob - * skipping CRC field is complex; for v0 just verify basic integrity) */ - (void)crc_rest; - if (hdr->crc32 != 0 && hdr->crc32 != crc) { - LOG_WRN("CRC mismatch (header region): expected 0x%08x, got 0x%08x", - hdr->crc32, crc); - /* Continue for v0 - strict CRC enforcement in future */ - } - - /* Copy hashes to model */ + /* ── Read header fields ──────────────────────────────────── */ + uint16_t flags = read_u16(blob + 6); + uint32_t header_len = read_u32(blob + 8); + uint32_t total_len = read_u32(blob + 12); + + if (unlikely(header_len != ZRMB_HEADER_LEN)) { + LOG_ERR("blob: unexpected header_len %u", header_len); + return ARBITER_EMODEL; + } + + if (unlikely(total_len > blob_len)) { + LOG_ERR("blob: total_len %u > blob_len %zu", total_len, + blob_len); + return ARBITER_EMODEL; + } + + /* ── CRC-32 verification ─────────────────────────────────── */ + /* CRC-32 sits at bytes 80-83 of the header. */ + uint32_t stored_crc = read_u32(blob + 80); + + if (stored_crc != 0) { + /* CRC computed over entire blob with CRC field zeroed. + * Equivalent to CRC(bytes[0:80]) + CRC_continue(zeros[4]) + * + CRC_continue(bytes[84:total_len]). */ + uint32_t crc = 0xFFFFFFFFU; + size_t j; + + /* Hash bytes 0..79 */ + for (j = 0; j < 80; j++) { + crc ^= blob[j]; + for (int b = 0; b < 8; b++) { + crc = (crc & 1U) + ? (crc >> 1) ^ 0xEDB88320U + : (crc >> 1); + } + } + /* Hash 4 zero bytes (CRC field placeholder) */ + for (j = 0; j < 4; j++) { + crc ^= 0; + for (int b = 0; b < 8; b++) { + crc = (crc & 1U) + ? (crc >> 1) ^ 0xEDB88320U + : (crc >> 1); + } + } + /* Hash bytes 84..total_len-1 */ + for (j = 84; j < total_len; j++) { + crc ^= blob[j]; + for (int b = 0; b < 8; b++) { + crc = (crc & 1U) + ? (crc >> 1) ^ 0xEDB88320U + : (crc >> 1); + } + } + crc ^= 0xFFFFFFFFU; + + if (unlikely(crc != stored_crc)) { + LOG_ERR("blob: CRC mismatch (stored=0x%08x computed=0x%08x)", + stored_crc, crc); + return ARBITER_EMODEL; + } + } + + /* ── Initialize model output ─────────────────────────────── */ memset(model_out, 0, sizeof(*model_out)); - memcpy((void *)model_out->model_hash, hdr->model_hash, 32); - memcpy((void *)model_out->schema_hash, hdr->schema_hash, 32); + memcpy((void *)model_out->model_hash, blob + 16, 32); + memcpy((void *)model_out->schema_hash, blob + 48, 32); - LOG_INF("Blob loaded: version=%u total_len=%u", - hdr->version, hdr->total_len); + /* Decode version from flags: bits 0-7 major, 8-15 minor */ + model_out->version[0] = (uint8_t)(flags & 0xFF); + model_out->version[1] = (uint8_t)((flags >> 8) & 0xFF); + model_out->version[2] = 0; - /* Section parsing would follow here. For v0, the blob loader - * validates the header and reports success. Full section parsing - * (facts, rules, conditions, actions, strings) will be implemented - * in Milestone 3. + /* ── Determine section count ─────────────────────────────── */ + /* + * Header layout (84 bytes): + * [0..79] fixed fields (magic, version, flags, lengths, hashes) + * [80..83] CRC-32 + * + * Section table starts at offset 84 (= header_len). + * total_len = 84 + table_size + data_size. + * When total_len == 84, there are 0 sections. */ + size_t table_start = ZRMB_HEADER_LEN; /* 84 */ + size_t payload_end = total_len; + size_t payload_size = (payload_end > table_start) + ? (payload_end - table_start) : 0; + size_t num_sections = 0; + + if (payload_size >= SECTION_ENTRY_SIZE) { + /* Peek at first entry's data offset to derive section count: + * first_data_offset = table_start + num_sections * 8 + * => num_sections = (first_data_offset - table_start) / 8 */ + uint16_t first_offset = read_u16(blob + table_start + 2); + + if (first_offset >= table_start && + first_offset <= payload_end) { + num_sections = (first_offset - table_start) + / SECTION_ENTRY_SIZE; + } + } + + /* ── Parse sections ──────────────────────────────────────── */ + for (size_t s = 0; s < num_sections; s++) { + const uint8_t *entry = blob + table_start + + s * SECTION_ENTRY_SIZE; + uint8_t sec_type = entry[0]; + uint16_t sec_off = read_u16(entry + 2); + uint16_t sec_count = read_u16(entry + 4); + uint16_t sec_esz = read_u16(entry + 6); + + /* Bounds check — offsets are from start of blob */ + size_t sec_end = (size_t)sec_off + + (size_t)sec_count * (size_t)sec_esz; + if (unlikely(sec_end > total_len)) { + LOG_ERR("blob: section %u overflows (end=%zu > %u)", + sec_type, sec_end, total_len); + return ARBITER_EMODEL; + } + + const uint8_t *sec_data = blob + sec_off; + int rc = ARBITER_OK; + + switch (sec_type) { + case SECTION_FACTS: + rc = parse_facts(sec_data, sec_count, model_out); + break; + case SECTION_RULES: + rc = parse_rules(sec_data, sec_count, model_out); + break; + case SECTION_CONDITIONS: + rc = parse_conditions(sec_data, sec_count, model_out); + break; + case SECTION_EXPRESSIONS: + rc = parse_expressions(sec_data, sec_count, model_out); + break; + case SECTION_ACTIONS: + rc = parse_actions(sec_data, sec_count, model_out); + break; + case SECTION_STRINGS: + /* String table — used for name resolution; skip + * for now (names set to NULL above). */ + break; + case SECTION_MODES: + rc = parse_modes(sec_data, sec_count, model_out); + break; + default: + LOG_WRN("blob: unknown section type %u, skipping", + sec_type); + break; + } + + if (unlikely(rc != ARBITER_OK)) { + return rc; + } + } + + /* If no facts or rules were loaded, set empty defaults so + * ARBITER_init() doesn't fail on NULL pointers. */ + if (model_out->facts == NULL) { + model_out->facts = blob_facts; + model_out->fact_count = 0; + } + if (model_out->rules == NULL) { + model_out->rules = blob_rules; + model_out->rule_count = 0; + } + + LOG_INF("blob: loaded %u facts, %u rules, %u conditions, " + "%u expressions, %u actions, %u modes", + model_out->fact_count, model_out->rule_count, + model_out->condition_count, model_out->expr_count, + model_out->action_count, model_out->mode_count); return ARBITER_OK; } diff --git a/lib/arbiter_engine.c b/lib/arbiter_engine.c index e70c26b..c9b9298 100644 --- a/lib/arbiter_engine.c +++ b/lib/arbiter_engine.c @@ -67,6 +67,106 @@ uint32_t ARBITER_get_last_eval_op_count(const struct ARBITER_ctx *ctx) return ctx->last_eval_op_count; } +int ARBITER_check_version(const struct ARBITER_model *model, + uint8_t min_major, uint8_t min_minor) +{ + if (unlikely(model == NULL)) { + return ARBITER_EINVAL; + } + + if (model->version[0] > min_major) { + return ARBITER_OK; + } + if (model->version[0] == min_major && + model->version[1] >= min_minor) { + return ARBITER_OK; + } + + LOG_WRN("Model version %u.%u.%u < required %u.%u", + model->version[0], model->version[1], model->version[2], + min_major, min_minor); + return ARBITER_EMODEL; +} + +#if defined(CONFIG_ARBITER_HOT_SWAP) && CONFIG_ARBITER_HOT_SWAP +int ARBITER_hot_swap(struct ARBITER_ctx *ctx, + const struct ARBITER_model *new_model) +{ + if (unlikely(ctx == NULL || new_model == NULL)) { + return ARBITER_EINVAL; + } + + if (unlikely(!ctx->initialized)) { + LOG_ERR("hot_swap: context not initialized"); + return ARBITER_EINVAL; + } + + /* Validate new model basics */ + if (new_model->facts == NULL || new_model->rules == NULL) { + LOG_ERR("hot_swap: new model has NULL facts or rules"); + return ARBITER_EMODEL; + } + + if (new_model->fact_count > CONFIG_ARBITER_MAX_FACTS) { + LOG_ERR("hot_swap: new model has %u facts, max is %d", + new_model->fact_count, CONFIG_ARBITER_MAX_FACTS); + return ARBITER_EMODEL; + } + + /* + * Preserve fact values: for each fact in the new model that also + * existed in the old model (same index, same type), keep the + * current value. All other facts get their defaults. + */ + const struct ARBITER_model *__restrict old_model = ctx->model; + struct ARBITER_fact_value *__restrict fv = ctx->fact_values; + arbiter_index_t common = (new_model->fact_count < old_model->fact_count) + ? new_model->fact_count : old_model->fact_count; + + /* Reset facts beyond the common range to new defaults */ + for (arbiter_index_t i = common; i < new_model->fact_count; i++) { + int32_t def = new_model->facts[i].default_value; + + fv[i].value = def; + fv[i].prev_value = def; + fv[i].timestamp_ms = 0; + fv[i].valid = false; + fv[i].changed = false; + } + + /* For common facts, keep value if type matches; reset otherwise */ + for (arbiter_index_t i = 0; i < common; i++) { + if (new_model->facts[i].type != old_model->facts[i].type) { + int32_t def = new_model->facts[i].default_value; + + fv[i].value = def; + fv[i].prev_value = def; + fv[i].timestamp_ms = 0; + fv[i].valid = false; + fv[i].changed = false; + } + } + + /* Zero out facts beyond new count */ + for (arbiter_index_t i = new_model->fact_count; + i < old_model->fact_count && i < CONFIG_ARBITER_MAX_FACTS; i++) { + memset(&fv[i], 0, sizeof(fv[i])); + } + + /* Atomic swap: update model pointer and snapshot metadata */ + ctx->model = new_model; + ctx->snapshot.count = new_model->fact_count; + ctx->snapshot.frozen = false; + + LOG_INF("hot_swap: %s -> %s (facts: %u -> %u)", + old_model->name ? old_model->name : "unnamed", + new_model->name ? new_model->name : "unnamed", + old_model->fact_count, new_model->fact_count); + + return ARBITER_OK; +} +#endif /* CONFIG_ARBITER_HOT_SWAP */ + const char *ARBITER_version_string(void) { return ARBITER_VERSION_STRING; diff --git a/python/arbiter/canonical.py b/python/arbiter/canonical.py index f11e0b4..0b26b1f 100644 --- a/python/arbiter/canonical.py +++ b/python/arbiter/canonical.py @@ -23,6 +23,7 @@ class CanonicalModel: expressions: list[dict[str, Any]] = field(default_factory=list) hazards: list[dict[str, Any]] = field(default_factory=list) safety_goals: list[dict[str, Any]] = field(default_factory=list) + version: str | None = None model_hash: str = "" schema_hash: str = "" compiler_version: str = "0.1.0" @@ -113,6 +114,7 @@ def canonicalize(data: dict[str, Any]) -> CanonicalModel: expressions=expressions, hazards=data.get("hazards", []), safety_goals=data.get("safety_goals", []), + version=data.get("version"), fact_id_map=fact_id_map, rule_id_map=rule_id_map, action_id_map=action_id_map, diff --git a/python/arbiter/emit_blob.py b/python/arbiter/emit_blob.py index 07788b0..8b73204 100644 --- a/python/arbiter/emit_blob.py +++ b/python/arbiter/emit_blob.py @@ -1,71 +1,409 @@ # SPDX-License-Identifier: MIT -"""Binary blob (.zrmb) emitter for compiled ARB models.""" +"""Binary blob emitter for compiled ARB models (.zrmb format). + +Layout: + [ZRMB header — 84 bytes] + [Section table — N * 8 bytes] + [Section data …] + +Header (84 bytes): + magic 4B "ZRMB" + version 2B uint16 LE (currently 1) + flags 2B uint16 LE (bits 0-7: major, 8-15: minor packed version) + header_len 4B uint32 LE (always 84) + total_len 4B uint32 LE + model_hash 32B SHA-256 + schema_hash 32B SHA-256 + crc32 4B uint32 LE — CRC-32 over everything except this field + +Section table entry (8 bytes): + type 1B uint8 + pad 1B 0 + offset 2B uint16 LE — byte offset from start of blob + count 2B uint16 LE — element count + elem_size 2B uint16 LE — size of one element in bytes +""" from __future__ import annotations import struct import zlib +from typing import Any from .canonical import CanonicalModel -ZRMB_MAGIC = b"ZRMB" -ZRMB_VERSION = 1 +# Section type constants +SECTION_FACTS = 1 +SECTION_RULES = 2 +SECTION_CONDITIONS = 3 +SECTION_EXPRESSIONS = 4 +SECTION_ACTIONS = 5 +SECTION_STRINGS = 6 +SECTION_MODES = 7 +_HEADER_LEN = 84 +_SECTION_ENTRY_SIZE = 8 +_BLOB_VERSION = 1 -def emit_blob(model: CanonicalModel) -> bytes: - """Emit a .zrmb binary blob from a canonical model. +# Wire sizes for packed structs (all little-endian, uint16 indices) +_FACT_ELEM_SIZE = 16 # id(2) + type(1) + pad(1) + range_min(4) + range_max(4) + default(4) + stale(2) + safety(1) + pad(1) => rearranged below +_RULE_ELEM_SIZE = 20 +_COND_ELEM_SIZE = 12 +_EXPR_ELEM_SIZE = 20 +_ACTION_ELEM_SIZE = 12 + +# ── Operator / type maps matching arbiter_model.h enums ────────────── + +_TYPE_MAP = {"bool": 0, "int32": 1, "uint32": 2, "enum": 3} +_CLASS_MAP = { + "inference": 0, "constraint": 1, "mode_guard": 2, + "safety_guard": 3, "obligation": 4, "advisory": 5, +} +_OP_MAP = { + "==": 0, "!=": 1, "<": 2, "<=": 3, ">": 4, ">=": 5, + "in": 6, "not_in": 7, "stale": 8, "not_stale": 9, + "changed": 10, "delta_gt": 11, "delta_lt": 12, +} +_COND_GROUP_MAP = {"all": 0, "any": 1, "not": 2} +_EXPR_OP_MAP = { + "add": 0, "sub": 1, "mul": 2, "div": 3, "mod": 4, + "abs": 5, "negate": 6, "min": 7, "max": 8, "clamp": 9, + "shift_r": 10, "shift_l": 11, "scale": 12, "assign": 13, + "accumulate": 14, +} +_ACTION_TYPE_MAP = { + "callback": 0, "log": 1, "notify": 2, "set_fact": 3, + "set_mode": 4, "raise_fault": 5, "clear_fault": 6, +} + + +def _pack_hash(hex_str: str) -> bytes: + """Convert a 64-char hex hash to 32 bytes, zero-padded if short.""" + h = hex_str[:64].ljust(64, "0") + return bytes.fromhex(h) - Format v0: - magic: 4 bytes "ZRMB" - version: uint16_le - flags: uint16_le - header_len: uint32_le - total_len: uint32_le (placeholder, filled after) - model_hash: 32 bytes - schema_hash: 32 bytes - crc32: uint32_le (placeholder, filled after) + +def _encode_version_flags(model: CanonicalModel) -> int: + """Encode model version into the 16-bit flags field. + + Bits 0-7: major version, bits 8-15: minor version. + If the model has no version, returns 0. """ - model_hash_bytes = bytes.fromhex(model.model_hash[:64].ljust(64, "0")) - schema_hash_bytes = bytes.fromhex(model.schema_hash[:64].ljust(64, "0")) - - header = bytearray() - header += ZRMB_MAGIC - header += struct.pack(" 1 else 0 + return (minor << 8) | major + except (ValueError, IndexError): + pass + return 0 + + +def _pack_facts(model: CanonicalModel) -> bytes: + """Pack fact definitions. + + Wire layout per fact (16 bytes): + id: uint16 LE + type: uint8 + safety_rel: uint8 (bool) + range_min: int32 LE + range_max: int32 LE + default_value: int32 LE + """ + buf = bytearray() for i, f in enumerate(model.facts): - ftype = {"bool": 0, "int32": 1, "uint32": 2, "enum": 3}.get( - f.get("type", "bool"), 0 - ) rng = f.get("range", [0, 0]) rmin = rng[0] if isinstance(rng, list) and len(rng) >= 2 else 0 rmax = rng[1] if isinstance(rng, list) and len(rng) >= 2 else 0 - sections += struct.pack( - " bytes: + """Pack rule definitions. + + Wire layout per rule (20 bytes): + id: uint16 LE + rule_class: uint8 + safety_critical: uint8 + cond_start: uint16 LE + cond_count: uint16 LE + action_start: uint16 LE + action_count: uint16 LE + expr_start: uint16 LE + expr_count: uint16 LE + safety_goal_id: uint16 LE + set_mode: uint16 LE + """ + buf = bytearray() + cond_offset = 0 + for i, r in enumerate(model.rules): + rclass = _CLASS_MAP.get(r.get("class", "inference"), 0) + then = r.get("then", {}) + + set_mode = 0xFFFF + if isinstance(then, dict) and "set_mode" in then: + mid = model.mode_id_map.get(then["set_mode"]) + if mid is not None: + set_mode = mid + + when = r.get("when", {}) + cond_count = 0 + if isinstance(when, dict): + for gk in ("all", "any", "not"): + g = when.get(gk) + if isinstance(g, list): + cond_count += len(g) + + action_start = 0 + action_count = 0 + if isinstance(then, dict) and "action" in then: + aid = model.action_id_map.get(then["action"]) + if aid is not None: + action_start = aid + action_count = 1 + + safety_critical = 1 if ( + isinstance(then, dict) and then.get("criticality") == "safety_critical" + ) else 0 + + expr_start = r.get("_expr_start", 0) + expr_count = r.get("_expr_count", 0) + + buf += struct.pack( + " bytes: + """Pack condition definitions. + + Wire layout per condition (12 bytes): + fact_id: uint16 LE + op: uint8 + group: uint8 + value: int32 LE + group_index: uint16 LE + next: uint16 LE + """ + buf = bytearray() + for c in model.conditions: + fact_id = c.get("fact_id", 0) + op = _OP_MAP.get(c.get("op", "=="), 0) + group = _COND_GROUP_MAP.get(c.get("group", "all"), 0) + val = c.get("value", 0) + if isinstance(val, bool): + val = 1 if val else 0 + buf += struct.pack(" bytes: + """Pack expression definitions. + + Wire layout per expression (20 bytes): + target_fact_id: uint16 LE + op: uint8 + pad: uint8 + left_fact_id: uint16 LE + left_literal: int32 LE + right_fact_id: uint16 LE + pad2: uint16 LE + right_literal: int32 LE + scale: int32 LE + """ + buf = bytearray() + # Recalculate: we need 20 bytes. Let's lay it out more carefully. + # target(2) + op(1) + pad(1) + left_fact(2) + pad(2) + left_lit(4) + right_fact(2) + pad(2) + right_lit(4) = 20 + # Actually let's use a cleaner layout: + # target(2) + op(1) + pad(1) + left_fact(2) + right_fact(2) + left_lit(4) + right_lit(4) + scale(4) = 20 + for e in model.expressions: + target = e.get("target_fact_id", 0) + op = _EXPR_OP_MAP.get(e.get("op", "assign"), 13) + left_fact = e.get("left_fact_id", 0xFFFF) + right_fact = e.get("right_fact_id", 0xFFFF) + left_lit = e.get("left_literal", 0) + right_lit = e.get("right_literal", 0) + scale = e.get("scale", 1) + buf += struct.pack( + " bytes: + """Pack action definitions. + + Wire layout per action (12 bytes): + id: uint16 LE + type: uint8 + safe_state_action: uint8 + target_fact_id: uint16 LE + must_complete_within_ms: uint16 LE + target_value: int32 LE + """ + buf = bytearray() + for i, a in enumerate(model.actions): + atype = _ACTION_TYPE_MAP.get(a.get("type", "callback"), 0) + safe_state = 1 if a.get("safe_state_action", False) else 0 + must_complete = int(a.get("must_complete_within_ms", 0)) & 0xFFFF + buf += struct.pack(" tuple[bytes, list[int]]: + """Pack all name strings into a string table section. + + Returns (string_blob, offsets_within_blob). + Strings are null-terminated UTF-8. + """ + buf = bytearray() + offsets: list[int] = [] + seen: dict[str, int] = {} + + def _add(s: str | None) -> int: + if s is None: + return 0xFFFF + if s in seen: + return seen[s] + off = len(buf) + seen[s] = off + buf += s.encode("utf-8") + b"\x00" + offsets.append(off) + return off + + # Model name + _add(model.name) + + # Fact names + for f in model.facts: + _add(f.get("id")) + + # Rule names and explanations + for r in model.rules: + _add(r.get("id")) + then = r.get("then", {}) + if isinstance(then, dict): + _add(then.get("explanation")) + + # Action names + for a in model.actions: + _add(a.get("id")) + + # Mode names + for m in model.modes: + if isinstance(m, dict): + _add(m.get("id")) + + return bytes(buf), offsets + + +def _pack_modes(model: CanonicalModel) -> bytes: + """Pack mode names as uint16 string offsets (placeholder — modes are + identified by index; the blob consumer uses the string table).""" + buf = bytearray() + for i, m in enumerate(model.modes): + if isinstance(m, dict) and "id" in m: + buf += struct.pack(" bytes: + """Emit a .zrmb binary blob from a canonical model. + + Returns the complete blob as bytes. + """ + # Pack all section data + facts_data = _pack_facts(model) + rules_data = _pack_rules(model) + cond_data = _pack_conditions(model) + expr_data = _pack_expressions(model) + action_data = _pack_actions(model) + string_data, _ = _pack_strings(model) + mode_data = _pack_modes(model) + + sections: list[tuple[int, bytes, int, int]] = [] # (type, data, count, elem_size) + + fact_elem = 16 + if model.facts: + sections.append((SECTION_FACTS, facts_data, len(model.facts), fact_elem)) + + rule_elem = 20 + if model.rules: + sections.append((SECTION_RULES, rules_data, len(model.rules), rule_elem)) + + cond_elem = 12 + if model.conditions: + sections.append((SECTION_CONDITIONS, cond_data, len(model.conditions), cond_elem)) + + expr_elem = 20 + if model.expressions: + sections.append((SECTION_EXPRESSIONS, expr_data, len(model.expressions), expr_elem)) + + action_elem = 12 + if model.actions: + sections.append((SECTION_ACTIONS, action_data, len(model.actions), action_elem)) + + if string_data: + sections.append((SECTION_STRINGS, string_data, len(string_data), 1)) + + mode_elem = 2 + mode_count = len([m for m in model.modes if isinstance(m, dict) and "id" in m]) + if mode_count: + sections.append((SECTION_MODES, mode_data, mode_count, mode_elem)) + + # Compute layout + section_table_size = len(sections) * _SECTION_ENTRY_SIZE + data_start = _HEADER_LEN + section_table_size + + # Build section table and collect data blobs + table_buf = bytearray() + data_buf = bytearray() + current_offset = data_start + + for sec_type, sec_data, count, elem_size in sections: + table_buf += struct.pack( + "