Skip to content

Commit 3e6127d

Browse files
authored
Merge pull request #173 from Jammy2211/feature/rgb
Feature/rgb
2 parents b93f511 + 473dc10 commit 3e6127d

7 files changed

Lines changed: 155 additions & 10 deletions

File tree

autoarray/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
from .layout.layout import Layout2D
6767
from .structures.arrays.uniform_1d import Array1D
6868
from .structures.arrays.uniform_2d import Array2D
69+
from .structures.arrays.rgb import Array2DRGB
6970
from .structures.arrays.irregular import ArrayIrregular
7071
from .structures.grids.uniform_1d import Grid1D
7172
from .structures.grids.uniform_2d import Grid2D

autoarray/fixtures.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def make_array_2d_7x7():
6868
return aa.Array2D.ones(shape_native=(7, 7), pixel_scales=(1.0, 1.0))
6969

7070

71+
def make_array_2d_rgb_7x7():
72+
return aa.Array2DRGB(values=np.ones((7, 7, 3)), mask=make_mask_2d_7x7())
73+
74+
7175
def make_layout_2d_7x7():
7276
return aa.Layout2D(
7377
shape_2d=(7, 7),

autoarray/mask/derive/zoom_2d.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44

55
if TYPE_CHECKING:
66
from autoarray.structures.arrays.uniform_2d import Array2D
7+
from autoarray.structures.arrays.rgb import Array2DRGB
78
from autoarray.mask.mask_2d import Mask2D
89

10+
911
from autoarray.structures.arrays import array_2d_util
1012
from autoarray.structures.grids import grid_2d_util
1113

@@ -242,8 +244,12 @@ def array_2d_from(self, array: Array2D, buffer: int = 1) -> Array2D:
242244
The number pixels around the extracted array used as a buffer.
243245
"""
244246
from autoarray.structures.arrays.uniform_2d import Array2D
247+
from autoarray.structures.arrays.rgb import Array2DRGB
245248
from autoarray.mask.mask_2d import Mask2D
246249

250+
if isinstance(array, Array2DRGB):
251+
return self.array_2d_rgb_from(array=array, buffer=buffer)
252+
247253
extracted_array_2d = array_2d_util.extracted_array_2d_from(
248254
array_2d=np.array(array.native),
249255
y0=self.region[0] - buffer,
@@ -269,3 +275,53 @@ def array_2d_from(self, array: Array2D, buffer: int = 1) -> Array2D:
269275
arr = array_2d_util.convert_array_2d(array_2d=extracted_array_2d, mask_2d=mask)
270276

271277
return Array2D(values=arr, mask=mask, header=array.header)
278+
279+
def array_2d_rgb_from(self, array: Array2DRGB, buffer: int = 1) -> Array2DRGB:
280+
"""
281+
Extract the 2D region of an RGB array corresponding to the rectangle encompassing all unmasked values.
282+
283+
This works the same as the `array_2d_from` method, but for RGB arrays, meaning that it iterates over the three
284+
channels of the RGB array and extracts the region for each channel separately.
285+
286+
This is used to extract and visualize only the region of an RGB image that is used in an analysis.
287+
288+
Parameters
289+
----------
290+
buffer
291+
The number pixels around the extracted array used as a buffer.
292+
"""
293+
from autoarray.structures.arrays.rgb import Array2DRGB
294+
from autoarray.mask.mask_2d import Mask2D
295+
296+
for i in range(3):
297+
298+
extracted_array_2d = array_2d_util.extracted_array_2d_from(
299+
array_2d=np.array(array.native[:, :, i]),
300+
y0=self.region[0] - buffer,
301+
y1=self.region[1] + buffer,
302+
x0=self.region[2] - buffer,
303+
x1=self.region[3] + buffer,
304+
)
305+
306+
if i == 0:
307+
array_2d_rgb = np.zeros(
308+
(extracted_array_2d.shape[0], extracted_array_2d.shape[1], 3)
309+
)
310+
311+
array_2d_rgb[:, :, i] = extracted_array_2d
312+
313+
extracted_mask_2d = array_2d_util.extracted_array_2d_from(
314+
array_2d=np.array(self.mask),
315+
y0=self.region[0] - buffer,
316+
y1=self.region[1] + buffer,
317+
x0=self.region[2] - buffer,
318+
x1=self.region[3] + buffer,
319+
)
320+
321+
mask = Mask2D(
322+
mask=extracted_mask_2d,
323+
pixel_scales=array.pixel_scales,
324+
origin=array.mask.mask_centre,
325+
)
326+
327+
return Array2DRGB(values=array_2d_rgb.astype("int"), mask=mask)

autoarray/plot/mat_plot/two_d.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from autoarray.plot.visuals.two_d import Visuals2D
1515
from autoarray.mask.derive.zoom_2d import Zoom2D
1616
from autoarray.structures.arrays.uniform_2d import Array2D
17+
from autoarray.structures.arrays.rgb import Array2DRGB
1718

1819
from autoarray.structures.arrays import array_2d_util
1920

@@ -243,7 +244,6 @@ def plot_array(
243244
bypass
244245
If `True`, `plt.close` is omitted and the matplotlib figure remains open. This is used when making subplots.
245246
"""
246-
247247
if array is None or np.all(array == 0):
248248
return
249249

@@ -280,14 +280,25 @@ def plot_array(
280280

281281
origin = conf.instance["visualize"]["general"]["general"]["imshow_origin"]
282282

283-
plt.imshow(
284-
X=array.native.array,
285-
aspect=aspect,
286-
cmap=self.cmap.cmap,
287-
norm=norm,
288-
extent=extent,
289-
origin=origin,
290-
)
283+
if isinstance(array, Array2DRGB):
284+
285+
plt.imshow(
286+
X=array.native.array,
287+
aspect=aspect,
288+
extent=extent,
289+
origin=origin,
290+
)
291+
292+
else:
293+
294+
plt.imshow(
295+
X=array.native.array,
296+
aspect=aspect,
297+
cmap=self.cmap.cmap,
298+
norm=norm,
299+
extent=extent,
300+
origin=origin,
301+
)
291302

292303
if visuals_2d.array_overlay is not None:
293304
self.array_overlay.overlay_array(
@@ -317,7 +328,12 @@ def plot_array(
317328
pixels=array.shape_native[1],
318329
)
319330

320-
self.title.set(auto_title=auto_labels.title, use_log10=self.use_log10)
331+
if isinstance(array, Array2DRGB):
332+
title = "RGB"
333+
else:
334+
title = auto_labels.title
335+
336+
self.title.set(auto_title=title, use_log10=self.use_log10)
321337
self.ylabel.set()
322338
self.xlabel.set()
323339

@@ -332,6 +348,7 @@ def plot_array(
332348
[annotate.set() for annotate in self.annotate]
333349

334350
if self.colorbar is not False:
351+
335352
cb = self.colorbar.set(
336353
units=self.units,
337354
ax=ax,

autoarray/structures/arrays/rgb.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from autoarray.abstract_ndarray import AbstractNDArray
2+
from autoarray.structures.arrays.uniform_2d import Array2D
3+
4+
5+
class Array2DRGB(Array2D):
6+
7+
def __init__(self, values, mask):
8+
"""
9+
A container for RGB images which have a final dimension of 3, which allows them to be visualized using
10+
the same functionality as `Array2D` objects.
11+
12+
By passing an RGB image to this class, the following visualization functionality is used when the RGB
13+
image is used in `Plotter` objects:
14+
15+
- The RGB image is plotted using the `imshow` function of Matplotlib.
16+
- Functionality which sets the scale of the axis, zooms the image, and sets the axis limits is used.
17+
- The colorbar is set to the RGB image, which is a 3D array with a final dimension of 3.
18+
- The formatting of the image is identical to that of `Array2D` objects, which means the image is plotted
19+
with the same aspect ratio as the original image making for easy subplot formatting.
20+
21+
This class always assumes the array is in its `native` representation, but with a final dimension of 3.
22+
23+
Parameters
24+
----------
25+
values
26+
The values of the RGB image, which is a 3D array with a final dimension of 3.
27+
mask
28+
The 2D mask associated with the array, defining the pixels each array value in its ``slim`` representation
29+
is paired with.
30+
"""
31+
32+
array = values
33+
34+
while isinstance(array, AbstractNDArray):
35+
array = array.array
36+
37+
self._array = array
38+
self.mask = mask
39+
40+
@property
41+
def native(self) -> "Array2D":
42+
"""
43+
Returns the RGB ndarray of shape [total_y_pixels, total_x_pixels, 3] in its `native` representation.
44+
"""
45+
return self

test_autoarray/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ def make_array_2d_7x7():
6969
return fixtures.make_array_2d_7x7()
7070

7171

72+
@pytest.fixture(name="array_2d_rgb_7x7")
73+
def make_array_2d_rgb_7x7():
74+
return fixtures.make_array_2d_rgb_7x7()
75+
76+
7277
@pytest.fixture(name="layout_2d_7x7")
7378
def make_layout_2d_7x7():
7479
return fixtures.make_layout_2d_7x7()

test_autoarray/structures/plot/test_structure_plotters.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,20 @@ def test__grid(
167167
grid_2d_plotter.figure_2d(color_array=color_array)
168168

169169
assert path.join(plot_path, "grid3.png") in plot_patch.paths
170+
171+
172+
def test__array_rgb(
173+
array_2d_rgb_7x7,
174+
plot_path,
175+
plot_patch,
176+
):
177+
array_plotter = aplt.Array2DPlotter(
178+
array=array_2d_rgb_7x7,
179+
mat_plot_2d=aplt.MatPlot2D(
180+
output=aplt.Output(path=plot_path, filename="array_rgb", format="png")
181+
),
182+
)
183+
184+
array_plotter.figure_2d()
185+
186+
assert path.join(plot_path, "array_rgb.png") in plot_patch.paths

0 commit comments

Comments
 (0)