Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions deepem/data/classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""Central registry of detection/segmentation classes.

Each entry maps a CLI flag (dict key) to its internal name (`internal_name`).
The internal name is the load-bearing identifier: it names the PyTorch output
submodule (and therefore the state_dict / ONNX output), the wandb metric, the
loss criterion, the dataset key used by the sampler, and the LHS of
`zettaset_lookup`. Flag and internal name can differ freely.

Adding a new task (e.g., "psd_v2", "ribosome") is a one-line edit:

'ribo': ClassSpec(internal_name='ribosome', binarize=True),

The matching `--ribo <weight>` CLI flag (train) and `--ribo` switch (test)
become available automatically; no edits to option.py needed.
"""
from dataclasses import dataclass
from typing import Callable, Optional, Union


@dataclass(frozen=True)
class ClassSpec:
"""Specification for a detection/segmentation class.

Attributes:
internal_name: The single load-bearing identifier for this class.
Used as the PyTorch output submodule name (and state_dict /
ONNX key), the wandb metric name, the loss criterion key, the
dataset key consumed by the sampler, and the LHS of
`zettaset_lookup`. Distinct from the CLI flag (the REGISTRY
dict key).
channels: Output channels. Either an int, or a callable
opt -> int for classes whose width depends on another flag
(e.g. blood_vessel takes opt.blv_num_channels).
binarize: Whether the dataset should binarize the annotation.
Bool or callable opt -> bool.
semantic_id: If set, this class participates in semantic_mapping
when --sem is on (label ID used by the combined-semantic GT).
"""
internal_name: str
channels: Union[int, Callable] = 1
binarize: Union[bool, Callable] = False
semantic_id: Optional[int] = None

def resolve_channels(self, opt) -> int:
return self.channels(opt) if callable(self.channels) else self.channels

def resolve_binarize(self, opt) -> bool:
return self.binarize(opt) if callable(self.binarize) else self.binarize


REGISTRY: dict[str, ClassSpec] = {
# Affinity / boundary
'aff': ClassSpec(internal_name='affinity', channels=3),
'long': ClassSpec(internal_name='long_range', channels=lambda opt: len(opt.edges)),
'bdr': ClassSpec(internal_name='boundary'),

# Detection (binary)
'syn': ClassSpec(internal_name='synapse', binarize=True),
'psd': ClassSpec(internal_name='synapse', binarize=True),
'mit': ClassSpec(internal_name='mitochondria', binarize=True),
'mye': ClassSpec(internal_name='myelin', binarize=True),
'fld': ClassSpec(internal_name='fold', binarize=True),
'glia': ClassSpec(internal_name='glia', binarize=True, semantic_id=5),
'img': ClassSpec(internal_name='image'),
'mito_to_cell': ClassSpec(internal_name='mitochondria_to_cell'),

# Multi-channel blood vessel: binarize only when collapsed to 1 channel
'blv': ClassSpec(
internal_name='blood_vessel',
channels=lambda opt: opt.blv_num_channels,
binarize=lambda opt: opt.blv_num_channels == 1,
semantic_id=7,
),

# Semantic-segmentation classes (binarize only when --sem is off)
'soma': ClassSpec(internal_name='soma', binarize=True, semantic_id=3),
'dend': ClassSpec(internal_name='dendrite', semantic_id=1),
'axon': ClassSpec(internal_name='axon', semantic_id=2),
'nucl': ClassSpec(internal_name='nucleus', semantic_id=4),
'ecs': ClassSpec(internal_name='extracellular_space', semantic_id=6),
'other': ClassSpec(internal_name='other_class', semantic_id=10),

# Embedding (consumed by MeanLoss / metric learning)
'vec': ClassSpec(internal_name='embedding', channels=lambda opt: opt.embed_dim),
'mito_emb': ClassSpec(internal_name='mitochondria_embedding', channels=lambda opt: opt.mito_emb_dim),
}


def semantic_mapping() -> dict[str, int]:
"""Class-name -> label-ID dict, derived from semantic_id fields."""
return {
spec.internal_name: spec.semantic_id
for spec in REGISTRY.values()
if spec.semantic_id is not None
}


def requires_binarize(opt) -> list[str]:
"""Output names whose dataset annotation must be binarized.

Mirrors the legacy hand-maintained list: includes classes with
binarize=True, then drops any that are taken over by semantic_mapping
when --sem is on.
"""
names = {
spec.internal_name
for spec in REGISTRY.values()
if spec.resolve_binarize(opt)
}
if getattr(opt, 'sem', False):
names -= set(semantic_mapping().keys())
return sorted(names)
72 changes: 24 additions & 48 deletions deepem/test/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,32 +55,24 @@ def initialize(self):
self.parser.add_argument('--delta_d', type=float, default=1.5)
self.parser.add_argument('--scale_init', type=float, default=1.0)

# Multiclass detection
self.parser.add_argument('--aff', action='store_true')
# Per-class enable flags — auto-declared from the central registry.
# `vec` and `long` take integer channel counts (legacy), so they are
# declared explicitly and skipped here.
from deepem.data.classes import REGISTRY as CLASS_REGISTRY
for flag_name in CLASS_REGISTRY:
if flag_name in ('vec', 'long'):
continue
self.parser.add_argument(f'--{flag_name}', action='store_true')

self.parser.add_argument('--long', type=int, default=0)
self.parser.add_argument('--aff_deprecated', type=int, default=None)
self.parser.add_argument('--bdr', action='store_true')
self.parser.add_argument('--syn', action='store_true')
self.parser.add_argument('--psd', action='store_true')
self.parser.add_argument('--mit', action='store_true')
self.parser.add_argument('--mito_to_cell', action='store_true')
self.parser.add_argument('--mye', action='store_true')
self.parser.add_argument('--mye_thresh', type=float, default=0.5)
self.parser.add_argument('--blv', action='store_true')
self.parser.add_argument('--blv_num_channels', type=int, default=1)
self.parser.add_argument('--glia', action='store_true')
self.parser.add_argument('--sem', action='store_true')
self.parser.add_argument('--img', action='store_true')
self.parser.add_argument('--merge_classes', type=str, default=[], nargs='+')

# Semantic segmentation
self.parser.add_argument('--semantic', action='store_true')
self.parser.add_argument('--dend', action='store_true') # Dendrite
self.parser.add_argument('--axon', action='store_true') # Axon
self.parser.add_argument('--soma', action='store_true') # Soma
self.parser.add_argument('--nucl', action='store_true') # Nucleus
self.parser.add_argument('--ecs', action='store_true') # Extracellular space
self.parser.add_argument('--other', action='store_true') # Other class

# Test-time augmentation
self.parser.add_argument('--test_aug', type=int, default=None, nargs='+')
Expand Down Expand Up @@ -224,49 +216,33 @@ def parse(self):
opt.inputsz = (aniso_z, opt.inputsz[1], opt.inputsz[2])
opt.in_spec = dict(input=(1,) + opt.inputsz)

# Per-class out_spec entries (driven by class registry).
from deepem.data.classes import REGISTRY as CLASS_REGISTRY
for flag_name, spec in CLASS_REGISTRY.items():
if flag_name in ('vec', 'long'):
continue # Integer-valued in test; handled below.
if getattr(opt, flag_name, False):
opt.out_spec[spec.internal_name] = (spec.resolve_channels(opt),) + opt.outputsz

# Integer-valued channel flags (legacy CLI shape).
if opt.vec:
opt.out_spec['embedding'] = (opt.vec,) + opt.outputsz
if opt.aff:
opt.out_spec['affinity'] = (3,) + opt.outputsz
if opt.long:
opt.out_spec['long_range'] = (opt.long,) + opt.outputsz
if opt.aff_deprecated:
opt.out_spec['affinity'] = (opt.aff_deprecated,) + opt.outputsz
if opt.bdr:
opt.out_spec['boundary'] = (1,) + opt.outputsz
if opt.syn:
opt.out_spec['synapse'] = (1,) + opt.outputsz
if opt.psd:
opt.out_spec['synapse'] = (1,) + opt.outputsz
if opt.mit:
opt.out_spec['mitochondria'] = (1,) + opt.outputsz

# mito_to_cell also modifies in_spec.
if opt.mito_to_cell:
opt.in_spec['input_mitochondria'] = (1,) + opt.inputsz
opt.out_spec['mitochondria_to_cell'] = (1,) + opt.outputsz
if opt.mye:
opt.out_spec['myelin'] = (1,) + opt.outputsz
if opt.blv:
opt.out_spec['blood_vessel'] = (opt.blv_num_channels,) + opt.outputsz
if opt.glia:
opt.out_spec['glia'] = (1,) + opt.outputsz

# --sem: shorthand for enabling all combined semantic-segmentation heads.
if opt.sem:
opt.out_spec['soma'] = (1,) + opt.outputsz
opt.out_spec['axon'] = (1,) + opt.outputsz
opt.out_spec['dendrite'] = (1,) + opt.outputsz
opt.out_spec['glia'] = (1,) + opt.outputsz
opt.out_spec['bvessel'] = (1,) + opt.outputsz
if opt.img:
opt.out_spec['image'] = (1,) + opt.outputsz
if opt.dend:
opt.out_spec['dendrite'] = (1,) + opt.outputsz
if opt.axon:
opt.out_spec['axon'] = (1,) + opt.outputsz
if opt.soma:
opt.out_spec['soma'] = (1,) + opt.outputsz
if opt.nucl:
opt.out_spec['nucleus'] = (1,) + opt.outputsz
if opt.ecs:
opt.out_spec['extracellular_space'] = (1,) + opt.outputsz
if opt.other:
opt.out_spec['other_class'] = (1,) + opt.outputsz

# Semantic segmentation
if opt.semantic:
Expand Down
18 changes: 5 additions & 13 deletions deepem/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,11 @@ def make_forward_scanner(opt, data_name=None):
return ForwardScanner(dataset, opt.scan_spec, **opt.scan_params)


SEMANTIC_MAPPING = {
'dendrite': 1,
'axon': 2,
'soma': 3,
'nucleus': 4,
'glia': 5,
'extracellular_space': 6,
'blood_vessel': 7,
'other_class': 10,
}

# Training class_dict iteration order (deepem/train/option.py); chunkflow's
# channel_voting does argmax+1 over exactly this channel ordering.
from deepem.data.classes import semantic_mapping as _semantic_mapping
SEMANTIC_MAPPING = _semantic_mapping()

# chunkflow's channel_voting does argmax+1 over exactly this channel ordering.
# Frozen contract with chunkflow — do not auto-derive from REGISTRY.
ARGMAX_ORDER = [
'blood_vessel', 'glia', 'soma', 'dendrite', 'axon',
'nucleus', 'extracellular_space', 'other_class',
Expand Down
111 changes: 26 additions & 85 deletions deepem/train/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,43 +206,25 @@ def initialize(self):
self.parser.add_argument('--sr_scale_z', type=int, default=5,
help='Z upsampling factor (default: 5 for 40nm:8nm ratio)')

# Long-range affinity
self.parser.add_argument('--long', type=float, default=0)
# Long-range affinity edges (--long weight comes from class registry).
self.parser.add_argument('--edges', type=vec3, default=[], nargs='+')

# Multiclass detection
self.parser.add_argument('--aff', type=float, default=0) # Affinity
self.parser.add_argument('--bdr', type=float, default=0) # Boundary
self.parser.add_argument('--syn', type=float, default=0) # Synapse
self.parser.add_argument('--psd', type=float, default=0) # Synapse
self.parser.add_argument('--mit', type=float, default=0) # Mitochondria
self.parser.add_argument('--mye', type=float, default=0) # Myelin
self.parser.add_argument('--fld', type=float, default=0) # Fold
self.parser.add_argument('--blv', type=float, default=0) # Blood vessel
# Per-class loss weights — auto-declared from the central registry.
# To add a new task, add an entry to deepem/data/classes.py REGISTRY.
from deepem.data.classes import REGISTRY as CLASS_REGISTRY
for flag_name in CLASS_REGISTRY:
self.parser.add_argument(f'--{flag_name}', type=float, default=0)

# Auxiliary knobs referenced by class specs / dataset.
self.parser.add_argument('--blv_num_channels', type=int, default=1)
self.parser.add_argument('--glia', type=float, default=0) # Glia
self.parser.add_argument('--glia_mask', action='store_true')
self.parser.add_argument('--img', type=float, default=0) # Image
self.parser.add_argument('--mito_to_cell', type=float, default=0) # Mito to cell
self.parser.add_argument('--mito_to_cell_mode', type=str, default='random') # Mito to cell mode
self.parser.add_argument('--mito_to_cell_mode', type=str, default='random')
self.parser.add_argument('--merge_classes', type=str, default=[], nargs='+') # for onnx export
self.parser.add_argument('--embed_dim', type=int, default=12)
self.parser.add_argument('--mito_emb_dim', type=int, default=6)

# Semantic segmentation
self.parser.add_argument('--sem', action='store_true')
self.parser.add_argument('--dend', type=float, default=0) # Dendrite
self.parser.add_argument('--axon', type=float, default=0) # Axon
self.parser.add_argument('--soma', type=float, default=0) # Soma
self.parser.add_argument('--nucl', type=float, default=0) # Nucleus
self.parser.add_argument('--ecs', type=float, default=0) # Extracellular space
self.parser.add_argument('--other', type=float, default=0) # Other class

# Metric learning
self.parser.add_argument('--vec', type=float, default=0)
self.parser.add_argument('--embed_dim', type=int, default=12)

# Mitochondria embedding
self.parser.add_argument('--mito_emb', type=float, default=0)
self.parser.add_argument('--mito_emb_dim', type=int, default=6)

# Test training
self.parser.add_argument('--test', action='store_true')
Expand Down Expand Up @@ -382,56 +364,14 @@ def parse(self):
opt.aug_params['sr_mode'] = opt.sr_mode
opt.aug_params['sr_scale_z'] = opt.sr_scale_z

# Multiclass detection
class_keys = list()
class_dict = {
'aff': ('affinity', 3),
'long': ('long_range', len(opt.edges)),
'bdr': ('boundary', 1),
'syn': ('synapse', 1),
'psd': ('synapse', 1),
'mit': ('mitochondria', 1),
'mye': ('myelin', 1),
'fld': ('fold', 1),
'blv': ('blood_vessel', opt.blv_num_channels),
'glia': ('glia', 1),
'soma': ('soma', 1),
'img': ('image', 1),
'vec': ('embedding', opt.embed_dim),
'dend': ('dendrite', 1),
'axon': ('axon', 1),
'nucl': ('nucleus', 1),
'ecs': ('extracellular_space', 1),
'other': ('other_class', 1),
'mito_to_cell': ('mitochondria_to_cell', 1),
'mito_emb': ('mitochondria_embedding', opt.mito_emb_dim),
}

semantic_mapping = {
'dendrite': 1,
'axon': 2,
'soma': 3,
'nucleus': 4,
'glia': 5,
'extracellular_space': 6,
'blood_vessel': 7,
'other_class': 10,
}

requires_binarize = [
"synapse",
"mitochondria",
"myelin",
"fold",
"glia",
"soma",
]

if opt.blv_num_channels == 1:
requires_binarize.append("blood_vessel")

if opt.sem:
requires_binarize = [x for x in requires_binarize if x not in semantic_mapping]
# Multiclass detection (driven by deepem/data/classes.py REGISTRY).
from deepem.data.classes import (
REGISTRY as CLASS_REGISTRY,
semantic_mapping as _semantic_mapping_from_registry,
requires_binarize as _requires_binarize_from_registry,
)
semantic_mapping = _semantic_mapping_from_registry()
requires_binarize = _requires_binarize_from_registry(opt)

# Test training
if opt.test:
Expand All @@ -440,14 +380,15 @@ def parse(self):
opt.avgs_intv = 10
opt.imgs_intv = 100

for k, v in class_dict.items():
loss_w = args[k]
class_keys = list()
for flag_name, spec in CLASS_REGISTRY.items():
loss_w = args[flag_name]
if loss_w > 0:
output_name, num_channels = v
num_channels = spec.resolve_channels(opt)
assert num_channels > 0
opt.out_spec[output_name] = (num_channels,) + opt.outputsz
opt.loss_weight[output_name] = loss_w
class_keys.append(k)
opt.out_spec[spec.internal_name] = (num_channels,) + opt.outputsz
opt.loss_weight[spec.internal_name] = loss_w
class_keys.append(flag_name)

# Mito-to-cell assignment hacks
if opt.mito_to_cell > 0:
Expand Down