diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index f925c40d7..d42989042 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -114,14 +114,18 @@ jobs: # Update this key to force a new cache (sync with packaging.yml) prefix-key: "python-v3" - - name: Install + - name: Install sedonadb-expr + run: | + pip install -e "python/sedonadb-expr" -vv + + - name: Install sedonadb run: | # Keep this export in sync with the export in dev/release/verify-release-candidate.sh export MATURIN_PEP517_ARGS="--features s2geography" pip install -e "python/sedonadb/[test]" -vv - # Unset so `--features s2geography` (sedonadb-only) doesn't - # carry into the plugin install. - unset MATURIN_PEP517_ARGS + + - name: Install sedonadb-zarr + run: | pip install -e "python/sedonadb-zarr/[test]" -vv - name: Download minimal geoarrow-data assets @@ -132,18 +136,23 @@ jobs: run: | docker compose up --wait --detach postgis - - name: Run tests + - name: Run tests (sedonadb) env: # Ensure that we don't skip tests that we didn't intend to SEDONADB_PYTHON_NO_SKIP_TESTS: "true" run: | - cd python + cd python/sedonadb python -m pytest -vv - - name: Run doctests + - name: Run doctests (sedonadb) run: | - cd python - python -m pytest --doctest-modules + cd python/sedonadb + python -m pytest --doctest-modules python/ + + - name: Run tests (sedonadb-expr) + run: | + cd python/sedonadb-expr + python -m pytest -vv - name: Shutdown docker compose services if: always() diff --git a/python/sedonadb-expr/.gitignore b/python/sedonadb-expr/.gitignore new file mode 100644 index 000000000..71528ae4b --- /dev/null +++ b/python/sedonadb-expr/.gitignore @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Generated files +python/sedonadb_expr/_version.py +python/sedonadb_expr/_generated/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +*.egg + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# IDE +.idea/ +.vscode/ +*.swp +*.swo diff --git a/python/sedonadb-expr/README.md b/python/sedonadb-expr/README.md new file mode 100644 index 000000000..83296cf9a --- /dev/null +++ b/python/sedonadb-expr/README.md @@ -0,0 +1,38 @@ + + +# SedonaDB Expr + +A standalone Python package for SedonaDB expressions. This is an optional +dependency of the `sedonadb` package that powers the type-specific accessors +without bloating the core package for non-interactive usage. + +## Installation + +```shell +pip install sedonadb-expr +``` + +## Example + +```python +import sedonadb_expr + +print(sedonadb_expr.__version__) +``` diff --git a/python/sedonadb-expr/_version.py b/python/sedonadb-expr/_version.py new file mode 100644 index 000000000..ff985f4c5 --- /dev/null +++ b/python/sedonadb-expr/_version.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Version source for hatchling - reads from workspace Cargo.toml. + +This file is used by hatchling at build time to determine the version. +The build hook then generates a static _version.py inside the package. +""" + +import re +from pathlib import Path + + +def get_version() -> str: + """Read version from the workspace root Cargo.toml.""" + here = Path(__file__).parent + cargo_toml = here.parent.parent / "Cargo.toml" + + if not cargo_toml.exists(): + raise FileNotFoundError(f"Could not find workspace Cargo.toml at {cargo_toml}") + + content = cargo_toml.read_text() + + match = re.search( + r'\[workspace\.package\].*?version\s*=\s*"([^"]+)"', + content, + re.DOTALL, + ) + if match: + return match.group(1) + + raise ValueError("Could not find workspace.package.version in Cargo.toml") diff --git a/python/sedonadb-expr/hatch_build.py b/python/sedonadb-expr/hatch_build.py new file mode 100644 index 000000000..277a82289 --- /dev/null +++ b/python/sedonadb-expr/hatch_build.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Hatch build hook for sedonadb-expr. + +This hook runs during sdist and wheel builds to generate Python source +files from the docs/reference/sql documentation files. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from hatchling.builders.hooks.plugin.interface import BuildHookInterface + + +class CustomBuildHook(BuildHookInterface): + """Custom build hook that generates Python sources from SQL docs.""" + + PLUGIN_NAME = "custom" + + def initialize(self, version: str, build_data: dict[str, Any]) -> None: + """ + Called before the build process starts. + + Args: + version: The version being built + build_data: Mutable dict to modify build behavior + """ + # Import the _codegen module directly to avoid triggering __init__.py, + # which imports from _generated (which doesn't exist yet). + import importlib.util + + here = Path(__file__).parent + codegen_path = here / "python" / "sedonadb_expr" / "_codegen.py" + spec = importlib.util.spec_from_file_location("_codegen", codegen_path) + codegen_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(codegen_module) + + self._generate_version(version, codegen_module.LICENSE_HEADER) + self._generate_sources(codegen_module, here) + + def _generate_version(self, version: str, license_header: str) -> None: + """Generate _version.py with the static version string.""" + here = Path(__file__).parent + version_file = here / "python" / "sedonadb_expr" / "_version.py" + + content = f'''{license_header} +# Auto-generated at build time - do not edit +__version__ = "{version}" +''' + version_file.write_text(content) + self.app.display_info(f"Generated _version.py with version {version}") + + def _generate_sources(self, codegen_module: Any, here: Path) -> None: + """Generate Python source files from docs/reference/sql.""" + generate_sources = codegen_module.generate_sources + + docs_sql = here.parent.parent / "docs" / "reference" / "sql" + output_dir = here / "python" / "sedonadb_expr" / "_generated" + + result = generate_sources(docs_sql, output_dir) + + if result.total_functions == 0 and not docs_sql.exists(): + self.app.display_warning( + f"docs/reference/sql not found at {docs_sql}, skipping generation" + ) + return + + self.app.display_info( + f"Generated {result.total_functions} functions total, " + f"{result.geo_method_count} geo methods" + ) diff --git a/python/sedonadb-expr/pyproject.toml b/python/sedonadb-expr/pyproject.toml new file mode 100644 index 000000000..e36243aa6 --- /dev/null +++ b/python/sedonadb-expr/pyproject.toml @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[build-system] +requires = ["hatchling", "pyyaml"] +build-backend = "hatchling.build" + +[project] +name = "sedonadb-expr" +readme = "README.md" +requires-python = ">=3.9" +classifiers = [ + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", +] +dynamic = ["version"] + +[project.optional-dependencies] +test = [ + "pytest", + "pyyaml", +] + +[tool.hatch.version] +source = "code" +path = "_version.py" +expression = "get_version()" + +[tool.hatch.build.targets.wheel] +packages = ["python/sedonadb_expr"] + +[tool.hatch.build.targets.wheel.hooks.custom] +path = "hatch_build.py" + +[tool.hatch.build.targets.sdist.hooks.custom] +path = "hatch_build.py" diff --git a/python/sedonadb-expr/python/sedonadb_expr/__init__.py b/python/sedonadb-expr/python/sedonadb_expr/__init__.py new file mode 100644 index 000000000..8536d706f --- /dev/null +++ b/python/sedonadb-expr/python/sedonadb_expr/__init__.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from sedonadb_expr._version import __version__ +from sedonadb_expr._generated.geo_functions import GeoFunctions +from sedonadb_expr._generated.geo_methods import GeoMethods + +__all__ = [ + "__version__", + "GeoFunctions", + "GeoMethods", +] diff --git a/python/sedonadb-expr/python/sedonadb_expr/_codegen.py b/python/sedonadb-expr/python/sedonadb_expr/_codegen.py new file mode 100644 index 000000000..d1edc1c4d --- /dev/null +++ b/python/sedonadb-expr/python/sedonadb_expr/_codegen.py @@ -0,0 +1,730 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Code generation for sedonadb-expr. + +This module generates Python source files from docs/reference/sql documentation files. +It can be invoked during the build process or run as a standalone script. +""" + +from __future__ import annotations + +import re +import textwrap +from pathlib import Path +from typing import Any + +import yaml + + +# Type to parameter name mapping (matches R version) +TYPE_TO_PARAM: dict[str, str] = { + "geometry": "geom", + "geography": "geom", + "raster": "rast", + "float64": "x", + "double": "x", + "integer": "n", + "int64": "n", + "string": "s", + "boolean": "b", + "crs": "crs", +} + +# Types that qualify for geo methods (first arg piped in) +GEO_TYPES = {"geometry", "geography"} + +DOCS_BASE_URL = "https://sedona.apache.org/sedonadb/latest/reference/sql" + + +def camel_to_snake(name: str) -> str: + """Convert CamelCase/PascalCase to snake_case. + + Examples: + AsBinary -> as_binary + GeomFromWKB -> geom_from_wkb + AsEWKT -> as_ewkt + LineInterpolatePoint -> line_interpolate_point + """ + # Insert underscore before uppercase letters that follow lowercase letters + # or before uppercase letters that are followed by lowercase letters + result = re.sub(r"(?<=[a-z])(?=[A-Z])", "_", name) + result = re.sub(r"(?<=[A-Z])(?=[A-Z][a-z])", "_", result) + return result.lower() + + +LICENSE_HEADER = """\ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" + + +class ArgInfo: + """Information about a kernel argument.""" + + def __init__( + self, + type: str, + name: str | None = None, + description: str | None = None, + optional: bool = False, + ): + self.type = type + self.name = name + self.description = description + self.optional = optional + + +class KernelInfo: + """Parsed kernel information.""" + + def __init__( + self, + args: list[ArgInfo] | None = None, + returns: str = "unknown", + variadic: bool = False, + kernel_signatures: list[str] | None = None, + ): + self.args = args if args is not None else [] + self.returns = returns + self.variadic = variadic + self.kernel_signatures = ( + kernel_signatures if kernel_signatures is not None else [] + ) + + @property + def has_optional_args(self) -> bool: + """Return True if any argument is optional.""" + return any(arg.optional for arg in self.args) + + +class FunctionInfo: + """Parsed function information from a .qmd file.""" + + def __init__( + self, + name: str, + title: str, + description: str, + kernels: list[dict[str, Any]], + is_geo_method: bool = False, + kernel_info: KernelInfo | None = None, + sql_name: str | None = None, + ): + self.name = name + self.title = title + self.description = description + self.kernels = kernels + self.is_geo_method = is_geo_method + self.kernel_info = kernel_info + self.sql_name = sql_name or name # e.g., "ST_AsBinary" + + @property + def method_name(self) -> str: + """Return the snake_case method name derived from the SQL function name. + + e.g., ST_AsBinary -> as_binary, ST_GeomFromWKB -> geom_from_wkb + """ + sql = self.sql_name + # Strip prefix (ST_, RS_, S2_, SD_) + for prefix in ("ST_", "RS_", "S2_", "SD_"): + if sql.upper().startswith(prefix): + sql = sql[len(prefix) :] + break + return camel_to_snake(sql) + + +def extract_frontmatter(file_path: Path) -> dict[str, Any]: + """Extract YAML frontmatter from a .qmd file.""" + content = file_path.read_text() + lines = content.split("\n") + + # Find YAML delimiters + delimiters = [i for i, line in enumerate(lines) if line.strip() == "---"] + if len(delimiters) < 2: + raise ValueError(f"Could not find YAML frontmatter in {file_path}") + + yaml_text = "\n".join(lines[delimiters[0] + 1 : delimiters[1]]) + return yaml.safe_load(yaml_text) + + +def extract_description_section(file_path: Path) -> str | None: + """Extract the ## Description section from the .qmd file body.""" + content = file_path.read_text() + lines = content.split("\n") + + # Find end of frontmatter + delimiters = [i for i, line in enumerate(lines) if line.strip() == "---"] + if len(delimiters) < 2: + return None + + body_lines = lines[delimiters[1] + 1 :] + + # Find ## Description section + desc_start = None + for i, line in enumerate(body_lines): + if line.startswith("## Description"): + desc_start = i + break + + if desc_start is None: + return None + + # Find next section or end + remaining = body_lines[desc_start + 1 :] + next_section = None + for i, line in enumerate(remaining): + if line.startswith("## "): + next_section = i + break + + if next_section is None: + desc_lines = remaining + else: + desc_lines = remaining[:next_section] + + # Process lines: preserve markdown lists, join paragraphs + result_lines: list[str] = [] + current_paragraph: list[str] = [] + + for line in desc_lines: + stripped = line.strip() + # Check if this is a list item (-, *, or numbered) + is_list_item = bool(re.match(r"^[-*]|\d+\.", stripped)) + + if not stripped: + # Empty line: flush current paragraph and add blank line for separation + if current_paragraph: + result_lines.append(" ".join(current_paragraph)) + current_paragraph = [] + result_lines.append("") # Preserve paragraph break + elif is_list_item: + # List item: flush paragraph first, then add list item + if current_paragraph: + result_lines.append(" ".join(current_paragraph)) + current_paragraph = [] + result_lines.append(stripped) + else: + # Regular text: accumulate into paragraph + current_paragraph.append(stripped) + + # Flush any remaining paragraph + if current_paragraph: + result_lines.append(" ".join(current_paragraph)) + + desc_text = "\n".join(result_lines).strip() + return desc_text if desc_text else None + + +def type_to_param_name( + arg_type: str, index: int = 0, needs_suffix: bool = False +) -> str: + """Generate parameter name from type.""" + base_name = TYPE_TO_PARAM.get(arg_type, "arg") + if needs_suffix: + suffix = chr(ord("a") + index) # 0=a, 1=b, 2=c, ... + return f"{base_name}_{suffix}" + return base_name + + +def parse_kernel_args(kernel_args: list) -> list[ArgInfo]: + """Parse kernel arguments into ArgInfo objects.""" + result = [] + for arg in kernel_args: + if isinstance(arg, str): + result.append(ArgInfo(type=arg)) + elif isinstance(arg, dict): + result.append( + ArgInfo( + type=arg.get("type", "unknown"), + name=arg.get("name"), + description=arg.get("description"), + ) + ) + else: + result.append(ArgInfo(type="unknown")) + return result + + +def generate_arg_names(arg_info_list: list[ArgInfo]) -> list[str]: + """Generate argument names for a kernel's args.""" + types = [info.type for info in arg_info_list] + type_counts: dict[str, int] = {} + type_totals: dict[str, int] = {} + + # Count total occurrences of each type + for t in types: + type_totals[t] = type_totals.get(t, 0) + 1 + + arg_names = [] + for info in arg_info_list: + arg_type = info.type + arg_name = info.name + + if arg_name is None: + type_counts[arg_type] = type_counts.get(arg_type, 0) + 1 + needs_suffix = type_totals.get(arg_type, 0) > 1 + arg_name = type_to_param_name( + arg_type, type_counts[arg_type] - 1, needs_suffix + ) + + arg_names.append(arg_name) + + return arg_names + + +def parse_kernel_params(kernels: list[dict], fn_name: str = "unknown") -> KernelInfo: + """Parse kernel arguments and generate parameter info.""" + if not kernels: + return KernelInfo() + + # Process all kernels + all_kernel_info = [parse_kernel_args(k.get("args", [])) for k in kernels] + all_kernel_args = [generate_arg_names(info) for info in all_kernel_info] + + # Find max args + kernel_lengths = [len(args) for args in all_kernel_args] + max_args = max(kernel_lengths) if kernel_lengths else 0 + + # Check for argument name conflicts + has_conflict = False + for pos in range(max_args): + names_at_pos = set() + for args in all_kernel_args: + if pos < len(args): + names_at_pos.add(args[pos]) + if len(names_at_pos) > 1: + has_conflict = True + break + + returns = kernels[0].get("returns", "unknown") + + if has_conflict: + # Build signature strings for documentation + kernel_signatures = [] + for i, args in enumerate(all_kernel_args): + types = [info.type for info in all_kernel_info[i]] + sig = ", ".join(f"{arg} ({t})" for arg, t in zip(args, types)) + kernel_signatures.append(sig) + + return KernelInfo( + args=[], + returns=returns, + variadic=True, + kernel_signatures=kernel_signatures, + ) + + # Use kernel with most arguments as reference + ref_idx = kernel_lengths.index(max(kernel_lengths)) if kernel_lengths else 0 + arg_info = all_kernel_info[ref_idx] if all_kernel_info else [] + arg_names = all_kernel_args[ref_idx] if all_kernel_args else [] + + # Determine minimum args (args present in all kernels) + min_args = min(kernel_lengths) if kernel_lengths else 0 + + # Update ArgInfo with generated names and optional flag + for i, info in enumerate(arg_info): + if info.name is None: + info.name = arg_names[i] + # Args beyond min_args are optional (not present in all kernels) + info.optional = i >= min_args + + return KernelInfo(args=arg_info, returns=returns, variadic=False) + + +def parse_qmd_file(qmd_path: Path) -> FunctionInfo | None: + """Parse a .qmd file and return FunctionInfo.""" + fn_name = qmd_path.stem # e.g., "st_envelope" + + try: + frontmatter = extract_frontmatter(qmd_path) + except Exception: + return None + + kernels = frontmatter.get("kernels", []) + if not kernels: + return None + + # Check if first argument of any kernel is geometry/geography + is_geo_method = False + for kernel in kernels: + args = kernel.get("args", []) + if args: + first_arg = args[0] + first_type = ( + first_arg if isinstance(first_arg, str) else first_arg.get("type", "") + ) + if first_type in GEO_TYPES: + is_geo_method = True + break + + # Get properly-cased SQL function name from title field + sql_name = frontmatter.get("title", fn_name) + title = frontmatter.get("description", frontmatter.get("title", fn_name)) + description = extract_description_section(qmd_path) or "" + + kernel_info = parse_kernel_params(kernels, fn_name) + + return FunctionInfo( + name=fn_name, + title=title, + description=description, + kernels=kernels, + is_geo_method=is_geo_method, + kernel_info=kernel_info, + sql_name=sql_name, + ) + + +def wrap_docstring(text: str, width: int = 88, indent: str = " ") -> str: + """Wrap text for docstrings, preserving markdown lists.""" + if not text: + return "" + + result_lines: list[str] = [] + for i, line in enumerate(text.split("\n")): + if not line.strip(): + result_lines.append("") + continue + + # Wrap each line separately + wrapped = textwrap.fill(line, width=width - len(indent)) + for j, wrapped_line in enumerate(wrapped.split("\n")): + if i == 0 and j == 0: + # First line of first paragraph - no indent + result_lines.append(wrapped_line) + else: + result_lines.append(indent + wrapped_line) + + return "\n".join(result_lines) + + +def generate_method_docstring(func: FunctionInfo) -> str: + """Generate docstring for a method.""" + title = func.title.strip() + parts = [f'"""{title}'] + + if func.description and func.description.strip() != title: + parts.append("") + parts.append(wrap_docstring(func.description, indent=" ")) + + kernel_info = func.kernel_info + if kernel_info: + if kernel_info.variadic and kernel_info.kernel_signatures: + # Variadic mode: document with bulleted list of supported combinations + # Skip the first arg (piped in via self._expr) from each signature + parts.append("") + parts.append("Variants:") + for sig in kernel_info.kernel_signatures: + # Split signature, skip first arg, rejoin + arg_parts = [p.strip() for p in sig.split(",")] + remaining = ", ".join(arg_parts[1:]) if len(arg_parts) > 1 else "" + if remaining: + parts.append(f" - {remaining}") + else: + parts.append(" - (no additional arguments)") + elif kernel_info.args: + # Skip first arg (piped in via self._expr) + remaining_args = kernel_info.args[1:] if len(kernel_info.args) > 1 else [] + if remaining_args: + parts.append("") + parts.append("Args:") + for arg in remaining_args: + desc = arg.description or f"Input {arg.type}" + parts.append(f" {arg.name}: {desc}") + + parts.append("") + parts.append("See Also:") + parts.append(f" {DOCS_BASE_URL}/{func.name}/") + parts.append('"""') + + joined = "\n ".join(parts) + return "\n".join(line.rstrip() for line in joined.split("\n")) + + +def generate_function_docstring(func: FunctionInfo) -> str: + """Generate docstring for a standalone function property.""" + title = func.title.strip() + parts = [f'"""{title}'] + + if func.description and func.description.strip() != title: + parts.append("") + parts.append(wrap_docstring(func.description, indent=" ")) + + kernel_info = func.kernel_info + if kernel_info: + if kernel_info.variadic and kernel_info.kernel_signatures: + # Variadic mode: document with bulleted list of supported combinations + parts.append("") + parts.append("Variants:") + for sig in kernel_info.kernel_signatures: + parts.append(f" - {sig}") + elif kernel_info.args: + parts.append("") + parts.append("Args:") + for arg in kernel_info.args: + desc = arg.description or f"Input {arg.type}" + parts.append(f" {arg.name}: {desc}") + + parts.append("") + parts.append("See Also:") + parts.append(f" {DOCS_BASE_URL}/{func.name}/") + parts.append('"""') + + joined = "\n ".join(parts) + return "\n".join(line.rstrip() for line in joined.split("\n")) + + +def generate_geo_methods_py(functions: list[FunctionInfo]) -> str: + """Generate geo_methods.py content.""" + # Filter to only geo methods (first arg is geometry/geography) + geo_funcs = [f for f in functions if f.is_geo_method] + + lines = [ + LICENSE_HEADER, + "", + '"""Auto-generated geometry/geography methods - do not edit."""', + "", + "from typing import Generic, TypeVar", + "", + "from sedonadb_expr.utils import MISSING, filter_missing_args", + "", + 'ExprT = TypeVar("ExprT")', + "", + "", + "class GeoMethods(Generic[ExprT]):", + ' """Geometry and geography methods accessible via expr.geo."""', + "", + " def __init__(self, expr: ExprT) -> None:", + " self._expr = expr", + ] + + for func in sorted(geo_funcs, key=lambda f: f.name): + # Method name: derived from SQL function name (e.g., ST_AsBinary -> as_binary) + method_name = func.method_name + + kernel_info = func.kernel_info + if not kernel_info: + continue + + # Build method signature - skip first arg (piped in) + remaining_args = kernel_info.args[1:] if len(kernel_info.args) > 1 else [] + # Check if any remaining args are optional + has_optional = any(arg.optional for arg in remaining_args) + + if kernel_info.variadic: + params = "self, *args" + call_args = "*args" + use_filter = False + elif remaining_args: + # Build param strings with MISSING default for optional args + param_strs = [] + for arg in remaining_args: + if arg.optional: + param_strs.append(f"{arg.name}=MISSING") + else: + param_strs.append(arg.name) + params = "self, " + ", ".join(param_strs) + call_args = ", ".join(arg.name for arg in remaining_args) + use_filter = has_optional + else: + params = "self" + call_args = "" + use_filter = False + + docstring = generate_method_docstring(func) + + lines.extend( + [ + "", + f" def {method_name}({params}) -> ExprT:", + f" {docstring}", + ] + ) + + if call_args: + if use_filter: + lines.append( + f' return self._expr._call("{func.name}", *filter_missing_args({call_args}))' + ) + else: + lines.append( + f' return self._expr._call("{func.name}", {call_args})' + ) + else: + lines.append(f' return self._expr._call("{func.name}")') + + lines.append("") + return "\n".join(lines) + + +def generate_geo_functions_py(functions: list[FunctionInfo]) -> str: + """Generate geo_functions.py content.""" + # Filter to only geo methods (these become callable properties) + geo_funcs = [f for f in functions if f.is_geo_method] + + lines = [ + LICENSE_HEADER, + "", + '"""Auto-generated geometry/geography functions - do not edit."""', + "", + "from typing import Callable, Generic, TypeVar", + "", + 'ExprT = TypeVar("ExprT")', + "", + "", + "class GeoFunctions(Generic[ExprT]):", + ' """Geometry and geography functions accessible via a factory."""', + "", + " def __init__(self, factory) -> None:", + " self._factory = factory", + ] + + for func in sorted(geo_funcs, key=lambda f: f.name): + # Property name: derived from SQL function name (e.g., ST_AsBinary -> as_binary) + prop_name = func.method_name + + docstring = generate_function_docstring(func) + + lines.extend( + [ + "", + " @property", + f" def {prop_name}(self) -> Callable[..., ExprT]:", + f" {docstring}", + f' return self._factory["{func.name}"]', + ] + ) + + lines.append("") + return "\n".join(lines) + + +class GenerationResult: + """Result of code generation.""" + + def __init__( + self, + total_functions: int, + geo_method_count: int, + generated_files: list[Path], + ): + self.total_functions = total_functions + self.geo_method_count = geo_method_count + self.generated_files = generated_files + + +def parse_qmd_files(docs_sql: Path, pattern: str) -> list[FunctionInfo]: + """Parse all .qmd files in a directory and return function definitions. + + Args: + docs_sql: Path to directory containing .qmd files. + + Returns: + List of parsed FunctionInfo objects. + """ + qmd_files = sorted(docs_sql.glob(pattern)) + functions: list[FunctionInfo] = [] + for qmd_file in qmd_files: + func = parse_qmd_file(qmd_file) + if func: + functions.append(func) + return functions + + +def generate_sources(docs_sql: Path, output_dir: Path) -> GenerationResult: + """Generate Python source files from docs/reference/sql. + + Args: + docs_sql: Path to docs/reference/sql directory containing .qmd files. + output_dir: Path to output directory for generated files. + + Returns: + GenerationResult with statistics about generated code. + """ + # Ensure output directory exists + output_dir.mkdir(parents=True, exist_ok=True) + + # Create __init__.py for the generated module + init_file = output_dir / "__init__.py" + init_file.write_text( + f"{LICENSE_HEADER}\n" + "# Auto-generated module - do not edit\n" + "# Generated from docs/reference/sql\n" + ) + + generated_files: list[Path] = [init_file] + + if not docs_sql.exists(): + return GenerationResult( + total_functions=0, + geo_method_count=0, + generated_files=generated_files, + ) + + functions = parse_qmd_files(docs_sql, "st_*.qmd") + + # Generate geo_methods.py + geo_methods_content = generate_geo_methods_py(functions) + geo_methods_file = output_dir / "geo_methods.py" + geo_methods_file.write_text(geo_methods_content) + generated_files.append(geo_methods_file) + + # Generate geo_functions.py + geo_functions_content = generate_geo_functions_py(functions) + geo_functions_file = output_dir / "geo_functions.py" + geo_functions_file.write_text(geo_functions_content) + generated_files.append(geo_functions_file) + + # Count stats + geo_method_count = sum(1 for f in functions if f.is_geo_method) + + return GenerationResult( + total_functions=len(functions), + geo_method_count=geo_method_count, + generated_files=generated_files, + ) + + +if __name__ == "__main__": + # Allow running as a standalone script for development/debugging + here = Path(__file__).parent + docs_sql = here.parent.parent.parent.parent / "docs" / "reference" / "sql" + output_dir = here / "_generated" + + result = generate_sources(docs_sql, output_dir) + + print(f"Generated {result.total_functions} functions total") + print(f"Generated {result.geo_method_count} geo methods") + print("Output files:") + for f in result.generated_files: + print(f" - {f}") diff --git a/python/sedonadb-expr/python/sedonadb_expr/utils.py b/python/sedonadb-expr/python/sedonadb_expr/utils.py new file mode 100644 index 000000000..c81ef8a5c --- /dev/null +++ b/python/sedonadb-expr/python/sedonadb_expr/utils.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Utility classes and functions for sedonadb-expr.""" + +from typing import Any + + +class _MissingType: + """Sentinel type for missing/omitted arguments. + + This is distinct from None, which represents a valid NULL value. + Use the MISSING singleton instance rather than creating new instances. + """ + + +MISSING = _MissingType() +"""Sentinel value for missing/omitted arguments. + +Use this as the default value for optional parameters. +""" + + +def filter_missing_args(*args: Any): + """Filter out trailing MISSING arguments, validating ordering. + + Args: + *args: Arguments to filter + + Returns: + Tuple of non-MISSING arguments + + Raises: + ValueError: If MISSING arguments appear before non-MISSING arguments + """ + if not args: + return () + + # Find indices of missing args + is_missing = [arg is MISSING for arg in args] + + if not any(is_missing): + return args + + # Find last non-missing arg + last_non_missing = -1 + for i in range(len(args) - 1, -1, -1): + if not is_missing[i]: + last_non_missing = i + break + + # Check no missing args before non-missing args + if last_non_missing >= 0 and any(is_missing[: last_non_missing + 1]): + raise ValueError("Missing arguments must be at the end of the argument list") + + # Return args up to and including last non-missing + if last_non_missing < 0: + return () + return args[: last_non_missing + 1] diff --git a/python/sedonadb-expr/tests/__init__.py b/python/sedonadb-expr/tests/__init__.py new file mode 100644 index 000000000..13a83393a --- /dev/null +++ b/python/sedonadb-expr/tests/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/python/sedonadb-expr/tests/test_codegen.py b/python/sedonadb-expr/tests/test_codegen.py new file mode 100644 index 000000000..d007a7a22 --- /dev/null +++ b/python/sedonadb-expr/tests/test_codegen.py @@ -0,0 +1,250 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pathlib import Path +from tempfile import TemporaryDirectory + +import pytest + +from sedonadb_expr import _codegen + +SAMPLE_QMD = """\ +--- +title: ST_Buffer +description: Computes a buffered geometry. +kernels: + - returns: geometry + args: + - geometry + - name: distance + type: float64 + description: Radius of the buffer +--- + +## Description + +Returns a geometry covering all points within a given distance. +This paragraph could have more than one line. + +This is the second paragraph. + +- Followed by a list! +- Second bullet point + +## The Next Section + +...if there is one +""" + + +@pytest.fixture +def sample_qmd_path(): + with TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "st_buffer.qmd" + path.write_text(SAMPLE_QMD) + yield path + + +def test_camel_to_snake(): + assert _codegen.camel_to_snake("AsBinary") == "as_binary" + assert _codegen.camel_to_snake("GeomFromWKB") == "geom_from_wkb" + assert _codegen.camel_to_snake("AsEWKT") == "as_ewkt" + assert _codegen.camel_to_snake("LineInterpolatePoint") == "line_interpolate_point" + + +def test_extract_frontmatter(sample_qmd_path: Path): + fm = _codegen.extract_frontmatter(sample_qmd_path) + assert fm["title"] == "ST_Buffer" + + +def test_extract_description_section(sample_qmd_path: Path): + desc = _codegen.extract_description_section(sample_qmd_path) + expected = """\ +Returns a geometry covering all points within a given distance. This paragraph could have more than one line. + +This is the second paragraph. + +- Followed by a list! +- Second bullet point""" + assert desc == expected + + +def test_generate_method_docstring_with_args(): + # Test case: method with description and args (first arg skipped) + func = _codegen.FunctionInfo( + name="st_buffer", + title="ST_Buffer", + description="Returns a buffered geometry.", + kernels=[], + kernel_info=_codegen.KernelInfo( + args=[ + _codegen.ArgInfo( + type="geometry", name="geom", description="Input geometry" + ), + _codegen.ArgInfo( + type="float64", name="distance", description="Buffer distance" + ), + ], + returns="geometry", + ), + ) + docstring = _codegen.generate_method_docstring(func) + expected = '''\ +"""ST_Buffer + + Returns a buffered geometry. + + Args: + distance: Buffer distance + + See Also: + https://sedona.apache.org/sedonadb/latest/reference/sql/st_buffer/ + """''' + assert docstring == expected + + +def test_generate_method_docstring_variadic(): + # Variadic mode is triggered when parameter names conflict across kernels + # e.g., ST_Buffer has (geom, distance, params) vs (geog, distance, num_quad_segs) + # where position 3 conflicts: "params" vs "num_quad_segs" + func = _codegen.FunctionInfo( + name="st_buffer", + title="ST_Buffer", + description="Creates a buffer.", + kernels=[], + kernel_info=_codegen.KernelInfo( + args=[], + returns="geometry", + variadic=True, + kernel_signatures=[ + "geom (geometry), distance (float64)", + "geom (geometry), distance (float64), params (string)", + "geog (geography), distance (float64)", + "geog (geography), distance (float64), num_quad_segs (integer)", + "geog (geography), distance (float64), params (string)", + ], + ), + ) + docstring = _codegen.generate_method_docstring(func) + # Method docstring skips the first arg (piped in via self._expr) + expected = '''\ +"""ST_Buffer + + Creates a buffer. + + Variants: + - distance (float64) + - distance (float64), params (string) + - distance (float64) + - distance (float64), num_quad_segs (integer) + - distance (float64), params (string) + + See Also: + https://sedona.apache.org/sedonadb/latest/reference/sql/st_buffer/ + """''' + assert docstring == expected + + +def test_generate_function_docstring_with_args(): + # Test case: function with description and args (all args included) + func = _codegen.FunctionInfo( + name="st_buffer", + title="ST_Buffer", + description="Returns a buffered geometry.", + kernels=[], + kernel_info=_codegen.KernelInfo( + args=[ + _codegen.ArgInfo( + type="geometry", name="geom", description="Input geometry" + ), + _codegen.ArgInfo( + type="float64", name="distance", description="Buffer distance" + ), + ], + returns="geometry", + ), + ) + docstring = _codegen.generate_function_docstring(func) + expected = '''\ +"""ST_Buffer + + Returns a buffered geometry. + + Args: + geom: Input geometry + distance: Buffer distance + + See Also: + https://sedona.apache.org/sedonadb/latest/reference/sql/st_buffer/ + """''' + assert docstring == expected + + +def test_generate_function_docstring_variadic(): + # Variadic mode is triggered when parameter names conflict across kernels + func = _codegen.FunctionInfo( + name="st_buffer", + title="ST_Buffer", + description="Creates a buffer.", + kernels=[], + kernel_info=_codegen.KernelInfo( + args=[], + returns="geometry", + variadic=True, + kernel_signatures=[ + "geom (geometry), distance (float64)", + "geom (geometry), distance (float64), params (string)", + "geom (geography), distance (float64)", + "geom (geography), distance (float64), num_quad_segs (integer)", + "geom (geography), distance (float64), params (string)", + ], + ), + ) + docstring = _codegen.generate_function_docstring(func) + expected = '''\ +"""ST_Buffer + + Creates a buffer. + + Variants: + - geom (geometry), distance (float64) + - geom (geometry), distance (float64), params (string) + - geom (geography), distance (float64) + - geom (geography), distance (float64), num_quad_segs (integer) + - geom (geography), distance (float64), params (string) + + See Also: + https://sedona.apache.org/sedonadb/latest/reference/sql/st_buffer/ + """''' + assert docstring == expected + + +def test_generate_sources(sample_qmd_path: Path): + docs_sql = sample_qmd_path.parent + with TemporaryDirectory() as tmpdir: + output_dir = Path(tmpdir) / "output" + + result = _codegen.generate_sources(docs_sql, output_dir) + + assert result.total_functions == 1 + assert result.geo_method_count == 1 + assert len(result.generated_files) == 3 + + # Verify generated files compile as valid Python + for file_path in result.generated_files: + code = file_path.read_text() + compile(code, str(file_path), "exec") diff --git a/python/sedonadb-expr/tests/test_utils.py b/python/sedonadb-expr/tests/test_utils.py new file mode 100644 index 000000000..103a5ecc8 --- /dev/null +++ b/python/sedonadb-expr/tests/test_utils.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +from sedonadb_expr import GeoFunctions, GeoMethods +from sedonadb_expr.utils import MISSING, filter_missing_args + + +class MockExpr: + """Mock expression that records _call invocations.""" + + def __init__(self): + self.calls = [] + + def _call(self, name, *args): + self.calls.append((name, args)) + return self + + +class MockFunctions: + """Mock function mapping that records _call invocations.""" + + def __init__(self): + self.calls = [] + + def __getitem__(self, name, *args): + def fn(*args): + self.calls.append((name, args)) + + return fn + + +def test_filter_missing_args(): + """Tests for filter_missing_args utility.""" + # Passthrough when no missing + assert filter_missing_args(1, 2, 3) == (1, 2, 3) + + # All missing returns empty + assert filter_missing_args(MISSING, MISSING) == () + + # Trailing missing are filtered + assert filter_missing_args(1, 2, MISSING, MISSING) == (1, 2) + assert filter_missing_args(1, MISSING) == (1,) + + # Empty args + assert filter_missing_args() == () + + # Missing before non-missing raises + with pytest.raises(ValueError, match="Missing arguments must be at the end"): + filter_missing_args(MISSING, 1) + + # Missing in middle raises + with pytest.raises(ValueError, match="Missing arguments must be at the end"): + filter_missing_args(1, MISSING, 2) + + +def test_geo_methods_missing_args(): + """Tests for MISSING argument handling in generated GeoMethods.""" + mock = MockExpr() + geo = GeoMethods(mock) + + # force4d with no args passes no extra arguments + geo.force4d() + assert mock.calls[-1] == ("st_force4d", ()) + + # force4d with z only passes just z + geo.force4d(z=1.0) + assert mock.calls[-1] == ("st_force4d", (1.0,)) + + # force4d with both z and m passes both + geo.force4d(z=1.0, m=2.0) + assert mock.calls[-1] == ("st_force4d", (1.0, 2.0)) + + # translate with partial missing filters correctly + geo.translate(deltaX=1.0, deltaY=2.0) + assert mock.calls[-1] == ("st_translate", (1.0, 2.0)) + + # force4d with MISSING z but non-MISSING m should raise + with pytest.raises(ValueError, match="Missing arguments must be at the end"): + geo.force4d(z=MISSING, m=2.0) + + # translate with MISSING in middle should raise + with pytest.raises(ValueError, match="Missing arguments must be at the end"): + geo.translate(deltaX=1.0, deltaY=MISSING, deltaZ=3.0) + + +def test_geo_functions(): + """Tests for GeoFunctions property access.""" + factory = MockFunctions() + geo_fns = GeoFunctions(factory) + + # Properties return callables + assert callable(geo_fns.affine) + assert callable(geo_fns.buffer) + + # Calling returned function invokes factory + geo_fns.envelope("geom_arg") + assert factory.calls[-1] == ("st_envelope", ("geom_arg",)) diff --git a/python/sedonadb/python/sedonadb/expr/expression.py b/python/sedonadb/python/sedonadb/expr/expression.py index 0ce6c57a3..497b9b753 100644 --- a/python/sedonadb/python/sedonadb/expr/expression.py +++ b/python/sedonadb/python/sedonadb/expr/expression.py @@ -34,6 +34,10 @@ from sedonadb.functions import Functions +if TYPE_CHECKING: + from sedonadb_expr import GeoMethods + + class Expr: """A column expression. @@ -212,6 +216,15 @@ def desc(self, nulls_first: bool = False) -> "SortExpr": """ return SortExpr(self._impl.desc(nulls_first)) + @property + def geo(self) -> "GeoMethods[Expr]": + from sedonadb_expr import GeoMethods + + return GeoMethods(self) + + def _call(self, name, *args) -> "Expr": + return self.funcs[name](*args) + # Arithmetic operators ------------------------------------------------- # # Each binary dunder routes through the shared `_binary` helper, which diff --git a/python/sedonadb/python/sedonadb/expr/literal.py b/python/sedonadb/python/sedonadb/expr/literal.py index a0c21828c..5fbbfcb04 100644 --- a/python/sedonadb/python/sedonadb/expr/literal.py +++ b/python/sedonadb/python/sedonadb/expr/literal.py @@ -20,6 +20,9 @@ from sedonadb.utility import sedona # noqa: F401 if TYPE_CHECKING: + from sedonadb_expr import GeoMethods + + from sedonadb.expr import Expr from sedonadb.functions import Functions @@ -71,6 +74,15 @@ def funcs(self) -> "Functions": return Functions(self._ctx, self) + @property + def geo(self) -> "GeoMethods[Expr]": + from sedonadb_expr import GeoMethods + + return GeoMethods(self) + + def _call(self, name, *args) -> "Expr": + return self.funcs[name](*args) + def alias(self, name: str): """Give this literal a column name. diff --git a/python/sedonadb/python/sedonadb/functions/__init__.py b/python/sedonadb/python/sedonadb/functions/__init__.py index 2469becfc..ac3e4591a 100644 --- a/python/sedonadb/python/sedonadb/functions/__init__.py +++ b/python/sedonadb/python/sedonadb/functions/__init__.py @@ -23,6 +23,8 @@ if TYPE_CHECKING: from sedonadb.functions.table import TableFunctions + from sedonadb.expr.expression import Expr + from sedonadb_expr import GeoFunctions class Functions: @@ -46,6 +48,12 @@ def table(self) -> "TableFunctions": return TableFunctions(self._ctx) + @property + def geo(self) -> "GeoFunctions[Expr]": + from sedonadb_expr import GeoFunctions + + return GeoFunctions(self) + def __getattr__(self, name) -> Union["ScalarUdf", "AggregateUdf"]: try: return ScalarUdf(self._ctx._impl.scalar_udf(name), self._ctx, self._expr) diff --git a/python/sedonadb/tests/expr/test_function_expression.py b/python/sedonadb/tests/expr/test_function_expression.py index add20624f..0675a684a 100644 --- a/python/sedonadb/tests/expr/test_function_expression.py +++ b/python/sedonadb/tests/expr/test_function_expression.py @@ -18,6 +18,9 @@ from sedonadb.expr import Expr from sedonadb.expr.expression import ScalarUdf, AggregateUdf +import shapely +import pytest + def test_scalar_st_function_returns_expr(con): st_geomfromwkt = con.funcs.st_geomfromwkt @@ -27,6 +30,10 @@ def test_scalar_st_function_returns_expr(con): assert isinstance(e, Expr) assert repr(e) == 'Expr(st_geomfromwkt(Utf8("POINT (0 1)")))' + # Also check piped function from literal + e = con.lit("POINT (0 1)").funcs.st_geomfromwkt() + assert repr(e) == 'Expr(st_geomfromwkt(Utf8("POINT (0 1)")))' + def test_scalar_st_function_alias_returns_expr(con): st_geomfromtext = con.funcs.st_geomfromtext @@ -46,6 +53,10 @@ def test_scalar_st_function_with_column(con): assert isinstance(e, Expr) assert repr(e) == "Expr(st_area(geom))" + # Also check piped function from column + e = con.col("geom").funcs.st_geomfromwkt() + assert repr(e) == "Expr(st_geomfromwkt(geom))" + def test_scalar_st_function_with_multiple_args(con): st_buffer = con.funcs.st_buffer @@ -108,3 +119,27 @@ def test_function_expression_composed(con): repr(e) == 'Expr(st_area(st_geomfromwkt(Utf8("POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))"))))' ) + + +def test_geo_functions_accessor(con): + pytest.importorskip("sedonadb_expr") + + # Check function as resolved from the geo accessor + e = con.funcs.geo.as_text(con.col("foofy")) + assert isinstance(e, Expr) + assert repr(e) == "Expr(st_astext(foofy))" + + +def test_geo_methods_accessor(con): + pytest.importorskip("sedonadb_expr") + + # Check piped function from literal via .geo accessor + e = con.lit(shapely.Point(0, 1)).geo.as_text() + assert ( + repr(e) + == """Expr(st_astext(Binary("1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,240,63") FieldMetadata { inner: {"ARROW:extension:metadata": "{}", "ARROW:extension:name": "geoarrow.wkb"} }))""" + ) + + # Check piped function from Expr via .geo accessor + e = con.col("foofy").geo.as_text() + assert repr(e) == "Expr(st_astext(foofy))"