@@ -147,6 +147,16 @@ def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]):
147147 over_sample_size = sub_size , mask = mask
148148 )
149149
150+
151+ @property
152+ def sub_is_uniform (self ) -> bool :
153+ """
154+ Returns True if the sub_size is uniform across all pixels in the mask.
155+ """
156+ return np .all (
157+ np .isclose (self .sub_size .array , self .sub_size .array [0 ])
158+ )
159+
150160 def tree_flatten (self ):
151161 return (self .mask , self .sub_size ), ()
152162
@@ -185,7 +195,7 @@ def sub_pixel_areas(self) -> np.ndarray:
185195 """
186196 The area of every sub-pixel in the mask.
187197 """
188- sub_pixel_areas = jnp .zeros (self .sub_total )
198+ sub_pixel_areas = np .zeros (self .sub_total )
189199
190200 k = 0
191201
@@ -221,15 +231,24 @@ def binned_array_2d_from(self, array: Array2D) -> "Array2D":
221231 except AttributeError :
222232 pass
223233
224- # binned_array_2d = over_sample_util.binned_array_2d_from(
225- # array_2d=jnp.array(array),
226- # mask_2d=jnp.array(self.mask),
227- # sub_size=jnp.array(self.sub_size).astype("int"),
228- # )
234+ if self .sub_is_uniform :
235+ binned_array_2d = array .reshape (
236+ self .mask .shape_slim , self .sub_size [0 ] ** 2
237+ ).mean (axis = 1 )
238+ else :
239+
240+ # Define group sizes
241+ group_sizes = jnp .array (self .sub_size .array .astype ("int" ) ** 2 )
242+
243+ # Compute the cumulative sum of group sizes to get split points
244+ split_indices = jnp .cumsum (group_sizes )
245+
246+ # Ensure correct concatenation by making 0 a JAX array
247+ start_indices = jnp .concatenate ((jnp .array ([0 ]), split_indices [:- 1 ]))
229248
230- binned_array_2d = array . reshape (
231- self . mask . shape_slim , self . sub_size [ 0 ] ** 2
232- ) .mean (axis = 1 )
249+ # Compute the group means
250+ binned_array_2d = jnp . array (
251+ [ array [ start : end ] .mean () for start , end in zip ( start_indices , split_indices )] )
233252
234253 return Array2D (
235254 values = binned_array_2d ,
0 commit comments