|
8 | 8 | from autoarray import type as ty |
9 | 9 | from autoarray.numpy_wrapper import use_jax, np as jnp |
10 | 10 |
|
| 11 | +def native_index_for_slim_index_2d_from( |
| 12 | + mask_2d: np.ndarray, |
| 13 | +) -> np.ndarray: |
| 14 | + """ |
| 15 | + Returns an array of shape [total_unmasked_pixels] that maps every unmasked pixel to its |
| 16 | + corresponding native 2D pixel using its (y,x) pixel indexes. |
| 17 | +
|
| 18 | + For example, for the following ``Mask2D``: |
| 19 | +
|
| 20 | + :: |
| 21 | + [[True, True, True, True] |
| 22 | + [True, False, False, True], |
| 23 | + [True, False, True, True], |
| 24 | + [True, True, True, True]] |
| 25 | +
|
| 26 | + This has three unmasked (``False`` values) which have the ``slim`` indexes: |
| 27 | +
|
| 28 | + :: |
| 29 | + [0, 1, 2] |
| 30 | +
|
| 31 | + The array ``native_index_for_slim_index_2d`` is therefore: |
| 32 | +
|
| 33 | + :: |
| 34 | + [[1,1], [1,2], [2,1]] |
| 35 | +
|
| 36 | + Parameters |
| 37 | + ---------- |
| 38 | + mask_2d |
| 39 | + A 2D array of bools, where `False` values are unmasked. |
| 40 | +
|
| 41 | + Returns |
| 42 | + ------- |
| 43 | + ndarray |
| 44 | + An array that maps pixels from a slimmed array of shape [total_unmasked_pixels] to its native array |
| 45 | + of shape [total_pixels, total_pixels]. |
| 46 | +
|
| 47 | + Examples |
| 48 | + -------- |
| 49 | + mask_2d = np.array([[True, True, True], |
| 50 | + [True, False, True] |
| 51 | + [True, True, True]]) |
| 52 | +
|
| 53 | + native_index_for_slim_index_2d = native_index_for_slim_index_2d_from(mask_2d=mask_2d) |
| 54 | + """ |
| 55 | + return jnp.stack(jnp.nonzero(~mask_2d.astype(bool))).T |
| 56 | + |
11 | 57 |
|
12 | 58 | def mask_2d_centres_from( |
13 | 59 | shape_native: Tuple[int, int], |
@@ -531,136 +577,56 @@ def mask_slim_indexes_from( |
531 | 577 | return np.where(mask_flat == return_masked_indexes)[0] |
532 | 578 |
|
533 | 579 |
|
534 | | -@numba_util.jit() |
535 | | -def check_if_edge_pixel(mask_2d: np.ndarray, y: int, x: int) -> bool: |
536 | | - """ |
537 | | - Checks if an input [y,x] pixel on the input `mask` is an edge-pixel. |
538 | | -
|
539 | | - An edge pixel is defined as a pixel on the mask which is unmasked (has a `False`) value and at least 1 of its 8 |
540 | | - direct neighbors is masked (is `True`). |
541 | | -
|
542 | | - Parameters |
543 | | - ---------- |
544 | | - mask_2d |
545 | | - The mask for which the input pixel is checked if it is an edge pixel. |
546 | | - y |
547 | | - The y pixel coordinate on the mask that is checked for if it is an edge pixel. |
548 | | - x |
549 | | - The x pixel coordinate on the mask that is checked for if it is an edge pixel. |
550 | | -
|
551 | | - Returns |
552 | | - ------- |
553 | | - bool |
554 | | - If `True` the pixel on the mask is an edge pixel, else a `False` is returned because it is not. |
555 | | - """ |
556 | | - |
557 | | - if ( |
558 | | - mask_2d[y + 1, x] |
559 | | - or mask_2d[y - 1, x] |
560 | | - or mask_2d[y, x + 1] |
561 | | - or mask_2d[y, x - 1] |
562 | | - or mask_2d[y + 1, x + 1] |
563 | | - or mask_2d[y + 1, x - 1] |
564 | | - or mask_2d[y - 1, x + 1] |
565 | | - or mask_2d[y - 1, x - 1] |
566 | | - ): |
567 | | - return True |
568 | | - else: |
569 | | - return False |
570 | | - |
571 | | - |
572 | | -@numba_util.jit() |
573 | | -def total_edge_pixels_from(mask_2d: np.ndarray) -> int: |
574 | | - """ |
575 | | - Returns the total number of edge-pixels in a mask. |
576 | | -
|
577 | | - An edge pixel is defined as a pixel on the mask which is unmasked (has a `False`) value and at least 1 of its 8 |
578 | | - direct neighbors is masked (is `True`). |
579 | | -
|
580 | | - Parameters |
581 | | - ---------- |
582 | | - mask_2d |
583 | | - The mask for which the total number of edge pixels is computed. |
584 | | -
|
585 | | - Returns |
586 | | - ------- |
587 | | - int |
588 | | - The total number of edge pixels. |
589 | | - """ |
590 | | - |
591 | | - edge_pixel_total = 0 |
592 | | - |
593 | | - for y in range(1, mask_2d.shape[0] - 1): |
594 | | - for x in range(1, mask_2d.shape[1] - 1): |
595 | | - if not mask_2d[y, x]: |
596 | | - if check_if_edge_pixel(mask_2d=mask_2d, y=y, x=x): |
597 | | - edge_pixel_total += 1 |
598 | | - |
599 | | - return edge_pixel_total |
600 | | - |
601 | | - |
602 | | -@numba_util.jit() |
603 | 580 | def edge_1d_indexes_from(mask_2d: np.ndarray) -> np.ndarray: |
604 | 581 | """ |
605 | 582 | Returns a 1D array listing all edge pixel indexes in the mask. |
606 | 583 |
|
607 | | - An edge pixel is defined as a pixel on the mask which is unmasked (has a `False`) value and at least 1 of its 8 |
| 584 | + An edge pixel is defined as a pixel on the mask which is unmasked (has a `False`) value and at least one of its 8 |
608 | 585 | direct neighbors is masked (is `True`). |
609 | 586 |
|
610 | | - For example, for the following ``Mask2D``: |
611 | | -
|
612 | | - :: |
613 | | - [[True, True, True, True, True], |
614 | | - [True, False, False, False, True], |
615 | | - [True, False, False, False, True], |
616 | | - [True, False, False, False, True], |
617 | | - [True, True, True, True, True]] |
618 | | -
|
619 | | - The `edge_slim` indexes (given via ``mask_2d.derive_indexes.edge_slim``) is given by: |
620 | | -
|
621 | | - :: |
622 | | - [0, 1, 2, 3, 5, 6, 7, 8] |
623 | | -
|
624 | | - Note that index 4 is skipped, which corresponds to the ``False`` value in the centre of the mask, because it |
625 | | - does not neighbor a ``True`` value in any one of the eight neighboring directions and is therefore not at |
626 | | - an edge. |
627 | | -
|
628 | 587 | Parameters |
629 | 588 | ---------- |
630 | 589 | mask_2d |
631 | | - The mask for which the 1D edge pixel indexes are computed. |
| 590 | + A 2D boolean array where `False` values indicate unmasked pixels. |
632 | 591 |
|
633 | 592 | Returns |
634 | 593 | ------- |
635 | 594 | np.ndarray |
636 | | - The 1D indexes of all edge pixels on the mask. |
637 | | - """ |
638 | | - |
639 | | - edge_pixel_total = total_edge_pixels_from(mask_2d) |
| 595 | + A 1D array of indexes of all edge pixels on the mask. |
640 | 596 |
|
641 | | - edge_pixels = np.zeros(edge_pixel_total) |
642 | | - edge_index = 0 |
643 | | - regular_index = 0 |
| 597 | + Examples |
| 598 | + -------- |
| 599 | + >>> mask = np.array([ |
| 600 | + ... [True, True, True, True, True], |
| 601 | + ... [True, False, False, True, True], |
| 602 | + ... [True, False, False, False, True], |
| 603 | + ... [True, True, False, True, True], |
| 604 | + ... [True, True, True, True, True] |
| 605 | + ... ]) |
| 606 | + >>> edge_1d_indexes_from(mask) |
| 607 | + array([1, 2, 5, 7, 8, 9]) |
| 608 | + """ |
| 609 | + # Pad the mask to handle edge cases without index errors |
| 610 | + padded_mask = np.pad(mask_2d, pad_width=1, mode='constant', constant_values=True) |
| 611 | + |
| 612 | + # Identify neighbors in 3x3 regions around each pixel |
| 613 | + neighbors = ( |
| 614 | + padded_mask[:-2, 1:-1] | padded_mask[2:, 1:-1] | # Up, Down |
| 615 | + padded_mask[1:-1, :-2] | padded_mask[1:-1, 2:] | # Left, Right |
| 616 | + padded_mask[:-2, :-2] | padded_mask[:-2, 2:] | # Top-left, Top-right |
| 617 | + padded_mask[2:, :-2] | padded_mask[2:, 2:] # Bottom-left, Bottom-right |
| 618 | + ) |
644 | 619 |
|
645 | | - for y in range(1, mask_2d.shape[0] - 1): |
646 | | - for x in range(1, mask_2d.shape[1] - 1): |
647 | | - if not mask_2d[y, x]: |
648 | | - if ( |
649 | | - mask_2d[y + 1, x] |
650 | | - or mask_2d[y - 1, x] |
651 | | - or mask_2d[y, x + 1] |
652 | | - or mask_2d[y, x - 1] |
653 | | - or mask_2d[y + 1, x + 1] |
654 | | - or mask_2d[y + 1, x - 1] |
655 | | - or mask_2d[y - 1, x + 1] |
656 | | - or mask_2d[y - 1, x - 1] |
657 | | - ): |
658 | | - edge_pixels[edge_index] = regular_index |
659 | | - edge_index += 1 |
| 620 | + # Identify edge pixels: False values with at least one True neighbor |
| 621 | + edge_mask = ~mask_2d & neighbors |
660 | 622 |
|
661 | | - regular_index += 1 |
| 623 | + # Create an index array where False entries get sequential 1D indices |
| 624 | + index_array = np.full(mask_2d.shape, fill_value=-1, dtype=int) |
| 625 | + false_indices = np.flatnonzero(~mask_2d) |
| 626 | + index_array[~mask_2d] = np.arange(len(false_indices)) |
662 | 627 |
|
663 | | - return edge_pixels |
| 628 | + # Return the 1D indexes of the edge pixels |
| 629 | + return index_array[edge_mask] |
664 | 630 |
|
665 | 631 |
|
666 | 632 | @numba_util.jit() |
@@ -911,62 +877,4 @@ def rescaled_mask_2d_from(mask_2d: np.ndarray, rescale_factor: float) -> np.ndar |
911 | 877 | return np.isclose(rescaled_mask_2d, 1) |
912 | 878 |
|
913 | 879 |
|
914 | | -@numba_util.jit() |
915 | | -def native_index_for_slim_index_2d_from( |
916 | | - mask_2d: np.ndarray, |
917 | | -) -> np.ndarray: |
918 | | - """ |
919 | | - Returns an array of shape [total_unmasked_pixels] that maps every unmasked pixel to its |
920 | | - corresponding native 2D pixel using its (y,x) pixel indexes. |
921 | | -
|
922 | | - For example, for the following ``Mask2D``: |
923 | | -
|
924 | | - :: |
925 | | - [[True, True, True, True] |
926 | | - [True, False, False, True], |
927 | | - [True, False, True, True], |
928 | | - [True, True, True, True]] |
929 | | -
|
930 | | - This has three unmasked (``False`` values) which have the ``slim`` indexes: |
931 | | -
|
932 | | - :: |
933 | | - [0, 1, 2] |
934 | | -
|
935 | | - The array ``native_index_for_slim_index_2d`` is therefore: |
936 | | -
|
937 | | - :: |
938 | | - [[1,1], [1,2], [2,1]] |
939 | | -
|
940 | | - Parameters |
941 | | - ---------- |
942 | | - mask_2d |
943 | | - A 2D array of bools, where `False` values are unmasked. |
944 | | -
|
945 | | - Returns |
946 | | - ------- |
947 | | - ndarray |
948 | | - An array that maps pixels from a slimmed array of shape [total_unmasked_pixels] to its native array |
949 | | - of shape [total_pixels, total_pixels]. |
950 | | -
|
951 | | - Examples |
952 | | - -------- |
953 | | - mask_2d = np.array([[True, True, True], |
954 | | - [True, False, True] |
955 | | - [True, True, True]]) |
956 | | -
|
957 | | - native_index_for_slim_index_2d = native_index_for_slim_index_2d_from(mask_2d=mask_2d) |
958 | | - """ |
959 | | - if use_jax: |
960 | | - return jnp.stack(jnp.nonzero(~mask_2d.astype(bool))).T |
961 | | - else: |
962 | | - total_pixels = np.sum(~mask_2d) |
963 | | - native_index_for_slim_index_2d = np.zeros(shape=(total_pixels, 2)) |
964 | | - slim_index = 0 |
965 | | - |
966 | | - for y in range(mask_2d.shape[0]): |
967 | | - for x in range(mask_2d.shape[1]): |
968 | | - if not mask_2d[y, x]: |
969 | | - native_index_for_slim_index_2d[slim_index, :] = y, x |
970 | | - slim_index += 1 |
971 | 880 |
|
972 | | - return native_index_for_slim_index_2d |
0 commit comments