diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..238042c1 --- /dev/null +++ b/Makefile @@ -0,0 +1,3 @@ +l lint: + @echo "Executing lint in backend code (pre-commit)" + pre-commit run --show-diff-on-failure --color=always --all-files diff --git a/openhexa/sdk/pipelines/parameter/__init__.py b/openhexa/sdk/pipelines/parameter/__init__.py index 477ed764..774fd1ec 100644 --- a/openhexa/sdk/pipelines/parameter/__init__.py +++ b/openhexa/sdk/pipelines/parameter/__init__.py @@ -5,6 +5,7 @@ from openhexa.sdk.pipelines.exceptions import InvalidParameterError, ParameterValueError +from .choices import ChoicesFromFile from .decorator import FunctionWithParameter, Parameter, parameter, validate_parameters from .types import ( TYPES_BY_PYTHON_TYPE, @@ -56,6 +57,8 @@ "SecretType", # Registry "TYPES_BY_PYTHON_TYPE", + # Dynamic choices + "ChoicesFromFile", # Widgets "DHIS2Widget", "IASOWidget", diff --git a/openhexa/sdk/pipelines/parameter/ast_constructible.py b/openhexa/sdk/pipelines/parameter/ast_constructible.py new file mode 100644 index 00000000..67f7c67c --- /dev/null +++ b/openhexa/sdk/pipelines/parameter/ast_constructible.py @@ -0,0 +1,44 @@ +"""Mixin for classes that can reconstruct themselves from an AST Call node.""" + +import ast +import inspect + + +class AstConstructible: + """Mixin that enables reconstruction of a class instance from an AST Call node. + + Any class whose ``__init__`` takes only scalar (``ast.Constant``) arguments + can inherit from this mixin and get ``from_ast_call`` for free. Adding or + renaming ``__init__`` parameters does *not* require touching the parser. + + To make the AST parser recognise a new subclass by name, add one entry to + ``_AST_CALLABLE_TYPES`` in ``runtime.py`` (and ensure the subclass module is + imported there). Auto-registration via ``__init_subclass__`` would not remove + that requirement — the registry entry only exists after the module is imported, + so an explicit import would still be needed. + """ + + @classmethod + def from_ast_call(cls, node: ast.Call) -> "AstConstructible": + """Reconstruct an instance from an AST Call node. + + Maps positional args to ``__init__`` parameter names via + ``inspect.signature``, then merges keyword args, and calls ``cls``. + """ + param_names = list(inspect.signature(cls).parameters.keys()) + kwargs = {} + for i, arg in enumerate(node.args): + if i >= len(param_names): + break + if not isinstance(arg, ast.Constant): + raise ValueError( + f"{cls.__name__}() positional argument {i + 1} must be a literal value, not a dynamic expression." + ) + kwargs[param_names[i]] = arg.value + for kw in node.keywords: + if not isinstance(kw.value, ast.Constant): + raise ValueError( + f"{cls.__name__}() keyword argument '{kw.arg}' must be a literal value, not a dynamic expression." + ) + kwargs[kw.arg] = kw.value.value + return cls(**kwargs) diff --git a/openhexa/sdk/pipelines/parameter/choices.py b/openhexa/sdk/pipelines/parameter/choices.py new file mode 100644 index 00000000..ddacc635 --- /dev/null +++ b/openhexa/sdk/pipelines/parameter/choices.py @@ -0,0 +1,67 @@ +"""Dynamic choices classes for pipeline parameters.""" + +from openhexa.sdk.pipelines.exceptions import InvalidParameterError + +from .ast_constructible import AstConstructible + +_SUPPORTED_FORMATS = {"csv", "json", "yaml", "yml"} + + +class ChoicesFromFile(AstConstructible): + """Descriptor for choices loaded dynamically from a file in the workspace file system. + + Parameters + ---------- + path : str + Path to the file in the workspace file system (e.g. "data/districts.csv"). + column : str, optional + Column name (CSV) or key (JSON/YAML) to use as choice values. + Required when the file has more than one column/key. + format : str, optional + File format (e.g. "csv", "json", "yaml"). Sent as-is to the platform. + """ + + def __init__(self, path: str, column: str | None = None, format: str | None = None): + self.path = path + self.column = column + self.format = format + self._validate_spec() + + def _validate_spec(self): + """Validate the path and column specification.""" + if not self.path or not isinstance(self.path, str): + raise InvalidParameterError("ChoicesFromFile path must be a non-empty string.") + if self.column is not None and not isinstance(self.column, str): + raise InvalidParameterError("ChoicesFromFile column must be a string.") + if self.format is not None and self.format not in _SUPPORTED_FORMATS: + raise InvalidParameterError( + f"ChoicesFromFile format '{self.format}' is not supported. " + f"Supported formats: {', '.join(sorted(_SUPPORTED_FORMATS))}." + ) + + def __repr__(self) -> str: + """Return a string representation of the ChoicesFromFile instance.""" + parts = [repr(self.path)] + if self.column is not None: + parts.append(f"column={self.column!r}") + if self.format is not None: + parts.append(f"format={self.format!r}") + return f"ChoicesFromFile({', '.join(parts)})" + + def __eq__(self, other: object) -> bool: + """Check equality based on path, column, and format.""" + if not isinstance(other, ChoicesFromFile): + return NotImplemented + return self.path == other.path and self.column == other.column and self.format == other.format + + def __hash__(self) -> int: + """Return hash based on path, column, and format.""" + return hash((self.path, self.column, self.format)) + + def to_dict(self) -> dict: + """Return a dictionary representation of the choices spec.""" + return { + "format": self.format, + "path": self.path, + "column": self.column, + } diff --git a/openhexa/sdk/pipelines/parameter/decorator.py b/openhexa/sdk/pipelines/parameter/decorator.py index 4884ffb5..01050fb4 100644 --- a/openhexa/sdk/pipelines/parameter/decorator.py +++ b/openhexa/sdk/pipelines/parameter/decorator.py @@ -15,6 +15,7 @@ S3Connection, ) +from .choices import ChoicesFromFile from .types import TYPES_BY_PYTHON_TYPE, Boolean, DHIS2ConnectionType, IASOConnectionType, Secret from .widgets import DHIS2Widget, IASOWidget @@ -42,7 +43,7 @@ def __init__( | File ], name: str | None = None, - choices: typing.Sequence | None = None, + choices: typing.Sequence | ChoicesFromFile | str | None = None, help: str | None = None, default: typing.Any | None = None, widget: DHIS2Widget | IASOWidget | None = None, @@ -66,14 +67,18 @@ def __init__( if choices is not None: if not self.type.accepts_choices: raise InvalidParameterError(f"Parameters of type {self.type} don't accept choices.") - elif len(choices) == 0: - raise InvalidParameterError("Choices, if provided, cannot be empty.") - - try: - for choice in choices: - self.type.validate(choice) - except ParameterValueError: - raise InvalidParameterError(f"The provided choices are not valid for the {self.type} parameter type.") + if isinstance(choices, str): + choices = ChoicesFromFile(choices) + elif not isinstance(choices, ChoicesFromFile): + if len(choices) == 0: + raise InvalidParameterError("Choices, if provided, cannot be empty.") + try: + for choice in choices: + self.type.validate(choice) + except ParameterValueError: + raise InvalidParameterError( + f"The provided choices are not valid for the {self.type} parameter type." + ) self.choices = choices self.name = name @@ -100,11 +105,11 @@ def validate(self, value: typing.Any) -> typing.Any: def to_dict(self) -> dict[str, typing.Any]: """Return a dictionary representation of the Parameter instance.""" - return { + d = { "code": self.code, "type": self.type.spec_type, "name": self.name, - "choices": self.choices, + "choices": None if isinstance(self.choices, ChoicesFromFile) else self.choices, "help": self.help, "default": self.default, "widget": self.widget.value if self.widget else None, @@ -113,6 +118,9 @@ def to_dict(self) -> dict[str, typing.Any]: "multiple": self.multiple, "directory": self.directory, } + if isinstance(self.choices, ChoicesFromFile): + d["choices_from_file"] = self.choices.to_dict() + return d def _validate_single(self, value: typing.Any): # Normalize empty values to None and handles default @@ -129,7 +137,11 @@ def _validate_single(self, value: typing.Any): return None pre_validated = self.type.validate(normalized_value) - if self.choices is not None and pre_validated not in self.choices: + if ( + self.choices is not None + and not isinstance(self.choices, ChoicesFromFile) + and pre_validated not in self.choices + ): raise ParameterValueError(f"The provided value for {self.code} is not included in the provided choices.") return pre_validated @@ -152,7 +164,11 @@ def _validate_multiple(self, value: typing.Any): raise ParameterValueError(f"{self.code} is required") pre_validated = [self.type.validate(single_value) for single_value in normalized_value] - if self.choices is not None and any(v not in self.choices for v in pre_validated): + if ( + self.choices is not None + and not isinstance(self.choices, ChoicesFromFile) + and any(v not in self.choices for v in pre_validated) + ): raise ParameterValueError( f"One of the provided values for {self.code} is not included in the provided choices." ) @@ -174,7 +190,7 @@ def _validate_default(self, default: typing.Any, multiple: bool): except ParameterValueError: raise InvalidParameterError(f"The default value for {self.code} is not valid.") - if self.choices is not None: + if self.choices is not None and not isinstance(self.choices, ChoicesFromFile): if isinstance(default, list): if not all(d in self.choices for d in default): raise InvalidParameterError( @@ -227,7 +243,7 @@ def parameter( | File ], name: str | None = None, - choices: typing.Sequence | None = None, + choices: typing.Sequence | ChoicesFromFile | str | None = None, help: str | None = None, widget: DHIS2Widget | IASOWidget | None = None, connection: str | None = None, @@ -261,7 +277,7 @@ def parameter( An optional default value for the parameter (should be of the type defined by the type parameter) required : bool, default=True Whether the parameter is mandatory - multiple : bool, default=True + multiple : bool, default=False Whether this parameter should be provided multiple values (if True, the value must be provided as a list of values of the chosen type) directory : str, optional diff --git a/openhexa/sdk/pipelines/runtime.py b/openhexa/sdk/pipelines/runtime.py index c1e3195e..e23888a6 100644 --- a/openhexa/sdk/pipelines/runtime.py +++ b/openhexa/sdk/pipelines/runtime.py @@ -6,6 +6,7 @@ import io import os import sys +from collections.abc import Callable from dataclasses import dataclass, field from pathlib import Path from typing import Any @@ -16,6 +17,7 @@ from openhexa.sdk.pipelines.exceptions import InvalidParameterError, PipelineNotFound from openhexa.sdk.pipelines.parameter import ( TYPES_BY_PYTHON_TYPE, + ChoicesFromFile, DHIS2Widget, IASOWidget, Parameter, @@ -25,6 +27,12 @@ from .pipeline import Pipeline +# Maps AST function names to classes that support from_ast_call(). +# Add an entry here when introducing a new AstConstructible type. +_AST_CALLABLE_TYPES: dict[str, type] = { + "ChoicesFromFile": ChoicesFromFile, +} + @dataclass class Argument: @@ -33,6 +41,7 @@ class Argument: name: str # Use str instead of string types: list[type] = field(default_factory=list) default_value: Any = None + transform: Callable | None = None def import_pipeline(pipeline_dir_path: str) -> Pipeline: @@ -172,6 +181,12 @@ def _get_decorator_arg_value(decorator: ast.Call, arg: Argument, index: int) -> return (keyword.value.id, True) elif isinstance(keyword.value, ast.List): return ([el.value for el in keyword.value.elts], True) + elif isinstance(keyword.value, ast.Call): + func = keyword.value.func + func_name = func.id if isinstance(func, ast.Name) else None + if func_name not in _AST_CALLABLE_TYPES: + raise ValueError(f"Unsupported call in choices argument: {func_name}") + return _AST_CALLABLE_TYPES[func_name].from_ast_call(keyword.value), True elif isinstance(keyword.value, ast.Attribute): if keyword.value.attr in DHIS2Widget.__members__: return getattr(DHIS2Widget, keyword.value.attr), True @@ -201,6 +216,8 @@ def _get_decorator_spec(decorator: ast.Call, args: tuple[Argument, ...]) -> dict args_spec = {} for i, arg in enumerate(args): value, is_keyword = _get_decorator_arg_value(decorator, arg, i) + if arg.transform is not None: + value = arg.transform(value) args_spec[arg.name] = {"value": value, "is_keyword": is_keyword} return args_spec @@ -287,7 +304,11 @@ def get_pipeline(pipeline_path: Path) -> Pipeline: Argument("code", [ast.Constant]), Argument("type", [ast.Name]), Argument("name", [ast.Constant]), - Argument("choices", [ast.List]), + Argument( + "choices", + [ast.List, ast.Call, ast.Constant], + transform=lambda v: ChoicesFromFile(v) if isinstance(v, str) else v, + ), Argument("help", [ast.Constant]), Argument("default", [ast.Constant, ast.List]), Argument("widget", [ast.Attribute]), diff --git a/pyproject.toml b/pyproject.toml index d6e9ccbd..40085654 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ requires-python = ">=3.11,<3.15" # the main constraint for supported Python vers dependencies = [ "urllib3<3", "multiprocess~=0.70.15", - "requests>=2.31,<2.34", + "requests>=2.31,<3", "PyYAML~=6.0", "click~=8.1.3", "jinja2>3,<4", diff --git a/tests/test_choices.py b/tests/test_choices.py new file mode 100644 index 00000000..a816ba85 --- /dev/null +++ b/tests/test_choices.py @@ -0,0 +1,392 @@ +"""Tests for ChoicesFromFile dynamic parameter choices.""" + +import tempfile +from unittest import TestCase + +import pytest + +from openhexa.sdk.pipelines.exceptions import InvalidParameterError +from openhexa.sdk.pipelines.parameter import ChoicesFromFile, Parameter, parameter +from openhexa.sdk.pipelines.runtime import get_pipeline + +# --------------------------------------------------------------------------- +# ChoicesFromFile construction +# --------------------------------------------------------------------------- + + +class TestChoicesFromFileConstruction: + def test_format_defaults_to_none(self): + fc = ChoicesFromFile("districts.csv") + assert fc.format is None + assert fc.path == "districts.csv" + assert fc.column is None + + def test_explicit_format_accepted(self): + fc = ChoicesFromFile("data/regions.json", column="code", format="json") + assert fc.format == "json" + assert fc.column == "code" + + def test_explicit_format_yaml(self): + assert ChoicesFromFile("list.yaml", format="yaml").format == "yaml" + + def test_yml_explicit_format_accepted(self): + assert ChoicesFromFile("list.yml", format="yml").format == "yml" + + def test_invalid_explicit_format_raises(self): + with pytest.raises(InvalidParameterError, match="Supported formats"): + ChoicesFromFile("districts.csv", format="excel") + + def test_any_extension_accepted(self): + fc = ChoicesFromFile("districts.xlsx") + assert fc.format is None + + def test_no_extension_accepted(self): + fc = ChoicesFromFile("districts") + assert fc.format is None + + def test_empty_path_raises(self): + with pytest.raises(InvalidParameterError): + ChoicesFromFile("") + + def test_non_string_column_raises(self): + with pytest.raises(InvalidParameterError): + ChoicesFromFile("districts.csv", column=42) + + def test_to_dict(self): + fc = ChoicesFromFile("data/districts.csv", column="code", format="csv") + assert fc.to_dict() == {"format": "csv", "path": "data/districts.csv", "column": "code"} + + def test_to_dict_no_column(self): + fc = ChoicesFromFile("districts.csv") + assert fc.to_dict() == {"format": None, "path": "districts.csv", "column": None} + + +# --------------------------------------------------------------------------- +# String shorthand — Parameter.__init__ +# --------------------------------------------------------------------------- + + +class TestStringShorthand: + # --- happy paths --- + + def test_string_shorthand_csv(self): + p = Parameter(code="district", type=str, choices="districts.csv") + assert p.choices == ChoicesFromFile("districts.csv") + + def test_string_shorthand_json(self): + p = Parameter(code="district", type=str, choices="data/regions.json") + assert isinstance(p.choices, ChoicesFromFile) + assert p.choices.format is None + + def test_string_shorthand_yaml(self): + p = Parameter(code="district", type=str, choices="list.yaml") + assert isinstance(p.choices, ChoicesFromFile) + assert p.choices.format is None + + def test_string_shorthand_any_extension(self): + p = Parameter(code="district", type=str, choices="list.yml") + assert isinstance(p.choices, ChoicesFromFile) + assert p.choices.format is None + + def test_string_shorthand_leading_slash_stripped(self): + p = Parameter(code="district", type=str, choices="/choices.csv") + assert p.choices.path == "/choices.csv" # ChoicesFromFile stores as-is; stripping is app-side + + def test_string_shorthand_serialises_same_as_explicit(self): + shorthand = Parameter(code="district", type=str, choices="districts.csv").to_dict() + explicit = Parameter(code="district", type=str, choices=ChoicesFromFile("districts.csv")).to_dict() + assert shorthand == explicit + + # --- static list still works --- + + def test_static_list_unaffected(self): + p = Parameter(code="country", type=str, choices=["UG", "KE"]) + assert p.choices == ["UG", "KE"] + + def test_explicit_choices_from_file_unaffected(self): + p = Parameter(code="district", type=str, choices=ChoicesFromFile("districts.csv", column="code")) + assert p.choices == ChoicesFromFile("districts.csv", column="code") + + # --- any string path is accepted (format defaults to None) --- + + def test_string_no_extension_accepted(self): + p = Parameter(code="district", type=str, choices="nodot") + assert isinstance(p.choices, ChoicesFromFile) + assert p.choices.format is None + + def test_string_any_extension_accepted(self): + p = Parameter(code="district", type=str, choices="file.xlsx") + assert isinstance(p.choices, ChoicesFromFile) + assert p.choices.format is None + + def test_empty_string_raises(self): + with pytest.raises(InvalidParameterError): + Parameter(code="district", type=str, choices="") + + # --- column cannot be specified via shorthand --- + + def test_shorthand_has_no_column(self): + p = Parameter(code="district", type=str, choices="districts.csv") + assert p.choices == ChoicesFromFile("districts.csv") + + def test_decorator_with_string_shorthand(self): + @parameter(code="district", type=str, choices="districts.csv") + def my_pipeline(district): + pass + + params = my_pipeline.get_all_parameters() + assert isinstance(params[0].choices, ChoicesFromFile) + + +# --------------------------------------------------------------------------- +# String shorthand — AST round-trip +# --------------------------------------------------------------------------- + + +class TestAstStringShorthand(TestCase): + def _write_pipeline(self, tmpdir, param_line): + with open(f"{tmpdir}/pipeline.py", "w") as f: + f.write( + "\n".join( + [ + "from openhexa.sdk.pipelines import pipeline, parameter", + "", + param_line, + "@pipeline(name='Test pipeline')", + "def test_pipeline(district):", + " pass", + ] + ) + ) + + def test_ast_string_shorthand_csv(self): + with tempfile.TemporaryDirectory() as tmpdir: + self._write_pipeline( + tmpdir, + "@parameter('district', type=str, choices='districts.csv')", + ) + p = get_pipeline(tmpdir) + param_dict = p.to_dict()["parameters"][0] + assert param_dict["choices"] is None + assert param_dict["choices_from_file"] == {"format": None, "path": "districts.csv", "column": None} + + def test_ast_string_shorthand_json(self): + with tempfile.TemporaryDirectory() as tmpdir: + self._write_pipeline( + tmpdir, + "@parameter('district', type=str, choices='regions.json')", + ) + p = get_pipeline(tmpdir) + assert p.to_dict()["parameters"][0]["choices_from_file"]["format"] is None + + def test_ast_string_shorthand_any_extension(self): + with tempfile.TemporaryDirectory() as tmpdir: + self._write_pipeline( + tmpdir, + "@parameter('district', type=str, choices='list.yml')", + ) + p = get_pipeline(tmpdir) + assert p.to_dict()["parameters"][0]["choices_from_file"]["format"] is None + + def test_ast_string_shorthand_same_output_as_explicit(self): + with tempfile.TemporaryDirectory() as tmpdir: + self._write_pipeline( + tmpdir, + "@parameter('district', type=str, choices='districts.csv')", + ) + shorthand_dict = get_pipeline(tmpdir).to_dict()["parameters"][0] + + with tempfile.TemporaryDirectory() as tmpdir: + self._write_pipeline( + tmpdir, + "@parameter('district', type=str, choices=ChoicesFromFile('districts.csv'))", + ) + # need the import for the explicit form + with open(f"{tmpdir}/pipeline.py", "w") as f: + f.write( + "\n".join( + [ + "from openhexa.sdk.pipelines import pipeline, parameter", + "from openhexa.sdk.pipelines.parameter import ChoicesFromFile", + "", + "@parameter('district', type=str, choices=ChoicesFromFile('districts.csv'))", + "@pipeline(name='Test pipeline')", + "def test_pipeline(district):", + " pass", + ] + ) + ) + explicit_dict = get_pipeline(tmpdir).to_dict()["parameters"][0] + + assert shorthand_dict == explicit_dict + + def test_ast_static_list_unaffected(self): + with tempfile.TemporaryDirectory() as tmpdir: + self._write_pipeline( + tmpdir, + "@parameter('country', type=str, choices=['UG', 'KE'])", + ) + p = get_pipeline(tmpdir) + param_dict = p.to_dict()["parameters"][0] + assert param_dict["choices"] == ["UG", "KE"] + assert "choices_from_file" not in param_dict + + def test_ast_string_no_extension_accepted(self): + with tempfile.TemporaryDirectory() as tmpdir: + self._write_pipeline( + tmpdir, + "@parameter('district', type=str, choices='nodot')", + ) + p = get_pipeline(tmpdir) + assert p.to_dict()["parameters"][0]["choices_from_file"]["format"] is None + + def test_ast_string_any_extension_accepted(self): + with tempfile.TemporaryDirectory() as tmpdir: + self._write_pipeline( + tmpdir, + "@parameter('district', type=str, choices='file.xlsx')", + ) + p = get_pipeline(tmpdir) + assert p.to_dict()["parameters"][0]["choices_from_file"]["format"] is None + + +# --------------------------------------------------------------------------- +# Parameter integration +# --------------------------------------------------------------------------- + + +class TestParameterWithChoicesFromFile: + def test_accepts_file_choices(self): + p = Parameter(code="district", type=str, choices=ChoicesFromFile("districts.csv")) + assert isinstance(p.choices, ChoicesFromFile) + + def test_to_dict_emits_file_choices_key(self): + p = Parameter(code="district", type=str, choices=ChoicesFromFile("districts.csv", column="code", format="csv")) + d = p.to_dict() + assert d["choices"] is None + assert d["choices_from_file"] == {"format": "csv", "path": "districts.csv", "column": "code"} + + def test_to_dict_no_file_choices_key_for_static_choices(self): + p = Parameter(code="country", type=str, choices=["UG", "KE"]) + d = p.to_dict() + assert d["choices"] == ["UG", "KE"] + assert "choices_from_file" not in d + + def test_rejects_file_choices_on_bool_type(self): + with pytest.raises(InvalidParameterError, match="don't accept choices"): + Parameter(code="flag", type=bool, choices=ChoicesFromFile("flags.csv")) + + def test_validate_single_skips_choices_check(self): + p = Parameter(code="district", type=str, choices=ChoicesFromFile("districts.csv")) + # Any string value passes — the platform validates against the resolved list + assert p.validate("any_value") == "any_value" + + def test_validate_multiple_skips_choices_check(self): + p = Parameter(code="district", type=str, choices=ChoicesFromFile("districts.csv"), multiple=True) + assert p.validate(["A", "B", "C"]) == ["A", "B", "C"] + + def test_default_not_validated_against_file_choices(self): + # Should not raise even though default isn't in any resolved list + p = Parameter(code="district", type=str, choices=ChoicesFromFile("districts.csv"), default="UNKNOWN") + assert p.default == "UNKNOWN" + + def test_decorator_with_file_choices(self): + @parameter(code="district", type=str, choices=ChoicesFromFile("districts.csv")) + def my_pipeline(district): + pass + + params = my_pipeline.get_all_parameters() + assert len(params) == 1 + assert isinstance(params[0].choices, ChoicesFromFile) + + +# --------------------------------------------------------------------------- +# AST round-trip +# --------------------------------------------------------------------------- + + +class TestAstChoicesFromFile(TestCase): + def _write_pipeline(self, tmpdir, param_line): + with open(f"{tmpdir}/pipeline.py", "w") as f: + f.write( + "\n".join( + [ + "from openhexa.sdk.pipelines import pipeline, parameter", + "from openhexa.sdk.pipelines.parameter import ChoicesFromFile", + "", + param_line, + "@pipeline(name='Test pipeline')", + "def test_pipeline(district):", + " pass", + ] + ) + ) + + def test_file_choices_positional_path(self): + with tempfile.TemporaryDirectory() as tmpdir: + self._write_pipeline( + tmpdir, + "@parameter('district', type=str, choices=ChoicesFromFile('districts.csv'))", + ) + p = get_pipeline(tmpdir) + param_dict = p.to_dict()["parameters"][0] + assert param_dict["choices"] is None + assert param_dict["choices_from_file"] == {"format": None, "path": "districts.csv", "column": None} + + def test_file_choices_with_column(self): + with tempfile.TemporaryDirectory() as tmpdir: + self._write_pipeline( + tmpdir, + "@parameter('district', type=str, choices=ChoicesFromFile('data/districts.csv', column='code'))", + ) + p = get_pipeline(tmpdir) + param_dict = p.to_dict()["parameters"][0] + assert param_dict["choices_from_file"] == {"format": None, "path": "data/districts.csv", "column": "code"} + + def test_file_choices_with_column_positional(self): + with tempfile.TemporaryDirectory() as tmpdir: + self._write_pipeline( + tmpdir, + "@parameter('district', type=str, choices=ChoicesFromFile('data/districts.csv', 'code'))", + ) + p = get_pipeline(tmpdir) + param_dict = p.to_dict()["parameters"][0] + assert param_dict["choices_from_file"] == {"format": None, "path": "data/districts.csv", "column": "code"} + + def test_file_choices_explicit_format(self): + with tempfile.TemporaryDirectory() as tmpdir: + self._write_pipeline( + tmpdir, + "@parameter('district', type=str, choices=ChoicesFromFile('regions.json', column='id', format='json'))", + ) + p = get_pipeline(tmpdir) + param_dict = p.to_dict()["parameters"][0] + assert param_dict["choices_from_file"]["format"] == "json" + + def test_file_choices_format_none_by_default(self): + with tempfile.TemporaryDirectory() as tmpdir: + self._write_pipeline( + tmpdir, + "@parameter('district', type=str, choices=ChoicesFromFile('list.yml'))", + ) + p = get_pipeline(tmpdir) + param_dict = p.to_dict()["parameters"][0] + assert param_dict["choices_from_file"]["format"] is None + + def test_unsupported_call_in_choices_raises(self): + with tempfile.TemporaryDirectory() as tmpdir: + with open(f"{tmpdir}/pipeline.py", "w") as f: + f.write( + "\n".join( + [ + "from openhexa.sdk.pipelines import pipeline, parameter", + "", + "@parameter('district', type=str, choices=dict(a=1))", + "@pipeline(name='Test pipeline')", + "def test_pipeline(district):", + " pass", + ] + ) + ) + with self.assertRaises(ValueError, msg="Unsupported call"): + get_pipeline(tmpdir)