From 934f0ac3b288f6aaf77a2c2c527678d3059da3f8 Mon Sep 17 00:00:00 2001 From: benjamin ellis Date: Thu, 28 Oct 2021 20:08:50 +0100 Subject: [PATCH] adding integration with polars --- .github/workflows/ci.yml | 6 +-- Cargo.toml | 6 +++ src/dataset.rs | 90 ++++++++++++++++++++++++++++++++++++++++ src/error.rs | 14 ++++++- src/lib.rs | 3 ++ 5 files changed, 115 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8de554c..4c36bde 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -27,15 +27,15 @@ jobs: run: | brew install cmake brew install libomp - cargo build + cargo build --all-features - name: Build for ubuntu if: matrix.os == 'ubuntu-latest' run: | sudo apt-get update sudo apt-get install -y cmake libclang-dev libc++-dev gcc-multilib - cargo build + cargo build --all-features - name: Run tests - run: cargo test + run: cargo test --all-features continue-on-error: ${{ matrix.rust == 'nightly' }} - name: Run Clippy uses: actions-rs/clippy-check@v1 diff --git a/Cargo.toml b/Cargo.toml index fd58a95..4a3f2b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,3 +13,9 @@ lightgbm-sys = { path = "lightgbm-sys", version = "0.3.0" } libc = "0.2.81" derive_builder = "0.5.1" serde_json = "1.0.59" +polars = {version = "0.16.0", optional = true} + + +[features] +default = [] +dataframe = ["polars"] diff --git a/src/dataset.rs b/src/dataset.rs index aa56974..b6e1d6f 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -3,6 +3,9 @@ use lightgbm_sys; use std; use std::ffi::CString; +#[cfg(feature = "dataframe")] +use polars::prelude::*; + use crate::{Error, Result}; /// Dataset used throughout LightGBM for training. @@ -118,6 +121,76 @@ impl Dataset { Ok(Self::new(handle)) } + + /// Create a new `Dataset` from a polars DataFrame. + /// + /// Note: the feature ```dataframe``` is required for this method + /// + /// Example + /// + #[cfg_attr( + feature = "dataframe", + doc = r##" + extern crate polars; + + use lightgbm::Dataset; + use polars::prelude::*; + use polars::df; + + let df: DataFrame = df![ + "feature_1" => [1.0, 0.7, 0.9, 0.2, 0.1], + "feature_2" => [0.1, 0.4, 0.8, 0.2, 0.7], + "feature_3" => [0.2, 0.5, 0.5, 0.1, 0.1], + "feature_4" => [0.1, 0.1, 0.1, 0.7, 0.9], + "label" => [0.0, 0.0, 0.0, 1.0, 1.0] + ].unwrap(); + let dataset = Dataset::from_dataframe(df, String::from("label")).unwrap(); + "## + )] + #[cfg(feature = "dataframe")] + pub fn from_dataframe(mut dataframe: DataFrame, label_column: String) -> Result { + let label_col_name = label_column.as_str(); + + let (m, n) = dataframe.shape(); + + let label_series = &dataframe.select_series(label_col_name)?[0].cast::()?; + + if label_series.null_count() != 0 { + panic!("Cannot create a dataset with null values, encountered nulls when creating the label array") + } + + dataframe.drop_in_place(label_col_name)?; + + let mut label_values = Vec::with_capacity(m); + + let label_values_ca = label_series.unpack::()?; + + label_values_ca + .into_no_null_iter() + .enumerate() + .for_each(|(_row_idx, val)| { + label_values.push(val); + }); + + let mut feature_values = Vec::with_capacity(m); + for _i in 0..m { + feature_values.push(Vec::with_capacity(n)); + } + + for (_col_idx, series) in dataframe.get_columns().iter().enumerate() { + if series.null_count() != 0 { + panic!("Cannot create a dataset with null values, encountered nulls when creating the features array") + } + + let series = series.cast::()?; + let ca = series.unpack::()?; + + ca.into_no_null_iter() + .enumerate() + .for_each(|(row_idx, val)| feature_values[row_idx].push(val)); + } + Self::from_mat(feature_values, label_values) + } } impl Drop for Dataset { @@ -151,4 +224,21 @@ mod tests { let dataset = Dataset::from_mat(data, label); assert!(dataset.is_ok()); } + + #[cfg(feature = "dataframe")] + #[test] + fn from_dataframe() { + use polars::df; + let df: DataFrame = df![ + "feature_1" => [1.0, 0.7, 0.9, 0.2, 0.1], + "feature_2" => [0.1, 0.4, 0.8, 0.2, 0.7], + "feature_3" => [0.2, 0.5, 0.5, 0.1, 0.1], + "feature_4" => [0.1, 0.1, 0.1, 0.7, 0.9], + "label" => [0.0, 0.0, 0.0, 1.0, 1.0] + ] + .unwrap(); + + let df_dataset = Dataset::from_dataframe(df, String::from("label")); + assert!(df_dataset.is_ok()); + } } diff --git a/src/error.rs b/src/error.rs index 1c6bee3..8564eb8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,10 +2,13 @@ use std::error; use std::ffi::CStr; -use std::fmt::{self, Display}; +use std::fmt::{self, Debug, Display}; use lightgbm_sys; +#[cfg(feature = "dataframe")] +use polars::prelude::*; + /// Convenience return type for most operations which can return an `LightGBM`. pub type Result = std::result::Result; @@ -49,6 +52,15 @@ impl Display for Error { } } +#[cfg(feature = "dataframe")] +impl From for Error { + fn from(pe: PolarsError) -> Self { + Self { + desc: pe.to_string(), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/lib.rs b/src/lib.rs index d7c1721..2867d47 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,9 @@ extern crate libc; extern crate lightgbm_sys; extern crate serde_json; +#[cfg(feature = "dataframe")] +extern crate polars; + #[macro_use] macro_rules! lgbm_call { ($x:expr) => {