From 3e586610473f384adf8fb1cb77f1558965129700 Mon Sep 17 00:00:00 2001 From: vaaaaanquish <6syun9@gmail.com> Date: Thu, 21 Jan 2021 22:12:57 +0900 Subject: [PATCH 1/3] add from_file, save_file --- src/booster.rs | 86 +++++++++++++++++++++----- test/test_from_file.input | 127 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 198 insertions(+), 15 deletions(-) create mode 100644 test/test_from_file.input diff --git a/src/booster.rs b/src/booster.rs index d24aa38..15b2cc8 100644 --- a/src/booster.rs +++ b/src/booster.rs @@ -13,13 +13,27 @@ use super::{LGBMResult, Dataset, LGBMError}; /// Core model in LightGBM, containing functions for training, evaluating and predicting. pub struct Booster { pub(super) handle: lightgbm_sys::BoosterHandle, - num_class: i64 } impl Booster { - fn new(handle: lightgbm_sys::BoosterHandle, num_class: i64) -> LGBMResult { - Ok(Booster{handle, num_class}) + fn new(handle: lightgbm_sys::BoosterHandle) -> LGBMResult { + Ok(Booster{handle}) + } + + /// Init from model file. + pub fn from_file(filename: String) -> LGBMResult{ + let filename_str = CString::new(filename).unwrap(); + let mut out_num_iterations = 0; + let mut handle = std::ptr::null_mut(); + lgbm_call!( + lightgbm_sys::LGBM_BoosterCreateFromModelfile( + filename_str.as_ptr() as *const c_char, + &mut out_num_iterations, + &mut handle + ) + ).unwrap(); + Ok(Booster::new(handle)?) } /// Create a new Booster model with given Dataset and parameters. @@ -56,14 +70,6 @@ impl Booster { num_iterations = parameter["num_iterations"].as_i64().unwrap(); } - // get num_class - let num_class: i64; - if parameter["num_class"].is_null(){ - num_class = 1; - } else { - num_class = parameter["num_class"].as_i64().unwrap(); - } - // exchange params {"x": "y", "z": 1} => "x=y z=1" let params_string = parameter.as_object().unwrap().iter().map(|(k, v)| format!("{}={}", k, v)).collect::>().join(" "); let params_cstring = CString::new(params_string).unwrap(); @@ -81,7 +87,7 @@ impl Booster { for _ in 1..num_iterations { lgbm_call!(lightgbm_sys::LGBM_BoosterUpdateOneIter(handle, &mut is_finished))?; } - Ok(Booster::new(handle, num_class)?) + Ok(Booster::new(handle)?) } /// Predict results for given data. @@ -102,9 +108,19 @@ impl Booster { let feature_length = data[0].len(); let params = CString::new("").unwrap(); let mut out_length: c_long = 0; - let out_result: Vec = vec![Default::default(); data.len() * self.num_class as usize]; let flat_data = data.into_iter().flatten().collect::>(); + // get num_class + let mut num_class = 0; + lgbm_call!( + lightgbm_sys::LGBM_BoosterGetNumClasses( + self.handle, + &mut num_class + ) + )?; + + let out_result: Vec = vec![Default::default(); data_length * num_class as usize]; + lgbm_call!( lightgbm_sys::LGBM_BoosterPredictForMat( self.handle, @@ -124,13 +140,28 @@ impl Booster { // reshape for multiclass [1,2,3,4,5,6] -> [[1,2,3], [4,5,6]] # 3 class let reshaped_output; - if self.num_class > 1{ - reshaped_output = out_result.chunks(self.num_class as usize).map(|x| x.to_vec()).collect(); + if num_class > 1{ + reshaped_output = out_result.chunks(num_class as usize).map(|x| x.to_vec()).collect(); } else { reshaped_output = vec![out_result]; } Ok(reshaped_output) } + + + /// Save model to file. + pub fn save_file(&self, filename: String){ + let filename_str = CString::new(filename).unwrap(); + lgbm_call!( + lightgbm_sys::LGBM_BoosterSaveModel( + self.handle, + 0 as i32, + -1 as i32, + 0 as i32, + filename_str.as_ptr() as *const c_char + ) + ).unwrap(); + } } @@ -145,6 +176,9 @@ impl Drop for Booster { mod tests { use super::*; use serde_json::json; + use std::path::Path; + use std::fs; + fn read_train_file() -> LGBMResult { Dataset::from_file("lightgbm-sys/lightgbm/examples/binary_classification/binary.train".to_string()) } @@ -173,4 +207,26 @@ mod tests { } assert_eq!(normalized_result, vec![0, 0, 1]); } + + #[test] + fn save_file() { + let dataset = read_train_file().unwrap(); + let params = json!{ + { + "num_iterations": 1, + "objective": "binary", + "metric": "auc", + "data_random_seed": 0 + } + }; + let bst = Booster::train(dataset, ¶ms).unwrap(); + bst.save_file("./test/test_save_file.output".to_string()); + assert!(Path::new("./test/test_save_file.output").exists()); + fs::remove_file("./test/test_save_file.output"); + } + + #[test] + fn from_file(){ + let bst = Booster::from_file("./test/test_from_file.input".to_string()); + } } diff --git a/test/test_from_file.input b/test/test_from_file.input new file mode 100644 index 0000000..c0c27f1 --- /dev/null +++ b/test/test_from_file.input @@ -0,0 +1,127 @@ +tree +version=v3 +num_class=1 +num_tree_per_iteration=1 +label_index=0 +max_feature_idx=27 +objective=binary sigmoid:1 +feature_names=Column_0 Column_1 Column_2 Column_3 Column_4 Column_5 Column_6 Column_7 Column_8 Column_9 Column_10 Column_11 Column_12 Column_13 Column_14 Column_15 Column_16 Column_17 Column_18 Column_19 Column_20 Column_21 Column_22 Column_23 Column_24 Column_25 Column_26 Column_27 +feature_infos=[0.27500000000000002:6.6950000000000003] [-2.4169999999999998:2.4300000000000002] [-1.7429999999999999:1.7429999999999999] [0.019:5.7000000000000002] [-1.7429999999999999:1.7429999999999999] [0.159:4.1900000000000004] [-2.9409999999999998:2.9699999999999998] [-1.7410000000000001:1.7410000000000001] [0:2.173] [0.19:5.1929999999999996] [-2.9039999999999999:2.9089999999999998] [-1.742:1.7429999999999999] [0:2.2149999999999999] [0.26400000000000001:6.5229999999999997] [-2.7279999999999998:2.7269999999999999] [-1.742:1.742] [0:2.548] [0.36499999999999999:6.0679999999999996] [-2.4950000000000001:2.496] [-1.74:1.7429999999999999] [0:3.1019999999999999] [0.17199999999999999:13.098000000000001] [0.41899999999999998:7.3920000000000003] [0.46100000000000002:3.6819999999999999] [0.38400000000000001:6.5830000000000002] [0.092999999999999999:7.8600000000000003] [0.38900000000000001:4.5430000000000001] [0.48899999999999999:4.3159999999999998] +tree_sizes= + +end of trees + +feature_importances: + +parameters: +[boosting: gbdt] +[objective: binary] +[metric: auc] +[tree_learner: serial] +[device_type: cpu] +[linear_tree: 0] +[data: ] +[valid: ] +[num_iterations: 1] +[learning_rate: 0.1] +[num_leaves: 31] +[num_threads: 0] +[deterministic: 0] +[force_col_wise: 0] +[force_row_wise: 0] +[histogram_pool_size: -1] +[max_depth: -1] +[min_data_in_leaf: 20] +[min_sum_hessian_in_leaf: 0.001] +[bagging_fraction: 1] +[pos_bagging_fraction: 1] +[neg_bagging_fraction: 1] +[bagging_freq: 0] +[bagging_seed: 3] +[feature_fraction: 1] +[feature_fraction_bynode: 1] +[feature_fraction_seed: 2] +[extra_trees: 0] +[extra_seed: 6] +[early_stopping_round: 0] +[first_metric_only: 0] +[max_delta_step: 0] +[lambda_l1: 0] +[lambda_l2: 0] +[linear_lambda: 0] +[min_gain_to_split: 0] +[drop_rate: 0.1] +[max_drop: 50] +[skip_drop: 0.5] +[xgboost_dart_mode: 0] +[uniform_drop: 0] +[drop_seed: 4] +[top_rate: 0.2] +[other_rate: 0.1] +[min_data_per_group: 100] +[max_cat_threshold: 32] +[cat_l2: 10] +[cat_smooth: 10] +[max_cat_to_onehot: 4] +[top_k: 20] +[monotone_constraints: ] +[monotone_constraints_method: basic] +[monotone_penalty: 0] +[feature_contri: ] +[forcedsplits_filename: ] +[refit_decay_rate: 0.9] +[cegb_tradeoff: 1] +[cegb_penalty_split: 0] +[cegb_penalty_feature_lazy: ] +[cegb_penalty_feature_coupled: ] +[path_smooth: 0] +[interaction_constraints: ] +[verbosity: 1] +[saved_feature_importance_type: 0] +[max_bin: 255] +[max_bin_by_feature: ] +[min_data_in_bin: 3] +[bin_construct_sample_cnt: 200000] +[data_random_seed: 0] +[is_enable_sparse: 1] +[enable_bundle: 1] +[use_missing: 1] +[zero_as_missing: 0] +[feature_pre_filter: 1] +[pre_partition: 0] +[two_round: 0] +[header: 0] +[label_column: ] +[weight_column: ] +[group_column: ] +[ignore_column: ] +[categorical_feature: ] +[forcedbins_filename: ] +[objective_seed: 5] +[num_class: 1] +[is_unbalance: 0] +[scale_pos_weight: 1] +[sigmoid: 1] +[boost_from_average: 1] +[reg_sqrt: 0] +[alpha: 0.9] +[fair_c: 1] +[poisson_max_delta_step: 0.7] +[tweedie_variance_power: 1.5] +[lambdarank_truncation_level: 30] +[lambdarank_norm: 1] +[label_gain: ] +[eval_at: ] +[multi_error_top_k: 1] +[auc_mu_weights: ] +[num_machines: 1] +[local_listen_port: 12400] +[time_out: 120] +[machine_list_filename: ] +[machines: ] +[gpu_platform_id: -1] +[gpu_device_id: -1] +[gpu_use_dp: 0] +[num_gpu: 1] + +end of parameters From f3a88e466960397791ad3d6dec312ae2619390fa Mon Sep 17 00:00:00 2001 From: vaaaaanquish <6syun9@gmail.com> Date: Thu, 21 Jan 2021 22:13:26 +0900 Subject: [PATCH 2/3] update version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index aae5c9c..9466d09 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lightgbm" -version = "0.1.2" +version = "0.1.3" authors = ["vaaaaanquish <6syun9@gmail.com>"] license = "MIT" repository = "https://github.com/vaaaaanquish/LightGBM" From dbb0e0feafa7bd344723e9e97b6c0444972fe50c Mon Sep 17 00:00:00 2001 From: vaaaaanquish <6syun9@gmail.com> Date: Thu, 21 Jan 2021 22:24:49 +0900 Subject: [PATCH 3/3] apt-get update --- .github/workflows/test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 404d192..cd420e3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,6 +32,7 @@ jobs: - name: Build for ubuntu if: matrix.os == 'ubuntu-latest' run: | + sudo apt-get update sudo apt-get install -y cmake libclang-dev libc++-dev gcc-multilib cargo build - name: Run tests