Skip to content

Commit 32f55f7

Browse files
committed
grid 2d casting
1 parent 578bc4c commit 32f55f7

1 file changed

Lines changed: 20 additions & 12 deletions

File tree

autoarray/structures/grids/grid_2d_util.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ def convert_grid(grid: Union[np.ndarray, List]) -> np.ndarray:
2121
except AttributeError:
2222
pass
2323

24-
return jnp.asarray(grid)
24+
if isinstance(grid, list):
25+
grid = np.asarray(grid)
26+
27+
return grid
2528

2629

2730
def check_grid_slim(grid, shape_native):
@@ -109,28 +112,33 @@ def convert_grid_2d(
109112

110113
grid_2d = convert_grid(grid=grid_2d)
111114

115+
is_numpy = True if isinstance(grid_2d, np.ndarray) else False
116+
112117
check_grid_2d_and_mask_2d(grid_2d=grid_2d, mask_2d=mask_2d)
113118

114119
is_native = len(grid_2d.shape) == 3
115120

116-
mask_2d = jnp.array(mask_2d.array)
117-
118121
if is_native:
119-
grid_2d = grid_2d.at[:, :, 0].multiply(jnp.invert(mask_2d))
120-
grid_2d = grid_2d.at[:, :, 1].multiply(jnp.invert(mask_2d))
122+
if not is_numpy:
123+
grid_2d = grid_2d.at[:, :, 0].multiply(jnp.invert(mask_2d))
124+
grid_2d = grid_2d.at[:, :, 1].multiply(jnp.invert(mask_2d))
125+
else:
126+
grid_2d[:, :, 0] *= np.invert(mask_2d)
127+
grid_2d[:, :, 1] *= np.invert(mask_2d)
121128

122129
if is_native == store_native:
123-
return grid_2d
130+
grid_2d = grid_2d
124131
elif not store_native:
125-
return grid_2d_slim_from(
132+
grid_2d = grid_2d_slim_from(
126133
grid_2d_native=grid_2d,
127134
mask=mask_2d,
128135
)
129-
return grid_2d_native_from(
130-
grid_2d_slim=grid_2d,
131-
mask_2d=mask_2d,
132-
)
133-
136+
else:
137+
grid_2d = grid_2d_native_from(
138+
grid_2d_slim=grid_2d,
139+
mask_2d=mask_2d,
140+
)
141+
return np.array(grid_2d) if is_numpy else jnp.array(grid_2d)
134142

135143
def convert_grid_2d_to_slim(
136144
grid_2d: Union[np.ndarray, List], mask_2d: Mask2D

0 commit comments

Comments
 (0)