Skip to content

Commit

Permalink
Remote Backend (#2463)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Nov 7, 2024
1 parent 9b9b03c commit 099b6dc
Show file tree
Hide file tree
Showing 38 changed files with 1,585 additions and 39 deletions.
179 changes: 179 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ bincode = { version = "2.0.0-rc.3", features = [
# The following packages disable the "std" feature for no_std compatibility
#
derive-new = { version = "0.7.0", default-features = false }
cfg-if = "1.0.0"

blas-src = { version = "0.10.0", default-features = false }
half = { version = "2.4.1", features = [
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ getrandom = { workspace = true, features = ["js"] }
web-time = { version = "1.1.0" }

[dependencies]
serde = { workspace = true }

# Network downloader
indicatif = { workspace = true, optional = true }
reqwest = { workspace = true, optional = true }
Expand Down
47 changes: 47 additions & 0 deletions crates/burn-common/src/id.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::rand::gen_random;
use serde::{Deserialize, Serialize};

/// Simple ID generator.
pub struct IdGenerator {}
Expand Down Expand Up @@ -64,3 +65,49 @@ mod tests {
assert_eq!(set.len(), EXPECTED_TOTAL_IDS);
}
}

/// Unique identifier that can represent a stream based on the current thread id.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub struct StreamId {
/// The value representing the thread id.
pub value: u64,
}

impl StreamId {
/// Get the current thread id.
pub fn current() -> Self {
Self {
#[cfg(feature = "std")]
value: Self::from_current_thread(),
#[cfg(not(feature = "std"))]
value: 0,
}
}

#[cfg(feature = "std")]
fn from_current_thread() -> u64 {
use core::hash::Hash;

std::thread_local! {
static ID: std::cell::OnceCell::<u64> = const { std::cell::OnceCell::new() };
};

// Getting the current thread is expensive, so we cache the value into a thread local
// variable, which is very fast.
ID.with(|cell| {
*cell.get_or_init(|| {
// A way to get a thread id encoded as u64.
let mut hasher = std::hash::DefaultHasher::default();
let id = std::thread::current().id();
id.hash(&mut hasher);
std::hash::Hasher::finish(&hasher)
})
})
}
}

impl core::fmt::Display for StreamId {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!("StreamId({:?})", self.value))
}
}
5 changes: 5 additions & 0 deletions crates/burn-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ doc = [
"hip-jit",
"vision",
"autodiff",
"remote",
"server",
# Doc features
"burn-candle/doc",
"burn-common/doc",
Expand Down Expand Up @@ -86,6 +88,8 @@ metal = ["burn-candle?/metal"]
openblas = ["burn-ndarray?/blas-openblas"]
openblas-system = ["burn-ndarray?/blas-openblas-system"]
template = ["burn-wgpu?/template"]
remote = ["burn-remote/client"]
server = ["burn-remote/server"]

candle = ["burn-candle"]
candle-cuda = ["candle", "burn-candle/cuda"]
Expand Down Expand Up @@ -131,6 +135,7 @@ burn-hip = { path = "../burn-hip", version = "0.16.0", optional = true, default-
burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", optional = true, default-features = false }
burn-tch = { path = "../burn-tch", version = "0.16.0", optional = true }
burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", optional = true, default-features = false }
burn-remote = { path = "../burn-remote", version = "0.16.0", default-features = false, optional = true }

data-encoding = { workspace = true }
uuid = { workspace = true }
Expand Down
Loading

0 comments on commit 099b6dc

Please sign in to comment.