Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 111 additions & 34 deletions src/tether/runtime/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

Design notes:
- Pure stdlib. Don't take a new runtime dep for the recorder.
- Synchronous file writes. The hook fires once per /act after inference;
flush() per record means a writer crash leaves at most one partial line
(reader skips it per D.1.11).
- Public writes enqueue to a bounded in-process queue; a daemon worker thread
performs file I/O so the /act event loop is not blocked by large JSONL lines.
The worker still flushes per record, so a writer crash leaves at most one
partial line (reader skips it per D.1.11).
- Disk-full → degrade silently. Catches OSError, sets `self.degraded`,
stops writing, but lets `tether serve` continue. /health surfaces this
via `getattr(server, '_recorder', None).degraded` if a consumer wants.
Expand Down Expand Up @@ -40,13 +41,17 @@
write to different sinks; the /act hook stamps `tether.record.seq` on
the OTel span so the two ledgers can be cross-grepped by seq.
"""

from __future__ import annotations

import copy
import gzip
import hashlib
import io
import json
import logging
import queue
import threading
import uuid
from datetime import datetime, timezone
from pathlib import Path
Expand All @@ -55,6 +60,8 @@
logger = logging.getLogger(__name__)

SCHEMA_VERSION = 1
RECORD_QUEUE_MAXSIZE = 1000
_QUEUE_STOP = object()

ImageRedaction = Literal["full", "hash_only", "none"]

Expand Down Expand Up @@ -107,9 +114,7 @@ def compute_config_hash(export_dir: str | Path) -> str:
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()[:16]


def _redact_image(
image_b64: str | None, mode: ImageRedaction
) -> dict[str, Any]:
def _redact_image(image_b64: str | None, mode: ImageRedaction) -> dict[str, Any]:
"""Apply image redaction policy. Returns the request.image_* fields
that should be in the record (varies by mode per D.1.8)."""
out: dict[str, Any] = {}
Expand All @@ -133,7 +138,9 @@ def _chain_hash(prev_hash: str, record: dict[str, Any]) -> str:
``prev_hash``. Deterministic (sorted keys, no whitespace).
"""
payload = {k: v for k, v in record.items() if k not in ("prev_record_hash", "record_hash")}
canonical = json.dumps(payload, sort_keys=True, separators=(",", ":"), default=str).encode("utf-8")
canonical = json.dumps(payload, sort_keys=True, separators=(",", ":"), default=str).encode(
"utf-8"
)
return hashlib.sha256(prev_hash.encode("ascii") + canonical).hexdigest()


Expand Down Expand Up @@ -191,9 +198,7 @@ def __init__(
# Filename: <YYYYMMDD>-<HHMMSS>-<model_hash>-<session_id>.jsonl[.gz]
ts = datetime.now(timezone.utc)
fname = (
f"{ts.strftime('%Y%m%d-%H%M%S')}-"
f"{model_hash or 'unknownhash'}-"
f"{self.session_id}.jsonl"
f"{ts.strftime('%Y%m%d-%H%M%S')}-{model_hash or 'unknownhash'}-{self.session_id}.jsonl"
)
if gzip_output:
fname += ".gz"
Expand Down Expand Up @@ -236,6 +241,7 @@ def __init__(
if pro_customer_id:
try:
from tether.pro.fingerprint import compute_fingerprint

# Sign over the canonicalized header meta so the fingerprint
# ties to this session's identity (not per-request).
canonical = json.dumps(
Expand All @@ -255,6 +261,14 @@ def __init__(
self._seq = 0
self._prev_record_hash = "0" * 64 # tamper-evident hash-chain head
self.degraded = False # set on first OSError; recorder stops writing
self._closed = False
self._record_queue: queue.Queue[object] = queue.Queue(maxsize=RECORD_QUEUE_MAXSIZE)
self._worker = threading.Thread(
target=self._run_worker,
name=f"RecordWriter-{self.session_id[:8]}",
daemon=True,
)
self._worker.start()
# Curate dual-write: when a FreeContributorCollector is attached,
# write_request emits to BOTH the JSONL trace (audit) AND the
# curate queue (training corpus). Independent failure modes; if
Expand All @@ -266,13 +280,56 @@ def __init__(
curate_collector.start()
except Exception as exc: # noqa: BLE001
logger.warning(
"curate_collector.start failed (curate dual-write disabled): %s", exc,
"curate_collector.start failed (curate dual-write disabled): %s",
exc,
)
self._curate_collector = None

# Open file lazily on first emit so the file isn't created if
# nothing ever gets recorded (e.g. empty test runs).

# ---------------------------------------------------------------
# Background worker
# ---------------------------------------------------------------

def _run_worker(self) -> None:
while True:
item = self._record_queue.get()
try:
if item is _QUEUE_STOP:
self._close_file()
return
if not isinstance(item, tuple):
self._set_degraded(TypeError("invalid record queue item"))
continue
for record in item:
self._emit(record)
except Exception as exc: # noqa: BLE001 — worker must not die with queued work
self._set_degraded(exc)
finally:
self._record_queue.task_done()

def _enqueue_records(self, records: tuple[dict[str, Any], ...]) -> bool:
if not records:
return True
if self.degraded or self._closed:
return False
try:
queued_records = copy.deepcopy(records)
except Exception as exc: # noqa: BLE001
self._set_degraded(exc)
return False
try:
self._record_queue.put_nowait(queued_records)
except queue.Full:
self._set_degraded(RuntimeError(f"record queue full (maxsize={RECORD_QUEUE_MAXSIZE})"))
return False
return True

def flush_sync(self) -> None:
"""Block until all records queued so far have been written or dropped."""
self._record_queue.join()

# ---------------------------------------------------------------
# File handle
# ---------------------------------------------------------------
Expand All @@ -289,6 +346,18 @@ def _open_if_needed(self) -> None:
except OSError as e:
self._set_degraded(e)

def _close_file(self) -> None:
if self._fh is None:
return
try:
self._fh.flush()
self._fh.close()
except OSError as e:
logger.warning("RecordWriter close failed: %s", e)
finally:
self._fh = None
logger.info("RecordWriter closed: %s (seq=%d)", self.filepath, self._seq)

def _set_degraded(self, exc: BaseException) -> None:
self.degraded = True
logger.error(
Expand Down Expand Up @@ -316,17 +385,19 @@ def _emit(self, record: dict[str, Any]) -> None:
# Records (D.1.3 / D.1.4 / D.1.6)
# ---------------------------------------------------------------

def _write_header(self) -> None:
if self._header_written:
return
record = {
def _header_record(self) -> dict[str, Any]:
return {
"kind": "header",
"schema_version": SCHEMA_VERSION,
"session_id": self.session_id,
"started_at": self.started_at,
**self._header_meta,
}
self._emit(record)

def _write_header(self) -> None:
if self._header_written:
return
self._emit(self._header_record())
self._header_written = True

def write_request(
Expand Down Expand Up @@ -357,17 +428,14 @@ def write_request(

Returns -1 if recording was skipped (degraded or sample_rate drop).
"""
if self.degraded:
if self.degraded or self._closed:
return -1
# Sample rate (deterministic by seq for now; random sampling is a v2 nit)
if self.sample_rate < 1.0 and (self._seq * self.sample_rate) % 1 >= self.sample_rate:
self._seq += 1
return -1

self._write_header()

seq = self._seq
self._seq += 1

request_obj: dict[str, Any] = {
"instruction": instruction,
Expand Down Expand Up @@ -398,7 +466,8 @@ def write_request(
"action_dim": action_dim,
},
"latency": latency_obj,
"denoise": denoise or {
"denoise": denoise
or {
"steps_used": 0,
"steps_configured": 0,
"adaptive": False,
Expand All @@ -424,7 +493,14 @@ def write_request(
if routing is not None:
record["routing"] = routing

self._emit(record)
needs_header = not self._header_written
queued_records = (self._header_record(), record) if needs_header else (record,)
if not self._enqueue_records(queued_records):
return -1
if needs_header:
self._header_written = True
self._seq += 1

# Curate dual-write: feed the same event into the contribution queue.
# Failures here NEVER affect the JSONL trace — collector is best-effort.
if self._curate_collector is not None and error is None:
Expand All @@ -434,6 +510,7 @@ def write_request(
QueueFull,
hash_instruction,
)

event = CollectedEvent(
timestamp=record["timestamp"],
episode_id=self.session_id,
Expand All @@ -456,7 +533,7 @@ def write_request(
def write_footer(self, totals: dict[str, int]) -> None:
"""Emit footer on clean shutdown. Optional per D.1.6 — readers
tolerate absence."""
if self.degraded or not self._header_written:
if self.degraded or self._closed or not self._header_written:
# Don't write a footer if we never wrote the header
return
record = {
Expand All @@ -465,26 +542,26 @@ def write_footer(self, totals: dict[str, int]) -> None:
"ended_at": _utc_now_iso(),
**totals,
}
self._emit(record)
self._enqueue_records((record,))

def close(self) -> None:
if self._closed:
return
self._closed = True
# Stop the curate collector first so its drain has a chance to
# flush queued events before the process exits.
if self._curate_collector is not None:
try:
self._curate_collector.stop()
except Exception as exc: # noqa: BLE001
logger.warning("curate_collector.stop failed: %s", exc)
if self._fh is None:
return
try:
self._fh.flush()
self._fh.close()
except OSError as e:
logger.warning("RecordWriter close failed: %s", e)
finally:
self._fh = None
logger.info("RecordWriter closed: %s (seq=%d)", self.filepath, self._seq)
self.flush_sync()
self._record_queue.put(_QUEUE_STOP)
self._record_queue.join()
if self._worker is not threading.current_thread():
self._worker.join(timeout=1.0)
if self._worker.is_alive():
logger.warning("RecordWriter worker did not stop cleanly")

# ---------------------------------------------------------------
# Convenience
Expand Down
3 changes: 1 addition & 2 deletions src/tether/runtime/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@
import io
import json
import logging
import os
import time
from pathlib import Path
from typing import Any

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .record import (
RecordWriter,
Expand Down
Loading