Skip to content

Commit

Permalink
Implement zstd compression for multipart payloads on the Rust side.
Browse files Browse the repository at this point in the history
I've tested this with both the Python and Rust changes applied, and I've successfully sent zstd-compressed traces to LangSmith servers.

This introduces breaking changes in the APIs for both `langsmith-tracing-client` and `langsmith-pyo3`, so I'm bumping the left-most non-zero version number in each of their manifests. In Rust, the left-most non-zero number is considered a "major" version -- in other words, leading zeroes are ignored for SemVer purposes.

This will require us to publish a new `langsmith-pyo3` Python package version with the new changes. I'd like to trigger the publish workflow after this is merged, so I can make the corresponding Python changes to enable zstd end-to-end. Lmk if you have concerns or want me to hold off on publishing a new version.
  • Loading branch information
obi1kenobi committed Jan 9, 2025
1 parent bdcca18 commit c007aed
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 5 deletions.
51 changes: 49 additions & 2 deletions rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ resolver = "2"
chrono = "0.4.38"
flate2 = "1.0.34"
futures = "0.3.31"
http = "1.2.0"
rayon = "1.10.0"
serde = { version = "1.0.210", features = ["derive"] }
serde_json = "1.0.128"
Expand All @@ -20,6 +21,7 @@ tokio = { version = "1", features = ["full"] }
tokio-util = "0.7.12"
ureq = "2.10.1"
uuid = { version = "1.11.0", features = ["v4"] }
zstd = { version = "0.13.2", features = ["zstdmt"] }

# Use rustls instead of OpenSSL, because OpenSSL is a nightmare when compiling across platforms.
# OpenSSL is a default feature, so we have to disable all default features, then re-add
Expand Down
2 changes: 1 addition & 1 deletion rust/crates/langsmith-pyo3/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "langsmith-pyo3"
version = "0.1.0-rc5"
version = "0.2.0-rc1"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
2 changes: 2 additions & 0 deletions rust/crates/langsmith-pyo3/src/blocking_tracing_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ impl BlockingTracingClient {
batch_size: usize,
batch_timeout_millis: u64,
worker_threads: usize,
compression_level: i32,
) -> PyResult<Self> {
let config = langsmith_tracing_client::client::blocking::ClientConfig {
endpoint,
Expand All @@ -39,6 +40,7 @@ impl BlockingTracingClient {

headers: None, // TODO: support custom headers
num_worker_threads: worker_threads,
compression_level,
};

let client = RustTracingClient::new(config)
Expand Down
4 changes: 3 additions & 1 deletion rust/crates/langsmith-tracing-client/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "langsmith-tracing-client"
version = "0.1.0"
version = "0.2.0"
edition = "2021"

[dependencies]
Expand All @@ -17,6 +17,8 @@ futures = { workspace = true }
rayon = { workspace = true }
ureq = { workspace = true }
flate2 = { workspace = true }
zstd = { workspace = true }
http = { workspace = true }

[dev-dependencies]
multer = "3.1.0"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::HashMap;
use std::sync::mpsc::{Receiver, Sender};
use std::sync::{mpsc, Arc, Mutex};
use std::thread::available_parallelism;
use std::time::{Duration, Instant};

use rayon::iter::{IntoParallelIterator, ParallelIterator};
Expand Down Expand Up @@ -266,14 +267,53 @@ impl RunProcessor {
for (part_name, part) in json_parts.into_iter().chain(attachment_parts) {
form = form.part(part_name, part);
}
let content_type = format!("multipart/form-data; boundary={}", form.boundary());

// We want to use as many threads as available cores to compress data.
// However, we have to be mindful of special values in the zstd library:
// - A setting of `0` here means "use the current thread only."
// - A setting of `1` means "use a separate thread, but only one."
//
// `1` isn't a useful setting for us, so turn `1` into `0` while
// keeping higher numbers the same.
let n_workers = match available_parallelism() {
Ok(num) => {
if num.get() == 1 {
0
} else {
num.get() as u32
}
}
Err(_) => {
// We failed to query the available number of cores.
// Use only the current single thread, to be safe.
0
}
};
let compressed_data = {
let mut buffer = Vec::with_capacity(4096);
let mut encoder = zstd::Encoder::new(&mut buffer, self.config.compression_level)
.and_then(|mut encoder| {
encoder.multithread(n_workers)?;
Ok(encoder)
})
.map_err(|e| TracingClientError::IoError(format!("{e}")))?;
std::io::copy(&mut form.reader(), &mut encoder)
.map_err(|e| TracingClientError::IoError(format!("{e}")))?;
encoder.finish().map_err(|e| TracingClientError::IoError(format!("{e}")))?;

buffer
};

// send the multipart POST request
let start_send_batch = Instant::now();
let response = self
.http_client
.post(format!("{}/runs/multipart", self.config.endpoint))
.multipart(form)
.headers(self.config.headers.as_ref().cloned().unwrap_or_default())
.header(http::header::CONTENT_TYPE, content_type)
.header(http::header::CONTENT_ENCODING, "zstd")
.body(compressed_data)
.send()?;
// println!("Sending batch took {:?}", start_send_batch.elapsed());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub struct ClientConfig {
pub batch_timeout: Duration,
pub headers: Option<HeaderMap>,
pub num_worker_threads: usize,
pub compression_level: i32,
}

pub struct TracingClient {
Expand Down

0 comments on commit c007aed

Please sign in to comment.