diff --git a/auto_editor/lang/palet.py b/auto_editor/lang/palet.py index a4f90b90d..ed51c1298 100644 --- a/auto_editor/lang/palet.py +++ b/auto_editor/lang/palet.py @@ -9,6 +9,7 @@ import cmath import math import random +from dataclasses import dataclass from difflib import get_close_matches from fractions import Fraction from functools import reduce @@ -17,6 +18,7 @@ from typing import TYPE_CHECKING import numpy as np +from numpy import logical_and, logical_not, logical_or, logical_xor from auto_editor.analyze import edit_method, mut_remove_large, mut_remove_small from auto_editor.lib.contracts import * @@ -64,17 +66,10 @@ class ClosingError(MyError): } +@dataclass(slots=True) class Token: - __slots__ = ("type", "value") - - def __init__(self, type: str, value: Any): - self.type = type - self.value = value - - def __str__(self) -> str: - return f"(Token {print_str(self.type)} {print_str(self.value)})" - - __repr__ = __str__ + type: str + value: Any class Lexer: @@ -374,14 +369,12 @@ def handle_strings() -> bool: ############################################################################### +@dataclass(slots=True) class Method: - __slots__ = "val" - - def __init__(self, val: str): - self.val = val + val: str def __str__(self) -> str: - return f'(Method "{self.val}")' + return f"#" __repr__ = __str__ @@ -474,17 +467,16 @@ def check_args( ) -> None: lower, upper = arity amount = len(values) - if upper is not None and lower > upper: - raise ValueError("lower must be less than upper") - if lower == upper and len(values) != lower: - raise MyError(f"{o}: Arity mismatch. Expected {lower}, got {amount}") + assert not (upper is not None and lower > upper) + base = f"`{o}` has an arity mismatch. Expected " + + if lower == upper and len(values) != lower: + raise MyError(f"{base}{lower}, got {amount}") if upper is None and amount < lower: - raise MyError(f"{o}: Arity mismatch. Expected at least {lower}, got {amount}") + raise MyError(f"{base}at least {lower}, got {amount}") if upper is not None and (amount > upper or amount < lower): - raise MyError( - f"{o}: Arity mismatch. Expected between {lower} and {upper}, got {amount}" - ) + raise MyError(f"{base}between {lower} and {upper}, got {amount}") if cont is None: return @@ -493,7 +485,7 @@ def check_args( check = cont[-1] if i >= len(cont) else cont[i] if not check_contract(check, val): exp = f"{check}" if callable(check) else print_str(check) - raise MyError(f"{o} expected a {exp}, got {print_str(val)}") + raise MyError(f"`{o}` expected a {exp}, got {print_str(val)}") is_cont = Contract("contract?", is_contract) @@ -575,7 +567,7 @@ def _sqrt(v: Number) -> Number: def _xor(*vals: Any) -> bool | BoolList: if is_boolarr(vals[0]): check_args("xor", vals, (2, None), [is_boolarr]) - return reduce(lambda a, b: boolop(a, b, np.logical_xor), vals) + return reduce(lambda a, b: boolop(a, b, logical_xor), vals) check_args("xor", vals, (2, None), [is_bool]) return reduce(lambda a, b: a ^ b, vals) @@ -829,6 +821,63 @@ def __call__(self, *args: Any) -> Any: return my_eval(inner_env, self.body[-1]) +@dataclass(slots=True) +class KeywordProc: + env: Env + name: str + parms: list[str] + kw_parms: list[str] + body: list + arity: tuple[int, None] + contracts: list[Any] | None = None + + def __call__(self, *args: Any) -> Any: + env = {} + + for i, parm in enumerate(self.parms): + if type(args[i]) is Keyword: + raise MyError(f"Invalid keyword `{args[i]}`") + env[parm] = args[i] + + remain_args = args[len(self.parms) :] + + allow_pos = True + pos_index = 0 + key = "" + for arg in remain_args: + if type(arg) is Keyword: + if key: + raise MyError("Expected value for keyword but got another keyword") + key = arg.val + allow_pos = False + elif key: + env[key] = arg + key = "" + else: + if not allow_pos: + raise MyError("Positional argument not allowed here") + if pos_index >= len(self.kw_parms): + base = f"`{self.name}` has an arity mismatch. Expected" + upper = len(self.parms) + len(self.kw_parms) + raise MyError(f"{base} at most {upper}") + + env[self.kw_parms[pos_index]] = arg + pos_index += 1 + + inner_env = Env(env, self.env) + + for item in self.body[0:-1]: + my_eval(inner_env, item) + + return my_eval(inner_env, self.body[-1]) + + def __str__(self) -> str: + return self.name + + def __repr__(self) -> str: + return f"#" + + class Syntax: __slots__ = "syn" @@ -898,13 +947,29 @@ def syn_define(env: Env, node: list) -> None: n = term[0].val parms: list[str] = [] - for item in term[1:]: - if type(item) is not Sym: - raise MyError(f"{node[0]}: must be an identifier") + kparms: list[str] = [] + kw_only = False - parms.append(f"{item}") + for item in term[1:]: + if kw_only: + if type(item) is Sym: + raise MyError(f"{node[0]}: {item} must be a keyword") + if type(item) is not Keyword: + raise MyError(f"{node[0]}: must be an identifier or keyword") + kparms.append(item.val) + else: + if type(item) is Keyword: + kw_only = True + kparms.append(item.val) + elif type(item) is Sym: + parms.append(item.val) + else: + raise MyError(f"{node[0]}: must be an identifier") - env[n] = UserProc(env, n, parms, body) + if kw_only: + env[n] = KeywordProc(env, n, parms, kparms, body, (len(parms), None)) + else: + env[n] = UserProc(env, n, parms, body) return None elif type(node[1]) is not Sym: @@ -1124,7 +1189,7 @@ def syn_and(env: Env, node: list) -> Any: if is_boolarr(first): vals = [first] + [my_eval(env, n) for n in node[2:]] check_args(node[0], vals, (2, None), [is_boolarr]) - return reduce(lambda a, b: boolop(a, b, np.logical_and), vals) + return reduce(lambda a, b: boolop(a, b, logical_and), vals) raise MyError(f"{node[0]} expects (or/c bool? bool-array?)") @@ -1148,7 +1213,7 @@ def syn_or(env: Env, node: list) -> Any: if is_boolarr(first): vals = [first] + [my_eval(env, n) for n in node[2:]] check_args(node[0], vals, (2, None), [is_boolarr]) - return reduce(lambda a, b: boolop(a, b, np.logical_or), vals) + return reduce(lambda a, b: boolop(a, b, logical_or), vals) raise MyError(f"{node[0]} expects (or/c bool? bool-array?)") @@ -1390,12 +1455,7 @@ def my_eval(env: Env, node: object) -> Any: "begin": Proc("begin", lambda *x: x[-1] if x else None, (0, None)), "void": Proc("void", lambda *v: None, (0, 0)), # control / b-arrays - "not": Proc( - "not", - lambda v: not v if type(v) is bool else np.logical_not(v), - (1, 1), - [bool_or_barr], - ), + "not": Proc("not", lambda v: not v if type(v) is bool else logical_not(v), (1, 1), [bool_or_barr]), "and": Syntax(syn_and), "or": Syntax(syn_or), "xor": Proc("xor", _xor, (2, None), [bool_or_barr]), diff --git a/resources/scripts/scope.pal b/resources/scripts/scope.pal index d1761ce5b..6232c7ddf 100644 --- a/resources/scripts/scope.pal +++ b/resources/scripts/scope.pal @@ -1,13 +1,11 @@ #!/usr/bin/env auto-editor palet - #lang palet -; Enforce lexical scoping +; Enforce lexical scoping (define (f x) (lambda (y) (+ x y))) (assert (equal? ((f 10) 12) 22)) ; Test that variables do not leak scope - (define (outer a) (define (inner1 b) (define (inner2 c) c) @@ -21,6 +19,26 @@ (assert (not (var-exists? 'b))) (assert (not (var-exists? 'c))) +; Test keyword arguments +(define (f1 a b c) (vector a b c)) +(define (f2 a #:b #:c) (vector a b c)) +;(define (f3 a #:b [#:c 0]) (vector a b c)) +;(define (f4 [a 2] [#:b 1] [#:c 0]) (vector a b c)) + +; Invalid defines +; (define (f [a 2] b c) (vector a b c)) +; (define (f a #:b c) (vector a b c)) +; (define (f a [#:b 1] #:c) (vector a b c)) +; (define (f [a 2] #:b [#:c 0]) (vector a b c)) +; (define (f a a #:b) (void)) +; (define (f a #:a #:b) (void)) + +(assert (equal? (f1 3 2 1) #(3 2 1))) +(assert (equal? (f2 3 2 1) #(3 2 1))) +(assert (equal? (f2 3 2 #:c 1) #(3 2 1))) +(assert (equal? (f2 3 #:b 2 #:c 1) #(3 2 1))) +(assert (equal? (f2 3 #:c 1 #:b 2) #(3 2 1))) + ; Test `let` and `let*` (assert (equal? (let ([x 5]) x) 5))