|
| 1 | +"""Validation helpers shared by nuisance-estimation components.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from typing import Any |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +from sklearn.utils.validation import check_array |
| 9 | + |
| 10 | + |
| 11 | +def validate_features(X: Any) -> np.ndarray: |
| 12 | + """Validate and coerce a feature matrix to a dense 2D numpy array.""" |
| 13 | + return check_array(X, ensure_2d=True, dtype=None, ensure_all_finite="allow-nan") |
| 14 | + |
| 15 | + |
| 16 | +def validate_vector(values: Any, *, name: str, n_samples: int | None = None) -> np.ndarray: |
| 17 | + """Validate and coerce a one-dimensional target or prediction vector.""" |
| 18 | + vector = check_array(values, ensure_2d=False, dtype=None, ensure_all_finite="allow-nan") |
| 19 | + vector = np.asarray(vector).reshape(-1) |
| 20 | + |
| 21 | + if n_samples is not None and vector.shape[0] != n_samples: |
| 22 | + raise ValueError(f"{name} must have length {n_samples}, got {vector.shape[0]}.") |
| 23 | + |
| 24 | + return vector |
| 25 | + |
| 26 | + |
| 27 | +def validate_binary_treatment(d: Any, *, n_samples: int | None = None) -> np.ndarray: |
| 28 | + """Validate that treatment assignments are binary-coded.""" |
| 29 | + treatment = validate_vector(d, name="d", n_samples=n_samples) |
| 30 | + unique_values = np.unique(treatment) |
| 31 | + |
| 32 | + if unique_values.size > 2 or not np.all(np.isin(unique_values, [0, 1])): |
| 33 | + raise ValueError( |
| 34 | + "Binary treatment is required for the built-in nuisance estimators. " |
| 35 | + "Expected values in {0, 1}." |
| 36 | + ) |
| 37 | + |
| 38 | + return treatment.astype(int, copy=False) |
| 39 | + |
| 40 | + |
| 41 | +def validate_manual_predictions( |
| 42 | + *, |
| 43 | + y: Any, |
| 44 | + d: Any, |
| 45 | + y_hat: Any, |
| 46 | + d_hat: Any, |
| 47 | + propensity_clip: float, |
| 48 | +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: |
| 49 | + """Validate manual nuisance predictions against observed outcomes and treatment.""" |
| 50 | + y_array = validate_vector(y, name="y") |
| 51 | + d_array = validate_binary_treatment(d, n_samples=y_array.shape[0]) |
| 52 | + y_hat_array = validate_vector(y_hat, name="y_hat", n_samples=y_array.shape[0]) |
| 53 | + d_hat_array = validate_vector(d_hat, name="d_hat", n_samples=y_array.shape[0]) |
| 54 | + |
| 55 | + if np.any(~np.isfinite(y_hat_array)): |
| 56 | + raise ValueError("y_hat must contain only finite values.") |
| 57 | + if np.any(~np.isfinite(d_hat_array)): |
| 58 | + raise ValueError("d_hat must contain only finite values.") |
| 59 | + |
| 60 | + d_hat_array = np.clip(d_hat_array.astype(float, copy=False), propensity_clip, 1.0 - propensity_clip) |
| 61 | + |
| 62 | + return y_array, d_array, y_hat_array, d_hat_array |
0 commit comments