Skip to content

Commit

Permalink
Merge pull request #13 from vaaaaanquish/save_file
Browse files Browse the repository at this point in the history
Save file/Load file
  • Loading branch information
vaaaaanquish authored Jan 21, 2021
2 parents ae7df90 + dbb0e0f commit bf0dea7
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 16 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jobs:
- name: Build for ubuntu
if: matrix.os == 'ubuntu-latest'
run: |
sudo apt-get update
sudo apt-get install -y cmake libclang-dev libc++-dev gcc-multilib
cargo build
- name: Run tests
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "lightgbm"
version = "0.1.2"
version = "0.1.3"
authors = ["vaaaaanquish <[email protected]>"]
license = "MIT"
repository = "https://github.com/vaaaaanquish/LightGBM"
Expand Down
86 changes: 71 additions & 15 deletions src/booster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,27 @@ use super::{LGBMResult, Dataset, LGBMError};
/// Core model in LightGBM, containing functions for training, evaluating and predicting.
pub struct Booster {
pub(super) handle: lightgbm_sys::BoosterHandle,
num_class: i64
}


impl Booster {
fn new(handle: lightgbm_sys::BoosterHandle, num_class: i64) -> LGBMResult<Self> {
Ok(Booster{handle, num_class})
fn new(handle: lightgbm_sys::BoosterHandle) -> LGBMResult<Self> {
Ok(Booster{handle})
}

/// Init from model file.
pub fn from_file(filename: String) -> LGBMResult<Self>{
let filename_str = CString::new(filename).unwrap();
let mut out_num_iterations = 0;
let mut handle = std::ptr::null_mut();
lgbm_call!(
lightgbm_sys::LGBM_BoosterCreateFromModelfile(
filename_str.as_ptr() as *const c_char,
&mut out_num_iterations,
&mut handle
)
).unwrap();
Ok(Booster::new(handle)?)
}

/// Create a new Booster model with given Dataset and parameters.
Expand Down Expand Up @@ -56,14 +70,6 @@ impl Booster {
num_iterations = parameter["num_iterations"].as_i64().unwrap();
}

// get num_class
let num_class: i64;
if parameter["num_class"].is_null(){
num_class = 1;
} else {
num_class = parameter["num_class"].as_i64().unwrap();
}

// exchange params {"x": "y", "z": 1} => "x=y z=1"
let params_string = parameter.as_object().unwrap().iter().map(|(k, v)| format!("{}={}", k, v)).collect::<Vec<_>>().join(" ");
let params_cstring = CString::new(params_string).unwrap();
Expand All @@ -81,7 +87,7 @@ impl Booster {
for _ in 1..num_iterations {
lgbm_call!(lightgbm_sys::LGBM_BoosterUpdateOneIter(handle, &mut is_finished))?;
}
Ok(Booster::new(handle, num_class)?)
Ok(Booster::new(handle)?)
}

/// Predict results for given data.
Expand All @@ -102,9 +108,19 @@ impl Booster {
let feature_length = data[0].len();
let params = CString::new("").unwrap();
let mut out_length: c_long = 0;
let out_result: Vec<f64> = vec![Default::default(); data.len() * self.num_class as usize];
let flat_data = data.into_iter().flatten().collect::<Vec<_>>();

// get num_class
let mut num_class = 0;
lgbm_call!(
lightgbm_sys::LGBM_BoosterGetNumClasses(
self.handle,
&mut num_class
)
)?;

let out_result: Vec<f64> = vec![Default::default(); data_length * num_class as usize];

lgbm_call!(
lightgbm_sys::LGBM_BoosterPredictForMat(
self.handle,
Expand All @@ -124,13 +140,28 @@ impl Booster {

// reshape for multiclass [1,2,3,4,5,6] -> [[1,2,3], [4,5,6]] # 3 class
let reshaped_output;
if self.num_class > 1{
reshaped_output = out_result.chunks(self.num_class as usize).map(|x| x.to_vec()).collect();
if num_class > 1{
reshaped_output = out_result.chunks(num_class as usize).map(|x| x.to_vec()).collect();
} else {
reshaped_output = vec![out_result];
}
Ok(reshaped_output)
}


/// Save model to file.
pub fn save_file(&self, filename: String){
let filename_str = CString::new(filename).unwrap();
lgbm_call!(
lightgbm_sys::LGBM_BoosterSaveModel(
self.handle,
0 as i32,
-1 as i32,
0 as i32,
filename_str.as_ptr() as *const c_char
)
).unwrap();
}
}


Expand All @@ -145,6 +176,9 @@ impl Drop for Booster {
mod tests {
use super::*;
use serde_json::json;
use std::path::Path;
use std::fs;

fn read_train_file() -> LGBMResult<Dataset> {
Dataset::from_file("lightgbm-sys/lightgbm/examples/binary_classification/binary.train".to_string())
}
Expand Down Expand Up @@ -173,4 +207,26 @@ mod tests {
}
assert_eq!(normalized_result, vec![0, 0, 1]);
}

#[test]
fn save_file() {
let dataset = read_train_file().unwrap();
let params = json!{
{
"num_iterations": 1,
"objective": "binary",
"metric": "auc",
"data_random_seed": 0
}
};
let bst = Booster::train(dataset, &params).unwrap();
bst.save_file("./test/test_save_file.output".to_string());
assert!(Path::new("./test/test_save_file.output").exists());
fs::remove_file("./test/test_save_file.output");
}

#[test]
fn from_file(){
let bst = Booster::from_file("./test/test_from_file.input".to_string());
}
}
127 changes: 127 additions & 0 deletions test/test_from_file.input
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
tree
version=v3
num_class=1
num_tree_per_iteration=1
label_index=0
max_feature_idx=27
objective=binary sigmoid:1
feature_names=Column_0 Column_1 Column_2 Column_3 Column_4 Column_5 Column_6 Column_7 Column_8 Column_9 Column_10 Column_11 Column_12 Column_13 Column_14 Column_15 Column_16 Column_17 Column_18 Column_19 Column_20 Column_21 Column_22 Column_23 Column_24 Column_25 Column_26 Column_27
feature_infos=[0.27500000000000002:6.6950000000000003] [-2.4169999999999998:2.4300000000000002] [-1.7429999999999999:1.7429999999999999] [0.019:5.7000000000000002] [-1.7429999999999999:1.7429999999999999] [0.159:4.1900000000000004] [-2.9409999999999998:2.9699999999999998] [-1.7410000000000001:1.7410000000000001] [0:2.173] [0.19:5.1929999999999996] [-2.9039999999999999:2.9089999999999998] [-1.742:1.7429999999999999] [0:2.2149999999999999] [0.26400000000000001:6.5229999999999997] [-2.7279999999999998:2.7269999999999999] [-1.742:1.742] [0:2.548] [0.36499999999999999:6.0679999999999996] [-2.4950000000000001:2.496] [-1.74:1.7429999999999999] [0:3.1019999999999999] [0.17199999999999999:13.098000000000001] [0.41899999999999998:7.3920000000000003] [0.46100000000000002:3.6819999999999999] [0.38400000000000001:6.5830000000000002] [0.092999999999999999:7.8600000000000003] [0.38900000000000001:4.5430000000000001] [0.48899999999999999:4.3159999999999998]
tree_sizes=

end of trees

feature_importances:

parameters:
[boosting: gbdt]
[objective: binary]
[metric: auc]
[tree_learner: serial]
[device_type: cpu]
[linear_tree: 0]
[data: ]
[valid: ]
[num_iterations: 1]
[learning_rate: 0.1]
[num_leaves: 31]
[num_threads: 0]
[deterministic: 0]
[force_col_wise: 0]
[force_row_wise: 0]
[histogram_pool_size: -1]
[max_depth: -1]
[min_data_in_leaf: 20]
[min_sum_hessian_in_leaf: 0.001]
[bagging_fraction: 1]
[pos_bagging_fraction: 1]
[neg_bagging_fraction: 1]
[bagging_freq: 0]
[bagging_seed: 3]
[feature_fraction: 1]
[feature_fraction_bynode: 1]
[feature_fraction_seed: 2]
[extra_trees: 0]
[extra_seed: 6]
[early_stopping_round: 0]
[first_metric_only: 0]
[max_delta_step: 0]
[lambda_l1: 0]
[lambda_l2: 0]
[linear_lambda: 0]
[min_gain_to_split: 0]
[drop_rate: 0.1]
[max_drop: 50]
[skip_drop: 0.5]
[xgboost_dart_mode: 0]
[uniform_drop: 0]
[drop_seed: 4]
[top_rate: 0.2]
[other_rate: 0.1]
[min_data_per_group: 100]
[max_cat_threshold: 32]
[cat_l2: 10]
[cat_smooth: 10]
[max_cat_to_onehot: 4]
[top_k: 20]
[monotone_constraints: ]
[monotone_constraints_method: basic]
[monotone_penalty: 0]
[feature_contri: ]
[forcedsplits_filename: ]
[refit_decay_rate: 0.9]
[cegb_tradeoff: 1]
[cegb_penalty_split: 0]
[cegb_penalty_feature_lazy: ]
[cegb_penalty_feature_coupled: ]
[path_smooth: 0]
[interaction_constraints: ]
[verbosity: 1]
[saved_feature_importance_type: 0]
[max_bin: 255]
[max_bin_by_feature: ]
[min_data_in_bin: 3]
[bin_construct_sample_cnt: 200000]
[data_random_seed: 0]
[is_enable_sparse: 1]
[enable_bundle: 1]
[use_missing: 1]
[zero_as_missing: 0]
[feature_pre_filter: 1]
[pre_partition: 0]
[two_round: 0]
[header: 0]
[label_column: ]
[weight_column: ]
[group_column: ]
[ignore_column: ]
[categorical_feature: ]
[forcedbins_filename: ]
[objective_seed: 5]
[num_class: 1]
[is_unbalance: 0]
[scale_pos_weight: 1]
[sigmoid: 1]
[boost_from_average: 1]
[reg_sqrt: 0]
[alpha: 0.9]
[fair_c: 1]
[poisson_max_delta_step: 0.7]
[tweedie_variance_power: 1.5]
[lambdarank_truncation_level: 30]
[lambdarank_norm: 1]
[label_gain: ]
[eval_at: ]
[multi_error_top_k: 1]
[auc_mu_weights: ]
[num_machines: 1]
[local_listen_port: 12400]
[time_out: 120]
[machine_list_filename: ]
[machines: ]
[gpu_platform_id: -1]
[gpu_device_id: -1]
[gpu_use_dp: 0]
[num_gpu: 1]

end of parameters

0 comments on commit bf0dea7

Please sign in to comment.