Skip to content

Commit

Permalink
Traits unification (powdr-labs#1625)
Browse files Browse the repository at this point in the history
Co-authored-by: chriseth <[email protected]>
  • Loading branch information
gzanitti and chriseth authored Sep 9, 2024
1 parent 4340d46 commit 73f49a4
Show file tree
Hide file tree
Showing 16 changed files with 701 additions and 46 deletions.
1 change: 1 addition & 0 deletions ast/src/analyzed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub enum StatementIdentifier {
#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)]
pub struct Analyzed<T> {
pub definitions: HashMap<String, (Symbol, Option<FunctionValueDefinition>)>,
pub solved_impls: HashMap<String, HashMap<Vec<Type>, Arc<Expression>>>,
pub public_declarations: HashMap<String, PublicDeclaration>,
pub intermediate_columns: HashMap<String, (Symbol, Vec<AlgebraicExpression<T>>)>,
pub identities: Vec<Identity<SelectedExpressions<AlgebraicExpression<T>>>>,
Expand Down
42 changes: 39 additions & 3 deletions ast/src/parsed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ pub mod types;
pub mod visitor;

use std::{
collections::HashMap,
iter::{empty, once},
ops,
str::FromStr,
sync::Arc,
};

use auto_enums::auto_enum;
Expand All @@ -27,6 +29,8 @@ use self::{
visitor::{Children, ExpressionVisitable},
};

use crate::parsed::types::TupleType;

#[derive(Display, Clone, Copy, PartialEq, Eq)]
pub enum SymbolCategory {
/// A value, which has a type and can be referenced in expressions (a variable, function, constant, ...).
Expand Down Expand Up @@ -348,8 +352,36 @@ impl<R> Children<Expression<R>> for EnumVariant<Expression<R>> {
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
pub struct TraitImplementation<Expr> {
pub name: SymbolPath,
pub source_ref: SourceRef,
pub type_scheme: TypeScheme,
pub functions: Vec<NamedExpression<Expr>>,
pub functions: Vec<NamedExpression<Arc<Expr>>>,
}

impl<R> TraitImplementation<Expression<R>> {
pub fn function_by_name(&self, name: &str) -> Option<&NamedExpression<Arc<Expression<R>>>> {
self.functions.iter().find(|f| f.name == name)
}

pub fn type_of_function(&self, trait_decl: &TraitDeclaration, fn_name: &str) -> Type {
let Type::Tuple(TupleType { items }) = &self.type_scheme.ty else {
panic!("Expected tuple type for trait implementation");
};

let type_var_mapping: HashMap<String, Type> = trait_decl
.type_vars
.iter()
.cloned()
.zip(items.iter().cloned())
.collect();

let trait_fn = trait_decl
.function_by_name(fn_name)
.expect("Function not found in trait declaration");

let mut trait_type = trait_fn.ty.clone();
trait_type.substitute_type_vars(&type_var_mapping);
trait_type
}
}

impl<R> Children<Expression<R>> for TraitImplementation<Expression<R>> {
Expand All @@ -360,15 +392,15 @@ impl<R> Children<Expression<R>> for TraitImplementation<Expression<R>> {
Box::new(
self.functions
.iter_mut()
.flat_map(|m| m.body.children_mut()),
.map(|named_expr| Arc::get_mut(&mut named_expr.body).unwrap()),
)
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
pub struct NamedExpression<Expr> {
pub name: String,
pub body: Box<Expr>,
pub body: Expr,
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
Expand All @@ -382,6 +414,10 @@ impl TraitDeclaration<u64> {
pub fn function_by_name(&self, name: &str) -> Option<&TraitFunction> {
self.functions.iter().find(|f| f.name == name)
}

pub fn function_by_name_mut(&mut self, name: &str) -> Option<&mut TraitFunction> {
self.functions.iter_mut().find(|f| f.name == name)
}
}

impl<R> Children<Expression<R>> for TraitDeclaration<Expression<R>> {
Expand Down
12 changes: 10 additions & 2 deletions executor/src/constant_evaluator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
pub use data_structures::{get_uniquely_sized, get_uniquely_sized_cloned, VariablySizedColumn};
use itertools::Itertools;
use powdr_ast::{
analyzed::{Analyzed, FunctionValueDefinition, Symbol, TypedExpression},
analyzed::{Analyzed, Expression, FunctionValueDefinition, Symbol, TypedExpression},
parsed::{
types::{ArrayType, Type},
IndexAccess,
Expand Down Expand Up @@ -58,6 +58,7 @@ fn generate_values<T: FieldElement>(
) -> Vec<T> {
let symbols = CachedSymbols {
symbols: &analyzed.definitions,
solved_impls: &analyzed.solved_impls,
cache: Arc::new(RwLock::new(Default::default())),
degree,
};
Expand Down Expand Up @@ -146,6 +147,7 @@ type SymbolCache<'a, T> = HashMap<String, BTreeMap<Option<Vec<Type>>, Arc<Value<
#[derive(Clone)]
pub struct CachedSymbols<'a, T> {
symbols: &'a HashMap<String, (Symbol, Option<FunctionValueDefinition>)>,
solved_impls: &'a HashMap<String, HashMap<Vec<Type>, Arc<Expression>>>,
cache: Arc<RwLock<SymbolCache<'a, T>>>,
degree: DegreeType,
}
Expand All @@ -165,7 +167,13 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T> for CachedSymbols<'a, T> {
{
return Ok(v.clone());
}
let result = Definitions::lookup_with_symbols(self.symbols, name, type_args, self)?;
let result = Definitions::lookup_with_symbols(
self.symbols,
self.solved_impls,
name,
type_args,
self,
)?;
self.cache
.write()
.unwrap()
Expand Down
1 change: 1 addition & 0 deletions executor/src/witgen/query_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T> for Symbols<'a, T> {
}
None => Definitions::lookup_with_symbols(
&self.fixed_data.analyzed.definitions,
&self.fixed_data.analyzed.solved_impls,
name,
type_args,
self,
Expand Down
4 changes: 2 additions & 2 deletions importer/src/path_canonicalizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ impl<'a> Folder for Canonicalizer<'a> {
.map(|value| value.map(|value| SymbolDefinition { name, value }.into()))
}
ModuleStatement::TraitImplementation(mut trait_impl) => {
for f in &mut trait_impl.functions {
canonicalize_inside_expression(&mut f.body, &self.path, self.paths)
for f in trait_impl.children_mut() {
canonicalize_inside_expression(f, &self.path, self.paths)
}
Some(Ok(ModuleStatement::TraitImplementation(trait_impl)))
}
Expand Down
25 changes: 18 additions & 7 deletions parser/src/powdr.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use powdr_ast::parsed::{*, asm::*, types::*};
use powdr_number::{BigInt, BigUint};
use crate::{ParserContext, unescape_string};
use powdr_parser_util::Error;
use std::sync::Arc;

grammar(ctx: &ParserContext);

Expand Down Expand Up @@ -685,6 +686,15 @@ StructExpression: Box<Expression> = {
<start:@L> <name:Identifier> "{" <fields:NamedExpressions> "}" <end:@R> => Box::new(Expression::Tuple(ctx.source_ref(start, end), vec![])),
}

NamedExpressions: Vec<NamedExpression<Expression>> = {
=> vec![],
<mut list:( <NamedExpression> "," )*> <end:NamedExpression> ","? => { list.push(end); list }
}

NamedExpression: NamedExpression<Expression> = {
<name:Identifier> ":" <body:Expression> => NamedExpression { name, body }
}

// ---------------------------- Pattern -----------------------------

Pattern: Pattern = {
Expand Down Expand Up @@ -747,21 +757,22 @@ TraitFunction: TraitFunction<Expression> = {
}

TraitImplementation: TraitImplementation<Expression> = {
"impl" <type_scheme: GenericTraitName> "{" <functions:NamedExpressions> "}" => TraitImplementation { name: type_scheme.0, type_scheme: type_scheme.1, functions }
<start:@L> "impl" <type_scheme: GenericTraitName> "{" <functions:NamedArcExpressions> "}" <end:@L> => TraitImplementation { name: type_scheme.0, source_ref: ctx.source_ref(start, end), type_scheme: type_scheme.1, functions }
}

NamedExpressions: Vec<NamedExpression<Expression>> = {

NamedArcExpressions: Vec<NamedExpression<Arc<Expression>>> = {
=> vec![],
<mut list:( <NamedExpression> "," )*> <end:NamedExpression> ","? => { list.push(end); list }
<mut list:( <NamedArcExpression> "," )*> <end:NamedArcExpression> ","? => { list.push(end); list }
}

NamedExpression: NamedExpression<Expression> = {
<name:Identifier> ":" <body:BoxedExpression> => NamedExpression { name, body }
NamedArcExpression: NamedExpression<Arc<Expression>> = {
<name:Identifier> ":" <body:Expression> => NamedExpression { name, body: body.into() }
}

GenericTraitName: (SymbolPath, TypeScheme) = {
<vars:("<" <TypeVarBounds> ">")> <name:SymbolPath> <items:("<" <TypeTermList<ArrayLengthNumber>> ">")> =>
(name, TypeScheme{ vars, ty: Type::Tuple(TupleType{items}) })
<vars:("<" <TypeVarBounds> ">")?> <name:SymbolPath> <items:("<" <TypeTermList<ArrayLengthNumber>> ">")> =>
(name, TypeScheme{ vars: vars.unwrap_or_default(), ty: Type::Tuple(TupleType{items}) })
}


Expand Down
26 changes: 22 additions & 4 deletions pil-analyzer/src/condenser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ type AnalyzedIdentity<T> = Identity<SelectedExpressions<AlgebraicExpression<T>>>

pub fn condense<T: FieldElement>(
mut definitions: HashMap<String, (Symbol, Option<FunctionValueDefinition>)>,
solved_impls: HashMap<String, HashMap<Vec<Type>, Arc<Expression>>>,
public_declarations: HashMap<String, PublicDeclaration>,
identities: &[ParsedIdentity],
source_order: Vec<StatementIdentifier>,
auto_added_symbols: HashSet<String>,
) -> Analyzed<T> {
let mut condenser = Condenser::new(&definitions);
let mut condenser = Condenser::new(&definitions, &solved_impls);

let mut condensed_identities = vec![];
let mut intermediate_columns = HashMap::new();
Expand Down Expand Up @@ -168,6 +169,7 @@ pub fn condense<T: FieldElement>(

Analyzed {
definitions,
solved_impls,
public_declarations,
intermediate_columns,
identities: condensed_identities,
Expand All @@ -182,6 +184,8 @@ pub struct Condenser<'a, T> {
degree: Option<DegreeRange>,
/// All the definitions from the PIL file.
symbols: &'a HashMap<String, (Symbol, Option<FunctionValueDefinition>)>,
/// Pointers to expressions for all referenced trait implementations and the concrete types.
solved_impls: &'a HashMap<String, HashMap<Vec<Type>, Arc<Expression>>>,
/// Evaluation cache.
symbol_values: SymbolCache<'a, T>,
/// Current namespace (for names of generated columns).
Expand All @@ -200,7 +204,10 @@ pub struct Condenser<'a, T> {
}

impl<'a, T: FieldElement> Condenser<'a, T> {
pub fn new(symbols: &'a HashMap<String, (Symbol, Option<FunctionValueDefinition>)>) -> Self {
pub fn new(
symbols: &'a HashMap<String, (Symbol, Option<FunctionValueDefinition>)>,
solved_impls: &'a HashMap<String, HashMap<Vec<Type>, Arc<Expression>>>,
) -> Self {
let counters = Counters::with_existing(symbols.values().map(|(sym, _)| sym), None, None);
Self {
symbols,
Expand All @@ -213,6 +220,7 @@ impl<'a, T: FieldElement> Condenser<'a, T> {
new_intermediate_column_values: Default::default(),
new_symbols: HashSet::new(),
new_constraints: vec![],
solved_impls,
}
}

Expand Down Expand Up @@ -340,7 +348,13 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T> for Condenser<'a, T> {
{
return Ok(v.clone());
}
let value = Definitions::lookup_with_symbols(self.symbols, name, type_args, self)?;
let value = Definitions::lookup_with_symbols(
self.symbols,
self.solved_impls,
name,
type_args,
self,
)?;
self.symbol_values
.entry(name.to_string())
.or_default()
Expand All @@ -350,7 +364,11 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T> for Condenser<'a, T> {
}

fn lookup_public_reference(&self, name: &str) -> Result<Arc<Value<'a, T>>, EvalError> {
Definitions(self.symbols).lookup_public_reference(name)
Definitions {
definitions: self.symbols,
solved_impls: self.solved_impls,
}
.lookup_public_reference(name)
}

fn min_degree(&self) -> Result<Arc<Value<'a, T>>, EvalError> {
Expand Down
Loading

0 comments on commit 73f49a4

Please sign in to comment.