Skip to content

Commit

Permalink
Put array as first arg in margin proc
Browse files Browse the repository at this point in the history
  • Loading branch information
WyattBlue committed Jul 29, 2024
1 parent 0b835e1 commit e612311
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 18 deletions.
28 changes: 12 additions & 16 deletions auto_editor/lang/palet.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,42 +620,36 @@ def make_array(dtype: Sym, size: int, v: int = 0) -> np.ndarray:
raise MyError(f"number too large to be converted to {dtype}")


def minclip(oarr: BoolList, _min: int) -> BoolList:
def minclip(oarr: BoolList, _min: int, /) -> BoolList:
arr = np.copy(oarr)
mut_remove_small(arr, _min, replace=1, with_=0)
return arr


def mincut(oarr: BoolList, _min: int) -> BoolList:
def mincut(oarr: BoolList, _min: int, /) -> BoolList:
arr = np.copy(oarr)
mut_remove_small(arr, _min, replace=0, with_=1)
return arr


def maxclip(oarr: BoolList, _min: int) -> BoolList:
def maxclip(oarr: BoolList, _min: int, /) -> BoolList:
arr = np.copy(oarr)
mut_remove_large(arr, _min, replace=1, with_=0)
return arr


def maxcut(oarr: BoolList, _min: int) -> BoolList:
def maxcut(oarr: BoolList, _min: int, /) -> BoolList:
arr = np.copy(oarr)
mut_remove_large(arr, _min, replace=0, with_=1)
return arr


def margin(a: int, b: Any, c: Any = None) -> BoolList:
if c is None:
check_args("margin", [a, b], (2, 2), (is_int, is_boolarr))
oarr = b
start, end = a, a
else:
check_args("margin", [a, b, c], (3, 3), (is_int, is_int, is_boolarr))
oarr = c
start, end = a, b

def margin(oarr: BoolList, start: int, end: int | None = None, /) -> BoolList:
arr = np.copy(oarr)
mut_margin(arr, start, end)
if end is None:
mut_margin(arr, start, start)
else:
mut_margin(arr, start, end)
return arr


Expand Down Expand Up @@ -1741,6 +1735,8 @@ def my_eval(env: Env, node: object) -> Any:
"round": Proc("round", round, (1, 1), is_real),
"max": Proc("max", lambda *v: max(v), (1, None), is_real),
"min": Proc("min", lambda *v: min(v), (1, None), is_real),
"max-seq": Proc("max-seq", max, (1, 1), is_sequence),
"min-seq": Proc("min-seq", min, (1, 1), is_sequence),
"mod": Proc("mod", mod, (2, 2), is_int),
"modulo": Proc("modulo", mod, (2, 2), is_int),
# symbols
Expand Down Expand Up @@ -1796,7 +1792,7 @@ def my_eval(env: Env, node: object) -> Any:
"bool-array": Proc(
"bool-array", lambda *a: np.array(a, dtype=np.bool_), (1, None), is_nat
),
"margin": Proc("margin", margin, (2, 3)),
"margin": Proc("margin", margin, (2, 3), is_boolarr, is_int),
"mincut": Proc("mincut", mincut, (2, 2), is_boolarr, is_nat),
"minclip": Proc("minclip", minclip, (2, 2), is_boolarr, is_nat),
"maxcut": Proc("maxcut", maxcut, (2, 2), is_boolarr, is_nat),
Expand Down
2 changes: 2 additions & 0 deletions auto_editor/lib/data_structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ def display_str(val: object) -> str:
return f"{val.real}{join}{val.imag}i"
if type(val) is np.bool_:
return "1" if val else "0"
if type(val) is np.float64 or type(val) is np.float32:
return f"{float(val)}"
if type(val) is Fraction:
return f"{val.numerator}/{val.denominator}"

Expand Down
4 changes: 2 additions & 2 deletions auto_editor/subcommands/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,11 +628,11 @@ def cases(*cases: tuple[str, Any]) -> None:
("(string #\\a #\\b)", "ab"),
("(string #\\a #\\b #\\c)", "abc"),
(
"(margin 0 (bool-array 0 0 0 1 0 0 0))",
"(margin (bool-array 0 0 0 1 0 0 0) 0)",
np.array([0, 0, 0, 1, 0, 0, 0], dtype=np.bool_),
),
(
"(margin -2 2 (bool-array 0 0 1 1 0 0 0))",
"(margin (bool-array 0 0 1 1 0 0 0) -2 2)",
np.array([0, 0, 0, 0, 1, 1, 0], dtype=np.bool_),
),
("(equal? 3 3)", True),
Expand Down

0 comments on commit e612311

Please sign in to comment.