diff --git a/src/tether/runtime/record.py b/src/tether/runtime/record.py index 5acc8c6..0eb13e7 100644 --- a/src/tether/runtime/record.py +++ b/src/tether/runtime/record.py @@ -6,9 +6,10 @@ Design notes: - Pure stdlib. Don't take a new runtime dep for the recorder. -- Synchronous file writes. The hook fires once per /act after inference; - flush() per record means a writer crash leaves at most one partial line - (reader skips it per D.1.11). +- Public writes enqueue to a bounded in-process queue; a daemon worker thread + performs file I/O so the /act event loop is not blocked by large JSONL lines. + The worker still flushes per record, so a writer crash leaves at most one + partial line (reader skips it per D.1.11). - Disk-full → degrade silently. Catches OSError, sets `self.degraded`, stops writing, but lets `tether serve` continue. /health surfaces this via `getattr(server, '_recorder', None).degraded` if a consumer wants. @@ -40,13 +41,17 @@ write to different sinks; the /act hook stamps `tether.record.seq` on the OTel span so the two ledgers can be cross-grepped by seq. """ + from __future__ import annotations +import copy import gzip import hashlib import io import json import logging +import queue +import threading import uuid from datetime import datetime, timezone from pathlib import Path @@ -55,6 +60,8 @@ logger = logging.getLogger(__name__) SCHEMA_VERSION = 1 +RECORD_QUEUE_MAXSIZE = 1000 +_QUEUE_STOP = object() ImageRedaction = Literal["full", "hash_only", "none"] @@ -107,9 +114,7 @@ def compute_config_hash(export_dir: str | Path) -> str: return hashlib.sha256(canonical.encode("utf-8")).hexdigest()[:16] -def _redact_image( - image_b64: str | None, mode: ImageRedaction -) -> dict[str, Any]: +def _redact_image(image_b64: str | None, mode: ImageRedaction) -> dict[str, Any]: """Apply image redaction policy. Returns the request.image_* fields that should be in the record (varies by mode per D.1.8).""" out: dict[str, Any] = {} @@ -133,7 +138,9 @@ def _chain_hash(prev_hash: str, record: dict[str, Any]) -> str: ``prev_hash``. Deterministic (sorted keys, no whitespace). """ payload = {k: v for k, v in record.items() if k not in ("prev_record_hash", "record_hash")} - canonical = json.dumps(payload, sort_keys=True, separators=(",", ":"), default=str).encode("utf-8") + canonical = json.dumps(payload, sort_keys=True, separators=(",", ":"), default=str).encode( + "utf-8" + ) return hashlib.sha256(prev_hash.encode("ascii") + canonical).hexdigest() @@ -191,9 +198,7 @@ def __init__( # Filename: ---.jsonl[.gz] ts = datetime.now(timezone.utc) fname = ( - f"{ts.strftime('%Y%m%d-%H%M%S')}-" - f"{model_hash or 'unknownhash'}-" - f"{self.session_id}.jsonl" + f"{ts.strftime('%Y%m%d-%H%M%S')}-{model_hash or 'unknownhash'}-{self.session_id}.jsonl" ) if gzip_output: fname += ".gz" @@ -236,6 +241,7 @@ def __init__( if pro_customer_id: try: from tether.pro.fingerprint import compute_fingerprint + # Sign over the canonicalized header meta so the fingerprint # ties to this session's identity (not per-request). canonical = json.dumps( @@ -255,6 +261,14 @@ def __init__( self._seq = 0 self._prev_record_hash = "0" * 64 # tamper-evident hash-chain head self.degraded = False # set on first OSError; recorder stops writing + self._closed = False + self._record_queue: queue.Queue[object] = queue.Queue(maxsize=RECORD_QUEUE_MAXSIZE) + self._worker = threading.Thread( + target=self._run_worker, + name=f"RecordWriter-{self.session_id[:8]}", + daemon=True, + ) + self._worker.start() # Curate dual-write: when a FreeContributorCollector is attached, # write_request emits to BOTH the JSONL trace (audit) AND the # curate queue (training corpus). Independent failure modes; if @@ -266,13 +280,56 @@ def __init__( curate_collector.start() except Exception as exc: # noqa: BLE001 logger.warning( - "curate_collector.start failed (curate dual-write disabled): %s", exc, + "curate_collector.start failed (curate dual-write disabled): %s", + exc, ) self._curate_collector = None # Open file lazily on first emit so the file isn't created if # nothing ever gets recorded (e.g. empty test runs). + # --------------------------------------------------------------- + # Background worker + # --------------------------------------------------------------- + + def _run_worker(self) -> None: + while True: + item = self._record_queue.get() + try: + if item is _QUEUE_STOP: + self._close_file() + return + if not isinstance(item, tuple): + self._set_degraded(TypeError("invalid record queue item")) + continue + for record in item: + self._emit(record) + except Exception as exc: # noqa: BLE001 — worker must not die with queued work + self._set_degraded(exc) + finally: + self._record_queue.task_done() + + def _enqueue_records(self, records: tuple[dict[str, Any], ...]) -> bool: + if not records: + return True + if self.degraded or self._closed: + return False + try: + queued_records = copy.deepcopy(records) + except Exception as exc: # noqa: BLE001 + self._set_degraded(exc) + return False + try: + self._record_queue.put_nowait(queued_records) + except queue.Full: + self._set_degraded(RuntimeError(f"record queue full (maxsize={RECORD_QUEUE_MAXSIZE})")) + return False + return True + + def flush_sync(self) -> None: + """Block until all records queued so far have been written or dropped.""" + self._record_queue.join() + # --------------------------------------------------------------- # File handle # --------------------------------------------------------------- @@ -289,6 +346,18 @@ def _open_if_needed(self) -> None: except OSError as e: self._set_degraded(e) + def _close_file(self) -> None: + if self._fh is None: + return + try: + self._fh.flush() + self._fh.close() + except OSError as e: + logger.warning("RecordWriter close failed: %s", e) + finally: + self._fh = None + logger.info("RecordWriter closed: %s (seq=%d)", self.filepath, self._seq) + def _set_degraded(self, exc: BaseException) -> None: self.degraded = True logger.error( @@ -316,17 +385,19 @@ def _emit(self, record: dict[str, Any]) -> None: # Records (D.1.3 / D.1.4 / D.1.6) # --------------------------------------------------------------- - def _write_header(self) -> None: - if self._header_written: - return - record = { + def _header_record(self) -> dict[str, Any]: + return { "kind": "header", "schema_version": SCHEMA_VERSION, "session_id": self.session_id, "started_at": self.started_at, **self._header_meta, } - self._emit(record) + + def _write_header(self) -> None: + if self._header_written: + return + self._emit(self._header_record()) self._header_written = True def write_request( @@ -357,17 +428,14 @@ def write_request( Returns -1 if recording was skipped (degraded or sample_rate drop). """ - if self.degraded: + if self.degraded or self._closed: return -1 # Sample rate (deterministic by seq for now; random sampling is a v2 nit) if self.sample_rate < 1.0 and (self._seq * self.sample_rate) % 1 >= self.sample_rate: self._seq += 1 return -1 - self._write_header() - seq = self._seq - self._seq += 1 request_obj: dict[str, Any] = { "instruction": instruction, @@ -398,7 +466,8 @@ def write_request( "action_dim": action_dim, }, "latency": latency_obj, - "denoise": denoise or { + "denoise": denoise + or { "steps_used": 0, "steps_configured": 0, "adaptive": False, @@ -424,7 +493,14 @@ def write_request( if routing is not None: record["routing"] = routing - self._emit(record) + needs_header = not self._header_written + queued_records = (self._header_record(), record) if needs_header else (record,) + if not self._enqueue_records(queued_records): + return -1 + if needs_header: + self._header_written = True + self._seq += 1 + # Curate dual-write: feed the same event into the contribution queue. # Failures here NEVER affect the JSONL trace — collector is best-effort. if self._curate_collector is not None and error is None: @@ -434,6 +510,7 @@ def write_request( QueueFull, hash_instruction, ) + event = CollectedEvent( timestamp=record["timestamp"], episode_id=self.session_id, @@ -456,7 +533,7 @@ def write_request( def write_footer(self, totals: dict[str, int]) -> None: """Emit footer on clean shutdown. Optional per D.1.6 — readers tolerate absence.""" - if self.degraded or not self._header_written: + if self.degraded or self._closed or not self._header_written: # Don't write a footer if we never wrote the header return record = { @@ -465,9 +542,12 @@ def write_footer(self, totals: dict[str, int]) -> None: "ended_at": _utc_now_iso(), **totals, } - self._emit(record) + self._enqueue_records((record,)) def close(self) -> None: + if self._closed: + return + self._closed = True # Stop the curate collector first so its drain has a chance to # flush queued events before the process exits. if self._curate_collector is not None: @@ -475,16 +555,13 @@ def close(self) -> None: self._curate_collector.stop() except Exception as exc: # noqa: BLE001 logger.warning("curate_collector.stop failed: %s", exc) - if self._fh is None: - return - try: - self._fh.flush() - self._fh.close() - except OSError as e: - logger.warning("RecordWriter close failed: %s", e) - finally: - self._fh = None - logger.info("RecordWriter closed: %s (seq=%d)", self.filepath, self._seq) + self.flush_sync() + self._record_queue.put(_QUEUE_STOP) + self._record_queue.join() + if self._worker is not threading.current_thread(): + self._worker.join(timeout=1.0) + if self._worker.is_alive(): + logger.warning("RecordWriter worker did not stop cleanly") # --------------------------------------------------------------- # Convenience diff --git a/src/tether/runtime/server.py b/src/tether/runtime/server.py index 8318b70..5f7999b 100644 --- a/src/tether/runtime/server.py +++ b/src/tether/runtime/server.py @@ -22,14 +22,13 @@ import io import json import logging +import os import time from pathlib import Path from typing import Any import numpy as np import torch -import torch.nn as nn -import torch.nn.functional as F from .record import ( RecordWriter, diff --git a/tests/test_record.py b/tests/test_record.py index dd0bd25..cbfe891 100644 --- a/tests/test_record.py +++ b/tests/test_record.py @@ -5,11 +5,14 @@ Pure stdlib — no model loads, no network. """ + from __future__ import annotations import gzip import hashlib import json +import threading +import time from pathlib import Path import pytest @@ -89,9 +92,18 @@ def test_header_has_required_fields(self, tmp_path): rec.close() h = _read_all(rec.filepath)[0] for field in [ - "kind", "schema_version", "tether_version", "model_hash", - "config_hash", "export_dir", "model_type", "export_kind", - "hardware", "providers", "session_id", "started_at", + "kind", + "schema_version", + "tether_version", + "model_hash", + "config_hash", + "export_dir", + "model_type", + "export_kind", + "hardware", + "providers", + "session_id", + "started_at", ]: assert field in h, f"header missing required field '{field}'" @@ -121,8 +133,16 @@ def test_request_has_required_fields(self, tmp_path): rec.close() req_rec = _read_all(rec.filepath)[1] for field in [ - "kind", "schema_version", "seq", "chunk_id", "timestamp", - "request", "response", "latency", "denoise", "mode", + "kind", + "schema_version", + "seq", + "chunk_id", + "timestamp", + "request", + "response", + "latency", + "denoise", + "mode", ]: assert field in req_rec, f"request missing '{field}'" assert req_rec["kind"] == "request" @@ -292,6 +312,126 @@ def test_gzip_size_smaller_for_big_payload(self, tmp_path): assert rec_gz.filepath.stat().st_size < rec_pl.filepath.stat().st_size +# --------------------------------------------------------------------------- +# Background writer +# --------------------------------------------------------------------------- + + +class TestBackgroundWriter: + def test_write_request_returns_before_background_emit_completes(self, tmp_path, monkeypatch): + rec = _make_writer(tmp_path, gzip_output=False) + emit_started = threading.Event() + release_emit = threading.Event() + original_emit = rec._emit + + def slow_emit(record): + emit_started.set() + if not release_emit.wait(1.0): + raise TimeoutError("test did not release RecordWriter worker") + original_emit(record) + + monkeypatch.setattr(rec, "_emit", slow_emit) + + started_at = time.perf_counter() + seq = _dummy_request(rec) + elapsed = time.perf_counter() - started_at + + assert seq == 0 + assert elapsed < 0.05 + assert emit_started.wait(1.0) + assert not rec.filepath.exists() + + release_emit.set() + rec.flush_sync() + records = _read_all(rec.filepath) + assert [r["kind"] for r in records] == ["header", "request"] + rec.close() + + def test_write_request_snapshots_mutable_inputs(self, tmp_path, monkeypatch): + rec = _make_writer(tmp_path, gzip_output=False) + emit_started = threading.Event() + release_emit = threading.Event() + original_emit = rec._emit + + def blocking_emit(record): + emit_started.set() + if not release_emit.wait(1.0): + raise TimeoutError("test did not release RecordWriter worker") + original_emit(record) + + monkeypatch.setattr(rec, "_emit", blocking_emit) + + state = [1.0, 2.0] + actions = [[3.0, 4.0]] + cache = {"hit": False, "nested": {"count": 1}} + + try: + seq = _dummy_request( + rec, + state=state, + actions=actions, + action_dim=2, + cache=cache, + ) + assert seq == 0 + assert emit_started.wait(1.0) + + state[0] = 99.0 + actions[0][0] = 99.0 + cache["nested"]["count"] = 99 + + release_emit.set() + rec.flush_sync() + request = _read_all(rec.filepath)[1] + assert request["request"]["state"] == [1.0, 2.0] + assert request["response"]["actions"] == [[3.0, 4.0]] + assert request["cache"]["nested"]["count"] == 1 + finally: + release_emit.set() + rec.close() + + def test_queue_full_degrades_and_short_circuits(self, tmp_path, monkeypatch): + monkeypatch.setattr("tether.runtime.record.RECORD_QUEUE_MAXSIZE", 1) + rec = _make_writer(tmp_path, gzip_output=False) + emit_started = threading.Event() + release_emit = threading.Event() + original_emit = rec._emit + + def blocking_emit(record): + emit_started.set() + release_emit.wait(1.0) + original_emit(record) + + monkeypatch.setattr(rec, "_emit", blocking_emit) + + try: + assert _dummy_request(rec, i=0) == 0 + assert emit_started.wait(1.0) + assert _dummy_request(rec, i=1) == 1 + + assert _dummy_request(rec, i=2) == -1 + assert rec.degraded is True + assert _dummy_request(rec, i=3) == -1 + finally: + release_emit.set() + rec.close() + + def test_close_drains_pending_records(self, tmp_path): + rec = _make_writer(tmp_path, gzip_output=False) + for i in range(10): + _dummy_request(rec, i=i) + rec.write_footer({"total_requests": 10}) + + rec.close() + + records = _read_all(rec.filepath) + requests = [r for r in records if r["kind"] == "request"] + assert len(requests) == 10 + assert requests[-1]["seq"] == 9 + assert records[-1]["kind"] == "footer" + assert records[-1]["total_requests"] == 10 + + # --------------------------------------------------------------------------- # Disk-full degraded path (D.1.11) # --------------------------------------------------------------------------- @@ -304,6 +444,7 @@ def test_degraded_on_write_failure(self, tmp_path, monkeypatch): still returns a valid seq — subsequent calls return -1.""" rec = _make_writer(tmp_path) _dummy_request(rec) # opens file + writes header + record + rec.flush_sync() assert not rec.degraded # Make every subsequent write raise @@ -313,27 +454,32 @@ def always_fail(_data): rec._fh.write = always_fail # type: ignore[assignment] # The call that triggers degradation still returns a valid seq - _dummy_request(rec, i=1) + assert _dummy_request(rec, i=1) == 1 + rec.flush_sync() assert rec.degraded is True # Subsequent calls short-circuit and return -1 seq_after = _dummy_request(rec, i=2) assert seq_after == -1 + rec.close() def test_degraded_recorder_skips_subsequent_writes(self, tmp_path): """Once degraded, write_request returns -1 and doesn't touch the file.""" rec = _make_writer(tmp_path) _dummy_request(rec) + rec.flush_sync() rec.degraded = True size_before = rec.filepath.stat().st_size seq = _dummy_request(rec, i=99) assert seq == -1 assert rec.filepath.stat().st_size == size_before + rec.close() def test_degraded_skips_footer(self, tmp_path): """write_footer is a no-op when degraded.""" rec = _make_writer(tmp_path) _dummy_request(rec) + rec.flush_sync() rec.degraded = True rec.write_footer({"total_requests": 99}) rec.close()