11from __future__ import annotations
2- from autoarray .numpy_wrapper import np , use_jax
32import numpy
3+ import jax .numpy as jnp
44from skimage import measure
55from scipy .spatial import ConvexHull
66from scipy .spatial import QhullError
@@ -42,16 +42,13 @@ def contour_array(self):
4242 return self ._contour_array
4343
4444 pixel_centres = geometry_util .grid_pixel_centres_2d_slim_from (
45- grid_scaled_2d_slim = np .array (self .grid ),
45+ grid_scaled_2d_slim = jnp .array (self .grid ),
4646 shape_native = self .shape_native ,
4747 pixel_scales = self .pixel_scales ,
4848 ).astype ("int" )
4949
50- arr = np .zeros (self .shape_native )
51- if use_jax :
52- arr = arr .at [tuple (np .array (pixel_centres ).T )].set (1 )
53- else :
54- arr [tuple (np .array (pixel_centres ).T )] = 1
50+ arr = jnp .zeros (self .shape_native )
51+ arr = arr .at [tuple (jnp .array (pixel_centres ).T )].set (1 )
5552
5653 return arr
5754
@@ -74,7 +71,7 @@ def contour_list(self):
7471 pixel_scales = self .pixel_scales ,
7572 )
7673
77- factor = 0.5 * np .array (self .pixel_scales ) * np .array ([- 1.0 , 1.0 ])
74+ factor = 0.5 * jnp .array (self .pixel_scales ) * jnp .array ([- 1.0 , 1.0 ])
7875 grid_scaled_1d += factor
7976
8077 contour_list .append (Grid2DIrregular (values = grid_scaled_1d ))
@@ -104,7 +101,7 @@ def hull(
104101 hull_x = grid_convex [hull_vertices , 0 ]
105102 hull_y = grid_convex [hull_vertices , 1 ]
106103
107- grid_hull = np .zeros ((len (hull_vertices ), 2 ))
104+ grid_hull = jnp .zeros ((len (hull_vertices ), 2 ))
108105
109106 grid_hull [:, 1 ] = hull_x
110107 grid_hull [:, 0 ] = hull_y
0 commit comments