Skip to content

Commit 731fcb5

Browse files
committed
black
1 parent 4e0b925 commit 731fcb5

17 files changed

Lines changed: 110 additions & 116 deletions

File tree

autoarray/fixtures.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ def make_blurring_grid_2d_7x7():
110110
return aa.Grid2D.from_mask(mask=make_blurring_mask_2d_7x7())
111111

112112

113-
114113
def make_image_7x7():
115114
return aa.Array2D.ones(shape_native=(7, 7), pixel_scales=(1.0, 1.0))
116115

autoarray/geometry/geometry_util.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def convert_pixel_scales_2d(pixel_scales: ty.PixelScales) -> Tuple[float, float]
181181

182182
return pixel_scales
183183

184+
184185
@numba_util.jit()
185186
def central_pixel_coordinates_2d_numba_from(
186187
shape_native: Tuple[int, int],
@@ -205,6 +206,7 @@ def central_pixel_coordinates_2d_numba_from(
205206
"""
206207
return (float(shape_native[0] - 1) / 2, float(shape_native[1] - 1) / 2)
207208

209+
208210
@numba_util.jit()
209211
def central_scaled_coordinate_2d_numba_from(
210212
shape_native: Tuple[int, int],
@@ -305,6 +307,7 @@ def central_scaled_coordinate_2d_from(
305307

306308
return (y_pixel, x_pixel)
307309

310+
308311
def pixel_coordinates_2d_from(
309312
scaled_coordinates_2d: Tuple[float, float],
310313
shape_native: Tuple[int, int],
@@ -589,9 +592,9 @@ def grid_pixel_centres_2d_slim_from(
589592
centres_scaled = np.array(centres_scaled)
590593
pixel_scales = np.array(pixel_scales)
591594
sign = np.array([-1.0, 1.0])
592-
return (
593-
(sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5
594-
).astype(int)
595+
return ((sign * grid_scaled_2d_slim / pixel_scales) + centres_scaled + 0.5).astype(
596+
int
597+
)
595598

596599

597600
def grid_pixel_indexes_2d_slim_from(
@@ -647,9 +650,7 @@ def grid_pixel_indexes_2d_slim_from(
647650
)
648651

649652
return (
650-
(grid_pixels_2d_slim * np.array([shape_native[1], 1]))
651-
.sum(axis=1)
652-
.astype(int)
653+
(grid_pixels_2d_slim * np.array([shape_native[1], 1])).sum(axis=1).astype(int)
653654
)
654655

655656

@@ -698,9 +699,7 @@ def grid_scaled_2d_slim_from(
698699
centres_scaled = np.array(centres_scaled)
699700
pixel_scales = np.array(pixel_scales)
700701
sign = np.array([-1, 1])
701-
return (
702-
(grid_pixels_2d_slim - centres_scaled - 0.5) * pixel_scales * sign
703-
)
702+
return (grid_pixels_2d_slim - centres_scaled - 0.5) * pixel_scales * sign
704703

705704

706705
def grid_pixel_centres_2d_from(
@@ -750,9 +749,7 @@ def grid_pixel_centres_2d_from(
750749
centres_scaled = np.array(centres_scaled)
751750
pixel_scales = np.array(pixel_scales)
752751
sign = np.array([-1.0, 1.0])
753-
return (
754-
(sign * grid_scaled_2d / pixel_scales) + centres_scaled + 0.5
755-
).astype(int)
752+
return ((sign * grid_scaled_2d / pixel_scales) + centres_scaled + 0.5).astype(int)
756753

757754

758755
def extent_symmetric_from(

autoarray/inversion/inversion/imaging/w_tilde.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -526,8 +526,7 @@ def mapped_reconstructed_data_dict(self) -> Dict[LinearObj, Array2D]:
526526
)
527527

528528
mapped_reconstructed_image = self.convolver.convolve_image_no_blurring(
529-
image=mapped_reconstructed_image,
530-
mask=self.mask
529+
image=mapped_reconstructed_image, mask=self.mask
531530
)
532531

533532
else:

autoarray/mask/derive/indexes_2d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,9 @@ def edge_slim(self) -> np.ndarray:
198198
199199
print(derive_indexes_2d.edge_slim)
200200
"""
201-
return mask_2d_util.edge_1d_indexes_from(mask_2d=np.array(self.mask).astype("bool")).astype(
202-
"int"
203-
)
201+
return mask_2d_util.edge_1d_indexes_from(
202+
mask_2d=np.array(self.mask).astype("bool")
203+
).astype("int")
204204

205205
@property
206206
def edge_native(self) -> np.ndarray:

autoarray/mask/mask_1d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class Mask1DKeys(Enum):
2525
PIXSCA = "PIXSCA"
2626
ORIGIN = "ORIGIN"
2727

28+
2829
class Mask1D(Mask):
2930
def __init__(
3031
self,

autoarray/mask/mask_2d_util.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from autoarray import exc
77
from autoarray.numpy_wrapper import np as jnp
88

9+
910
def native_index_for_slim_index_2d_from(
1011
mask_2d: np.ndarray,
1112
) -> np.ndarray:
@@ -400,6 +401,7 @@ def mask_2d_via_pixel_coordinates_from(
400401
return mask_2d
401402
return buffed_mask_2d_from(mask_2d=mask_2d, buffer=buffer) # Apply buf
402403

404+
403405
def min_false_distance_to_edge(mask: np.ndarray) -> Tuple[int, int]:
404406
"""
405407
Compute the minimum 1D distance in the y and x directions from any `False` value at the mask's extreme positions
@@ -618,14 +620,18 @@ def edge_1d_indexes_from(mask_2d: np.ndarray) -> np.ndarray:
618620
array([0, 1, 2, 3, 5, 6, 7, 8])
619621
"""
620622
# Pad the mask to handle edge cases without index errors
621-
padded_mask = np.pad(mask_2d, pad_width=1, mode='constant', constant_values=True)
623+
padded_mask = np.pad(mask_2d, pad_width=1, mode="constant", constant_values=True)
622624

623625
# Identify neighbors in 3x3 regions around each pixel
624626
neighbors = (
625-
padded_mask[:-2, 1:-1] | padded_mask[2:, 1:-1] | # Up, Down
626-
padded_mask[1:-1, :-2] | padded_mask[1:-1, 2:] | # Left, Right
627-
padded_mask[:-2, :-2] | padded_mask[:-2, 2:] | # Top-left, Top-right
628-
padded_mask[2:, :-2] | padded_mask[2:, 2:] # Bottom-left, Bottom-right
627+
padded_mask[:-2, 1:-1]
628+
| padded_mask[2:, 1:-1] # Up, Down
629+
| padded_mask[1:-1, :-2]
630+
| padded_mask[1:-1, 2:] # Left, Right
631+
| padded_mask[:-2, :-2]
632+
| padded_mask[:-2, 2:] # Top-left, Top-right
633+
| padded_mask[2:, :-2]
634+
| padded_mask[2:, 2:] # Bottom-left, Bottom-right
629635
)
630636

631637
# Identify edge pixels: False values with at least one True neighbor
@@ -708,10 +714,10 @@ def border_slim_indexes_from(mask_2d: np.ndarray) -> np.ndarray:
708714

709715
# Identify border pixels: where the full length in any direction is True
710716
border_mask = (
711-
(up_sums == np.arange(height)[:, None]) |
712-
(down_sums == np.arange(height - 1, -1, -1)[:, None]) |
713-
(left_sums == np.arange(width)[None, :]) |
714-
(right_sums == np.arange(width - 1, -1, -1)[None, :])
717+
(up_sums == np.arange(height)[:, None])
718+
| (down_sums == np.arange(height - 1, -1, -1)[:, None])
719+
| (left_sums == np.arange(width)[None, :])
720+
| (right_sums == np.arange(width - 1, -1, -1)[None, :])
715721
) & ~mask_2d
716722

717723
# Create an index array where False entries get sequential 1D indices
@@ -767,14 +773,16 @@ def buffed_mask_2d_from(mask_2d: np.ndarray, buffer: int = 1) -> np.ndarray:
767773
buffer_range = np.arange(-buffer, buffer + 1)
768774

769775
# Generate all possible neighbors for each False entry
770-
dy, dx = np.meshgrid(buffer_range, buffer_range, indexing='ij')
776+
dy, dx = np.meshgrid(buffer_range, buffer_range, indexing="ij")
771777
neighbors = np.stack([dy.ravel(), dx.ravel()], axis=-1)
772778

773779
# Calculate all neighboring positions for all False coordinates
774780
all_neighbors = np.add(np.array(false_coords).T[:, np.newaxis], neighbors)
775781

776782
# Clip the neighbors to stay within the bounds of the mask
777-
valid_neighbors = np.clip(all_neighbors, [0, 0], [mask_2d.shape[0] - 1, mask_2d.shape[1] - 1])
783+
valid_neighbors = np.clip(
784+
all_neighbors, [0, 0], [mask_2d.shape[0] - 1, mask_2d.shape[1] - 1]
785+
)
778786

779787
# Update the buffed mask: set all the neighbors to False
780788
buffed_mask_2d[valid_neighbors[:, :, 0], valid_neighbors[:, :, 1]] = False
@@ -833,6 +841,3 @@ def rescaled_mask_2d_from(mask_2d: np.ndarray, rescale_factor: float) -> np.ndar
833841
rescaled_mask_2d[:, 0] = 1
834842
rescaled_mask_2d[:, rescaled_mask_2d.shape[1] - 1] = 1
835843
return np.isclose(rescaled_mask_2d, 1)
836-
837-
838-

autoarray/operators/over_sampling/over_sample_util.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -463,16 +463,10 @@ def grid_2d_slim_over_sampled_via_mask_from(
463463
# )
464464
# else:
465465
grid_slim[sub_index, 0] = -(
466-
y_scaled
467-
- y_sub_half
468-
+ y1 * y_sub_step
469-
+ (y_sub_step / 2.0)
466+
y_scaled - y_sub_half + y1 * y_sub_step + (y_sub_step / 2.0)
470467
)
471468
grid_slim[sub_index, 1] = (
472-
x_scaled
473-
- x_sub_half
474-
+ x1 * x_sub_step
475-
+ (x_sub_step / 2.0)
469+
x_scaled - x_sub_half + x1 * x_sub_step + (x_sub_step / 2.0)
476470
)
477471
sub_index += 1
478472

autoarray/operators/over_sampling/over_sampler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,9 @@ def binned_array_2d_from(self, array: Array2D) -> "Array2D":
226226
# sub_size=np.array(self.sub_size).astype("int"),
227227
# )
228228

229-
binned_array_2d = array.reshape(self.mask.shape_slim, self.sub_size[0]**2).mean(axis=1)
229+
binned_array_2d = array.reshape(
230+
self.mask.shape_slim, self.sub_size[0] ** 2
231+
).mean(axis=1)
230232

231233
return Array2D(
232234
values=binned_array_2d,

autoarray/structures/arrays/array_2d_util.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,7 @@ def convert_array(array: Union[np.ndarray, List]) -> np.ndarray:
2727
array = np.asarray(array)
2828
elif isinstance(array, jnp.ndarray):
2929
array = jax.lax.cond(
30-
type(array) is list,
31-
lambda _: jnp.asarray(array),
32-
lambda _: array,
33-
None
30+
type(array) is list, lambda _: jnp.asarray(array), lambda _: array, None
3431
)
3532
return array
3633

@@ -41,6 +38,7 @@ def check_array_2d(array_2d: np.ndarray):
4138
"An array input into the Array2D.__new__ method is not of shape 1."
4239
)
4340

41+
4442
def check_array_2d_and_mask_2d(array_2d: np.ndarray, mask_2d: Mask2D):
4543
"""
4644
The `manual` classmethods in the `Array2D` object take as input a list or ndarray which is returned as an
@@ -90,6 +88,7 @@ def check_array_2d_and_mask_2d(array_2d: np.ndarray, mask_2d: Mask2D):
9088
"""
9189
)
9290

91+
9392
def convert_array_2d(
9493
array_2d: Union[np.ndarray, List],
9594
mask_2d: Mask2D,
@@ -489,6 +488,7 @@ def index_slim_for_index_2d_from(indexes_2d: np.ndarray, shape_native) -> np.nda
489488

490489
return index_slim_for_index_native_2d
491490

491+
492492
def array_2d_slim_from(
493493
array_2d_native: np.ndarray,
494494
mask_2d: np.ndarray,
@@ -534,6 +534,7 @@ def array_2d_slim_from(
534534
"""
535535
return array_2d_native[~mask_2d.astype(bool)]
536536

537+
537538
def array_2d_native_from(
538539
array_2d_slim: np.ndarray,
539540
mask_2d: np.ndarray,
@@ -620,9 +621,7 @@ def array_2d_via_indexes_from(
620621
The native 2D array of values mapped from the slimmed array with dimensions (total_values, total_values).
621622
"""
622623
return (
623-
jnp.zeros(shape)
624-
.at[tuple(native_index_for_slim_index_2d.T)]
625-
.set(array_2d_slim)
624+
jnp.zeros(shape).at[tuple(native_index_for_slim_index_2d.T)].set(array_2d_slim)
626625
)
627626

628627

autoarray/structures/arrays/kernel_2d.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,9 @@ def convolved_array_from(self, array: Array2D) -> Array2D:
474474

475475
array_2d = array.native
476476

477-
convolved_array_2d = scipy.signal.convolve2d(array_2d._array, np.array(self.native._array), mode="same")
477+
convolved_array_2d = scipy.signal.convolve2d(
478+
array_2d._array, np.array(self.native._array), mode="same"
479+
)
478480

479481
convolved_array_1d = array_2d_util.array_2d_slim_from(
480482
mask_2d=np.array(array_2d.mask),
@@ -543,15 +545,11 @@ def convolve_image_no_blurring(self, image, mask, jax_method="fft"):
543545
kernels that are more than about 5x5. Default is `fft`.
544546
"""
545547

546-
slim_to_native = jnp.nonzero(
547-
jnp.logical_not(mask.array), size=image.shape[0]
548-
)
548+
slim_to_native = jnp.nonzero(jnp.logical_not(mask.array), size=image.shape[0])
549549

550550
expanded_array_native = jnp.zeros(mask.shape)
551551

552-
expanded_array_native = expanded_array_native.at[slim_to_native].set(
553-
image
554-
)
552+
expanded_array_native = expanded_array_native.at[slim_to_native].set(image)
555553

556554
kernel = np.array(self.native.array)
557555

@@ -571,4 +569,6 @@ def convolve_mapping_matrix(self, mapping_matrix, mask):
571569
image
572570
1D array of the values which are to be blurred with the convolver's PSF.
573571
"""
574-
return jax.vmap(self.convolve_image_no_blurring, in_axes=(1, None))(mapping_matrix, mask).T
572+
return jax.vmap(self.convolve_image_no_blurring, in_axes=(1, None))(
573+
mapping_matrix, mask
574+
).T

0 commit comments

Comments
 (0)