Skip to content

Commit 083ed0b

Browse files
committed
over sampling unit tests
1 parent 5430647 commit 083ed0b

2 files changed

Lines changed: 38 additions & 9 deletions

File tree

autoarray/operators/over_sampling/over_sampler.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,16 @@ def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]):
147147
over_sample_size=sub_size, mask=mask
148148
)
149149

150+
151+
@property
152+
def sub_is_uniform(self) -> bool:
153+
"""
154+
Returns True if the sub_size is uniform across all pixels in the mask.
155+
"""
156+
return np.all(
157+
np.isclose(self.sub_size.array, self.sub_size.array[0])
158+
)
159+
150160
def tree_flatten(self):
151161
return (self.mask, self.sub_size), ()
152162

@@ -185,7 +195,7 @@ def sub_pixel_areas(self) -> np.ndarray:
185195
"""
186196
The area of every sub-pixel in the mask.
187197
"""
188-
sub_pixel_areas = jnp.zeros(self.sub_total)
198+
sub_pixel_areas = np.zeros(self.sub_total)
189199

190200
k = 0
191201

@@ -221,15 +231,24 @@ def binned_array_2d_from(self, array: Array2D) -> "Array2D":
221231
except AttributeError:
222232
pass
223233

224-
# binned_array_2d = over_sample_util.binned_array_2d_from(
225-
# array_2d=jnp.array(array),
226-
# mask_2d=jnp.array(self.mask),
227-
# sub_size=jnp.array(self.sub_size).astype("int"),
228-
# )
234+
if self.sub_is_uniform:
235+
binned_array_2d = array.reshape(
236+
self.mask.shape_slim, self.sub_size[0] ** 2
237+
).mean(axis=1)
238+
else:
239+
240+
# Define group sizes
241+
group_sizes = jnp.array(self.sub_size.array.astype("int") ** 2)
242+
243+
# Compute the cumulative sum of group sizes to get split points
244+
split_indices = jnp.cumsum(group_sizes)
245+
246+
# Ensure correct concatenation by making 0 a JAX array
247+
start_indices = jnp.concatenate((jnp.array([0]), split_indices[:-1]))
229248

230-
binned_array_2d = array.reshape(
231-
self.mask.shape_slim, self.sub_size[0] ** 2
232-
).mean(axis=1)
249+
# Compute the group means
250+
binned_array_2d = jnp.array(
251+
[array[start:end].mean() for start, end in zip(start_indices, split_indices)])
233252

234253
return Array2D(
235254
values=binned_array_2d,

test_autoarray/operators/over_sample/test_over_sampler.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,16 @@ def test__binned_array_2d_from():
7070
pixel_scales=1.0,
7171
)
7272

73+
over_sampling = aa.OverSampler(
74+
mask=mask, sub_size=aa.Array2D(values=[2, 2], mask=mask)
75+
)
76+
77+
arr = np.array([1.0, 5.0, 7.0, 10.0, 10.0, 10.0, 10.0, 10.0])
78+
79+
binned_array_2d = over_sampling.binned_array_2d_from(array=arr)
80+
81+
assert binned_array_2d.slim == pytest.approx(np.array([5.75, 10.0]), 1.0e-4)
82+
7383
over_sampling = aa.OverSampler(
7484
mask=mask, sub_size=aa.Array2D(values=[1, 2], mask=mask)
7585
)

0 commit comments

Comments
 (0)