Skip to content

Commit

Permalink
adding integration with polars
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminjellis committed Oct 28, 2021
1 parent fe6cadc commit 934f0ac
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 4 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ jobs:
run: |
brew install cmake
brew install libomp
cargo build
cargo build --all-features
- name: Build for ubuntu
if: matrix.os == 'ubuntu-latest'
run: |
sudo apt-get update
sudo apt-get install -y cmake libclang-dev libc++-dev gcc-multilib
cargo build
cargo build --all-features
- name: Run tests
run: cargo test
run: cargo test --all-features
continue-on-error: ${{ matrix.rust == 'nightly' }}
- name: Run Clippy
uses: actions-rs/clippy-check@v1
Expand Down
6 changes: 6 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,9 @@ lightgbm-sys = { path = "lightgbm-sys", version = "0.3.0" }
libc = "0.2.81"
derive_builder = "0.5.1"
serde_json = "1.0.59"
polars = {version = "0.16.0", optional = true}


[features]
default = []
dataframe = ["polars"]
90 changes: 90 additions & 0 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use lightgbm_sys;
use std;
use std::ffi::CString;

#[cfg(feature = "dataframe")]
use polars::prelude::*;

use crate::{Error, Result};

/// Dataset used throughout LightGBM for training.
Expand Down Expand Up @@ -118,6 +121,76 @@ impl Dataset {

Ok(Self::new(handle))
}

/// Create a new `Dataset` from a polars DataFrame.
///
/// Note: the feature ```dataframe``` is required for this method
///
/// Example
///
#[cfg_attr(
feature = "dataframe",
doc = r##"
extern crate polars;
use lightgbm::Dataset;
use polars::prelude::*;
use polars::df;
let df: DataFrame = df![
"feature_1" => [1.0, 0.7, 0.9, 0.2, 0.1],
"feature_2" => [0.1, 0.4, 0.8, 0.2, 0.7],
"feature_3" => [0.2, 0.5, 0.5, 0.1, 0.1],
"feature_4" => [0.1, 0.1, 0.1, 0.7, 0.9],
"label" => [0.0, 0.0, 0.0, 1.0, 1.0]
].unwrap();
let dataset = Dataset::from_dataframe(df, String::from("label")).unwrap();
"##
)]
#[cfg(feature = "dataframe")]
pub fn from_dataframe(mut dataframe: DataFrame, label_column: String) -> Result<Self> {
let label_col_name = label_column.as_str();

let (m, n) = dataframe.shape();

let label_series = &dataframe.select_series(label_col_name)?[0].cast::<Float32Type>()?;

if label_series.null_count() != 0 {
panic!("Cannot create a dataset with null values, encountered nulls when creating the label array")
}

dataframe.drop_in_place(label_col_name)?;

let mut label_values = Vec::with_capacity(m);

let label_values_ca = label_series.unpack::<Float32Type>()?;

label_values_ca
.into_no_null_iter()
.enumerate()
.for_each(|(_row_idx, val)| {
label_values.push(val);
});

let mut feature_values = Vec::with_capacity(m);
for _i in 0..m {
feature_values.push(Vec::with_capacity(n));
}

for (_col_idx, series) in dataframe.get_columns().iter().enumerate() {
if series.null_count() != 0 {
panic!("Cannot create a dataset with null values, encountered nulls when creating the features array")
}

let series = series.cast::<Float64Type>()?;
let ca = series.unpack::<Float64Type>()?;

ca.into_no_null_iter()
.enumerate()
.for_each(|(row_idx, val)| feature_values[row_idx].push(val));
}
Self::from_mat(feature_values, label_values)
}
}

impl Drop for Dataset {
Expand Down Expand Up @@ -151,4 +224,21 @@ mod tests {
let dataset = Dataset::from_mat(data, label);
assert!(dataset.is_ok());
}

#[cfg(feature = "dataframe")]
#[test]
fn from_dataframe() {
use polars::df;
let df: DataFrame = df![
"feature_1" => [1.0, 0.7, 0.9, 0.2, 0.1],
"feature_2" => [0.1, 0.4, 0.8, 0.2, 0.7],
"feature_3" => [0.2, 0.5, 0.5, 0.1, 0.1],
"feature_4" => [0.1, 0.1, 0.1, 0.7, 0.9],
"label" => [0.0, 0.0, 0.0, 1.0, 1.0]
]
.unwrap();

let df_dataset = Dataset::from_dataframe(df, String::from("label"));
assert!(df_dataset.is_ok());
}
}
14 changes: 13 additions & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
use std::error;
use std::ffi::CStr;
use std::fmt::{self, Display};
use std::fmt::{self, Debug, Display};

use lightgbm_sys;

#[cfg(feature = "dataframe")]
use polars::prelude::*;

/// Convenience return type for most operations which can return an `LightGBM`.
pub type Result<T> = std::result::Result<T, Error>;

Expand Down Expand Up @@ -49,6 +52,15 @@ impl Display for Error {
}
}

#[cfg(feature = "dataframe")]
impl From<PolarsError> for Error {
fn from(pe: PolarsError) -> Self {
Self {
desc: pe.to_string(),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ extern crate libc;
extern crate lightgbm_sys;
extern crate serde_json;

#[cfg(feature = "dataframe")]
extern crate polars;

#[macro_use]
macro_rules! lgbm_call {
($x:expr) => {
Expand Down

0 comments on commit 934f0ac

Please sign in to comment.