Skip to content

Commit

Permalink
Merge pull request #7 from vaaaaanquish/fix_train_example
Browse files Browse the repository at this point in the history
fix train example
  • Loading branch information
vaaaaanquish authored Jan 15, 2021
2 parents b24c4d7 + eefac57 commit e2663f9
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 51 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "lightgbm"
version = "0.1.0"
version = "0.1.1"
authors = ["vaaaaanquish <[email protected]>"]
license = "MIT"
repository = "https://github.com/vaaaaanquish/LightGBM"
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
LightGBM Rust binding


Now: Done is better than perfect.


# develop

```
Expand Down
2 changes: 1 addition & 1 deletion examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ authors = ["vaaaaanquish <[email protected]>"]
publish = false

[dependencies]
lightgbm = "0.1.0"
lightgbm = "0.1.1"
csv = "1.1.5"
itertools = "0.9.0"
27 changes: 20 additions & 7 deletions examples/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,27 @@ fn main() -> std::io::Result<()> {
// let label = vec![0.0, 0.0, 0.0, 1.0, 1.0];
// let train_dataset = Dataset::from_mat(feature, label).unwrap();

let train_dataset = Dataset::from_file("../lightgbm-sys/lightgbm/examples/binary_classification/binary.train".to_string()).unwrap();
// let train_dataset = Dataset::from_file("../lightgbm-sys/lightgbm/examples/binary_classification/binary.train".to_string()).unwrap();

let mut train_rdr = csv::ReaderBuilder::new().has_headers(false).delimiter(b'\t').from_path("../lightgbm-sys/lightgbm/examples/binary_classification/binary.train")?;
let mut train_labels: Vec<f32> = Vec::new();
let mut train_feature: Vec<Vec<f64>> = Vec::new();
for result in train_rdr.records() {
let record = result?;
let label = record[0].parse::<f32>().unwrap();
let feature: Vec<f64> = record.iter().map(|x| x.parse::<f64>().unwrap()).collect::<Vec<f64>>()[1..].to_vec();
train_labels.push(label);
train_feature.push(feature);
}
let train_dataset = Dataset::from_mat(train_feature, train_labels).unwrap();

let mut rdr = csv::ReaderBuilder::new().has_headers(false).delimiter(b'\t').from_path("../lightgbm-sys/lightgbm/examples/binary_classification/binary.test")?;
let mut test_labels: Vec<i8> = Vec::new();
let mut test_feature: Vec<Vec<f32>> = Vec::new();
let mut test_labels: Vec<f32> = Vec::new();
let mut test_feature: Vec<Vec<f64>> = Vec::new();
for result in rdr.records() {
let record = result?;
let label = record[0].parse::<i8>().unwrap();
let feature: Vec<f32> = record.iter().map(|x| x.parse::<f32>().unwrap()).collect::<Vec<f32>>()[1..].to_vec();
let label = record[0].parse::<f32>().unwrap();
let feature: Vec<f64> = record.iter().map(|x| x.parse::<f64>().unwrap()).collect::<Vec<f64>>()[1..].to_vec();
test_labels.push(label);
test_feature.push(feature);
}
Expand All @@ -32,11 +44,12 @@ fn main() -> std::io::Result<()> {

let mut tp = 0;
for (label, pred) in zip(&test_labels, &result){
if label == &(1 as i8) && pred > &(0.5 as f64) {
if label == &(1 as f32) && pred > &(0.5 as f64) {
tp = tp + 1;
} else if label == &(0 as i8) && pred <= &(0.5 as f64) {
} else if label == &(0 as f32) && pred <= &(0.5 as f64) {
tp = tp + 1;
}
println!("{}, {}", label, pred)
}
println!("{} / {}", &tp, result.len());
Ok(())
Expand Down
47 changes: 23 additions & 24 deletions src/booster.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use lightgbm_sys;

use libc::{c_char, c_int, c_float, c_double, c_long, c_void};
use libc::{c_char, c_double, c_void, c_long};
use std::ffi::CString;
use std::convert::TryInto;
use std;

use super::{LGBMResult, Dataset};
Expand All @@ -18,49 +17,49 @@ impl Booster {
}

pub fn train(dataset: Dataset) -> LGBMResult<Self> {
let params = CString::new("objective=binary metric=auc").unwrap();
let mut handle = std::ptr::null_mut();
let mut params = CString::new("app=binary metric=auc num_leaves=31").unwrap();
unsafe {
lightgbm_sys::LGBM_BoosterCreate(
dataset.handle,
params.as_ptr() as *const c_char,
&mut handle);
&mut handle
);
}

// train
let mut is_finished: i32 = 0;
unsafe{
for n in 1..50 {
let ret = lightgbm_sys::LGBM_BoosterUpdateOneIter(handle, &mut is_finished);
for _ in 1..100 {
lightgbm_sys::LGBM_BoosterUpdateOneIter(handle, &mut is_finished);
}
}
Ok(Booster::new(handle)?)
}

pub fn predict(&self, data: Vec<Vec<f32>>) -> LGBMResult<Vec<f64>> {
let data_length = data.len() as i32;
let feature_length = data[0].len() as i32;
let mut params = CString::new("").unwrap();
let mut out_len: c_long = 0;
// let mut out_result = Vec::with_capacity(data_length.try_into().unwrap());
let data_size = data_length.try_into().unwrap();
let mut out_result: Vec<f64> = vec![Default::default(); data_size];
pub fn predict(&self, data: Vec<Vec<f64>>) -> LGBMResult<Vec<f64>> {
let data_length = data.len();
let feature_length = data[0].len();
let params = CString::new("").unwrap();
let mut out_length: c_long = 0;
let out_result: Vec<f64> = vec![Default::default(); data.len()];
let flat_data = data.into_iter().flatten().collect::<Vec<_>>();

unsafe {
lightgbm_sys::LGBM_BoosterPredictForMat(
self.handle,
data.as_ptr() as * mut c_void,
lightgbm_sys::C_API_DTYPE_FLOAT32.try_into().unwrap(),
data_length,
feature_length,
0,
0,
0,
0,
flat_data.as_ptr() as *const c_void,
lightgbm_sys::C_API_DTYPE_FLOAT64 as i32,
data_length as i32,
feature_length as i32,
1 as i32,
0 as i32,
0 as i32,
-1 as i32,
params.as_ptr() as *const c_char,
&mut out_len,
&mut out_length,
out_result.as_ptr() as *mut c_double
);
);
}
Ok(out_result)
}
Expand Down
40 changes: 22 additions & 18 deletions src/dataset.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use libc::{c_void,c_char};

use std;
use std::convert::TryInto;
use std::ffi::CString;
use lightgbm_sys;

Expand All @@ -17,45 +16,50 @@ impl Dataset {
Ok(Dataset{handle})
}

pub fn from_mat(data: Vec<Vec<f32>>, label: Vec<f32>) -> LGBMResult<Self> {
let mut handle = std::ptr::null_mut();
let data_length = data.len() as i32;
let feature_length = data[0].len() as i32;
pub fn from_mat(data: Vec<Vec<f64>>, label: Vec<f32>) -> LGBMResult<Self> {
let data_length = data.len();
let feature_length = data[0].len();
let params = CString::new("").unwrap();
let label_str = CString::new("label").unwrap();
let reference = std::ptr::null_mut(); // not use
let mut handle = std::ptr::null_mut();
let flat_data = data.into_iter().flatten().collect::<Vec<_>>();

unsafe{
lightgbm_sys::LGBM_DatasetCreateFromMat(
data.as_ptr() as * mut c_void,
lightgbm_sys::C_API_DTYPE_FLOAT32.try_into().unwrap(),
data_length,
feature_length,
1,
flat_data.as_ptr() as *const c_void,
lightgbm_sys::C_API_DTYPE_FLOAT64 as i32,
data_length as i32,
feature_length as i32,
1 as i32,
params.as_ptr() as *const c_char,
std::ptr::null_mut(),
&mut handle);
reference,
&mut handle
);

lightgbm_sys::LGBM_DatasetSetField(
handle,
label_str.as_ptr() as *const c_char,
label.as_ptr() as * mut c_void,
data_length,
lightgbm_sys::C_API_DTYPE_FLOAT32.try_into().unwrap());
label.as_ptr() as *const c_void,
data_length as i32,
lightgbm_sys::C_API_DTYPE_FLOAT32 as i32
);
}
Ok(Dataset::new(handle)?)
}

pub fn from_file(file_path: String) -> LGBMResult<Self> {
let mut handle = std::ptr::null_mut();
let file_path_str = CString::new(file_path).unwrap();
let params = CString::new("").unwrap();
let mut handle = std::ptr::null_mut();

unsafe {
lightgbm_sys::LGBM_DatasetCreateFromFile(
file_path_str.as_ptr() as * const c_char,
file_path_str.as_ptr() as *const c_char,
params.as_ptr() as *const c_char,
std::ptr::null_mut(),
&mut handle);
&mut handle
);
}
Ok(Dataset::new(handle)?)
}
Expand Down

0 comments on commit e2663f9

Please sign in to comment.