diff --git a/.gitignore b/.gitignore index 61d7627..f8efa46 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,5 @@ lightgbm-sys/target examples/binary_classification/target/ examples/multiclass_classification/target/ examples/regression/target/ + +.idea diff --git a/examples/multiclass_classification/src/main.rs b/examples/multiclass_classification/src/main.rs index 4f39225..86d5896 100644 --- a/examples/multiclass_classification/src/main.rs +++ b/examples/multiclass_classification/src/main.rs @@ -63,7 +63,7 @@ fn main() -> std::io::Result<()> { } }; - let booster = Booster::train(train_dataset, ¶ms).unwrap(); + let booster = Booster::train(train_dataset, None, ¶ms).unwrap(); let result = booster.predict(test_features).unwrap(); let mut tp = 0; diff --git a/src/booster.rs b/src/booster.rs index e8d4d4d..0830628 100644 --- a/src/booster.rs +++ b/src/booster.rs @@ -1,11 +1,10 @@ -use libc::{c_char, c_double, c_longlong, c_void}; use std; use std::convert::TryInto; use std::ffi::CString; -use serde_json::Value; - +use libc::{c_char, c_double, c_longlong, c_void}; use lightgbm_sys; +use serde_json::Value; use crate::{Dataset, Error, Result}; @@ -14,6 +13,13 @@ pub struct Booster { handle: lightgbm_sys::BoosterHandle, } +/// Represents the score during training on either the train or validation set +#[derive(Debug, PartialEq)] +pub struct EvalResult { + pub metric: String, + pub score: f64, +} + impl Booster { fn new(handle: lightgbm_sys::BoosterHandle) -> Self { Booster { handle } @@ -69,9 +75,44 @@ impl Booster { /// "metric": "auc" /// } /// }; - /// let bst = Booster::train(dataset, ¶ms).unwrap(); + /// let bst = Booster::train(dataset, None, ¶ms).unwrap(); /// ``` - pub fn train(dataset: Dataset, parameter: &Value) -> Result { + /// Validation data can be provided aswell. + /// ``` + /// extern crate serde_json; + /// use lightgbm::{Dataset, Booster}; + /// use serde_json::json; + /// + /// let data = vec![vec![1.0, 0.1, 0.2, 0.1], + /// vec![0.7, 0.4, 0.5, 0.1], + /// vec![0.9, 0.8, 0.5, 0.1], + /// vec![0.2, 0.2, 0.8, 0.7], + /// vec![0.1, 0.7, 1.0, 0.9]]; + /// let label = vec![0.0, 0.0, 0.0, 1.0, 1.0]; + /// let train_data = Dataset::from_mat(data, label).unwrap(); + /// + /// let data = vec![ + /// vec![0.9, 0.6, 0.2, 0.1], + /// vec![0.5, 0.7, 0.2, 0.1], + /// vec![0.2, 0.1, 0.6, 0.8]]; + /// let label = vec![0.0, 0.0, 1.0]; + /// let val_data = Dataset::from_mat(data, label); + /// + /// let params = json!{ + /// { + /// "num_iterations": 3, + /// "objective": "binary", + /// "metric": "auc" + /// } + /// }; + /// + /// let bst = Booster::train(train_data, val_data.ok(), ¶ms).unwrap(); + /// ``` + pub fn train( + train_data: Dataset, + val_data: Option, + parameter: &Value, + ) -> Result { // get num_iterations let num_iterations: i64 = if parameter["num_iterations"].is_null() { 100 @@ -80,10 +121,22 @@ impl Booster { }; // exchange params {"x": "y", "z": 1} => "x=y z=1" + // and {"k" = ["a", "b"]} => "k=a,b" let params_string = parameter .as_object() .unwrap() .iter() + .map(|(k, v)| match v { + Value::Array(a) => { + let v_formatted = a.iter().map(|x| x.to_string() + ",").collect::(); + let v_formatted = v_formatted + .replace("\",\"", ",") + .trim_end_matches(',') + .to_string(); + (k, v_formatted) + } + _ => (k, v.to_string()), + }) .map(|(k, v)| format!("{}={}", k, v)) .collect::>() .join(" "); @@ -91,13 +144,21 @@ impl Booster { let mut handle = std::ptr::null_mut(); lgbm_call!(lightgbm_sys::LGBM_BoosterCreate( - dataset.handle, + train_data.handle, params_cstring.as_ptr() as *const c_char, &mut handle ))?; + // the following has to borrow val_data to avoid dropping the dataset + if let Some(validation_data) = &val_data { + lgbm_call!(lightgbm_sys::LGBM_BoosterAddValidData( + handle, + validation_data.handle + ))?; + } + let mut is_finished: i32 = 0; - for _ in 1..num_iterations { + for _ in 0..num_iterations { lgbm_call!(lightgbm_sys::LGBM_BoosterUpdateOneIter( handle, &mut is_finished @@ -187,7 +248,7 @@ impl Booster { .collect::>(); lgbm_call!(lightgbm_sys::LGBM_BoosterGetFeatureNames( self.handle, - num_feature as i32, + num_feature, &mut num_feature_names, feature_name_length as u64, &mut out_buffer_len, @@ -200,6 +261,109 @@ impl Booster { Ok(output) } + /// Get number of evaluation metrics + pub fn num_eval(&self) -> Result { + let mut out_len = 0; + lgbm_call!(lightgbm_sys::LGBM_BoosterGetEvalCounts( + self.handle, + &mut out_len + ))?; + Ok(out_len) + } + + /// Get names of evaluation metrics + pub fn eval_names(&self) -> Result> { + let num_metrics = self.num_eval()?; + + ///////////////////////////////////////////////////////////////////// + // call with 0-sized buffer to find out how much space to allocate + ///////////////////////////////////////////////////////////////////// + let mut num_eval_names = 0; + let mut out_buffer_len = 0; + + lgbm_call!(lightgbm_sys::LGBM_BoosterGetEvalNames( + self.handle, + 0, + &mut num_eval_names, + 0, + &mut out_buffer_len, + std::ptr::null_mut() as *mut *mut c_char + )) + .unwrap(); + + ///////////////////////////////////////////////////////////////////// + // sanity check + ///////////////////////////////////////////////////////////////////// + if num_eval_names != num_metrics { + return Err(Error::new(format!( + "expected num_eval_names==num_metrics, but got {num_eval_names}!={num_metrics}. This is a bug in lightgbm or its rust wrapper" + ))); + } + + ///////////////////////////////////////////////////////////////////// + // get the actual strings + ///////////////////////////////////////////////////////////////////// + + let mut out_strs = (0..num_metrics) + .map(|_| (0..out_buffer_len).map(|_| 0).collect::>()) + .collect::>(); + + let mut out_strs_pointers = out_strs + .iter_mut() + .map(|s| s.as_mut_ptr()) + .collect::>(); + + let metric_name_length = out_buffer_len; + + lgbm_call!(lightgbm_sys::LGBM_BoosterGetEvalNames( + self.handle, + num_metrics, + &mut num_eval_names, + metric_name_length, + &mut out_buffer_len, + out_strs_pointers.as_mut_ptr() as *mut *mut c_char + )) + .unwrap(); + + drop(out_strs_pointers); // don't let pointers outlive their target + + let mut output = Vec::with_capacity(out_strs.len()); + for mut out_str in out_strs { + let first_null = out_str + .iter() + .enumerate() + .find(|(_, e)| **e == 0) + .map(|(i, _)| i) + .expect("string not null terminated, possible memory corruption"); + out_str.truncate(first_null + 1); + + let string = CString::from_vec_with_nul(out_str) + .expect("string memory invariant violated, possible memory corruption") + .into_string() + .map_err(|_| Error::new("name not valid UTF-8"))?; + output.push(string); + } + + Ok(output) + } + + pub fn get_eval(&self, data_index: i32) -> Result> { + let names = self.eval_names()?; + let mut out_len = 0; + let out_result: Vec = vec![Default::default(); names.len()]; + lgbm_call!(lightgbm_sys::LGBM_BoosterGetEval( + self.handle, + data_index, + &mut out_len, + out_result.as_ptr() as *mut c_double + ))?; + Ok(names + .into_iter() + .zip(out_result) + .map(|(metric, score)| EvalResult { metric, score }) + .collect()) + } + // Get Feature Importance pub fn feature_importance(&self) -> Result> { let num_feature = self.num_feature()?; @@ -296,18 +460,23 @@ impl Drop for Booster { #[cfg(test)] mod tests { - use super::*; - use serde_json::json; use std::fs; use std::path::Path; + use serde_json::json; + + use super::*; + fn _read_train_file() -> Result { - Dataset::from_file(&"lightgbm-sys/lightgbm/examples/binary_classification/binary.train") + Dataset::from_file( + &"lightgbm-sys/lightgbm/examples/binary_classification/binary.train", + None, + ) } fn _train_booster(params: &Value) -> Booster { let dataset = _read_train_file().unwrap(); - Booster::train(dataset, ¶ms).unwrap() + Booster::train(dataset, None, ¶ms).unwrap() } fn _default_params() -> Value { @@ -350,9 +519,98 @@ mod tests { assert_eq!(num_feature, 28); } + #[test] + fn eval_names() { + let params = json! { + { + "num_iterations": 1, + "objective": "binary", + "metrics": ["auc", "l1"], + "data_random_seed": 0 + } + }; + let bst = _train_booster(¶ms); + let eval_names = bst.eval_names().unwrap(); + assert_eq!(eval_names, vec!["auc", "l1"]) + } + + #[test] + fn get_eval_sample_dataset() { + let params = json! { + { + "num_iterations": 30, + "objective": "binary", + "boosting_type": "gbdt", + "metrics": ["binary_logloss","auc"], + "label_column": 0, + "max_bin": 255, + "tree_learner": "serial", + "feature_fraction": 0.8, + "is_enable_sparse": true, + "data_random_seed": 0 + } + }; + let train = _read_train_file().unwrap(); + let val = Dataset::from_file( + &"lightgbm-sys/lightgbm/examples/binary_classification/binary.test", + Some(train.handle), + ) + .unwrap(); + + let bst = Booster::train(train, Some(val), ¶ms).unwrap(); + + let eval_train = bst.get_eval(0); + let eval_val = bst.get_eval(1); + assert!(eval_val.is_ok()); + assert!(eval_train.is_ok()); + let eval_invalid = bst.get_eval(420); + assert!(eval_invalid.is_err()); + } + + #[test] + fn get_eval() { + let data = vec![ + vec![1.0, 0.1, 0.2, 0.1], + vec![0.7, 0.4, 0.5, 0.1], + vec![0.9, 0.8, 0.5, 0.1], + vec![0.2, 0.2, 0.8, 0.7], + vec![0.1, 0.7, 1.0, 0.9], + ]; + let label = vec![0.0, 0.0, 0.0, 1.0, 1.0]; + let train_data = Dataset::from_mat(data, label).unwrap(); + + let data = vec![ + vec![0.9, 0.6, 0.2, 0.1], + vec![0.5, 0.7, 0.2, 0.1], + vec![0.2, 0.1, 0.6, 0.8], + ]; + let label = vec![0.0, 0.0, 1.0]; + let val_data = Dataset::from_mat(data, label); + + let params = json! { + { + "num_iterations": 3, + "objective": "binary", + "metric": ["auc","l1"] + } + }; + + let bst = Booster::train(train_data, val_data.ok(), ¶ms).unwrap(); + + let train_res = bst.get_eval(0).unwrap(); + let val_res = bst.get_eval(1).unwrap(); + let invalid_res = bst.get_eval(420); + assert!(invalid_res.is_err()); + assert_eq!(train_res[0].metric, "auc"); + assert_eq!(val_res[1].metric, "l1"); + assert!(0.0 <= train_res[0].score && train_res[0].score <= 1.0); // make shure values make sense + assert!(0.0 <= train_res[1].score); + } + #[test] fn feature_importance() { - let params = _default_params(); + let mut params = _default_params(); + params["num_iterations"] = "0".parse().unwrap(); let bst = _train_booster(¶ms); let feature_importance = bst.feature_importance().unwrap(); assert_eq!(feature_importance, vec![0.0; 28]); diff --git a/src/dataset.rs b/src/dataset.rs index a2bcbdb..406a028 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -1,5 +1,6 @@ use libc::{c_char, c_void}; use lightgbm_sys; +use lightgbm_sys::DatasetHandle; use std; use std::convert::TryInto; use std::ffi::CString; @@ -32,7 +33,7 @@ use crate::{Error, Result}; /// ``` /// use lightgbm::Dataset; /// -/// let dataset = Dataset::from_file(&"lightgbm-sys/lightgbm/examples/binary_classification/binary.train").unwrap(); +/// let dataset = Dataset::from_file(&"lightgbm-sys/lightgbm/examples/binary_classification/binary.train", None).unwrap(); /// ``` pub struct Dataset { pub(crate) handle: lightgbm_sys::DatasetHandle, @@ -112,21 +113,28 @@ impl Dataset { /// 0 0.1 0.9 1.0 /// ``` /// + /// You can provide a Dataset Handle to align bin mappers between Datasets + /// /// Example /// ``` /// use lightgbm::Dataset; /// - /// let dataset = Dataset::from_file(&"lightgbm-sys/lightgbm/examples/binary_classification/binary.train"); + /// let dataset = Dataset::from_file(&"lightgbm-sys/lightgbm/examples/binary_classification/binary.train", None); /// ``` - pub fn from_file(file_path: &str) -> Result { + pub fn from_file(file_path: &str, dataset_handle: Option) -> Result { let file_path_str = CString::new(file_path).unwrap(); let params = CString::new("").unwrap(); let mut handle = std::ptr::null_mut(); + let reference = match dataset_handle { + Some(h) => h, + None => std::ptr::null_mut(), + }; + lgbm_call!(lightgbm_sys::LGBM_DatasetCreateFromFile( file_path_str.as_ptr() as *const c_char, params.as_ptr() as *const c_char, - std::ptr::null_mut(), + reference, &mut handle ))?; @@ -257,7 +265,10 @@ impl Drop for Dataset { mod tests { use super::*; fn read_train_file() -> Result { - Dataset::from_file(&"lightgbm-sys/lightgbm/examples/binary_classification/binary.train") + Dataset::from_file( + &"lightgbm-sys/lightgbm/examples/binary_classification/binary.train", + None, + ) } #[test]