Skip to content

Commit

Permalink
Add import, math module
Browse files Browse the repository at this point in the history
  • Loading branch information
WyattBlue committed Jan 25, 2024
1 parent a9fa6ad commit 7316d58
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 16 deletions.
23 changes: 23 additions & 0 deletions auto_editor/lang/libmath.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations

import math

from auto_editor.lib.contracts import Proc, andc, gt_c, is_real, between_c


def all() -> dict[str, object]:
return {
"exp": Proc("exp", math.exp, (1, 1), is_real),
"ceil": Proc("ceil", math.ceil, (1, 1), is_real),
"floor": Proc("floor", math.floor, (1, 1), is_real),
"sin": Proc("sin", math.sin, (1, 1), is_real),
"cos": Proc("cos", math.cos, (1, 1), is_real),
"tan": Proc("tan", math.tan, (1, 1), is_real),
"asin": Proc("asin", math.asin, (1, 1), between_c(-1, 1)),
"acos": Proc("acos", math.acos, (1, 1), between_c(-1, 1)),
"atan": Proc("atan", math.atan, (1, 1), is_real),
"log": Proc("log", math.log, (1, 2), andc(is_real, gt_c(0))),
"pi": math.pi,
"e": math.e,
"tau": math.tau,
}
32 changes: 22 additions & 10 deletions auto_editor/lang/palet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

from __future__ import annotations

import cmath
import math
from cmath import sqrt as complex_sqrt
from dataclasses import dataclass
from difflib import get_close_matches
from fractions import Fraction
Expand Down Expand Up @@ -549,7 +548,7 @@ def int_div(n: int, *m: int) -> int:


def _sqrt(v: Number) -> Number:
r = cmath.sqrt(v)
r = complex_sqrt(v)
if r.imag == 0:
if int(r.real) == r.real:
return int(r.real)
Expand Down Expand Up @@ -1396,6 +1395,25 @@ def syn_let_star(env: Env, node: Node) -> Any:
return my_eval(inner_env, node[-1])


def syn_import(env: Env, node: Node) -> None:
guard_term(node, 2, 2)

if type(node[1]) is not Sym:
raise MyError("class name must be an identifier")

module = node[1].val
error = MyError(f"No module named `{module}`")

if module != "math":
raise error
try:
obj = __import__("auto_editor.lang.libmath", fromlist=["lang"])
except ImportError:
raise error

env.update(obj.all())


def syn_class(env: Env, node: Node) -> None:
if len(node) < 2:
raise MyError(f"{node[0]}: Expects at least 1 term")
Expand Down Expand Up @@ -1544,6 +1562,7 @@ def my_eval(env: Env, node: object) -> Any:
"case": Syntax(syn_case),
"let": Syntax(syn_let),
"let*": Syntax(syn_let_star),
"import": Syntax(syn_import),
"class": Syntax(syn_class),
"@r": Syntax(attr),
# loops
Expand Down Expand Up @@ -1615,17 +1634,10 @@ def my_eval(env: Env, node: object) -> Any:
"imag-part": Proc("imag-part", lambda v: v.imag, (1, 1), is_num),
# reals
"pow": Proc("pow", pow, (2, 2), is_real),
"exp": Proc("exp", math.exp, (1, 1), is_real),
"abs": Proc("abs", abs, (1, 1), is_real),
"ceil": Proc("ceil", math.ceil, (1, 1), is_real),
"floor": Proc("floor", math.floor, (1, 1), is_real),
"round": Proc("round", round, (1, 1), is_real),
"max": Proc("max", lambda *v: max(v), (1, None), is_real),
"min": Proc("min", lambda *v: min(v), (1, None), is_real),
"sin": Proc("sin", math.sin, (1, 1), is_real),
"cos": Proc("cos", math.cos, (1, 1), is_real),
"log": Proc("log", math.log, (1, 2), andc(is_real, gt_c(0))),
"tan": Proc("tan", math.tan, (1, 1), is_real),
"mod": Proc("mod", mod, (2, 2), is_int),
"modulo": Proc("modulo", mod, (2, 2), is_int),
# symbols
Expand Down
7 changes: 1 addition & 6 deletions auto_editor/subcommands/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,12 +577,6 @@ def cases(*cases: tuple[str, Any]) -> None:
("(pow 4 0.5)", 2.0),
("(abs 1.0)", 1.0),
("(abs -1)", 1),
("(round 3.5)", 4),
("(round 2.5)", 2),
("(ceil 2.1)", 3),
("(ceil 2.9)", 3),
("(floor 2.1)", 2),
("(floor 2.9)", 2),
("(bool? #t)", True),
("(bool? #f)", True),
("(bool? 0)", False),
Expand Down Expand Up @@ -693,6 +687,7 @@ def palet_scripts():
run.raw(["palet", "resources/scripts/maxcut.pal"])
run.raw(["palet", "resources/scripts/scope.pal"])
run.raw(["palet", "resources/scripts/case.pal"])
run.raw(["palet", "resources/scripts/testmath.pal"])

tests = []

Expand Down
28 changes: 28 additions & 0 deletions resources/scripts/testmath.pal
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/usr/bin/env auto-editor palet
#lang palet

(import math)

(assert (equal? (round 3.5) 4))
(assert (equal? (round 2.5) 2))
(assert (equal? (ceil 2.1) 3))
(assert (equal? (ceil 2.9) 3))
(assert (equal? (floor 2.1) 2))
(assert (equal? (floor 2.9) 2))

(assert (equal? (sin 0) 0.0))
(assert (equal? (sin 0/1) 0.0))
(assert (equal? (sin (/ pi 2)) 1.0))

(assert (equal? (cos 0) 1.0))
(assert (equal? (cos (* pi 2)) 1.0))
(assert (equal? (cos pi) -1.0))
(assert (equal? (cos tau) 1.0))

(assert (equal? (asin 0) 0.0))
(assert (equal? (asin 0/1) 0.0))
(assert (equal? (acos 1) 0.0))
(assert (equal? (acos -1) pi))

(assert (equal? (log 1) 0.0))
(assert (equal? (log e) 1.0))

0 comments on commit 7316d58

Please sign in to comment.