diff --git a/Cargo.lock b/Cargo.lock index 79883868c4..b9486139e8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4908,6 +4908,17 @@ dependencies = [ "sp1-prover", ] +[[package]] +name = "sp1-ffi" +version = "0.1.0" +dependencies = [ + "bincode", + "fibonacci-script", + "sp1-core", + "sp1-prover", + "sp1-sdk", +] + [[package]] name = "sp1-helper" version = "1.0.1" diff --git a/Cargo.toml b/Cargo.toml index 2dca15cd78..1325020b63 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = [ "core", "derive", "eval", + "ffi", "helper", "primitives", "prover", @@ -42,11 +43,13 @@ debug = true debug-assertions = true [workspace.dependencies] +fibonacci-script = { path = "examples/fibonacci/script" } sp1-build = { path = "build", version = "1.0.1" } sp1-derive = { path = "derive", version = "1.0.1" } sp1-core = { path = "core", version = "1.0.1" } sp1-cli = { path = "cli", version = "1.0.1", default-features = false } sp1-eval = { path = "eval", version = "1.0.1", default-features = false } +sp1-ffi = { path = "ffi" } sp1-helper = { path = "helper", version = "1.0.1", default-features = false } sp1-primitives = { path = "primitives", version = "1.0.1" } sp1-prover = { path = "prover", version = "1.0.1" } @@ -80,7 +83,7 @@ p3-uni-stark = "0.1.3-succinct" p3-maybe-rayon = "0.1.3-succinct" p3-bn254-fr = "0.1.3-succinct" -# For local development. +# For local development. # p3-air = { path = "../Plonky3/air" } # p3-field = { path = "../Plonky3/field" } diff --git a/ffi/Cargo.toml b/ffi/Cargo.toml new file mode 100644 index 0000000000..32e9e80122 --- /dev/null +++ b/ffi/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "sp1-ffi" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +name = "sp1_ffi" +crate-type = ["cdylib"] + +[dependencies] +bincode = "1.3.3" + +fibonacci-script = { workspace = true } +sp1-core = { workspace = true } +sp1-prover = { workspace = true } +sp1-sdk = { workspace = true } diff --git a/ffi/src/lib.rs b/ffi/src/lib.rs new file mode 100644 index 0000000000..79b0469dda --- /dev/null +++ b/ffi/src/lib.rs @@ -0,0 +1,8 @@ +pub const FIBONACCI_ELF: &[u8] = + include_bytes!("../../examples/fibonacci/program/elf/riscv32im-succinct-zkvm-elf"); + +pub mod operator; +pub mod worker; + +pub use operator::*; +pub use worker::*; diff --git a/ffi/src/operator.rs b/ffi/src/operator.rs new file mode 100644 index 0000000000..a9f950bb4f --- /dev/null +++ b/ffi/src/operator.rs @@ -0,0 +1,453 @@ +use crate::FIBONACCI_ELF; +use fibonacci_script::FibonacciArgs; +use sp1_core::{stark::ShardProof, utils::BabyBearPoseidon2}; +use sp1_prover::{ReduceProgramType, SP1ReduceProof}; +use sp1_sdk::mmp::{ + common::ProveArgs, + operator::{ + operator_absorb_commits, operator_construct_sp1_core_proof, + operator_prepare_compress_input_chunks, operator_prepare_compress_inputs, + operator_prepare_plonk_witness, operator_prove_plonk, operator_prove_shrink, + operator_split_into_checkpoints, + }, + scenario::{compress_prove, core_prove, plonk_prove}, +}; + +#[no_mangle] +pub extern "C" fn operator_split_into_checkpoints_c( + o_public_values_stream_len_ptr: *mut usize, + o_public_values_stream_ptr: *mut *mut u8, + o_public_values_len_ptr: *mut usize, + o_public_values_ptr: *mut *mut u8, + o_num_checkpoints_ptr: *mut usize, + o_checkpoints_len_ptr: *mut *mut usize, + o_checkpoints_ptr: *mut *mut *mut u8, + o_cycles_ptr: *mut u64, +) { + // TODO(TomTaehoonKim): Remove this when args are passed from the caller. + let fibonacci_args = FibonacciArgs { n: 20, evm: false }; + let args = ProveArgs { + zkvm_input: fibonacci_args.n.to_le_bytes().to_vec(), + elf: FIBONACCI_ELF.to_vec(), + }; + let args_bytes = args.to_bytes(); + + let mut o_public_values_stream = Vec::new(); + let mut o_public_values_bytes = Vec::new(); + let mut o_checkpoints = Vec::new(); + let mut o_cycles = 0; + + operator_split_into_checkpoints::>( + &args_bytes, + &mut o_public_values_stream, + &mut o_public_values_bytes, + &mut o_checkpoints, + &mut o_cycles, + ); + + let checkpoints_len = o_checkpoints + .iter() + .map(|checkpoint| checkpoint.len()) + .collect::>(); + let checkpoints = o_checkpoints + .into_iter() + .map(|checkpoint| Box::into_raw(checkpoint.into_boxed_slice()) as *mut u8) + .collect::>(); + + unsafe { + *o_public_values_stream_len_ptr = o_public_values_stream.len(); + *o_public_values_stream_ptr = + Box::into_raw(o_public_values_stream.into_boxed_slice()) as *mut u8; + *o_public_values_len_ptr = o_public_values_bytes.len(); + *o_public_values_ptr = Box::into_raw(o_public_values_bytes.into_boxed_slice()) as *mut u8; + *o_num_checkpoints_ptr = checkpoints.len(); + *o_checkpoints_len_ptr = Box::into_raw(checkpoints_len.into_boxed_slice()) as *mut usize; + *o_checkpoints_ptr = Box::into_raw(checkpoints.into_boxed_slice()) as *mut *mut u8; + *o_cycles_ptr = o_cycles; + } +} + +#[no_mangle] +pub extern "C" fn operator_absorb_commits_c( + num_checkpoints: usize, + num_commitments_ptr: *const usize, + commitments_lens_ptr: *const *const usize, + commitments_vec_ptr: *const *const *const u8, + records_len_ptr: *const *const usize, + records_ptr: *const *const *const u8, + o_challenger_state_len_ptr: *mut usize, + o_challenger_state_ptr: *mut *mut u8, +) { + // TODO(TomTaehoonKim): Remove this when args are passed from the caller. + let fibonacci_args = FibonacciArgs { n: 20, evm: false }; + let args = ProveArgs { + zkvm_input: fibonacci_args.n.to_le_bytes().to_vec(), + elf: FIBONACCI_ELF.to_vec(), + }; + let args_bytes = args.to_bytes(); + + let mut commitments_lens = Vec::new(); + for i in 0..num_checkpoints { + let num_commitments = unsafe { *num_commitments_ptr.add(i) }; + let commitments_len = + unsafe { std::slice::from_raw_parts(*commitments_lens_ptr.add(i), num_commitments) }; + commitments_lens.push(commitments_len.to_vec()); + } + + let mut commitments_vec = Vec::new(); + for i in 0..num_checkpoints { + let num_commitments = unsafe { *num_commitments_ptr.add(i) }; + let commitments_ptr = + unsafe { std::slice::from_raw_parts(*commitments_vec_ptr.add(i), num_commitments) }; + let mut commitments = Vec::new(); + for j in 0..num_commitments { + let commitment = + unsafe { std::slice::from_raw_parts(commitments_ptr[j], commitments_lens[i][j]) }; + commitments.push(commitment.to_vec()); + } + commitments_vec.push(commitments); + } + + let mut records_lens = Vec::new(); + for i in 0..num_checkpoints { + let num_records = unsafe { *num_commitments_ptr.add(i) }; + let records_len = + unsafe { std::slice::from_raw_parts(*records_len_ptr.add(i), num_records) }; + records_lens.push(records_len.to_vec()); + } + + let mut records_vec = Vec::new(); + for i in 0..num_checkpoints { + let num_records = unsafe { *num_commitments_ptr.add(i) }; + let records_ptr = unsafe { std::slice::from_raw_parts(*records_ptr.add(i), num_records) }; + let mut records = Vec::new(); + for j in 0..num_records { + let record = unsafe { std::slice::from_raw_parts(records_ptr[j], records_lens[i][j]) }; + records.push(record.to_vec()); + } + records_vec.push(records); + } + + let mut o_challenger_state = Vec::new(); + operator_absorb_commits::>( + &args_bytes, + &commitments_vec, + &records_vec, + &mut o_challenger_state, + ); + + unsafe { + *o_challenger_state_len_ptr = o_challenger_state.len(); + *o_challenger_state_ptr = Box::into_raw(o_challenger_state.into_boxed_slice()) as *mut u8; + } +} + +#[no_mangle] +pub extern "C" fn operator_construct_sp1_core_proof_c( + num_checkpoints: usize, + num_shard_proofs_ptr: *const usize, + shard_proofs_lens_ptr: *const *const usize, + shard_proofs_ptr: *const *const *const u8, + public_values_stream_len: usize, + public_values_stream_ptr: *const u8, + cycles: u64, + o_proof_len_ptr: *mut usize, + o_proof_ptr: *mut *mut u8, +) { + // TODO(TomTaehoonKim): Remove this when args are passed from the caller. + let fibonacci_args = FibonacciArgs { n: 20, evm: false }; + let args = ProveArgs { + zkvm_input: fibonacci_args.n.to_le_bytes().to_vec(), + elf: FIBONACCI_ELF.to_vec(), + }; + let args_bytes = args.to_bytes(); + + let mut shard_proofs_vec = Vec::new(); + for i in 0..num_checkpoints { + let num_shard_proofs = unsafe { *num_shard_proofs_ptr.add(i) }; + let shard_proofs_len = + unsafe { std::slice::from_raw_parts(*shard_proofs_lens_ptr.add(i), num_shard_proofs) }; + let shard_proofs_ptr = + unsafe { std::slice::from_raw_parts(*shard_proofs_ptr.add(i), num_shard_proofs) }; + let mut shard_proofs = Vec::new(); + for j in 0..num_shard_proofs { + let shard_proof = + unsafe { std::slice::from_raw_parts(shard_proofs_ptr[j], shard_proofs_len[j]) }; + shard_proofs.push(shard_proof.to_vec()); + } + shard_proofs_vec.push(shard_proofs); + } + + let public_values_stream = + unsafe { std::slice::from_raw_parts(public_values_stream_ptr, public_values_stream_len) } + .to_vec(); + + let mut o_proof = Vec::new(); + operator_construct_sp1_core_proof::>( + &args_bytes, + &shard_proofs_vec, + &public_values_stream, + cycles, + &mut o_proof, + ); + + // TODO(TomTaehoonKim): Remove this when verification c api is implemented. + core_prove::scenario_end(&args, &o_proof).expect("verification failed"); + + unsafe { + *o_proof_len_ptr = o_proof.len(); + *o_proof_ptr = Box::into_raw(o_proof.into_boxed_slice()) as *mut u8; + } +} + +#[no_mangle] +pub extern "C" fn operator_prepare_compress_inputs_c( + core_proof_len: usize, + core_proof_ptr: *const u8, + o_num_rec_layouts_ptr: *mut usize, + o_rec_layouts_lens_ptr: *mut *mut usize, + o_rec_layouts_ptr: *mut *mut *mut u8, + o_num_def_layouts_ptr: *mut usize, + o_def_layouts_lens_ptr: *mut *mut usize, + o_def_layouts_ptr: *mut *mut *mut u8, + o_last_proof_public_values_len_ptr: *mut usize, + o_last_proof_public_values_ptr: *mut *mut u8, +) { + // TODO(TomTaehoonKim): Remove this when args are passed from the caller. + let fibonacci_args = FibonacciArgs { n: 20, evm: false }; + let args = ProveArgs { + zkvm_input: fibonacci_args.n.to_le_bytes().to_vec(), + elf: FIBONACCI_ELF.to_vec(), + }; + let args_bytes = args.to_bytes(); + + let core_proof = unsafe { std::slice::from_raw_parts(core_proof_ptr, core_proof_len) }.to_vec(); + + let mut o_rec_layouts = Vec::new(); + let mut o_def_layouts = Vec::new(); + let mut o_last_proof_public_values = Vec::new(); + operator_prepare_compress_inputs::>( + &args_bytes, + &core_proof, + &mut o_rec_layouts, + &mut o_def_layouts, + &mut o_last_proof_public_values, + ); + + let rec_layouts_lens = o_rec_layouts + .iter() + .map(|rec_layout| rec_layout.len()) + .collect::>(); + let rec_layouts = o_rec_layouts + .into_iter() + .map(|rec_layout| Box::into_raw(rec_layout.into_boxed_slice()) as *mut u8) + .collect::>(); + + let def_layouts_lens = o_def_layouts + .iter() + .map(|def_layout| def_layout.len()) + .collect::>(); + let def_layouts = o_def_layouts + .into_iter() + .map(|def_layout| Box::into_raw(def_layout.into_boxed_slice()) as *mut u8) + .collect::>(); + + unsafe { + *o_num_rec_layouts_ptr = rec_layouts.len(); + *o_rec_layouts_lens_ptr = Box::into_raw(rec_layouts_lens.into_boxed_slice()) as *mut usize; + *o_rec_layouts_ptr = Box::into_raw(rec_layouts.into_boxed_slice()) as *mut *mut u8; + *o_num_def_layouts_ptr = def_layouts.len(); + *o_def_layouts_lens_ptr = Box::into_raw(def_layouts_lens.into_boxed_slice()) as *mut usize; + *o_def_layouts_ptr = Box::into_raw(def_layouts.into_boxed_slice()) as *mut *mut u8; + *o_last_proof_public_values_len_ptr = o_last_proof_public_values.len(); + *o_last_proof_public_values_ptr = + Box::into_raw(o_last_proof_public_values.into_boxed_slice()) as *mut u8; + } +} + +#[no_mangle] +pub extern "C" fn operator_prepare_compress_input_chunks_c( + num_compressed_proofs: usize, + compressed_proofs_lens_ptr: *const usize, + compressed_proofs_ptr: *const *const u8, + o_num_red_layouts_ptr: *mut usize, + o_red_layout_len_ptr: *mut *mut usize, + o_red_layout_ptr: *mut *mut *mut u8, +) { + let compressed_proofs_lens = + unsafe { std::slice::from_raw_parts(compressed_proofs_lens_ptr, num_compressed_proofs) }; + let mut compressed_proofs = Vec::new(); + for i in 0..num_compressed_proofs { + let compressed_proof = unsafe { + std::slice::from_raw_parts(*compressed_proofs_ptr.add(i), compressed_proofs_lens[i]) + }; + compressed_proofs.push(compressed_proof.to_vec()); + } + + let mut o_red_layout = Vec::new(); + operator_prepare_compress_input_chunks(&compressed_proofs, &mut o_red_layout); + + let red_layout_lens = o_red_layout + .iter() + .map(|red_layout| red_layout.len()) + .collect::>(); + let red_layout = o_red_layout + .into_iter() + .map(|red_layout| Box::into_raw(red_layout.into_boxed_slice()) as *mut u8) + .collect::>(); + + unsafe { + *o_num_red_layouts_ptr = red_layout.len(); + *o_red_layout_len_ptr = Box::into_raw(red_layout_lens.into_boxed_slice()) as *mut usize; + *o_red_layout_ptr = Box::into_raw(red_layout.into_boxed_slice()) as *mut *mut u8; + } +} + +#[no_mangle] +pub extern "C" fn operator_verify_sp1_compressed_proof_c( + core_proof_len: usize, + core_proof_ptr: *const u8, + compressed_proof_len: usize, + compressed_proof_ptr: *const u8, + o_compressed_proof_len_ptr: *mut usize, + o_compressed_proof_ptr: *mut *mut u8, +) { + // TODO(TomTaehoonKim): Remove this when args are passed from the caller. + let fibonacci_args = FibonacciArgs { n: 20, evm: false }; + let args = ProveArgs { + zkvm_input: fibonacci_args.n.to_le_bytes().to_vec(), + elf: FIBONACCI_ELF.to_vec(), + }; + + let core_proof = unsafe { std::slice::from_raw_parts(core_proof_ptr, core_proof_len) }.to_vec(); + let compressed_proof = + unsafe { std::slice::from_raw_parts(compressed_proof_ptr, compressed_proof_len) }.to_vec(); + + let compressed_shard_proofs_obj: (ShardProof, ReduceProgramType) = + bincode::deserialize(&compressed_proof).unwrap(); + let compressed_proof = SP1ReduceProof { + proof: compressed_shard_proofs_obj.0, + }; + let compressed_proof = bincode::serialize(&compressed_proof).unwrap(); + + compress_prove::scenario_end(&args, &core_proof, &compressed_proof) + .expect("verification failed"); + + unsafe { + *o_compressed_proof_len_ptr = compressed_proof.len(); + *o_compressed_proof_ptr = Box::into_raw(compressed_proof.into_boxed_slice()) as *mut u8; + } +} + +#[no_mangle] +pub extern "C" fn operator_prove_shrink_c( + compressed_proof_len: usize, + compressed_proof_ptr: *const u8, + o_shrink_proof_len_ptr: *mut usize, + o_shrink_proof_ptr: *mut *mut u8, +) { + // TODO(TomTaehoonKim): Remove this when args are passed from the caller. + let fibonacci_args = FibonacciArgs { n: 20, evm: false }; + let args = ProveArgs { + zkvm_input: fibonacci_args.n.to_le_bytes().to_vec(), + elf: FIBONACCI_ELF.to_vec(), + }; + let args_bytes = args.to_bytes(); + + let compressed_proof = + unsafe { std::slice::from_raw_parts(compressed_proof_ptr, compressed_proof_len) }.to_vec(); + + let mut o_shrink_proof = Vec::new(); + operator_prove_shrink::>(&args_bytes, &compressed_proof, &mut o_shrink_proof); + + unsafe { + *o_shrink_proof_len_ptr = o_shrink_proof.len(); + *o_shrink_proof_ptr = Box::into_raw(o_shrink_proof.into_boxed_slice()) as *mut u8; + } +} + +#[no_mangle] +pub extern "C" fn operator_prepare_plonk_witness_c( + shrink_proof_len: usize, + shrink_proof_ptr: *const u8, + o_plonk_witness_len_ptr: *mut usize, + o_plonk_witness_ptr: *mut *mut u8, +) { + // TODO(TomTaehoonKim): Remove this when args are passed from the caller. + let fibonacci_args = FibonacciArgs { n: 20, evm: false }; + let args = ProveArgs { + zkvm_input: fibonacci_args.n.to_le_bytes().to_vec(), + elf: FIBONACCI_ELF.to_vec(), + }; + let args_bytes = args.to_bytes(); + + let shrink_proof = + unsafe { std::slice::from_raw_parts(shrink_proof_ptr, shrink_proof_len) }.to_vec(); + + let mut o_plonk_witness = Vec::new(); + operator_prepare_plonk_witness::>(&args_bytes, &shrink_proof, &mut o_plonk_witness); + + unsafe { + *o_plonk_witness_len_ptr = o_plonk_witness.len(); + *o_plonk_witness_ptr = Box::into_raw(o_plonk_witness.into_boxed_slice()) as *mut u8; + } +} + +#[no_mangle] +pub extern "C" fn operator_prove_plonk_c( + shrink_proof_len: usize, + shrink_proof_ptr: *const u8, + o_plonk_proof_len_ptr: *mut usize, + o_plonk_proof_ptr: *mut *mut u8, +) { + // TODO(TomTaehoonKim): Remove this when args are passed from the caller. + let fibonacci_args = FibonacciArgs { n: 20, evm: false }; + let args = ProveArgs { + zkvm_input: fibonacci_args.n.to_le_bytes().to_vec(), + elf: FIBONACCI_ELF.to_vec(), + }; + let args_bytes = args.to_bytes(); + + let shrink_proof = + unsafe { std::slice::from_raw_parts(shrink_proof_ptr, shrink_proof_len) }.to_vec(); + + let mut o_plonk_proof = Vec::new(); + operator_prove_plonk::>(&args_bytes, &shrink_proof, &mut o_plonk_proof); + + unsafe { + *o_plonk_proof_len_ptr = o_plonk_proof.len(); + *o_plonk_proof_ptr = Box::into_raw(o_plonk_proof.into_boxed_slice()) as *mut u8; + } +} + +#[no_mangle] +pub extern "C" fn operator_verify_sp1_plonk_proof_c( + core_proof_len: usize, + core_proof_ptr: *const u8, + plonk_proof_len: usize, + plonk_proof_ptr: *const u8, + o_plonk_proof_with_public_values_len_ptr: *mut usize, + o_plonk_proof_with_public_values_ptr: *mut *mut u8, +) { + // TODO(TomTaehoonKim): Remove this when args are passed from the caller. + let fibonacci_args = FibonacciArgs { n: 20, evm: false }; + let args = ProveArgs { + zkvm_input: fibonacci_args.n.to_le_bytes().to_vec(), + elf: FIBONACCI_ELF.to_vec(), + }; + + let core_proof = unsafe { std::slice::from_raw_parts(core_proof_ptr, core_proof_len) }.to_vec(); + let plonk_proof = + unsafe { std::slice::from_raw_parts(plonk_proof_ptr, plonk_proof_len) }.to_vec(); + + let plonk_proof_with_public_values = + plonk_prove::scenario_end(&args, &core_proof, &plonk_proof).expect("verification failed"); + let plonk_proof_with_public_values = + bincode::serialize(&plonk_proof_with_public_values).unwrap(); + + unsafe { + *o_plonk_proof_with_public_values_len_ptr = plonk_proof_with_public_values.len(); + *o_plonk_proof_with_public_values_ptr = + Box::into_raw(plonk_proof_with_public_values.into_boxed_slice()) as *mut u8; + } +} diff --git a/ffi/src/worker.rs b/ffi/src/worker.rs new file mode 100644 index 0000000000..c1b04a91c8 --- /dev/null +++ b/ffi/src/worker.rs @@ -0,0 +1,160 @@ +use crate::FIBONACCI_ELF; +use fibonacci_script::FibonacciArgs; +use sp1_sdk::mmp::{ + common::ProveArgs, + worker::{worker_commit_checkpoint, worker_compress_proofs, worker_prove_checkpoint}, +}; + +#[no_mangle] +pub extern "C" fn worker_commit_checkpoint_c( + idx: u32, + checkpoint_len: usize, + checkpoint_ptr: *const u8, + is_last_checkpoint: bool, + public_values_len: usize, + public_values: *const u8, + o_num_commitments_ptr: *mut usize, + o_commitments_len_ptr: *mut *mut usize, + o_commitments_ptr: *mut *mut *mut u8, + o_records_len_ptr: *mut *mut usize, + o_records_ptr: *mut *mut *mut u8, +) { + // TODO(TomTaehoonKim): Remove this when args are passed from the caller. + let fibonacci_args = FibonacciArgs { n: 20, evm: false }; + let args = ProveArgs { + zkvm_input: fibonacci_args.n.to_le_bytes().to_vec(), + elf: FIBONACCI_ELF.to_vec(), + }; + let args_bytes = args.to_bytes(); + + let mut o_commitments = Vec::new(); + let mut o_records = Vec::new(); + + let checkpoint = unsafe { std::slice::from_raw_parts(checkpoint_ptr, checkpoint_len) }.to_vec(); + let public_values = unsafe { std::slice::from_raw_parts(public_values, public_values_len) }; + worker_commit_checkpoint::>( + &args_bytes, + idx, + &checkpoint, + is_last_checkpoint, + public_values, + &mut o_commitments, + &mut o_records, + ); + + let commitments_len = o_commitments + .iter() + .map(|commitment| commitment.len()) + .collect::>(); + let commitments = o_commitments + .into_iter() + .map(|commitment| Box::into_raw(commitment.into_boxed_slice()) as *mut u8) + .collect::>(); + let records_len = o_records + .iter() + .map(|record| record.len()) + .collect::>(); + let records = o_records + .into_iter() + .map(|record| Box::into_raw(record.into_boxed_slice()) as *mut u8) + .collect::>(); + + unsafe { + *o_num_commitments_ptr = commitments.len(); + *o_commitments_len_ptr = Box::into_raw(commitments_len.into_boxed_slice()) as *mut usize; + *o_commitments_ptr = Box::into_raw(commitments.into_boxed_slice()) as *mut *mut u8; + *o_records_len_ptr = Box::into_raw(records_len.into_boxed_slice()) as *mut usize; + *o_records_ptr = Box::into_raw(records.into_boxed_slice()) as *mut *mut u8; + } +} + +#[no_mangle] +pub extern "C" fn worker_prove_checkpoint_c( + challenger_state_len: usize, + challenger_state_ptr: *const u8, + num_records: usize, + records_lens: *const usize, + records_ptr: *const *const u8, + o_shard_proofs_len_ptr: *mut *mut usize, + o_shard_proofs_ptr: *mut *mut *mut u8, +) { + // TODO(TomTaehoonKim): Remove this when args are passed from the caller. + let fibonacci_args = FibonacciArgs { n: 20, evm: false }; + let args = ProveArgs { + zkvm_input: fibonacci_args.n.to_le_bytes().to_vec(), + elf: FIBONACCI_ELF.to_vec(), + }; + let args_bytes = args.to_bytes(); + + let challenger_state = + unsafe { std::slice::from_raw_parts(challenger_state_ptr, challenger_state_len) }.to_vec(); + + let mut records = Vec::new(); + for i in 0..num_records { + records.push( + unsafe { std::slice::from_raw_parts(*records_ptr.add(i), *records_lens.add(i)) } + .to_vec(), + ); + } + + let mut o_shard_proofs = Vec::new(); + worker_prove_checkpoint::>( + &args_bytes, + &challenger_state, + &records, + &mut o_shard_proofs, + ); + + let shard_proofs_len = o_shard_proofs + .iter() + .map(|shard_proof| shard_proof.len()) + .collect::>(); + let shard_proofs = o_shard_proofs + .into_iter() + .map(|shard_proof| Box::into_raw(shard_proof.into_boxed_slice()) as *mut u8) + .collect::>(); + + unsafe { + *o_shard_proofs_len_ptr = Box::into_raw(shard_proofs_len.into_boxed_slice()) as *mut usize; + *o_shard_proofs_ptr = Box::into_raw(shard_proofs.into_boxed_slice()) as *mut *mut u8; + } +} + +#[no_mangle] +pub extern "C" fn worker_compress_proofs_c( + layout_len: usize, + layout_ptr: *const u8, + layout_type: usize, + last_proof_public_values_len: usize, + last_proof_public_values_ptr: *const u8, + o_proof_len_ptr: *mut usize, + o_proof_ptr: *mut *mut u8, +) { + // TODO(TomTaehoonKim): Remove this when args are passed from the caller. + let fibonacci_args = FibonacciArgs { n: 20, evm: false }; + let args = ProveArgs { + zkvm_input: fibonacci_args.n.to_le_bytes().to_vec(), + elf: FIBONACCI_ELF.to_vec(), + }; + let args_bytes = args.to_bytes(); + + let layout = unsafe { std::slice::from_raw_parts(layout_ptr, layout_len) }.to_vec(); + let last_proof_public_values = unsafe { + std::slice::from_raw_parts(last_proof_public_values_ptr, last_proof_public_values_len) + } + .to_vec(); + + let mut o_proof = Vec::new(); + worker_compress_proofs::>( + &args_bytes, + &layout, + layout_type, + Some(&last_proof_public_values), + &mut o_proof, + ); + + unsafe { + *o_proof_len_ptr = o_proof.len(); + *o_proof_ptr = Box::into_raw(o_proof.into_boxed_slice()) as *mut u8; + } +} diff --git a/sdk/Cargo.toml b/sdk/Cargo.toml index 39c6e82442..ca18891d55 100644 --- a/sdk/Cargo.toml +++ b/sdk/Cargo.toml @@ -22,6 +22,8 @@ reqwest = { version = "0.12.4", features = [ "stream", ] } anyhow = "1.0.83" +sp1-recursion-circuit = { workspace = true } +sp1-recursion-gnark-ffi = { workspace = true } sp1-prover = { workspace = true } sp1-core = { workspace = true } futures = "0.3.30" @@ -58,7 +60,7 @@ sysinfo = "0.30.13" default = ["network"] neon = ["sp1-core/neon"] native-gnark = ["sp1-prover/native-gnark"] -# TODO: Once alloy has a 1.* release, we can likely remove this feature flag, as there will be less +# TODO: Once alloy has a 1.* release, we can likely remove this feature flag, as there will be less # dependency resolution issues. network = ["dep:alloy-sol-types"] diff --git a/sdk/src/mmp/operator/mod.rs b/sdk/src/mmp/operator/mod.rs index e3009344e4..0c2e08af95 100644 --- a/sdk/src/mmp/operator/mod.rs +++ b/sdk/src/mmp/operator/mod.rs @@ -17,7 +17,8 @@ use std::borrow::Borrow; use steps::{ construct_sp1_core_proof_impl, operator_absorb_commits_impl, operator_prepare_compress_input_chunks_impl, operator_prepare_compress_inputs_impl, - operator_prove_plonk_impl, operator_prove_shrink_impl, operator_split_into_checkpoints_impl, + operator_prepare_plonk_witness_impl, operator_prove_plonk_impl, operator_prove_shrink_impl, + operator_split_into_checkpoints_impl, }; use utils::{read_bin_file_to_vec, ChallengerState}; @@ -189,6 +190,20 @@ pub fn operator_prove_shrink( *o_shrink_proof = bincode::serialize(&shrink_proof).unwrap(); } +pub fn operator_prepare_plonk_witness( + args: &Vec, + shrink_proof: &[u8], + o_plonk_witness: &mut Vec, +) { + let args_obj: ProveArgs = ProveArgs::from_slice(args.as_slice()); + let shrink_proof_obj: SP1ReduceProof = + bincode::deserialize(shrink_proof).unwrap(); + + let plonk_witness = operator_prepare_plonk_witness_impl(&args_obj, shrink_proof_obj).unwrap(); + + *o_plonk_witness = serde_json::to_vec(&plonk_witness).unwrap(); +} + pub fn operator_prove_plonk( args: &Vec, shrink_proof: &[u8], diff --git a/sdk/src/mmp/operator/steps.rs b/sdk/src/mmp/operator/steps.rs index c75dd748f8..9491a1be6e 100644 --- a/sdk/src/mmp/operator/steps.rs +++ b/sdk/src/mmp/operator/steps.rs @@ -15,12 +15,15 @@ use sp1_core::{ stark::{MachineProof, MachineProver, ShardProof, StarkGenericConfig}, utils::{BabyBearPoseidon2, SP1CoreProverError}, }; +use sp1_prover::build::Witness; use sp1_prover::{ PlonkBn254Proof, ReduceProgramType, SP1CoreProof, SP1CoreProofData, SP1DeferredMemoryLayout, SP1ProofWithMetadata, SP1Prover, SP1PublicValues, SP1RecursionMemoryLayout, SP1ReduceProof, SP1Stdin, SP1VerifyingKey, }; +use sp1_recursion_circuit::witness::Witnessable; use sp1_recursion_core::stark::RecursionAir; +use sp1_recursion_gnark_ffi::witness::GnarkWitness; use tracing::info_span; fn operator_split_into_checkpoints( @@ -238,6 +241,36 @@ pub fn operator_prove_shrink_impl( .map_err(|e| anyhow::anyhow!(e)) } +pub fn operator_prepare_plonk_witness_impl( + args: &ProveArgs, + shrink_proof: SP1ReduceProof, +) -> Result { + let (client, _, pk, _) = common::init_client(args); + let (_, opts, _) = common::bootstrap(&client, &pk).unwrap(); + let sp1_prover = client.prover.sp1_prover(); + + let outer_proof = sp1_prover.wrap_bn254(shrink_proof, opts).unwrap(); + + let vkey_digest = outer_proof.sp1_vkey_digest_bn254(); + let commited_values_digest = outer_proof.sp1_commited_values_digest_bn254(); + + std::fs::write("vkey_digest.txt", format!("{:?}", vkey_digest)).unwrap(); + std::fs::write( + "commited_values_digest.txt", + format!("{:?}", commited_values_digest), + ) + .unwrap(); + + let mut witness = Witness::default(); + outer_proof.proof.write(&mut witness); + witness.write_commited_values_digest(commited_values_digest); + witness.write_vkey_hash(vkey_digest); + + let gnark_witness = GnarkWitness::new(witness); + + Ok(gnark_witness) +} + pub fn operator_prove_plonk_impl( args: &ProveArgs, shrink_proof: SP1ReduceProof, diff --git a/sdk/src/mmp/scenario/compress_prove.rs b/sdk/src/mmp/scenario/compress_prove.rs index 43ab48193b..e5f82bfab5 100644 --- a/sdk/src/mmp/scenario/compress_prove.rs +++ b/sdk/src/mmp/scenario/compress_prove.rs @@ -11,7 +11,7 @@ use anyhow::Result; use serde::de::DeserializeOwned; use serde::Serialize; use sp1_core::{stark::ShardProof, utils::BabyBearPoseidon2}; -use sp1_prover::SP1ReduceProof; +use sp1_prover::{ReduceProgramType, SP1ReduceProof}; use tracing::info_span; pub fn mpc_prove_compress( @@ -85,9 +85,11 @@ pub fn mpc_prove_compress( } }; - let shard_proof: ShardProof = + let shard_proof: (ShardProof, ReduceProgramType) = bincode::deserialize(&compressed_proof).unwrap(); - let proof = SP1ReduceProof { proof: shard_proof }; + let proof = SP1ReduceProof { + proof: shard_proof.0, + }; let proof = bincode::serialize(&proof).unwrap(); tracing::info!("proof size: {:?}", proof.len()); @@ -112,7 +114,9 @@ pub fn scenario_end( sp1_version: client.prover.version().to_string(), }; - client.verify(&proof, &vk).unwrap(); + client + .verify(&proof, &vk) + .expect("failed to verify compress proof"); tracing::info!("Successfully verified compress proof"); Ok(proof) diff --git a/sdk/src/mmp/scenario/core_prove.rs b/sdk/src/mmp/scenario/core_prove.rs index 86be988c54..48308a3dc1 100644 --- a/sdk/src/mmp/scenario/core_prove.rs +++ b/sdk/src/mmp/scenario/core_prove.rs @@ -110,8 +110,10 @@ pub fn scenario_end( sp1_version: client.prover.version().to_string(), }; - client.verify(&proof, &vk).expect("failed to verify proof"); - tracing::info!("Successfully generated core-proof(verified)"); + client + .verify(&proof, &vk) + .expect("failed to verify core proof"); + tracing::info!("Successfully verified core proof"); Ok(proof) } diff --git a/sdk/src/mmp/scenario/plonk_prove.rs b/sdk/src/mmp/scenario/plonk_prove.rs index 6a52061417..903ba1f4f8 100644 --- a/sdk/src/mmp/scenario/plonk_prove.rs +++ b/sdk/src/mmp/scenario/plonk_prove.rs @@ -47,8 +47,10 @@ pub fn scenario_end( sp1_version: client.prover.version().to_string(), }; - client.verify(&proof, &vk).unwrap(); - tracing::info!("Successfully verified compress proof"); + client + .verify(&proof, &vk) + .expect("failed to verify plonk proof"); + tracing::info!("Successfully verified plonk proof"); Ok(proof) }