Skip to content

Commit 578bc4c

Browse files
committed
array 1d now use convert rules
1 parent 0fa9646 commit 578bc4c

2 files changed

Lines changed: 17 additions & 17 deletions

File tree

autoarray/structures/arrays/array_1d_util.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,25 @@ def convert_array_1d(
4040
"""
4141
array_1d = array_2d_util.convert_array(array=array_1d)
4242

43+
is_numpy = True if isinstance(array_1d, np.ndarray) else False
44+
4345
is_native = array_1d.shape[0] == mask_1d.shape_native[0]
4446

4547
mask_1d = jnp.array(mask_1d.array)
4648

4749
if is_native == store_native:
48-
return array_1d
50+
array_1d = array_1d
4951
elif not store_native:
50-
return array_1d_slim_from(
52+
array_1d = array_1d_slim_from(
5153
array_1d_native=array_1d,
5254
mask_1d=mask_1d,
5355
)
54-
55-
return array_1d_native_from(
56-
array_1d_slim=array_1d,
57-
mask_1d=mask_1d,
58-
)
56+
else:
57+
array_1d = array_1d_native_from(
58+
array_1d_slim=array_1d,
59+
mask_1d=mask_1d,
60+
)
61+
return np.array(array_1d) if is_numpy else jnp.array(array_1d)
5962

6063

6164
def array_1d_slim_from(

autoarray/structures/arrays/array_2d_util.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,7 @@ def convert_array_2d(
122122
"""
123123
array_2d = convert_array(array=array_2d).copy()
124124

125-
if isinstance(array_2d, np.ndarray):
126-
is_numpy = True
127-
else:
128-
is_numpy = False
125+
is_numpy = True if isinstance(array_2d, np.ndarray) else False
129126

130127
check_array_2d_and_mask_2d(array_2d=array_2d, mask_2d=mask_2d)
131128

@@ -135,17 +132,17 @@ def convert_array_2d(
135132
array_2d *= np.invert(mask_2d)
136133

137134
if is_native == store_native:
138-
return np.array(array_2d) if is_numpy else jnp.array(array_2d)
135+
array_2d = array_2d
139136
elif not store_native:
140137
array_2d = array_2d_slim_from(
141138
array_2d_native=array_2d,
142139
mask_2d=mask_2d,
143140
)
144-
return np.array(array_2d) if is_numpy else jnp.array(array_2d)
145-
array_2d = array_2d_native_from(
146-
array_2d_slim=array_2d,
147-
mask_2d=mask_2d,
148-
)
141+
else:
142+
array_2d = array_2d_native_from(
143+
array_2d_slim=array_2d,
144+
mask_2d=mask_2d,
145+
)
149146
return np.array(array_2d) if is_numpy else jnp.array(array_2d)
150147

151148

0 commit comments

Comments
 (0)