@@ -21,7 +21,10 @@ def convert_grid(grid: Union[np.ndarray, List]) -> np.ndarray:
2121 except AttributeError :
2222 pass
2323
24- return jnp .asarray (grid )
24+ if isinstance (grid , list ):
25+ grid = np .asarray (grid )
26+
27+ return grid
2528
2629
2730def check_grid_slim (grid , shape_native ):
@@ -109,28 +112,33 @@ def convert_grid_2d(
109112
110113 grid_2d = convert_grid (grid = grid_2d )
111114
115+ is_numpy = True if isinstance (grid_2d , np .ndarray ) else False
116+
112117 check_grid_2d_and_mask_2d (grid_2d = grid_2d , mask_2d = mask_2d )
113118
114119 is_native = len (grid_2d .shape ) == 3
115120
116- mask_2d = jnp .array (mask_2d .array )
117-
118121 if is_native :
119- grid_2d = grid_2d .at [:, :, 0 ].multiply (jnp .invert (mask_2d ))
120- grid_2d = grid_2d .at [:, :, 1 ].multiply (jnp .invert (mask_2d ))
122+ if not is_numpy :
123+ grid_2d = grid_2d .at [:, :, 0 ].multiply (jnp .invert (mask_2d ))
124+ grid_2d = grid_2d .at [:, :, 1 ].multiply (jnp .invert (mask_2d ))
125+ else :
126+ grid_2d [:, :, 0 ] *= np .invert (mask_2d )
127+ grid_2d [:, :, 1 ] *= np .invert (mask_2d )
121128
122129 if is_native == store_native :
123- return grid_2d
130+ grid_2d = grid_2d
124131 elif not store_native :
125- return grid_2d_slim_from (
132+ grid_2d = grid_2d_slim_from (
126133 grid_2d_native = grid_2d ,
127134 mask = mask_2d ,
128135 )
129- return grid_2d_native_from (
130- grid_2d_slim = grid_2d ,
131- mask_2d = mask_2d ,
132- )
133-
136+ else :
137+ grid_2d = grid_2d_native_from (
138+ grid_2d_slim = grid_2d ,
139+ mask_2d = mask_2d ,
140+ )
141+ return np .array (grid_2d ) if is_numpy else jnp .array (grid_2d )
134142
135143def convert_grid_2d_to_slim (
136144 grid_2d : Union [np .ndarray , List ], mask_2d : Mask2D
0 commit comments