Skip to content

Commit

Permalink
Merge pull request #9 from vaaaaanquish/add_regression
Browse files Browse the repository at this point in the history
add regression examples
  • Loading branch information
vaaaaanquish authored Jan 16, 2021
2 parents 5cfaa71 + 6e6be3b commit 573565b
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ lightgbm-sys/target
# example
examples/binary_classification/target/
examples/multiclass_classification/target/
examples/regression/target/
11 changes: 11 additions & 0 deletions examples/regression/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[package]
name = "lightgbm-example-regression"
version = "0.1.0"
authors = ["vaaaaanquish <[email protected]>"]
publish = false

[dependencies]
lightgbm = { path = "../../" }
csv = "1.1.5"
itertools = "0.9.0"
serde_json = "1.0.59"
55 changes: 55 additions & 0 deletions examples/regression/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
extern crate lightgbm;
extern crate csv;
extern crate serde_json;
extern crate itertools;


use itertools::zip;
use lightgbm::{Dataset, Booster};
use serde_json::json;


fn load_file(file_path: &str) -> (Vec<Vec<f64>>, Vec<f32>) {
let rdr = csv::ReaderBuilder::new().has_headers(false).delimiter(b'\t').from_path(file_path);
let mut labels: Vec<f32> = Vec::new();
let mut features: Vec<Vec<f64>> = Vec::new();
for result in rdr.unwrap().records() {
let record = result.unwrap();
let label = record[0].parse::<f32>().unwrap();
let feature: Vec<f64> = record.iter().map(|x| x.parse::<f64>().unwrap()).collect::<Vec<f64>>()[1..].to_vec();
labels.push(label);
features.push(feature);
}
(features, labels)
}


fn main() -> std::io::Result<()> {
let (train_features, train_labels) = load_file("../../lightgbm-sys/lightgbm/examples/regression/regression.train");
let (test_features, test_labels) = load_file("../../lightgbm-sys/lightgbm/examples/regression/regression.test");
let train_dataset = Dataset::from_mat(train_features, train_labels).unwrap();

let params = json!{
{
"num_iterations": 100,
"objective": "regression",
"metric": "l2"
}
};

let booster = Booster::train(train_dataset, &params).unwrap();
let result = booster.predict(test_features).unwrap();


let mut tp = 0;
for (label, pred) in zip(&test_labels, &result[0]){
if label == &(1 as f32) && pred > &(0.5 as f64) {
tp = tp + 1;
} else if label == &(0 as f32) && pred <= &(0.5 as f64) {
tp = tp + 1;
}
println!("{}, {}", label, pred)
}
println!("{} / {}", &tp, result[0].len());
Ok(())
}

0 comments on commit 573565b

Please sign in to comment.