Skip to content

Commit

Permalink
[+] Boster::predict_row implemented.
Browse files Browse the repository at this point in the history
  • Loading branch information
npatsakula committed Oct 31, 2023
1 parent 16b71a8 commit f469129
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions src/booster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,35 @@ impl Booster {
Ok(reshaped_output)
}

/// Predict for single row.
pub fn predict_row(&self, data: Vec<f64>) -> Result<Vec<f64>> {
let feature_length = data.len();
let params = CString::new("").unwrap();

let mut num_class = 0;
lgbm_call!(lightgbm_sys::LGBM_BoosterGetNumClasses(
self.handle,
&mut num_class
))?;
let mut out_result = vec![Default::default(); num_class as usize];

lgbm_call!(lightgbm_sys::LGBM_BoosterPredictForMatSingleRow(
self.handle,
data.as_ptr().cast(),
lightgbm_sys::C_API_DTYPE_FLOAT64 as _,
feature_length as _,
1,
lightgbm_sys::C_API_PREDICT_NORMAL as _,
0,
-1,
params.as_ptr().cast(),
&mut 0,
out_result.as_mut_ptr(),
))?;

Ok(out_result)
}

/// Get Feature Num.
pub fn num_feature(&self) -> Result<i32> {
let mut out_len = 0;
Expand Down Expand Up @@ -280,6 +309,26 @@ mod tests {
assert_eq!(normalized_result, vec![0, 0, 1]);
}

#[test]
fn predict_single_row() {
let params = json! {
{
"num_iterations": 10,
"objective": "binary",
"metric": "auc",
"data_random_seed": 0
}
};
let bst = _train_booster(&params);
let feature = vec![0.9; 28];
let result = bst.predict_row(feature).unwrap();
let mut normalized_result = Vec::new();
for r in &result {
normalized_result.push(if r > &0.5 { 1 } else { 0 });
}
assert_eq!(normalized_result, vec![1]);
}

#[test]
fn num_feature() {
let params = _default_params();
Expand Down

0 comments on commit f469129

Please sign in to comment.