Skip to content

Commit

Permalink
Trying to call mnist model in main
Browse files Browse the repository at this point in the history
Signed-off-by: Aisuko <[email protected]>
  • Loading branch information
Aisuko committed Nov 18, 2023
1 parent 660cc49 commit a6ff963
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 52 deletions.
144 changes: 102 additions & 42 deletions backend/rust/backend-burn/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::collections::HashMap;
use std::net::SocketAddr;
use std::process::{id, Command};
use std::sync::{Arc, Mutex};

use bunker::pb::Result as PbResult;
use bunker::pb::{
Expand All @@ -15,16 +17,12 @@ use tonic::{Request, Response, Status};
use async_trait::async_trait;

use tracing::{event, span, Level};
use tracing_subscriber::filter::LevelParseError;

use std::fs;
use std::process::{Command,id};

use models::*;
// implement BackendService trait in bunker

#[derive(Default, Debug)]
struct BurnBackend;
pub struct BurnBackend;

#[async_trait]
impl BackendService for BurnBackend {
Expand All @@ -42,33 +40,15 @@ impl BackendService for BurnBackend {

#[tracing::instrument]
async fn predict(&self, request: Request<PredictOptions>) -> Result<Response<Reply>, Status> {
let mut models: Vec<Box<dyn LLM>> = vec![Box::new(models::MNINST::new())];
let result = models[0].predict(request.into_inner());

match result {
Ok(res) => {
let reply = Reply {
message: res.into(),
};
let res = Response::new(reply);
Ok(res)
}
Err(e) => {
let reply = Reply {
message: e.to_string().into(),
};
let res = Response::new(reply);
Ok(res)
}
}
todo!("predict")
}

#[tracing::instrument]
async fn load_model(
&self,
request: Request<ModelOptions>,
) -> Result<Response<PbResult>, Status> {
todo!()
todo!("load_model")
}

#[tracing::instrument]
Expand Down Expand Up @@ -121,35 +101,34 @@ impl BackendService for BurnBackend {
&self,
request: Request<HealthMessage>,
) -> Result<Response<StatusResponse>, Status> {

// Here we do not need to cover the windows platform
let mut breakdown = HashMap::new();
let mut memory_usage: u64=0;
let mut memory_usage: u64 = 0;

#[cfg(target_os = "linux")]
{
let pid =id();
let stat = fs::read_to_string(format!("/proc/{}/stat", pid)).expect("Failed to read stat file");
let pid = id();
let stat = fs::read_to_string(format!("/proc/{}/stat", pid))
.expect("Failed to read stat file");

let stats: Vec<&str> = stat.split_whitespace().collect();
memory_usage = stats[23].parse::<u64>().expect("Failed to parse RSS");
}

#[cfg(target_os="macos")]
#[cfg(target_os = "macos")]
{
let output=Command::new("ps")
.arg("-p")
.arg(id().to_string())
.arg("-o")
.arg("rss=")
.output()
.expect("failed to execute process");

memory_usage = String::from_utf8_lossy(&output.stdout)
.trim()
.parse::<u64>()
.expect("Failed to parse memory usage");
let output = Command::new("ps")
.arg("-p")
.arg(id().to_string())
.arg("-o")
.arg("rss=")
.output()
.expect("failed to execute process");

memory_usage = String::from_utf8_lossy(&output.stdout)
.trim()
.parse::<u64>()
.expect("Failed to parse memory usage");
}
breakdown.insert("RSS".to_string(), memory_usage);

Expand All @@ -167,6 +146,87 @@ impl BackendService for BurnBackend {
}
}

#[cfg(test)]
mod tests {
use super::*;
use tonic::Request;

#[tokio::test]
async fn test_health() {
let backend = BurnBackend::default();
let request = Request::new(HealthMessage {});
let response = backend.health(request).await;

assert!(response.is_ok());
let response = response.unwrap();
let message_str = String::from_utf8(response.get_ref().message.clone()).unwrap();
assert_eq!(message_str, "OK");
}
#[tokio::test]
async fn test_status() {
let backend = BurnBackend::default();
let request = Request::new(HealthMessage {});
let response = backend.status(request).await;

assert!(response.is_ok());
let response = response.unwrap();
let state = response.get_ref().state;
assert_eq!(state, 0);
}

#[tokio::test]
async fn test_load_model() {
let backend = BurnBackend::default();
let request = Request::new(ModelOptions {
model: "test".to_string(),
context_size: 0,
seed: 0,
n_batch: 0,
f16_memory: false,
m_lock: false,
m_map: false,
vocab_only: false,
low_vram: false,
embeddings: false,
numa: false,
ngpu_layers: 0,
main_gpu: "".to_string(),
tensor_split: "".to_string(),
threads: 1,
library_search_path: "".to_string(),
rope_freq_base: 0.0,
rope_freq_scale: 0.0,
rms_norm_eps: 0.0,
ngqa: 0,
model_file: "".to_string(),
device: "".to_string(),
use_triton: false,
model_base_name: "".to_string(),
use_fast_tokenizer: false,
pipeline_type: "".to_string(),
scheduler_type: "".to_string(),
cuda: false,
cfg_scale: 0.0,
img2img: false,
clip_model: "".to_string(),
clip_subfolder: "".to_string(),
clip_skip: 0,
tokenizer: "".to_string(),
lora_base: "".to_string(),
lora_adapter: "".to_string(),
no_mul_mat_q: false,
draft_model: "".to_string(),
audio_path: "".to_string(),
quantization: "".to_string(),
});
let response = backend.load_model(request).await;

assert!(response.is_ok());
let response = response.unwrap();
//TO_DO: add test for response
}
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let subscriber = tracing_subscriber::fmt()
Expand Down
11 changes: 9 additions & 2 deletions backend/rust/models/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
use bunker::pb::{ModelOptions, PredictOptions};

pub(crate) mod mnist;
pub use mnist::mnist::MNINST;

use bunker::pb::{ModelOptions, PredictOptions};

/// Trait for implementing a Language Model.
pub trait LLM {
/// Loads the model from the given options.
fn load_model(&mut self, request: ModelOptions) -> Result<String, Box<dyn std::error::Error>>;
/// Predicts the output for the given input options.
fn predict(&mut self, request: PredictOptions) -> Result<String, Box<dyn std::error::Error>>;
}

pub struct LLModel {
model: Box<dyn LLM + 'static>,
}
10 changes: 5 additions & 5 deletions backend/rust/models/src/mnist/mnist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
//! Adapter by Aisuko
use burn::{
backend::wgpu::{compute::init_async, AutoGraphicsApi, WgpuDevice},
backend::wgpu::{AutoGraphicsApi, WgpuDevice},
module::Module,
nn::{self, BatchNorm, PaddingConfig2d},
record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
tensor::{backend::Backend, Tensor},
};

// https://github.com/burn-rs/burn/blob/main/examples/mnist-inference-web/model.bin
static STATE_ENCODED: &[u8] = include_bytes!("model.bin");

const NUM_CLASSES: usize = 10;

Expand All @@ -36,7 +35,7 @@ pub struct MNINST<B: Backend> {
}

impl<B: Backend> MNINST<B> {
pub fn new() -> Self {
pub fn new(model_name: &str) -> Self {
let conv1 = ConvBlock::new([1, 8], [3, 3]); // 1 input channel, 8 output channels, 3x3 kernel size
let conv2 = ConvBlock::new([8, 16], [3, 3]); // 8 input channels, 16 output channels, 3x3 kernel size
let conv3 = ConvBlock::new([16, 24], [3, 3]); // 16 input channels, 24 output channels, 3x3 kernel size
Expand All @@ -59,8 +58,9 @@ impl<B: Backend> MNINST<B> {
fc2: fc2,
activation: nn::GELU::new(),
};
let state_encoded: &[u8] = &std::fs::read(model_name).expect("Failed to load model");
let record = BinBytesRecorder::<FullPrecisionSettings>::default()
.load(STATE_ENCODED.to_vec())
.load(state_encoded.to_vec())
.expect("Failed to decode state");

instance.load_record(record)
Expand Down Expand Up @@ -178,7 +178,7 @@ mod tests {
pub type Backend = burn::backend::NdArrayBackend<f32>;
#[test]
fn test_inference() {
let mut model = MNINST::<Backend>::new();
let mut model = MNINST::<Backend>::new("model.bin");
let output = model.inference(&[0.0; 28 * 28]).unwrap();
assert_eq!(output.len(), 10);
}
Expand Down
12 changes: 9 additions & 3 deletions backend/rust/models/src/mnist/mod.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
use crate::LLM;
use bunker::pb::{ModelOptions, PredictOptions};

pub(crate) mod mnist;

use mnist::MNINST;

use bunker::pb::{ModelOptions, PredictOptions};

#[cfg(feature = "ndarray")]
pub type Backend = burn::backend::NdArrayBackend<f32>;

impl LLM for mnist::MNINST<Backend> {
impl LLM for MNINST<Backend> {
fn load_model(&mut self, request: ModelOptions) -> Result<String, Box<dyn std::error::Error>> {
todo!("load model")
let model = request.model_file;
let instance = MNINST::<Backend>::new(&model);
*self = instance;
Ok("".to_string())
}

fn predict(&mut self, pre_ops: PredictOptions) -> Result<String, Box<dyn std::error::Error>> {
Expand Down

0 comments on commit a6ff963

Please sign in to comment.