From 03b4165623f7ca579ecfb2835d68aa71ddb56ccb Mon Sep 17 00:00:00 2001 From: schaeff Date: Sat, 16 Nov 2024 16:04:37 +0100 Subject: [PATCH 1/5] optimizer: deduplicate fixed columns --- pilopt/Cargo.toml | 1 + pilopt/src/lib.rs | 35 +++++++++++++++++++++++++++++++++++ pilopt/tests/optimizer.rs | 19 +++++++++++++++++++ 3 files changed, 55 insertions(+) diff --git a/pilopt/Cargo.toml b/pilopt/Cargo.toml index 457b72def7..ffe628c4b9 100644 --- a/pilopt/Cargo.toml +++ b/pilopt/Cargo.toml @@ -13,6 +13,7 @@ powdr-number.workspace = true log = "0.4.17" pretty_assertions = "1.4.0" +itertools = "0.13.0" [dev-dependencies] powdr-pil-analyzer.workspace = true diff --git a/pilopt/src/lib.rs b/pilopt/src/lib.rs index 68c9685c69..b6c53b3b75 100644 --- a/pilopt/src/lib.rs +++ b/pilopt/src/lib.rs @@ -4,6 +4,7 @@ use std::cmp::Ordering; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; +use itertools::Itertools; use powdr_ast::analyzed::{ AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicExpression, AlgebraicReference, AlgebraicUnaryOperation, AlgebraicUnaryOperator, Analyzed, ConnectIdentity, Expression, @@ -24,6 +25,7 @@ pub fn optimize(mut pil_file: Analyzed) -> Analyzed { let col_count_pre = (pil_file.commitment_count(), pil_file.constant_count()); remove_unreferenced_definitions(&mut pil_file); remove_constant_fixed_columns(&mut pil_file); + deduplicate_fixed_columns(&mut pil_file); simplify_identities(&mut pil_file); extract_constant_lookups(&mut pil_file); remove_constant_witness_columns(&mut pil_file); @@ -215,6 +217,39 @@ fn constant_value(function: &FunctionValueDefinition) -> Option { } } +/// Deduplicate fixed columns of the same namespace which share the same value. +/// This uses the `Display` implementation of the function value, so `|i| i` is different from `|j| j` +/// This is enough for use cases where exactly the same function is inserted many times +/// This only replaces the references inside expressions and does not clean up the now unreachable fixed column definitions +fn deduplicate_fixed_columns(pil_file: &mut Analyzed) { + // build a map of `poly_id` to the `(name, poly_id)` they can be replaced by + let replacement_map: BTreeMap = pil_file + .constant_polys_in_source_order() + // group symbols by common namespace and displayed value + .into_group_map_by(|(symbol, value)| (symbol.absolute_name.split("::").next().unwrap(), value.as_ref().unwrap().to_string())) + .values() + // map all other symbols to the first one + .flat_map(|group| { + group[1..].iter().flat_map(|from| { + from.0 + .array_elements() + .map(|(_, from_id)| from_id) + .zip_eq(group[0].0.array_elements()) + }) + }) + .collect(); + + // substitute all occurences in expressions. + pil_file.post_visit_expressions_in_identities_mut(&mut |e| { + if let AlgebraicExpression::Reference(r) = e { + if let Some((new_name, new_id)) = replacement_map.get(&r.poly_id) { + r.name = new_name.clone(); + r.poly_id = *new_id; + } + }; + }); +} + /// Simplifies multiplications by zero and one. fn simplify_identities(pil_file: &mut Analyzed) { pil_file.post_visit_expressions_in_identities_mut(&mut simplify_expression_single); diff --git a/pilopt/tests/optimizer.rs b/pilopt/tests/optimizer.rs index cd6ba2bfe1..383513947c 100644 --- a/pilopt/tests/optimizer.rs +++ b/pilopt/tests/optimizer.rs @@ -24,6 +24,25 @@ fn replace_fixed() { assert_eq!(optimized, expectation); } +#[test] +fn deduplicate_fixed() { + let input = r#"namespace N(65536); + col fixed first = [1, 32]*; + col fixed second = [1, 32]*; + col witness X; + col witness Y; + X * first = Y * second; +"#; + let expectation = r#"namespace N(65536); + col fixed first = [1_fe, 32_fe]*; + col witness X; + col witness Y; + N::X * N::first = N::Y * N::first; +"#; + let optimized = optimize(analyze_string::(input).unwrap()).to_string(); + assert_eq!(optimized, expectation); +} + #[test] fn replace_lookup() { let input = r#"namespace N(65536); From 75bcc8d8f328b3854b3b0fa4c924028b85b1d74f Mon Sep 17 00:00:00 2001 From: schaeff Date: Sun, 17 Nov 2024 17:28:01 +0100 Subject: [PATCH 2/5] fmt --- pilopt/src/lib.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pilopt/src/lib.rs b/pilopt/src/lib.rs index b6c53b3b75..70bd2f93b2 100644 --- a/pilopt/src/lib.rs +++ b/pilopt/src/lib.rs @@ -226,7 +226,12 @@ fn deduplicate_fixed_columns(pil_file: &mut Analyzed) { let replacement_map: BTreeMap = pil_file .constant_polys_in_source_order() // group symbols by common namespace and displayed value - .into_group_map_by(|(symbol, value)| (symbol.absolute_name.split("::").next().unwrap(), value.as_ref().unwrap().to_string())) + .into_group_map_by(|(symbol, value)| { + ( + symbol.absolute_name.split("::").next().unwrap(), + value.as_ref().unwrap().to_string(), + ) + }) .values() // map all other symbols to the first one .flat_map(|group| { From 1e15180b1fb94b975f7bff7fdf1def2b47cc9e4e Mon Sep 17 00:00:00 2001 From: schaeff Date: Sun, 17 Nov 2024 18:28:41 +0100 Subject: [PATCH 3/5] use syntactic equality --- ast/src/analyzed/mod.rs | 10 ++-- ast/src/parsed/mod.rs | 102 ++++++++++++++++++++++++++++++---------- ast/src/parsed/types.rs | 6 ++- parser-util/src/lib.rs | 5 ++ pilopt/src/lib.rs | 6 +-- 5 files changed, 95 insertions(+), 34 deletions(-) diff --git a/ast/src/analyzed/mod.rs b/ast/src/analyzed/mod.rs index c2dd00da3c..0188fd3f0c 100644 --- a/ast/src/analyzed/mod.rs +++ b/ast/src/analyzed/mod.rs @@ -796,7 +796,7 @@ pub enum SymbolKind { Other(), } -#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)] pub enum FunctionValueDefinition { Array(ArrayExpression), Expression(TypedExpression), @@ -1152,7 +1152,9 @@ impl SelectedExpressions { pub type Expression = parsed::Expression; pub type TypedExpression = crate::parsed::TypedExpression; -#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash, PartialOrd, Ord, +)] pub enum Reference { LocalVar(u64, String), Poly(PolynomialReference), @@ -1577,7 +1579,9 @@ impl From for AlgebraicExpression { /// Reference to a symbol with optional type arguments. /// Named `PolynomialReference` for historical reasons, it can reference /// any symbol. -#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, PartialOrd, Ord, Hash, +)] pub struct PolynomialReference { /// Absolute name of the symbol. pub name: String, diff --git a/ast/src/parsed/mod.rs b/ast/src/parsed/mod.rs index 6b89e61277..f8371e091e 100644 --- a/ast/src/parsed/mod.rs +++ b/ast/src/parsed/mod.rs @@ -251,7 +251,9 @@ impl Children for PilStatement { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub enum TypeDeclaration { Enum(EnumDeclaration), Struct(StructDeclaration), @@ -289,7 +291,9 @@ impl Children> for TypeDeclaration { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub struct StructDeclaration { pub name: String, pub type_vars: TypeBounds, @@ -326,7 +330,9 @@ impl Children> for StructDeclaration { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub struct EnumDeclaration { pub name: String, pub type_vars: TypeBounds, @@ -351,7 +357,9 @@ impl Children> for EnumDeclaration> { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub struct EnumVariant { pub name: String, pub fields: Option>>, @@ -452,7 +460,9 @@ impl Children> for TraitImplementation> { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub struct NamedExpression { pub name: String, pub body: Expr, @@ -476,7 +486,9 @@ impl Children> for NamedExpression>> { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub struct TraitDeclaration { pub name: String, pub type_vars: Vec, @@ -502,7 +514,9 @@ impl Children> for TraitDeclaration> { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub struct NamedType { pub name: String, pub ty: Type, @@ -548,7 +562,9 @@ impl Children> for SelectedExpressions> { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub enum Expression { Reference(SourceRef, Ref), PublicReference(SourceRef, String), @@ -641,7 +657,9 @@ impl_source_reference!( Expression ); -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub struct UnaryOperation> { pub op: UnaryOperator, pub expr: Box, @@ -663,7 +681,9 @@ impl Children for UnaryOperation { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub struct BinaryOperation> { pub left: Box, pub op: BinaryOperator, @@ -686,7 +706,9 @@ impl Children for BinaryOperation { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub struct Number { #[schemars(skip)] pub value: BigUint, @@ -739,7 +761,9 @@ impl Expression { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub struct MatchExpression> { pub scrutinee: Box, pub arms: Vec>, @@ -766,7 +790,9 @@ impl Children for MatchExpression { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub struct BlockExpression { pub statements: Vec>, pub expr: Option>, @@ -959,7 +985,9 @@ impl NamespacedPolynomialReference { } } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema, Hash, +)] pub struct LambdaExpression> { pub kind: FunctionKind, pub params: Vec, @@ -985,7 +1013,7 @@ impl Children for LambdaExpression { } #[derive( - Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema, + Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema, Hash, )] pub enum FunctionKind { Pure, @@ -993,7 +1021,9 @@ pub enum FunctionKind { Query, } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema, Hash, +)] pub struct ArrayLiteral> { pub items: Vec, } @@ -1156,7 +1186,9 @@ impl BinaryOperator { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub struct IndexAccess> { pub array: Box, pub index: Box, @@ -1178,7 +1210,9 @@ impl Children for IndexAccess { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub struct FunctionCall> { pub function: Box, pub arguments: Vec, @@ -1200,7 +1234,9 @@ impl Children for FunctionCall { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub struct MatchArm> { pub pattern: Pattern, pub value: E, @@ -1216,7 +1252,9 @@ impl Children for MatchArm { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub struct IfExpression> { pub condition: Box, pub body: Box, @@ -1249,7 +1287,9 @@ impl Children for IfExpression { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub struct StructExpression { pub name: Ref, pub fields: Vec>>>, @@ -1271,7 +1311,9 @@ impl Children> for StructExpression { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub enum StatementInsideBlock> { // TODO add a source ref here. LetStatement(LetStatementInsideBlock), @@ -1294,7 +1336,9 @@ impl Children for StatementInsideBlock { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub struct LetStatementInsideBlock> { pub pattern: Pattern, pub ty: Option>, @@ -1352,7 +1396,9 @@ impl Children for FunctionDefinition { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub enum ArrayExpression { Value(Vec>), RepeatedValue(Vec>), @@ -1511,7 +1557,9 @@ impl Children> for ArrayExpression { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub enum Pattern { CatchAll(SourceRef), // "_", matches a single value Ellipsis(SourceRef), // "..", matches a series of values, only valid inside array patterns @@ -1604,7 +1652,9 @@ impl SourceReference for Pattern { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub struct TypedExpression> { pub e: Expression, pub type_scheme: Option>, diff --git a/ast/src/parsed/types.rs b/ast/src/parsed/types.rs index 49bed0e11d..a2dd25b81b 100644 --- a/ast/src/parsed/types.rs +++ b/ast/src/parsed/types.rs @@ -422,7 +422,9 @@ impl From for Type { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash, +)] pub struct TypeScheme { /// Type variables and their trait bounds. pub vars: TypeBounds, @@ -481,7 +483,7 @@ impl From for TypeScheme { } #[derive( - Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Default, Serialize, Deserialize, JsonSchema, + Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Default, Serialize, Deserialize, JsonSchema, Hash, )] // TODO bounds should be SymbolPaths in the future. pub struct TypeBounds(Vec<(String, BTreeSet)>); diff --git a/parser-util/src/lib.rs b/parser-util/src/lib.rs index 664a7e47d6..3e1cb979ed 100644 --- a/parser-util/src/lib.rs +++ b/parser-util/src/lib.rs @@ -4,6 +4,7 @@ use std::{ fmt::{self, Debug, Formatter}, + hash::Hash, sync::Arc, }; @@ -42,6 +43,10 @@ impl PartialEq for SourceRef { impl Eq for SourceRef {} +impl Hash for SourceRef { + fn hash(&self, _: &mut H) {} +} + impl SourceRef { pub fn unknown() -> Self { Default::default() diff --git a/pilopt/src/lib.rs b/pilopt/src/lib.rs index 70bd2f93b2..47c4f3ffbe 100644 --- a/pilopt/src/lib.rs +++ b/pilopt/src/lib.rs @@ -218,18 +218,18 @@ fn constant_value(function: &FunctionValueDefinition) -> Option { } /// Deduplicate fixed columns of the same namespace which share the same value. -/// This uses the `Display` implementation of the function value, so `|i| i` is different from `|j| j` +/// This compares the function values, so `|i| i` is different from `|j| j` /// This is enough for use cases where exactly the same function is inserted many times /// This only replaces the references inside expressions and does not clean up the now unreachable fixed column definitions fn deduplicate_fixed_columns(pil_file: &mut Analyzed) { // build a map of `poly_id` to the `(name, poly_id)` they can be replaced by let replacement_map: BTreeMap = pil_file .constant_polys_in_source_order() - // group symbols by common namespace and displayed value + // group symbols by common namespace and function value .into_group_map_by(|(symbol, value)| { ( symbol.absolute_name.split("::").next().unwrap(), - value.as_ref().unwrap().to_string(), + value.as_ref().unwrap(), ) }) .values() From 19e31acf0795831607d531b712d2ef8dec2791b4 Mon Sep 17 00:00:00 2001 From: schaeff Date: Sun, 17 Nov 2024 18:44:36 +0100 Subject: [PATCH 4/5] test across namespaces --- pilopt/tests/optimizer.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/pilopt/tests/optimizer.rs b/pilopt/tests/optimizer.rs index 383513947c..e4cb67fe76 100644 --- a/pilopt/tests/optimizer.rs +++ b/pilopt/tests/optimizer.rs @@ -32,12 +32,23 @@ fn deduplicate_fixed() { col witness X; col witness Y; X * first = Y * second; + namespace M(65536); + col fixed first = [1, 32]*; + col fixed second = [1, 32]*; + col witness X; + col witness Y; + X * first = Y * second; "#; let expectation = r#"namespace N(65536); col fixed first = [1_fe, 32_fe]*; col witness X; col witness Y; N::X * N::first = N::Y * N::first; +namespace M(65536); + col fixed first = [1_fe, 32_fe]*; + col witness X; + col witness Y; + M::X * M::first = M::Y * M::first; "#; let optimized = optimize(analyze_string::(input).unwrap()).to_string(); assert_eq!(optimized, expectation); From 4435e3f3e01be365a25dd29be3ccfffb2db223d4 Mon Sep 17 00:00:00 2001 From: schaeff Date: Thu, 21 Nov 2024 15:54:28 +0100 Subject: [PATCH 5/5] address review comments --- pilopt/src/lib.rs | 54 ++++++++++++++++++++++++--------------- pilopt/tests/optimizer.rs | 6 +++-- 2 files changed, 37 insertions(+), 23 deletions(-) diff --git a/pilopt/src/lib.rs b/pilopt/src/lib.rs index 47c4f3ffbe..35a262acbc 100644 --- a/pilopt/src/lib.rs +++ b/pilopt/src/lib.rs @@ -10,7 +10,7 @@ use powdr_ast::analyzed::{ AlgebraicUnaryOperation, AlgebraicUnaryOperator, Analyzed, ConnectIdentity, Expression, FunctionValueDefinition, Identity, LookupIdentity, PermutationIdentity, PhantomLookupIdentity, PhantomPermutationIdentity, PolyID, PolynomialIdentity, PolynomialReference, PolynomialType, - Reference, SymbolKind, + Reference, Symbol, SymbolKind, }; use powdr_ast::parsed::types::Type; use powdr_ast::parsed::visitor::{AllChildren, Children, ExpressionVisitable}; @@ -219,40 +219,52 @@ fn constant_value(function: &FunctionValueDefinition) -> Option { /// Deduplicate fixed columns of the same namespace which share the same value. /// This compares the function values, so `|i| i` is different from `|j| j` +fn extract_namespace(symbol: &Symbol) -> &str { + symbol.absolute_name.split("::").next().unwrap() +} + /// This is enough for use cases where exactly the same function is inserted many times /// This only replaces the references inside expressions and does not clean up the now unreachable fixed column definitions fn deduplicate_fixed_columns(pil_file: &mut Analyzed) { // build a map of `poly_id` to the `(name, poly_id)` they can be replaced by - let replacement_map: BTreeMap = pil_file - .constant_polys_in_source_order() - // group symbols by common namespace and function value - .into_group_map_by(|(symbol, value)| { - ( - symbol.absolute_name.split("::").next().unwrap(), - value.as_ref().unwrap(), - ) - }) - .values() - // map all other symbols to the first one - .flat_map(|group| { - group[1..].iter().flat_map(|from| { - from.0 - .array_elements() - .map(|(_, from_id)| from_id) - .zip_eq(group[0].0.array_elements()) + let (replacement_by_id, replacement_by_name): (BTreeMap, BTreeMap) = + pil_file + .constant_polys_in_source_order() + // group symbols by common namespace and function value + .into_group_map_by(|(symbol, value)| { + (extract_namespace(symbol), value.as_ref().unwrap()) }) - }) - .collect(); + .values() + // map all other symbols to the first one + .flat_map(|group| { + group[1..].iter().flat_map(|from| { + from.0 + .array_elements() + .zip_eq(group[0].0.array_elements()) + .map(|((name, from_id), to_id)| ((from_id, to_id.clone()), (name, to_id))) + }) + }) + .unzip(); // substitute all occurences in expressions. + pil_file.post_visit_expressions_in_identities_mut(&mut |e| { if let AlgebraicExpression::Reference(r) = e { - if let Some((new_name, new_id)) = replacement_map.get(&r.poly_id) { + if let Some((new_name, new_id)) = replacement_by_id.get(&r.poly_id) { r.name = new_name.clone(); r.poly_id = *new_id; } }; }); + + // substitute all occurences in definitions. + pil_file.post_visit_expressions_in_definitions_mut(&mut |e| { + if let Expression::Reference(_, Reference::Poly(reference)) = e { + if let Some((replacement_name, _)) = replacement_by_name.get(&reference.name) { + reference.name = replacement_name.clone(); + } + }; + }); } /// Simplifies multiplications by zero and one. diff --git a/pilopt/tests/optimizer.rs b/pilopt/tests/optimizer.rs index e4cb67fe76..053d75f38b 100644 --- a/pilopt/tests/optimizer.rs +++ b/pilopt/tests/optimizer.rs @@ -29,9 +29,10 @@ fn deduplicate_fixed() { let input = r#"namespace N(65536); col fixed first = [1, 32]*; col fixed second = [1, 32]*; + col i = first * second; col witness X; col witness Y; - X * first = Y * second; + X * first = Y * second + i; namespace M(65536); col fixed first = [1, 32]*; col fixed second = [1, 32]*; @@ -41,9 +42,10 @@ fn deduplicate_fixed() { "#; let expectation = r#"namespace N(65536); col fixed first = [1_fe, 32_fe]*; + col i = N::first * N::first; col witness X; col witness Y; - N::X * N::first = N::Y * N::first; + N::X * N::first = N::Y * N::first + N::i; namespace M(65536); col fixed first = [1_fe, 32_fe]*; col witness X;