From b4dce78eeeac3cb7731e19cd27879e335eb773d4 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Thu, 14 May 2026 15:13:31 +0900 Subject: [PATCH 1/2] feat: central class registry replaces per-class CLI flags Adding a new detection task no longer requires editing argparse declarations, class_dict, semantic_mapping, or requires_binarize. A new task is now a one-line entry in deepem/data/classes.py REGISTRY: 'ribo': ClassSpec(out_name='ribosome', binarize=True) The matching --ribo flag (train: float weight; test: store_true) becomes available automatically, and downstream wiring (out_spec, loss_weight, requires_binarize, semantic_mapping) is derived from the registry. Existing flags (--psd, --glia, --syn, etc.) remain unchanged, so seuron's argv generation needs no modification. train/option.py: - per-class argparse declarations replaced with a REGISTRY loop - hardcoded class_dict / semantic_mapping / requires_binarize replaced with calls into deepem.data.classes - --long moved out of the Long-range section (it lives in REGISTRY) test/option.py: - per-class store_true declarations replaced with a REGISTRY loop (--vec and --long stay explicit since they are integer-valued) - if-chain over class flags replaced with a REGISTRY loop; integer-channel and side-effect cases (vec, long, aff_deprecated, mito_to_cell, sem) kept as explicit blocks test/utils.py: - SEMANTIC_MAPPING derived from REGISTRY via semantic_mapping() - ARGMAX_ORDER kept hardcoded (frozen contract with chunkflow) Co-Authored-By: Claude Opus 4.7 (1M context) --- deepem/data/classes.py | 102 +++++++++++++++++++++++++++++++++++++ deepem/test/option.py | 72 +++++++++----------------- deepem/test/utils.py | 18 ++----- deepem/train/option.py | 111 ++++++++++------------------------------- 4 files changed, 157 insertions(+), 146 deletions(-) create mode 100644 deepem/data/classes.py diff --git a/deepem/data/classes.py b/deepem/data/classes.py new file mode 100644 index 0000000..a756118 --- /dev/null +++ b/deepem/data/classes.py @@ -0,0 +1,102 @@ +"""Central registry of detection/segmentation classes. + +Adding a new task (e.g., "psd_v2", "ribosome") is a one-line edit: + + 'ribo': ClassSpec(out_name='ribosome', binarize=True), + +The matching `--ribo ` CLI flag (train) and `--ribo` switch (test) +become available automatically; no edits to option.py needed. +""" +from dataclasses import dataclass +from typing import Callable, Optional, Union + + +@dataclass(frozen=True) +class ClassSpec: + """Specification for a detection/segmentation class. + + Attributes: + out_name: Output spec key. Used by model heads, loss, and as the + annotation name looked up in the dataset. + channels: Output channels. Either an int, or a callable + opt -> int for classes whose width depends on another flag + (e.g. blood_vessel takes opt.blv_num_channels). + binarize: Whether the dataset should binarize the annotation. + Bool or callable opt -> bool. + semantic_id: If set, this class participates in semantic_mapping + when --sem is on (label ID used by the combined-semantic GT). + """ + out_name: str + channels: Union[int, Callable] = 1 + binarize: Union[bool, Callable] = False + semantic_id: Optional[int] = None + + def resolve_channels(self, opt) -> int: + return self.channels(opt) if callable(self.channels) else self.channels + + def resolve_binarize(self, opt) -> bool: + return self.binarize(opt) if callable(self.binarize) else self.binarize + + +REGISTRY: dict[str, ClassSpec] = { + # Affinity / boundary + 'aff': ClassSpec(out_name='affinity', channels=3), + 'long': ClassSpec(out_name='long_range', channels=lambda opt: len(opt.edges)), + 'bdr': ClassSpec(out_name='boundary'), + + # Detection (binary) + 'syn': ClassSpec(out_name='synapse', binarize=True), + 'psd': ClassSpec(out_name='synapse', binarize=True), + 'mit': ClassSpec(out_name='mitochondria', binarize=True), + 'mye': ClassSpec(out_name='myelin', binarize=True), + 'fld': ClassSpec(out_name='fold', binarize=True), + 'glia': ClassSpec(out_name='glia', binarize=True, semantic_id=5), + 'img': ClassSpec(out_name='image'), + 'mito_to_cell': ClassSpec(out_name='mitochondria_to_cell'), + + # Multi-channel blood vessel: binarize only when collapsed to 1 channel + 'blv': ClassSpec( + out_name='blood_vessel', + channels=lambda opt: opt.blv_num_channels, + binarize=lambda opt: opt.blv_num_channels == 1, + semantic_id=7, + ), + + # Semantic-segmentation classes (binarize only when --sem is off) + 'soma': ClassSpec(out_name='soma', binarize=True, semantic_id=3), + 'dend': ClassSpec(out_name='dendrite', semantic_id=1), + 'axon': ClassSpec(out_name='axon', semantic_id=2), + 'nucl': ClassSpec(out_name='nucleus', semantic_id=4), + 'ecs': ClassSpec(out_name='extracellular_space', semantic_id=6), + 'other': ClassSpec(out_name='other_class', semantic_id=10), + + # Embedding (consumed by MeanLoss / metric learning) + 'vec': ClassSpec(out_name='embedding', channels=lambda opt: opt.embed_dim), + 'mito_emb': ClassSpec(out_name='mitochondria_embedding', channels=lambda opt: opt.mito_emb_dim), +} + + +def semantic_mapping() -> dict[str, int]: + """Class-name -> label-ID dict, derived from semantic_id fields.""" + return { + spec.out_name: spec.semantic_id + for spec in REGISTRY.values() + if spec.semantic_id is not None + } + + +def requires_binarize(opt) -> list[str]: + """Output names whose dataset annotation must be binarized. + + Mirrors the legacy hand-maintained list: includes classes with + binarize=True, then drops any that are taken over by semantic_mapping + when --sem is on. + """ + names = { + spec.out_name + for spec in REGISTRY.values() + if spec.resolve_binarize(opt) + } + if getattr(opt, 'sem', False): + names -= set(semantic_mapping().keys()) + return sorted(names) diff --git a/deepem/test/option.py b/deepem/test/option.py index 7c26edf..d4bd3ee 100644 --- a/deepem/test/option.py +++ b/deepem/test/option.py @@ -55,32 +55,24 @@ def initialize(self): self.parser.add_argument('--delta_d', type=float, default=1.5) self.parser.add_argument('--scale_init', type=float, default=1.0) - # Multiclass detection - self.parser.add_argument('--aff', action='store_true') + # Per-class enable flags — auto-declared from the central registry. + # `vec` and `long` take integer channel counts (legacy), so they are + # declared explicitly and skipped here. + from deepem.data.classes import REGISTRY as CLASS_REGISTRY + for flag_name in CLASS_REGISTRY: + if flag_name in ('vec', 'long'): + continue + self.parser.add_argument(f'--{flag_name}', action='store_true') + self.parser.add_argument('--long', type=int, default=0) self.parser.add_argument('--aff_deprecated', type=int, default=None) - self.parser.add_argument('--bdr', action='store_true') - self.parser.add_argument('--syn', action='store_true') - self.parser.add_argument('--psd', action='store_true') - self.parser.add_argument('--mit', action='store_true') - self.parser.add_argument('--mito_to_cell', action='store_true') - self.parser.add_argument('--mye', action='store_true') self.parser.add_argument('--mye_thresh', type=float, default=0.5) - self.parser.add_argument('--blv', action='store_true') self.parser.add_argument('--blv_num_channels', type=int, default=1) - self.parser.add_argument('--glia', action='store_true') self.parser.add_argument('--sem', action='store_true') - self.parser.add_argument('--img', action='store_true') self.parser.add_argument('--merge_classes', type=str, default=[], nargs='+') # Semantic segmentation self.parser.add_argument('--semantic', action='store_true') - self.parser.add_argument('--dend', action='store_true') # Dendrite - self.parser.add_argument('--axon', action='store_true') # Axon - self.parser.add_argument('--soma', action='store_true') # Soma - self.parser.add_argument('--nucl', action='store_true') # Nucleus - self.parser.add_argument('--ecs', action='store_true') # Extracellular space - self.parser.add_argument('--other', action='store_true') # Other class # Test-time augmentation self.parser.add_argument('--test_aug', type=int, default=None, nargs='+') @@ -224,49 +216,33 @@ def parse(self): opt.inputsz = (aniso_z, opt.inputsz[1], opt.inputsz[2]) opt.in_spec = dict(input=(1,) + opt.inputsz) + # Per-class out_spec entries (driven by class registry). + from deepem.data.classes import REGISTRY as CLASS_REGISTRY + for flag_name, spec in CLASS_REGISTRY.items(): + if flag_name in ('vec', 'long'): + continue # Integer-valued in test; handled below. + if getattr(opt, flag_name, False): + opt.out_spec[spec.out_name] = (spec.resolve_channels(opt),) + opt.outputsz + + # Integer-valued channel flags (legacy CLI shape). if opt.vec: opt.out_spec['embedding'] = (opt.vec,) + opt.outputsz - if opt.aff: - opt.out_spec['affinity'] = (3,) + opt.outputsz + if opt.long: + opt.out_spec['long_range'] = (opt.long,) + opt.outputsz if opt.aff_deprecated: opt.out_spec['affinity'] = (opt.aff_deprecated,) + opt.outputsz - if opt.bdr: - opt.out_spec['boundary'] = (1,) + opt.outputsz - if opt.syn: - opt.out_spec['synapse'] = (1,) + opt.outputsz - if opt.psd: - opt.out_spec['synapse'] = (1,) + opt.outputsz - if opt.mit: - opt.out_spec['mitochondria'] = (1,) + opt.outputsz + + # mito_to_cell also modifies in_spec. if opt.mito_to_cell: opt.in_spec['input_mitochondria'] = (1,) + opt.inputsz - opt.out_spec['mitochondria_to_cell'] = (1,) + opt.outputsz - if opt.mye: - opt.out_spec['myelin'] = (1,) + opt.outputsz - if opt.blv: - opt.out_spec['blood_vessel'] = (opt.blv_num_channels,) + opt.outputsz - if opt.glia: - opt.out_spec['glia'] = (1,) + opt.outputsz + + # --sem: shorthand for enabling all combined semantic-segmentation heads. if opt.sem: opt.out_spec['soma'] = (1,) + opt.outputsz opt.out_spec['axon'] = (1,) + opt.outputsz opt.out_spec['dendrite'] = (1,) + opt.outputsz opt.out_spec['glia'] = (1,) + opt.outputsz opt.out_spec['bvessel'] = (1,) + opt.outputsz - if opt.img: - opt.out_spec['image'] = (1,) + opt.outputsz - if opt.dend: - opt.out_spec['dendrite'] = (1,) + opt.outputsz - if opt.axon: - opt.out_spec['axon'] = (1,) + opt.outputsz - if opt.soma: - opt.out_spec['soma'] = (1,) + opt.outputsz - if opt.nucl: - opt.out_spec['nucleus'] = (1,) + opt.outputsz - if opt.ecs: - opt.out_spec['extracellular_space'] = (1,) + opt.outputsz - if opt.other: - opt.out_spec['other_class'] = (1,) + opt.outputsz # Semantic segmentation if opt.semantic: diff --git a/deepem/test/utils.py b/deepem/test/utils.py index af5cec6..b678951 100644 --- a/deepem/test/utils.py +++ b/deepem/test/utils.py @@ -93,19 +93,11 @@ def make_forward_scanner(opt, data_name=None): return ForwardScanner(dataset, opt.scan_spec, **opt.scan_params) -SEMANTIC_MAPPING = { - 'dendrite': 1, - 'axon': 2, - 'soma': 3, - 'nucleus': 4, - 'glia': 5, - 'extracellular_space': 6, - 'blood_vessel': 7, - 'other_class': 10, -} - -# Training class_dict iteration order (deepem/train/option.py); chunkflow's -# channel_voting does argmax+1 over exactly this channel ordering. +from deepem.data.classes import semantic_mapping as _semantic_mapping +SEMANTIC_MAPPING = _semantic_mapping() + +# chunkflow's channel_voting does argmax+1 over exactly this channel ordering. +# Frozen contract with chunkflow — do not auto-derive from REGISTRY. ARGMAX_ORDER = [ 'blood_vessel', 'glia', 'soma', 'dendrite', 'axon', 'nucleus', 'extracellular_space', 'other_class', diff --git a/deepem/train/option.py b/deepem/train/option.py index ac31b42..c4f6e76 100644 --- a/deepem/train/option.py +++ b/deepem/train/option.py @@ -206,43 +206,25 @@ def initialize(self): self.parser.add_argument('--sr_scale_z', type=int, default=5, help='Z upsampling factor (default: 5 for 40nm:8nm ratio)') - # Long-range affinity - self.parser.add_argument('--long', type=float, default=0) + # Long-range affinity edges (--long weight comes from class registry). self.parser.add_argument('--edges', type=vec3, default=[], nargs='+') - # Multiclass detection - self.parser.add_argument('--aff', type=float, default=0) # Affinity - self.parser.add_argument('--bdr', type=float, default=0) # Boundary - self.parser.add_argument('--syn', type=float, default=0) # Synapse - self.parser.add_argument('--psd', type=float, default=0) # Synapse - self.parser.add_argument('--mit', type=float, default=0) # Mitochondria - self.parser.add_argument('--mye', type=float, default=0) # Myelin - self.parser.add_argument('--fld', type=float, default=0) # Fold - self.parser.add_argument('--blv', type=float, default=0) # Blood vessel + # Per-class loss weights — auto-declared from the central registry. + # To add a new task, add an entry to deepem/data/classes.py REGISTRY. + from deepem.data.classes import REGISTRY as CLASS_REGISTRY + for flag_name in CLASS_REGISTRY: + self.parser.add_argument(f'--{flag_name}', type=float, default=0) + + # Auxiliary knobs referenced by class specs / dataset. self.parser.add_argument('--blv_num_channels', type=int, default=1) - self.parser.add_argument('--glia', type=float, default=0) # Glia self.parser.add_argument('--glia_mask', action='store_true') - self.parser.add_argument('--img', type=float, default=0) # Image - self.parser.add_argument('--mito_to_cell', type=float, default=0) # Mito to cell - self.parser.add_argument('--mito_to_cell_mode', type=str, default='random') # Mito to cell mode + self.parser.add_argument('--mito_to_cell_mode', type=str, default='random') self.parser.add_argument('--merge_classes', type=str, default=[], nargs='+') # for onnx export + self.parser.add_argument('--embed_dim', type=int, default=12) + self.parser.add_argument('--mito_emb_dim', type=int, default=6) # Semantic segmentation self.parser.add_argument('--sem', action='store_true') - self.parser.add_argument('--dend', type=float, default=0) # Dendrite - self.parser.add_argument('--axon', type=float, default=0) # Axon - self.parser.add_argument('--soma', type=float, default=0) # Soma - self.parser.add_argument('--nucl', type=float, default=0) # Nucleus - self.parser.add_argument('--ecs', type=float, default=0) # Extracellular space - self.parser.add_argument('--other', type=float, default=0) # Other class - - # Metric learning - self.parser.add_argument('--vec', type=float, default=0) - self.parser.add_argument('--embed_dim', type=int, default=12) - - # Mitochondria embedding - self.parser.add_argument('--mito_emb', type=float, default=0) - self.parser.add_argument('--mito_emb_dim', type=int, default=6) # Test training self.parser.add_argument('--test', action='store_true') @@ -382,56 +364,14 @@ def parse(self): opt.aug_params['sr_mode'] = opt.sr_mode opt.aug_params['sr_scale_z'] = opt.sr_scale_z - # Multiclass detection - class_keys = list() - class_dict = { - 'aff': ('affinity', 3), - 'long': ('long_range', len(opt.edges)), - 'bdr': ('boundary', 1), - 'syn': ('synapse', 1), - 'psd': ('synapse', 1), - 'mit': ('mitochondria', 1), - 'mye': ('myelin', 1), - 'fld': ('fold', 1), - 'blv': ('blood_vessel', opt.blv_num_channels), - 'glia': ('glia', 1), - 'soma': ('soma', 1), - 'img': ('image', 1), - 'vec': ('embedding', opt.embed_dim), - 'dend': ('dendrite', 1), - 'axon': ('axon', 1), - 'nucl': ('nucleus', 1), - 'ecs': ('extracellular_space', 1), - 'other': ('other_class', 1), - 'mito_to_cell': ('mitochondria_to_cell', 1), - 'mito_emb': ('mitochondria_embedding', opt.mito_emb_dim), - } - - semantic_mapping = { - 'dendrite': 1, - 'axon': 2, - 'soma': 3, - 'nucleus': 4, - 'glia': 5, - 'extracellular_space': 6, - 'blood_vessel': 7, - 'other_class': 10, - } - - requires_binarize = [ - "synapse", - "mitochondria", - "myelin", - "fold", - "glia", - "soma", - ] - - if opt.blv_num_channels == 1: - requires_binarize.append("blood_vessel") - - if opt.sem: - requires_binarize = [x for x in requires_binarize if x not in semantic_mapping] + # Multiclass detection (driven by deepem/data/classes.py REGISTRY). + from deepem.data.classes import ( + REGISTRY as CLASS_REGISTRY, + semantic_mapping as _semantic_mapping_from_registry, + requires_binarize as _requires_binarize_from_registry, + ) + semantic_mapping = _semantic_mapping_from_registry() + requires_binarize = _requires_binarize_from_registry(opt) # Test training if opt.test: @@ -440,14 +380,15 @@ def parse(self): opt.avgs_intv = 10 opt.imgs_intv = 100 - for k, v in class_dict.items(): - loss_w = args[k] + class_keys = list() + for flag_name, spec in CLASS_REGISTRY.items(): + loss_w = args[flag_name] if loss_w > 0: - output_name, num_channels = v + num_channels = spec.resolve_channels(opt) assert num_channels > 0 - opt.out_spec[output_name] = (num_channels,) + opt.outputsz - opt.loss_weight[output_name] = loss_w - class_keys.append(k) + opt.out_spec[spec.out_name] = (num_channels,) + opt.outputsz + opt.loss_weight[spec.out_name] = loss_w + class_keys.append(flag_name) # Mito-to-cell assignment hacks if opt.mito_to_cell > 0: From b6791474fab06f8e6e79e91248ec5955302b386a Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Thu, 14 May 2026 16:27:30 +0900 Subject: [PATCH 2/2] refactor(classes): rename ClassSpec.out_name to internal_name MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This name is not just an "out spec key" — it's the load-bearing identifier used by: - PyTorch output submodule (and therefore state_dict / ONNX output names) - wandb metric name - loss criterion key - dataset key consumed by the sampler - LHS of zettaset_lookup The CLI flag (the REGISTRY dict key) can differ freely from internal_name, e.g. `apex_mit` -> `apex_mitochondria`. Renaming the field makes the distinction explicit at the call site. Co-Authored-By: Claude Opus 4.7 (1M context) --- deepem/data/classes.py | 62 ++++++++++++++++++++++++------------------ deepem/test/option.py | 2 +- deepem/train/option.py | 4 +-- 3 files changed, 39 insertions(+), 29 deletions(-) diff --git a/deepem/data/classes.py b/deepem/data/classes.py index a756118..d1c5e37 100644 --- a/deepem/data/classes.py +++ b/deepem/data/classes.py @@ -1,8 +1,14 @@ """Central registry of detection/segmentation classes. +Each entry maps a CLI flag (dict key) to its internal name (`internal_name`). +The internal name is the load-bearing identifier: it names the PyTorch output +submodule (and therefore the state_dict / ONNX output), the wandb metric, the +loss criterion, the dataset key used by the sampler, and the LHS of +`zettaset_lookup`. Flag and internal name can differ freely. + Adding a new task (e.g., "psd_v2", "ribosome") is a one-line edit: - 'ribo': ClassSpec(out_name='ribosome', binarize=True), + 'ribo': ClassSpec(internal_name='ribosome', binarize=True), The matching `--ribo ` CLI flag (train) and `--ribo` switch (test) become available automatically; no edits to option.py needed. @@ -16,8 +22,12 @@ class ClassSpec: """Specification for a detection/segmentation class. Attributes: - out_name: Output spec key. Used by model heads, loss, and as the - annotation name looked up in the dataset. + internal_name: The single load-bearing identifier for this class. + Used as the PyTorch output submodule name (and state_dict / + ONNX key), the wandb metric name, the loss criterion key, the + dataset key consumed by the sampler, and the LHS of + `zettaset_lookup`. Distinct from the CLI flag (the REGISTRY + dict key). channels: Output channels. Either an int, or a callable opt -> int for classes whose width depends on another flag (e.g. blood_vessel takes opt.blv_num_channels). @@ -26,7 +36,7 @@ class ClassSpec: semantic_id: If set, this class participates in semantic_mapping when --sem is on (label ID used by the combined-semantic GT). """ - out_name: str + internal_name: str channels: Union[int, Callable] = 1 binarize: Union[bool, Callable] = False semantic_id: Optional[int] = None @@ -40,46 +50,46 @@ def resolve_binarize(self, opt) -> bool: REGISTRY: dict[str, ClassSpec] = { # Affinity / boundary - 'aff': ClassSpec(out_name='affinity', channels=3), - 'long': ClassSpec(out_name='long_range', channels=lambda opt: len(opt.edges)), - 'bdr': ClassSpec(out_name='boundary'), + 'aff': ClassSpec(internal_name='affinity', channels=3), + 'long': ClassSpec(internal_name='long_range', channels=lambda opt: len(opt.edges)), + 'bdr': ClassSpec(internal_name='boundary'), # Detection (binary) - 'syn': ClassSpec(out_name='synapse', binarize=True), - 'psd': ClassSpec(out_name='synapse', binarize=True), - 'mit': ClassSpec(out_name='mitochondria', binarize=True), - 'mye': ClassSpec(out_name='myelin', binarize=True), - 'fld': ClassSpec(out_name='fold', binarize=True), - 'glia': ClassSpec(out_name='glia', binarize=True, semantic_id=5), - 'img': ClassSpec(out_name='image'), - 'mito_to_cell': ClassSpec(out_name='mitochondria_to_cell'), + 'syn': ClassSpec(internal_name='synapse', binarize=True), + 'psd': ClassSpec(internal_name='synapse', binarize=True), + 'mit': ClassSpec(internal_name='mitochondria', binarize=True), + 'mye': ClassSpec(internal_name='myelin', binarize=True), + 'fld': ClassSpec(internal_name='fold', binarize=True), + 'glia': ClassSpec(internal_name='glia', binarize=True, semantic_id=5), + 'img': ClassSpec(internal_name='image'), + 'mito_to_cell': ClassSpec(internal_name='mitochondria_to_cell'), # Multi-channel blood vessel: binarize only when collapsed to 1 channel 'blv': ClassSpec( - out_name='blood_vessel', + internal_name='blood_vessel', channels=lambda opt: opt.blv_num_channels, binarize=lambda opt: opt.blv_num_channels == 1, semantic_id=7, ), # Semantic-segmentation classes (binarize only when --sem is off) - 'soma': ClassSpec(out_name='soma', binarize=True, semantic_id=3), - 'dend': ClassSpec(out_name='dendrite', semantic_id=1), - 'axon': ClassSpec(out_name='axon', semantic_id=2), - 'nucl': ClassSpec(out_name='nucleus', semantic_id=4), - 'ecs': ClassSpec(out_name='extracellular_space', semantic_id=6), - 'other': ClassSpec(out_name='other_class', semantic_id=10), + 'soma': ClassSpec(internal_name='soma', binarize=True, semantic_id=3), + 'dend': ClassSpec(internal_name='dendrite', semantic_id=1), + 'axon': ClassSpec(internal_name='axon', semantic_id=2), + 'nucl': ClassSpec(internal_name='nucleus', semantic_id=4), + 'ecs': ClassSpec(internal_name='extracellular_space', semantic_id=6), + 'other': ClassSpec(internal_name='other_class', semantic_id=10), # Embedding (consumed by MeanLoss / metric learning) - 'vec': ClassSpec(out_name='embedding', channels=lambda opt: opt.embed_dim), - 'mito_emb': ClassSpec(out_name='mitochondria_embedding', channels=lambda opt: opt.mito_emb_dim), + 'vec': ClassSpec(internal_name='embedding', channels=lambda opt: opt.embed_dim), + 'mito_emb': ClassSpec(internal_name='mitochondria_embedding', channels=lambda opt: opt.mito_emb_dim), } def semantic_mapping() -> dict[str, int]: """Class-name -> label-ID dict, derived from semantic_id fields.""" return { - spec.out_name: spec.semantic_id + spec.internal_name: spec.semantic_id for spec in REGISTRY.values() if spec.semantic_id is not None } @@ -93,7 +103,7 @@ def requires_binarize(opt) -> list[str]: when --sem is on. """ names = { - spec.out_name + spec.internal_name for spec in REGISTRY.values() if spec.resolve_binarize(opt) } diff --git a/deepem/test/option.py b/deepem/test/option.py index d4bd3ee..5e9fb3e 100644 --- a/deepem/test/option.py +++ b/deepem/test/option.py @@ -222,7 +222,7 @@ def parse(self): if flag_name in ('vec', 'long'): continue # Integer-valued in test; handled below. if getattr(opt, flag_name, False): - opt.out_spec[spec.out_name] = (spec.resolve_channels(opt),) + opt.outputsz + opt.out_spec[spec.internal_name] = (spec.resolve_channels(opt),) + opt.outputsz # Integer-valued channel flags (legacy CLI shape). if opt.vec: diff --git a/deepem/train/option.py b/deepem/train/option.py index c4f6e76..a2aec2a 100644 --- a/deepem/train/option.py +++ b/deepem/train/option.py @@ -386,8 +386,8 @@ def parse(self): if loss_w > 0: num_channels = spec.resolve_channels(opt) assert num_channels > 0 - opt.out_spec[spec.out_name] = (num_channels,) + opt.outputsz - opt.loss_weight[spec.out_name] = loss_w + opt.out_spec[spec.internal_name] = (num_channels,) + opt.outputsz + opt.loss_weight[spec.internal_name] = loss_w class_keys.append(flag_name) # Mito-to-cell assignment hacks