Skip to content

Commit b9a6fa5

Browse files
Jammy2211claude
authored andcommitted
Overhaul 2D plot styling and subplot layout
- Default colormap driven by segmentdata.py via config (autoarray) - Colorbar: 3 ticks (min/mid/max), unit on middle, configurable fraction/pad/rotation/labelsize (separate figure vs subplot sizes) - Tick labels: arcsecond suffix, inward positioning, 2 d.p. formatting, vertical y-axis labels - Axes box aspect set from data extent (restores old square-panel behaviour, eliminates subplot whitespace) - hide_unused_axes() called in all fixed-grid subplot functions - subplot_imaging_dataset consolidated to 3x3 grid; subplot_imaging removed - Position dots default to black - save_reconstruction_csv() extracted to inversion_plots.py Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent e007e8c commit b9a6fa5

11 files changed

Lines changed: 284 additions & 58 deletions

File tree

autoarray/config/visualize/general.yaml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,27 @@ units:
1515
cb_unit: $\,\,\mathrm{e^{-}}\,\mathrm{s^{-1}}$ # The string or latex unit label used for the colorbar of the image, for example electrons per second.
1616
scaled_symbol: '"' # The symbol used when plotting spatial coordinates computed via the pixel_scale (e.g. for Astronomy data this is arc-seconds).
1717
unscaled_symbol: pix # The symbol used when plotting spatial coordinates in unscaled pixel units.
18+
colormap: autoarray # Default colormap for 2D plots. Use any matplotlib name to override (e.g. jet, viridis).
19+
ticks:
20+
extent_factor_2d: 0.75 # Fraction of half-extent used for 2D tick positions (< 1.0 pulls ticks inward from edges).
21+
number_of_ticks_2d: 3 # Number of ticks on each spatial axis of 2D plots.
22+
colorbar:
23+
fraction: 0.047 # Fraction of original axes to use for the colorbar.
24+
pad: 0.01 # Padding between colorbar and axes.
25+
labelrotation: 90 # Rotation of colorbar tick labels in degrees.
26+
labelsize: 22 # Font size of colorbar tick labels for single-panel figures.
27+
labelsize_subplot: 24 # Font size of colorbar tick labels for subplot panels.
1828
mat_plot:
1929
figure:
2030
figsize: (7, 7) # Default figure size. Override via aplt.Figure(figsize=(...)).
2131
yticks:
2232
fontsize: 22 # Default y-tick font size. Override via aplt.YTicks(fontsize=...).
33+
yticks_subplot:
34+
fontsize: 22 # Default y-tick font size for subplot panels.
2335
xticks:
2436
fontsize: 22 # Default x-tick font size. Override via aplt.XTicks(fontsize=...).
37+
xticks_subplot:
38+
fontsize: 22 # Default x-tick font size for subplot panels.
2539
title:
2640
fontsize: 24 # Default title font size. Override via aplt.Title(fontsize=...).
2741
ylabel:

autoarray/dataset/plot/imaging_plots.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,15 @@ def subplot_imaging_dataset(
9292
title="Point Spread Function",
9393
colormap=colormap,
9494
use_log10=use_log10,
95+
cb_unit="",
9596
)
9697
plot_array(
9798
dataset.psf.kernel,
9899
ax=axes[4],
99100
title="PSF (log10)",
100101
colormap=colormap,
101102
use_log10=True,
103+
cb_unit="",
102104
)
103105

104106
plot_array(
@@ -107,6 +109,7 @@ def subplot_imaging_dataset(
107109
title="Signal-To-Noise Map",
108110
colormap=colormap,
109111
use_log10=use_log10,
112+
cb_unit="",
110113
grid=grid,
111114
positions=positions,
112115
lines=lines,
@@ -120,6 +123,7 @@ def subplot_imaging_dataset(
120123
title="Over Sample Size (Light Profiles)",
121124
colormap=colormap,
122125
use_log10=use_log10,
126+
cb_unit="",
123127
)
124128

125129
over_sample_size_pix = getattr(getattr(dataset, "grids", None), "over_sample_size_pixelization", None)
@@ -130,8 +134,11 @@ def subplot_imaging_dataset(
130134
title="Over Sample Size (Pixelization)",
131135
colormap=colormap,
132136
use_log10=use_log10,
137+
cb_unit="",
133138
)
134139

140+
from autoarray.plot.utils import hide_unused_axes
141+
hide_unused_axes(axes)
135142
plt.tight_layout()
136143
subplot_save(fig, output_path, output_filename, output_format)
137144

autoarray/dataset/plot/interferometer_plots.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from autoarray.plot.array import plot_array
77
from autoarray.plot.grid import plot_grid
88
from autoarray.plot.yx import plot_yx
9-
from autoarray.plot.utils import subplot_save
9+
from autoarray.plot.utils import subplot_save, hide_unused_axes
1010
from autoarray.structures.grids.irregular_2d import Grid2DIrregular
1111

1212

@@ -84,6 +84,7 @@ def subplot_interferometer_dataset(
8484
use_log10=use_log10,
8585
)
8686

87+
hide_unused_axes(axes)
8788
plt.tight_layout()
8889
subplot_save(fig, output_path, output_filename, output_format)
8990

@@ -138,5 +139,6 @@ def subplot_interferometer_dirty_images(
138139
use_log10=use_log10,
139140
)
140141

142+
hide_unused_axes(axes)
141143
plt.tight_layout()
142144
subplot_save(fig, output_path, output_filename, output_format)

autoarray/fit/plot/fit_imaging_plots.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import matplotlib.pyplot as plt
44

55
from autoarray.plot.array import plot_array
6-
from autoarray.plot.utils import subplot_save, symmetric_vmin_vmax
6+
from autoarray.plot.utils import subplot_save, symmetric_vmin_vmax, hide_unused_axes
77

88

99
def subplot_fit_imaging(
@@ -118,5 +118,6 @@ def subplot_fit_imaging(
118118
lines=lines,
119119
)
120120

121+
hide_unused_axes(axes)
121122
plt.tight_layout()
122123
subplot_save(fig, output_path, output_filename, output_format)

autoarray/fit/plot/fit_interferometer_plots.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from autoarray.plot.array import plot_array
77
from autoarray.plot.yx import plot_yx
8-
from autoarray.plot.utils import subplot_save, symmetric_vmin_vmax
8+
from autoarray.plot.utils import subplot_save, symmetric_vmin_vmax, hide_unused_axes
99

1010

1111
def subplot_fit_interferometer(
@@ -98,6 +98,7 @@ def subplot_fit_interferometer(
9898
plot_axis_type="scatter",
9999
)
100100

101+
hide_unused_axes(axes)
101102
plt.tight_layout()
102103
subplot_save(fig, output_path, output_filename, output_format)
103104

@@ -191,5 +192,6 @@ def subplot_fit_interferometer_dirty_images(
191192
use_log10=use_log10,
192193
)
193194

195+
hide_unused_axes(axes)
194196
plt.tight_layout()
195197
subplot_save(fig, output_path, output_filename, output_format)

autoarray/inversion/plot/inversion_plots.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1+
import csv
12
import logging
23
import numpy as np
3-
from typing import Optional
4+
from pathlib import Path
5+
from typing import Optional, Union
46

57
import matplotlib.pyplot as plt
68
from autoconf import conf
79

810
from autoarray.inversion.mappers.abstract import Mapper
911
from autoarray.plot.array import plot_array
10-
from autoarray.plot.utils import numpy_grid, numpy_lines, numpy_positions, subplot_save
12+
from autoarray.plot.utils import numpy_grid, numpy_lines, numpy_positions, subplot_save, hide_unused_axes
1113
from autoarray.inversion.plot.mapper_plots import plot_mapper
1214
from autoarray.structures.arrays.uniform_2d import Array2D
1315

@@ -215,6 +217,7 @@ def _recon_array():
215217
except (TypeError, Exception):
216218
pass
217219

220+
hide_unused_axes(axes)
218221
plt.tight_layout()
219222
subplot_save(fig, output_path, f"{output_filename}_{mapper_index}", output_format)
220223

@@ -332,7 +335,40 @@ def subplot_mappings(
332335
lines=lines,
333336
)
334337

338+
hide_unused_axes(axes)
335339
plt.tight_layout()
336340
subplot_save(
337341
fig, output_path, f"{output_filename}_{pixelization_index}", output_format
338342
)
343+
344+
345+
def save_reconstruction_csv(
346+
inversion,
347+
output_path: Union[str, Path],
348+
) -> None:
349+
"""Write a CSV of each mapper's reconstruction and noise map to *output_path*.
350+
351+
One file is written per mapper: ``source_plane_reconstruction_{i}.csv``,
352+
with columns ``y``, ``x``, ``reconstruction``, ``noise_map``.
353+
354+
Parameters
355+
----------
356+
inversion
357+
An ``AbstractInversion`` instance.
358+
output_path
359+
Directory in which to write the CSV files.
360+
"""
361+
output_path = Path(output_path)
362+
mapper_list = inversion.cls_list_from(cls=Mapper)
363+
364+
for i, mapper in enumerate(mapper_list):
365+
y = mapper.source_plane_mesh_grid[:, 0]
366+
x = mapper.source_plane_mesh_grid[:, 1]
367+
reconstruction = inversion.reconstruction_dict[mapper]
368+
noise_map = inversion.reconstruction_noise_map_dict[mapper]
369+
370+
with open(output_path / f"source_plane_reconstruction_{i}.csv", mode="w", newline="") as f:
371+
writer = csv.writer(f)
372+
writer.writerow(["y", "x", "reconstruction", "noise_map"])
373+
for j in range(len(x)):
374+
writer.writerow([float(y[j]), float(x[j]), float(reconstruction[j]), float(noise_map[j])])

autoarray/plot/array.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ def plot_array(
4747
title: str = "",
4848
xlabel: str = "",
4949
ylabel: str = "",
50-
colormap: str = "jet",
50+
colormap: Optional[str] = None,
5151
vmin: Optional[float] = None,
5252
vmax: Optional[float] = None,
5353
use_log10: bool = False,
54-
aspect: str = "auto",
54+
cb_unit: Optional[str] = None,
5555
origin_imshow: str = "upper",
5656
# --- figure control (used only when ax is None) -----------------------------
5757
figsize: Optional[Tuple[int, int]] = None,
@@ -105,8 +105,6 @@ def plot_array(
105105
Explicit color scale limits.
106106
use_log10
107107
When ``True`` a ``LogNorm`` is applied.
108-
aspect
109-
Passed directly to ``imshow``.
110108
origin_imshow
111109
Passed directly to ``imshow`` (``"upper"`` or ``"lower"``).
112110
figsize
@@ -135,6 +133,10 @@ def plot_array(
135133
if array is None or array.size == 0:
136134
return
137135

136+
if colormap is None:
137+
from autoarray.plot.utils import _default_colormap
138+
colormap = _default_colormap()
139+
138140
# convert overlay params (safe for None and already-numpy inputs)
139141
border = numpy_grid(border)
140142
origin = numpy_grid(origin)
@@ -180,16 +182,33 @@ def plot_array(
180182
else:
181183
norm = None
182184

185+
# Compute the axes-box aspect ratio from the data extent so that the
186+
# physical cell is correctly shaped and tight_layout has no whitespace
187+
# to absorb. This reproduces the old "square" subplot behaviour where
188+
# ratio = x_range / y_range was passed to plt.subplot(aspect=ratio).
189+
if extent is not None:
190+
x_range = abs(extent[1] - extent[0])
191+
y_range = abs(extent[3] - extent[2])
192+
_box_aspect = (x_range / y_range) if y_range > 0 else 1.0
193+
else:
194+
h, w = array.shape[:2]
195+
_box_aspect = (w / h) if h > 0 else 1.0
196+
183197
im = ax.imshow(
184198
array,
185199
cmap=colormap,
186200
norm=norm,
187201
extent=extent,
188-
aspect=aspect,
202+
aspect="auto", # image fills the axes box; box shape set below
189203
origin=origin_imshow,
190204
)
191205

192-
plt.colorbar(im, ax=ax)
206+
# Shape the axes box to match the data so there is no surrounding
207+
# whitespace when the panel is embedded in a subplot grid.
208+
ax.set_aspect(_box_aspect, adjustable="box")
209+
210+
from autoarray.plot.utils import _apply_colorbar
211+
_apply_colorbar(im, ax, cb_unit=cb_unit, is_subplot=not owns_figure)
193212

194213
# --- overlays --------------------------------------------------------------
195214
if array_overlay is not None:
@@ -198,7 +217,7 @@ def plot_array(
198217
cmap="Greys",
199218
alpha=0.5,
200219
extent=extent,
201-
aspect=aspect,
220+
aspect="auto",
202221
origin=origin_imshow,
203222
)
204223

@@ -223,7 +242,7 @@ def plot_array(
223242
ax.scatter(mesh_grid[:, 1], mesh_grid[:, 0], s=1, c="w", alpha=0.5)
224243

225244
if positions is not None:
226-
colors = ["r", "g", "b", "m", "c", "y"]
245+
colors = ["k", "g", "b", "m", "c", "y"]
227246
for i, pos in enumerate(positions):
228247
ax.scatter(pos[:, 1], pos[:, 0], s=20, c=colors[i % len(colors)], zorder=5)
229248

@@ -263,7 +282,7 @@ def plot_array(
263282
pass
264283

265284
# --- labels / ticks --------------------------------------------------------
266-
apply_labels(ax, title=title, xlabel=xlabel, ylabel=ylabel)
285+
apply_labels(ax, title=title, xlabel=xlabel, ylabel=ylabel, is_subplot=not owns_figure)
267286

268287
if extent is not None:
269288
apply_extent(ax, extent)

autoarray/plot/grid.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def plot_grid(
3232
title: str = "",
3333
xlabel: str = 'x (")',
3434
ylabel: str = 'y (")',
35-
colormap: str = "jet",
35+
colormap: Optional[str] = None,
3636
buffer: float = 0.1,
3737
extent: Optional[Tuple[float, float, float, float]] = None,
3838
force_symmetric_extent: bool = True,
@@ -101,6 +101,10 @@ def plot_grid(
101101

102102
lines = numpy_lines(lines)
103103

104+
if colormap is None:
105+
from autoarray.plot.utils import _default_colormap
106+
colormap = _default_colormap()
107+
104108
owns_figure = ax is None
105109
if owns_figure:
106110
figsize = figsize or conf_figsize("figures")
@@ -126,7 +130,8 @@ def plot_grid(
126130
ecolor=colors,
127131
)
128132

129-
plt.colorbar(sc, ax=ax)
133+
from autoarray.plot.utils import _apply_colorbar
134+
_apply_colorbar(sc, ax)
130135
else:
131136
if y_errors is None and x_errors is None:
132137
ax.scatter(grid[:, 1], grid[:, 0], s=1, c="k")

autoarray/plot/inversion.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def plot_inversion_reconstruction(
2121
title: str = "Reconstruction",
2222
xlabel: str = 'x (")',
2323
ylabel: str = 'y (")',
24-
colormap: str = "jet",
24+
colormap: Optional[str] = None,
2525
vmin: Optional[float] = None,
2626
vmax: Optional[float] = None,
2727
use_log10: bool = False,
@@ -76,6 +76,10 @@ def plot_inversion_reconstruction(
7676
from autoarray.inversion.mesh.interpolator.delaunay import InterpolatorDelaunay
7777
from autoarray.inversion.mesh.interpolator.knn import InterpolatorKNearestNeighbor
7878

79+
if colormap is None:
80+
from autoarray.plot.utils import _default_colormap
81+
colormap = _default_colormap()
82+
7983
owns_figure = ax is None
8084
if owns_figure:
8185
figsize = figsize or conf_figsize("figures")
@@ -192,7 +196,8 @@ def _plot_rectangular(ax, pixel_values, mapper, norm, colormap, extent):
192196
aspect="auto",
193197
origin="upper",
194198
)
195-
plt.colorbar(im, ax=ax)
199+
from autoarray.plot.utils import _apply_colorbar
200+
_apply_colorbar(im, ax)
196201
else:
197202
y_edges, x_edges = mapper.mesh_geometry.edges_transformed.T
198203
Y, X = np.meshgrid(y_edges, x_edges, indexing="ij")
@@ -204,7 +209,8 @@ def _plot_rectangular(ax, pixel_values, mapper, norm, colormap, extent):
204209
norm=norm,
205210
cmap=colormap,
206211
)
207-
plt.colorbar(im, ax=ax)
212+
from autoarray.plot.utils import _apply_colorbar
213+
_apply_colorbar(im, ax)
208214

209215

210216
def _plot_delaunay(ax, pixel_values, mapper, norm, colormap):
@@ -240,4 +246,5 @@ def _plot_delaunay(ax, pixel_values, mapper, norm, colormap):
240246
vals = pixel_values
241247

242248
tc = ax.tripcolor(x, y, vals, cmap=colormap, norm=norm, shading="gouraud")
243-
plt.colorbar(tc, ax=ax)
249+
from autoarray.plot.utils import _apply_colorbar
250+
_apply_colorbar(tc, ax)

autoarray/plot/segmentdata.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,3 +1042,15 @@
10421042
]
10431043
),
10441044
}
1045+
1046+
COLORMAP_NAME = "autoarray"
1047+
1048+
1049+
def register():
1050+
"""Register the autoarray segmentdata colormap with matplotlib (idempotent)."""
1051+
import matplotlib
1052+
import matplotlib.colors as mcolors
1053+
1054+
if COLORMAP_NAME not in matplotlib.colormaps:
1055+
cmap = mcolors.LinearSegmentedColormap(COLORMAP_NAME, segmentdata)
1056+
matplotlib.colormaps.register(cmap)

0 commit comments

Comments
 (0)