Skip to content

Commit cf61b74

Browse files
committed
numba import
1 parent d76904e commit cf61b74

1 file changed

Lines changed: 32 additions & 25 deletions

File tree

autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1+
from astropy.io import fits
12
import logging
2-
from functools import reduce
33
import numpy as np
4-
import sys, time
5-
import numba
6-
from numba import prange, config
4+
import time
75
import multiprocessing as mp
6+
import os
87
from typing import Tuple
98

109
from autoarray import numba_util
@@ -1425,32 +1424,40 @@ def jit_loop_preload_4(
14251424
)
14261425

14271426

1427+
try:
14281428

1429+
import numba
1430+
from numba import prange
14291431

1430-
@numba.jit("void(f8[:,:], i8)", nopython=True, parallel=True, cache = True)
1431-
def jit_loop2(
1432-
curvature_matrix: np.ndarray,
1433-
pix_pixels: int):
1434-
'''
1435-
Performs second stage of curvature matrix calculation using Numba parallelisation and JIT.
14361432

1437-
Parameters
1438-
----------
1439-
curvature_matrix
1440-
Curvature matrix this function operates on. Still requires third stage of calculation.
1441-
pix_pixels
1442-
Size of one dimension of the curvature matrix.
1433+
@numba.jit("void(f8[:,:], i8)", nopython=True, parallel=True, cache = True)
1434+
def jit_loop2(
1435+
curvature_matrix: np.ndarray,
1436+
pix_pixels: int):
1437+
'''
1438+
Performs second stage of curvature matrix calculation using Numba parallelisation and JIT.
14431439
1444-
Returns
1445-
-------
1446-
none
1447-
Updates shared object.
1448-
'''
1440+
Parameters
1441+
----------
1442+
curvature_matrix
1443+
Curvature matrix this function operates on. Still requires third stage of calculation.
1444+
pix_pixels
1445+
Size of one dimension of the curvature matrix.
1446+
1447+
Returns
1448+
-------
1449+
none
1450+
Updates shared object.
1451+
'''
1452+
1453+
curvature_matrix_temp = curvature_matrix.copy()
1454+
for i in prange(pix_pixels):
1455+
for j in range(pix_pixels):
1456+
curvature_matrix[i, j] = curvature_matrix_temp[i, j] + curvature_matrix_temp[j, i]
1457+
1458+
except ModuleNotFoundError:
14491459

1450-
curvature_matrix_temp = curvature_matrix.copy()
1451-
for i in prange(pix_pixels):
1452-
for j in range(pix_pixels):
1453-
curvature_matrix[i, j] = curvature_matrix_temp[i, j] + curvature_matrix_temp[j, i]
1460+
pass
14541461

14551462

14561463
@numba_util.jit(cache = True)

0 commit comments

Comments
 (0)