From b1bfa9dba3024bac878466f8796120898955a165 Mon Sep 17 00:00:00 2001 From: nhtyy Date: Wed, 11 Dec 2024 16:26:34 -0800 Subject: [PATCH] fix: prove req pattern, fix async stuff --- crates/sdk/Cargo.toml | 1 + crates/sdk/src/client.rs | 6 +- crates/sdk/src/local/prover.rs | 127 ++++++++++++++++++---------- crates/sdk/src/network-v2/prover.rs | 95 +++++---------------- crates/sdk/src/request.rs | 2 +- crates/sdk/src/verify.rs | 2 +- 6 files changed, 109 insertions(+), 124 deletions(-) diff --git a/crates/sdk/Cargo.toml b/crates/sdk/Cargo.toml index f6fe5acdd..b0cad306e 100644 --- a/crates/sdk/Cargo.toml +++ b/crates/sdk/Cargo.toml @@ -87,6 +87,7 @@ network-v2 = [ "dep:backoff", ] cuda = ["sp1-cuda"] +blocking = [] [build-dependencies] vergen = { version = "8", default-features = false, features = [ diff --git a/crates/sdk/src/client.rs b/crates/sdk/src/client.rs index b53182f22..8aaea2a96 100644 --- a/crates/sdk/src/client.rs +++ b/crates/sdk/src/client.rs @@ -51,7 +51,7 @@ impl ProverClient { } } - pub async fn setup(&self, elf: Arc<[u8]>) -> (Arc) { + pub async fn setup(&self, elf: Arc<[u8]>) -> Arc { self.inner.setup(elf).await } @@ -97,12 +97,12 @@ impl ProverClientBuilder { } impl ProverClientBuilder { - pub fn with_rpc_url(mut self, url: String) -> Self { + pub fn rpc_url(mut self, url: String) -> Self { self.inner_builder = self.inner_builder.rpc_url(url); self } - pub fn with_private_key(mut self, key: String) -> Self { + pub fn private_key(mut self, key: String) -> Self { self.inner_builder = self.inner_builder.private_key(key); self } diff --git a/crates/sdk/src/local/prover.rs b/crates/sdk/src/local/prover.rs index e4c052a7f..75f16fec1 100644 --- a/crates/sdk/src/local/prover.rs +++ b/crates/sdk/src/local/prover.rs @@ -45,6 +45,10 @@ impl LocalProver { fn sp1_prover(&self) -> &SP1Prover { &self.prover } + + pub fn prove(&self, pk: Arc, stdin: SP1Stdin) -> LocalProofRequest { + LocalProofRequest::new(self, &pk, stdin) + } } pub struct LocalProverBuilder {} @@ -61,7 +65,7 @@ impl LocalProverBuilder { pub struct LocalProofRequest<'a> { pub prover: &'a LocalProver, - pub pk: Arc, + pub pk: &'a Arc, pub stdin: SP1Stdin, pub mode: Mode, pub timeout: u64, @@ -70,7 +74,7 @@ pub struct LocalProofRequest<'a> { } impl<'a> LocalProofRequest<'a> { - pub fn new(prover: &'a LocalProver, pk: Arc, stdin: SP1Stdin) -> Self { + pub fn new(prover: &'a LocalProver, pk: &'a Arc, stdin: SP1Stdin) -> Self { Self { prover, pk, @@ -117,9 +121,24 @@ impl<'a> LocalProofRequest<'a> { self } + pub fn run(self) -> Result { + let context = SP1Context::default(); + Self::run_inner( + &self.prover.prover, + &**self.pk, + self.stdin, + self.mode, + self.timeout, + self.version, + self.sp1_prover_opts, + ) + } +} + +impl LocalProver { fn run_inner( - prover: Arc>, - pk: Arc, + prover: &SP1Prover, + pk: &SP1ProvingKey, stdin: SP1Stdin, mode: Mode, timeout: u64, @@ -202,19 +221,6 @@ impl<'a> LocalProofRequest<'a> { unreachable!() } - - pub fn run(self) -> Result { - let context = SP1Context::default(); - Self::run_inner( - Arc::clone(&self.prover.prover), - self.pk, - self.stdin, - self.mode, - self.timeout, - self.version, - self.sp1_prover_opts, - ) - } } #[async_trait] @@ -227,12 +233,14 @@ impl Prover for LocalProver { }) .await .unwrap(); + result } #[cfg(feature = "blocking")] fn setup_sync(&self, elf: &[u8]) -> Arc { let (pk, _vk) = self.prover.setup(elf); + Arc::new(pk) } @@ -262,23 +270,25 @@ impl Prover for LocalProver { stdin: SP1Stdin, opts: ProofOpts, ) -> Result { - let prover = Arc::clone(&self.prover); + let mut req = self.prove(pk, stdin); - task::spawn_blocking(move || { - let context = SP1Context::default(); + if let Some(mode) = opts.mode { + req.mode = mode; + } - LocalProofRequest::run_inner( - prover, - pk, - stdin, - opts.mode, - opts.timeout, - SP1_CIRCUIT_VERSION.to_string(), - SP1ProverOpts::default(), - ) - }) - .await - .unwrap() + if let Some(timeout) = opts.timeout { + req.timeout = timeout; + } + + if let Some(version) = opts.version { + req.version = version; + } + + if let Some(sp1_prover_opts) = opts.sp1_prover_opts { + req.sp1_prover_opts = sp1_prover_opts; + } + + req.await } #[cfg(feature = "blocking")] @@ -288,18 +298,25 @@ impl Prover for LocalProver { stdin: SP1Stdin, opts: ProofOpts, ) -> Result { - let context = SP1Context::default(); + let mut req = self.prove(pk, stdin); - LocalProofRequest::run_inner( - Arc::clone(&self.prover), - pk, - stdin, - opts.mode, - opts.timeout, - SP1_CIRCUIT_VERSION.to_string(), - SP1ProverOpts::default(), - context, - ) + if let Some(mode) = opts.mode { + req.mode = mode; + } + + if let Some(timeout) = opts.timeout { + req.timeout = timeout; + } + + if let Some(version) = opts.version { + req.version = version; + } + + if let Some(sp1_prover_opts) = opts.sp1_prover_opts { + req.sp1_prover_opts = sp1_prover_opts; + } + + req.run() } async fn verify( @@ -308,6 +325,7 @@ impl Prover for LocalProver { vk: Arc, ) -> Result<(), SP1VerificationError> { let prover = Arc::clone(&self.prover); + task::spawn_blocking(move || verify::verify(&prover, SP1_CIRCUIT_VERSION, &proof, &vk)) .await .unwrap() @@ -331,9 +349,28 @@ impl Default for LocalProver { impl<'a> IntoFuture for LocalProofRequest<'a> { type Output = Result; - type IntoFuture = Pin + Send + 'a>>; + type IntoFuture = Pin>>; fn into_future(self) -> Self::IntoFuture { - Box::pin(async move { self.run() }) + let LocalProofRequest { prover, pk, stdin, mode, timeout, version, sp1_prover_opts } = self; + + let pk = Arc::clone(pk); + let prover = prover.prover.clone(); + + Box::pin(async move { + task::spawn_blocking(move || { + LocalProofRequest::run_inner( + &prover, + &**pk, + stdin, + mode, + timeout, + version, + sp1_prover_opts, + ) + }) + .await + .expect("To be able to join prove handle") + }) } } diff --git a/crates/sdk/src/network-v2/prover.rs b/crates/sdk/src/network-v2/prover.rs index 184d8a28d..99af25ca6 100644 --- a/crates/sdk/src/network-v2/prover.rs +++ b/crates/sdk/src/network-v2/prover.rs @@ -62,63 +62,6 @@ impl NetworkProver { } } - /// Sets the proof mode to core. - pub fn core(mut self) -> Self { - self.network_client.mode = Mode::Core; - self - } - - /// Sets the proof mode to compressed. - pub fn compressed(mut self) -> Self { - self.network_client.mode = Mode::Compressed; - self - } - - /// Sets the proof mode to plonk. - pub fn plonk(mut self) -> Self { - self.network_client.mode = Mode::Plonk; - self - } - - /// Sets the proof mode to groth16. - pub fn groth16(mut self) -> Self { - self.network_client.mode = Mode::Groth16; - self - } - - /// Sets the RPC URL for the prover network. - /// - /// This configures the endpoint that will be used for all network operations. - /// If not set, the default RPC URL will be used. - pub fn timeout(mut self, timeout: u64) -> Self { - self.network_client.timeout = Some(timeout); - self - } - - /// Sets the cycle limit for the prover network. - /// - /// See `get_cycle_limit` for more details the final cycle limit is determined. - pub fn cycle_limit(mut self, limit: u64) -> Self { - self.network_client.cycle_limit = Some(limit); - self - } - - /// Skips simulation when determining the cycle limit. - /// - /// See `get_cycle_limit` for more details the final cycle limit is determined. - pub fn skip_simulation(mut self, skip: bool) -> Self { - self.network_client.skip_simulation = skip; - self - } - - /// Sets the fulfillment strategy for the prover network. - /// - /// See `request_proof` for more details the final cycle limit is determined. - pub fn strategy(mut self, strategy: FulfillmentStrategy) -> Self { - self.network_client.strategy = Some(strategy); - self - } - /// Get the cycle limit to used for a proof request. /// /// The cycle limit is determined according to the following priority: @@ -152,17 +95,13 @@ impl NetworkProver { } /// Registers a program if it is not already registered. - pub async fn register_program( - &self, - vk: &SP1VerifyingKey, - elf: &[u8], - ) -> Result { + async fn register_program(&self, vk: &SP1VerifyingKey, elf: &[u8]) -> Result { self.network_client.register_program(vk, elf).await } /// Requests a proof from the prover network, returning the request ID. #[allow(clippy::too_many_arguments)] - pub async fn request_proof( + async fn request_proof( &self, vk_hash: &VerifyingKeyHash, stdin: &SP1Stdin, @@ -197,7 +136,7 @@ impl NetworkProver { /// /// The proof request must have already been submitted. This function will return a /// `RequestTimedOut` error if the request does not received a response within the timeout. - pub async fn wait_proof( + async fn wait_proof( &self, request_id: &RequestId, timeout_secs: u64, @@ -254,14 +193,6 @@ impl NetworkProver { } } - pub fn prove_with_options<'a>( - &'a self, - pk: Arc, - stdin: SP1Stdin, - ) -> NetworkProofRequest<'a> { - NetworkProofRequest::new(self, pk, stdin) - } - /// Creates a new network prover builder. See [`NetworkProverBuilder`] for more details. pub fn builder() -> NetworkProverBuilder { NetworkProverBuilder::new() @@ -327,8 +258,23 @@ impl<'a> NetworkProofRequest<'a> { } } - pub fn with_mode(mut self, mode: Mode) -> Self { - self.mode = mode.into(); + pub fn groth16(mut self) -> Self { + self.mode = ProofMode::Groth16; + self + } + + pub fn plonk(mut self) -> Self { + self.mode = ProofMode::Plonk; + self + } + + pub fn core(mut self) -> Self { + self.mode = ProofMode::Core; + self + } + + pub fn compressed(mut self) -> Self { + self.mode = ProofMode::Compressed; self } @@ -445,6 +391,7 @@ impl Prover for NetworkProver { .with_mode(opts.mode) .with_timeout(opts.timeout) .with_cycle_limit(opts.cycle_limit); + request.run_inner().await } diff --git a/crates/sdk/src/request.rs b/crates/sdk/src/request.rs index 715fc8d8c..8820a0cab 100644 --- a/crates/sdk/src/request.rs +++ b/crates/sdk/src/request.rs @@ -73,6 +73,6 @@ impl<'a> IntoFuture for DynProofRequest<'a> { type IntoFuture = Pin + Send + 'a>>; fn into_future(self) -> Self::IntoFuture { - Box::pin(self.prover.prove_with_options(self.pk, self.stdin, self.opts)) + self.prover.prove_with_options(self.pk, self.stdin, self.opts) } } diff --git a/crates/sdk/src/verify.rs b/crates/sdk/src/verify.rs index 2e8a0fe44..8bcdbe338 100644 --- a/crates/sdk/src/verify.rs +++ b/crates/sdk/src/verify.rs @@ -15,8 +15,8 @@ use strum_macros::EnumString; use thiserror::Error; use crate::install::try_install_circuit_artifacts; -use crate::opts::ProofOpts; use crate::local::SP1VerificationError; +use crate::opts::ProofOpts; use crate::{proof::SP1Proof, proof::SP1ProofKind, proof::SP1ProofWithPublicValues}; /// Verify that an SP1 proof is valid given its vkey and metadata.