diff --git a/auto_editor/lang/palet.py b/auto_editor/lang/palet.py index 7a1adbacb..cfd6aa021 100644 --- a/auto_editor/lang/palet.py +++ b/auto_editor/lang/palet.py @@ -535,9 +535,9 @@ def _sqrt(v: Number) -> Number: def _xor(*vals: Any) -> bool | BoolList: if is_boolarr(vals[0]): - check_args("xor", vals, (2, None), [is_boolarr]) + check_args("xor", vals, (2, None), (is_boolarr,)) return reduce(lambda a, b: boolop(a, b, logical_xor), vals) - check_args("xor", vals, (2, None), [is_bool]) + check_args("xor", vals, (2, None), (is_bool,)) return reduce(lambda a, b: a ^ b, vals) @@ -625,11 +625,11 @@ def maxcut(oarr: BoolList, _min: int) -> BoolList: def margin(a: int, b: Any, c: Any = None) -> BoolList: if c is None: - check_args("margin", [a, b], (2, 2), [is_int, is_boolarr]) + check_args("margin", [a, b], (2, 2), (is_int, is_boolarr)) oarr = b start, end = a, a else: - check_args("margin", [a, b, c], (3, 3), [is_int, is_int, is_boolarr]) + check_args("margin", [a, b, c], (3, 3), (is_int, is_int, is_boolarr)) oarr = c start, end = a, b @@ -747,8 +747,8 @@ def __init__( env: Env, name: str, parms: list[str], + contracts: tuple[Any, ...], body: list, - contracts: list[Any] | None = None, ): self.env = env self.name = name @@ -886,7 +886,7 @@ def syn_lambda(env: Env, node: list) -> UserProc: parms.append(f"{item}") - return UserProc(env, "", parms, node[2:]) + return UserProc(env, "", parms, (), node[2:]) def syn_define(env: Env, node: list) -> None: @@ -910,7 +910,7 @@ def syn_define(env: Env, node: list) -> None: if type(item) is Sym: raise MyError(f"{node[0]}: {item} must be a keyword") if type(item) is not Keyword: - raise MyError(f"{node[0]}: must be an identifier or keyword") + raise MyError(f"{node[0]}: must be a keyword") kparms.append(item.val) else: if type(item) is Keyword: @@ -924,7 +924,7 @@ def syn_define(env: Env, node: list) -> None: if kw_only: env[n] = KeywordProc(env, n, parms, kparms, body, (len(parms), None)) else: - env[n] = UserProc(env, n, parms, body) + env[n] = UserProc(env, n, parms, (), body) return None elif type(node[1]) is not Sym: @@ -951,7 +951,7 @@ def syn_define(env: Env, node: list) -> None: parms.append(f"{item}") - env[n] = UserProc(env, n, parms, body) + env[n] = UserProc(env, n, parms, (), body) else: for item in node[2:-1]: @@ -971,7 +971,7 @@ def syn_definec(env: Env, node: list) -> None: n = node[1][0].val - contracts: list[Proc | Contract] = [] + contracts: list[Any] = [] parms: list[str] = [] for item in node[1][1:]: if item == Sym("->"): @@ -988,7 +988,7 @@ def syn_definec(env: Env, node: list) -> None: parms.append(f"{item[0]}") contracts.append(con) - env[n] = UserProc(env, n, parms, node[2:], contracts) + env[n] = UserProc(env, n, parms, tuple(contracts), node[2:]) return None @@ -1145,7 +1145,7 @@ def syn_and(env: Env, node: list) -> Any: if is_boolarr(first): vals = [first] + [my_eval(env, n) for n in node[2:]] - check_args(node[0], vals, (2, None), [is_boolarr]) + check_args(node[0], vals, (2, None), (is_boolarr,)) return reduce(lambda a, b: boolop(a, b, logical_and), vals) raise MyError(f"{node[0]} expects (or/c bool? bool-array?)") @@ -1169,7 +1169,7 @@ def syn_or(env: Env, node: list) -> Any: if is_boolarr(first): vals = [first] + [my_eval(env, n) for n in node[2:]] - check_args(node[0], vals, (2, None), [is_boolarr]) + check_args(node[0], vals, (2, None), (is_boolarr,)) return reduce(lambda a, b: boolop(a, b, logical_or), vals) raise MyError(f"{node[0]} expects (or/c bool? bool-array?)") @@ -1405,156 +1405,151 @@ def my_eval(env: Env, node: object) -> Any: "begin": Proc("begin", lambda *x: x[-1] if x else None, (0, None)), "void": Proc("void", lambda *v: None, (0, 0)), # control / b-arrays - "not": Proc("not", lambda v: not v if type(v) is bool else logical_not(v), (1, 1), [bool_or_barr]), + "not": Proc("not", lambda v: not v if type(v) is bool else logical_not(v), (1, 1), bool_or_barr), "and": Syntax(syn_and), "or": Syntax(syn_or), - "xor": Proc("xor", _xor, (2, None), [bool_or_barr]), + "xor": Proc("xor", _xor, (2, None), bool_or_barr), # booleans - ">": Proc(">", lambda a, b: a > b, (2, 2), [is_real, is_real]), - ">=": Proc(">=", lambda a, b: a >= b, (2, 2), [is_real, is_real]), - "<": Proc("<", lambda a, b: a < b, (2, 2), [is_real, is_real]), - "<=": Proc("<=", lambda a, b: a <= b, (2, 2), [is_real, is_real]), - "=": Proc("=", equal_num, (1, None), [is_num]), + ">": Proc(">", lambda a, b: a > b, (2, 2), is_real), + ">=": Proc(">=", lambda a, b: a >= b, (2, 2), is_real), + "<": Proc("<", lambda a, b: a < b, (2, 2), is_real), + "<=": Proc("<=", lambda a, b: a <= b, (2, 2), is_real), + "=": Proc("=", equal_num, (1, None), is_num), "eq?": Proc("eq?", lambda a, b: a is b, (2, 2)), "equal?": Proc("equal?", is_equal, (2, 2)), - "zero?": UserProc(env, "zero?", ["z"], [[Sym("="), Sym("z"), 0]], [is_num]), - "positive?": UserProc( - env, "positive?", ["x"], [[Sym(">"), Sym("x"), 0]], [is_real] - ), - "negative?": UserProc( - env, "negative?", ["x"], [[Sym("<"), Sym("x"), 0]], [is_real] - ), + "zero?": UserProc(env, "zero?", ["z"], (is_num,), [[Sym("="), Sym("z"), 0]]), + "positive?": UserProc(env, "positive?", ["x"], (is_real,), [[Sym(">"), Sym("x"), 0]]), + "negative?": UserProc(env, "negative?", ["x"], (is_real,), [[Sym("<"), Sym("x"), 0]]), "even?": UserProc( - env, "even?", ["n"], [[Sym("zero?"), [Sym("mod"), Sym("n"), 2]]], [is_int] + env, "even?", ["n"], (is_int,), [[Sym("zero?"), [Sym("mod"), Sym("n"), 2]]] ), "odd?": UserProc( - env, "odd?", ["n"], [[Sym("not"), [Sym("even?"), Sym("n")]]], [is_int] + env, "odd?", ["n"], (is_int,), [[Sym("not"), [Sym("even?"), Sym("n")]]] ), - ">=/c": Proc(">=/c", gte_c, (1, 1), [is_real]), - ">/c": Proc(">/c", gt_c, (1, 1), [is_real]), - "<=/c": Proc("<=/c", lte_c, (1, 1), [is_real]), - "=/c": Proc(">=/c", gte_c, (1, 1), is_real), + ">/c": Proc(">/c", gt_c, (1, 1), is_real), + "<=/c": Proc("<=/c", lte_c, (1, 1), is_real), + "string": Proc("symbol->string", str, (1, 1), [is_symbol]), - "string->symbol": Proc("string->symbol", Sym, (1, 1), [is_str]), + "symbol->string": Proc("symbol->string", str, (1, 1), is_symbol), + "string->symbol": Proc("string->symbol", Sym, (1, 1), is_str), # strings - "string": Proc("string", string_append, (0, None), [is_char]), - "&": Proc("&", string_append, (0, None), [is_str]), - "split": Proc("split", str.split, (1, 2), [is_str, is_str]), - "strip": Proc("strip", str.strip, (1, 1), [is_str]), - "str-repeat": Proc("str-repeat", lambda s, a: s * a, (2, 2), [is_str, is_int]), - "startswith": Proc("startswith", str.startswith, (2, 2), [is_str, is_str]), - "endswith": Proc("endswith", str.endswith, (2, 2), [is_str, is_str]), - "replace": Proc("replace", str.replace, (3, 4), [is_str, is_str, is_str, is_int]), - "title": Proc("title", str.title, (1, 1), [is_str]), - "lower": Proc("lower", str.lower, (1, 1), [is_str]), - "upper": Proc("upper", str.upper, (1, 1), [is_str]), + "string": Proc("string", string_append, (0, None), is_char), + "&": Proc("&", string_append, (0, None), is_str), + "split": Proc("split", str.split, (1, 2), is_str, is_str), + "strip": Proc("strip", str.strip, (1, 1), is_str), + "str-repeat": Proc("str-repeat", lambda s, a: s * a, (2, 2), is_str, is_int), + "startswith": Proc("startswith", str.startswith, (2, 2), is_str), + "endswith": Proc("endswith", str.endswith, (2, 2), is_str), + "replace": Proc("replace", str.replace, (3, 4), is_str, is_str, is_str, is_int), + "title": Proc("title", str.title, (1, 1), is_str), + "lower": Proc("lower", str.lower, (1, 1), is_str), + "upper": Proc("upper", str.upper, (1, 1), is_str), # format - "char->int": Proc("char->int", lambda c: ord(c.val), (1, 1), [is_char]), - "int->char": Proc("int->char", Char, (1, 1), [is_int]), + "char->int": Proc("char->int", lambda c: ord(c.val), (1, 1), is_char), + "int->char": Proc("int->char", Char, (1, 1), is_int), "~a": Proc("~a", lambda *v: "".join([display_str(a) for a in v]), (0, None)), "~s": Proc("~s", lambda *v: " ".join([display_str(a) for a in v]), (0, None)), "~v": Proc("~v", lambda *v: " ".join([print_str(a) for a in v]), (0, None)), # keyword "keyword?": is_keyw, - "keyword->string": Proc("keyword->string", lambda v: v.val.val, (1, 1), [is_keyw]), - "string->keyword": Proc("string->keyword", QuotedKeyword, (1, 1), [is_str]), + "keyword->string": Proc("keyword->string", lambda v: v.val.val, (1, 1), is_keyw), + "string->keyword": Proc("string->keyword", QuotedKeyword, (1, 1), is_str), # vectors "vector": Proc("vector", lambda *a: list(a), (0, None)), "make-vector": Proc( - "make-vector", lambda size, a=0: [a] * size, (1, 2), [is_uint, any_p] + "make-vector", lambda size, a=0: [a] * size, (1, 2), is_uint, any_p ), - "vector-append": Proc("vector-append", vector_append, (0, None), [is_vector]), - "vector-pop!": Proc("vector-pop!", list.pop, (1, 1), [is_vector]), - "vector-add!": Proc("vector-add!", list.append, (2, 2), [is_vector, any_p]), - "vector-set!": Proc("vector-set!", vector_set, (3, 3), [is_vector, is_int, any_p]), - "vector-extend!": Proc("vector-extend!", vector_extend, (2, None), [is_vector]), - "sort": Proc("sort", sorted, (1, 1), [is_vector]), - "sort!": Proc("sort!", list.sort, (1, 1), [is_vector]), + "vector-append": Proc("vector-append", vector_append, (0, None), is_vector), + "vector-pop!": Proc("vector-pop!", list.pop, (1, 1), is_vector), + "vector-add!": Proc("vector-add!", list.append, (2, 2), is_vector, any_p), + "vector-set!": Proc("vector-set!", vector_set, (3, 3), is_vector, is_int, any_p), + "vector-extend!": Proc("vector-extend!", vector_extend, (2, None), is_vector), + "sort": Proc("sort", sorted, (1, 1), is_vector), + "sort!": Proc("sort!", list.sort, (1, 1), is_vector), # arrays - "array": Proc("array", array_proc, (2, None), [is_symbol, is_real]), - "make-array": Proc("make-array", make_array, (2, 3), [is_symbol, is_uint, is_real]), + "array": Proc("array", array_proc, (2, None), is_symbol, is_real), + "make-array": Proc("make-array", make_array, (2, 3), is_symbol, is_uint, is_real), "array-splice!": Proc( - "array-splice!", splice, (2, 4), [is_array, is_real, is_int, is_int] + "array-splice!", splice, (2, 4), is_array, is_real, is_int, is_int ), - "array-copy": Proc("array-copy", np.copy, (1, 1), [is_array]), - "count-nonzero": Proc("count-nonzero", np.count_nonzero, (1, 1), [is_array]), + "array-copy": Proc("array-copy", np.copy, (1, 1), is_array), + "count-nonzero": Proc("count-nonzero", np.count_nonzero, (1, 1), is_array), # bool arrays "bool-array": Proc( - "bool-array", lambda *a: np.array(a, dtype=np.bool_), (1, None), [is_uint] + "bool-array", lambda *a: np.array(a, dtype=np.bool_), (1, None), is_uint ), - "margin": Proc("margin", margin, (2, 3), None), - "mincut": Proc("mincut", mincut, (2, 2), [is_boolarr, is_uint]), - "minclip": Proc("minclip", minclip, (2, 2), [is_boolarr, is_uint]), - "maxcut": Proc("maxcut", maxcut, (2, 2), [is_boolarr, is_uint]), - "maxclip": Proc("maxclip", maxclip, (2, 2), [is_boolarr, is_uint]), + "margin": Proc("margin", margin, (2, 3)), + "mincut": Proc("mincut", mincut, (2, 2), is_boolarr, is_uint), + "minclip": Proc("minclip", minclip, (2, 2), is_boolarr, is_uint), + "maxcut": Proc("maxcut", maxcut, (2, 2), is_boolarr, is_uint), + "maxclip": Proc("maxclip", maxclip, (2, 2), is_boolarr, is_uint), # ranges - "range": Proc("range", range, (1, 3), [is_int, is_int, int_not_zero]), + "range": Proc("range", range, (1, 3), is_int, is_int, int_not_zero), # generic iterables - "len": Proc("len", len, (1, 1), [is_iterable]), - "reverse": Proc("reverse", lambda v: v[::-1], (1, 1), [is_sequence]), - "ref": Proc("ref", ref, (2, 2), [is_sequence, is_int]), - "slice": Proc("slice", p_slice, (2, 4), [is_sequence, is_int]), + "len": Proc("len", len, (1, 1), is_iterable), + "reverse": Proc("reverse", lambda v: v[::-1], (1, 1), is_sequence), + "ref": Proc("ref", ref, (2, 2), is_sequence, is_int), + "slice": Proc("slice", p_slice, (2, 4), is_sequence, is_int), # procedures - "map": Proc("map", palet_map, (2, 2), [is_proc, is_sequence]), - "apply": Proc("apply", lambda p, s: p(*s), (2, 2), [is_proc, is_sequence]), - "and/c": Proc("and/c", andc, (1, None), [is_cont]), - "or/c": Proc("or/c", orc, (1, None), [is_cont]), - "not/c": Proc("not/c", notc, (1, 1), [is_cont]), + "map": Proc("map", palet_map, (2, 2), is_proc, is_sequence), + "apply": Proc("apply", lambda p, s: p(*s), (2, 2), is_proc, is_sequence), + "and/c": Proc("and/c", andc, (1, None), is_cont), + "or/c": Proc("or/c", orc, (1, None), is_cont), + "not/c": Proc("not/c", notc, (1, 1), is_cont), # hashs "hash": Proc("hash", palet_hash, (0, None)), - "hash-ref": Proc("hash", hash_ref, (2, 2), [is_hash, any_p]), - "hash-set!": Proc("hash-set!", hash_set, (3, 3), [is_hash, any_p, any_p]), - "has-key?": Proc("has-key?", lambda h, k: k in h, (2, 2), [is_hash, any_p]), - "hash-remove!": Proc("hash-remove!", hash_remove, (2, 2), [is_hash, any_p]), - "hash-update!": UserProc(env, "hash-update!", ["h", "v", "up"], + "hash-ref": Proc("hash", hash_ref, (2, 2), is_hash, any_p), + "hash-set!": Proc("hash-set!", hash_set, (3, 3), is_hash, any_p, any_p), + "has-key?": Proc("has-key?", lambda h, k: k in h, (2, 2), is_hash, any_p), + "hash-remove!": Proc("hash-remove!", hash_remove, (2, 2), is_hash, any_p), + "hash-update!": UserProc(env, "hash-update!", ["h", "v", "up"], (is_hash, any_p), [[Sym("hash-set!"), Sym("h"), Sym("v"), [Sym("up"), [Sym("hash-ref"), Sym("h"), Sym("v")]]]], - [is_hash, any_p, any_p], ), # actions - "assert": Proc("assert", palet_assert, (1, 2), [any_p, orc(is_str, False)]), + "assert": Proc("assert", palet_assert, (1, 2), any_p, orc(is_str, False)), "display": Proc("display", lambda v: print(display_str(v), end=""), (1, 1)), "displayln": Proc("displayln", lambda v: print(display_str(v)), (1, 1)), - "error": Proc("error", raise_, (1, 1), [is_str]), - "sleep": Proc("sleep", sleep, (1, 1), [is_int_or_float]), + "error": Proc("error", raise_, (1, 1), is_str), + "sleep": Proc("sleep", sleep, (1, 1), is_int_or_float), "print": Proc("print", lambda v: print(print_str(v), end=""), (1, 1)), "println": Proc("println", lambda v: print(print_str(v)), (1, 1)), - "system": Proc("system", palet_system, (1, 1), [is_str]), + "system": Proc("system", palet_system, (1, 1), is_str), # conversions - "number->string": Proc("number->string", number_to_string, (1, 1), [is_num]), + "number->string": Proc("number->string", number_to_string, (1, 1), is_num), "string->vector": Proc( - "string->vector", lambda s: [Char(c) for c in s], (1, 1), [is_str] + "string->vector", lambda s: [Char(c) for c in s], (1, 1), is_str ), - "range->vector": Proc("range->vector", list, (1, 1), [is_range]), + "range->vector": Proc("range->vector", list, (1, 1), is_range), # reflexion - "var-exists?": Proc("var-exists?", lambda sym: sym.val in env, (1, 1), [is_symbol]), + "var-exists?": Proc("var-exists?", lambda sym: sym.val in env, (1, 1), is_symbol), "rename": Syntax(syn_rename), "delete": Syntax(syn_delete), }) diff --git a/auto_editor/lib/contracts.py b/auto_editor/lib/contracts.py index cb76e45aa..6158412c1 100644 --- a/auto_editor/lib/contracts.py +++ b/auto_editor/lib/contracts.py @@ -50,7 +50,7 @@ def check_args( o: str, values: list | tuple, arity: tuple[int, int | None], - cont: list[Contract] | None, + cont: tuple[Any, ...], ) -> None: lower, upper = arity amount = len(values) @@ -65,7 +65,7 @@ def check_args( if upper is not None and (amount > upper or amount < lower): raise MyError(f"{base}between {lower} and {upper}, got {amount}") - if cont is None: + if not cont: return for i, val in enumerate(values): @@ -75,12 +75,16 @@ def check_args( raise MyError(f"`{o}` expected a {exp}, got {print_str(val)}") -@dataclass(slots=True) class Proc: - name: str - proc: Callable - arity: tuple[int, int | None] = (1, None) - contracts: list[Any] | None = None + __slots__ = ("name", "proc", "arity", "contracts") + + def __init__( + self, n: str, p: Callable, a: tuple[int, int | None] = (1, None), *c: Any + ): + self.name = n + self.proc = p + self.arity = a + self.contracts: tuple[Any, ...] = c def __call__(self, *args: Any) -> Any: check_args(self.name, args, self.arity, self.contracts) @@ -134,13 +138,13 @@ def is_contract(c: object) -> bool: def andc(*cs: object) -> Proc: return Proc( - "flat-and/c", lambda v: all([check_contract(c, v) for c in cs]), (1, 1), [any_p] + "flat-and/c", lambda v: all([check_contract(c, v) for c in cs]), (1, 1), any_p ) def orc(*cs: object) -> Proc: return Proc( - "flat-or/c", lambda v: any([check_contract(c, v) for c in cs]), (1, 1), [any_p] + "flat-or/c", lambda v: any([check_contract(c, v) for c in cs]), (1, 1), any_p )