Skip to content

Commit 7feb31a

Browse files
committed
Add new parameters to imshow: cmap, vmin, vmax, and origin for enhanced image display control
1 parent e4fb09f commit 7feb31a

3 files changed

Lines changed: 292 additions & 28 deletions

File tree

Examples/plot_image2d.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,19 @@
1515
1616
Key bindings on the image panel: **R** reset view · **C** toggle colorbar ·
1717
**L** / **S** cycle colour-scale modes.
18+
19+
New ``imshow`` parameters
20+
-------------------------
21+
``cmap``
22+
Colormap name passed directly to :meth:`~anyplotlib.figure_plots.Axes.imshow`
23+
(e.g. ``"viridis"``, ``"inferno"``). Defaults to ``"gray"``.
24+
``vmin`` / ``vmax``
25+
Colormap clipping limits in data units. Values outside the range are
26+
clamped to the colormap endpoints. Defaults to the data min/max.
27+
``origin``
28+
``"upper"`` (default) places row 0 at the top (image convention).
29+
``"lower"`` places row 0 at the bottom (scientific / matrix convention)
30+
and automatically reverses the y-axis so tick values increase upward.
1831
"""
1932
import numpy as np
2033
import anyplotlib as apl
@@ -47,13 +60,14 @@ def _ring(r, r0, width, amp):
4760
ax_img = fig.add_subplot(gs[0, 0])
4861
ax_hist = fig.add_subplot(gs[1, 0])
4962

50-
# ── Image panel ───────────────────────────────────────────────────────────────
51-
v = ax_img.imshow(image, axes=[x, y], units="nm")
52-
v.set_colormap("inferno")
53-
63+
# ── Image panel — cmap, vmin, vmax supplied directly to imshow ────────────────
5464
vmin_init = float(image.min())
5565
vmax_init = float(image.max())
56-
v.set_clim(vmin=vmin_init, vmax=vmax_init)
66+
67+
# Pass cmap, vmin, and vmax directly — no separate set_colormap / set_clim call
68+
# needed for the initial display.
69+
v = ax_img.imshow(image, axes=[x, y], units="nm",
70+
cmap="inferno", vmin=vmin_init, vmax=vmax_init)
5771

5872
# First-order spot markers
5973
dx = x[1] - x[0]
@@ -98,12 +112,28 @@ def _apply_high(event):
98112
fig
99113

100114
# %%
101-
# Adjust colour map
102-
# ------------------
115+
# Adjust colour map and display range
116+
# ------------------------------------
103117
# :meth:`~anyplotlib.figure_plots.Plot2D.set_colormap` switches the palette;
104118
# :meth:`~anyplotlib.figure_plots.Plot2D.set_clim` adjusts the display range.
119+
# Both are equivalent to passing ``cmap`` / ``vmin`` / ``vmax`` at construction.
105120

106121
v.set_colormap("viridis")
107122
v.set_clim(vmin=0.0, vmax=0.8)
108123

109124
fig
125+
126+
# %%
127+
# origin='lower' — scientific / matrix convention
128+
# ------------------------------------------------
129+
# Passing ``origin='lower'`` places row 0 of the data at the *bottom* of the
130+
# image, matching the matplotlib / scientific convention. The y-axis is
131+
# automatically reversed so tick values still increase upward.
132+
133+
mat = np.arange(64, dtype=float).reshape(8, 8) # row 0 = small values
134+
135+
fig2, ax2 = apl.subplots()
136+
v2 = ax2.imshow(mat, cmap="plasma", origin="lower")
137+
138+
fig2
139+

anyplotlib/figure_plots.py

Lines changed: 79 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -153,22 +153,42 @@ def __init__(self, fig: "Figure", spec: SubplotSpec): # noqa: F821
153153
# ------------------------------------------------------------------
154154
def imshow(self, data: np.ndarray,
155155
axes: list | None = None,
156-
units: str = "px") -> "Plot2D":
156+
units: str = "px",
157+
cmap: str | None = None,
158+
vmin: float | None = None,
159+
vmax: float | None = None,
160+
origin: str = "upper") -> "Plot2D":
157161
"""Attach a 2-D image to this axes cell.
158162
159163
Parameters
160164
----------
161-
data : np.ndarray shape (H, W) or (H, W, C)
165+
data : np.ndarray, shape (H, W) or (H, W, C)
166+
Image data. RGB/RGBA arrays use only the first channel.
162167
axes : [x_axis, y_axis], optional
168+
Physical coordinate arrays for each axis.
163169
units : str, optional
170+
Axis units label. Default ``"px"``.
171+
cmap : str, optional
172+
Colormap name (e.g. ``"viridis"``, ``"inferno"``).
173+
Defaults to ``"gray"``.
174+
vmin, vmax : float, optional
175+
Colormap clipping limits in data units. Values outside this
176+
range are clamped to the colormap endpoints. Defaults to the
177+
data min / max.
178+
origin : ``"upper"`` | ``"lower"``, optional
179+
Where row 0 of the array is placed. ``"upper"`` (default)
180+
puts row 0 at the top, matching the usual image convention.
181+
``"lower"`` puts row 0 at the bottom, matching the matplotlib
182+
convention for matrices / scientific plots.
164183
165184
Returns
166185
-------
167186
Plot2D
168187
"""
169188
x_axis = axes[0] if axes and len(axes) > 0 else None
170189
y_axis = axes[1] if axes and len(axes) > 1 else None
171-
plot = Plot2D(data, x_axis=x_axis, y_axis=y_axis, units=units)
190+
plot = Plot2D(data, x_axis=x_axis, y_axis=y_axis, units=units,
191+
cmap=cmap, vmin=vmin, vmax=vmax, origin=origin)
172192
self._attach(plot)
173193
return plot
174194

@@ -495,17 +515,36 @@ class Plot2D:
495515
"""
496516

497517
def __init__(self, data: np.ndarray,
498-
x_axis=None, y_axis=None, units: str = "px"):
518+
x_axis=None, y_axis=None, units: str = "px",
519+
cmap: str | None = None,
520+
vmin: float | None = None,
521+
vmax: float | None = None,
522+
origin: str = "upper"):
499523
self._id: str = "" # assigned by Axes._attach
500524
self._fig: object = None # assigned by Axes._attach
501525

526+
_valid_origins = ("upper", "lower")
527+
if origin not in _valid_origins:
528+
raise ValueError(
529+
f"origin must be one of {_valid_origins!r}, got {origin!r}"
530+
)
531+
self._origin: str = origin
532+
502533
data = np.asarray(data)
503534
if data.ndim == 3:
504535
data = data[:, :, 0]
505536
if data.ndim != 2:
506537
raise ValueError(f"data must be 2-D (H x W), got {data.shape}")
507538

508539
h, w = data.shape
540+
541+
# origin='lower' — row 0 at the bottom, matching matplotlib's matrix
542+
# convention. Flip the data so our renderer (which always draws row 0
543+
# at the top) shows the correct orientation, and reverse the y-axis so
544+
# tick values increase upward.
545+
if origin == "lower":
546+
data = np.flipud(data)
547+
509548
x_axis_given = x_axis is not None
510549
y_axis_given = y_axis is not None
511550
if x_axis is None:
@@ -515,12 +554,20 @@ def __init__(self, data: np.ndarray,
515554
x_axis = np.asarray(x_axis, dtype=float)
516555
y_axis = np.asarray(y_axis, dtype=float)
517556

518-
img_u8, vmin, vmax = _normalize_image(data)
557+
if origin == "lower":
558+
y_axis = y_axis[::-1]
559+
560+
img_u8, raw_vmin, raw_vmax = _normalize_image(data)
519561
self._raw_u8 = img_u8
520-
self._raw_vmin = vmin
521-
self._raw_vmax = vmax
562+
self._raw_vmin = raw_vmin
563+
self._raw_vmax = raw_vmax
564+
565+
cmap_name = cmap if cmap is not None else "gray"
566+
cmap_lut = _build_colormap_lut(cmap_name)
522567

523-
cmap_lut = _build_colormap_lut("gray")
568+
# vmin/vmax clip the colormap in data units; default to the full range.
569+
disp_min = float(vmin) if vmin is not None else raw_vmin
570+
disp_max = float(vmax) if vmax is not None else raw_vmax
524571

525572
# Compute physical pixel scale (data-units per pixel) from axis arrays
526573
scale_x = float(abs(x_axis[-1] - x_axis[0]) / max(w - 1, 1)) if len(x_axis) >= 2 else 1.0
@@ -538,14 +585,14 @@ def __init__(self, data: np.ndarray,
538585
"units": units,
539586
"scale_x": scale_x,
540587
"scale_y": scale_y,
541-
"display_min": vmin,
542-
"display_max": vmax,
543-
"raw_min": vmin,
544-
"raw_max": vmax,
588+
"display_min": disp_min,
589+
"display_max": disp_max,
590+
"raw_min": raw_vmin,
591+
"raw_max": raw_vmax,
545592
"show_colorbar": False,
546593
"log_scale": False,
547594
"scale_mode": "linear",
548-
"colormap_name": "gray",
595+
"colormap_name": cmap_name,
549596
"colormap_data": cmap_lut,
550597
"zoom": 1.0,
551598
"center_x": 0.5,
@@ -589,13 +636,21 @@ def to_state_dict(self) -> dict:
589636
# ------------------------------------------------------------------
590637
def update(self, data: np.ndarray,
591638
x_axis=None, y_axis=None, units: str | None = None) -> None:
592-
"""Replace the image data."""
639+
"""Replace the image data.
640+
641+
The ``origin`` supplied at construction is automatically re-applied
642+
so the new data is displayed with the same orientation.
643+
"""
593644
data = np.asarray(data)
594645
if data.ndim == 3:
595646
data = data[:, :, 0]
596647
if data.ndim != 2:
597648
raise ValueError(f"data must be 2-D, got {data.shape}")
598649
h, w = data.shape
650+
651+
if self._origin == "lower":
652+
data = np.flipud(data)
653+
599654
img_u8, vmin, vmax = _normalize_image(data)
600655
self._raw_u8, self._raw_vmin, self._raw_vmax = img_u8, vmin, vmax
601656

@@ -604,21 +659,24 @@ def update(self, data: np.ndarray,
604659
self._state["image_width"] = w
605660
self._state["has_axes"] = True
606661
if y_axis is not None:
607-
self._state["y_axis"] = np.asarray(y_axis, float).tolist()
662+
ya = np.asarray(y_axis, float)
663+
if self._origin == "lower":
664+
ya = ya[::-1]
665+
self._state["y_axis"] = ya.tolist()
608666
self._state["image_height"] = h
609667
self._state["has_axes"] = True
610668
if units is not None:
611669
self._state["units"] = units
612670

613671
self._state.update({
614-
"image_b64": self._encode_bytes(img_u8),
672+
"image_b64": self._encode_bytes(img_u8),
615673
"image_width": w,
616674
"image_height": h,
617-
"display_min": vmin,
618-
"display_max": vmax,
619-
"raw_min": vmin,
620-
"raw_max": vmax,
621-
"colormap_data": _build_colormap_lut(self._state["colormap_name"]),
675+
"display_min": vmin,
676+
"display_max": vmax,
677+
"raw_min": vmin,
678+
"raw_max": vmax,
679+
"colormap_data": _build_colormap_lut(self._state["colormap_name"]),
622680
})
623681
self._push()
624682

0 commit comments

Comments
 (0)