Skip to content

Commit

Permalink
Move to TypeRegistry (#544)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcauberer authored May 11, 2024
1 parent 65f09aa commit d56dafd
Show file tree
Hide file tree
Showing 39 changed files with 912 additions and 568 deletions.
60 changes: 44 additions & 16 deletions media/test-project/test.spice
Original file line number Diff line number Diff line change
@@ -1,26 +1,54 @@
type T int|long;
import "std/data/doubly-linked-list";

type TestStruct<T> struct {
T _f1
unsigned long length
f<int> main() {
DoublyLinkedList<String> list;
assert list.getSize() == 0;
assert list.isEmpty();
list.pushBack(String("Hello"));
assert list.getSize() == 1;
assert !list.isEmpty();
String var = String("World");
list.pushBack(var);
assert list.getSize() == 2;
list.pushFront(String("Hi"));
assert list.getSize() == 3;
assert list.getFront() == String("Hi");
assert list.getBack() == String("World");
list.removeFront();
assert list.getSize() == 2;
assert list.getFront() == String("Hello");
list.removeBack();
assert list.getSize() == 1;
assert list.getBack() == String("Hello");
list.pushBack(String("World"));
list.pushFront(String("Hi"));
list.pushBack(String("Programmers"));
assert list.getSize() == 4;
list.remove(String("World"));
assert list.getSize() == 3;
assert list.get(0) == String("Hi");
assert list.get(1) == String("Hello");
assert list.get(2) == String("Programmers");
list.removeAt(1);
assert list.getSize() == 2;
assert list.get(0) == String("Hi");
assert list.get(1) == String("Programmers");
printf("All assertions passed!\n");
}

p TestStruct.ctor(const unsigned long initialLength) {
this.length = initialLength;
}
/*import "std/iterator/number-iterator";

p TestStruct.printLength() {
printf("%d\n", this.length);
}
f<int> main() {
NumberIterator<int> itInt = range(1, 10);
dyn idxAndValueInt = itInt.getIdx();
assert idxAndValueInt.getSecond() == 4;
}*/

type Alias alias TestStruct<long>;
/*import "std/data/hash-table";

f<int> main() {
Alias a = Alias{12345l, (unsigned long) 54321l};
a.printLength();
dyn b = Alias(12l);
b.printLength();
}
HashTable<int, string> map = HashTable<int, string>(3l);
}*/

/*import "bootstrap/util/block-allocator";
import "bootstrap/util/memory";
Expand Down
2 changes: 1 addition & 1 deletion src-bootstrap/bindings/llvm/llvm.spice
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ public f<Value> Builder.getStruct(const Vector<Value>& values, bool packed = fal
public f<Value> Builder.getArray(Type ty, const Vector<Value>& values) {
unsafe {
LLVMValueRef* valuesRef = (LLVMValueRef*) values.getDataPtr();
LLVMValueRef valueRef = LLVMConstArray2(ty.self, valuesRef, (unsigned long) values.getSize());
LLVMValueRef valueRef = LLVMConstArray2(ty.self, valuesRef, values.getSize());
return Value{ valueRef };
}
}
Expand Down
7 changes: 5 additions & 2 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,14 @@ set(SOURCES
symboltablebuilder/SymbolTable.h
symboltablebuilder/SymbolTableEntry.cpp
symboltablebuilder/SymbolTableEntry.h
symboltablebuilder/QualType.cpp
symboltablebuilder/QualType.h
symboltablebuilder/Capture.cpp
symboltablebuilder/Capture.h
symboltablebuilder/Type.cpp
symboltablebuilder/Type.h
symboltablebuilder/TypeChain.cpp
symboltablebuilder/TypeChain.h
symboltablebuilder/TypeSpecifiers.cpp
symboltablebuilder/TypeSpecifiers.h
symboltablebuilder/Lifecycle.cpp
Expand Down Expand Up @@ -143,6 +146,8 @@ set(SOURCES
util/CommonUtil.h
util/FileUtil.cpp
util/FileUtil.h
util/CustomHashFunctions.cpp
util/CustomHashFunctions.h
util/CodeLoc.cpp
util/CodeLoc.h
util/CompilerWarning.cpp
Expand All @@ -153,8 +158,6 @@ set(SOURCES
util/Memory.h
util/RawStringOStream.cpp
util/RawStringOStream.h
symboltablebuilder/QualType.cpp
symboltablebuilder/QualType.h
)

add_executable(spice ${SOURCES} ${ANTLR_Spice_CXX_OUTPUTS})
Expand Down
6 changes: 4 additions & 2 deletions src/SourceFile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <exception/AntlrThrowingErrorListener.h>
#include <exception/CompilerError.h>
#include <global/GlobalResourceManager.h>
#include <global/TypeRegistry.h>
#include <importcollector/ImportCollector.h>
#include <irgenerator/IRGenerator.h>
#include <iroptimizer/IROptimizer.h>
Expand Down Expand Up @@ -554,6 +555,7 @@ void SourceFile::runBackEnd() { // NOLINT(misc-no-recursion)
CHECK_ABORT_FLAG_V()
std::cout << "\nSuccessfully compiled " << std::to_string(resourceManager.sourceFiles.size()) << " source file(s)";
std::cout << " or " << std::to_string(resourceManager.getTotalLineCount()) << " lines in total.\n";
std::cout << "Total number of types: " << std::to_string(TypeRegistry::getTypeCount()) << "\n";
std::cout << "Total compile time: " << std::to_string(resourceManager.totalTimer.getDurationMilliseconds()) << " ms\n";
}
}
Expand Down Expand Up @@ -582,8 +584,8 @@ bool SourceFile::imports(const SourceFile *sourceFile) const {
return std::ranges::any_of(dependencies, [=](const auto &dependency) { return dependency.second == sourceFile; });
}

bool SourceFile::isAlreadyImported(const std::string &filePathSearch,
std::vector<const SourceFile *> &circle) const { // NOLINT(misc-no-recursion)
bool SourceFile::isAlreadyImported(const std::string &filePathSearch, // NOLINT(misc-no-recursion)
std::vector<const SourceFile *> &circle) const {
circle.push_back(this);
// Check if the current source file corresponds to the path to search
if (std::filesystem::equivalent(filePath, filePathSearch))
Expand Down
4 changes: 4 additions & 0 deletions src/SourceFile.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
#include <string>
#include <utility>

// Ignore some warnings in ANTLR generated code
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Woverloaded-virtual"
#include <SpiceLexer.h>
#include <SpiceParser.h>
#include <Token.h>
#pragma GCC diagnostic pop

#include <ast/ASTNodes.h>
#include <exception/AntlrThrowingErrorListener.h>
Expand Down
7 changes: 6 additions & 1 deletion src/ast/ASTBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@
#include <functional>
#include <utility>

#include <CompilerPass.h>
// Ignore some warnings in ANTLR generated code
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Woverloaded-virtual"
#include <SpiceVisitor.h>
#pragma GCC diagnostic pop

#include <CompilerPass.h>
#include <util/CodeLoc.h>
#include <util/GlobalDefinitions.h>

Expand Down
10 changes: 9 additions & 1 deletion src/global/GlobalResourceManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

#include <SourceFile.h>
#include <ast/ASTNodes.h>
#include <global/TypeRegistry.h>
#include <typechecker/FunctionManager.h>
#include <typechecker/StructManager.h>
#include <util/FileUtil.h>

#include <llvm/MC/TargetRegistry.h>
Expand Down Expand Up @@ -61,7 +64,12 @@ GlobalResourceManager::GlobalResourceManager(const CliOptions &cliOptions)
}

GlobalResourceManager::~GlobalResourceManager() {
// Cleanup all global LLVM resources
// Cleanup all statics
TypeRegistry::clear();
FunctionManager::clear();
StructManager::clear();
InterfaceManager::clear();
// Cleanup all LLVM statics
llvm::llvm_shutdown();
}

Expand Down
72 changes: 61 additions & 11 deletions src/global/TypeRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,81 @@

#include "TypeRegistry.h"

#include <symboltablebuilder/Type.h>
#include <util/CustomHashFunctions.h>

namespace spice::compiler {

// Static member initialization
std::unordered_map<std::string, std::unique_ptr<Type>> TypeRegistry::types = {};
std::unordered_map<uint64_t, std::unique_ptr<Type>> TypeRegistry::types = {};

const Type *TypeRegistry::get(const std::string &name) {
const auto it = types.find(name);
return it != types.end() ? it->second.get() : nullptr;
}

const Type *TypeRegistry::getOrInsert(Type type) {
const std::string name = type.getName();
/**
* Get or insert a type into the type registry
*
* @param type The type to insert
* @return The inserted type
*/
const Type *TypeRegistry::getOrInsert(const Type &&type) {
const uint64_t hash = std::hash<Type>{}(type);

// Check if type already exists
const auto it = types.find(name);
const auto it = types.find(hash);
if (it != types.end())
return it->second.get();

// Create new type
const auto insertedElement = types.emplace(name, std::make_unique<Type>(std::move(type)));
const auto insertedElement = types.emplace(hash, std::make_unique<Type>(type));
return insertedElement.first->second.get();
}

/**
* Get or insert a type into the type registry
*
* @param superType The super type of the type
* @return The inserted type
*/
const Type *TypeRegistry::getOrInsert(SuperType superType) { return getOrInsert(Type(superType)); }

/**
* Get or insert a type into the type registry
*
* @param superType The super type of the type
* @param subType The sub type of the type
* @return The inserted type
*/
const Type *TypeRegistry::getOrInsert(SuperType superType, const std::string &subType) {
return getOrInsert(Type(superType, subType));
}

/**
* Get or insert a type into the type registry
*
* @param superType The super type of the type
* @param subType The sub type of the type
* @param typeId The type ID of the type
* @param data The data of the type
* @param templateTypes The template types of the type
* @return The inserted type
*/
const Type *TypeRegistry::getOrInsert(SuperType superType, const std::string &subType, uint64_t typeId,
const TypeChainElementData &data, const QualTypeList &templateTypes) {
return getOrInsert(Type(superType, subType, typeId, data, templateTypes));
}

/**
* Get or insert a type into the type registry
*
* @param typeChain The type chain of the type
* @return The inserted type
*/
const Type *TypeRegistry::getOrInsert(const TypeChain &typeChain) { return getOrInsert(Type(typeChain)); }

/**
* Get the number of types in the type registry
*
* @return The number of types in the type registry
*/
size_t TypeRegistry::getTypeCount() { return types.size(); }

void TypeRegistry::clear() { types.clear();}

} // namespace spice::compiler
20 changes: 13 additions & 7 deletions src/global/TypeRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

namespace spice::compiler {
#include <symboltablebuilder/Type.h>

// Forward declarations
class Type;
enum SuperType : uint8_t;
namespace spice::compiler {

class TypeRegistry {
public:
Expand All @@ -19,13 +18,20 @@ class TypeRegistry {
TypeRegistry(const TypeRegistry &) = delete;

// Public methods
static const Type *get(const std::string &name);
static const Type *getOrInsert(Type type);
static const Type *getOrInsert(SuperType superType);
static const Type *getOrInsert(SuperType superType, const std::string &subType);
static const Type *getOrInsert(SuperType superType, const std::string &subType, uint64_t typeId,
const TypeChainElementData &data, const QualTypeList &templateTypes);
static const Type *getOrInsert(const TypeChain& typeChain);
static size_t getTypeCount();
static void clear();

private:
// Private members
static std::unordered_map<std::string, std::unique_ptr<Type>> types;
static std::unordered_map<uint64_t, std::unique_ptr<Type>> types;

// Private methods
static const Type *getOrInsert(const Type &&type);
};

} // namespace spice::compiler
4 changes: 0 additions & 4 deletions src/irgenerator/GenExpressions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,6 @@ std::any IRGenerator::visitTernaryExpr(const TernaryExprNode *node) {
trueValue = condValue;
falseValue = resolveValue(node->operands()[1]);
} else {
const QualType &op1Type = node->operands()[1]->getEvaluatedSymbolType(manIdx);
const QualType &op2Type = node->operands()[2]->getEvaluatedSymbolType(manIdx);
llvm::Type *op1Ty = op1Type.toLLVMType(context, currentScope);
llvm::Type *op2Ty = op2Type.toLLVMType(context, currentScope);
trueValue = resolveValue(node->operands()[1]);
falseValue = resolveValue(node->operands()[2]);
}
Expand Down
4 changes: 2 additions & 2 deletions src/irgenerator/GenImplicit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,8 @@ void IRGenerator::generateTestMain() {

// Prepare entry for test main
QualType functionType(TY_FUNCTION);
functionType.getSpecifiers() = TypeSpecifiers::of(TY_FUNCTION);
functionType.getSpecifiers().isPublic = true;
functionType.setSpecifiers(TypeSpecifiers::of(TY_FUNCTION));
functionType.makePublic();
SymbolTableEntry entry(MAIN_FUNCTION_NAME, functionType, rootScope, nullptr, 0, false);

// Prepare test main function
Expand Down
14 changes: 4 additions & 10 deletions src/irgenerator/GenTopLevelDefinitions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,8 @@ std::any IRGenerator::visitFctDef(const FctDefNode *node) {
assert(paramSymbol != nullptr);
const QualType paramSymbolType = manifestation->getParamTypes().at(argIdx);
// Pass the information if captures are taken for function/procedure types
if (paramSymbolType.isOneOf({TY_FUNCTION, TY_PROCEDURE}) && paramSymbolType.hasLambdaCaptures()) {
QualType paramSymbolSymbolType = paramSymbol->getQualType();
paramSymbolSymbolType.setHasLambdaCaptures(true);
paramSymbol->updateType(paramSymbolSymbolType, true);
}
if (paramSymbolType.isOneOf({TY_FUNCTION, TY_PROCEDURE}) && paramSymbolType.hasLambdaCaptures())
paramSymbol->updateType(paramSymbol->getQualType().getWithLambdaCaptures(), true);
// Retrieve type of param
llvm::Type *paramType = paramSymbolType.toLLVMType(context, currentScope);
// Add it to the lists
Expand Down Expand Up @@ -347,11 +344,8 @@ std::any IRGenerator::visitProcDef(const ProcDefNode *node) {
assert(paramSymbol != nullptr);
const QualType paramSymbolType = manifestation->getParamTypes().at(argIdx);
// Pass the information if captures are taken for function/procedure types
if (paramSymbolType.isOneOf({TY_FUNCTION, TY_PROCEDURE}) && paramSymbolType.hasLambdaCaptures()) {
QualType paramSymbolSymbolType = paramSymbol->getQualType();
paramSymbolSymbolType.setHasLambdaCaptures(true);
paramSymbol->updateType(paramSymbolSymbolType, true);
}
if (paramSymbolType.isOneOf({TY_FUNCTION, TY_PROCEDURE}) && paramSymbolType.hasLambdaCaptures())
paramSymbol->updateType(paramSymbol->getQualType().getWithLambdaCaptures(), true);
// Retrieve type of param
llvm::Type *paramType = paramSymbolType.toLLVMType(context, currentScope);
// Add it to the lists
Expand Down
Loading

0 comments on commit d56dafd

Please sign in to comment.