Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ analysis_types:
num_classes: 2
bands: ["B04", "B03", "B02", "B08"] # Red, Green, Blue, NIR
classes: ["non_forest", "forest"]
scl_clear_labels: [4, 5, 6, 11] # vegetation, bare soil, water, snow
thresholds:
alert_forest_loss: 5.0 # Alert if >5% forest loss
critical_forest_loss: 15.0 # Critical if >15% loss
Expand All @@ -38,6 +39,7 @@ analysis_types:
num_classes: 3
bands: ["B02", "B03", "B04", "B11"] # Blue, Green, Red, SWIR
classes: ["open_water", "sea_ice", "land"]
scl_clear_labels: [4, 5, 6, 11] # vegetation, bare soil, water, snow
thresholds:
alert_ice_loss: 10.0 # Alert if >10% ice loss
critical_ice_loss: 25.0 # Critical if >25% loss
Expand All @@ -60,12 +62,19 @@ analysis_types:
display_name: "Flood Detection"
description: "Detect and monitor flooding events and affected areas"
model:
architecture: "unet"
architecture: "flood_unet"
weights: "models/unet_flood.pth"
in_channels: 3
num_classes: 3
model_sar:
architecture: "flood_unet"
weights: "models/unet_flood_sar.pth"
in_channels: 5
num_classes: 3
bands: ["B03", "B08", "B11"] # Green, NIR, SWIR
sar_bands: ["VV", "VH"]
classes: ["dry_land", "permanent_water", "flooded"]
scl_clear_labels: [4, 5, 6] # vegetation, bare soil, water (NO snow/ice)
thresholds:
alert_flood_area: 5.0 # Alert if >5% area flooded
critical_flood_area: 20.0 # Critical if >20% flooded
Expand All @@ -74,6 +83,7 @@ analysis_types:
- "flooded_percentage"
- "flooded_area_km2"
- "mndwi_stats"
- "affected_road_km"

# Drought Monitoring
drought:
Expand Down Expand Up @@ -153,6 +163,13 @@ satellite:
cloud_coverage_max: 20 # percentage
revisit_time: 5 # days

sentinel1:
bands: ["VV", "VH"]
resolution: 10 # meters
revisit_time: 6 # days
polarization: ["VV", "VH"]
orbit: "DESCENDING"

landsat8:
bands: ["B4", "B3", "B2", "B5"]
resolution: 30 # meters
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ fiona>=1.9.0
opencv-python>=4.5.0
pillow>=9.0.0
albumentations>=1.3.0
segmentation-models-pytorch>=0.3.3
timm>=0.9.0
scipy>=1.10.0

# Visualization
matplotlib>=3.5.0
Expand Down
51 changes: 42 additions & 9 deletions scripts/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,30 +46,63 @@
# Model loading
# ---------------------------------------------------------------------------

def _load_run_config(ckpt_path: Path) -> dict:
"""Load the full training config from the run directory."""
run_dir = ckpt_path.parent
config_path = run_dir / "config.yaml"
if config_path.exists():
try:
import yaml
with open(config_path) as f:
return yaml.safe_load(f) or {}
except Exception:
pass
return {}


def load_model(ckpt_path: Path) -> tuple[nn.Module, dict]:
from climatevision.models.unet import get_model
from climatevision.models.flood_unet import build_flood_model

ckpt = torch.load(ckpt_path, map_location="cpu")
cfg = ckpt.get("cfg", {})
# Trainer cfg is in ckpt['cfg']; full model cfg is in config.yaml next to checkpoint
cfg = _load_run_config(ckpt_path)
if not cfg:
cfg = ckpt.get("cfg", {})

arch = cfg.get("model", {}).get("architecture", "attention_unet")
state = ckpt.get("ema_state_dict") or ckpt.get("model_state_dict", ckpt)

# Infer in_channels from weight shape
in_ch = 4
# Infer in_channels and n_classes from weight shape
in_ch = cfg.get("model", {}).get("in_channels", 4)
n_classes = cfg.get("model", {}).get("num_classes", 2)
for key, val in state.items():
if "inc" in key and "weight" in key and val.ndim == 4:
in_ch = val.shape[1]
break
if val.ndim == 4:
if val.shape[1] in (3, 4, 5) and "encoder" not in key and "down" not in key:
# first conv layer — input channels
in_ch = val.shape[1]
if val.shape[0] in (2, 3) and ("outc" in key or "segmentation_head" in key or "classifier" in key):
# final conv layer — output classes
n_classes = val.shape[0]

# Flood models use smp-based architectures
if arch in ("flood_unet", "flood_unet_s2only"):
use_sar = in_ch == 5
model = build_flood_model(
use_sar=use_sar,
encoder_name=cfg.get("model", {}).get("encoder", "efficientnet-b7"),
)
else:
model = get_model(arch, n_channels=in_ch, n_classes=n_classes)

model = get_model(arch, n_channels=in_ch, n_classes=2)
model.load_state_dict(state, strict=False)
model.eval()

logger.info(
"Loaded %s (in_channels=%d) from epoch %d val_iou=%.4f",
"Loaded %s (in_channels=%d, classes=%d) from epoch %d val_iou=%.4f",
arch,
in_ch,
n_classes,
ckpt.get("epoch", 0),
ckpt.get("val_iou", 0.0),
)
Expand Down Expand Up @@ -233,7 +266,7 @@ def main() -> None:
"checkpoint": str(ckpt_path),
"architecture": cfg.get("model", {}).get("architecture", "unknown"),
"in_channels": in_channels,
"num_classes": 2,
"num_classes": cfg.get("model", {}).get("num_classes", 2),
"image_size": image_size,
"onnx_opset": args.opset,
"onnx_path": str(onnx_path),
Expand Down
Loading
Loading