11from __future__ import annotations
2- import numpy
2+ import numpy as np
33import jax .numpy as jnp
44from skimage import measure
55from scipy .spatial import ConvexHull
@@ -42,7 +42,7 @@ def contour_array(self):
4242 return self ._contour_array
4343
4444 pixel_centres = geometry_util .grid_pixel_centres_2d_slim_from (
45- grid_scaled_2d_slim = jnp .array (self .grid ),
45+ grid_scaled_2d_slim = np .array (self .grid ),
4646 shape_native = self .shape_native ,
4747 pixel_scales = self .pixel_scales ,
4848 ).astype ("int" )
@@ -56,7 +56,7 @@ def contour_array(self):
5656 def contour_list (self ):
5757 # make sure to use base numpy to convert JAX array back to a normal array
5858 contour_indices_list = measure .find_contours (
59- numpy .array (self .contour_array . array ), 0
59+ np .array (self .contour_array ), 0
6060 )
6161
6262 if len (contour_indices_list ) == 0 :
@@ -71,7 +71,7 @@ def contour_list(self):
7171 pixel_scales = self .pixel_scales ,
7272 )
7373
74- factor = 0.5 * jnp .array (self .pixel_scales ) * jnp .array ([- 1.0 , 1.0 ])
74+ factor = 0.5 * np .array (self .pixel_scales ) * np .array ([- 1.0 , 1.0 ])
7575 grid_scaled_1d += factor
7676
7777 contour_list .append (Grid2DIrregular (values = grid_scaled_1d ))
@@ -86,10 +86,10 @@ def hull(
8686 return None
8787
8888 # cast JAX arrays to base numpy arrays
89- grid_convex = numpy .zeros ((len (self .grid ), 2 ))
89+ grid_convex = np .zeros ((len (self .grid ), 2 ))
9090
91- grid_convex [:, 0 ] = numpy .array (self .grid [:, 1 ])
92- grid_convex [:, 1 ] = numpy .array (self .grid [:, 0 ])
91+ grid_convex [:, 0 ] = np .array (self .grid [:, 1 ])
92+ grid_convex [:, 1 ] = np .array (self .grid [:, 0 ])
9393
9494 try :
9595 hull = ConvexHull (grid_convex )
0 commit comments