1+ import jax .numpy as jnp
2+ import numpy as np
13from typing import Tuple , Union
2- from autoarray . numpy_wrapper import np , use_jax
4+
35
46from autoarray import numba_util
57from autoarray import type as ty
@@ -179,8 +181,69 @@ def convert_pixel_scales_2d(pixel_scales: ty.PixelScales) -> Tuple[float, float]
179181
180182 return pixel_scales
181183
184+ @numba_util .jit ()
185+ def central_pixel_coordinates_2d_numba_from (
186+ shape_native : Tuple [int , int ],
187+ ) -> Tuple [float , float ]:
188+ """
189+ Returns the central pixel coordinates of a 2D geometry (and therefore a 2D data structure like an ``Array2D``)
190+ from the shape of that data structure.
191+
192+ Examples of the central pixels are as follows:
193+
194+ - For a 3x3 image, the central pixel is pixel [1, 1].
195+ - For a 4x4 image, the central pixel is [1.5, 1.5].
196+
197+ Parameters
198+ ----------
199+ shape_native
200+ The dimensions of the data structure, which can be in 1D, 2D or higher dimensions.
201+
202+ Returns
203+ -------
204+ The central pixel coordinates of the data structure.
205+ """
206+ return (float (shape_native [0 ] - 1 ) / 2 , float (shape_native [1 ] - 1 ) / 2 )
182207
183208@numba_util .jit ()
209+ def central_scaled_coordinate_2d_numba_from (
210+ shape_native : Tuple [int , int ],
211+ pixel_scales : ty .PixelScales ,
212+ origin : Tuple [float , float ] = (0.0 , 0.0 ),
213+ ) -> Tuple [float , float ]:
214+ """
215+ Returns the central scaled coordinates of a 2D geometry (and therefore a 2D data structure like an ``Array2D``)
216+ from the shape of that data structure.
217+
218+ This is computed by using the data structure's shape and converting it to scaled units using an input
219+ pixel-coordinates to scaled-coordinate conversion factor `pixel_scales`.
220+
221+ The origin of the scaled grid can also be input and moved from (0.0, 0.0).
222+
223+ Parameters
224+ ----------
225+ shape_native
226+ The 2D shape of the data structure whose central scaled coordinates are computed.
227+ pixel_scales
228+ The (y,x) scaled units to pixel units conversion factor of the 2D data structure.
229+ origin
230+ The (y,x) scaled units origin of the coordinate system the central scaled coordinate is computed on.
231+
232+ Returns
233+ -------
234+ The central coordinates of the 2D data structure in scaled units.
235+ """
236+
237+ central_pixel_coordinates = central_pixel_coordinates_2d_numba_from (
238+ shape_native = shape_native
239+ )
240+
241+ y_pixel = central_pixel_coordinates [0 ] + (origin [0 ] / pixel_scales [0 ])
242+ x_pixel = central_pixel_coordinates [1 ] - (origin [1 ] / pixel_scales [1 ])
243+
244+ return (y_pixel , x_pixel )
245+
246+
184247def central_pixel_coordinates_2d_from (
185248 shape_native : Tuple [int , int ],
186249) -> Tuple [float , float ]:
@@ -205,7 +268,6 @@ def central_pixel_coordinates_2d_from(
205268 return (float (shape_native [0 ] - 1 ) / 2 , float (shape_native [1 ] - 1 ) / 2 )
206269
207270
208- @numba_util .jit ()
209271def central_scaled_coordinate_2d_from (
210272 shape_native : Tuple [int , int ],
211273 pixel_scales : ty .PixelScales ,
@@ -234,7 +296,7 @@ def central_scaled_coordinate_2d_from(
234296 The central coordinates of the 2D data structure in scaled units.
235297 """
236298
237- central_pixel_coordinates = central_pixel_coordinates_2d_from (
299+ central_pixel_coordinates = central_pixel_coordinates_2d_numba_from (
238300 shape_native = shape_native
239301 )
240302
@@ -243,8 +305,6 @@ def central_scaled_coordinate_2d_from(
243305
244306 return (y_pixel , x_pixel )
245307
246-
247- @numba_util .jit ()
248308def pixel_coordinates_2d_from (
249309 scaled_coordinates_2d : Tuple [float , float ],
250310 shape_native : Tuple [int , int ],
@@ -352,7 +412,7 @@ def scaled_coordinates_2d_from(
352412 origin=(0.0, 0.0)
353413 )
354414 """
355- central_scaled_coordinates = central_scaled_coordinate_2d_from (
415+ central_scaled_coordinates = central_scaled_coordinate_2d_numba_from (
356416 shape_native = shape_native , pixel_scales = pixel_scales , origin = origins
357417 )
358418
@@ -382,18 +442,16 @@ def transform_grid_2d_to_reference_frame(
382442 grid
383443 The 2d grid of (y, x) coordinates which are transformed to a new reference frame.
384444 """
385- if use_jax :
386- shifted_grid_2d = grid_2d .array - np .array (centre )
387- else :
388- shifted_grid_2d = grid_2d - np .array (centre )
389- radius = np .sqrt (np .sum (shifted_grid_2d ** 2.0 , axis = 1 ))
390- theta_coordinate_to_profile = np .arctan2 (
445+ shifted_grid_2d = np .array (grid_2d ) - jnp .array (centre )
446+
447+ radius = jnp .sqrt (jnp .sum (shifted_grid_2d ** 2.0 , axis = 1 ))
448+ theta_coordinate_to_profile = jnp .arctan2 (
391449 shifted_grid_2d [:, 0 ], shifted_grid_2d [:, 1 ]
392- ) - np .radians (angle )
393- return np .vstack (
450+ ) - jnp .radians (angle )
451+ return jnp .vstack (
394452 [
395- radius * np .sin (theta_coordinate_to_profile ),
396- radius * np .cos (theta_coordinate_to_profile ),
453+ radius * jnp .sin (theta_coordinate_to_profile ),
454+ radius * jnp .cos (theta_coordinate_to_profile ),
397455 ]
398456 ).T
399457
@@ -435,7 +493,6 @@ def transform_grid_2d_from_reference_frame(
435493 return np .vstack ((y , x )).T
436494
437495
438- @numba_util .jit ()
439496def grid_pixels_2d_slim_from (
440497 grid_scaled_2d_slim : np .ndarray ,
441498 shape_native : Tuple [int , int ],
@@ -476,33 +533,15 @@ def grid_pixels_2d_slim_from(
476533 grid_pixels_2d_slim = grid_scaled_2d_slim_from(grid_scaled_2d_slim=grid_scaled_2d_slim, shape=(2,2),
477534 pixel_scales=(0.5, 0.5), origin=(0.0, 0.0))
478535 """
479-
480536 centres_scaled = central_scaled_coordinate_2d_from (
481537 shape_native = shape_native , pixel_scales = pixel_scales , origin = origin
482538 )
483- if use_jax :
484- centres_scaled = np .array (centres_scaled )
485- pixel_scales = np .array (pixel_scales )
486- sign = np .array ([- 1 , 1 ])
487- return (sign * grid_scaled_2d_slim / pixel_scales ) + centres_scaled + 0.5
488- else :
489- grid_pixels_2d_slim = np .zeros ((grid_scaled_2d_slim .shape [0 ], 2 ))
490- for slim_index in range (grid_scaled_2d_slim .shape [0 ]):
491- grid_pixels_2d_slim [slim_index , 0 ] = (
492- (- grid_scaled_2d_slim [slim_index , 0 ] / pixel_scales [0 ])
493- + centres_scaled [0 ]
494- + 0.5
495- )
496- grid_pixels_2d_slim [slim_index , 1 ] = (
497- (grid_scaled_2d_slim [slim_index , 1 ] / pixel_scales [1 ])
498- + centres_scaled [1 ]
499- + 0.5
500- )
501-
502- return grid_pixels_2d_slim
539+ centres_scaled = np .array (centres_scaled )
540+ pixel_scales = np .array (pixel_scales )
541+ sign = np .array ([- 1 , 1 ])
542+ return (sign * grid_scaled_2d_slim / pixel_scales ) + centres_scaled + 0.5
503543
504544
505- @numba_util .jit ()
506545def grid_pixel_centres_2d_slim_from (
507546 grid_scaled_2d_slim : np .ndarray ,
508547 shape_native : Tuple [int , int ],
@@ -547,32 +586,14 @@ def grid_pixel_centres_2d_slim_from(
547586 shape_native = shape_native , pixel_scales = pixel_scales , origin = origin
548587 )
549588
550- if use_jax :
551- centres_scaled = np .array (centres_scaled )
552- pixel_scales = np .array (pixel_scales )
553- sign = np .array ([- 1.0 , 1.0 ])
554- grid_pixels_2d_slim = (
555- (sign * grid_scaled_2d_slim / pixel_scales ) + centres_scaled + 0.5
556- ).astype (int )
557- else :
558- grid_pixels_2d_slim = np .zeros ((grid_scaled_2d_slim .shape [0 ], 2 ))
559-
560- for slim_index in range (grid_scaled_2d_slim .shape [0 ]):
561- grid_pixels_2d_slim [slim_index , 0 ] = int (
562- (- grid_scaled_2d_slim [slim_index , 0 ] / pixel_scales [0 ])
563- + centres_scaled [0 ]
564- + 0.5
565- )
566- grid_pixels_2d_slim [slim_index , 1 ] = int (
567- (grid_scaled_2d_slim [slim_index , 1 ] / pixel_scales [1 ])
568- + centres_scaled [1 ]
569- + 0.5
570- )
571-
572- return grid_pixels_2d_slim
589+ centres_scaled = np .array (centres_scaled )
590+ pixel_scales = np .array (pixel_scales )
591+ sign = np .array ([- 1.0 , 1.0 ])
592+ return (
593+ (sign * grid_scaled_2d_slim / pixel_scales ) + centres_scaled + 0.5
594+ ).astype (int )
573595
574596
575- @numba_util .jit ()
576597def grid_pixel_indexes_2d_slim_from (
577598 grid_scaled_2d_slim : np .ndarray ,
578599 shape_native : Tuple [int , int ],
@@ -625,25 +646,13 @@ def grid_pixel_indexes_2d_slim_from(
625646 origin = origin ,
626647 )
627648
628- if use_jax :
629- grid_pixel_indexes_2d_slim = (
630- (grid_pixels_2d_slim * np .array ([shape_native [1 ], 1 ]))
631- .sum (axis = 1 )
632- .astype (int )
633- )
634- else :
635- grid_pixel_indexes_2d_slim = np .zeros (grid_pixels_2d_slim .shape [0 ])
636-
637- for slim_index in range (grid_pixels_2d_slim .shape [0 ]):
638- grid_pixel_indexes_2d_slim [slim_index ] = int (
639- grid_pixels_2d_slim [slim_index , 0 ] * shape_native [1 ]
640- + grid_pixels_2d_slim [slim_index , 1 ]
641- )
642-
643- return grid_pixel_indexes_2d_slim
649+ return (
650+ (grid_pixels_2d_slim * np .array ([shape_native [1 ], 1 ]))
651+ .sum (axis = 1 )
652+ .astype (int )
653+ )
644654
645655
646- @numba_util .jit ()
647656def grid_scaled_2d_slim_from (
648657 grid_pixels_2d_slim : np .ndarray ,
649658 shape_native : Tuple [int , int ],
@@ -682,33 +691,18 @@ def grid_scaled_2d_slim_from(
682691 grid_pixels_2d_slim = grid_scaled_2d_slim_from(grid_pixels_2d_slim=grid_pixels_2d_slim, shape=(2,2),
683692 pixel_scales=(0.5, 0.5), origin=(0.0, 0.0))
684693 """
685-
686694 centres_scaled = central_scaled_coordinate_2d_from (
687695 shape_native = shape_native , pixel_scales = pixel_scales , origin = origin
688696 )
689- if use_jax :
690- centres_scaled = np .array (centres_scaled )
691- pixel_scales = np .array (pixel_scales )
692- sign = np .array ([- 1 , 1 ])
693- grid_scaled_2d_slim = (
694- (grid_pixels_2d_slim - centres_scaled - 0.5 ) * pixel_scales * sign
695- )
696- else :
697- grid_scaled_2d_slim = np .zeros ((grid_pixels_2d_slim .shape [0 ], 2 ))
698-
699- for slim_index in range (grid_scaled_2d_slim .shape [0 ]):
700- grid_scaled_2d_slim [slim_index , 0 ] = (
701- - (grid_pixels_2d_slim [slim_index , 0 ] - centres_scaled [0 ] - 0.5 )
702- * pixel_scales [0 ]
703- )
704- grid_scaled_2d_slim [slim_index , 1 ] = (
705- grid_pixels_2d_slim [slim_index , 1 ] - centres_scaled [1 ] - 0.5
706- ) * pixel_scales [1 ]
707-
708- return grid_scaled_2d_slim
697+
698+ centres_scaled = np .array (centres_scaled )
699+ pixel_scales = np .array (pixel_scales )
700+ sign = np .array ([- 1 , 1 ])
701+ return (
702+ (grid_pixels_2d_slim - centres_scaled - 0.5 ) * pixel_scales * sign
703+ )
709704
710705
711- @numba_util .jit ()
712706def grid_pixel_centres_2d_from (
713707 grid_scaled_2d : np .ndarray ,
714708 shape_native : Tuple [int , int ],
@@ -753,30 +747,12 @@ def grid_pixel_centres_2d_from(
753747 shape_native = shape_native , pixel_scales = pixel_scales , origin = origin
754748 )
755749
756- if use_jax :
757- centres_scaled = np .array (centres_scaled )
758- pixel_scales = np .array (pixel_scales )
759- sign = np .array ([- 1.0 , 1.0 ])
760- grid_pixels_2d = (
761- (sign * grid_scaled_2d / pixel_scales ) + centres_scaled + 0.5
762- ).astype (int )
763- else :
764- grid_pixels_2d = np .zeros ((grid_scaled_2d .shape [0 ], grid_scaled_2d .shape [1 ], 2 ))
765-
766- for y in range (grid_scaled_2d .shape [0 ]):
767- for x in range (grid_scaled_2d .shape [1 ]):
768- grid_pixels_2d [y , x , 0 ] = int (
769- (- grid_scaled_2d [y , x , 0 ] / pixel_scales [0 ])
770- + centres_scaled [0 ]
771- + 0.5
772- )
773- grid_pixels_2d [y , x , 1 ] = int (
774- (grid_scaled_2d [y , x , 1 ] / pixel_scales [1 ])
775- + centres_scaled [1 ]
776- + 0.5
777- )
778-
779- return grid_pixels_2d
750+ centres_scaled = np .array (centres_scaled )
751+ pixel_scales = np .array (pixel_scales )
752+ sign = np .array ([- 1.0 , 1.0 ])
753+ return (
754+ (sign * grid_scaled_2d / pixel_scales ) + centres_scaled + 0.5
755+ ).astype (int )
780756
781757
782758def extent_symmetric_from (
0 commit comments