Skip to content
12 changes: 0 additions & 12 deletions src/parcels/_core/xgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,18 +133,6 @@ def __init__(self, grid: xgcm.Grid, mesh):
ptyping.assert_valid_mesh(mesh)
self._ds = ds

@classmethod
def from_dataset(cls, ds: xr.Dataset, mesh, xgcm_kwargs=None):
"""WARNING: unstable API, subject to change in future versions.""" # TODO v4: make private or remove warning on v4 release
if xgcm_kwargs is None:
xgcm_kwargs = {}

xgcm_kwargs = {**_DEFAULT_XGCM_KWARGS, **xgcm_kwargs}

ds = _drop_field_data(ds)
grid = xgcm.Grid(ds, **xgcm_kwargs)
return cls(grid, mesh=mesh)

def __repr__(self):
return xgrid_repr(self)

Expand Down
2 changes: 1 addition & 1 deletion src/parcels/_sgrid/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def _attach_sgrid_metadata(ds: xr.Dataset, grid: SGrid2DMetadata | SGrid3DMetada
0,
grid.to_attrs(),
)
# ds.attrs["Conventions"] = "SGRID" # TODO: re-enable once XGrid.from_dataset is gone
ds.attrs["Conventions"] = "SGRID"
return ds


Expand Down
16 changes: 16 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import pytest

from parcels import FieldSet
from parcels._datasets.structured.generic import datasets as datasets_structured

SKIP_BY_DEFAULT = {"validation", "flaky"}


Expand All @@ -17,3 +20,16 @@ def pytest_collection_modifyitems(config, items):
@pytest.fixture
def tmp_parquet(tmp_path):
return tmp_path / "tmp.parquet"


@pytest.fixture
def fieldset() -> FieldSet:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit 'magical' that the fieldset() is defined here, but can be used in all the test files without an explicit import in these test files. New devs may get confused?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed in person - this is just pytest magic :)

"""FieldSet with U and V"""
ds = datasets_structured["ds_2d_left"].copy()
ds = ds[["U_A_grid", "V_A_grid", "grid"]].rename(
{
"U_A_grid": "U",
"V_A_grid": "V",
}
)
Comment on lines +29 to +34
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to worry about the values in the arrays of these datasets? For example, if they are by default zero then some advection tests may pass irrespective of how good the advection scheme is. Or is it simply not intended that the values in these arrays are used?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed in person - these datasets have random data at the moment, where other datasets can be used for more flow specific tests.

return FieldSet.from_sgrid_conventions(ds, mesh="flat")
99 changes: 55 additions & 44 deletions tests/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import numpy as np
import pytest
import uxarray as ux
import xarray as xr

from parcels import Field, UxGrid, VectorField, XGrid
from parcels._core.fieldset import FieldSet
from parcels._datasets.structured.generic import T as T_structured
from parcels._datasets.structured.generic import datasets as datasets_structured
from parcels._datasets.unstructured.generic import datasets as datasets_unstructured
Expand All @@ -18,7 +17,7 @@

def test_field_init_param_types():
data = datasets_structured["ds_2d_left"]
grid = XGrid.from_dataset(data, mesh="flat")
grid = FieldSet.from_sgrid_conventions(data, mesh="flat").data_g.grid

with pytest.raises(TypeError, match="Expected a string for variable name, got int instead."):
Field(name=123, data=data["data_g"], grid=grid, interp_method=XLinear)
Expand Down Expand Up @@ -46,25 +45,28 @@ def test_field_init_param_types():
Field(name="test", data=data["data_g"], grid=123, interp_method=XLinear)


@pytest.mark.parametrize(
"data,grid",
[
pytest.param(
ux.UxDataArray(),
XGrid.from_dataset(datasets_structured["ds_2d_left"], mesh="flat"),
id="uxdata-grid",
),
pytest.param(
xr.DataArray(),
UxGrid(
datasets_unstructured["stommel_gyre_delaunay"].uxgrid,
z=datasets_unstructured["stommel_gyre_delaunay"].coords["zf"],
mesh="flat",
),
id="xarray-uxgrid",
),
],
)
# @pytest.mark.parametrize(
# "data,grid",
# [
# pytest.param(
# ux.UxDataArray(),
# XGrid.from_dataset(datasets_structured["ds_2d_left"], mesh="flat"),
# id="uxdata-grid",
# ),
# pytest.param(
# xr.DataArray(),
# UxGrid(
# datasets_unstructured["stommel_gyre_delaunay"].uxgrid,
# z=datasets_unstructured["stommel_gyre_delaunay"].coords["zf"],
# mesh="flat",
# ),
# id="xarray-uxgrid",
# ),
# ],
# )
@pytest.mark.skip(
"Likely not relevant after refactoring from https://github.com/Parcels-code/Parcels/pull/2646"
) # TODO: Remove or replace
def test_field_incompatible_combination(data, grid):
with pytest.raises(ValueError, match="Incompatible data-grid combination."):
Field(
Expand All @@ -75,16 +77,19 @@ def test_field_incompatible_combination(data, grid):
)


@pytest.mark.parametrize(
"data,grid",
[
pytest.param(
datasets_structured["ds_2d_left"]["data_g"],
XGrid.from_dataset(datasets_structured["ds_2d_left"], mesh="flat"),
id="ds_2d_left",
), # TODO: Perhaps this test should be expanded to cover more datasets?
],
)
# @pytest.mark.parametrize(
# "data,grid",
# [
# pytest.param(
# datasets_structured["ds_2d_left"]["data_g"],
# XGrid.from_dataset(datasets_structured["ds_2d_left"], mesh="flat"),
# id="ds_2d_left",
# ), # TODO: Perhaps this test should be expanded to cover more datasets?
# ],
# )
@pytest.mark.skip(
"Needs updating after refactoring from https://github.com/Parcels-code/Parcels/pull/2646"
) # TODO: Remove or replace
def test_field_init_structured_grid(data, grid):
"""Test creating a field."""
field = Field(
Expand All @@ -98,6 +103,9 @@ def test_field_init_structured_grid(data, grid):
assert field.grid == grid


@pytest.mark.skip(
"Needs updating after refactoring from https://github.com/Parcels-code/Parcels/pull/2646"
) # TODO: Remove or replace
def test_field_init_fail_on_float_time_dim():
"""Test field initialisation fails when given float array as time dimension.

Expand All @@ -124,16 +132,19 @@ def test_field_init_fail_on_float_time_dim():
)


@pytest.mark.parametrize(
"data,grid",
[
pytest.param(
datasets_structured["ds_2d_left"]["data_g"],
XGrid.from_dataset(datasets_structured["ds_2d_left"], mesh="flat"),
id="ds_2d_left",
),
],
)
# @pytest.mark.parametrize(
# "data,grid",
# [
# pytest.param(
# datasets_structured["ds_2d_left"]["data_g"],
# XGrid.from_dataset(datasets_structured["ds_2d_left"], mesh="flat"),
# id="ds_2d_left",
# ),
# ],
# )
@pytest.mark.skip(
"Needs updating after refactoring from https://github.com/Parcels-code/Parcels/pull/2646"
) # TODO: Remove or replace
def test_field_time_interval(data, grid):
"""Test creating a field."""
field = Field(name="test_field", data=data, grid=grid, interp_method=XLinear)
Expand All @@ -148,7 +159,7 @@ def test_vectorfield_init_different_time_intervals():

def test_field_invalid_interpolator():
ds = datasets_structured["ds_2d_left"]
grid = XGrid.from_dataset(ds, mesh="flat")
grid = FieldSet.from_sgrid_conventions(ds, mesh="flat").data_g.grid

def invalid_interpolator_wrong_signature(particle_positions, grid_positions, invalid):
return 0.0
Expand All @@ -165,7 +176,7 @@ def invalid_interpolator_wrong_signature(particle_positions, grid_positions, inv

def test_vectorfield_invalid_interpolator():
ds = datasets_structured["ds_2d_left"]
grid = XGrid.from_dataset(ds, mesh="flat")
grid = FieldSet.from_sgrid_conventions(ds, mesh="flat").data_g.grid

def invalid_interpolator_wrong_signature(particle_positions, grid_positions, invalid):
return 0.0
Expand Down
38 changes: 21 additions & 17 deletions tests/test_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,18 @@
import pytest
import xarray as xr

from parcels import Field, ParticleFile, ParticleSet, VectorField, XGrid, convert
from parcels import Field, ParticleFile, ParticleSet, XGrid, convert
from parcels._core.fieldset import CalendarError, FieldSet, _datetime_to_msg
from parcels._datasets.structured.generic import T as T_structured
from parcels._datasets.structured.generic import datasets as datasets_structured
from parcels._datasets.structured.generic import datasets_sgrid
from parcels._datasets.unstructured.generic import datasets as datasets_unstructured
from parcels.interpolators import XLinear, XLinear_Velocity
from parcels.interpolators import XLinear
from tests import utils

ds = datasets_structured["ds_2d_left"]


@pytest.fixture
def fieldset() -> FieldSet:
"""Fixture to create a FieldSet object for testing."""
grid = XGrid.from_dataset(ds, mesh="flat")
U = Field("U", ds["U_A_grid"], grid, interp_method=XLinear)
V = Field("V", ds["V_A_grid"], grid, interp_method=XLinear)
UV = VectorField("UV", U, V, vector_interp_method=XLinear_Velocity)

return FieldSet(
[U, V, UV],
)


def test_fieldset_init_wrong_types():
with pytest.raises(ValueError, match="Expected `field` to be a Field or VectorField object. Got .*"):
FieldSet([1.0, 2.0, 3.0])
Expand Down Expand Up @@ -65,19 +52,28 @@ def test_fieldset_add_constant_field(fieldset):
assert fieldset.test_constant_field[time, z, lat, lon] == 1.0


@pytest.mark.skip(
"Likely not relevant after refactoring from https://github.com/Parcels-code/Parcels/pull/2646"
) # TODO: Remove or replace
def test_fieldset_add_field(fieldset):
grid = XGrid.from_dataset(ds, mesh="flat")
field = Field("test_field", ds["U_A_grid"], grid, interp_method=XLinear)
fieldset.add_field(field)
assert fieldset.test_field == field


@pytest.mark.skip(
"Likely not relevant after refactoring from https://github.com/Parcels-code/Parcels/pull/2646"
) # TODO: Remove or replace
def test_fieldset_add_field_wrong_type(fieldset):
not_a_field = 1.0
with pytest.raises(ValueError, match="Expected `field` to be a Field or VectorField object. Got .*"):
fieldset.add_field(not_a_field, "test_field")


@pytest.mark.skip(
"Likely not relevant after refactoring from https://github.com/Parcels-code/Parcels/pull/2646"
) # TODO: Remove or replace
def test_fieldset_add_field_already_exists(fieldset):
grid = XGrid.from_dataset(ds, mesh="flat")
field = Field("test_field", ds["U_A_grid"], grid, interp_method=XLinear)
Expand All @@ -97,8 +93,7 @@ def test_fieldset_gridset(fieldset):


def test_fieldset_no_UV(tmp_parquet):
grid = XGrid.from_dataset(ds, mesh="flat")
fieldset = FieldSet([Field("P", ds["U_A_grid"], grid, interp_method=XLinear)])
fieldset = FieldSet.from_sgrid_conventions(ds[["U_A_grid", "grid"]].rename({"U_A_grid": "P"}), mesh="flat")

def SampleP(particles, fieldset):
particles.dlon += fieldset.P[particles]
Expand All @@ -125,6 +120,9 @@ def test_fieldset_from_structured_generic_datasets(ds):
def test_fieldset_gridset_multiple_grids(): ...


@pytest.mark.skip(
"Needs updating after refactoring from https://github.com/Parcels-code/Parcels/pull/2646"
) # TODO: Remove or replace
def test_fieldset_time_interval():
grid1 = XGrid.from_dataset(ds, mesh="flat")
field1 = Field("field1", ds["U_A_grid"], grid1, interp_method=XLinear)
Expand All @@ -149,6 +147,9 @@ def test_fieldset_time_interval_constant_fields():
assert fieldset.time_interval is None


@pytest.mark.skip(
"Needs updating after refactoring from https://github.com/Parcels-code/Parcels/pull/2646"
) # TODO: Remove or replace
def test_fieldset_init_incompatible_calendars():
ds1 = ds.copy()
ds1["time"] = (
Expand All @@ -174,6 +175,9 @@ def test_fieldset_init_incompatible_calendars():
FieldSet([U, V, incompatible_calendar])


@pytest.mark.skip(
"Needs updating after refactoring from https://github.com/Parcels-code/Parcels/pull/2646"
) # TODO: Remove or replace
def test_fieldset_add_field_incompatible_calendars(fieldset):
ds_test = ds.copy()
ds_test["time"] = (
Expand Down
14 changes: 4 additions & 10 deletions tests/test_index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,17 @@
import xgcm

import parcels.tutorial
from parcels import Field, XGrid
from parcels import XGrid
from parcels._core.fieldset import FieldSet
from parcels._core.index_search import _latlon_rad_to_xyz, _search_indices_curvilinear_2d
from parcels._datasets.structured.generic import datasets
from parcels.interpolators import XLinear


@pytest.fixture
def field_cone():
ds = datasets["2d_left_unrolled_cone"]
grid = XGrid.from_dataset(ds, mesh="flat")
field = Field(
name="test_field",
data=ds["data_g"],
grid=grid,
interp_method=XLinear,
)
return field
fieldset = FieldSet.from_sgrid_conventions(ds, mesh="flat")
return fieldset.data_g


def test_grid_indexing_fpoints(field_cone):
Expand Down
Loading
Loading