1+ from astropy .io import fits
12import logging
23import numpy as np
3- from typing import Optional
4+ from pathlib import Path
45
56from autoconf import cached_property
7+ from autoconf .fitsable import ndarray_via_fits_from , output_to_fits
68
79from autoarray .dataset .abstract .dataset import AbstractDataset
810from autoarray .dataset .interferometer .w_tilde import WTildeInterferometer
1214from autoarray .structures .visibilities import Visibilities
1315from autoarray .structures .visibilities import VisibilitiesNoiseMap
1416
15- from autoarray .structures . arrays import array_2d_util
17+ from autoarray .inversion . inversion . interferometer import inversion_interferometer_util
1618
1719logger = logging .getLogger (__name__ )
1820
@@ -25,6 +27,7 @@ def __init__(
2527 uv_wavelengths : np .ndarray ,
2628 real_space_mask ,
2729 transformer_class = TransformerNUFFT ,
30+ preprocessing_directory = None ,
2831 ):
2932 """
3033 An interferometer dataset, containing the visibilities data, noise-map, real-space msk, Fourier transformer and
@@ -86,6 +89,12 @@ def __init__(
8689 uv_wavelengths = uv_wavelengths , real_space_mask = real_space_mask
8790 )
8891
92+ self .preprocessing_directory = (
93+ Path (preprocessing_directory )
94+ if preprocessing_directory is not None
95+ else None
96+ )
97+
8998 @cached_property
9099 def grids (self ):
91100 return GridsDataset (
@@ -120,7 +129,7 @@ def from_fits(
120129 file_path = noise_map_path , hdu = noise_map_hdu
121130 )
122131
123- uv_wavelengths = array_2d_util . numpy_array_2d_via_fits_from (
132+ uv_wavelengths = ndarray_via_fits_from (
124133 file_path = uv_wavelengths_path , hdu = uv_wavelengths_hdu
125134 )
126135
@@ -132,6 +141,23 @@ def from_fits(
132141 transformer_class = transformer_class ,
133142 )
134143
144+ def w_tilde_preprocessing (self ):
145+ if self .preprocessing_directory .is_dir ():
146+ filename = "{}/curvature_preload.fits" .format (self .preprocessing_directory )
147+
148+ if not self .preprocessing_directory .isfile (filename ):
149+ print ("The file {} does not exist" .format (filename ))
150+ logger .info ("INTERFEROMETER - Computing W-Tilde... May take a moment." )
151+
152+ curvature_preload = inversion_interferometer_util .w_tilde_curvature_preload_interferometer_from (
153+ noise_map_real = self .noise_map .real ,
154+ uv_wavelengths = self .uv_wavelengths ,
155+ shape_masked_pixels_2d = self .transformer .grid .mask .shape_native_masked_pixels ,
156+ grid_radians_2d = self .transformer .grid .mask .unmasked_grid_sub_1 .in_radians .native ,
157+ )
158+
159+ fits .writeto (filename , data = curvature_preload )
160+
135161 @cached_property
136162 def w_tilde (self ):
137163 """
@@ -152,10 +178,8 @@ def w_tilde(self):
152178
153179 logger .info ("INTERFEROMETER - Computing W-Tilde... May take a moment." )
154180
155- from autoarray .inversion .inversion import inversion_util_secret
156-
157181 curvature_preload = (
158- inversion_util_secret .w_tilde_curvature_preload_interferometer_from (
182+ inversion_interferometer_util .w_tilde_curvature_preload_interferometer_from (
159183 noise_map_real = np .array (self .noise_map .real ),
160184 uv_wavelengths = np .array (self .uv_wavelengths ),
161185 shape_masked_pixels_2d = np .array (
@@ -167,7 +191,7 @@ def w_tilde(self):
167191 )
168192 )
169193
170- w_matrix = inversion_util_secret .w_tilde_via_preload_from (
194+ w_matrix = inversion_interferometer_util .w_tilde_via_preload_from (
171195 w_tilde_preload = curvature_preload ,
172196 native_index_for_slim_index = self .real_space_mask .derive_indexes .native_for_slim ,
173197 )
@@ -245,8 +269,8 @@ def output_to_fits(
245269 self .noise_map .output_to_fits (file_path = noise_map_path , overwrite = overwrite )
246270
247271 if self .uv_wavelengths is not None and uv_wavelengths_path is not None :
248- array_2d_util . numpy_array_2d_to_fits (
249- array_2d = self .uv_wavelengths ,
272+ output_to_fits (
273+ values = self .uv_wavelengths ,
250274 file_path = uv_wavelengths_path ,
251275 overwrite = overwrite ,
252276 )
0 commit comments