Skip to content

Commit d3c5bbf

Browse files
committed
same rules for grids
1 parent 49b8dc5 commit d3c5bbf

4 files changed

Lines changed: 12 additions & 14 deletions

File tree

autoarray/structures/arrays/array_1d_util.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ def convert_array_1d(
4444

4545
is_native = array_1d.shape[0] == mask_1d.shape_native[0]
4646

47-
mask_1d = jnp.array(mask_1d.array)
48-
4947
if is_native == store_native:
5048
array_1d = array_1d
5149
elif not store_native:
@@ -124,7 +122,7 @@ def array_1d_native_from(
124122
).astype("int")
125123

126124
return array_1d_via_indexes_1d_from(
127-
array_1d_slim=np.array(array_1d_slim),
125+
array_1d_slim=array_1d_slim,
128126
shape=shape,
129127
native_index_for_slim_index_1d=native_index_for_slim_index_1d,
130128
)

autoarray/structures/arrays/array_2d_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def convert_array_2d(
129129
is_native = len(array_2d.shape) == 2
130130

131131
if is_native and not skip_mask:
132-
array_2d *= np.invert(mask_2d)
132+
array_2d *= ~mask_2d
133133

134134
if is_native == store_native:
135135
array_2d = array_2d

autoarray/structures/grids/grid_1d_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def grid_1d_slim_from(
171171
"""
172172

173173
return array_1d_util.array_1d_slim_from(
174-
array_1d_native=np.array(grid_1d_native),
174+
array_1d_native=grid_1d_native,
175175
mask_1d=mask_1d,
176176
)
177177

autoarray/structures/grids/grid_2d_util.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,11 @@ def convert_grid_2d(
120120

121121
if is_native:
122122
if not is_numpy:
123-
grid_2d = grid_2d.at[:, :, 0].multiply(jnp.invert(mask_2d))
124-
grid_2d = grid_2d.at[:, :, 1].multiply(jnp.invert(mask_2d))
123+
grid_2d = grid_2d.at[:, :, 0].multiply(~mask_2d)
124+
grid_2d = grid_2d.at[:, :, 1].multiply(~mask_2d)
125125
else:
126-
grid_2d[:, :, 0] *= np.invert(mask_2d)
127-
grid_2d[:, :, 1] *= np.invert(mask_2d)
126+
grid_2d[:, :, 0] *= ~mask_2d
127+
grid_2d[:, :, 1] *= ~mask_2d
128128

129129
if is_native == store_native:
130130
grid_2d = grid_2d
@@ -682,16 +682,16 @@ def grid_2d_slim_from(
682682
"""
683683

684684
grid_1d_slim_y = array_2d_util.array_2d_slim_from(
685-
array_2d_native=np.array(grid_2d_native[:, :, 0]),
686-
mask_2d=np.array(mask),
685+
array_2d_native=grid_2d_native[:, :, 0],
686+
mask_2d=mask,
687687
)
688688

689689
grid_1d_slim_x = array_2d_util.array_2d_slim_from(
690-
array_2d_native=np.array(grid_2d_native[:, :, 1]),
691-
mask_2d=np.array(mask),
690+
array_2d_native=grid_2d_native[:, :, 1],
691+
mask_2d=mask,
692692
)
693693

694-
return np.stack((grid_1d_slim_y, grid_1d_slim_x), axis=-1)
694+
return jnp.stack((grid_1d_slim_y, grid_1d_slim_x), axis=-1)
695695

696696

697697
def grid_2d_native_from(

0 commit comments

Comments
 (0)