|
1 | 1 | import numpy as np |
2 | | -from jax.tree_util import register_pytree_node_class |
3 | 2 |
|
4 | 3 | import pytest |
5 | 4 |
|
6 | 5 | from autoarray.structures.triangles.abstract import HEIGHT_FACTOR |
7 | 6 |
|
8 | | -from autoarray.structures.triangles.coordinate_array import ( |
9 | | - CoordinateArrayTriangles, |
| 7 | +from autoarray.structures.triangles.coordinate_array_np import ( |
| 8 | + CoordinateArrayTrianglesNp, |
10 | 9 | ) |
11 | 10 |
|
12 | | -CoordinateArrayTriangles = register_pytree_node_class(CoordinateArrayTriangles) |
13 | | - |
14 | 11 |
|
15 | 12 | def test__two(two_triangles): |
16 | 13 |
|
@@ -52,7 +49,7 @@ def test__trivial_triangles(one_triangle): |
52 | 49 |
|
53 | 50 |
|
54 | 51 | def test__above(): |
55 | | - triangles = CoordinateArrayTriangles( |
| 52 | + triangles = CoordinateArrayTrianglesNp( |
56 | 53 | coordinates=np.array([[0, 1]]), |
57 | 54 | side_length=1.0, |
58 | 55 | ) |
@@ -87,7 +84,7 @@ def test__above(): |
87 | 84 |
|
88 | 85 | @pytest.fixture |
89 | 86 | def upside_down(): |
90 | | - return CoordinateArrayTriangles( |
| 87 | + return CoordinateArrayTrianglesNp( |
91 | 88 | coordinates=np.array([[1, 0]]), |
92 | 89 | side_length=1.0, |
93 | 90 | ) |
@@ -279,104 +276,37 @@ def test_means(one_triangle): |
279 | 276 |
|
280 | 277 |
|
281 | 278 | def test_triangles_touch(): |
282 | | - triangles = CoordinateArrayTriangles( |
| 279 | + triangles = CoordinateArrayTrianglesNp( |
283 | 280 | np.array([[0, 0], [2, 0]]), |
284 | 281 | ) |
285 | 282 |
|
286 | 283 | assert max(triangles.triangles[0][:, 0]) == min(triangles.triangles[1][:, 0]) |
287 | 284 |
|
288 | | - triangles = CoordinateArrayTriangles( |
| 285 | + triangles = CoordinateArrayTrianglesNp( |
289 | 286 | np.array([[0, 0], [0, 1]]), |
290 | 287 | ) |
291 | 288 | assert max(triangles.triangles[0][:, 1]) == min(triangles.triangles[1][:, 1]) |
292 | 289 |
|
293 | 290 |
|
294 | 291 | def test_from_grid_regression(): |
295 | | - triangles = CoordinateArrayTriangles.for_limits_and_scale( |
296 | | - x_min=-4.75, |
297 | | - x_max=4.75, |
298 | | - y_min=-4.75, |
299 | | - y_max=4.75, |
300 | | - scale=0.5, |
| 292 | + triangles = CoordinateArrayTrianglesNp.for_limits_and_scale( |
| 293 | + x_min=-2.0, |
| 294 | + x_max=2.0, |
| 295 | + y_min=-2.0, |
| 296 | + y_max=2.0, |
| 297 | + scale=1.5, |
301 | 298 | ) |
302 | 299 |
|
303 | 300 | x = triangles.vertices[:, 0] |
304 | | - assert min(x) <= -4.75 |
305 | | - assert max(x) >= 4.75 |
| 301 | + assert min(x) <= -2.0 |
| 302 | + assert max(x) >= 2.0 |
306 | 303 |
|
307 | 304 | y = triangles.vertices[:, 1] |
308 | | - assert min(y) <= -4.75 |
309 | | - assert max(y) >= 4.75 |
310 | | - |
| 305 | + assert min(y) <= -2.0 |
| 306 | + assert max(y) >= 2.0 |
311 | 307 |
|
312 | | -@pytest.fixture |
313 | | -def one_triangle(): |
314 | | - return CoordinateArrayTriangles( |
315 | | - coordinates=np.array([[0, 0]]), |
316 | | - side_length=1.0, |
317 | | - ) |
318 | 308 |
|
319 | | - |
320 | | -def test_neighborhood(one_triangle): |
321 | | - import jax |
322 | | - |
323 | | - assert np.allclose( |
324 | | - np.array(jax.jit(one_triangle.neighborhood)().triangles), |
325 | | - np.array( |
326 | | - [ |
327 | | - [ |
328 | | - [-0.5, -0.4330126941204071], |
329 | | - [-1.0, 0.4330126941204071], |
330 | | - [0.0, 0.4330126941204071], |
331 | | - ], |
332 | | - [ |
333 | | - [0.0, -1.299038052558899], |
334 | | - [-0.5, -0.4330126941204071], |
335 | | - [0.5, -0.4330126941204071], |
336 | | - ], |
337 | | - [ |
338 | | - [0.0, 0.4330126941204071], |
339 | | - [0.5, -0.4330126941204071], |
340 | | - [-0.5, -0.4330126941204071], |
341 | | - ], |
342 | | - [ |
343 | | - [0.5, -0.4330126941204071], |
344 | | - [0.0, 0.4330126941204071], |
345 | | - [1.0, 0.4330126941204071], |
346 | | - ], |
347 | | - ] |
348 | | - ), |
349 | | - ) |
350 | | - |
351 | | - |
352 | | -def test_up_sample(one_triangle): |
353 | | - import jax |
354 | | - |
355 | | - up_sampled = jax.jit(one_triangle.up_sample)() |
356 | | - assert np.allclose( |
357 | | - np.array(up_sampled.triangles), |
358 | | - np.array( |
359 | | - [ |
360 | | - [ |
361 | | - [[0.0, -0.4330126941204071], [-0.25, 0.0], [0.25, 0.0]], |
362 | | - [ |
363 | | - [0.25, 0.0], |
364 | | - [0.5, -0.4330126941204071], |
365 | | - [0.0, -0.4330126941204071], |
366 | | - ], |
367 | | - [ |
368 | | - [-0.25, 0.0], |
369 | | - [0.0, -0.4330126941204071], |
370 | | - [-0.5, -0.4330126941204071], |
371 | | - ], |
372 | | - [[0.0, 0.4330126941204071], [0.25, 0.0], [-0.25, 0.0]], |
373 | | - ] |
374 | | - ] |
375 | | - ), |
376 | | - ) |
377 | | - |
378 | | - |
379 | | -def test_means(one_triangle): |
| 309 | +def test_means_up_sampled(one_triangle): |
380 | 310 | assert len(one_triangle.means) == 1 |
381 | 311 |
|
382 | 312 | up_sampled = one_triangle.up_sample() |
|
0 commit comments