File tree Expand file tree Collapse file tree
autoarray/structures/arrays Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -40,22 +40,25 @@ def convert_array_1d(
4040 """
4141 array_1d = array_2d_util .convert_array (array = array_1d )
4242
43+ is_numpy = True if isinstance (array_1d , np .ndarray ) else False
44+
4345 is_native = array_1d .shape [0 ] == mask_1d .shape_native [0 ]
4446
4547 mask_1d = jnp .array (mask_1d .array )
4648
4749 if is_native == store_native :
48- return array_1d
50+ array_1d = array_1d
4951 elif not store_native :
50- return array_1d_slim_from (
52+ array_1d = array_1d_slim_from (
5153 array_1d_native = array_1d ,
5254 mask_1d = mask_1d ,
5355 )
54-
55- return array_1d_native_from (
56- array_1d_slim = array_1d ,
57- mask_1d = mask_1d ,
58- )
56+ else :
57+ array_1d = array_1d_native_from (
58+ array_1d_slim = array_1d ,
59+ mask_1d = mask_1d ,
60+ )
61+ return np .array (array_1d ) if is_numpy else jnp .array (array_1d )
5962
6063
6164def array_1d_slim_from (
Original file line number Diff line number Diff line change @@ -122,10 +122,7 @@ def convert_array_2d(
122122 """
123123 array_2d = convert_array (array = array_2d ).copy ()
124124
125- if isinstance (array_2d , np .ndarray ):
126- is_numpy = True
127- else :
128- is_numpy = False
125+ is_numpy = True if isinstance (array_2d , np .ndarray ) else False
129126
130127 check_array_2d_and_mask_2d (array_2d = array_2d , mask_2d = mask_2d )
131128
@@ -135,17 +132,17 @@ def convert_array_2d(
135132 array_2d *= np .invert (mask_2d )
136133
137134 if is_native == store_native :
138- return np . array ( array_2d ) if is_numpy else jnp . array ( array_2d )
135+ array_2d = array_2d
139136 elif not store_native :
140137 array_2d = array_2d_slim_from (
141138 array_2d_native = array_2d ,
142139 mask_2d = mask_2d ,
143140 )
144- return np . array ( array_2d ) if is_numpy else jnp . array ( array_2d )
145- array_2d = array_2d_native_from (
146- array_2d_slim = array_2d ,
147- mask_2d = mask_2d ,
148- )
141+ else :
142+ array_2d = array_2d_native_from (
143+ array_2d_slim = array_2d ,
144+ mask_2d = mask_2d ,
145+ )
149146 return np .array (array_2d ) if is_numpy else jnp .array (array_2d )
150147
151148
You can’t perform that action at this time.
0 commit comments