diff --git a/src/compiler/compiler.rs b/src/compiler/compiler.rs index e717a993..4ebc3b0a 100644 --- a/src/compiler/compiler.rs +++ b/src/compiler/compiler.rs @@ -4,6 +4,7 @@ use num_bigint::BigInt; use crate::{ field::Field, + frontend::dsl::trace::DSLTraceGenerator, interpreter::InterpreterTraceGenerator, parser::{ ast::{ @@ -269,6 +270,8 @@ impl Compiler { sbpir.machines.insert(machine_name.clone(), sbpir_machine); } + let sbpir = sbpir.transform_metadata(|_| ()); + sbpir.without_trace() } @@ -279,7 +282,7 @@ impl Compiler { setup: &MachineSetup, machine_name: &str, state_id: &str, - ) -> Vec> { + ) -> Vec> { let exprs = setup.get_poly_constraints(state_id).unwrap(); exprs @@ -300,40 +303,40 @@ impl Compiler { symbols: &SymTable, machine_name: &str, state_id: &str, - expr: &Expr, - ) -> Expr, ()> { + expr: &Expr, + ) -> Expr, DebugSymRef> { use Expr::*; match expr { - Const(v, _) => Const(*v, ()), - Sum(ses, _) => Sum( + Const(v, dsym) => Const(*v, dsym.clone()), + Sum(ses, dsym) => Sum( ses.iter() .map(|se| self.translate_queries_expr(symbols, machine_name, state_id, se)) .collect(), - (), + dsym.clone(), ), - Mul(ses, _) => Mul( + Mul(ses, dsym) => Mul( ses.iter() .map(|se| self.translate_queries_expr(symbols, machine_name, state_id, se)) .collect(), - (), + dsym.clone(), ), - Neg(se, _) => Neg( + Neg(se, dsym) => Neg( Box::new(self.translate_queries_expr(symbols, machine_name, state_id, se.as_ref())), - (), + dsym.clone(), ), - Pow(se, exp, _) => Pow( + Pow(se, exp, dsym) => Pow( Box::new(self.translate_queries_expr(symbols, machine_name, state_id, se.as_ref())), *exp, - (), + dsym.clone(), ), - MI(se, _) => MI( + MI(se, dsym) => MI( Box::new(self.translate_queries_expr(symbols, machine_name, state_id, se.as_ref())), - (), + dsym.clone(), ), - Halo2Expr(se, _) => Halo2Expr(se.clone(), ()), - Query(id, _) => Query( + Halo2Expr(se, dsym) => Halo2Expr(se.clone(), dsym.clone()), + Query(id, dsym) => Query( self.translate_query(symbols, machine_name, state_id, id), - (), + dsym.clone(), ), } } @@ -428,10 +431,10 @@ impl Compiler { machine_name: &str, machine_setup: &MachineSetup, state_id: &str, - ) -> StepType { + ) -> StepType { let handler = self.mapping.get_step_type_handler(machine_name, state_id); - let mut step_type: StepType = + let mut step_type: StepType = StepType::new(handler.uuid(), handler.annotation.to_string()); self.add_internal_signals(symbols, machine_name, &mut step_type, state_id); @@ -448,7 +451,7 @@ impl Compiler { &mut self, symbols: &SymTable, machine_name: &str, - step_type: &mut StepType, + step_type: &mut StepType, state_id: &str, ) { let internal_ids = self.get_all_internals(symbols, machine_name, state_id); @@ -467,7 +470,7 @@ impl Compiler { fn add_step_type_handlers( &mut self, - machine: &mut SBPIRMachine, + machine: &mut SBPIRMachine, DebugSymRef>, symbols: &SymTable, machine_name: &str, ) { @@ -497,7 +500,7 @@ impl Compiler { fn add_forward_signals( &mut self, - machine: &mut SBPIRMachine, + machine: &mut SBPIRMachine, DebugSymRef>, symbols: &SymTable, machine_name: &str, ) { diff --git a/src/compiler/compiler_legacy.rs b/src/compiler/compiler_legacy.rs index c7bed275..01ed7e21 100644 --- a/src/compiler/compiler_legacy.rs +++ b/src/compiler/compiler_legacy.rs @@ -259,15 +259,16 @@ impl CompilerLegacy { setup .iter() .map(|(machine_id, machine)| { - let poly_constraints: HashMap>> = machine - .poly_constraints_iter() - .map(|(step_id, step)| { - let new_step: Vec> = - step.iter().map(|pi| Self::map_pi_consts(pi)).collect(); + let poly_constraints: HashMap>> = + machine + .poly_constraints_iter() + .map(|(step_id, step)| { + let new_step: Vec> = + step.iter().map(|pi| Self::map_pi_consts(pi)).collect(); - (step_id.clone(), new_step) - }) - .collect(); + (step_id.clone(), new_step) + }) + .collect(); let new_machine: MachineSetup = machine.replace_poly_constraints(poly_constraints); @@ -276,17 +277,25 @@ impl CompilerLegacy { .collect() } - fn map_pi_consts(expr: &Expr) -> Expr { + fn map_pi_consts( + expr: &Expr, + ) -> Expr { use Expr::*; match expr { - Const(v, _) => Const(F::from_big_int(v), ()), - Sum(ses, _) => Sum(ses.iter().map(|se| Self::map_pi_consts(se)).collect(), ()), - Mul(ses, _) => Mul(ses.iter().map(|se| Self::map_pi_consts(se)).collect(), ()), - Neg(se, _) => Neg(Box::new(Self::map_pi_consts(se)), ()), - Pow(se, exp, _) => Pow(Box::new(Self::map_pi_consts(se)), *exp, ()), - Query(q, _) => Query(q.clone(), ()), + Const(v, dsym) => Const(F::from_big_int(v), dsym.clone()), + Sum(ses, dsym) => Sum( + ses.iter().map(|se| Self::map_pi_consts(se)).collect(), + dsym.clone(), + ), + Mul(ses, dsym) => Mul( + ses.iter().map(|se| Self::map_pi_consts(se)).collect(), + dsym.clone(), + ), + Neg(se, dsym) => Neg(Box::new(Self::map_pi_consts(se)), dsym.clone()), + Pow(se, exp, dsym) => Pow(Box::new(Self::map_pi_consts(se)), *exp, dsym.clone()), + Query(q, dsym) => Query(q.clone(), dsym.clone()), Halo2Expr(_, _) => todo!(), - MI(se, _) => MI(Box::new(Self::map_pi_consts(se)), ()), + MI(se, dsym) => MI(Box::new(Self::map_pi_consts(se)), dsym.clone()), } } @@ -316,7 +325,7 @@ impl CompilerLegacy { poly_constraints.iter().for_each(|poly| { let constraint = Constraint { annotation: format!("{:?}", poly), - expr: poly.clone(), + expr: poly.transform_meta(|_| ()), typing: Typing::AntiBooly, }; ctx.constr(constraint); @@ -372,7 +381,7 @@ impl CompilerLegacy { setup: &Setup, machine_id: &str, state_id: &str, - ) -> Vec, ()>> { + ) -> Vec, DebugSymRef>> { let exprs = setup .get(machine_id) .unwrap() @@ -390,38 +399,41 @@ impl CompilerLegacy { symbols: &SymTable, machine_id: &str, state_id: &str, - expr: &Expr, - ) -> Expr, ()> { + expr: &Expr, + ) -> Expr, DebugSymRef> { use Expr::*; match expr { - Const(v, _) => Const(*v, ()), - Sum(ses, _) => Sum( + Const(v, dsym) => Const(*v, dsym.clone()), + Sum(ses, dsym) => Sum( ses.iter() .map(|se| self.translate_queries_expr(symbols, machine_id, state_id, se)) .collect(), - (), + dsym.clone(), ), - Mul(ses, _) => Mul( + Mul(ses, dsym) => Mul( ses.iter() .map(|se| self.translate_queries_expr(symbols, machine_id, state_id, se)) .collect(), - (), + dsym.clone(), ), - Neg(se, _) => Neg( + Neg(se, dsym) => Neg( Box::new(self.translate_queries_expr(symbols, machine_id, state_id, se.as_ref())), - (), + dsym.clone(), ), - Pow(se, exp, _) => Pow( + Pow(se, exp, dsym) => Pow( Box::new(self.translate_queries_expr(symbols, machine_id, state_id, se.as_ref())), *exp, - (), + dsym.clone(), ), - MI(se, _) => MI( + MI(se, dsym) => MI( Box::new(self.translate_queries_expr(symbols, machine_id, state_id, se.as_ref())), - (), + dsym.clone(), + ), + Halo2Expr(se, dsym) => Halo2Expr(se.clone(), dsym.clone()), + Query(id, dsym) => Query( + self.translate_query(symbols, machine_id, state_id, id), + dsym.clone(), ), - Halo2Expr(se, _) => Halo2Expr(se.clone(), ()), - Query(id, _) => Query(self.translate_query(symbols, machine_id, state_id, id), ()), } } diff --git a/src/compiler/setup_inter.rs b/src/compiler/setup_inter.rs index 30130eaa..e99b7498 100644 --- a/src/compiler/setup_inter.rs +++ b/src/compiler/setup_inter.rs @@ -32,7 +32,7 @@ pub(super) fn interpret(ast: &[TLDecl], _symbols: &SymTable) pub(super) type Setup = HashMap>; pub(super) struct MachineSetup { - poly_constraints: HashMap>>, + poly_constraints: HashMap>>, input_signals: Vec>, output_signals: Vec>, @@ -49,10 +49,10 @@ impl Default for MachineSetup { } impl MachineSetup { pub(crate) fn map_consts(&self) -> MachineSetup { - let poly_constraints: HashMap>> = self + let poly_constraints: HashMap>> = self .poly_constraints_iter() .map(|(step_id, step)| { - let new_step: Vec> = step + let new_step: Vec> = step .iter() .map(|pi| Self::convert_const_to_field(pi)) .collect(); @@ -66,28 +66,32 @@ impl MachineSetup { } fn convert_const_to_field( - expr: &Expr, - ) -> Expr { + expr: &Expr, + ) -> Expr { use Expr::*; match expr { - Const(v, _) => Const(F::from_big_int(v), ()), - Sum(ses, _) => Sum( + Const(v, dsym) => Const(F::from_big_int(v), dsym.clone()), + Sum(ses, dsym) => Sum( ses.iter() .map(|se| Self::convert_const_to_field(se)) .collect(), - (), + dsym.clone(), ), - Mul(ses, _) => Mul( + Mul(ses, dsym) => Mul( ses.iter() .map(|se| Self::convert_const_to_field(se)) .collect(), - (), + dsym.clone(), ), - Neg(se, _) => Neg(Box::new(Self::convert_const_to_field(se)), ()), - Pow(se, exp, _) => Pow(Box::new(Self::convert_const_to_field(se)), *exp, ()), - Query(q, _) => Query(q.clone(), ()), + Neg(se, dsym) => Neg(Box::new(Self::convert_const_to_field(se)), dsym.clone()), + Pow(se, exp, dsym) => Pow( + Box::new(Self::convert_const_to_field(se)), + *exp, + dsym.clone(), + ), + Query(q, dsym) => Query(q.clone(), dsym.clone()), Halo2Expr(_, _) => todo!(), - MI(se, _) => MI(Box::new(Self::convert_const_to_field(se)), ()), + MI(se, dsym) => MI(Box::new(Self::convert_const_to_field(se)), dsym.clone()), } } } @@ -125,7 +129,7 @@ impl MachineSetup { fn add_poly_constraints>( &mut self, state: S, - poly_constraints: Vec>, + poly_constraints: Vec>, ) { self.poly_constraints .get_mut(&state.into()) @@ -135,13 +139,13 @@ impl MachineSetup { pub(super) fn poly_constraints_iter( &self, - ) -> std::collections::hash_map::Iter>> { + ) -> std::collections::hash_map::Iter>> { self.poly_constraints.iter() } pub(super) fn replace_poly_constraints( &self, - poly_constraints: HashMap>>, + poly_constraints: HashMap>>, ) -> MachineSetup { MachineSetup { poly_constraints, @@ -157,7 +161,7 @@ impl MachineSetup { pub(super) fn get_poly_constraints>( &self, state: S, - ) -> Option<&Vec>> { + ) -> Option<&Vec>> { self.poly_constraints.get(&state.into()) } } @@ -253,10 +257,15 @@ impl SetupInterpreter { HyperTransition(_, _, _, _) => todo!("Implement compilation for hyper transitions"), }; - self.add_poly_constraints(result.into_iter().map(|cr| cr.anti_booly).collect()); + self.add_poly_constraints( + result + .into_iter() + .map(|cr| cr.anti_booly.transform_meta(|_| cr.dsym.clone())) + .collect(), + ); } - fn add_poly_constraints(&mut self, pis: Vec>) { + fn add_poly_constraints(&mut self, pis: Vec>) { self.setup .get_mut(&self.current_machine) .unwrap() diff --git a/src/poly/mod.rs b/src/poly/mod.rs index b7c026ff..0d69af78 100644 --- a/src/poly/mod.rs +++ b/src/poly/mod.rs @@ -91,14 +91,14 @@ impl Expr { } } -impl Expr { +impl Expr { pub fn transform_meta(&self, apply_meta: ApplyMetaFn) -> Expr where ApplyMetaFn: Fn(&Expr) -> N + Clone, { let new_meta = apply_meta(self); match self { - Expr::Const(v, _) => Expr::Const(*v, new_meta), + Expr::Const(v, _) => Expr::Const(v.clone(), new_meta), Expr::Sum(ses, _) => Expr::Sum( ses.iter() .map(|e| e.transform_meta(apply_meta.clone())) diff --git a/src/sbpir/mod.rs b/src/sbpir/mod.rs index e1f90b74..279cb217 100644 --- a/src/sbpir/mod.rs +++ b/src/sbpir/mod.rs @@ -275,8 +275,8 @@ pub struct SBPIR = DSLTraceGenerator, M: Clon pub identifiers: HashMap, } -impl> SBPIR { - pub(crate) fn default() -> SBPIR { +impl, M: Clone> SBPIR { + pub(crate) fn default() -> SBPIR { let machines = HashMap::new(); let identifiers = HashMap::new(); SBPIR { @@ -289,7 +289,7 @@ impl> SBPIR { &self, // TODO does it have to be the same trace across all the machines? trace: &TG2, - ) -> SBPIR { + ) -> SBPIR { let mut machines_with_trace = HashMap::new(); for (name, machine) in self.machines.iter() { let machine_with_trace = machine.with_trace(trace.clone()); @@ -301,7 +301,7 @@ impl> SBPIR { } } - pub(crate) fn without_trace(&self) -> SBPIR { + pub(crate) fn without_trace(&self) -> SBPIR { let mut machines_without_trace = HashMap::new(); for (name, machine) in self.machines.iter() { let machine_without_trace = machine.without_trace(); @@ -312,7 +312,28 @@ impl> SBPIR { identifiers: self.identifiers.clone(), } } +} +impl + Clone, M: Clone> SBPIR { + pub fn transform_metadata( + self, + apply_meta: ApplyMetaFn, + ) -> SBPIR + where + ApplyMetaFn: Fn(&Expr, M>) -> N + Clone, + { + SBPIR { + machines: self + .machines + .into_iter() + .map(|(name, machine)| (name, machine.transform_meta(apply_meta.clone()))) + .collect(), + identifiers: self.identifiers, + } + } +} + +impl> SBPIR { /// Eliminate multiplicative inverses pub(crate) fn eliminate_mul_inv(mut self) -> SBPIR { for machine in self.machines.values_mut() { @@ -443,7 +464,7 @@ impl StepType { } } -impl StepType { +impl StepType { pub fn transform_meta(&self, apply_meta: ApplyMetaFn) -> StepType where ApplyMetaFn: Fn(&Expr, M>) -> N + Clone, @@ -470,7 +491,7 @@ impl StepType { auto_signals: self .auto_signals .iter() - .map(|(k, v)| (*k, v.transform_meta(apply_meta.clone()))) + .map(|(k, v)| (k.clone(), v.transform_meta(apply_meta.clone()))) .collect(), annotations: self.annotations.clone(), } @@ -560,7 +581,7 @@ pub struct Constraint { pub expr: PIR, } -impl Constraint { +impl Constraint { pub fn transform_meta(&self, apply_meta: ApplyMetaFn) -> Constraint where ApplyMetaFn: Fn(&Expr, M>) -> N + Clone, @@ -579,7 +600,7 @@ pub struct TransitionConstraint { pub expr: PIR, } -impl TransitionConstraint { +impl TransitionConstraint { pub fn transform_meta( &self, apply_meta: ApplyMetaFn, @@ -676,7 +697,7 @@ impl Lookup { } } -impl Lookup { +impl Lookup { pub fn transform_meta(&self, apply_meta: ApplyMetaFn) -> Lookup where ApplyMetaFn: Fn(&Expr, M>) -> N + Clone, diff --git a/src/sbpir/sbpir_machine.rs b/src/sbpir/sbpir_machine.rs index 1b1cc941..cdcc3789 100644 --- a/src/sbpir/sbpir_machine.rs +++ b/src/sbpir/sbpir_machine.rs @@ -4,12 +4,13 @@ use crate::{ trace::{DSLTraceGenerator, TraceContext}, StepTypeHandler, }, + poly::Expr, sbpir::Halo2Column, util::{uuid, UUID}, wit_gen::{FixedAssignment, NullTraceGenerator, TraceGenerator}, }; use halo2_proofs::plonk::{Advice, Fixed}; -use std::{collections::HashMap, fmt::Debug, rc::Rc}; +use std::{collections::HashMap, fmt::Debug, hash::Hash, rc::Rc}; use super::{ query::Queriable, ExposeOffset, FixedSignal, ForwardSignal, ImportedHalo2Advice, @@ -273,6 +274,45 @@ impl> SBPIRMachine { } } +impl + Clone, M: Clone> + SBPIRMachine +{ + pub fn transform_meta( + &self, + apply_meta: ApplyMetaFn, + ) -> SBPIRMachine + where + ApplyMetaFn: Fn(&Expr, M>) -> N + Clone, + { + SBPIRMachine { + step_types: self + .step_types + .iter() + .map(|(uuid, step_type)| (*uuid, step_type.transform_meta(apply_meta.clone()))) + .collect(), + + forward_signals: self.forward_signals.clone(), + shared_signals: self.shared_signals.clone(), + fixed_signals: self.fixed_signals.clone(), + halo2_advice: self.halo2_advice.clone(), + halo2_fixed: self.halo2_fixed.clone(), + exposed: self.exposed.clone(), + + annotations: self.annotations.clone(), + + trace_generator: self.trace_generator.clone(), + fixed_assignments: self.fixed_assignments.clone(), + + first_step: self.first_step, + last_step: self.last_step, + num_steps: self.num_steps, + q_enable: self.q_enable, + + id: self.id, + } + } +} + #[cfg(test)] mod tests { use crate::wit_gen::NullTraceGenerator;