Skip to content

Commit ec1e81e

Browse files
committed
remove autofit imports
1 parent ea139fc commit ec1e81e

4 files changed

Lines changed: 6 additions & 7 deletions

File tree

autoarray/geometry/geometry_2d.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
from autoarray import type as ty
1414
from autoarray.geometry import geometry_util
1515

16-
from autofit.jax_wrapper import use_jax
17-
1816
logging.basicConfig()
1917
logger = logging.getLogger(__name__)
2018

autoarray/operators/over_sampling/over_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from autoarray.operators.over_sampling import over_sample_util
1212

13-
from autofit.jax_wrapper import register_pytree_node_class
13+
from autoarray.numpy_wrapper import register_pytree_node_class
1414

1515

1616
@register_pytree_node_class

autoarray/structures/vectors/uniform.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import numpy as np
44
import jax.numpy as jnp
5-
# from autofit.jax_wrapper import numpy as np, use_jax
65
from typing import List, Optional, Tuple, Union
76

87
from autoarray.structures.arrays.uniform_2d import Array2D

test_autoarray/test_jax_changes.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
import autoarray as aa
1+
import jax.numpy as jnp
22
import pytest
33

4+
5+
import autoarray as aa
6+
47
from autoarray import Grid2D, Mask2D
5-
from autofit.jax_wrapper import numpy as np
68

79

810
@pytest.fixture(name="array")
@@ -33,4 +35,4 @@ def test_boolean_issue():
3335
mask=Mask2D.all_false((10, 10), pixel_scales=1.0),
3436
)
3537
values, keys = Grid2D.instance_flatten(grid)
36-
np.array(Grid2D.instance_unflatten(keys, values))
38+
jnp.array(Grid2D.instance_unflatten(keys, values))

0 commit comments

Comments
 (0)