Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prove req pattern, fix async stuff #1857

Merged
merged 1 commit into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ network-v2 = [
"dep:backoff",
]
cuda = ["sp1-cuda"]
blocking = []

[build-dependencies]
vergen = { version = "8", default-features = false, features = [
Expand Down
6 changes: 3 additions & 3 deletions crates/sdk/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl ProverClient {
}
}

pub async fn setup(&self, elf: Arc<[u8]>) -> (Arc<SP1ProvingKey>) {
pub async fn setup(&self, elf: Arc<[u8]>) -> Arc<SP1ProvingKey> {
self.inner.setup(elf).await
}

Expand Down Expand Up @@ -97,12 +97,12 @@ impl<T: BuildableProver> ProverClientBuilder<T> {
}

impl ProverClientBuilder<NetworkProverBuilder> {
pub fn with_rpc_url(mut self, url: String) -> Self {
pub fn rpc_url(mut self, url: String) -> Self {
self.inner_builder = self.inner_builder.rpc_url(url);
self
}

pub fn with_private_key(mut self, key: String) -> Self {
pub fn private_key(mut self, key: String) -> Self {
self.inner_builder = self.inner_builder.private_key(key);
self
}
Expand Down
127 changes: 82 additions & 45 deletions crates/sdk/src/local/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ impl LocalProver {
fn sp1_prover(&self) -> &SP1Prover {
&self.prover
}

pub fn prove(&self, pk: Arc<SP1ProvingKey>, stdin: SP1Stdin) -> LocalProofRequest {
LocalProofRequest::new(self, &pk, stdin)
}
}

pub struct LocalProverBuilder {}
Expand All @@ -61,7 +65,7 @@ impl LocalProverBuilder {

pub struct LocalProofRequest<'a> {
pub prover: &'a LocalProver,
pub pk: Arc<SP1ProvingKey>,
pub pk: &'a Arc<SP1ProvingKey>,
pub stdin: SP1Stdin,
pub mode: Mode,
pub timeout: u64,
Expand All @@ -70,7 +74,7 @@ pub struct LocalProofRequest<'a> {
}

impl<'a> LocalProofRequest<'a> {
pub fn new(prover: &'a LocalProver, pk: Arc<SP1ProvingKey>, stdin: SP1Stdin) -> Self {
pub fn new(prover: &'a LocalProver, pk: &'a Arc<SP1ProvingKey>, stdin: SP1Stdin) -> Self {
Self {
prover,
pk,
Expand Down Expand Up @@ -117,9 +121,24 @@ impl<'a> LocalProofRequest<'a> {
self
}

pub fn run(self) -> Result<SP1ProofWithPublicValues> {
let context = SP1Context::default();
Self::run_inner(
&self.prover.prover,
&**self.pk,
self.stdin,
self.mode,
self.timeout,
self.version,
self.sp1_prover_opts,
)
}
}

impl LocalProver {
fn run_inner(
prover: Arc<SP1Prover<DefaultProverComponents>>,
pk: Arc<SP1ProvingKey>,
prover: &SP1Prover<DefaultProverComponents>,
pk: &SP1ProvingKey,
stdin: SP1Stdin,
mode: Mode,
timeout: u64,
Expand Down Expand Up @@ -202,19 +221,6 @@ impl<'a> LocalProofRequest<'a> {

unreachable!()
}

pub fn run(self) -> Result<SP1ProofWithPublicValues> {
let context = SP1Context::default();
Self::run_inner(
Arc::clone(&self.prover.prover),
self.pk,
self.stdin,
self.mode,
self.timeout,
self.version,
self.sp1_prover_opts,
)
}
}

#[async_trait]
Expand All @@ -227,12 +233,14 @@ impl Prover for LocalProver {
})
.await
.unwrap();

result
}

#[cfg(feature = "blocking")]
fn setup_sync(&self, elf: &[u8]) -> Arc<SP1ProvingKey> {
let (pk, _vk) = self.prover.setup(elf);

Arc::new(pk)
}

Expand Down Expand Up @@ -262,23 +270,25 @@ impl Prover for LocalProver {
stdin: SP1Stdin,
opts: ProofOpts,
) -> Result<SP1ProofWithPublicValues> {
let prover = Arc::clone(&self.prover);
let mut req = self.prove(pk, stdin);

task::spawn_blocking(move || {
let context = SP1Context::default();
if let Some(mode) = opts.mode {
req.mode = mode;
}

LocalProofRequest::run_inner(
prover,
pk,
stdin,
opts.mode,
opts.timeout,
SP1_CIRCUIT_VERSION.to_string(),
SP1ProverOpts::default(),
)
})
.await
.unwrap()
if let Some(timeout) = opts.timeout {
req.timeout = timeout;
}

if let Some(version) = opts.version {
req.version = version;
}

if let Some(sp1_prover_opts) = opts.sp1_prover_opts {
req.sp1_prover_opts = sp1_prover_opts;
}

req.await
}

#[cfg(feature = "blocking")]
Expand All @@ -288,18 +298,25 @@ impl Prover for LocalProver {
stdin: SP1Stdin,
opts: ProofOpts,
) -> Result<SP1ProofWithPublicValues> {
let context = SP1Context::default();
let mut req = self.prove(pk, stdin);

LocalProofRequest::run_inner(
Arc::clone(&self.prover),
pk,
stdin,
opts.mode,
opts.timeout,
SP1_CIRCUIT_VERSION.to_string(),
SP1ProverOpts::default(),
context,
)
if let Some(mode) = opts.mode {
req.mode = mode;
}

if let Some(timeout) = opts.timeout {
req.timeout = timeout;
}

if let Some(version) = opts.version {
req.version = version;
}

if let Some(sp1_prover_opts) = opts.sp1_prover_opts {
req.sp1_prover_opts = sp1_prover_opts;
}

req.run()
}

async fn verify(
Expand All @@ -308,6 +325,7 @@ impl Prover for LocalProver {
vk: Arc<SP1VerifyingKey>,
) -> Result<(), SP1VerificationError> {
let prover = Arc::clone(&self.prover);

task::spawn_blocking(move || verify::verify(&prover, SP1_CIRCUIT_VERSION, &proof, &vk))
.await
.unwrap()
Expand All @@ -331,9 +349,28 @@ impl Default for LocalProver {

impl<'a> IntoFuture for LocalProofRequest<'a> {
type Output = Result<SP1ProofWithPublicValues>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output>>>;

fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { self.run() })
let LocalProofRequest { prover, pk, stdin, mode, timeout, version, sp1_prover_opts } = self;

let pk = Arc::clone(pk);
let prover = prover.prover.clone();

Box::pin(async move {
task::spawn_blocking(move || {
LocalProofRequest::run_inner(
&prover,
&**pk,
stdin,
mode,
timeout,
version,
sp1_prover_opts,
)
})
.await
.expect("To be able to join prove handle")
})
}
}
95 changes: 21 additions & 74 deletions crates/sdk/src/network-v2/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,63 +62,6 @@ impl NetworkProver {
}
}

/// Sets the proof mode to core.
pub fn core(mut self) -> Self {
self.network_client.mode = Mode::Core;
self
}

/// Sets the proof mode to compressed.
pub fn compressed(mut self) -> Self {
self.network_client.mode = Mode::Compressed;
self
}

/// Sets the proof mode to plonk.
pub fn plonk(mut self) -> Self {
self.network_client.mode = Mode::Plonk;
self
}

/// Sets the proof mode to groth16.
pub fn groth16(mut self) -> Self {
self.network_client.mode = Mode::Groth16;
self
}

/// Sets the RPC URL for the prover network.
///
/// This configures the endpoint that will be used for all network operations.
/// If not set, the default RPC URL will be used.
pub fn timeout(mut self, timeout: u64) -> Self {
self.network_client.timeout = Some(timeout);
self
}

/// Sets the cycle limit for the prover network.
///
/// See `get_cycle_limit` for more details the final cycle limit is determined.
pub fn cycle_limit(mut self, limit: u64) -> Self {
self.network_client.cycle_limit = Some(limit);
self
}

/// Skips simulation when determining the cycle limit.
///
/// See `get_cycle_limit` for more details the final cycle limit is determined.
pub fn skip_simulation(mut self, skip: bool) -> Self {
self.network_client.skip_simulation = skip;
self
}

/// Sets the fulfillment strategy for the prover network.
///
/// See `request_proof` for more details the final cycle limit is determined.
pub fn strategy(mut self, strategy: FulfillmentStrategy) -> Self {
self.network_client.strategy = Some(strategy);
self
}

/// Get the cycle limit to used for a proof request.
///
/// The cycle limit is determined according to the following priority:
Expand Down Expand Up @@ -152,17 +95,13 @@ impl NetworkProver {
}

/// Registers a program if it is not already registered.
pub async fn register_program(
&self,
vk: &SP1VerifyingKey,
elf: &[u8],
) -> Result<VerifyingKeyHash> {
async fn register_program(&self, vk: &SP1VerifyingKey, elf: &[u8]) -> Result<VerifyingKeyHash> {
self.network_client.register_program(vk, elf).await
}

/// Requests a proof from the prover network, returning the request ID.
#[allow(clippy::too_many_arguments)]
pub async fn request_proof(
async fn request_proof(
&self,
vk_hash: &VerifyingKeyHash,
stdin: &SP1Stdin,
Expand Down Expand Up @@ -197,7 +136,7 @@ impl NetworkProver {
///
/// The proof request must have already been submitted. This function will return a
/// `RequestTimedOut` error if the request does not received a response within the timeout.
pub async fn wait_proof<P: DeserializeOwned>(
async fn wait_proof<P: DeserializeOwned>(
&self,
request_id: &RequestId,
timeout_secs: u64,
Expand Down Expand Up @@ -254,14 +193,6 @@ impl NetworkProver {
}
}

pub fn prove_with_options<'a>(
&'a self,
pk: Arc<SP1ProvingKey>,
stdin: SP1Stdin,
) -> NetworkProofRequest<'a> {
NetworkProofRequest::new(self, pk, stdin)
}

/// Creates a new network prover builder. See [`NetworkProverBuilder`] for more details.
pub fn builder() -> NetworkProverBuilder {
NetworkProverBuilder::new()
Expand Down Expand Up @@ -327,8 +258,23 @@ impl<'a> NetworkProofRequest<'a> {
}
}

pub fn with_mode(mut self, mode: Mode) -> Self {
self.mode = mode.into();
pub fn groth16(mut self) -> Self {
self.mode = ProofMode::Groth16;
self
}

pub fn plonk(mut self) -> Self {
self.mode = ProofMode::Plonk;
self
}

pub fn core(mut self) -> Self {
self.mode = ProofMode::Core;
self
}

pub fn compressed(mut self) -> Self {
self.mode = ProofMode::Compressed;
self
}

Expand Down Expand Up @@ -445,6 +391,7 @@ impl Prover for NetworkProver {
.with_mode(opts.mode)
.with_timeout(opts.timeout)
.with_cycle_limit(opts.cycle_limit);

request.run_inner().await
}

Expand Down
2 changes: 1 addition & 1 deletion crates/sdk/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,6 @@ impl<'a> IntoFuture for DynProofRequest<'a> {
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;

fn into_future(self) -> Self::IntoFuture {
Box::pin(self.prover.prove_with_options(self.pk, self.stdin, self.opts))
self.prover.prove_with_options(self.pk, self.stdin, self.opts)
}
}
2 changes: 1 addition & 1 deletion crates/sdk/src/verify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ use strum_macros::EnumString;
use thiserror::Error;

use crate::install::try_install_circuit_artifacts;
use crate::opts::ProofOpts;
use crate::local::SP1VerificationError;
use crate::opts::ProofOpts;
use crate::{proof::SP1Proof, proof::SP1ProofKind, proof::SP1ProofWithPublicValues};

/// Verify that an SP1 proof is valid given its vkey and metadata.
Expand Down
Loading