@@ -120,11 +120,11 @@ def convert_grid_2d(
120120
121121 if is_native :
122122 if not is_numpy :
123- grid_2d = grid_2d .at [:, :, 0 ].multiply (jnp . invert ( mask_2d ) )
124- grid_2d = grid_2d .at [:, :, 1 ].multiply (jnp . invert ( mask_2d ) )
123+ grid_2d = grid_2d .at [:, :, 0 ].multiply (~ mask_2d )
124+ grid_2d = grid_2d .at [:, :, 1 ].multiply (~ mask_2d )
125125 else :
126- grid_2d [:, :, 0 ] *= np . invert ( mask_2d )
127- grid_2d [:, :, 1 ] *= np . invert ( mask_2d )
126+ grid_2d [:, :, 0 ] *= ~ mask_2d
127+ grid_2d [:, :, 1 ] *= ~ mask_2d
128128
129129 if is_native == store_native :
130130 grid_2d = grid_2d
@@ -682,16 +682,16 @@ def grid_2d_slim_from(
682682 """
683683
684684 grid_1d_slim_y = array_2d_util .array_2d_slim_from (
685- array_2d_native = np . array ( grid_2d_native [:, :, 0 ]) ,
686- mask_2d = np . array ( mask ) ,
685+ array_2d_native = grid_2d_native [:, :, 0 ],
686+ mask_2d = mask ,
687687 )
688688
689689 grid_1d_slim_x = array_2d_util .array_2d_slim_from (
690- array_2d_native = np . array ( grid_2d_native [:, :, 1 ]) ,
691- mask_2d = np . array ( mask ) ,
690+ array_2d_native = grid_2d_native [:, :, 1 ],
691+ mask_2d = mask ,
692692 )
693693
694- return np .stack ((grid_1d_slim_y , grid_1d_slim_x ), axis = - 1 )
694+ return jnp .stack ((grid_1d_slim_y , grid_1d_slim_x ), axis = - 1 )
695695
696696
697697def grid_2d_native_from (
0 commit comments