Skip to content

Commit

Permalink
Support intermediate polynomials in p3 (powdr-labs#1995)
Browse files Browse the repository at this point in the history
Following discussions about large expressions with a lot of common
subexpressions, support intermediate polynomials in plonky3 so that
these subexpressions can be explicitly declared as intermediate
polynomials and cached.

With the following program:

```rust
let N: int = 2**15;

machine Main with degree: N {
    col witness a;
    let i0: inter = a + a;
    let i1: inter = i0 + i0;
    let i2: inter = i1 + i1;
    let i3: inter = i2 + i2;
    let i4: inter = i3 + i3;
    let i5: inter = i4 + i4;
    let i6: inter = i5 + i5;
    let i7: inter = i6 + i6;
    let i8: inter = i7 + i7;
    let i9: inter = i8 + i8;
    let i10: inter = i9 + i9;
    let i11: inter = i10 + i10;
    let i12: inter = i11 + i11;
    let i13: inter = i12 + i12;
    let i14: inter = i13 + i13;
    let i15: inter = i14 + i14;
    i15 = 78;
}
```

On main, 3s on quotient polynomial:
```console
Setup took 0.03163829s
INFO     prove [ 3.24s | 0.03% / 100.00% ]
INFO     ┝━ commit to stage {stage} data [ 25.5ms | 0.71% / 0.79% ]
INFO     │  ┕━ coset_lde_batch [ 2.36ms | 0.07% ] dims: 2x32768
INFO     ┝━ infer log of constraint degree [ 10.7ms | 0.33% ]
INFO     ┝━ compute quotient polynomial [ 3.08s | 94.79% ]
INFO     ┝━ infer log of constraint degree [ 9.19ms | 0.28% ]
INFO     ┝━ infer log of constraint degree [ 8.15ms | 0.25% ]
INFO     ┝━ commit to quotient poly chunks [ 23.8ms | 0.68% / 0.73% ]
INFO     │  ┕━ coset_lde_batch [ 1.59ms | 0.05% ] dims: 2x32768
INFO     ┝━ infer log of constraint degree [ 8.89ms | 0.27% ]
INFO     ┝━ compute_inverse_denominators [ 1.27ms | 0.04% ]
INFO     ┝━ reduce matrix quotient [ 1.24ms | 0.00% / 0.04% ] dims: 2x65536
INFO     │  ┝━ compute opened values with Lagrange interpolation [ 771µs | 0.02% ]
INFO     │  ┕━ reduce rows [ 455µs | 0.01% ]
INFO     ┝━ reduce matrix quotient [ 1.06ms | 0.00% / 0.03% ] dims: 2x65536
INFO     │  ┝━ compute opened values with Lagrange interpolation [ 689µs | 0.02% ]
INFO     │  ┕━ reduce rows [ 368µs | 0.01% ]
INFO     ┝━ reduce matrix quotient [ 1.06ms | 0.00% / 0.03% ] dims: 2x65536
INFO     │  ┝━ compute opened values with Lagrange interpolation [ 699µs | 0.02% ]
INFO     │  ┕━ reduce rows [ 355µs | 0.01% ]
INFO     ┝━ FRI prover [ 69.4ms | 0.00% / 2.14% ]
INFO     │  ┝━ commit phase [ 29.4ms | 0.91% ]
INFO     │  ┝━ grind for proof-of-work witness [ 39.3ms | 1.21% ]
INFO     │  ┕━ query phase [ 623µs | 0.02% ]
INFO     ┕━ infer log of constraint degree [ 8.05ms | 0.25% ]
INFO     verify [ 49.9ms | 30.69% / 100.00% ]
INFO     ┝━ infer log of constraint degree [ 9.05ms | 18.11% ]
INFO     ┝━ infer log of constraint degree [ 9.47ms | 18.95% ]
INFO     ┝━ infer log of constraint degree [ 7.94ms | 15.90% ]
INFO     ┕━ infer log of constraint degree [ 8.16ms | 16.34% ]
Proof generation took 3.2949023s
```

With this change, 10ms spent on quotient polynomial:
```console
Setup took 0.004602916s
INFO     prove [ 117ms | 0.90% / 100.00% ]
INFO     ┝━ commit to stage {stage} data [ 27.5ms | 21.66% / 23.59% ]
INFO     │  ┕━ coset_lde_batch [ 2.25ms | 1.93% ] dims: 2x32768
INFO     ┝━ infer log of constraint degree [ 67.2µs | 0.06% ]
INFO     ┝━ compute quotient polynomial [ 9.95ms | 8.52% ]
INFO     ┝━ infer log of constraint degree [ 10.4µs | 0.01% ]
INFO     ┝━ infer log of constraint degree [ 4.92µs | 0.00% ]
INFO     ┝━ commit to quotient poly chunks [ 30.6ms | 24.10% / 26.17% ]
INFO     │  ┕━ coset_lde_batch [ 2.42ms | 2.07% ] dims: 2x32768
INFO     ┝━ infer log of constraint degree [ 13.2µs | 0.01% ]
INFO     ┝━ compute_inverse_denominators [ 1.53ms | 1.31% ]
INFO     ┝━ reduce matrix quotient [ 1.47ms | 0.01% / 1.26% ] dims: 2x65536
INFO     │  ┝━ compute opened values with Lagrange interpolation [ 912µs | 0.78% ]
INFO     │  ┕━ reduce rows [ 545µs | 0.47% ]
INFO     ┝━ reduce matrix quotient [ 2.14ms | 0.01% / 1.83% ] dims: 2x65536
INFO     │  ┝━ compute opened values with Lagrange interpolation [ 887µs | 0.76% ]
INFO     │  ┕━ reduce rows [ 1.25ms | 1.07% ]
INFO     ┝━ reduce matrix quotient [ 2.25ms | 0.00% / 1.93% ] dims: 2x65536
INFO     │  ┝━ compute opened values with Lagrange interpolation [ 1.87ms | 1.60% ]
INFO     │  ┕━ reduce rows [ 380µs | 0.33% ]
INFO     ┝━ FRI prover [ 40.1ms | 0.09% / 34.38% ]
INFO     │  ┝━ commit phase [ 33.9ms | 29.01% ]
INFO     │  ┝━ grind for proof-of-work witness [ 5.33ms | 4.56% ]
INFO     │  ┕━ query phase [ 843µs | 0.72% ]
INFO     ┕━ infer log of constraint degree [ 11.0µs | 0.01% ]
INFO     verify [ 14.8ms | 99.82% / 100.00% ]
INFO     ┝━ infer log of constraint degree [ 8.62µs | 0.06% ]
INFO     ┝━ infer log of constraint degree [ 10.2µs | 0.07% ]
INFO     ┝━ infer log of constraint degree [ 4.58µs | 0.03% ]
INFO     ┕━ infer log of constraint degree [ 3.79µs | 0.03% ]
Proof generation took 0.132404s
```
  • Loading branch information
Schaeff authored Oct 31, 2024
1 parent bd1506a commit 45504a9
Showing 1 changed file with 58 additions and 8 deletions.
66 changes: 58 additions & 8 deletions plonky3/src/circuit_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ pub struct ConstraintSystem<T> {
witness_columns: BTreeMap<PolyID, (usize, usize)>,
// for each fixed column, the index of this column in the fixed columns
fixed_columns: BTreeMap<PolyID, usize>,
// for each intermediate polynomial, the expression
intermediates: BTreeMap<PolyID, AlgebraicExpression<T>>,
identities: Vec<Identity<T>>,
// for each public column, the name, poly_id, index in the witness columns, and stage
pub(crate) publics_by_stage: Vec<Vec<(String, PolyID, usize)>>,
Expand All @@ -51,7 +53,7 @@ pub struct ConstraintSystem<T> {

impl<T: FieldElement> From<&Analyzed<T>> for ConstraintSystem<T> {
fn from(analyzed: &Analyzed<T>) -> Self {
let identities = analyzed.identities_with_inlined_intermediate_polynomials();
let identities = analyzed.identities.clone();
let constant_count = analyzed.constant_count();
let stage_widths = (0..analyzed.stage_count() as u32)
.map(|stage| {
Expand All @@ -72,6 +74,16 @@ impl<T: FieldElement> From<&Analyzed<T>> for ConstraintSystem<T> {
.map(|(index, (_, id))| (id, index))
.collect();

let intermediates = analyzed
.intermediate_polys_in_source_order()
.flat_map(|(symbol, definitions)| {
symbol
.array_elements()
.zip_eq(definitions)
.map(|((_, id), expr)| (id, expr.clone()))
})
.collect();

let witness_columns = analyzed
.definitions_in_source_order(PolynomialType::Committed)
.into_group_map_by(|(s, _)| s.stage.unwrap_or_default())
Expand Down Expand Up @@ -118,6 +130,7 @@ impl<T: FieldElement> From<&Analyzed<T>> for ConstraintSystem<T> {
stage_widths,
witness_columns,
fixed_columns,
intermediates,
challenges_by_stage,
}
}
Expand Down Expand Up @@ -235,6 +248,7 @@ where
e: &AlgebraicExpression<T>,
traces_by_stage: &[AB::M],
fixed: &AB::M,
intermediate_cache: &mut BTreeMap<u64, AB::Expr>,
publics: &BTreeMap<&String, <AB as MultistageAirBuilder>::PublicVar>,
challenges: &[BTreeMap<&u64, <AB as MultistageAirBuilder>::Challenge>],
) -> AB::Expr {
Expand All @@ -255,7 +269,22 @@ where
fixed.row_slice(r.next as usize)[index].into()
}
PolynomialType::Intermediate => {
unreachable!("intermediate polynomials should have been inlined")
if let Some(expr) = intermediate_cache.get(&poly_id.id) {
expr.clone()
} else {
let value = self.to_plonky3_expr::<AB>(
&self.constraint_system.intermediates[&poly_id],
traces_by_stage,
fixed,
intermediate_cache,
publics,
challenges,
);
assert!(intermediate_cache
.insert(poly_id.id, value.clone())
.is_none());
value
}
}
}
}
Expand All @@ -274,6 +303,7 @@ where
left,
traces_by_stage,
fixed,
intermediate_cache,
publics,
challenges,
);
Expand All @@ -285,10 +315,22 @@ where
_ => unimplemented!("pow with non-constant exponent"),
},
AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { left, op, right }) => {
let left =
self.to_plonky3_expr::<AB>(left, traces_by_stage, fixed, publics, challenges);
let right =
self.to_plonky3_expr::<AB>(right, traces_by_stage, fixed, publics, challenges);
let left = self.to_plonky3_expr::<AB>(
left,
traces_by_stage,
fixed,
intermediate_cache,
publics,
challenges,
);
let right = self.to_plonky3_expr::<AB>(
right,
traces_by_stage,
fixed,
intermediate_cache,
publics,
challenges,
);

match op {
Add => left + right,
Expand All @@ -298,8 +340,14 @@ where
}
}
AlgebraicExpression::UnaryOperation(AlgebraicUnaryOperation { op, expr }) => {
let expr: <AB as AirBuilder>::Expr =
self.to_plonky3_expr::<AB>(expr, traces_by_stage, fixed, publics, challenges);
let expr: <AB as AirBuilder>::Expr = self.to_plonky3_expr::<AB>(
expr,
traces_by_stage,
fixed,
intermediate_cache,
publics,
challenges,
);

match op {
AlgebraicUnaryOperator::Minus => -expr,
Expand Down Expand Up @@ -341,6 +389,7 @@ where
let traces_by_stage: Vec<AB::M> =
(0..stage_count).map(|i| builder.stage_trace(i)).collect();
let fixed = builder.preprocessed();
let mut intermediate_cache = BTreeMap::new();
let public_input_values_by_stage = (0..stage_count)
.map(|i| builder.stage_public_values(i))
.collect_vec();
Expand Down Expand Up @@ -405,6 +454,7 @@ where
&identity.expression,
&traces_by_stage,
&fixed,
&mut intermediate_cache,
&public_vals_by_id,
&challenges_by_stage,
);
Expand Down

0 comments on commit 45504a9

Please sign in to comment.