diff --git a/builtin/builtin.go b/builtin/builtin.go index 43b08b78..d76cc9d0 100644 --- a/builtin/builtin.go +++ b/builtin/builtin.go @@ -972,33 +972,36 @@ var Builtins = []*ast.Function{ return arrayType, nil }, }, - { - Name: "bitand", - Func: func(args ...any) (any, error) { - return bitFunc("bitand", func(x, y int) (any, error) { - return x & y, nil - }, args) - }, - Types: types(new(func(int, int) int)), - }, - { - Name: "bitor", - Func: func(args ...any) (any, error) { - return bitFunc("bitor", func(x, y int) (any, error) { - return x | y, nil - }, args) - }, - Types: types(new(func(int, int) int)), - }, - { - Name: "bitxor", - Func: func(args ...any) (any, error) { - return bitFunc("bitxor", func(x, y int) (any, error) { - return x ^ y, nil - }, args) - }, - Types: types(new(func(int, int) int)), - }, + bitFunc("bitand", func(x, y int) (any, error) { + return x & y, nil + }), + bitFunc("bitor", func(x, y int) (any, error) { + return x | y, nil + }), + bitFunc("bitxor", func(x, y int) (any, error) { + return x ^ y, nil + }), + bitFunc("bitnand", func(x, y int) (any, error) { + return x &^ y, nil + }), + bitFunc("bitshl", func(x, y int) (any, error) { + if y < 0 { + return nil, fmt.Errorf("invalid operation: negative shift count %d (type int)", y) + } + return x << y, nil + }), + bitFunc("bitshr", func(x, y int) (any, error) { + if y < 0 { + return nil, fmt.Errorf("invalid operation: negative shift count %d (type int)", y) + } + return x >> y, nil + }), + bitFunc("bitushr", func(x, y int) (any, error) { + if y < 0 { + return nil, fmt.Errorf("invalid operation: negative shift count %d (type int)", y) + } + return int(uint(x) >> y), nil + }), { Name: "bitnot", Func: func(args ...any) (any, error) { @@ -1011,51 +1014,6 @@ var Builtins = []*ast.Function{ } return ^x, nil }, - Types: types(new(func(int) any)), - }, - { - Name: "bitnand", - Func: func(args ...any) (any, error) { - return bitFunc("bitnand", func(x, y int) (any, error) { - return x &^ y, nil - }, args) - }, - Types: types(new(func(int, int) int)), - }, - { - Name: "bitshr", - Func: func(args ...any) (any, error) { - return bitFunc("bitshr", func(x, y int) (any, error) { - if y < 0 { - return nil, fmt.Errorf("invalid operation: negative shift count %d (type int)", y) - } - return x >> y, nil - }, args) - }, - Types: types(new(func(int, int) int)), - }, - { - Name: "bitshl", - Func: func(args ...any) (any, error) { - return bitFunc("bitshl", func(x, y int) (any, error) { - if y < 0 { - return nil, fmt.Errorf("invalid operation: negative shift count %d (type int)", y) - } - return x << y, nil - }, args) - }, - Types: types(new(func(int, int) int)), - }, - { - Name: "bitushr", - Func: func(args ...any) (any, error) { - return bitFunc("bitushr", func(x, y int) (any, error) { - if y < 0 { - return nil, fmt.Errorf("invalid operation: negative shift count %d (type int)", y) - } - return int(uint(x) >> y), nil - }, args) - }, - Types: types(new(func(int, int) int)), + Types: types(new(func(int) int)), }, } diff --git a/builtin/func.go b/builtin/func.go index cd347ba1..1c2546a2 100644 --- a/builtin/func.go +++ b/builtin/func.go @@ -6,6 +6,7 @@ import ( "reflect" "strconv" + "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/vm/runtime" ) @@ -273,17 +274,23 @@ func Min(args ...any) (any, error) { return min, nil } -func bitFunc(name string, fn func(x, y int) (any, error), args []any) (any, error) { - if len(args) != 2 { - return nil, fmt.Errorf("invalid number of arguments for %s (expected 2, got %d)", name, len(args)) +func bitFunc(name string, fn func(x, y int) (any, error)) *ast.Function { + return &ast.Function{ + Name: name, + Func: func(args ...any) (any, error) { + if len(args) != 2 { + return nil, fmt.Errorf("invalid number of arguments for %s (expected 2, got %d)", name, len(args)) + } + x, err := toInt(args[0]) + if err != nil { + return nil, fmt.Errorf("%v to call %s", err, name) + } + y, err := toInt(args[1]) + if err != nil { + return nil, fmt.Errorf("%v to call %s", err, name) + } + return fn(x, y) + }, + Types: types(new(func(int, int) int)), } - x, err := toInt(args[0]) - if err != nil { - return nil, fmt.Errorf("%v to call %s", err, name) - } - y, err := toInt(args[1]) - if err != nil { - return nil, fmt.Errorf("%v to call %s", err, name) - } - return fn(x, y) }