Skip to content

Commit

Permalink
Fixed recursion
Browse files Browse the repository at this point in the history
  • Loading branch information
AdUhTkJm authored and mengzhuo committed Dec 29, 2024
1 parent c27618b commit 482e318
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 30 deletions.
20 changes: 7 additions & 13 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __exit__(self, exc_type, exc_value, traceback):
parser.add_argument("-i", "--build-index", action="store_true", help="build OCaml index and exit")
parser.add_argument("-b", "--build-only", action="store_true", help="build without testing")
parser.add_argument("-v", "--verbose", action="store_true", help="interpreter outputs detailed values")
parser.add_argument("-t", "--test", type=str, help="execute this test case only")

args = parser.parse_args()

Expand All @@ -34,15 +35,8 @@ def __exit__(self, exc_type, exc_value, traceback):
core = "~/.moon/lib/core"
bundled = f"{core}/target/wasm-gc/release/bundle"

if args.debug:
debug = "OCAMLRUNPARAM=b"
else:
debug = ""

if args.verbose:
verbose = "-DVERBOSE"
else:
verbose = ""
debug = "OCAMLRUNPARAM=b" if args.debug else ""
verbose = "-DVERBOSE" if args.verbose else ""

def try_remove(path):
if os.path.exists(path):
Expand Down Expand Up @@ -70,10 +64,10 @@ def try_remove(path):
exit(0)

with DirContext("test"):
os.makedirs("build", exist_ok=True);

cases = os.listdir("src");
os.makedirs("build", exist_ok=True)

cases = os.listdir("src") if args.test is None else [args.test]

for src in cases:
print(f"Execute task: {src}")
# Remove all previously compiled files.
Expand All @@ -88,7 +82,7 @@ def try_remove(path):
os.system(f"{debug} moonc link-core {bundled}/core.core build/{src}.core -o build/{dest} -pkg-config-path {src}/moon.pkg.json -pkg-sources {core}:{src} -target {target}")

# Test.
os.system(f"build/interpreter build/{dest}.ssa > build/output.txt")
os.system(f"build/interpreter build/{dest}.ssa > build/output.txt 2> build/debug.txt")
diff = os.system(f"diff build/output.txt src/{src}/{src}.ans")

if diff == 0:
Expand Down
59 changes: 42 additions & 17 deletions test/interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ std::map<std::string, std::vector<std::string>> fns;

// Values of registers used when interpreting.
// TODO: currently no FP supported.
std::map<std::string, uint64_t> regs;
std::map<std::string, int64_t> regs;

std::vector<std::string> split(std::string s, std::string delim) {
size_t start = 0, end;
Expand Down Expand Up @@ -55,8 +55,8 @@ int int_of(std::string s) {
return atoi(s.c_str());
}

#define RTYPE(name, op) std::make_pair(name, [](uint64_t x, uint64_t y) { return x op y; })
#define MEM(name, type) std::make_pair(name, [](uint64_t x, int offset) { return *((type*) (x + offset)); })
#define RTYPE(name, op) std::make_pair(name, [](int64_t x, int64_t y) { return x op y; })
#define MEM(name, type) std::make_pair(name, [](int64_t x, int offset) { return *((type*) (x + offset)); })
#define VAL(i) regs[args[i]]

#ifdef VERBOSE
Expand All @@ -68,8 +68,8 @@ int int_of(std::string s) {
#endif

// Argument `label` is where we start interpreting.
uint64_t interpret(std::string label) {
static std::map<std::string, std::function<uint64_t (uint64_t, uint64_t)>> rtype = {
int64_t interpret(std::string label) {
static std::map<std::string, std::function<int64_t (int64_t, int64_t)>> rtype = {
RTYPE("add", +),
RTYPE("sub", -),
RTYPE("mul", *),
Expand All @@ -88,11 +88,11 @@ uint64_t interpret(std::string label) {
RTYPE("shr", >>),
};

static std::map<std::string, std::function<uint64_t (uint64_t, int)>> load = {
static std::map<std::string, std::function<int64_t (int64_t, int)>> load = {
MEM("lb", char),
MEM("lh", char16_t),
MEM("lw", int),
MEM("ld", uint64_t),
MEM("ld", int64_t),
};

std::string prev;
Expand Down Expand Up @@ -132,53 +132,78 @@ uint64_t interpret(std::string label) {
if (op == "sw")
*((int*)(rs + offset)) = rd;
if (op == "sd")
*((uint64_t*)(rs + offset)) = rd;
*((int64_t*)(rs + offset)) = rd;

OUTPUT(args[1], VAL(1));
continue;
}

if (op == "call") {
auto before = regs;

auto fn = args[2];
for (int i = 0; i < fns[fn].size(); i++) {
regs[fns[fn][i]] = VAL(i + 3);
OUTPUT(fns[fn][i], VAL(i + 3));
}
VAL(1) = interpret(fn);

auto value = interpret(fn);
regs = before;
VAL(1) = value;

OUTPUT(args[1], VAL(1));
continue;
}

if (op == "call_indirect") {
auto before = regs;

// Remember, we store function names in the pointer
std::string fn(*(char**) VAL(2));
SAY("jump to " << fn);
for (int i = 0; i < fns[fn].size(); i++) {
regs[fns[fn][i]] = VAL(i + 3);
OUTPUT(fns[fn][i], VAL(i + 3));
}
VAL(1) = interpret(fn);

auto value = interpret(fn);
regs = before;
VAL(1) = value;

OUTPUT(args[1], VAL(1));
continue;
}

if (op == "call_libc") {
for (int i = 3; i < args.size(); i++)
OUTPUT(args[i], VAL(i));

if (args[2] == "puts") {
std::u16string utf16_str((char16_t*) VAL(3));
// Make the output string null-terminated
int len = *(int*) (VAL(3) - 4);
auto ptr = new char16_t[len + 1];
memcpy(ptr, (void*) VAL(3), len * 2);
ptr[len] = 0;

std::u16string utf16_str(ptr);

// Convert to UTF-8, so that cout can output it
std::wstring_convert<std::codecvt_utf8_utf16<char16_t>, char16_t> convert;
std::string utf8_string = convert.to_bytes(utf16_str);

for (int i = 0; i < len * 2; i++)
OUTPUT("char of byte " << i, *(char*) (ptr + i));
std::cout << utf8_string << std::endl;
continue;
}

if (args[2] == "malloc") {
VAL(1) = (uint64_t) new char[VAL(3)];
VAL(1) = (int64_t) new char[VAL(3)];
continue;
}

if (args[2] == "strlen") {
VAL(1) = (uint64_t) strlen((char*) VAL(3));
VAL(1) = (int64_t) strlen((char*) VAL(3));
continue;
}

Expand All @@ -203,7 +228,7 @@ uint64_t interpret(std::string label) {
if (op == "malloc") {
auto len = int_of(args[2]);

VAL(1) = (uint64_t) new char[len];
VAL(1) = (int64_t) new char[len];
OUTPUT(args[1], VAL(1));
continue;
}
Expand Down Expand Up @@ -337,15 +362,15 @@ int main(int argc, char** argv) {
for (int i = 0; i < len; i++) {
char* name = new char[elems[i].size() + 1];
strcpy(name, elems[i].c_str());
*(uint64_t*)(space + i * 8) = (uint64_t) name;
*(int64_t*)(space + i * 8) = (int64_t) name;
}
}

if (!space) {
std::cerr << "Bad SSA: unrecognized global array type\n";
return 2;
}
regs[name] = (uint64_t) space;
regs[name] = (int64_t) space;
continue;
}

Expand All @@ -370,4 +395,4 @@ int main(int argc, char** argv) {

regs["_"] = unit;
interpret("_start");
}
}
1 change: 1 addition & 0 deletions test/src/fib/fib.ans
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
21
11 changes: 11 additions & 0 deletions test/src/fib/fib.mbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
fn fib(x: Int) -> Int {
if x <= 1 {
1
} else {
fib(x - 1) + fib(x - 2)
}
}

fn main {
println(fib(7))
}

0 comments on commit 482e318

Please sign in to comment.