From e12857a9989ee70283eba0b9c4c9c24abc456bf9 Mon Sep 17 00:00:00 2001 From: David Schwab Date: Thu, 4 May 2023 13:02:50 +0200 Subject: [PATCH 01/10] change function signature and fix tests --- .gitignore | 2 ++ examples/multiclass_classification/src/main.rs | 2 +- src/booster.rs | 8 ++++---- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 61d7627..f8efa46 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,5 @@ lightgbm-sys/target examples/binary_classification/target/ examples/multiclass_classification/target/ examples/regression/target/ + +.idea diff --git a/examples/multiclass_classification/src/main.rs b/examples/multiclass_classification/src/main.rs index 4f39225..86d5896 100644 --- a/examples/multiclass_classification/src/main.rs +++ b/examples/multiclass_classification/src/main.rs @@ -63,7 +63,7 @@ fn main() -> std::io::Result<()> { } }; - let booster = Booster::train(train_dataset, ¶ms).unwrap(); + let booster = Booster::train(train_dataset, None, ¶ms).unwrap(); let result = booster.predict(test_features).unwrap(); let mut tp = 0; diff --git a/src/booster.rs b/src/booster.rs index e8d4d4d..10e6b2a 100644 --- a/src/booster.rs +++ b/src/booster.rs @@ -69,9 +69,9 @@ impl Booster { /// "metric": "auc" /// } /// }; - /// let bst = Booster::train(dataset, ¶ms).unwrap(); + /// let bst = Booster::train(dataset, None, ¶ms).unwrap(); /// ``` - pub fn train(dataset: Dataset, parameter: &Value) -> Result { + pub fn train(train_data: Dataset, val_data: Option ,parameter: &Value) -> Result { // get num_iterations let num_iterations: i64 = if parameter["num_iterations"].is_null() { 100 @@ -91,7 +91,7 @@ impl Booster { let mut handle = std::ptr::null_mut(); lgbm_call!(lightgbm_sys::LGBM_BoosterCreate( - dataset.handle, + train_data.handle, params_cstring.as_ptr() as *const c_char, &mut handle ))?; @@ -307,7 +307,7 @@ mod tests { fn _train_booster(params: &Value) -> Booster { let dataset = _read_train_file().unwrap(); - Booster::train(dataset, ¶ms).unwrap() + Booster::train(dataset, None, ¶ms).unwrap() } fn _default_params() -> Value { From 4510965db8454f25c9fa10432cc9aee4418d6197 Mon Sep 17 00:00:00 2001 From: David Schwab Date: Thu, 4 May 2023 19:09:44 +0200 Subject: [PATCH 02/10] implement get_eval_names --- src/booster.rs | 81 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 79 insertions(+), 2 deletions(-) diff --git a/src/booster.rs b/src/booster.rs index 10e6b2a..9d1465e 100644 --- a/src/booster.rs +++ b/src/booster.rs @@ -1,4 +1,4 @@ -use libc::{c_char, c_double, c_longlong, c_void}; +use libc::{c_char, c_double, c_int, c_longlong, c_void}; use std; use std::convert::TryInto; use std::ffi::CString; @@ -71,7 +71,38 @@ impl Booster { /// }; /// let bst = Booster::train(dataset, None, ¶ms).unwrap(); /// ``` - pub fn train(train_data: Dataset, val_data: Option ,parameter: &Value) -> Result { + /// Validation data can be provided aswell. + /// ``` + /// extern crate serde_json; + /// use lightgbm::{Dataset, Booster}; + /// use serde_json::json; + /// + /// let data = vec![vec![1.0, 0.1, 0.2, 0.1], + /// vec![0.7, 0.4, 0.5, 0.1], + /// vec![0.9, 0.8, 0.5, 0.1], + /// vec![0.2, 0.2, 0.8, 0.7], + /// vec![0.1, 0.7, 1.0, 0.9]]; + /// let label = vec![0.0, 0.0, 0.0, 1.0, 1.0]; + /// let train_data = Dataset::from_mat(data, label).unwrap(); + /// + /// let data = vec![ + /// vec![0.9, 0.6, 0.2, 0.1], + /// vec![0.5, 0.7, 0.2, 0.1], + /// vec![0.2, 0.1, 0.6, 0.8]]; + /// let label = vec![0.0, 0.0, 1.0]; + /// let val_data = Dataset::from_mat(data, label); + /// + /// let params = json!{ + /// { + /// "num_iterations": 3, + /// "objective": "binary", + /// "metric": "auc" + /// } + /// }; + /// + /// let bst = Booster::train(train_data, val_data.ok(), ¶ms).unwrap(); + /// ``` + pub fn train(train_data: Dataset, val_data: Option, parameter: &Value) -> Result { // get num_iterations let num_iterations: i64 = if parameter["num_iterations"].is_null() { 100 @@ -96,6 +127,13 @@ impl Booster { &mut handle ))?; + if let Some(validation_data) = val_data { + lgbm_call!(lightgbm_sys::LGBM_BoosterAddValidData( + handle, + validation_data.handle + ))?; + } + let mut is_finished: i32 = 0; for _ in 1..num_iterations { lgbm_call!(lightgbm_sys::LGBM_BoosterUpdateOneIter( @@ -200,6 +238,37 @@ impl Booster { Ok(output) } + + /// return the name of up to 20 evaluation metrics that were used + pub fn get_eval_names(&self) -> Result> { + let num_metrics = 20; + let feature_name_length = 32; + let mut num_eval_names = 0; + let mut out_buffer_len = 0; + let out_strs = (0..num_metrics) + .map(|_| { + CString::new(" ".repeat(feature_name_length)) + .unwrap() + .into_raw() as *mut c_char + }) + .collect::>(); + lgbm_call!(lightgbm_sys::LGBM_BoosterGetEvalNames( + self.handle, + num_metrics as i32, + &mut num_eval_names, + feature_name_length as u64, + &mut out_buffer_len, + out_strs.as_ptr() as *mut *mut c_char + ))?; + let output: Vec = out_strs + .into_iter() + .map(|s| unsafe { CString::from_raw(s).into_string().unwrap() }) + .take(num_eval_names as usize) + .collect(); + Ok(output) + } + + // Get Feature Importance pub fn feature_importance(&self) -> Result> { let num_feature = self.num_feature()?; @@ -350,6 +419,14 @@ mod tests { assert_eq!(num_feature, 28); } + #[test] + fn get_eval_names() { + let params = _default_params(); + let bst = _train_booster(¶ms); + let eval_names = bst.get_eval_names().unwrap(); + assert_eq!(eval_names, vec!["auc"]) + } + #[test] fn feature_importance() { let params = _default_params(); From 6182be17875413d284a17641daf6b97856b68b45 Mon Sep 17 00:00:00 2001 From: David Schwab Date: Thu, 4 May 2023 19:52:40 +0200 Subject: [PATCH 03/10] tried to increase test complexity. eval_names() test fails when more than one metrix is used. --- src/booster.rs | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/src/booster.rs b/src/booster.rs index 9d1465e..b595803 100644 --- a/src/booster.rs +++ b/src/booster.rs @@ -1,4 +1,4 @@ -use libc::{c_char, c_double, c_int, c_longlong, c_void}; +use libc::{c_char, c_double, c_longlong, c_void}; use std; use std::convert::TryInto; use std::ffi::CString; @@ -238,10 +238,19 @@ impl Booster { Ok(output) } - - /// return the name of up to 20 evaluation metrics that were used - pub fn get_eval_names(&self) -> Result> { - let num_metrics = 20; + /// Get number of evaluation metrics + pub fn num_eval(&self) -> Result { + let mut out_len = 0; + lgbm_call!(lightgbm_sys::LGBM_BoosterGetEvalCounts( + self.handle, + &mut out_len + ))?; + Ok(out_len) + } + + /// Get names of evaluation metrics + pub fn eval_names(&self) -> Result> { + let num_metrics = self.num_eval()?; let feature_name_length = 32; let mut num_eval_names = 0; let mut out_buffer_len = 0; @@ -420,11 +429,13 @@ mod tests { } #[test] - fn get_eval_names() { - let params = _default_params(); + fn eval_names() { + let mut params = _default_params(); + params["metric"] = serde_json::Value::from(vec!["auc", "acc"]); + println!("{}", params); let bst = _train_booster(¶ms); - let eval_names = bst.get_eval_names().unwrap(); - assert_eq!(eval_names, vec!["auc"]) + let eval_names = bst.eval_names().unwrap(); + assert_eq!(eval_names, vec!["auc", "acc"]) } #[test] From 87995cc9526508f192f32ae0782c7b0815d48465 Mon Sep 17 00:00:00 2001 From: David Schwab Date: Fri, 5 May 2023 09:54:28 +0200 Subject: [PATCH 04/10] fixed bug in parameter formatting: edgecase where multiple values are provided for key was not formatted properly --- src/booster.rs | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/src/booster.rs b/src/booster.rs index b595803..692ba8a 100644 --- a/src/booster.rs +++ b/src/booster.rs @@ -1,11 +1,10 @@ -use libc::{c_char, c_double, c_longlong, c_void}; use std; use std::convert::TryInto; use std::ffi::CString; -use serde_json::Value; - +use libc::{c_char, c_double, c_longlong, c_void}; use lightgbm_sys; +use serde_json::{Value}; use crate::{Dataset, Error, Result}; @@ -115,6 +114,16 @@ impl Booster { .as_object() .unwrap() .iter() + .map(|(k, v)| + match v { + Value::Array(a) => { + let v_formatted = a.iter().map(|x| x.to_string() + ",").collect::(); + let v_formatted = v_formatted.replace("\",\"", ",") + .trim_end_matches(",").to_string(); + (k, v_formatted) + }, + _ => (k, v.to_string()) + }) .map(|(k, v)| format!("{}={}", k, v)) .collect::>() .join(" "); @@ -374,11 +383,13 @@ impl Drop for Booster { #[cfg(test)] mod tests { - use super::*; - use serde_json::json; use std::fs; use std::path::Path; + use serde_json::json; + + use super::*; + fn _read_train_file() -> Result { Dataset::from_file(&"lightgbm-sys/lightgbm/examples/binary_classification/binary.train") } @@ -430,12 +441,18 @@ mod tests { #[test] fn eval_names() { - let mut params = _default_params(); - params["metric"] = serde_json::Value::from(vec!["auc", "acc"]); + let params = json! { + { + "num_iterations": 10, + "objective": "binary", + "metric": ["auc", "l1"], + "data_random_seed": 0 + } + }; println!("{}", params); let bst = _train_booster(¶ms); let eval_names = bst.eval_names().unwrap(); - assert_eq!(eval_names, vec!["auc", "acc"]) + assert_eq!(eval_names, vec!["auc", "l1"]) } #[test] From c2ee104d92b63e9ec08d83f20ba0744bcd346911 Mon Sep 17 00:00:00 2001 From: David Schwab Date: Fri, 5 May 2023 13:50:26 +0200 Subject: [PATCH 05/10] implemented bin matching for dataset file loading, implemented eval results --- src/booster.rs | 136 +++++++++++++++++++++++++++++++++++++++++++------ src/dataset.rs | 21 ++++++-- 2 files changed, 136 insertions(+), 21 deletions(-) diff --git a/src/booster.rs b/src/booster.rs index 692ba8a..4648b97 100644 --- a/src/booster.rs +++ b/src/booster.rs @@ -4,7 +4,7 @@ use std::ffi::CString; use libc::{c_char, c_double, c_longlong, c_void}; use lightgbm_sys; -use serde_json::{Value}; +use serde_json::Value; use crate::{Dataset, Error, Result}; @@ -13,6 +13,13 @@ pub struct Booster { handle: lightgbm_sys::BoosterHandle, } +/// Represents the score during training on either the train or validation set +#[derive(Debug, PartialEq)] +pub struct EvalResult { + pub metric: String, + pub score: f64, +} + impl Booster { fn new(handle: lightgbm_sys::BoosterHandle) -> Self { Booster { handle } @@ -101,7 +108,11 @@ impl Booster { /// /// let bst = Booster::train(train_data, val_data.ok(), ¶ms).unwrap(); /// ``` - pub fn train(train_data: Dataset, val_data: Option, parameter: &Value) -> Result { + pub fn train( + train_data: Dataset, + val_data: Option, + parameter: &Value, + ) -> Result { // get num_iterations let num_iterations: i64 = if parameter["num_iterations"].is_null() { 100 @@ -110,20 +121,22 @@ impl Booster { }; // exchange params {"x": "y", "z": 1} => "x=y z=1" + // and {"k" = ["a", "b"]} => "k=a,b" let params_string = parameter .as_object() .unwrap() .iter() - .map(|(k, v)| - match v { + .map(|(k, v)| match v { Value::Array(a) => { let v_formatted = a.iter().map(|x| x.to_string() + ",").collect::(); - let v_formatted = v_formatted.replace("\",\"", ",") - .trim_end_matches(",").to_string(); + let v_formatted = v_formatted + .replace("\",\"", ",") + .trim_end_matches(",") + .to_string(); (k, v_formatted) - }, - _ => (k, v.to_string()) - }) + } + _ => (k, v.to_string()), + }) .map(|(k, v)| format!("{}={}", k, v)) .collect::>() .join(" "); @@ -260,12 +273,12 @@ impl Booster { /// Get names of evaluation metrics pub fn eval_names(&self) -> Result> { let num_metrics = self.num_eval()?; - let feature_name_length = 32; + let metric_name_length = 32; let mut num_eval_names = 0; let mut out_buffer_len = 0; let out_strs = (0..num_metrics) .map(|_| { - CString::new(" ".repeat(feature_name_length)) + CString::new(" ".repeat(metric_name_length)) .unwrap() .into_raw() as *mut c_char }) @@ -274,7 +287,7 @@ impl Booster { self.handle, num_metrics as i32, &mut num_eval_names, - feature_name_length as u64, + metric_name_length as u64, &mut out_buffer_len, out_strs.as_ptr() as *mut *mut c_char ))?; @@ -286,6 +299,22 @@ impl Booster { Ok(output) } + pub fn get_eval(&self, data_index: i32) -> Result> { + let names = self.eval_names()?; + let mut out_len = 0; + let out_result: Vec = vec![Default::default(); names.len()]; + lgbm_call!(lightgbm_sys::LGBM_BoosterGetEval( + self.handle, + data_index, + &mut out_len, + out_result.as_ptr() as *mut c_double + ))?; + Ok(names + .into_iter() + .zip(out_result) + .map(|(metric, score)| EvalResult { metric, score }) + .collect()) + } // Get Feature Importance pub fn feature_importance(&self) -> Result> { @@ -391,7 +420,10 @@ mod tests { use super::*; fn _read_train_file() -> Result { - Dataset::from_file(&"lightgbm-sys/lightgbm/examples/binary_classification/binary.train") + Dataset::from_file( + &"lightgbm-sys/lightgbm/examples/binary_classification/binary.train", + None, + ) } fn _train_booster(params: &Value) -> Booster { @@ -443,18 +475,90 @@ mod tests { fn eval_names() { let params = json! { { - "num_iterations": 10, + "num_iterations": 1, "objective": "binary", - "metric": ["auc", "l1"], + "metrics": ["auc", "l1"], "data_random_seed": 0 } }; - println!("{}", params); let bst = _train_booster(¶ms); let eval_names = bst.eval_names().unwrap(); assert_eq!(eval_names, vec!["auc", "l1"]) } + #[ignore] + #[test] + fn get_eval_broken() { + let params = json! { + { + "num_iterations": 30, + "objective": "binary", + "boosting_type": "gbdt", + "metrics": ["binary_logloss","auc"], + "label_column": 0, + "max_bin": 255, + "tree_learner": "serial", + "feature_fraction": 0.8, + "is_enable_sparse": true, + "data_random_seed": 0 + } + }; + let train = _read_train_file().unwrap(); + let val = Dataset::from_file( + &"lightgbm-sys/lightgbm/examples/binary_classification/binary.test", + Some(train.handle), + ) + .unwrap(); + + // this training segfaults at training step ffi call + let _bst = Booster::train(train, Some(val), ¶ms).unwrap(); + + //let eval_train = bst.get_eval(0).unwrap(); + //let eval_val = bst.get_eval(1).unwrap(); + //let eval_invalid = bst.get_eval(420); + //assert!(eval_invalid.is_err()); + } + + #[test] + fn get_eval() { + let data = vec![ + vec![1.0, 0.1, 0.2, 0.1], + vec![0.7, 0.4, 0.5, 0.1], + vec![0.9, 0.8, 0.5, 0.1], + vec![0.2, 0.2, 0.8, 0.7], + vec![0.1, 0.7, 1.0, 0.9], + ]; + let label = vec![0.0, 0.0, 0.0, 1.0, 1.0]; + let train_data = Dataset::from_mat(data, label).unwrap(); + + let data = vec![ + vec![0.9, 0.6, 0.2, 0.1], + vec![0.5, 0.7, 0.2, 0.1], + vec![0.2, 0.1, 0.6, 0.8], + ]; + let label = vec![0.0, 0.0, 1.0]; + let val_data = Dataset::from_mat(data, label); + + let params = json! { + { + "num_iterations": 3, + "objective": "binary", + "metric": ["auc","l1"] + } + }; + + let bst = Booster::train(train_data, val_data.ok(), ¶ms).unwrap(); + + let train_res = bst.get_eval(0).unwrap(); + let val_res = bst.get_eval(1).unwrap(); + let invalid_res = bst.get_eval(420); + assert!(invalid_res.is_err()); + assert_eq!(train_res[0].metric, "auc"); + assert_eq!(val_res[1].metric, "l1"); + assert!(0.0 <= train_res[0].score && train_res[0].score <= 1.0); // make shure values make sense + assert!(0.0 <= train_res[1].score); + } + #[test] fn feature_importance() { let params = _default_params(); diff --git a/src/dataset.rs b/src/dataset.rs index a2bcbdb..406a028 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -1,5 +1,6 @@ use libc::{c_char, c_void}; use lightgbm_sys; +use lightgbm_sys::DatasetHandle; use std; use std::convert::TryInto; use std::ffi::CString; @@ -32,7 +33,7 @@ use crate::{Error, Result}; /// ``` /// use lightgbm::Dataset; /// -/// let dataset = Dataset::from_file(&"lightgbm-sys/lightgbm/examples/binary_classification/binary.train").unwrap(); +/// let dataset = Dataset::from_file(&"lightgbm-sys/lightgbm/examples/binary_classification/binary.train", None).unwrap(); /// ``` pub struct Dataset { pub(crate) handle: lightgbm_sys::DatasetHandle, @@ -112,21 +113,28 @@ impl Dataset { /// 0 0.1 0.9 1.0 /// ``` /// + /// You can provide a Dataset Handle to align bin mappers between Datasets + /// /// Example /// ``` /// use lightgbm::Dataset; /// - /// let dataset = Dataset::from_file(&"lightgbm-sys/lightgbm/examples/binary_classification/binary.train"); + /// let dataset = Dataset::from_file(&"lightgbm-sys/lightgbm/examples/binary_classification/binary.train", None); /// ``` - pub fn from_file(file_path: &str) -> Result { + pub fn from_file(file_path: &str, dataset_handle: Option) -> Result { let file_path_str = CString::new(file_path).unwrap(); let params = CString::new("").unwrap(); let mut handle = std::ptr::null_mut(); + let reference = match dataset_handle { + Some(h) => h, + None => std::ptr::null_mut(), + }; + lgbm_call!(lightgbm_sys::LGBM_DatasetCreateFromFile( file_path_str.as_ptr() as *const c_char, params.as_ptr() as *const c_char, - std::ptr::null_mut(), + reference, &mut handle ))?; @@ -257,7 +265,10 @@ impl Drop for Dataset { mod tests { use super::*; fn read_train_file() -> Result { - Dataset::from_file(&"lightgbm-sys/lightgbm/examples/binary_classification/binary.train") + Dataset::from_file( + &"lightgbm-sys/lightgbm/examples/binary_classification/binary.train", + None, + ) } #[test] From 7f6e73a94abed811f964689aeb6d66abb4b6fa73 Mon Sep 17 00:00:00 2001 From: David Schwab Date: Fri, 5 May 2023 14:02:15 +0200 Subject: [PATCH 06/10] fix clippy warnings --- src/booster.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/booster.rs b/src/booster.rs index 4648b97..463d1c9 100644 --- a/src/booster.rs +++ b/src/booster.rs @@ -131,7 +131,7 @@ impl Booster { let v_formatted = a.iter().map(|x| x.to_string() + ",").collect::(); let v_formatted = v_formatted .replace("\",\"", ",") - .trim_end_matches(",") + .trim_end_matches(',') .to_string(); (k, v_formatted) } @@ -285,7 +285,7 @@ impl Booster { .collect::>(); lgbm_call!(lightgbm_sys::LGBM_BoosterGetEvalNames( self.handle, - num_metrics as i32, + num_metrics, &mut num_eval_names, metric_name_length as u64, &mut out_buffer_len, From 151a11f05cd5149fba815f049798332a3564e519 Mon Sep 17 00:00:00 2001 From: Jannis Froese Date: Wed, 10 May 2023 18:17:02 +0200 Subject: [PATCH 07/10] don't prematurely drop validation data --- src/booster.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/booster.rs b/src/booster.rs index 463d1c9..12852a0 100644 --- a/src/booster.rs +++ b/src/booster.rs @@ -149,7 +149,8 @@ impl Booster { &mut handle ))?; - if let Some(validation_data) = val_data { + // the following has to borrow val_data to avoid dropping the dataset + if let Some(validation_data) = &val_data { lgbm_call!(lightgbm_sys::LGBM_BoosterAddValidData( handle, validation_data.handle @@ -486,7 +487,6 @@ mod tests { assert_eq!(eval_names, vec!["auc", "l1"]) } - #[ignore] #[test] fn get_eval_broken() { let params = json! { From 5e9d2395d154965e95eeb503053f632cab62a451 Mon Sep 17 00:00:00 2001 From: David Schwab Date: Wed, 10 May 2023 18:27:28 +0200 Subject: [PATCH 08/10] fix seg_fault and wrong train loop indexing --- src/booster.rs | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/booster.rs b/src/booster.rs index 463d1c9..23145cb 100644 --- a/src/booster.rs +++ b/src/booster.rs @@ -149,7 +149,7 @@ impl Booster { &mut handle ))?; - if let Some(validation_data) = val_data { + if let Some(validation_data) = &val_data { lgbm_call!(lightgbm_sys::LGBM_BoosterAddValidData( handle, validation_data.handle @@ -157,7 +157,7 @@ impl Booster { } let mut is_finished: i32 = 0; - for _ in 1..num_iterations { + for _ in 0..num_iterations { lgbm_call!(lightgbm_sys::LGBM_BoosterUpdateOneIter( handle, &mut is_finished @@ -247,7 +247,7 @@ impl Booster { .collect::>(); lgbm_call!(lightgbm_sys::LGBM_BoosterGetFeatureNames( self.handle, - num_feature as i32, + num_feature, &mut num_feature_names, feature_name_length as u64, &mut out_buffer_len, @@ -486,9 +486,8 @@ mod tests { assert_eq!(eval_names, vec!["auc", "l1"]) } - #[ignore] #[test] - fn get_eval_broken() { + fn get_eval_samle_dataset() { let params = json! { { "num_iterations": 30, @@ -510,13 +509,14 @@ mod tests { ) .unwrap(); - // this training segfaults at training step ffi call - let _bst = Booster::train(train, Some(val), ¶ms).unwrap(); + let bst = Booster::train(train, Some(val), ¶ms).unwrap(); - //let eval_train = bst.get_eval(0).unwrap(); - //let eval_val = bst.get_eval(1).unwrap(); - //let eval_invalid = bst.get_eval(420); - //assert!(eval_invalid.is_err()); + let eval_train = bst.get_eval(0); + let eval_val = bst.get_eval(1); + assert!(eval_val.is_ok()); + assert!(eval_train.is_ok()); + let eval_invalid = bst.get_eval(420); + assert!(eval_invalid.is_err()); } #[test] @@ -561,7 +561,8 @@ mod tests { #[test] fn feature_importance() { - let params = _default_params(); + let mut params = _default_params(); + params["num_iterations"] = "0".parse().unwrap(); let bst = _train_booster(¶ms); let feature_importance = bst.feature_importance().unwrap(); assert_eq!(feature_importance, vec![0.0; 28]); From 8779e60232d1ba7f1f54e40c52423cca42c313fc Mon Sep 17 00:00:00 2001 From: Jannis Froese Date: Fri, 12 May 2023 15:54:36 +0200 Subject: [PATCH 09/10] fix unsoundness of eval_names --- src/booster.rs | 73 +++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 58 insertions(+), 15 deletions(-) diff --git a/src/booster.rs b/src/booster.rs index f232e19..cccf44d 100644 --- a/src/booster.rs +++ b/src/booster.rs @@ -274,29 +274,72 @@ impl Booster { /// Get names of evaluation metrics pub fn eval_names(&self) -> Result> { let num_metrics = self.num_eval()?; - let metric_name_length = 32; + + ///////////////////////////////////////////////////////////////////// + // call with 0-sized buffer to find out how much space to allocate + ///////////////////////////////////////////////////////////////////// let mut num_eval_names = 0; let mut out_buffer_len = 0; - let out_strs = (0..num_metrics) - .map(|_| { - CString::new(" ".repeat(metric_name_length)) - .unwrap() - .into_raw() as *mut c_char - }) + + lgbm_call!(lightgbm_sys::LGBM_BoosterGetEvalNames( + self.handle, + 0, + &mut num_eval_names, + 0, + &mut out_buffer_len, + std::ptr::null_mut() as *mut *mut c_char + )) + .unwrap(); + + ///////////////////////////////////////////////////////////////////// + // sanity check + ///////////////////////////////////////////////////////////////////// + assert_eq!(num_eval_names, num_metrics); + + ///////////////////////////////////////////////////////////////////// + // get the actual strings + ///////////////////////////////////////////////////////////////////// + + let mut out_strs = (0..num_metrics) + .map(|_| (0..out_buffer_len).map(|_| 0).collect::>()) + .collect::>(); + + let mut out_strs_pointers = out_strs + .iter_mut() + .map(|s| s.as_mut_ptr()) .collect::>(); + + let metric_name_length = out_buffer_len; + lgbm_call!(lightgbm_sys::LGBM_BoosterGetEvalNames( self.handle, num_metrics, &mut num_eval_names, - metric_name_length as u64, + metric_name_length, &mut out_buffer_len, - out_strs.as_ptr() as *mut *mut c_char - ))?; - let output: Vec = out_strs - .into_iter() - .map(|s| unsafe { CString::from_raw(s).into_string().unwrap() }) - .take(num_eval_names as usize) - .collect(); + out_strs_pointers.as_mut_ptr() as *mut *mut c_char + )) + .unwrap(); + + drop(out_strs_pointers); // don't let pointers outlive their target + + let mut output = Vec::with_capacity(out_strs.len()); + for mut out_str in out_strs { + let first_null = out_str + .iter() + .enumerate() + .find(|(_, e)| **e == 0) + .map(|(i, _)| i) + .expect("string not null terminated, possible memory corruption"); + out_str.truncate(first_null + 1); + + let string = CString::from_vec_with_nul(out_str) + .expect("string memory invariant violated, possible memory corruption") + .into_string() + .map_err(|_| Error::new("name not valid UTF-8"))?; + output.push(string); + } + Ok(output) } From f1f0e76ed7b9fb0f09c50f47bbc55da8ed6f4839 Mon Sep 17 00:00:00 2001 From: Jannis Froese Date: Mon, 15 May 2023 13:35:38 +0200 Subject: [PATCH 10/10] don't panic in sanity check --- src/booster.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/booster.rs b/src/booster.rs index cccf44d..0830628 100644 --- a/src/booster.rs +++ b/src/booster.rs @@ -294,7 +294,11 @@ impl Booster { ///////////////////////////////////////////////////////////////////// // sanity check ///////////////////////////////////////////////////////////////////// - assert_eq!(num_eval_names, num_metrics); + if num_eval_names != num_metrics { + return Err(Error::new(format!( + "expected num_eval_names==num_metrics, but got {num_eval_names}!={num_metrics}. This is a bug in lightgbm or its rust wrapper" + ))); + } ///////////////////////////////////////////////////////////////////// // get the actual strings