From 050ba79e9ae0376abc883230652ca638bfc75150 Mon Sep 17 00:00:00 2001 From: Paulo Matos Date: Thu, 15 Jun 2023 16:51:21 +0200 Subject: [PATCH] Implement closures --- src/ASTRuntime.cpp | 42 ++++++++++ src/AnalysisFreeVars.cpp | 134 +++++++++++++++++++++++++++++++ src/CMakeLists.txt | 2 + src/include/ASTRuntime.h | 36 +++++++++ src/include/ASTVisitor.h | 1 + src/include/AnalysisFreeVars.h | 42 ++++++++++ src/include/ast.h | 30 +++++-- src/include/ast_fwd.h | 1 + src/include/environment.h | 4 + src/include/interpreter.h | 2 + src/interpreter.cpp | 54 ++++++++----- test/integration/closure.rkt | 8 ++ test/integration/let-values3.rkt | 7 ++ 13 files changed, 335 insertions(+), 28 deletions(-) create mode 100644 src/ASTRuntime.cpp create mode 100644 src/AnalysisFreeVars.cpp create mode 100644 src/include/ASTRuntime.h create mode 100644 src/include/AnalysisFreeVars.h create mode 100644 test/integration/closure.rkt create mode 100644 test/integration/let-values3.rkt diff --git a/src/ASTRuntime.cpp b/src/ASTRuntime.cpp new file mode 100644 index 0000000..aa7ab27 --- /dev/null +++ b/src/ASTRuntime.cpp @@ -0,0 +1,42 @@ +#include "ASTRuntime.h" + +#include "AnalysisFreeVars.h" + +#include + +using namespace ast; + +Closure::Closure(const Lambda &Lbd, const std::vector &Envs) + : ClonableNode(ASTNodeKind::AST_Closure), + L(std::unique_ptr(static_cast(Lbd.clone()))) { + + // To create a closure we need to: + + // 1. Find the free variables in the lambda. + AnalysisFreeVars AFV; + L->accept(AFV); + auto const &FreeVars = AFV.getResult(); + + // 2. Find in the current environment, the values of the free variables + // and save them. + for (auto const &Var : FreeVars) { + for (auto const &E : llvm::reverse(Envs)) { + auto const &Val = E.lookup(Var); + if (Val) { + Env.add(Var, std::unique_ptr(Val->clone())); + break; + } + } + } +} + +Closure::Closure(const Closure &Other) + : ClonableNode(ASTNodeKind::AST_Closure), + L(std::unique_ptr(static_cast(Other.L->clone()))) { + for (auto const &E : Other.Env) { + Env.add(E.first, std::unique_ptr(E.second->clone())); + } +} + +void Closure::dump() const {} +void Closure::write() const {} \ No newline at end of file diff --git a/src/AnalysisFreeVars.cpp b/src/AnalysisFreeVars.cpp new file mode 100644 index 0000000..0bb2a45 --- /dev/null +++ b/src/AnalysisFreeVars.cpp @@ -0,0 +1,134 @@ +#include "AnalysisFreeVars.h" + +void AnalysisFreeVars::visit(ast::Identifier const &Id) { + // If the identifier is not in the environment, then it is a free variable. + + for (auto const &Var : llvm::reverse(Vars)) { + if (Var.count(Id) == 0) { + Result.insert(Id); + } + } +} + +void AnalysisFreeVars::visit(ast::Integer const &Int) { + // Integers do not have free variables. + // Nothing to do. +} + +void AnalysisFreeVars::visit(ast::Linklet const &Linklet) { + llvm::errs() << "Free variable analysis only applies to expressions.\n"; +} + +void AnalysisFreeVars::visit(ast::DefineValues const &DV) { + llvm::errs() << "Free variable analysis only applies to expressions.\n"; +} + +void AnalysisFreeVars::visit(ast::Values const &V) { + // Need to check for free variable in each expression of the Values + // expression. + for (auto const &Expr : V.getExprs()) { + Expr->accept(*this); + } +} + +void AnalysisFreeVars::visit(ast::Void const &Vd) { + // Void expressions have no free variables. + // Nothing to do. +} + +void AnalysisFreeVars::visit(ast::Lambda const &L) { + const ast::Formal &F = L.getFormals(); + std::set FormalVars; + + if (F.getType() == ast::Formal::Type::Identifier) { + auto IF = static_cast(F); + FormalVars.insert(IF.getIdentifier()); + } else if (F.getType() == ast::Formal::Type::List) { + auto LF = static_cast(F); + for (auto const &Id : LF.getIds()) { + FormalVars.insert(Id); + } + } else if (F.getType() == ast::Formal::Type::ListRest) { + auto LRF = static_cast(F); + for (auto const &Id : LRF.getIds()) { + FormalVars.insert(Id); + } + FormalVars.insert(LRF.getRestFormal()); + } + + // Save the current environment. + Vars.push_back(FormalVars); + + // Check for free variables in the body of the lambda. + L.getBody().accept(*this); + + // Restore the environment. + Vars.pop_back(); +} + +void AnalysisFreeVars::visit(ast::Closure const &L) { + // Closures by definition do not have free variables. + // Nothing to do. +} + +void AnalysisFreeVars::visit(ast::Begin const &B) { + // Iterate through all the begin expressions and check for free variables. + for (auto const &Expr : B.getBody()) { + Expr->accept(*this); + } +} + +void AnalysisFreeVars::visit(ast::List const &L) { + // Iterate through all the List expressions and check for free variables. + for (auto const &Expr : L.values()) { + Expr->accept(*this); + } +} + +void AnalysisFreeVars::visit(ast::Application const &A) { + // Iterate through all the Application expressions and check for free + // variables. + for (auto const &Expr : A.getExprs()) { + Expr->accept(*this); + } +} + +void AnalysisFreeVars::visit(ast::SetBang const &SB) { + // Check for free variables on the right hand side expression of SetBang + // expression. + SB.getExpr().accept(*this); +} + +void AnalysisFreeVars::visit(ast::IfCond const &If) { + // Check for free variables on the condition expression of IfCond expression. + If.getCond().accept(*this); + // Check for free variables on the consequent expression of IfCond expression. + If.getThen().accept(*this); + // Check for free variables on the alternative expression of IfCond + // expression. + If.getElse().accept(*this); +} + +void AnalysisFreeVars::visit(ast::BooleanLiteral const &Bool) { + // Boolean literals have no free variables. + // Nothing to do. +} + +void AnalysisFreeVars::visit(ast::LetValues const &LV) { + std::set LVVars; + for (size_t Idx = 0; Idx < LV.bindingCount(); Idx++) + for (auto const &Var : LV.getBindingIds(Idx)) + LVVars.insert(Var); + + Vars.push_back(LVVars); + + for (size_t Idx = 0; Idx < LV.bodyCount(); Idx++) + LV.getBodyExpr(Idx).accept(*this); + + Vars.pop_back(); +} + +void AnalysisFreeVars::visit(ast::RuntimeFunction const &LV) { + // Runtime Functions have no free variables. + // Nothing to do. +} diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a51b558..f226075 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -18,6 +18,8 @@ add_llvm_executable(norac main.cpp environment.cpp ast.cpp + ASTRuntime.cpp + AnalysisFreeVars.cpp idpool.cpp Parse.cpp Lex.cpp diff --git a/src/include/ASTRuntime.h b/src/include/ASTRuntime.h new file mode 100644 index 0000000..7350877 --- /dev/null +++ b/src/include/ASTRuntime.h @@ -0,0 +1,36 @@ +#pragma once + +#include "ast.h" +#include "environment.h" + +#include + +namespace ast { +// +// This file includes the structures that are used in addition to +// those in ast.h during runtime interpretation. +// +// The simplest example is the Closure. + +// A Closure is a runtime manifestation of a Lambda. +class Closure : public ClonableNode { +public: + Closure(const Lambda &Lbd, const std::vector &Envs); + Closure(const Closure &Other); + + static bool classof(const ASTNode *N) { + return N->getKind() == ASTNodeKind::AST_Closure; + } + + void dump() const override; + void write() const override; + + const Lambda &getLambda() const { return *L; } + const Environment &getEnvironment() const { return Env; } + +private: + std::unique_ptr L; + Environment Env; +}; + +}; // namespace ast \ No newline at end of file diff --git a/src/include/ASTVisitor.h b/src/include/ASTVisitor.h index 6c02ce1..daf54e1 100644 --- a/src/include/ASTVisitor.h +++ b/src/include/ASTVisitor.h @@ -14,6 +14,7 @@ class ASTVisitor { virtual void visit(ast::Values const &V) = 0; virtual void visit(ast::Void const &Vd) = 0; virtual void visit(ast::Lambda const &L) = 0; + virtual void visit(ast::Closure const &L) = 0; virtual void visit(ast::Begin const &B) = 0; virtual void visit(ast::List const &L) = 0; virtual void visit(ast::Application const &A) = 0; diff --git a/src/include/AnalysisFreeVars.h b/src/include/AnalysisFreeVars.h new file mode 100644 index 0000000..9060f2a --- /dev/null +++ b/src/include/AnalysisFreeVars.h @@ -0,0 +1,42 @@ +#pragma once + +#include "ASTVisitor.h" +#include "ast.h" + +#include + +#include +#include +#include +#include +#include + +// File implementing free variable analysis for expressions. + +class AnalysisFreeVars : public ASTVisitor { +public: + virtual void visit(ast::Identifier const &Id) override; + virtual void visit(ast::Integer const &Int) override; + virtual void visit(ast::Linklet const &Linklet) override; + virtual void visit(ast::DefineValues const &DV) override; + virtual void visit(ast::Values const &V) override; + virtual void visit(ast::Void const &Vd) override; + virtual void visit(ast::Lambda const &L) override; + virtual void visit(ast::Closure const &L) override; + virtual void visit(ast::Begin const &B) override; + virtual void visit(ast::List const &L) override; + virtual void visit(ast::Application const &A) override; + virtual void visit(ast::SetBang const &SB) override; + virtual void visit(ast::IfCond const &If) override; + virtual void visit(ast::BooleanLiteral const &Bool) override; + virtual void visit(ast::LetValues const &LV) override; + virtual void visit(ast::RuntimeFunction const &LV) override; + + // Get the current saved result. + std::set getResult() const { return Result; }; + +private: + std::set Result; /// List of free variables. + llvm::SmallVector> + Vars; /// Environment map for identifiers. +}; diff --git a/src/include/ast.h b/src/include/ast.h index df56052..b38bbb4 100644 --- a/src/include/ast.h +++ b/src/include/ast.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -42,6 +43,7 @@ class ASTNode { AST_BooleanLiteral, AST_Integer, AST_Lambda, + AST_Closure, // result of evaluating a Lambda expression AST_List, AST_Values, AST_Void, @@ -206,6 +208,7 @@ class Linklet : public ClonableNode { class Application : public ClonableNode { public: Application() : ClonableNode(ASTNodeKind::AST_Application) {} + Application &operator=(Application &&) = delete; Application(const Application &); Application(Application &&) = default; Application &operator=(const Application &) = delete; @@ -221,8 +224,12 @@ class Application : public ClonableNode { return N->getKind() == ASTNodeKind::AST_Application; } + const llvm::SmallVector> &getExprs() const { + return Exprs; + } + private: - std::vector> Exprs; + llvm::SmallVector> Exprs; }; // AST Node representing a begin or begin0 expression. @@ -234,7 +241,8 @@ class Begin : public ClonableNode { Begin &operator=(const Begin &B) = delete; ~Begin() = default; - [[nodiscard]] const std::vector> &getBody() const { + [[nodiscard]] const llvm::SmallVector> & + getBody() const { return Body; } [[nodiscard]] size_t bodyCount() const { return Body.size(); } @@ -249,7 +257,7 @@ class Begin : public ClonableNode { } private: - std::vector> Body; + llvm::SmallVector> Body; bool Zero = false; }; @@ -495,6 +503,8 @@ class Lambda : public ClonableNode { void dump() const override; void write() const override; + llvm::SmallVector findFreeVariables() const; + static bool classof(const ASTNode *N) { return N->getKind() == ASTNodeKind::AST_Lambda; } @@ -507,6 +517,8 @@ class Lambda : public ClonableNode { class LetValues : public ClonableNode { public: LetValues() : ClonableNode(ASTNodeKind::AST_LetValues) {} + LetValues &operator=(const LetValues &) = delete; + LetValues &operator=(LetValues &&) = delete; LetValues(const LetValues &DV); LetValues(LetValues &&DV) = default; ~LetValues() = default; @@ -575,8 +587,10 @@ class List : public ClonableNode { return N->getKind() == ASTNodeKind::AST_List; } + auto const &values() const { return Values; } + private: - std::vector> Values; + llvm::SmallVector> Values; }; class SetBang : public ClonableNode { @@ -671,8 +685,12 @@ class RuntimeFunction : public ValueNode { virtual std::unique_ptr operator()(const std::vector &Args) const = 0; - void dump() const override { std::cerr << "#"; } - void write() const override { std::cout << "#"; } + void dump() const override { + llvm::errs() << "#"; + } + void write() const override { + llvm::outs() << "#"; + } static bool classof(const ASTNode *N) { return N->getKind() == ASTNodeKind::AST_RuntimeFunction; diff --git a/src/include/ast_fwd.h b/src/include/ast_fwd.h index 75a76e5..6f05f07 100644 --- a/src/include/ast_fwd.h +++ b/src/include/ast_fwd.h @@ -8,6 +8,7 @@ class DefineValues; class Values; class Void; class Lambda; +class Closure; class Begin; class List; class Application; diff --git a/src/include/environment.h b/src/include/environment.h index fcd7d80..9724c4b 100644 --- a/src/include/environment.h +++ b/src/include/environment.h @@ -20,6 +20,10 @@ class Environment { // Lookup an identifier in the environment. std::unique_ptr lookup(ast::Identifier const &Id) const; + // Implement range style access to the Env map. + auto begin() const { return Env.begin(); } + auto end() const { return Env.end(); } + private: // Environment map for identifiers. std::map> Env; diff --git a/src/include/interpreter.h b/src/include/interpreter.h index c8ccc5e..50997d9 100644 --- a/src/include/interpreter.h +++ b/src/include/interpreter.h @@ -26,6 +26,7 @@ class Interpreter : public ASTVisitor { virtual void visit(ast::Values const &V) override; virtual void visit(ast::Void const &Vd) override; virtual void visit(ast::Lambda const &L) override; + virtual void visit(ast::Closure const &L) override; virtual void visit(ast::Begin const &B) override; virtual void visit(ast::List const &L) override; virtual void visit(ast::Application const &A) override; @@ -40,6 +41,7 @@ class Interpreter : public ASTVisitor { // Get the current saved result. std::unique_ptr getResult() const { + assert(Result && "No result has been recorded during interpretation."); return std::unique_ptr(Result->clone()); }; std::unique_ptr diff --git a/src/interpreter.cpp b/src/interpreter.cpp index 3dc26d7..6076486 100644 --- a/src/interpreter.cpp +++ b/src/interpreter.cpp @@ -5,11 +5,13 @@ #include #include +#include #include #include #include #include +#include "ASTRuntime.h" #include "Casting.h" #include "ast_fwd.h" #include "llvm/Support/raw_ostream.h" @@ -45,11 +47,8 @@ void Interpreter::visit(ast::Identifier const &Id) { llvm::dbgs() << Name << "\n"; }); - // FIXME: why is it that : - // for (auto &Env : Envs | std::ranges::reverse) { - // does not work here? - for (auto Env = Envs.rbegin(); Env != Envs.rend(); ++Env) { - auto V = Env->lookup(Id); + for (auto &Env : llvm::reverse(Envs)) { + auto V = Env.lookup(Id); if (V) { Result = std::move(V); return; @@ -173,8 +172,15 @@ void Interpreter::visit(ast::Void const &Vd) { } void Interpreter::visit(ast::Lambda const &L) { + // The interpretation of a lambda expression is a closure, + // even if no variables are captured. LLVM_DEBUG(llvm::dbgs() << "Interpreting Lambda\n"); - Result = std::unique_ptr(L.clone()); + Result = std::make_unique(L, Envs); +} + +void Interpreter::visit(ast::Closure const &C) { + LLVM_DEBUG(llvm::dbgs() << "Interpreting Closure\n"); + Result = std::unique_ptr(C.clone()); } void Interpreter::visit(ast::Begin const &B) { @@ -214,14 +220,15 @@ void Interpreter::visit(ast::Application const &A) { A[0].accept(*this); std::unique_ptr D = std::move(Result); - // Error out if not a lambda expression or runtime expression. - std::unique_ptr L = dyn_castU(D); - if (!L) { + // Error out if not a Closure expression or Runtime expression. + std::unique_ptr C = dyn_castU(D); + if (!C) { // maybe a runtime function? std::unique_ptr RF = dyn_castU(D); if (!RF) { - std::cerr << "Expected lambda expression in Application.\n"; + llvm::errs() << "Expected closure or runtime function expression in " + "application.\n"; return; } @@ -233,21 +240,19 @@ void Interpreter::visit(ast::Application const &A) { // will contain pointers to the results in ArgHolder. This sucks a bit but // at this point, I am not sure if there's a point in focusing on optimizing // this. - std::vector> ArgHolder; - std::vector Args; - ArgHolder.reserve(A.length() - 1); - Args.reserve(A.length() - 1); - for (size_t Idx = 1; Idx < A.length(); ++Idx) { - A[Idx].accept(*this); + std::vector> ArgHolder(A.length() - 1); + std::vector Args(A.length() - 1); + for (size_t Idx = 0; Idx < A.length() - 1; ++Idx) { + A[Idx + 1].accept(*this); assert(Result && "Expected result from expression."); - ArgHolder.emplace_back(std::move(Result)); - Args.emplace_back(ArgHolder.back().get()); + ArgHolder[Idx] = std::move(Result); + Args[Idx] = ArgHolder[Idx].get(); } LLVM_DEBUG({ llvm::dbgs() << "Calling runtime function: " << RF->getName() << "\n"; for (const ast::ValueNode *Arg : Args) { - assert(llvm::dyn_cast(Arg) && + assert(Arg && llvm::dyn_cast(Arg) && "Expected Integer in runtime function call."); llvm::dbgs() << " Arg: "; Arg->dump(); @@ -271,7 +276,8 @@ void Interpreter::visit(ast::Application const &A) { // If we have a list formals then, error out of args diff than formals. // If we have a list rest formals then, error out if args less than formals. // If it's identifier formals then it does not matter. - const ast::Formal &F = L->getFormals(); + const ast::Lambda &L = C->getLambda(); + const ast::Formal &F = L.getFormals(); if (F.getType() == ast::Formal::Type::List) { auto LF = static_cast(F); if (Args.size() != LF.size()) { @@ -323,9 +329,13 @@ void Interpreter::visit(ast::Application const &A) { llvm_unreachable("unknown formal type"); } - Envs.push_back(Env); + Envs.push_back(C->getEnvironment()); // Pushes the closure environment first. + Envs.push_back(Env); // Then pushes the environment with the args. + // 4. Return the result of the application. - L->getBody().accept(*this); + L.getBody().accept(*this); + + Envs.pop_back(); Envs.pop_back(); } diff --git a/test/integration/closure.rkt b/test/integration/closure.rkt new file mode 100644 index 0000000..71660cf --- /dev/null +++ b/test/integration/closure.rkt @@ -0,0 +1,8 @@ +;; RUN: norac %s | FileCheck %s +;; CHECK: 12 +(linklet () () + (define-values (fn) (values 0)) + (let-values (((x) (values 2))) + (set! fn (lambda (y) (+ x y)))) + (let-values (((x) (values 3))) + (fn 10))) diff --git a/test/integration/let-values3.rkt b/test/integration/let-values3.rkt new file mode 100644 index 0000000..307ce05 --- /dev/null +++ b/test/integration/let-values3.rkt @@ -0,0 +1,7 @@ +;; RUN: norac %s | FileCheck %s +;; CHECK: 8 +(linklet () () + (let-values ([(x) 5]) + (let-values ([(f) (lambda (y) (+ x y))]) + (let-values ([(x) 7]) + (f 3)))))