diff --git a/backend/rust/Cargo.toml b/backend/rust/Cargo.toml index 60719544e525..c9ec91b1559f 100644 --- a/backend/rust/Cargo.toml +++ b/backend/rust/Cargo.toml @@ -2,7 +2,7 @@ resolver = "2" members = [ "bunker", - "backend-burn", + "backend", "codegen", "models", ] \ No newline at end of file diff --git a/backend/rust/backend-burn/src/main.rs b/backend/rust/backend-burn/src/main.rs deleted file mode 100644 index 6aadfaba69e6..000000000000 --- a/backend/rust/backend-burn/src/main.rs +++ /dev/null @@ -1,199 +0,0 @@ -use std::collections::HashMap; -use std::net::SocketAddr; - -use bunker::pb::Result as PbResult; -use bunker::pb::{ - EmbeddingResult, GenerateImageRequest, HealthMessage, MemoryUsageData, ModelOptions, - PredictOptions, Reply, StatusResponse, TokenizationResponse, TranscriptRequest, - TranscriptResult, TtsRequest, -}; - -use bunker::BackendService; -use tokio_stream::wrappers::ReceiverStream; -use tonic::{Request, Response, Status}; - -use async_trait::async_trait; - -use tracing::{event, span, Level}; -use tracing_subscriber::filter::LevelParseError; - -use std::fs; -use std::process::{Command,id}; - -use models::*; -// implement BackendService trait in bunker - -#[derive(Default, Debug)] -struct BurnBackend; - -#[async_trait] -impl BackendService for BurnBackend { - type PredictStreamStream = ReceiverStream>; - - #[tracing::instrument] - async fn health(&self, request: Request) -> Result, Status> { - // return a Result,Status> - let reply = Reply { - message: "OK".into(), - }; - let res = Response::new(reply); - Ok(res) - } - - #[tracing::instrument] - async fn predict(&self, request: Request) -> Result, Status> { - let mut models: Vec> = vec![Box::new(models::MNINST::new())]; - let result = models[0].predict(request.into_inner()); - - match result { - Ok(res) => { - let reply = Reply { - message: res.into(), - }; - let res = Response::new(reply); - Ok(res) - } - Err(e) => { - let reply = Reply { - message: e.to_string().into(), - }; - let res = Response::new(reply); - Ok(res) - } - } - } - - #[tracing::instrument] - async fn load_model( - &self, - request: Request, - ) -> Result, Status> { - todo!() - } - - #[tracing::instrument] - async fn predict_stream( - &self, - request: Request, - ) -> Result>>, Status> { - todo!() - } - - #[tracing::instrument] - async fn embedding( - &self, - request: Request, - ) -> Result, Status> { - todo!() - } - - #[tracing::instrument] - async fn generate_image( - &self, - request: Request, - ) -> Result, Status> { - todo!() - } - - #[tracing::instrument] - async fn audio_transcription( - &self, - request: Request, - ) -> Result, Status> { - todo!() - } - - #[tracing::instrument] - async fn tts(&self, request: Request) -> Result, Status> { - todo!() - } - - #[tracing::instrument] - async fn tokenize_string( - &self, - request: Request, - ) -> Result, Status> { - todo!() - } - - #[tracing::instrument] - async fn status( - &self, - request: Request, - ) -> Result, Status> { - - // Here we do not need to cover the windows platform - let mut breakdown = HashMap::new(); - let mut memory_usage: u64=0; - - #[cfg(target_os = "linux")] - { - let pid =id(); - let stat = fs::read_to_string(format!("/proc/{}/stat", pid)).expect("Failed to read stat file"); - - let stats: Vec<&str> = stat.split_whitespace().collect(); - memory_usage = stats[23].parse::().expect("Failed to parse RSS"); - } - - #[cfg(target_os="macos")] - { - let output=Command::new("ps") - .arg("-p") - .arg(id().to_string()) - .arg("-o") - .arg("rss=") - .output() - .expect("failed to execute process"); - - memory_usage = String::from_utf8_lossy(&output.stdout) - .trim() - .parse::() - .expect("Failed to parse memory usage"); - - } - breakdown.insert("RSS".to_string(), memory_usage); - - let memory_usage = Option::from(MemoryUsageData { - total: memory_usage, - breakdown, - }); - - let reponse = StatusResponse { - state: 0, //TODO: add state https://github.com/mudler/LocalAI/blob/9b17af18b3aa0c3cab16284df2d6f691736c30c1/pkg/grpc/proto/backend.proto#L188C9-L188C9 - memory: memory_usage, - }; - - Ok(Response::new(reponse)) - } -} - -#[tokio::main] -async fn main() -> Result<(), Box> { - let subscriber = tracing_subscriber::fmt() - .compact() - .with_file(true) - .with_line_number(true) - .with_thread_ids(true) - .with_target(false) - .finish(); - - tracing::subscriber::set_global_default(subscriber)?; - - // call bunker::run with BurnBackend - let burn_backend = BurnBackend {}; - let addr = "[::1]:50051" - .parse::() - .expect("Failed to parse address"); - - // Implmenet Into for addr - let result = bunker::run(burn_backend, addr).await?; - - event!(Level::INFO, "Burn Server is starting"); - - let span = span!(Level::INFO, "Burn Server"); - let _enter = span.enter(); - - event!(Level::INFO, "Burn Server started successfully"); - - Ok(result) -} diff --git a/backend/rust/backend-burn/Cargo.toml b/backend/rust/backend/Cargo.toml similarity index 94% rename from backend/rust/backend-burn/Cargo.toml rename to backend/rust/backend/Cargo.toml index f97347d324b8..4f32f6e1640a 100644 --- a/backend/rust/backend-burn/Cargo.toml +++ b/backend/rust/backend/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "backend-burn" +name = "backend" version = "0.1.0" edition = "2021" diff --git a/backend/rust/backend-burn/Makefile b/backend/rust/backend/Makefile similarity index 100% rename from backend/rust/backend-burn/Makefile rename to backend/rust/backend/Makefile diff --git a/backend/rust/backend/src/main.rs b/backend/rust/backend/src/main.rs new file mode 100644 index 000000000000..7f6b442c3b6c --- /dev/null +++ b/backend/rust/backend/src/main.rs @@ -0,0 +1,309 @@ +use std::collections::HashMap; +use std::net::SocketAddr; +use std::process::{id, Command}; +use std::sync::{Arc, Mutex}; + +use bunker::pb::Result as PbResult; +use bunker::pb::{ + EmbeddingResult, GenerateImageRequest, HealthMessage, MemoryUsageData, ModelOptions, + PredictOptions, Reply, StatusResponse, TokenizationResponse, TranscriptRequest, + TranscriptResult, TtsRequest, +}; + +use bunker::BackendService; +use tokio_stream::wrappers::ReceiverStream; +use tonic::{Request, Response, Status}; + +use async_trait::async_trait; + +use tracing::{event, span, Level}; + +use models::*; +// implement BackendService trait in bunker + +#[derive(Default, Debug)] +pub struct BurnBackend; + +#[async_trait] +impl BackendService for BurnBackend { + type PredictStreamStream = ReceiverStream>; + + #[tracing::instrument] + async fn health(&self, request: Request) -> Result, Status> { + // return a Result,Status> + let reply = Reply { + message: "OK".into(), + }; + let res = Response::new(reply); + Ok(res) + } + + #[tracing::instrument] + async fn predict(&self, request: Request) -> Result, Status> { + // TODO: How to get model from load_model function? + let mut model= MNINST::new("model.bin"); + let result = model.predict(request.get_ref().clone()); + match result { + Ok(output) => { + let reply = Reply { + message: output.into_bytes(), + }; + let res = Response::new(reply); + Ok(res) + } + Err(e) => { + let result = PbResult { + message: format!("Failed to predict: {}", e), + success: false, + }; + Err(Status::internal(result.message)) + } + } + } + + #[tracing::instrument] + async fn load_model( + &self, + request: Request, + ) -> Result, Status> { + let result= match request.get_ref().model.as_str() { + "mnist" => { + let mut model = MNINST::new("model.bin"); + let result = model.load_model(request.get_ref().clone()); + match result { + Ok(_) => { + let model = Arc::new(Mutex::new(model)); + let model = model.clone(); + let result = PbResult { + message: "Model loaded successfully".into(), + success: true, + }; + Ok(Response::new(result)) + } + Err(e) => { + let result = PbResult { + message: format!("Failed to load model: {}", e), + success: false, + }; + Err(Status::internal(result.message)) + } + } + } + _ => { + let result = PbResult { + message: format!("Model {} not found", request.get_ref().model), + success: false, + }; + Err(Status::internal(result.message)) + } + }; + // TODO: add model to backend, how to transfer model to backend and let predict funciton can use it? + result + } + + #[tracing::instrument] + async fn predict_stream( + &self, + request: Request, + ) -> Result>>, Status> { + todo!() + } + + #[tracing::instrument] + async fn embedding( + &self, + request: Request, + ) -> Result, Status> { + todo!() + } + + #[tracing::instrument] + async fn generate_image( + &self, + request: Request, + ) -> Result, Status> { + todo!() + } + + #[tracing::instrument] + async fn audio_transcription( + &self, + request: Request, + ) -> Result, Status> { + todo!() + } + + #[tracing::instrument] + async fn tts(&self, request: Request) -> Result, Status> { + todo!() + } + + #[tracing::instrument] + async fn tokenize_string( + &self, + request: Request, + ) -> Result, Status> { + todo!() + } + + #[tracing::instrument] + async fn status( + &self, + request: Request, + ) -> Result, Status> { + // Here we do not need to cover the windows platform + let mut breakdown = HashMap::new(); + let mut memory_usage: u64 = 0; + + #[cfg(target_os = "linux")] + { + let pid = id(); + let stat = fs::read_to_string(format!("/proc/{}/stat", pid)) + .expect("Failed to read stat file"); + + let stats: Vec<&str> = stat.split_whitespace().collect(); + memory_usage = stats[23].parse::().expect("Failed to parse RSS"); + } + + #[cfg(target_os = "macos")] + { + let output = Command::new("ps") + .arg("-p") + .arg(id().to_string()) + .arg("-o") + .arg("rss=") + .output() + .expect("failed to execute process"); + + memory_usage = String::from_utf8_lossy(&output.stdout) + .trim() + .parse::() + .expect("Failed to parse memory usage"); + } + breakdown.insert("RSS".to_string(), memory_usage); + + let memory_usage = Option::from(MemoryUsageData { + total: memory_usage, + breakdown, + }); + + let reponse = StatusResponse { + state: 0, //TODO: add state https://github.com/mudler/LocalAI/blob/9b17af18b3aa0c3cab16284df2d6f691736c30c1/pkg/grpc/proto/backend.proto#L188C9-L188C9 + memory: memory_usage, + }; + + Ok(Response::new(reponse)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tonic::Request; + + #[tokio::test] + async fn test_health() { + let backend = BurnBackend::default(); + let request = Request::new(HealthMessage {}); + let response = backend.health(request).await; + + assert!(response.is_ok()); + let response = response.unwrap(); + let message_str = String::from_utf8(response.get_ref().message.clone()).unwrap(); + assert_eq!(message_str, "OK"); + } + #[tokio::test] + async fn test_status() { + let backend = BurnBackend::default(); + let request = Request::new(HealthMessage {}); + let response = backend.status(request).await; + + assert!(response.is_ok()); + let response = response.unwrap(); + let state = response.get_ref().state; + assert_eq!(state, 0); + } + + #[tokio::test] + async fn test_load_model() { + let backend = BurnBackend::default(); + let request = Request::new(ModelOptions { + model: "test".to_string(), + context_size: 0, + seed: 0, + n_batch: 0, + f16_memory: false, + m_lock: false, + m_map: false, + vocab_only: false, + low_vram: false, + embeddings: false, + numa: false, + ngpu_layers: 0, + main_gpu: "".to_string(), + tensor_split: "".to_string(), + threads: 1, + library_search_path: "".to_string(), + rope_freq_base: 0.0, + rope_freq_scale: 0.0, + rms_norm_eps: 0.0, + ngqa: 0, + model_file: "".to_string(), + device: "".to_string(), + use_triton: false, + model_base_name: "".to_string(), + use_fast_tokenizer: false, + pipeline_type: "".to_string(), + scheduler_type: "".to_string(), + cuda: false, + cfg_scale: 0.0, + img2img: false, + clip_model: "".to_string(), + clip_subfolder: "".to_string(), + clip_skip: 0, + tokenizer: "".to_string(), + lora_base: "".to_string(), + lora_adapter: "".to_string(), + no_mul_mat_q: false, + draft_model: "".to_string(), + audio_path: "".to_string(), + quantization: "".to_string(), + }); + let response = backend.load_model(request).await; + + assert!(response.is_ok()); + let response = response.unwrap(); + //TO_DO: add test for response + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let subscriber = tracing_subscriber::fmt() + .compact() + .with_file(true) + .with_line_number(true) + .with_thread_ids(true) + .with_target(false) + .finish(); + + tracing::subscriber::set_global_default(subscriber)?; + + // call bunker::run with BurnBackend + let burn_backend = BurnBackend {}; + let addr = "[::1]:50051" + .parse::() + .expect("Failed to parse address"); + + // Implmenet Into for addr + let result = bunker::run(burn_backend, addr).await?; + + event!(Level::INFO, "Burn Server is starting"); + + let span = span!(Level::INFO, "Burn Server"); + let _enter = span.enter(); + + event!(Level::INFO, "Burn Server started successfully"); + + Ok(result) +} diff --git a/backend/rust/models/src/lib.rs b/backend/rust/models/src/lib.rs index f3302e83ef73..b739bf0b0d51 100644 --- a/backend/rust/models/src/lib.rs +++ b/backend/rust/models/src/lib.rs @@ -1,9 +1,16 @@ +use bunker::pb::{ModelOptions, PredictOptions}; + pub(crate) mod mnist; pub use mnist::mnist::MNINST; -use bunker::pb::{ModelOptions, PredictOptions}; - +/// Trait for implementing a Language Model. pub trait LLM { + /// Loads the model from the given options. fn load_model(&mut self, request: ModelOptions) -> Result>; + /// Predicts the output for the given input options. fn predict(&mut self, request: PredictOptions) -> Result>; } + +pub struct LLModel { + model: Box, +} diff --git a/backend/rust/models/src/mnist/mnist.rs b/backend/rust/models/src/mnist/mnist.rs index 995b2706ed05..7a727bbbf441 100644 --- a/backend/rust/models/src/mnist/mnist.rs +++ b/backend/rust/models/src/mnist/mnist.rs @@ -4,7 +4,7 @@ //! Adapter by Aisuko use burn::{ - backend::wgpu::{compute::init_async, AutoGraphicsApi, WgpuDevice}, + backend::wgpu::{AutoGraphicsApi, WgpuDevice}, module::Module, nn::{self, BatchNorm, PaddingConfig2d}, record::{BinBytesRecorder, FullPrecisionSettings, Recorder}, @@ -12,7 +12,6 @@ use burn::{ }; // https://github.com/burn-rs/burn/blob/main/examples/mnist-inference-web/model.bin -static STATE_ENCODED: &[u8] = include_bytes!("model.bin"); const NUM_CLASSES: usize = 10; @@ -36,7 +35,7 @@ pub struct MNINST { } impl MNINST { - pub fn new() -> Self { + pub fn new(model_name: &str) -> Self { let conv1 = ConvBlock::new([1, 8], [3, 3]); // 1 input channel, 8 output channels, 3x3 kernel size let conv2 = ConvBlock::new([8, 16], [3, 3]); // 8 input channels, 16 output channels, 3x3 kernel size let conv3 = ConvBlock::new([16, 24], [3, 3]); // 16 input channels, 24 output channels, 3x3 kernel size @@ -59,8 +58,9 @@ impl MNINST { fc2: fc2, activation: nn::GELU::new(), }; + let state_encoded: &[u8] = &std::fs::read(model_name).expect("Failed to load model"); let record = BinBytesRecorder::::default() - .load(STATE_ENCODED.to_vec()) + .load(state_encoded.to_vec()) .expect("Failed to decode state"); instance.load_record(record) @@ -178,7 +178,7 @@ mod tests { pub type Backend = burn::backend::NdArrayBackend; #[test] fn test_inference() { - let mut model = MNINST::::new(); + let mut model = MNINST::::new("model.bin"); let output = model.inference(&[0.0; 28 * 28]).unwrap(); assert_eq!(output.len(), 10); } diff --git a/backend/rust/models/src/mnist/mod.rs b/backend/rust/models/src/mnist/mod.rs index d53b76c6c7a4..8cc85ce0e520 100644 --- a/backend/rust/models/src/mnist/mod.rs +++ b/backend/rust/models/src/mnist/mod.rs @@ -1,14 +1,20 @@ use crate::LLM; -use bunker::pb::{ModelOptions, PredictOptions}; pub(crate) mod mnist; +use mnist::MNINST; + +use bunker::pb::{ModelOptions, PredictOptions}; + #[cfg(feature = "ndarray")] pub type Backend = burn::backend::NdArrayBackend; -impl LLM for mnist::MNINST { +impl LLM for MNINST { fn load_model(&mut self, request: ModelOptions) -> Result> { - todo!("load model") + let model = request.model_file; + let instance = MNINST::::new(&model); + *self = instance; + Ok("".to_string()) } fn predict(&mut self, pre_ops: PredictOptions) -> Result> {