Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Added

- `prune_chart_to_tree` (default `False`): when on, chart rows whose
`chart_strain_field` value isn't a tree tip are filtered out before
drawing. CLI form:
`--prune-chart-to-tree / --no-prune-chart-to-tree`.
([#6](https://github.com/jbloomlab/tree-annotated-plot/issues/6))

## [0.2.2] - 2026-05-09

### Fixed
Expand Down
11 changes: 8 additions & 3 deletions src/tree_annotated_plot/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,14 @@ class PlotConfig:
"When off (default), tree tips not present in the chart's strain "
"set are a fatal error. When on, those tips (and any internal "
"nodes whose subtrees become empty) are dropped before drawing, "
"with single-child internals collapsed into their kept child. "
"Chart strains not present in the tree are *always* fatal "
"regardless of this flag — pruning would silently lose plot data.",
"with single-child internals collapsed into their kept child.",
] = False

prune_chart_to_tree: Annotated[
bool,
"When off (default), chart strains not present in the tree are a "
"fatal error. When on, chart rows whose `chart_strain_field` "
"value isn't a tree tip are filtered out before drawing.",
] = False

strict_version: Annotated[
Expand Down
107 changes: 101 additions & 6 deletions src/tree_annotated_plot/_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def plot(
scale_bar: bool = False,
branch_length_units: str | None = None,
prune_tree_to_chart: bool = False,
prune_chart_to_tree: bool = False,
strict_version: bool = True,
connect_leader_to_label: bool = False,
strain_label_font_size: float = 10.0,
Expand Down Expand Up @@ -82,6 +83,7 @@ def plot(
scale_bar=scale_bar,
branch_length_units=branch_length_units,
prune_tree_to_chart=prune_tree_to_chart,
prune_chart_to_tree=prune_chart_to_tree,
strict_version=strict_version,
connect_leader_to_label=connect_leader_to_label,
strain_label_font_size=strain_label_font_size,
Expand Down Expand Up @@ -168,12 +170,33 @@ def _build(

chart_strains = _extract_chart_strains(spec, axis_hits, config.chart_strain_field)

if config.prune_chart_to_tree and (set(chart_strains) - set(tip_names)):
_prune_chart_spec_to_strains(
spec,
chart_strain_field=config.chart_strain_field,
keep_strains=set(tip_names),
)
chart = alt.Chart.from_dict(spec)
chart_strains = _extract_chart_strains(
spec, axis_hits, config.chart_strain_field
)
if not chart_strains:
raise ValueError(
"prune_chart_to_tree=True dropped every chart row: no "
"chart strain matched any tree tip under "
f"chart_strain_field={config.chart_strain_field!r} / "
f"tree_strain_field={config.tree_strain_field!r}. Pruning "
"is meant for charts that bundle a superset of strains; "
"an empty intersection suggests a wrong field choice."
)

_reconcile_tips_and_strains(
tree_strains=tip_names,
chart_strains=chart_strains,
chart_strain_field=config.chart_strain_field,
tree_strain_field=config.tree_strain_field,
prune_tree_to_chart=config.prune_tree_to_chart,
prune_chart_to_tree=config.prune_chart_to_tree,
chart_spec=spec,
tree_source=tree,
)
Expand Down Expand Up @@ -641,6 +664,70 @@ def _extract_chart_strains(
return _extract_field_values_from_spec_data(spec, chart_strain_field)


def _prune_chart_spec_to_strains(
spec: dict, *, chart_strain_field: str, keep_strains: set[str]
) -> None:
"""Filter a Vega-Lite spec in place to drop rows outside `keep_strains`.

Mutates three kinds of structure:
- top-level `datasets` entries (each a list of row-dicts).
- inline `data.values` lists anywhere in the spec tree.
- explicit `sort` lists on encoding channels bound to
`chart_strain_field` (any other `sort` is left alone).

Rows that don't carry `chart_strain_field` at all are preserved
(we have no signal to drop them). URL-backed data raises — we
can't fetch + filter at plot time, mirroring `_extract_chart_strains`.
"""

def row_kept(row: Any) -> bool:
if not isinstance(row, dict):
return True
if chart_strain_field not in row:
return True
return row[chart_strain_field] in keep_strains

datasets = spec.get("datasets") if isinstance(spec, dict) else None
if isinstance(datasets, dict):
for name, rows in list(datasets.items()):
if isinstance(rows, list):
datasets[name] = [row for row in rows if row_kept(row)]

def walk(node: Any) -> None:
if isinstance(node, dict):
data = node.get("data")
if isinstance(data, dict):
if "url" in data:
raise ValueError(
f"chart references data via URL ({data['url']!r}); "
"URL data is not supported, so prune_chart_to_tree "
"cannot filter it. Materialize the data inline "
"(via alt.Chart(df) with a pandas DataFrame) before "
"saving the chart."
)
if "values" in data and isinstance(data["values"], list):
data["values"] = [row for row in data["values"] if row_kept(row)]
encoding = node.get("encoding")
if isinstance(encoding, dict):
for channel in encoding.values():
if not isinstance(channel, dict):
continue
if channel.get("field") != chart_strain_field:
continue
sort = channel.get("sort")
if isinstance(sort, list):
channel["sort"] = [s for s in sort if s in keep_strains]
for k, v in node.items():
if k in ("data", "datasets"):
continue
walk(v)
elif isinstance(node, list):
for item in node:
walk(item)

walk(spec)


def _extract_field_values_from_spec_data(spec: dict, field: str) -> list[str]:
"""Walk spec for inline / named data and return distinct values of `field`.

Expand Down Expand Up @@ -708,13 +795,16 @@ def _reconcile_tips_and_strains(
chart_strain_field: str,
tree_strain_field: str,
prune_tree_to_chart: bool,
prune_chart_to_tree: bool,
chart_spec: dict,
tree_source: Any,
) -> None:
"""Verify tree strains and chart strains are reconcilable.

Three asymmetries:
- chart strains not in tree → always fatal.
- chart strains not in tree → fatal unless `prune_chart_to_tree=True`
(in which case the chart spec has already been pre-filtered upstream
and this set is expected to be empty by the time we get here).
- tree tips not in chart → fatal unless `prune_tree_to_chart=True`.
- (duplicate tree_strain_field values across tips → handled by the
separate `_check_no_duplicate_tip_strains`.)
Expand All @@ -728,7 +818,9 @@ def _reconcile_tips_and_strains(
chart_minus_tree = chart_set - tree_set
tree_minus_chart = tree_set - chart_set

if not chart_minus_tree and (not tree_minus_chart or prune_tree_to_chart):
chart_ok = not chart_minus_tree
tree_ok = not tree_minus_chart or prune_tree_to_chart
if chart_ok and tree_ok:
return

hints = _candidate_field_hints(
Expand All @@ -748,6 +840,7 @@ def _reconcile_tips_and_strains(
chart_minus_tree=chart_minus_tree,
tree_minus_chart=tree_minus_chart,
prune_tree_to_chart=prune_tree_to_chart,
prune_chart_to_tree=prune_chart_to_tree,
hints=hints,
)
)
Expand All @@ -762,14 +855,16 @@ def _format_strain_mismatch(
chart_minus_tree: set[str],
tree_minus_chart: set[str],
prune_tree_to_chart: bool,
prune_chart_to_tree: bool,
hints: list[str],
) -> str:
parts: list[str] = []
if chart_minus_tree:
if chart_minus_tree and not prune_chart_to_tree:
parts.append(
f"{len(chart_minus_tree)} chart strain(s) are not present in the "
"tree (these would be silently dropped if we pruned, so this is "
"always fatal)."
"tree. Pass `prune_chart_to_tree=True` to drop the offending "
"chart rows automatically (use with care — this discards plot "
"data)."
)
if tree_minus_chart and not prune_tree_to_chart:
parts.append(
Expand All @@ -783,7 +878,7 @@ def _format_strain_mismatch(
)
parts.append("Sample chart_strain_field values: " f"{sorted(chart_strains)[:5]}")
parts.append("Sample tree_strain_field values: " f"{sorted(tree_strains)[:5]}")
if chart_minus_tree:
if chart_minus_tree and not prune_chart_to_tree:
parts.append(f"Sample chart-only values: {sorted(chart_minus_tree)[:5]}")
if tree_minus_chart and not prune_tree_to_chart:
parts.append(f"Sample tree-only values: {sorted(tree_minus_chart)[:5]}")
Expand Down
97 changes: 94 additions & 3 deletions tests/test_reconciliation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _chart_for_strains(strains: list[str], *, height: int = 200) -> alt.Chart:
)


# ---------- chart-not-in-tree (always fatal) ----------
# ---------- chart-not-in-tree (fatal unless prune_chart_to_tree) ----------


def test_chart_strain_not_in_tree_is_fatal_default() -> None:
Expand All @@ -90,9 +90,9 @@ def test_chart_strain_not_in_tree_is_fatal_default() -> None:
)


def test_chart_strain_not_in_tree_is_fatal_even_with_prune() -> None:
def test_chart_strain_not_in_tree_is_fatal_with_only_tree_prune() -> None:
"""`prune_tree_to_chart=True` only drops *tree* tips. A chart strain
not in the tree still raises — pruning would silently lose plot data."""
not in the tree still raises — the two flags are orthogonal."""
chart = _chart_for_strains(["A1", "A2", "X"])
with pytest.raises(ValueError, match="not present in the tree"):
tree_annotated_plot.plot(
Expand All @@ -105,6 +105,97 @@ def test_chart_strain_not_in_tree_is_fatal_even_with_prune() -> None:
)


def test_chart_strain_not_in_tree_succeeds_with_prune_chart_to_tree() -> None:
"""`prune_chart_to_tree=True` filters chart rows whose strain isn't a
tree tip. The resulting chart's strain-axis sort matches the kept tree
tip order, and the dropped strain is gone from the chart's data."""
chart = _chart_for_strains(["A1", "A2", "A3", "B1", "B2", "X"])
out = tree_annotated_plot.plot(
_auspice_two_clades(),
chart,
chart_strain_field="strain",
tree_strain_field="name",
branch_length="div",
prune_chart_to_tree=True,
)
assert isinstance(out, alt.HConcatChart)
spec = out.to_dict()
# The strain-axis sort on the user-chart panel should be exactly the
# tree's tip order — no `X`.
sorts = []
for ch in out.hconcat:
ch_spec = ch.to_dict()
enc = ch_spec.get("encoding", {})
for channel in enc.values():
if isinstance(channel, dict) and channel.get("field") == "strain":
if isinstance(channel.get("sort"), list):
sorts.append(channel["sort"])
assert sorts, "expected at least one strain-axis encoding with a sort"
for s in sorts:
assert "X" not in s
assert set(s) <= {"A1", "A2", "A3", "B1", "B2"}
# And the user-chart data (now in `datasets`) should not contain X rows.
for rows in (spec.get("datasets") or {}).values():
if isinstance(rows, list) and rows and isinstance(rows[0], dict):
if "strain" in rows[0]:
assert all(r["strain"] != "X" for r in rows)


def test_prune_chart_to_tree_zero_overlap_raises() -> None:
"""If pruning would drop every chart row (no overlap with tree tips),
raise a clear error rather than producing an empty plot."""
chart = _chart_for_strains(["X", "Y", "Z"])
with pytest.raises(ValueError, match="dropped every chart row"):
tree_annotated_plot.plot(
_auspice_two_clades(),
chart,
chart_strain_field="strain",
tree_strain_field="name",
branch_length="div",
prune_chart_to_tree=True,
)


def test_prune_chart_and_tree_combined() -> None:
"""Both flags may be on at once: the chart has extras (X) and the tree
has tips the chart lacks (B1, B2). Pruning is bidirectional."""
chart = _chart_for_strains(["A1", "A2", "A3", "X"])
out = tree_annotated_plot.plot(
_auspice_two_clades(),
chart,
chart_strain_field="strain",
tree_strain_field="name",
branch_length="div",
prune_tree_to_chart=True,
prune_chart_to_tree=True,
)
assert isinstance(out, alt.HConcatChart)
# Final intersection is {A1, A2, A3}: chart loses X, tree loses B1/B2.
found_sort = None
for ch in out.hconcat:
enc = ch.to_dict().get("encoding", {})
for channel in enc.values():
if isinstance(channel, dict) and channel.get("field") == "strain":
if isinstance(channel.get("sort"), list):
found_sort = channel["sort"]
assert found_sort is not None
assert set(found_sort) == {"A1", "A2", "A3"}


def test_prune_chart_to_tree_default_error_mentions_flag() -> None:
"""When chart has strains not in tree and the user hasn't opted in,
the error message should suggest `prune_chart_to_tree=True`."""
chart = _chart_for_strains(["A1", "A2", "A3", "B1", "B2", "X"])
with pytest.raises(ValueError, match="prune_chart_to_tree=True"):
tree_annotated_plot.plot(
_auspice_two_clades(),
chart,
chart_strain_field="strain",
tree_strain_field="name",
branch_length="div",
)


# ---------- tree-not-in-chart (fatal unless prune) ----------


Expand Down