diff --git a/ax/analysis/healthcheck/no_effects_analysis.py b/ax/analysis/healthcheck/no_effects_analysis.py index d4bb35960ac..8de435ab88e 100644 --- a/ax/analysis/healthcheck/no_effects_analysis.py +++ b/ax/analysis/healthcheck/no_effects_analysis.py @@ -100,6 +100,11 @@ def compute( df_tone = check_experiment_effects_per_metric( data=data, objective_names=set(objective_names) ) + if df_tone.empty: + raise UserInputError( + "TestOfNoEffectAnalysis requires at least one trial with two " + "or more arms, since the test compares arms within a trial." + ) metrics_tone = df_tone.groupby("metric_name")["has_effect"].sum() > 0 metrics_with_effects = [i for i in metrics_tone.index if metrics_tone[i]] objectives_without_effects = set() diff --git a/ax/analysis/healthcheck/tests/test_no_effects_analysis.py b/ax/analysis/healthcheck/tests/test_no_effects_analysis.py index 5dce9606cea..92d1bb8a1ae 100644 --- a/ax/analysis/healthcheck/tests/test_no_effects_analysis.py +++ b/ax/analysis/healthcheck/tests/test_no_effects_analysis.py @@ -101,6 +101,74 @@ def test_raises_error_without_data(self) -> None: with self.assertRaises(AxError): self.tone.compute(experiment=self.experiment) + def test_raises_error_without_sample_sizes(self) -> None: + # GIVEN an experiment whose data has no `n` column (as produced by + # standard trial evaluation, e.g. `Client.complete_trial`) + self.experiment.attach_data( + data=Data( + df=pd.DataFrame( + { + "arm_name": ["0_0", "0_1", "0_2"], + "metric_name": ["branin"] * 3, + "mean": [0.0, 1.0, 2.0], + "sem": [0.1] * 3, + "trial_index": [0] * 3, + "metric_signature": ["branin"] * 3, + } + ) + ) + ) + # WHEN we compute the healthcheck + # THEN it raises a UserInputError instead of a KeyError + with self.assertRaisesRegex(UserInputError, "per-arm sample sizes"): + self.tone.compute(experiment=self.experiment) + + def test_deterministic_data_detects_effects(self) -> None: + # GIVEN an experiment with deterministic data (sem == 0) and + # clearly different means + self.experiment.attach_data( + data=Data( + df=pd.DataFrame( + { + "arm_name": ["0_0", "0_1", "0_2"], + "metric_name": ["branin"] * 3, + "mean": [0.0, 1.0, 2.0], + "sem": [0.0] * 3, + "trial_index": [0] * 3, + "n": [1000] * 3, + "metric_signature": ["branin"] * 3, + } + ) + ) + ) + # WHEN we compute the healthcheck + card = self.tone.compute(experiment=self.experiment) + # THEN it is a PASS (previously a NaN p-value silently read as + # "no effect" and produced a WARNING) + self.assertEqual(card.get_status(), HealthcheckStatus.PASS) + + def test_raises_error_with_only_single_arm_trials(self) -> None: + # GIVEN an experiment whose trials all have a single arm + self.experiment.attach_data( + data=Data( + df=pd.DataFrame( + { + "arm_name": ["0_0"], + "metric_name": ["branin"], + "mean": [1.0], + "sem": [0.1], + "trial_index": [0], + "n": [1000], + "metric_signature": ["branin"], + } + ) + ) + ) + # WHEN we compute the healthcheck + # THEN it raises a UserInputError + with self.assertRaisesRegex(UserInputError, "two or more arms"): + self.tone.compute(experiment=self.experiment) + def test_multi_objective_partial_no_effects(self) -> None: # GIVEN we have a multi-objective experiment with one metric with no effects self.moo_experiment.attach_data( diff --git a/ax/utils/stats/no_effects.py b/ax/utils/stats/no_effects.py index fcb10b80cde..e2ce5681ad5 100644 --- a/ax/utils/stats/no_effects.py +++ b/ax/utils/stats/no_effects.py @@ -12,10 +12,26 @@ import pandas as pd import scipy from ax.core.data import Data +from ax.exceptions.core import UserInputError from ax.utils.stats.math_utils import relativize from scipy.stats import norm +def _validate_sample_sizes(df: pd.DataFrame) -> None: + """Ensure per-arm sample sizes are available for the test of no effect. + + The ``n`` column is optional on Ax ``Data`` and is not produced by + standard trial evaluation (e.g. ``Client.complete_trial``), so it must + be validated before use. + """ + if "n" not in df.columns or df["n"].isnull().any(): + raise UserInputError( + "The test of no effect requires per-arm sample sizes: every row " + "of the data must have a value in the `n` column. Attach data " + "that includes sample sizes to use this test." + ) + + def check_experiment_effects_per_metric( data: Data, objective_names: set[str] | None = None, @@ -44,6 +60,7 @@ def check_experiment_effects_per_metric( """ df = data.df + _validate_sample_sizes(df) df_grouped = df.groupby(["metric_name", "trial_index"]) cols = ["metric_name", "trial_index", "p_value", "has_effect"] @@ -56,6 +73,10 @@ def check_experiment_effects_per_metric( for metric_name, trial_index in df_grouped.groups.keys(): dfm = df_grouped.get_group((metric_name, trial_index)) + if len(dfm) < 2: + # A single-arm group cannot show within-trial effects. + continue + p_value, f_stat = no_effect_test_welch( means=list(dfm["mean"].values), sems=list(dfm["sem"].values), @@ -112,11 +133,12 @@ def check_experiment_effects( Returns: effective: True if the null of no treatment effects can be rejected. ineffective_on_objectives: List of objectives on which the - null of no treatment effects can be rejected. + null of no treatment effects could not be rejected. bounds_df: The minimum and maximum bounds on possible effects. """ df = data.df + _validate_sample_sizes(df) df_grouped = df.groupby("metric_name") K = len(df_grouped) ps = [] @@ -228,6 +250,11 @@ def no_effect_test_welch( Returns: A tuple containing - the p-value and - the test-statistic value. + + Raises: + UserInputError: If there are fewer than two arms, an arm has one or + fewer observations, or a mix of zero and positive sems makes the + test undefined. """ means_arr = np.array(means) sems_arr = np.array(sems) @@ -235,7 +262,36 @@ def no_effect_test_welch( K = len(means_arr) + if K < 2: + raise UserInputError( + f"Welch's test of no effect requires at least two arms, received {K}." + ) + if np.any(ns_arr <= 1): + raise UserInputError( + "Welch's test of no effect requires more than one observation " + "(n > 1) in each arm." + ) + variances = np.multiply(sems_arr**2, ns_arr) + + zero_variance = variances <= 0 + if np.any(zero_variance): + # The mean of an arm with sem == 0 is known exactly, so Welch's F + # statistic is undefined (its weights divide by the variance). + # Previously this silently produced a NaN p-value, which read as + # "no effect" even when the exactly-known means clearly differed. + if np.unique(means_arr[zero_variance]).size > 1: + # Two exactly-known means differ: an exact effect. + return 0.0, float("inf") + if np.all(zero_variance): + # All means are known exactly and are identical: no effect. + return 1.0, 0.0 + raise UserInputError( + "Welch's test of no effect requires a positive sem for each " + "arm; found a mix of zero and positive sems, for which the " + "test is undefined." + ) + ws = np.divide(ns_arr, variances) W = np.sum(ws) diff --git a/ax/utils/stats/tests/test_no_effects.py b/ax/utils/stats/tests/test_no_effects.py new file mode 100644 index 00000000000..eae743f7683 --- /dev/null +++ b/ax/utils/stats/tests/test_no_effects.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import pandas as pd +from ax.core.data import Data +from ax.exceptions.core import UserInputError +from ax.utils.common.testutils import TestCase +from ax.utils.stats.no_effects import ( + check_experiment_effects, + check_experiment_effects_per_metric, + no_effect_test_welch, +) + + +def _data(rows: list[dict[str, object]]) -> Data: + base = {"trial_index": 0, "metric_name": "m1", "metric_signature": "m1"} + return Data(df=pd.DataFrame([{**base, **row} for row in rows])) + + +class TestNoEffects(TestCase): + def test_effects_detected(self) -> None: + # GIVEN two arms with clearly different means + data = _data( + [ + {"arm_name": "0_0", "mean": 1.0, "sem": 0.01, "n": 1000}, + {"arm_name": "0_1", "mean": 2.0, "sem": 0.01, "n": 1000}, + ] + ) + # WHEN we run the test of no effect + df_tone = check_experiment_effects_per_metric(data=data, objective_names={"m1"}) + # THEN an effect is detected + self.assertTrue(df_tone["has_effect"].all()) + + def test_missing_sample_sizes(self) -> None: + # GIVEN data without the optional `n` column (as produced by standard + # trial evaluation, e.g. `Client.complete_trial`) + data = _data( + [ + {"arm_name": "0_0", "mean": 1.0, "sem": 0.1}, + {"arm_name": "0_1", "mean": 2.0, "sem": 0.1}, + ] + ) + # WHEN we run the test of no effect + # THEN it raises a UserInputError instead of a KeyError + with self.assertRaisesRegex(UserInputError, "per-arm sample sizes"): + check_experiment_effects_per_metric(data=data, objective_names={"m1"}) + + def test_check_experiment_effects_missing_sample_sizes(self) -> None: + # GIVEN data without the optional `n` column + data = _data( + [ + {"arm_name": "status_quo", "mean": 1.0, "sem": 0.1}, + {"arm_name": "0_0", "mean": 2.0, "sem": 0.1}, + ] + ) + # WHEN we run the overall (across-metric) test of no effect + # THEN it raises a UserInputError instead of a KeyError + with self.assertRaisesRegex(UserInputError, "per-arm sample sizes"): + check_experiment_effects(data=data, objective_names={"m1"}) + + def test_null_sample_sizes(self) -> None: + # GIVEN data where some rows are missing a sample size + data = _data( + [ + {"arm_name": "0_0", "mean": 1.0, "sem": 0.1, "n": 1000}, + {"arm_name": "0_1", "mean": 2.0, "sem": 0.1, "n": None}, + ] + ) + # WHEN we run the test of no effect + # THEN it raises a UserInputError + with self.assertRaisesRegex(UserInputError, "per-arm sample sizes"): + check_experiment_effects_per_metric(data=data, objective_names={"m1"}) + + def test_zero_sem_with_different_means(self) -> None: + # GIVEN deterministic data (sem == 0) with clearly different means + data = _data( + [ + {"arm_name": "0_0", "mean": 1.0, "sem": 0.0, "n": 1000}, + {"arm_name": "0_1", "mean": 2.0, "sem": 0.0, "n": 1000}, + ] + ) + # WHEN we run the test of no effect + df_tone = check_experiment_effects_per_metric(data=data, objective_names={"m1"}) + # THEN the exact effect is detected (previously a NaN p-value silently + # read as "no effect") + self.assertEqual(df_tone["p_value"].item(), 0.0) + self.assertTrue(df_tone["has_effect"].item()) + + def test_zero_sem_with_equal_means(self) -> None: + # GIVEN deterministic data (sem == 0) with identical means + data = _data( + [ + {"arm_name": "0_0", "mean": 1.0, "sem": 0.0, "n": 1000}, + {"arm_name": "0_1", "mean": 1.0, "sem": 0.0, "n": 1000}, + ] + ) + # WHEN we run the test of no effect + df_tone = check_experiment_effects_per_metric(data=data, objective_names={"m1"}) + # THEN no effect is detected, with a well-defined p-value + self.assertEqual(df_tone["p_value"].item(), 1.0) + self.assertFalse(df_tone["has_effect"].item()) + + def test_mixed_zero_and_positive_sems(self) -> None: + # GIVEN one arm with sem == 0 and another with a positive sem + # WHEN we run Welch's test directly + # THEN it raises a UserInputError since the test is undefined + with self.assertRaisesRegex(UserInputError, "positive sem"): + no_effect_test_welch(means=[1.0, 1.0], sems=[0.0, 0.1], ns=[1000, 1000]) + + def test_single_arm_groups_are_skipped(self) -> None: + # GIVEN one single-arm trial and one two-arm trial + data = _data( + [ + {"arm_name": "0_0", "mean": 1.0, "sem": 0.1, "n": 1000}, + { + "arm_name": "1_0", + "mean": 1.0, + "sem": 0.1, + "n": 1000, + "trial_index": 1, + }, + { + "arm_name": "1_1", + "mean": 2.0, + "sem": 0.1, + "n": 1000, + "trial_index": 1, + }, + ] + ) + # WHEN we run the test of no effect + df_tone = check_experiment_effects_per_metric(data=data, objective_names={"m1"}) + # THEN the single-arm trial is skipped (previously it produced a NaN + # p-value) and the two-arm trial is tested + self.assertEqual(df_tone["trial_index"].tolist(), [1]) + self.assertTrue(df_tone["has_effect"].item()) + + def test_welch_requires_at_least_two_arms(self) -> None: + with self.assertRaisesRegex(UserInputError, "at least two arms"): + no_effect_test_welch(means=[1.0], sems=[0.1], ns=[1000]) + + def test_welch_requires_more_than_one_observation(self) -> None: + with self.assertRaisesRegex(UserInputError, "more than one observation"): + no_effect_test_welch(means=[1.0, 2.0], sems=[0.1, 0.1], ns=[1, 1000])