Skip to content

Commit 44a2808

Browse files
committed
fix structure plotters
1 parent 3b6ab48 commit 44a2808

3 files changed

Lines changed: 11 additions & 8 deletions

File tree

autoarray/operators/contour.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations
2-
import numpy
2+
import numpy as np
33
import jax.numpy as jnp
44
from skimage import measure
55
from scipy.spatial import ConvexHull
@@ -42,7 +42,7 @@ def contour_array(self):
4242
return self._contour_array
4343

4444
pixel_centres = geometry_util.grid_pixel_centres_2d_slim_from(
45-
grid_scaled_2d_slim=jnp.array(self.grid),
45+
grid_scaled_2d_slim=np.array(self.grid),
4646
shape_native=self.shape_native,
4747
pixel_scales=self.pixel_scales,
4848
).astype("int")
@@ -56,7 +56,7 @@ def contour_array(self):
5656
def contour_list(self):
5757
# make sure to use base numpy to convert JAX array back to a normal array
5858
contour_indices_list = measure.find_contours(
59-
numpy.array(self.contour_array.array), 0
59+
np.array(self.contour_array), 0
6060
)
6161

6262
if len(contour_indices_list) == 0:
@@ -71,7 +71,7 @@ def contour_list(self):
7171
pixel_scales=self.pixel_scales,
7272
)
7373

74-
factor = 0.5 * jnp.array(self.pixel_scales) * jnp.array([-1.0, 1.0])
74+
factor = 0.5 * np.array(self.pixel_scales) * np.array([-1.0, 1.0])
7575
grid_scaled_1d += factor
7676

7777
contour_list.append(Grid2DIrregular(values=grid_scaled_1d))
@@ -86,10 +86,10 @@ def hull(
8686
return None
8787

8888
# cast JAX arrays to base numpy arrays
89-
grid_convex = numpy.zeros((len(self.grid), 2))
89+
grid_convex = np.zeros((len(self.grid), 2))
9090

91-
grid_convex[:, 0] = numpy.array(self.grid[:, 1])
92-
grid_convex[:, 1] = numpy.array(self.grid[:, 0])
91+
grid_convex[:, 0] = np.array(self.grid[:, 1])
92+
grid_convex[:, 1] = np.array(self.grid[:, 0])
9393

9494
try:
9595
hull = ConvexHull(grid_convex)

autoarray/plot/wrap/two_d/array_overlay.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,6 @@ def overlay_array(self, array, figure):
1919
aspect = figure.aspect_from(shape_native=array.shape_native)
2020
extent = array.extent_of_zoomed_array(buffer=0)
2121

22-
plt.imshow(X=array.native, aspect=aspect, extent=extent, **self.config_dict)
22+
print(type(array))
23+
24+
plt.imshow(X=array.native._array, aspect=aspect, extent=extent, **self.config_dict)

test_autoarray/structures/plot/test_structure_plotters.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from os import path
44
import pytest
55
import numpy as np
6+
import jax.numpy as jnp
67
import shutil
78

89
directory = path.dirname(path.realpath(__file__))

0 commit comments

Comments
 (0)