From ca51326e6ad298464abb26e6c1eb032006416ee1 Mon Sep 17 00:00:00 2001 From: tereshchuk1 Date: Fri, 12 Jun 2026 05:43:48 +0200 Subject: [PATCH 1/2] t push --forceAdd MultiViewLightGBM model and include lightgbm in nox test session --- drevalpy/models/__init__.py | 7 + .../models/baselines/hyperparameters.yaml | 28 ++ .../models/baselines/multi_view_lightgbm.py | 267 ++++++++++++++++++ noxfile.py | 4 +- pyproject.toml | 2 + tests/models/test_baselines.py | 6 +- 6 files changed, 311 insertions(+), 3 deletions(-) create mode 100644 drevalpy/models/baselines/multi_view_lightgbm.py diff --git a/drevalpy/models/__init__.py b/drevalpy/models/__init__.py index 7ed4dd14..7053e96e 100644 --- a/drevalpy/models/__init__.py +++ b/drevalpy/models/__init__.py @@ -30,8 +30,10 @@ "AdaBoostDecisionTree", "LassoModel", "MultiViewXGBoost", + "MultiViewLightGBM", ] +from .baselines.multi_view_lightgbm import MultiViewLightGBM from .baselines.multi_view_random_forest import MultiViewRandomForest from .baselines.multi_view_xgboost import MultiViewXGBoost from .baselines.naive_pred import ( @@ -98,6 +100,11 @@ "DIPK": DIPKModel, "PharmaFormer": PharmaFormerModel, "SRMF": SRMF, + "KNNRegressor": KNNRegressor, + "AdaBoostDecisionTree": AdaBoostDecisionTree, + "Lasso": LassoModel, + "MultiViewXGBoost": MultiViewXGBoost, + "MultiViewLightGBM": MultiViewLightGBM, "Precily": PrecilyModel, } diff --git a/drevalpy/models/baselines/hyperparameters.yaml b/drevalpy/models/baselines/hyperparameters.yaml index 35fe0f1b..c32f3381 100644 --- a/drevalpy/models/baselines/hyperparameters.yaml +++ b/drevalpy/models/baselines/hyperparameters.yaml @@ -265,3 +265,31 @@ MultiViewXGBoost: - 1 reg_lambda: - 0.1 + +MultiViewLightGBM: + cell_line_views: + - - gene_expression + - methylation + - mutations + - copy_number_variation_gistic + - gene_expression + - proteomics + drug_views: + - fingerprints + learning_rate: + - 0.1 + max_depth: + - 10 + num_leaves: + - 63 + - 127 + subsample: + - 0.8 + colsample_bytree: + - 0.6 + - 0.8 + reg_alpha: + - 0 + - 1 + reg_lambda: + - 0.1 diff --git a/drevalpy/models/baselines/multi_view_lightgbm.py b/drevalpy/models/baselines/multi_view_lightgbm.py new file mode 100644 index 00000000..4c73d976 --- /dev/null +++ b/drevalpy/models/baselines/multi_view_lightgbm.py @@ -0,0 +1,267 @@ +"""Contains the baseline MultiViewLightGBM model.""" + +import json +import os + +import joblib +import numpy as np +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler + +from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset + +from ..drp_model import DRPModel +from ..utils import ( + ProteomicsMedianCenterAndImputeTransformer, + _get_view_as_list, + load_multi_cell_line_view, + load_single_drug_view, + prepare_expression_and_methylation, + prepare_proteomics, +) + + +class MultiViewLightGBM(DRPModel): + """LightGBM model with multi-omic cell line features and drug fingerprints.""" + + cell_line_views = [ + "gene_expression", + "methylation", + "mutations", + "copy_number_variation_gistic", + ] + drug_views = ["fingerprints"] + + def __init__(self): + """Initializes the MultiViewLightGBM model.""" + super().__init__() + self.model = None + self.gene_expression_scaler = StandardScaler() + # methylation-specific defaults + self.methylation_scaler = StandardScaler() + self.methylation_pca = None + self.pca_ncomp = 100 + # proteomics-specific defaults + self.proteomics_transformer = None + self.proteomics_feature_threshold = 0.7 + self.proteomics_n_features = 1000 + self.proteomics_normalization_width = 0.3 + self.proteomics_normalization_downshift = 1.8 + + @classmethod + def get_model_name(cls) -> str: + """ + Returns the model name. + + :returns: MultiViewLightGBM + """ + return "MultiViewLightGBM" + + def build_model(self, hyperparameters: dict) -> None: + """ + Builds the model from hyperparameters. + + :param hyperparameters: dictionary containing the hyperparameters. + :raises ImportError: if lightgbm is not installed. + """ + try: + import lightgbm as lgb + except ImportError as e: + raise ImportError( + "MultiViewLightGBM requires the optional 'lightgbm' extra. " + "Install it with: pip install drevalpy[lightgbm] (or `poetry install -E lightgbm`)." + ) from e + + self.log_hyperparameters(hyperparameters) + self.hyperparameters = hyperparameters + self.cell_line_views = _get_view_as_list( + hyperparameters.get( + "cell_line_views", + ["gene_expression", "methylation", "mutations", "copy_number_variation_gistic"], + ) + ) + self.drug_views = _get_view_as_list(hyperparameters.get("drug_views", ["fingerprints"])) + if "methylation" in self.cell_line_views: + self.pca_ncomp = hyperparameters.get("methylation_n_components", 100) + if "proteomics" in self.cell_line_views: + self.proteomics_feature_threshold = hyperparameters.get("proteomics_feature_threshold", 0.7) + self.proteomics_n_features = hyperparameters.get("proteomics_n_features", 1000) + self.proteomics_normalization_width = hyperparameters.get("proteomics_normalization_width", 0.3) + self.proteomics_normalization_downshift = hyperparameters.get("proteomics_normalization_downshift", 1.8) + self.proteomics_transformer = ProteomicsMedianCenterAndImputeTransformer( + feature_threshold=self.proteomics_feature_threshold, + n_features=self.proteomics_n_features, + normalization_downshift=self.proteomics_normalization_downshift, + normalization_width=self.proteomics_normalization_width, + ) + self.model = lgb.LGBMRegressor( + n_estimators=hyperparameters.get("n_estimators", 100), + learning_rate=hyperparameters.get("learning_rate", 0.1), + max_depth=hyperparameters.get("max_depth", 6), + num_leaves=hyperparameters.get("num_leaves", 63), + subsample=hyperparameters.get("subsample", 0.8), + colsample_bytree=hyperparameters.get("colsample_bytree", 0.8), + reg_alpha=hyperparameters.get("reg_alpha", 0.0), + reg_lambda=hyperparameters.get("reg_lambda", 0.0), + random_state=42, + n_jobs=-1, + verbosity=-1, + ) + + def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """ + Loads the cell line features. + + :param data_path: data path e.g. data/ + :param dataset_name: dataset name e.g. GDSC1 + :returns: FeatureDataset containing the cell line omics features + """ + return load_multi_cell_line_view(self.cell_line_views, data_path, dataset_name, self.get_model_name()) + + def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset | None: + """ + Loads the drug features. + + :param data_path: path to the drug features, e.g., data/ + :param dataset_name: name of the dataset, e.g., GDSC1 + :returns: FeatureDataset containing the drug features + """ + return load_single_drug_view(self.drug_views, data_path, dataset_name, self.get_model_name()) + + def train( + self, + output: DrugResponseDataset, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, + output_earlystopping: DrugResponseDataset | None = None, + model_checkpoint_dir: str = "", + ) -> None: + """ + Trains the model. + + :param output: training dataset containing the response output + :param cell_line_input: cell line omics features + :param drug_input: drug features + :param output_earlystopping: not used + :param model_checkpoint_dir: not used + """ + if "methylation" in self.cell_line_views: + first_cl_feature = next(iter(cell_line_input.features.values())) + n_met_features = first_cl_feature["methylation"].shape[0] + n_components = min(self.pca_ncomp, n_met_features) + self.methylation_pca = PCA(n_components=n_components) + + if "gene_expression" in self.cell_line_views or "methylation" in self.cell_line_views: + cell_line_input = prepare_expression_and_methylation( + cell_line_input=cell_line_input, + cell_line_ids=np.unique(output.cell_line_ids), + training=True, + gene_expression_scaler=self.gene_expression_scaler, + methylation_scaler=self.methylation_scaler, + methylation_pca=self.methylation_pca, + ) + + if "proteomics" in self.cell_line_views: + cell_line_input = prepare_proteomics( + cell_line_input=cell_line_input, + cell_line_ids=np.unique(output.cell_line_ids), + training=True, + transformer=self.proteomics_transformer, + ) + + inputs = self.get_feature_matrices( + cell_line_ids=output.cell_line_ids, + drug_ids=output.drug_ids, + cell_line_input=cell_line_input, + drug_input=drug_input, + ) + + array_list = [inputs[view] for view in self.cell_line_views + self.drug_views] + x = np.concatenate(array_list, axis=1) + self.model.fit(x, output.response) + + def predict( + self, + cell_line_ids: np.ndarray, + drug_ids: np.ndarray, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, + ) -> np.ndarray: + """ + Predicts the response for the given input. + + :param cell_line_ids: cell line ids + :param drug_ids: drug ids + :param cell_line_input: cell line omics features + :param drug_input: drug features + :returns: predicted response + """ + if "gene_expression" in self.cell_line_views or "methylation" in self.cell_line_views: + cell_line_input = prepare_expression_and_methylation( + cell_line_input=cell_line_input, + cell_line_ids=np.unique(cell_line_ids), + training=False, + gene_expression_scaler=self.gene_expression_scaler, + methylation_scaler=self.methylation_scaler, + methylation_pca=self.methylation_pca, + ) + + if "proteomics" in self.cell_line_views: + cell_line_input = prepare_proteomics( + cell_line_input=cell_line_input, + cell_line_ids=np.unique(cell_line_ids), + training=False, + transformer=self.proteomics_transformer, + ) + + inputs = self.get_feature_matrices( + cell_line_ids=cell_line_ids, + drug_ids=drug_ids, + cell_line_input=cell_line_input, + drug_input=drug_input, + ) + + array_list = [inputs[view] for view in self.cell_line_views + self.drug_views] + x = np.concatenate(array_list, axis=1) + return self.model.predict(x) + + def save(self, directory: str) -> None: + """ + Saves the model to disk. + + :param directory: target directory + """ + os.makedirs(directory, exist_ok=True) + joblib.dump(self.model, os.path.join(directory, "model.pkl")) + with open(os.path.join(directory, "hyperparameters.json"), "w") as f: + json.dump(self.hyperparameters, f) + if "gene_expression" in self.cell_line_views: + joblib.dump(self.gene_expression_scaler, os.path.join(directory, "gene_scaler.pkl")) + if "methylation" in self.cell_line_views: + joblib.dump(self.methylation_scaler, os.path.join(directory, "methylation_scaler.pkl")) + joblib.dump(self.methylation_pca, os.path.join(directory, "methylation_pca.pkl")) + if self.proteomics_transformer is not None: + joblib.dump(self.proteomics_transformer, os.path.join(directory, "proteomics_transformer.pkl")) + + @classmethod + def load(cls, directory: str) -> "MultiViewLightGBM": + """ + Loads the model from disk. + + :param directory: directory containing the saved model files + :returns: restored MultiViewLightGBM instance + """ + instance = cls() + with open(os.path.join(directory, "hyperparameters.json")) as f: + hyperparameters = json.load(f) + instance.build_model(hyperparameters) + instance.model = joblib.load(os.path.join(directory, "model.pkl")) + if "gene_expression" in instance.cell_line_views: + instance.gene_expression_scaler = joblib.load(os.path.join(directory, "gene_scaler.pkl")) + if "methylation" in instance.cell_line_views: + instance.methylation_scaler = joblib.load(os.path.join(directory, "methylation_scaler.pkl")) + instance.methylation_pca = joblib.load(os.path.join(directory, "methylation_pca.pkl")) + transformer_path = os.path.join(directory, "proteomics_transformer.pkl") + if os.path.exists(transformer_path): + instance.proteomics_transformer = joblib.load(transformer_path) + return instance diff --git a/noxfile.py b/noxfile.py index 01667e2f..dcae8772 100644 --- a/noxfile.py +++ b/noxfile.py @@ -144,7 +144,7 @@ def tests(session: Session) -> None: :param session: The Session object. """ - session.install(".[xgboost,precily]") + session.install(".[xgboost,precily,lightgbm]") session.install("coverage[toml]", "pytest", "pygments") try: session.run( @@ -188,7 +188,7 @@ def typeguard(session: Session) -> None: :param session: The Session object. """ - session.install(".[xgboost,precily]") + session.install(".[xgboost,precily,lightgbm]") session.install("pytest", "typeguard", "pygments") session.run( "pytest", diff --git a/pyproject.toml b/pyproject.toml index 1df70431..9da375d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ starlette = ">=0.49.1" pydantic = { version = ">=2.5", optional = true } wandb = ">=0.24.0" xgboost = { version = "^3.2.0", optional = true } +lightgbm = { version = "^4.0.0", optional = true } typer = ">=0.26,<0.27" rich = "^15.0.0" gseapy = { version = "^1.1.0", optional = true } @@ -61,6 +62,7 @@ poetry-plugin-export = ">=1.8" [tool.poetry.extras] multiprocessing = ["ray", "pydantic"] xgboost = ["xgboost"] +lightgbm = ["lightgbm"] precily = ["gseapy"] [tool.poetry.dependencies.ray] diff --git a/tests/models/test_baselines.py b/tests/models/test_baselines.py index 410bb8e8..ffc7a22b 100644 --- a/tests/models/test_baselines.py +++ b/tests/models/test_baselines.py @@ -65,6 +65,7 @@ def test_random_forest_respects_max_depth(max_depth_input, expected) -> None: "KNNRegressor", "Lasso", "MultiViewXGBoost", + "MultiViewLightGBM", ], ) @pytest.mark.parametrize("test_mode", ["LTO", "LPO", "LCO", "LDO"]) @@ -86,6 +87,8 @@ def test_baselines( """ if model_name == "MultiViewXGBoost": pytest.importorskip("xgboost", reason="MultiViewXGBoost requires the optional 'xgboost' extra") + if model_name == "MultiViewLightGBM": + pytest.importorskip("lightgbm", reason="MultiViewLightGBM requires the optional 'lightgbm' extra") drug_response = sample_dataset drug_response.split_dataset( n_cv_splits=2, @@ -337,6 +340,7 @@ def _call_other_baselines(model: str, train_dataset: DrugResponseDataset, val_da "AdaBoostDecisionTree", "SVR", "MultiViewXGBoost", + "MultiViewLightGBM", ]: # test a hpam config with cell_line_views == "gene expression" and one with "proteomics covered_gex = False @@ -356,7 +360,7 @@ def _call_other_baselines(model: str, train_dataset: DrugResponseDataset, val_da else: hpams = hpams[:2] model_instance = model_class() - if model != "MultiViewXGBoost": + if model not in ("MultiViewXGBoost", "MultiViewLightGBM"): assert isinstance(model_instance, SklearnModel) for hpam_combi in hpams: if model == "RandomForest" or model == "GradientBoosting": From 89f0cab59b018d7d9ce77de3d706ff83c53ae032 Mon Sep 17 00:00:00 2001 From: tereshchuk1 Date: Fri, 12 Jun 2026 06:03:27 +0200 Subject: [PATCH 2/2] Update poetry.lock for lightgbm dependency --- drevalpy/models/__init__.py | 6 +--- .../models/baselines/hyperparameters.yaml | 5 ++-- poetry.lock | 30 ++++++++++++++++++- pyproject.toml | 6 ++-- 4 files changed, 36 insertions(+), 11 deletions(-) diff --git a/drevalpy/models/__init__.py b/drevalpy/models/__init__.py index 7053e96e..c0fb3e00 100644 --- a/drevalpy/models/__init__.py +++ b/drevalpy/models/__init__.py @@ -96,15 +96,11 @@ "SimpleNeuralNetwork": SimpleNeuralNetwork, "MultiViewNeuralNetwork": MultiViewNeuralNetwork, "MultiViewXGBoost": MultiViewXGBoost, + "MultiViewLightGBM": MultiViewLightGBM, # Published models "DIPK": DIPKModel, "PharmaFormer": PharmaFormerModel, "SRMF": SRMF, - "KNNRegressor": KNNRegressor, - "AdaBoostDecisionTree": AdaBoostDecisionTree, - "Lasso": LassoModel, - "MultiViewXGBoost": MultiViewXGBoost, - "MultiViewLightGBM": MultiViewLightGBM, "Precily": PrecilyModel, } diff --git a/drevalpy/models/baselines/hyperparameters.yaml b/drevalpy/models/baselines/hyperparameters.yaml index c32f3381..8af4bcf1 100644 --- a/drevalpy/models/baselines/hyperparameters.yaml +++ b/drevalpy/models/baselines/hyperparameters.yaml @@ -278,9 +278,8 @@ MultiViewLightGBM: - fingerprints learning_rate: - 0.1 - max_depth: - - 10 num_leaves: + - 31 - 63 - 127 subsample: @@ -292,4 +291,6 @@ MultiViewLightGBM: - 0 - 1 reg_lambda: + - 0 - 0.1 + - 1 diff --git a/poetry.lock b/poetry.lock index ac3bf4b6..f7e8622d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2377,6 +2377,33 @@ files = [ {file = "librt-0.11.0.tar.gz", hash = "sha256:075dc3ef4458a278e0195cbf6ac9d38808d9b906c5a6c7f7f79c3888276a3fb1"}, ] +[[package]] +name = "lightgbm" +version = "4.6.0" +description = "LightGBM Python-package" +optional = true +python-versions = ">=3.7" +groups = ["main"] +markers = "extra == \"lightgbm\"" +files = [ + {file = "lightgbm-4.6.0-py3-none-macosx_10_15_x86_64.whl", hash = "sha256:b7a393de8a334d5c8e490df91270f0763f83f959574d504c7ccb9eee4aef70ed"}, + {file = "lightgbm-4.6.0-py3-none-macosx_12_0_arm64.whl", hash = "sha256:2dafd98d4e02b844ceb0b61450a660681076b1ea6c7adb8c566dfd66832aafad"}, + {file = "lightgbm-4.6.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4d68712bbd2b57a0b14390cbf9376c1d5ed773fa2e71e099cac588703b590336"}, + {file = "lightgbm-4.6.0-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:cb19b5afea55b5b61cbb2131095f50538bd608a00655f23ad5d25ae3e3bf1c8d"}, + {file = "lightgbm-4.6.0-py3-none-win_amd64.whl", hash = "sha256:37089ee95664b6550a7189d887dbf098e3eadab03537e411f52c63c121e3ba4b"}, + {file = "lightgbm-4.6.0.tar.gz", hash = "sha256:cb1c59720eb569389c0ba74d14f52351b573af489f230032a1c9f314f8bab7fe"}, +] + +[package.dependencies] +numpy = ">=1.17.0" +scipy = "*" + +[package.extras] +arrow = ["cffi (>=1.15.1)", "pyarrow (>=6.0.1)"] +dask = ["dask[array,dataframe,distributed] (>=2.0.0)", "pandas (>=0.24.0)"] +pandas = ["pandas (>=0.24.0)"] +scikit-learn = ["scikit-learn (>=0.24.2)"] + [[package]] name = "lightning-utilities" version = "0.15.3" @@ -6824,6 +6851,7 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more_it type = ["pytest-mypy (>=1.0.1) ; platform_python_implementation != \"PyPy\""] [extras] +lightgbm = ["lightgbm"] multiprocessing = ["pydantic", "ray"] precily = ["gseapy"] xgboost = ["xgboost"] @@ -6831,4 +6859,4 @@ xgboost = ["xgboost"] [metadata] lock-version = "2.1" python-versions = ">=3.11,<3.14" -content-hash = "acf951602b19284828d120e9a14999ba14616b735b103aeb7dd5baf244379a2f" +content-hash = "5ea0d65a35030e7f01831e034c6011c5719854714308b980611e21887883334a" diff --git a/pyproject.toml b/pyproject.toml index 9da375d1..7135a705 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,11 +50,11 @@ poetry = ">=2.0.1" starlette = ">=0.49.1" pydantic = { version = ">=2.5", optional = true } wandb = ">=0.24.0" -xgboost = { version = "^3.2.0", optional = true } -lightgbm = { version = "^4.0.0", optional = true } +xgboost = { version = ">=3.2.0", optional = true } +lightgbm = { version = ">=4.0.0", optional = true } typer = ">=0.26,<0.27" rich = "^15.0.0" -gseapy = { version = "^1.1.0", optional = true } +gseapy = { version = ">=1.1.0", optional = true } [tool.poetry.requires-plugins] poetry-plugin-export = ">=1.8"