diff --git a/Cargo.lock b/Cargo.lock index 5125fc574a..5d4be36767 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -229,6 +229,17 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "async-trait" +version = "0.1.83" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "atoi" version = "2.0.0" @@ -296,6 +307,64 @@ dependencies = [ "arrayvec", ] +[[package]] +name = "axum" +version = "0.7.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae" +dependencies = [ + "async-trait", + "axum-core", + "base64 0.22.1", + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.5.0", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sha1", + "sync_wrapper 1.0.1", + "tokio", + "tokio-tungstenite", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper 1.0.1", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "backend-comparison" version = "0.16.0" @@ -524,6 +593,7 @@ dependencies = [ "indicatif", "rayon", "reqwest 0.12.9", + "serde", "tokio", "web-time", ] @@ -542,6 +612,7 @@ dependencies = [ "burn-derive", "burn-hip", "burn-ndarray", + "burn-remote", "burn-tch", "burn-tensor", "burn-wgpu", @@ -647,6 +718,7 @@ dependencies = [ "derive-new 0.7.0", "half", "log", + "paste", ] [[package]] @@ -729,15 +801,38 @@ dependencies = [ "serde", ] +[[package]] +name = "burn-remote" +version = "0.16.0" +dependencies = [ + "axum", + "burn-common", + "burn-remote", + "burn-router", + "burn-tensor", + "derive-new 0.7.0", + "futures-util", + "log", + "rmp-serde", + "serde", + "serde_bytes", + "tokio", + "tokio-tungstenite", + "tracing-core", + "tracing-subscriber", +] + [[package]] name = "burn-router" version = "0.16.0" dependencies = [ "burn-autodiff", + "burn-common", "burn-ndarray", "burn-tensor", "burn-wgpu", "hashbrown 0.15.0", + "log", "spin", ] @@ -3149,6 +3244,7 @@ dependencies = [ "http 1.1.0", "http-body 1.0.1", "httparse", + "httpdate", "itoa", "pin-project-lite", "smallvec", @@ -3693,6 +3789,12 @@ dependencies = [ "regex-automata 0.1.10", ] +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + [[package]] name = "matrixmultiply" version = "0.3.9" @@ -6098,6 +6200,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af99884400da37c88f5e9146b7f1fd0fbcae8f6eec4e9da38b67d05486f814a6" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_rusqlite" version = "0.36.0" @@ -6154,6 +6266,14 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "server" +version = "0.16.0" +dependencies = [ + "burn", + "cfg-if", +] + [[package]] name = "sha1" version = "0.10.6" @@ -6874,6 +6994,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.12" @@ -6936,6 +7068,28 @@ dependencies = [ "zip 0.6.6", ] +[[package]] +name = "tower" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper 0.1.2", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + [[package]] name = "tower-service" version = "0.3.3" @@ -6978,6 +7132,7 @@ version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -7051,6 +7206,24 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.1.0", + "httparse", + "log", + "rand", + "sha1", + "thiserror", + "utf-8", +] + [[package]] name = "typenum" version = "1.17.0" @@ -7174,6 +7347,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8parse" version = "0.2.2" diff --git a/Cargo.toml b/Cargo.toml index 261d85d2e3..8e48e1186b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -119,6 +119,7 @@ bincode = { version = "2.0.0-rc.3", features = [ # The following packages disable the "std" feature for no_std compatibility # derive-new = { version = "0.7.0", default-features = false } +cfg-if = "1.0.0" blas-src = { version = "0.10.0", default-features = false } half = { version = "2.4.1", features = [ diff --git a/crates/burn-common/Cargo.toml b/crates/burn-common/Cargo.toml index 02205f94e3..2ec4cdb3fc 100644 --- a/crates/burn-common/Cargo.toml +++ b/crates/burn-common/Cargo.toml @@ -23,6 +23,8 @@ getrandom = { workspace = true, features = ["js"] } web-time = { version = "1.1.0" } [dependencies] +serde = { workspace = true } + # Network downloader indicatif = { workspace = true, optional = true } reqwest = { workspace = true, optional = true } diff --git a/crates/burn-common/src/id.rs b/crates/burn-common/src/id.rs index 6a90892249..81aceb752f 100644 --- a/crates/burn-common/src/id.rs +++ b/crates/burn-common/src/id.rs @@ -1,4 +1,5 @@ use crate::rand::gen_random; +use serde::{Deserialize, Serialize}; /// Simple ID generator. pub struct IdGenerator {} @@ -64,3 +65,49 @@ mod tests { assert_eq!(set.len(), EXPECTED_TOTAL_IDS); } } + +/// Unique identifier that can represent a stream based on the current thread id. +#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)] +pub struct StreamId { + /// The value representing the thread id. + pub value: u64, +} + +impl StreamId { + /// Get the current thread id. + pub fn current() -> Self { + Self { + #[cfg(feature = "std")] + value: Self::from_current_thread(), + #[cfg(not(feature = "std"))] + value: 0, + } + } + + #[cfg(feature = "std")] + fn from_current_thread() -> u64 { + use core::hash::Hash; + + std::thread_local! { + static ID: std::cell::OnceCell:: = const { std::cell::OnceCell::new() }; + }; + + // Getting the current thread is expensive, so we cache the value into a thread local + // variable, which is very fast. + ID.with(|cell| { + *cell.get_or_init(|| { + // A way to get a thread id encoded as u64. + let mut hasher = std::hash::DefaultHasher::default(); + let id = std::thread::current().id(); + id.hash(&mut hasher); + std::hash::Hasher::finish(&hasher) + }) + }) + } +} + +impl core::fmt::Display for StreamId { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_fmt(format_args!("StreamId({:?})", self.value)) + } +} diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index 88d52370a3..c82b27265a 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -39,6 +39,8 @@ doc = [ "hip-jit", "vision", "autodiff", + "remote", + "server", # Doc features "burn-candle/doc", "burn-common/doc", @@ -86,6 +88,8 @@ metal = ["burn-candle?/metal"] openblas = ["burn-ndarray?/blas-openblas"] openblas-system = ["burn-ndarray?/blas-openblas-system"] template = ["burn-wgpu?/template"] +remote = ["burn-remote/client"] +server = ["burn-remote/server"] candle = ["burn-candle"] candle-cuda = ["candle", "burn-candle/cuda"] @@ -131,6 +135,7 @@ burn-hip = { path = "../burn-hip", version = "0.16.0", optional = true, default- burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", optional = true, default-features = false } burn-tch = { path = "../burn-tch", version = "0.16.0", optional = true } burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", optional = true, default-features = false } +burn-remote = { path = "../burn-remote", version = "0.16.0", default-features = false, optional = true } data-encoding = { workspace = true } uuid = { workspace = true } diff --git a/crates/burn-core/src/backend.rs b/crates/burn-core/src/backend.rs index 7aeef77120..5608ca6d1a 100644 --- a/crates/burn-core/src/backend.rs +++ b/crates/burn-core/src/backend.rs @@ -7,6 +7,11 @@ pub use ndarray::NdArray; #[cfg(feature = "autodiff")] pub use burn_autodiff as autodiff; +#[cfg(feature = "remote")] +pub use burn_remote as remote; +#[cfg(feature = "remote")] +pub use burn_remote::RemoteBackend; + #[cfg(feature = "autodiff")] pub use burn_autodiff::Autodiff; diff --git a/crates/burn-core/src/lib.rs b/crates/burn-core/src/lib.rs index ce118345ad..d1788d10cd 100644 --- a/crates/burn-core/src/lib.rs +++ b/crates/burn-core/src/lib.rs @@ -43,6 +43,9 @@ pub mod tensor; /// Backend module. pub mod backend; +#[cfg(feature = "server")] +pub use burn_remote::server; + extern crate alloc; #[cfg(all( diff --git a/crates/burn-cuda/Cargo.toml b/crates/burn-cuda/Cargo.toml index 708f16227c..c366386b0e 100644 --- a/crates/burn-cuda/Cargo.toml +++ b/crates/burn-cuda/Cargo.toml @@ -12,8 +12,8 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-cuda" version.workspace = true [features] +default = ["fusion", "autotune", "burn-jit/default", "cubecl/default"] autotune = ["burn-jit/autotune"] -default = ["fusion", "burn-jit/default", "cubecl/default"] doc = ["burn-jit/doc"] fusion = ["burn-fusion", "burn-jit/fusion"] std = ["burn-jit/std", "cubecl/std"] diff --git a/crates/burn-hip/Cargo.toml b/crates/burn-hip/Cargo.toml index 65ae4a0b01..d5f0bb70f5 100644 --- a/crates/burn-hip/Cargo.toml +++ b/crates/burn-hip/Cargo.toml @@ -34,6 +34,7 @@ derive-new = { workspace = true } burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false, features = [ "export_tests", ] } +paste = { workspace = true } [package.metadata.docs.rs] features = ["doc"] diff --git a/crates/burn-hip/src/lib.rs b/crates/burn-hip/src/lib.rs index cc4b3172a5..35442dfa4f 100644 --- a/crates/burn-hip/src/lib.rs +++ b/crates/burn-hip/src/lib.rs @@ -24,6 +24,7 @@ mod tests { use burn_jit::JitBackend; pub type TestRuntime = cubecl::hip::HipRuntime; + pub use half::{bf16, f16}; burn_jit::testgen_all!(); } diff --git a/crates/burn-remote/Cargo.toml b/crates/burn-remote/Cargo.toml new file mode 100644 index 0000000000..65ca12e1ee --- /dev/null +++ b/crates/burn-remote/Cargo.toml @@ -0,0 +1,51 @@ +[package] +authors = ["nathanielsimard "] +categories = ["science"] +description = "Backend router decorator over websocket." +edition.workspace = true +keywords = ["deep-learning", "machine-learning", "data"] +license.workspace = true +name = "burn-remote" +readme.workspace = true +repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-router-remote" +documentation = "https://docs.rs/burn-router-remote" +version.workspace = true + +[features] +default = [] +doc = [] +client = ["tokio-tungstenite"] +server = ["axum", "tracing-core", "tracing-subscriber"] + + +[dependencies] +burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = true, features = ["repr"]} +burn-common = { path = "../burn-common", version = "0.16.0", default-features = true} +burn-router = { path = "../burn-router", version = "0.16.0", default-features = true} + +# Basic dependencies +derive-new = {workspace = true } +log = { workspace = true } + +# Shared dependencies +tokio = { version = "1.37", features = ["sync", "rt-multi-thread"] } +serde = { workspace = true, features = ["derive"] } +serde_bytes = { workspace = true } +rmp-serde = { workspace = true } +futures-util = { version = "0.3" } + +# Client dependencies +tokio-tungstenite = { version = "0.24", optional = true } + +# Server dependencies +axum = { version = "0.7.5", features = ["ws"], optional = true } +tracing-core = { workspace = true, optional = true } +tracing-subscriber = { workspace = true, optional = true } + +[dev-dependencies] +# We activate the features client and server during dev. +burn-remote = { path = ".", version = "0.16.0", features=["client", "server"] } + +[package.metadata.docs.rs] +features = ["doc"] +rustdoc-args = ["--cfg", "docsrs"] diff --git a/crates/burn-remote/README.md b/crates/burn-remote/README.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/crates/burn-remote/src/client/base.rs b/crates/burn-remote/src/client/base.rs new file mode 100644 index 0000000000..29057f7886 --- /dev/null +++ b/crates/burn-remote/src/client/base.rs @@ -0,0 +1,99 @@ +use super::worker::{ClientRequest, ClientWorker}; +use crate::shared::{ComputeTask, ConnectionId, Task, TaskResponseContent}; +use burn_common::id::StreamId; +use burn_tensor::repr::TensorId; +use std::{ + future::Future, + sync::{atomic::AtomicU64, Arc}, +}; +use tokio::sync::mpsc::Sender; + +pub use super::WsDevice; + +#[derive(Clone)] +pub struct WsClient { + pub(crate) device: WsDevice, + pub(crate) sender: Arc, + pub(crate) runtime: Arc, +} + +impl WsClient { + pub fn init(device: WsDevice) -> Self { + ClientWorker::start(device) + } + + pub(crate) fn new( + device: WsDevice, + sender: Sender, + runtime: Arc, + ) -> Self { + Self { + device, + runtime, + sender: Arc::new(WsSender { + sender, + position_counter: AtomicU64::new(0), + tensor_id_counter: AtomicU64::new(0), + }), + } + } +} + +pub(crate) struct WsSender { + sender: Sender, + position_counter: AtomicU64, + tensor_id_counter: AtomicU64, +} + +impl WsSender { + pub(crate) fn send(&self, task: ComputeTask) -> impl Future + Send { + let position = self + .position_counter + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let stream_id = StreamId::current(); + let sender = self.sender.clone(); + + async move { + sender + .send(ClientRequest::WithoutCallback(Task::Compute( + task, + ConnectionId::new(position, stream_id), + ))) + .await + .unwrap(); + } + } + + pub(crate) fn new_tensor_id(&self) -> TensorId { + let val = self + .tensor_id_counter + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + TensorId::new(val) + } + pub(crate) fn send_callback( + &self, + task: ComputeTask, + ) -> impl Future + Send { + let position = self + .position_counter + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let stream_id = StreamId::current(); + let sender = self.sender.clone(); + let (callback_sender, mut callback_recv) = tokio::sync::mpsc::channel(1); + + async move { + sender + .send(ClientRequest::WithSyncCallback( + Task::Compute(task, ConnectionId::new(position, stream_id)), + callback_sender, + )) + .await + .unwrap(); + + match callback_recv.recv().await { + Some(val) => val, + None => panic!(""), + } + } + } +} diff --git a/crates/burn-remote/src/client/channel.rs b/crates/burn-remote/src/client/channel.rs new file mode 100644 index 0000000000..6c431702af --- /dev/null +++ b/crates/burn-remote/src/client/channel.rs @@ -0,0 +1,45 @@ +use burn_router::{RouterTensor, RunnerChannel, TensorHandle}; +use burn_tensor::repr::TensorDescription; + +use super::{ + runner::{WsBridge, WsDevice}, + WsClient, +}; + +/// A local channel with direct connection to the backend runner clients. +#[derive(Clone)] +pub struct WsChannel; + +impl RunnerChannel for WsChannel { + type Device = WsDevice; + type Bridge = WsBridge; + type Client = WsClient; + + type FloatElem = f32; + + type IntElem = i32; + + fn name() -> String { + "remote".into() + } + + fn init_client(device: &Self::Device) -> Self::Client { + WsClient::init(device.clone()) + } + + fn get_tensor_handle( + _tensor: &TensorDescription, + _client: &Self::Client, + ) -> TensorHandle { + panic!("Unsupported") + } + + fn register_tensor( + _client: &Self::Client, + _handle: TensorHandle, + _shape: Vec, + _dtype: burn_tensor::DType, + ) -> RouterTensor { + panic!("Unsupported") + } +} diff --git a/crates/burn-remote/src/client/mod.rs b/crates/burn-remote/src/client/mod.rs new file mode 100644 index 0000000000..55073f5625 --- /dev/null +++ b/crates/burn-remote/src/client/mod.rs @@ -0,0 +1,8 @@ +mod base; +mod channel; +mod runner; +mod worker; + +pub use base::*; +pub use channel::*; +pub use runner::WsDevice; diff --git a/crates/burn-remote/src/client/runner.rs b/crates/burn-remote/src/client/runner.rs new file mode 100644 index 0000000000..f75bc52173 --- /dev/null +++ b/crates/burn-remote/src/client/runner.rs @@ -0,0 +1,175 @@ +use burn_router::{MultiBackendBridge, RouterTensor, RunnerClient}; +use burn_tensor::{ + backend::{DeviceId, DeviceOps}, + DType, TensorData, +}; +use std::sync::Arc; + +use crate::shared::{ComputeTask, TaskResponseContent}; + +use super::WsClient; + +// It is very important to block on any request made with the sender, since ordering is crucial +// when registering operation or creating tensors. +// +// The overhead is minimal, since we only wait for the task to be sent to the async +// channel, but not sent to the websocket server and even less processed by the server. +impl RunnerClient for WsClient { + type Device = WsDevice; + + fn register(&self, op: burn_tensor::repr::OperationDescription) { + let fut = self + .sender + .send(ComputeTask::RegisterOperation(Box::new(op))); + self.runtime.block_on(fut); + } + + fn read_tensor( + &self, + tensor: burn_tensor::repr::TensorDescription, + ) -> impl std::future::Future + Send { + // Important for ordering to call the creation of the future sync. + let fut = self.sender.send_callback(ComputeTask::ReadTensor(tensor)); + + async move { + match fut.await { + TaskResponseContent::ReadTensor(data) => data, + _ => panic!("Invalid message type"), + } + } + } + + fn register_tensor_data(&self, data: TensorData) -> RouterTensor { + let id = self.sender.new_tensor_id(); + let shape = data.shape.clone(); + let dtype = data.dtype; + + let fut = self.sender.send(ComputeTask::RegisterTensor(id, data)); + + self.runtime.block_on(fut); + + RouterTensor::new(Arc::new(id), shape, dtype, self.clone()) + } + + fn register_empty_tensor( + &self, + shape: Vec, + dtype: burn_tensor::DType, + ) -> RouterTensor { + let id = self.sender.new_tensor_id(); + + RouterTensor::new(Arc::new(id), shape, dtype, self.clone()) + } + + fn register_float_tensor( + &self, + shape: Vec, + _full_precision: bool, + ) -> RouterTensor { + self.register_empty_tensor(shape, DType::F32) + } + + fn device(&self) -> Self::Device { + self.device.clone() + } + + fn register_orphan(&self, id: &burn_tensor::repr::TensorId) { + let fut = self.sender.send(ComputeTask::RegisterOrphan(*id)); + self.runtime.block_on(fut); + } + + fn sync(&self) { + // Important for ordering to call the creation of the future sync. + let fut = self.sender.send_callback(ComputeTask::SyncBackend); + + let fut = async move { + match fut.await { + TaskResponseContent::SyncBackend => {} + _ => panic!("Invalid message type"), + }; + }; + + self.runtime.block_on(fut) + } + + fn seed(&self, _seed: u64) { + // TODO + } +} + +#[derive(Clone, PartialEq, Eq, Debug)] +/// The device contains the connection information of the server. +pub struct WsDevice { + pub(crate) address: Arc, +} + +impl WsDevice { + /// Create a device from an url. + pub fn new(url: &str) -> Self { + let mut address = String::new(); + + if !url.starts_with("ws://") { + address += "ws://"; + address += url; + } else { + address += url; + }; + + Self { + address: Arc::new(address), + } + } +} + +impl Default for WsDevice { + fn default() -> Self { + let address = match std::env::var("BURN_REMOTE_ADDRESS") { + Ok(address) => address, + Err(_) => String::from("ws://127.0.0.1:3000"), + }; + + Self { + address: Arc::new(address), + } + } +} + +impl DeviceOps for WsDevice { + fn id(&self) -> DeviceId { + DeviceId { + type_id: 0, + index_id: 0, + } + } +} + +pub struct WsBridge; + +impl MultiBackendBridge for WsBridge { + type TensorHandle = TensorData; + type Device = WsDevice; + + fn change_backend_float( + tensor: Self::TensorHandle, + _shape: burn_tensor::Shape, + _target_device: &Self::Device, + ) -> Self::TensorHandle { + tensor + } + + fn change_backend_int( + tensor: Self::TensorHandle, + _shape: burn_tensor::Shape, + _target_device: &Self::Device, + ) -> Self::TensorHandle { + tensor + } + + fn change_backend_bool( + tensor: Self::TensorHandle, + _shape: burn_tensor::Shape, + _target_device: &Self::Device, + ) -> Self::TensorHandle { + tensor + } +} diff --git a/crates/burn-remote/src/client/worker.rs b/crates/burn-remote/src/client/worker.rs new file mode 100644 index 0000000000..75209dfa2a --- /dev/null +++ b/crates/burn-remote/src/client/worker.rs @@ -0,0 +1,140 @@ +use super::{runner::WsDevice, WsClient}; +use crate::shared::{ConnectionId, SessionId, Task, TaskResponse, TaskResponseContent}; +use futures_util::{SinkExt, StreamExt}; +use std::{collections::HashMap, sync::Arc}; +use tokio_tungstenite::{ + connect_async_with_config, + tungstenite::protocol::{Message, WebSocketConfig}, +}; + +pub type CallbackSender = tokio::sync::mpsc::Sender; + +pub enum ClientRequest { + WithSyncCallback(Task, CallbackSender), + WithoutCallback(Task), +} + +#[derive(Default)] +pub(crate) struct ClientWorker { + requests: HashMap, +} + +impl ClientWorker { + async fn on_response(&mut self, response: TaskResponse) { + match self.requests.remove(&response.id) { + Some(request) => { + request.send(response.content).await.unwrap(); + } + None => { + panic!("Can't ignore message from the server."); + } + } + } + + fn register_callback(&mut self, id: ConnectionId, callback: CallbackSender) { + self.requests.insert(id, callback); + } +} + +impl ClientWorker { + pub fn start(device: WsDevice) -> WsClient { + let runtime = Arc::new( + tokio::runtime::Builder::new_multi_thread() + .enable_io() + .build() + .unwrap(), + ); + + let (sender, mut rec) = tokio::sync::mpsc::channel(10); + let address_request = format!("{}/{}", device.address.as_str(), "request"); + let address_response = format!("{}/{}", device.address.as_str(), "response"); + + const MB: usize = 1024 * 1024; + + #[allow(deprecated)] + runtime.spawn(async move { + log::info!("Connecting to {address_request} ..."); + let (mut stream_request, _) = connect_async_with_config( + address_request.clone(), + Some(WebSocketConfig { + max_send_queue: None, + write_buffer_size: 0, + max_write_buffer_size: usize::MAX, + max_message_size: None, + max_frame_size: Some(MB * 512), + accept_unmasked_frames: true, + }), + true, + ) + .await + .expect("Failed to connect"); + let (mut stream_response, _) = connect_async_with_config( + address_response, + Some(WebSocketConfig { + max_send_queue: None, + write_buffer_size: 0, + max_write_buffer_size: usize::MAX, + max_message_size: None, + max_frame_size: Some(MB * 512), + accept_unmasked_frames: true, + }), + true, + ) + .await + .expect("Failed to connect"); + + let state = Arc::new(tokio::sync::Mutex::new(ClientWorker::default())); + + // Init the connection. + let session_id = SessionId::new(); + let bytes = rmp_serde::to_vec(&Task::Init(session_id)).expect("Can serialize tasks to bytes."); + stream_request.send(Message::Binary(bytes.clone())).await.expect("Can send the message on the websocket."); + stream_response.send(Message::Binary(bytes)).await.expect("Can send the message on the websocket."); + + // Websocket async worker loading callback from the server. + let state_ws = state.clone(); + tokio::spawn(async move { + while let Some(msg) = stream_response.next().await { + let msg = match msg { + Ok(msg) => msg, + Err(err) => panic!("An error happened while receiving messages from the websocket: {err:?}"), + }; + + match msg { + Message::Binary(bytes) => { + let response: TaskResponse = rmp_serde::from_slice(&bytes).expect("Can deserialize messages from the websocket."); + let mut state = state_ws.lock().await; + state.on_response(response).await; + } + Message::Close(_) => { + log::warn!("Closed connection"); + return; + }, + _ => panic!("Unsupported websocket message: {msg:?}"), + }; + } + }); + + // Channel async worker sending operations to the server. + tokio::spawn(async move { + while let Some(req) = rec.recv().await { + let task = match req { + ClientRequest::WithSyncCallback(task, callback) => { + let mut state = state.lock().await; + if let Task::Compute(_content, id) = &task { + state.register_callback(*id, callback); + } + task + } + ClientRequest::WithoutCallback(task) => task, + + }; + let bytes = rmp_serde::to_vec(&task).expect("Can serialize tasks to bytes."); + stream_request.send(Message::Binary(bytes)).await.expect("Can send the message on the websocket."); + } + }); + }); + + WsClient::new(device, sender, runtime) + } +} diff --git a/crates/burn-remote/src/lib.rs b/crates/burn-remote/src/lib.rs new file mode 100644 index 0000000000..6b920d298a --- /dev/null +++ b/crates/burn-remote/src/lib.rs @@ -0,0 +1,37 @@ +#[macro_use] +extern crate derive_new; + +#[cfg(feature = "client")] +pub(crate) mod client; + +#[cfg(feature = "server")] +pub mod server; + +pub(crate) mod shared; + +#[cfg(feature = "client")] +mod __client { + use super::*; + + use burn_router::BackendRouter; + use client::WsChannel; + + /// The remote backend allows you to run computation on a remote device. + /// + /// Make sure there is a running server before trying to connect to it. + /// + /// ```rust, ignore + /// fn main() { + /// let device = Default::default(); + /// let port = 3000; + /// + /// // You need to activate the `server` feature flag to have access to this function. + /// burn::server::start::(device, port); + /// } + ///``` + pub type RemoteBackend = BackendRouter; + + pub use client::WsDevice as RemoteDevice; +} +#[cfg(feature = "client")] +pub use __client::*; diff --git a/crates/burn-remote/src/server/base.rs b/crates/burn-remote/src/server/base.rs new file mode 100644 index 0000000000..169c262014 --- /dev/null +++ b/crates/burn-remote/src/server/base.rs @@ -0,0 +1,196 @@ +use std::{net::SocketAddr, sync::Arc}; + +use axum::{ + extract::{ + ws::{self, WebSocket, WebSocketUpgrade}, + State, + }, + response::IntoResponse, + routing::any, + Router, +}; + +use burn_tensor::{ + backend::{Backend, BackendBridge}, + repr::ReprBackend, + Device, +}; +use tracing_core::{Level, LevelFilter}; +use tracing_subscriber::prelude::*; +use tracing_subscriber::{filter::filter_fn, registry}; + +use crate::shared::{ComputeTask, Task}; + +use super::session::SessionManager; + +#[derive(Clone)] +pub struct WsServer { + state: Arc>, +} + +impl WsServer +where + // Restrict full precision backend handle to be the same + <::FullPrecisionBridge as BackendBridge>::Target: + ReprBackend, +{ + /// Start the server on the given address. + pub async fn start(device: Device, port: u16) { + let layer = tracing_subscriber::fmt::layer() + .with_filter(LevelFilter::INFO) + .with_filter(filter_fn(|m| { + if let Some(path) = m.module_path() { + // The wgpu crate is logging too much, so we skip `info` level. + if path.starts_with("wgpu") && *m.level() >= Level::INFO { + return false; + } + } + true + })); + registry().with(layer).init(); + + let address = format!("0.0.0.0:{port}"); + log::info!("Start server {address} on device {device:?}"); + + let state = SessionManager::::new(device); + let state = Self { + state: Arc::new(state), + }; + + // build our application with some routes + let app = Router::new() + .route("/response", any(Self::handler_response)) + .route("/request", any(Self::handler_request)) + .with_state(state); + + // run it with hyper + let listener = tokio::net::TcpListener::bind(address).await.unwrap(); + axum::serve( + listener, + app.into_make_service_with_connect_info::(), + ) + .await + .unwrap(); + } + + async fn handler_response( + ws: WebSocketUpgrade, + State(state): State, + ) -> impl IntoResponse { + ws.on_upgrade(move |socket| state.handle_socket_response(socket)) + } + + async fn handler_request(ws: WebSocketUpgrade, State(state): State) -> impl IntoResponse { + ws.on_upgrade(move |socket| state.handle_socket_request(socket)) + } + + async fn handle_socket_response(self, mut socket: WebSocket) { + log::info!("[Response Handler] On new connection."); + + let packet = socket.recv().await; + let msg = match packet { + Some(msg) => msg, + None => { + log::info!("Still no message"); + panic!(""); + } + }; + + if let Ok(ws::Message::Binary(bytes)) = msg { + let task = match rmp_serde::from_slice::(&bytes) { + Ok(val) => val, + Err(err) => { + log::info!("Only bytes messages are supported {err:?}"); + panic!(""); + } + }; + let id = match task { + Task::Init(id) => id, + _ => panic!(""), + }; + + let receiver = self.state.register_responder(id).await; + + log::info!("Response handler connection active"); + + while let Ok(callback) = receiver.recv() { + let response = callback.recv().unwrap(); + let bytes = rmp_serde::to_vec(&response).unwrap(); + + socket.send(ws::Message::Binary(bytes)).await.unwrap(); + } + } else { + panic!(""); + } + } + + async fn handle_socket_request(self, mut socket: WebSocket) { + log::info!("[Request Handler] On new connection."); + let mut session_id = None; + + loop { + let packet = socket.recv().await; + let msg = match packet { + Some(msg) => msg, + None => { + log::info!("Still no message"); + continue; + } + }; + + if let Ok(ws::Message::Binary(bytes)) = msg { + let task = match rmp_serde::from_slice::(&bytes) { + Ok(val) => val, + Err(err) => { + log::info!("Only bytes message in the json format are supported {err:?}"); + break; + } + }; + + let (stream, connection_id, task) = + match self.state.stream(&mut session_id, task).await { + Some(val) => val, + None => { + log::info!("Ops session activated {session_id:?}"); + continue; + } + }; + + match task { + ComputeTask::RegisterOperation(op) => { + stream.register_operation(op); + } + ComputeTask::RegisterTensor(id, data) => { + stream.register_tensor(id, data); + } + ComputeTask::RegisterOrphan(id) => { + stream.register_orphan(id); + } + ComputeTask::ReadTensor(tensor) => { + stream.read_tensor(connection_id, tensor); + } + ComputeTask::SyncBackend => { + stream.sync(connection_id); + } + } + } else { + log::info!("Not a binary message, closing, received {msg:?}"); + break; + }; + } + + log::info!("Closing connection"); + self.state.close(session_id).await; + } +} + +#[tokio::main] +/// Start the server on the given port and [device](Device). +pub async fn start(device: Device, port: u16) +where + // Restrict full precision backend handle to be the same + <::FullPrecisionBridge as BackendBridge>::Target: + ReprBackend, +{ + WsServer::::start(device, port).await; +} diff --git a/crates/burn-remote/src/server/mod.rs b/crates/burn-remote/src/server/mod.rs new file mode 100644 index 0000000000..68d36882be --- /dev/null +++ b/crates/burn-remote/src/server/mod.rs @@ -0,0 +1,7 @@ +pub(crate) mod processor; +pub(crate) mod session; +pub(crate) mod stream; + +mod base; + +pub use base::start; diff --git a/crates/burn-remote/src/server/processor.rs b/crates/burn-remote/src/server/processor.rs new file mode 100644 index 0000000000..ee0cd52e65 --- /dev/null +++ b/crates/burn-remote/src/server/processor.rs @@ -0,0 +1,84 @@ +use burn_router::{Runner, RunnerClient}; +use burn_tensor::{ + backend::{Backend, BackendBridge}, + repr::{OperationDescription, ReprBackend, TensorDescription, TensorId}, + TensorData, +}; +use core::marker::PhantomData; +use std::sync::mpsc::Sender; + +use crate::shared::{ConnectionId, TaskResponse, TaskResponseContent}; + +/// The goal of the processor is to asynchonously process compute tasks on it own thread. +pub struct Processor { + p: PhantomData, +} + +pub type Callback = Sender; + +pub enum ProcessorTask { + RegisterOperation(Box), + RegisterTensor(TensorId, TensorData), + ReadTensor(ConnectionId, TensorDescription, Callback), + Sync(ConnectionId, Callback), + Fence(Callback<()>), + RegisterOrphan(TensorId), + Close, +} + +impl Processor +where + // Restrict full precision backend handle to be the same + <::FullPrecisionBridge as BackendBridge>::Target: + ReprBackend, +{ + pub fn start(runner: Runner) -> Sender { + let (sender, rec) = std::sync::mpsc::channel(); + + std::thread::spawn(move || { + for item in rec.iter() { + match item { + ProcessorTask::RegisterOperation(op) => { + runner.register(*op); + } + ProcessorTask::RegisterOrphan(id) => { + runner.register_orphan(&id); + } + ProcessorTask::Sync(id, callback) => { + runner.sync(); + callback + .send(TaskResponse { + content: TaskResponseContent::SyncBackend, + id, + }) + .unwrap(); + } + ProcessorTask::RegisterTensor(id, data) => { + runner.register_tensor_data_id(id, data); + } + ProcessorTask::ReadTensor(id, tensor, callback) => { + let tensor = burn_common::future::block_on(runner.read_tensor(tensor)); + callback + .send(TaskResponse { + content: TaskResponseContent::ReadTensor(tensor), + id, + }) + .unwrap(); + } + ProcessorTask::Close => { + let device = runner.device(); + runner.sync(); + core::mem::drop(runner); + B::sync(&device); + return; + } + ProcessorTask::Fence(sender) => { + sender.send(()).unwrap(); + } + } + } + }); + + sender + } +} diff --git a/crates/burn-remote/src/server/session.rs b/crates/burn-remote/src/server/session.rs new file mode 100644 index 0000000000..e0ac508129 --- /dev/null +++ b/crates/burn-remote/src/server/session.rs @@ -0,0 +1,242 @@ +use burn_common::id::StreamId; +use burn_router::Runner; +use burn_tensor::{ + backend::{Backend, BackendBridge}, + repr::{ReprBackend, TensorDescription, TensorId, TensorStatus}, + Device, +}; +use std::{ + collections::HashMap, + sync::mpsc::{Receiver, Sender}, +}; +use tokio::sync::Mutex; + +use crate::shared::{ComputeTask, ConnectionId, SessionId, Task, TaskResponse}; + +use super::stream::Stream; + +/// A session manager control the creation of sessions. +/// +/// Each session manages its own stream, spawning one thread per stream to mimic the same behavior +/// a native backend would have. +pub struct SessionManager { + runner: Runner, + sessions: tokio::sync::Mutex>>, +} + +struct Session { + runner: Runner, + tensors: HashMap>, + streams: HashMap>, + sender: Sender>, + receiver: Option>>, +} + +impl SessionManager +where + // Restrict full precision backend handle to be the same + <::FullPrecisionBridge as BackendBridge>::Target: + ReprBackend, +{ + pub fn new(device: Device) -> Self { + Self { + runner: Runner::new(device), + sessions: Mutex::new(Default::default()), + } + } + + /// Register a new responder for the session. Only one responder can exist for a session for + /// now. + pub async fn register_responder( + &self, + session_id: SessionId, + ) -> Receiver> { + log::info!("Register responder for session {session_id}"); + let mut sessions = self.sessions.lock().await; + self.register_session(&mut sessions, session_id); + + let session = sessions.get_mut(&session_id).unwrap(); + session.init_responder() + } + + /// Get the stream for the current session and task. + pub async fn stream( + &self, + session_id: &mut Option, + task: Task, + ) -> Option<(Stream, ConnectionId, ComputeTask)> { + let mut sessions = self.sessions.lock().await; + + let session_id = match session_id { + Some(id) => *id, + None => match task { + Task::Init(id) => { + log::info!("Init requester for session {id}"); + *session_id = Some(id); + self.register_session(&mut sessions, id); + return None; + } + _ => panic!("The first message should initialize the session"), + }, + }; + + match sessions.get_mut(&session_id) { + Some(session) => { + let (task, connection_id) = match task { + Task::Compute(task, connection_id) => (task, connection_id), + _ => panic!("Only support compute tasks."), + }; + let stream = session.select(connection_id.stream_id, &task); + Some((stream, connection_id, task)) + } + None => { + panic!("To be initialized"); + } + } + } + + /// Close the session with the given id. + pub async fn close(&self, session_id: Option) { + if let Some(id) = session_id { + let mut sessions = self.sessions.lock().await; + if let Some(session) = sessions.get_mut(&id) { + session.close(); + } + } + } + + fn register_session(&self, sessions: &mut HashMap>, id: SessionId) { + sessions.entry(id).or_insert_with(|| { + log::info!("Creating a new session {id}"); + + Session::new(self.runner.clone()) + }); + } +} + +impl Session +where + // Restrict full precision backend handle to be the same + <::FullPrecisionBridge as BackendBridge>::Target: + ReprBackend, +{ + fn new(runner: Runner) -> Self { + let (sender, reveiver) = std::sync::mpsc::channel(); + Self { + runner, + tensors: Default::default(), + streams: Default::default(), + sender, + receiver: Some(reveiver), + } + } + + fn init_responder(&mut self) -> Receiver> { + let mut receiver = None; + core::mem::swap(&mut receiver, &mut self.receiver); + receiver.expect("Only one responder per session is possible.") + } + + /// Select the current [stream](Stream) based on the given task. + fn select(&mut self, stream_id: StreamId, task: &ComputeTask) -> Stream { + // We have to check every streams involved in the last operation, making + // sure the backend is up-to-date with those operations. + // + // 1. We update the tensor status of all tensors in the task. + // 2. We don't keep track of tensors that are used for the last time. + let mut fences = Vec::new(); + for (tensor_id, status) in task.tensors_info() { + let tensor_stream_ids = match self.tensors.get(&tensor_id) { + Some(val) => val, + None => { + if status != TensorStatus::ReadWrite { + // Add the first stream that created the tensor that may be used by other + // streams later. + self.register_tensor(tensor_id, stream_id); + } + continue; + } + }; + + let current_stream_already_synced = tensor_stream_ids.contains(&stream_id); + + if !current_stream_already_synced { + // We only need to sync to the first stream that created the tensor. + if let Some(id) = tensor_stream_ids.iter().next() { + fences.push(*id); + } + } + + // We add the stream to the list of updated stream to avoid needed to flush other + // operations that might use this tensor. + self.register_tensor(tensor_id, stream_id); + + // If the tensor has the status `read_write`, it means no other stream can reuse it + // afterward, so we remove it from the state. + if status == TensorStatus::ReadWrite { + self.tensors.remove(&tensor_id); + } + } + + // Cleanup orphans. + if let ComputeTask::RegisterOrphan(tensor_id) = task { + self.tensors.remove(tensor_id); + } + + // We have to wait for the streams to be updated. + for stream_id in fences { + if let Some(stream) = self.streams.get(&stream_id) { + stream.fence_sync(); + } + } + + // We return the stream. + match self.streams.get(&stream_id) { + Some(stream) => stream.clone(), + None => { + let stream = Stream::::new(self.runner.clone(), self.sender.clone()); + self.streams.insert(stream_id, stream.clone()); + stream + } + } + } + + fn register_tensor(&mut self, tensor_id: TensorId, stream_id: StreamId) { + match self.tensors.get_mut(&tensor_id) { + Some(ids) => { + ids.push(stream_id); + } + None => { + self.tensors.insert(tensor_id, vec![stream_id]); + } + } + } + + // Close all streams created in the session. + fn close(&mut self) { + for (id, stream) in self.streams.drain() { + log::info!("Closing stream {id}"); + stream.close(); + } + } +} + +impl ComputeTask { + fn tensors_info(&self) -> Vec<(TensorId, TensorStatus)> { + fn from_descriptions(desc: &[&TensorDescription]) -> Vec<(TensorId, TensorStatus)> { + desc.iter().map(|t| (t.id, t.status.clone())).collect() + } + + match self { + ComputeTask::RegisterOperation(op) => from_descriptions(&op.nodes()), + ComputeTask::RegisterTensor(tensor_id, _tensor_data) => { + vec![(*tensor_id, TensorStatus::NotInit)] + } + ComputeTask::RegisterOrphan(tensor_id) => { + vec![(*tensor_id, TensorStatus::ReadWrite)] + } + ComputeTask::ReadTensor(tensor_description) => from_descriptions(&[tensor_description]), + ComputeTask::SyncBackend => vec![], + } + } +} diff --git a/crates/burn-remote/src/server/stream.rs b/crates/burn-remote/src/server/stream.rs new file mode 100644 index 0000000000..5ade2994c6 --- /dev/null +++ b/crates/burn-remote/src/server/stream.rs @@ -0,0 +1,94 @@ +use core::marker::PhantomData; +use std::sync::mpsc::{Receiver, Sender}; + +use crate::shared::{ConnectionId, TaskResponse}; + +use super::processor::{Processor, ProcessorTask}; +use burn_router::Runner; +use burn_tensor::{ + backend::{Backend, BackendBridge}, + repr::{OperationDescription, ReprBackend, TensorDescription, TensorId}, + TensorData, +}; + +/// A stream makes sure all operations registered are executed in the order they were sent to the +/// server, protentially waiting to reconstruct consistency. +#[derive(Clone)] +pub struct Stream { + compute_sender: Sender, + writer_sender: Sender>, + _p: PhantomData, +} + +impl Stream +where + // Restrict full precision backend handle to be the same + <::FullPrecisionBridge as BackendBridge>::Target: + ReprBackend, +{ + pub fn new(runner: Runner, writer_sender: Sender>) -> Self { + let sender = Processor::start(runner); + + Self { + compute_sender: sender, + writer_sender, + _p: PhantomData, + } + } + + pub fn register_operation(&self, op: Box) { + self.compute_sender + .send(ProcessorTask::RegisterOperation(op)) + .unwrap(); + } + + pub fn register_tensor(&self, tensor_id: TensorId, data: TensorData) { + self.compute_sender + .send(ProcessorTask::RegisterTensor(tensor_id, data)) + .unwrap() + } + + pub fn register_orphan(&self, tensor_id: TensorId) { + self.compute_sender + .send(ProcessorTask::RegisterOrphan(tensor_id)) + .unwrap() + } + + pub fn read_tensor(&self, id: ConnectionId, desc: TensorDescription) { + let (callback_sender, callback_rec) = std::sync::mpsc::channel(); + + self.compute_sender + .send(ProcessorTask::ReadTensor(id, desc, callback_sender)) + .unwrap(); + + self.writer_sender.send(callback_rec).unwrap(); + } + + pub fn sync(&self, id: ConnectionId) { + let (callback_sender, callback_rec) = std::sync::mpsc::channel(); + + self.compute_sender + .send(ProcessorTask::Sync(id, callback_sender)) + .unwrap(); + + self.writer_sender.send(callback_rec).unwrap(); + } + + // Ensure that all tasks are sent to the backend. + // + // It doesn't mean that the computation is done, but it means the backend has received the + // tasks, which may be queued. + pub fn fence_sync(&self) { + let (callback_sender, callback_rec) = std::sync::mpsc::channel(); + + self.compute_sender + .send(ProcessorTask::Fence(callback_sender.clone())) + .unwrap(); + + callback_rec.recv().unwrap(); + } + + pub fn close(&self) { + self.compute_sender.send(ProcessorTask::Close).unwrap(); + } +} diff --git a/crates/burn-remote/src/shared/mod.rs b/crates/burn-remote/src/shared/mod.rs new file mode 100644 index 0000000000..1b324d5124 --- /dev/null +++ b/crates/burn-remote/src/shared/mod.rs @@ -0,0 +1,2 @@ +mod task; +pub(crate) use task::*; diff --git a/crates/burn-remote/src/shared/task.rs b/crates/burn-remote/src/shared/task.rs new file mode 100644 index 0000000000..68b6209d59 --- /dev/null +++ b/crates/burn-remote/src/shared/task.rs @@ -0,0 +1,68 @@ +use std::fmt::Display; + +use burn_common::id::{IdGenerator, StreamId}; +use burn_tensor::{ + repr::{OperationDescription, TensorDescription, TensorId}, + TensorData, +}; +use serde::{Deserialize, Serialize}; + +#[allow(missing_docs)] +#[derive(new, Serialize, Deserialize, Debug, Hash, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)] +pub struct ConnectionId { + pub position: u64, + pub stream_id: StreamId, +} + +/// Unique identifier that can represent a session. +#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)] +pub struct SessionId { + id: u64, +} + +impl Display for SessionId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "SessionId({})", self.id) + } +} + +impl SessionId { + /// Create a new [session id](SessionId). + #[allow(dead_code)] + pub fn new() -> Self { + Self { + id: IdGenerator::generate(), + } + } +} + +#[allow(missing_docs)] +#[derive(Serialize, Deserialize, Debug)] +pub enum Task { + Compute(ComputeTask, ConnectionId), + Init(SessionId), +} + +#[allow(missing_docs)] +#[derive(Serialize, Deserialize, Debug)] +pub enum ComputeTask { + RegisterOperation(Box), + RegisterTensor(TensorId, TensorData), + RegisterOrphan(TensorId), + ReadTensor(TensorDescription), + SyncBackend, +} + +#[allow(missing_docs)] +#[derive(Serialize, Deserialize, Debug)] +pub struct TaskResponse { + pub content: TaskResponseContent, + pub id: ConnectionId, +} + +#[allow(missing_docs)] +#[derive(Serialize, Deserialize, Debug)] +pub enum TaskResponseContent { + ReadTensor(TensorData), + SyncBackend, +} diff --git a/crates/burn-router/Cargo.toml b/crates/burn-router/Cargo.toml index efe2c3da04..f6df54e59f 100644 --- a/crates/burn-router/Cargo.toml +++ b/crates/burn-router/Cargo.toml @@ -13,14 +13,15 @@ version.workspace = true [features] default = ["std"] -std = [] +std = ["burn-tensor/std", "burn-common/std"] doc = ["default"] [dependencies] burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = ["repr"]} +burn-common = { path = "../burn-common", version = "0.16.0", default-features = false} hashbrown = { workspace = true } spin = { workspace = true } - +log = { workspace = true } [dev-dependencies] burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, features = [ @@ -31,7 +32,7 @@ burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = ] } burn-ndarray = { path = "../burn-ndarray", version = "0.16.0" } -burn-wgpu = { path = "../burn-wgpu", version = "0.16.0" } +burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", default-features = false } [package.metadata.docs.rs] diff --git a/crates/burn-router/src/channel/direct.rs b/crates/burn-router/src/channel/direct.rs index 9250ff017d..a0f8814607 100644 --- a/crates/burn-router/src/channel/direct.rs +++ b/crates/burn-router/src/channel/direct.rs @@ -14,26 +14,3 @@ impl Clone for DirectChannel { } } } - -// NOTE: conflicting implementations because B1 and B2 cannot be differentiated (could be the same type) -// impl From>> -// for RouterTensor> -// { -// fn from(value: RouterTensor>) -> Self { -// RouterTensor { -// desc: value.desc, -// client: MultiRunnerClient2::RunnerClient1(value.client), -// } -// } -// } - -// impl From>> -// for RouterTensor> -// { -// fn from(value: RouterTensor>) -> Self { -// RouterTensor { -// desc: value.desc, -// client: MultiRunnerClient2::RunnerClient2(value.client), -// } -// } -// } diff --git a/crates/burn-router/src/runner.rs b/crates/burn-router/src/runner.rs index 2f408155b4..2b4ebc9a63 100644 --- a/crates/burn-router/src/runner.rs +++ b/crates/burn-router/src/runner.rs @@ -58,7 +58,8 @@ where <::FullPrecisionBridge as BackendBridge>::Target: ReprBackend, { - pub(crate) fn new(device: B::Device) -> Self { + /// Create a new runner. + pub fn new(device: B::Device) -> Self { Self { context: Arc::new(Mutex::new(RunnerContext { handles: HandleContainer::new(), @@ -90,7 +91,29 @@ where RouterTensor::new(id, shape, dtype, client) } - pub(crate) fn register_tensor_data_desc(&self, data: TensorData) -> TensorDescription { + /// Register a tensor from its data and id. + pub fn register_tensor_data_id(&self, id: TensorId, data: TensorData) { + let mut ctx = self.context.lock(); + let dtype = data.dtype; + + if dtype.is_float() { + let tensor = B::float_from_data(data, &self.device); + ctx.handles.register_float_tensor::(&id, tensor) + } else if dtype.is_int() { + let tensor = B::int_from_data(data, &self.device); + ctx.handles.register_int_tensor::(&id, tensor) + } else if dtype.is_bool() { + let tensor = B::bool_from_data(data, &self.device); + ctx.handles.register_bool_tensor::(&id, tensor) + } else if let DType::QFloat(_) = dtype { + todo!(); + } + + core::mem::drop(ctx); + } + + /// Register a tensor and returns its description. + pub fn register_tensor_data_desc(&self, data: TensorData) -> TensorDescription { let mut ctx = self.context.lock(); let id = ctx.create_empty_handle(); let shape = data.shape.clone(); @@ -119,11 +142,8 @@ where } } - pub(crate) fn register_empty_tensor_desc( - &self, - shape: Vec, - dtype: DType, - ) -> TensorDescription { + /// Register an empty tensor and returns its description. + pub fn register_empty_tensor_desc(&self, shape: Vec, dtype: DType) -> TensorDescription { let mut ctx = self.context.lock(); let id = ctx.create_empty_handle(); core::mem::drop(ctx); diff --git a/crates/burn-router/src/tensor.rs b/crates/burn-router/src/tensor.rs index c53f2069ba..15faf90817 100644 --- a/crates/burn-router/src/tensor.rs +++ b/crates/burn-router/src/tensor.rs @@ -21,7 +21,8 @@ pub struct RouterTensor { } impl RouterTensor { - pub(crate) fn new(id: Arc, shape: Vec, dtype: DType, client: C) -> Self { + /// Create a new router tensor. + pub fn new(id: Arc, shape: Vec, dtype: DType, client: C) -> Self { Self { id, shape, diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml index 34038a9463..de28869e36 100644 --- a/crates/burn/Cargo.toml +++ b/crates/burn/Cargo.toml @@ -55,6 +55,8 @@ ndarray = ["burn-core/ndarray"] tch = ["burn-core/tch"] wgpu = ["burn-core/wgpu"] wgpu-spirv = ["burn-core/wgpu-spirv"] +remote = ["burn-core/remote"] +server = ["burn-core/server"] # Network utils network = ["burn-core/network"] diff --git a/crates/burn/src/lib.rs b/crates/burn/src/lib.rs index 33aaed8ce3..b0ecf06a71 100644 --- a/crates/burn/src/lib.rs +++ b/crates/burn/src/lib.rs @@ -92,6 +92,7 @@ //! - `autodiff`: Makes available the Autodiff backend //! - Others: //! - `std`: Activates the standard library (deactivate for no_std) +//! - `server`: Enables the remote server. //! - `network`: Enables network utilities (currently, only a file downloader with progress bar) //! - `experimental-named-tensor`: Enables named tensors (experimental) //! diff --git a/examples/custom-training-loop/examples/custom-training-loop.rs b/examples/custom-training-loop/examples/custom-training-loop.rs index d3d5b30bbd..a418ede196 100644 --- a/examples/custom-training-loop/examples/custom-training-loop.rs +++ b/examples/custom-training-loop/examples/custom-training-loop.rs @@ -1,5 +1,5 @@ -use burn::backend::{wgpu::WgpuDevice, Autodiff, Wgpu}; +use burn::backend::{Autodiff, Wgpu}; fn main() { - custom_training_loop::run::>(WgpuDevice::default()); + custom_training_loop::run::>(Default::default()); } diff --git a/examples/server/Cargo.toml b/examples/server/Cargo.toml new file mode 100644 index 0000000000..5d06497e08 --- /dev/null +++ b/examples/server/Cargo.toml @@ -0,0 +1,18 @@ +[package] +authors = ["nathanielsimard "] +edition.workspace = true +license.workspace = true +name = "server" +publish = false +version.workspace = true + +[features] +default = ["wgpu"] +cuda-jit = ["burn/cuda-jit"] +wgpu = ["burn/wgpu"] +wgpu-spirv = ["wgpu", "burn/wgpu-spirv"] +ndarray = ["burn/ndarray"] + +[dependencies] +cfg-if = { workspace = true } +burn = { path = "../../crates/burn", version = "0.16.0", features = ["server"] } diff --git a/examples/server/examples/server.rs b/examples/server/examples/server.rs new file mode 100644 index 0000000000..8d79ebb8cb --- /dev/null +++ b/examples/server/examples/server.rs @@ -0,0 +1,3 @@ +fn main() { + server::start(); +} diff --git a/examples/server/src/lib.rs b/examples/server/src/lib.rs new file mode 100644 index 0000000000..d206771ea0 --- /dev/null +++ b/examples/server/src/lib.rs @@ -0,0 +1,20 @@ +pub fn start() { + let port = std::env::var("REMOTE_BACKEND_PORT") + .map(|port| match port.parse::() { + Ok(val) => val, + Err(err) => panic!("Invalid port, got {port} with error {err}"), + }) + .unwrap_or(3000); + + cfg_if::cfg_if! { + if #[cfg(feature = "ndarray")]{ + burn::server::start::(Default::default(), port); + } else if #[cfg(feature = "wgpu")] { + burn::server::start::(Default::default(), port); + } else if #[cfg(feature = "cuda-jit")]{ + burn::server::start::(Default::default(), port); + } else { + panic!("No backend selected, can't start server on port {port}"); + } + } +} diff --git a/examples/text-classification/Cargo.toml b/examples/text-classification/Cargo.toml index 1bbfaebcca..9ee0f6b8c4 100644 --- a/examples/text-classification/Cargo.toml +++ b/examples/text-classification/Cargo.toml @@ -17,6 +17,7 @@ tch-cpu = ["burn/tch"] tch-gpu = ["burn/tch"] wgpu = ["burn/wgpu"] wgpu-spirv = ["wgpu", "burn/wgpu-spirv"] +remote = ["burn/remote"] cuda-jit = ["burn/cuda-jit"] hip-jit = ["burn/hip-jit"] diff --git a/examples/text-classification/examples/ag-news-train.rs b/examples/text-classification/examples/ag-news-train.rs index d586fbd883..bf12a0b6d9 100644 --- a/examples/text-classification/examples/ag-news-train.rs +++ b/examples/text-classification/examples/ag-news-train.rs @@ -91,6 +91,16 @@ mod wgpu { } } +#[cfg(feature = "remote")] +mod remote { + use crate::{launch, ElemType}; + use burn::backend::{Autodiff, RemoteBackend}; + + pub fn run() { + launch::>(vec![Default::default()]); + } +} + #[cfg(feature = "cuda-jit")] mod cuda_jit { use crate::{launch, ElemType}; @@ -129,4 +139,6 @@ fn main() { cuda_jit::run(); #[cfg(feature = "hip-jit")] hip_jit::run(); + #[cfg(feature = "remote")] + remote::run(); } diff --git a/examples/text-classification/src/training.rs b/examples/text-classification/src/training.rs index f2b31ae4d1..fa07f09158 100644 --- a/examples/text-classification/src/training.rs +++ b/examples/text-classification/src/training.rs @@ -96,8 +96,6 @@ pub fn train( .metric_valid_numeric(AccuracyMetric::new()) .metric_train_numeric(LossMetric::new()) .metric_valid_numeric(LossMetric::new()) - .metric_train_numeric(LossMetric::new()) - .metric_valid_numeric(LossMetric::new()) .metric_train_numeric(LearningRateMetric::new()) .with_file_checkpointer(CompactRecorder::new()) .devices(devices)