diff --git a/Doc/whatsnew/3.14.rst b/Doc/whatsnew/3.14.rst index 72abfebd46f2b9..36d50ef10e8296 100644 --- a/Doc/whatsnew/3.14.rst +++ b/Doc/whatsnew/3.14.rst @@ -210,6 +210,10 @@ configuration mechanisms). Other language changes ====================== +* Constant comparsion expressions are now folded and evaluated before runtime. + For example, expressions like: ``"str" in ("str",)`` or ``1 == 1.0 == True`` + are now pre-evaluated. + (Contributed by Yan Yanchii in :gh:`128706`.) * The :func:`map` built-in now has an optional keyword-only *strict* flag like :func:`zip` to check that all the iterables are of equal length. diff --git a/Lib/test/test_ast/test_ast.py b/Lib/test/test_ast/test_ast.py index c268a1f00f938e..2bd8aadadfe31a 100644 --- a/Lib/test/test_ast/test_ast.py +++ b/Lib/test/test_ast/test_ast.py @@ -3180,7 +3180,8 @@ def create_unaryop(operand): self.assert_ast(result_code, non_optimized_target, optimized_target) def test_folding_not(self): - code = "not (1 %s (1,))" + # use list as left-hand side to avoid folding constant expression to True/False + code = "not ([] %s (1,))" operators = { "in": ast.In(), "is": ast.Is(), @@ -3192,7 +3193,7 @@ def test_folding_not(self): def create_notop(operand): return ast.UnaryOp(op=ast.Not(), operand=ast.Compare( - left=ast.Constant(value=1), + left=ast.List(), ops=[operators[operand]], comparators=[ast.Tuple(elts=[ast.Constant(value=1)])] )) @@ -3201,7 +3202,7 @@ def create_notop(operand): result_code = code % op non_optimized_target = self.wrap_expr(create_notop(op)) optimized_target = self.wrap_expr( - ast.Compare(left=ast.Constant(1), ops=[opt_operators[op]], comparators=[ast.Constant(value=(1,))]) + ast.Compare(left=ast.List(), ops=[opt_operators[op]], comparators=[ast.Constant(value=(1,))]) ) with self.subTest( @@ -3239,8 +3240,58 @@ def test_folding_tuple(self): self.assert_ast(code, non_optimized_target, optimized_target) - def test_folding_comparator(self): - code = "1 %s %s1%s" + def test_folding_compare(self): + true = self.wrap_expr(ast.Constant(value=True)) + false = self.wrap_expr(ast.Constant(value=False)) + + folded_cases = ( + ("3 > 2 > 1", (ast.Constant(3), [ast.Gt(), ast.Gt()], [ast.Constant(value=2), ast.Constant(value=1)]), true), + ("3 > 4 > 1", (ast.Constant(3), [ast.Gt(), ast.Gt()], [ast.Constant(value=4), ast.Constant(value=1)]), false), + ("3 >= 3 >= 1", (ast.Constant(3), [ast.GtE(), ast.GtE()], [ast.Constant(value=3), ast.Constant(value=1)]), true), + ("3 >= 4 >= 1", (ast.Constant(3), [ast.GtE(), ast.GtE()], [ast.Constant(value=4), ast.Constant(value=1)]), false), + ("1 < 2 < 3", (ast.Constant(1), [ast.Lt(), ast.Lt()], [ast.Constant(value=2), ast.Constant(value=3)]), true), + ("1 < 0 < 3", (ast.Constant(1), [ast.Lt(), ast.Lt()], [ast.Constant(value=0), ast.Constant(value=3)]), false), + ("1 <= 2 <= 3", (ast.Constant(1), [ast.LtE(), ast.LtE()], [ast.Constant(value=2), ast.Constant(value=3)]), true), + ("1 <= 0 <= 3", (ast.Constant(1), [ast.LtE(), ast.LtE()], [ast.Constant(value=0), ast.Constant(value=3)]), false), + ("1 == 1.0 == True", (ast.Constant(1), [ast.Eq(), ast.Eq()], [ast.Constant(value=1.0), ast.Constant(value=True)]), true), + ("1 == 2 == True", (ast.Constant(1), [ast.Eq(), ast.Eq()], [ast.Constant(value=2), ast.Constant(value=True)]), false), + ("1 != 2 != 3", (ast.Constant(1), [ast.NotEq(), ast.NotEq()], [ast.Constant(value=2), ast.Constant(value=3)]), true), + ("1 != 1 != 3", (ast.Constant(1), [ast.NotEq(), ast.NotEq()], [ast.Constant(value=1), ast.Constant(value=3)]), false), + ("1 in [1, 2]", (ast.Constant(1), [ast.In()], [ast.List(elts=[ast.Constant(1), ast.Constant(2)])]), true), + ("1 in [2, 2]", (ast.Constant(1), [ast.In()], [ast.List(elts=[ast.Constant(2), ast.Constant(2)])]), false), + ("1 not in [1, 2]", (ast.Constant(1), [ast.NotIn()], [ast.List(elts=[ast.Constant(1), ast.Constant(2)])]), false), + ("1 not in [2, 2]", (ast.Constant(1), [ast.NotIn()], [ast.List(elts=[ast.Constant(2), ast.Constant(2)])]), true), + ) + + for code, original, folded in folded_cases: + left, ops, comparators = original + unfolded = self.wrap_expr(ast.Compare(left=left, ops=ops, comparators=comparators)) + self.assert_ast(code=code, non_optimized_target=unfolded, optimized_target=folded) + + # these should stay as they were + unfolded_cases = ( + ("3 > 2 > []", ast.Compare(left=ast.Constant(3), ops=[ast.Gt(), ast.Gt()], comparators=[ast.Constant(2), ast.List()])), + ("1 > [] > 0", ast.Compare(left=ast.Constant(1), ops=[ast.Gt(), ast.Gt()], comparators=[ast.List(), ast.Constant(0)])), + ("1 >= [] >= 0", ast.Compare(left=ast.Constant(1), ops=[ast.GtE(), ast.GtE()], comparators=[ast.List(), ast.Constant(0)])), + ("1 < [] < 0", ast.Compare(left=ast.Constant(1), ops=[ast.Lt(), ast.Lt()], comparators=[ast.List(), ast.Constant(0)])), + ("1 <= [] <= 0", ast.Compare(left=ast.Constant(1), ops=[ast.LtE(), ast.LtE()], comparators=[ast.List(), ast.Constant(0)])), + ("1 == [] == 0", ast.Compare(left=ast.Constant(1), ops=[ast.Eq(), ast.Eq()], comparators=[ast.List(), ast.Constant(0)])), + ("1 != [] != 0", ast.Compare(left=ast.Constant(1), ops=[ast.NotEq(), ast.NotEq()], comparators=[ast.List(), ast.Constant(0)])), + ("1 is 1", ast.Compare(left=ast.Constant(1), ops=[ast.Is()], comparators=[ast.Constant(1)])), + ("1 is not 1", ast.Compare(left=ast.Constant(1), ops=[ast.IsNot()], comparators=[ast.Constant(1)])), + # invalid also should stay as they were + ("1 in 1", ast.Compare(left=ast.Constant(1), ops=[ast.In()], comparators=[ast.Constant(1)])), + ("1 not in 1", ast.Compare(left=ast.Constant(1), ops=[ast.NotIn()], comparators=[ast.Constant(1)])), + ) + + for code, expected in unfolded_cases: + self.assertTrue(ast.compare(ast.parse(code), self.wrap_expr(expected))) + + def test_folding_comparator_list_set_subst(self): + """Test substitution of list/set with tuple/frozenset in expressions like "1 in [1]" or "1 in {1}" """ + + # use list as left-hand side to avoid folding constant comparison expression to True/False + code = "[] %s %s1%s" operators = [("in", ast.In()), ("not in", ast.NotIn())] braces = [ ("[", "]", ast.List, (1,)), @@ -3249,11 +3300,11 @@ def test_folding_comparator(self): for left, right, non_optimized_comparator, optimized_comparator in braces: for op, node in operators: non_optimized_target = self.wrap_expr(ast.Compare( - left=ast.Constant(1), ops=[node], + left=ast.List(), ops=[node], comparators=[non_optimized_comparator(elts=[ast.Constant(1)])] )) optimized_target = self.wrap_expr(ast.Compare( - left=ast.Constant(1), ops=[node], + left=ast.List(), ops=[node], comparators=[ast.Constant(value=optimized_comparator)] )) self.assert_ast(code % (op, left, right), non_optimized_target, optimized_target) diff --git a/Misc/NEWS.d/next/Core_and_Builtins/2025-01-10-18-42-57.gh-issue-128706.4T2XSt.rst b/Misc/NEWS.d/next/Core_and_Builtins/2025-01-10-18-42-57.gh-issue-128706.4T2XSt.rst new file mode 100644 index 00000000000000..1949a649d9f748 --- /dev/null +++ b/Misc/NEWS.d/next/Core_and_Builtins/2025-01-10-18-42-57.gh-issue-128706.4T2XSt.rst @@ -0,0 +1 @@ +Add constant folding for constant comparisons. diff --git a/Python/ast_opt.c b/Python/ast_opt.c index 01e208b88eca8b..06034f2d805832 100644 --- a/Python/ast_opt.c +++ b/Python/ast_opt.c @@ -639,6 +639,68 @@ fold_compare(expr_ty node, PyArena *arena, _PyASTOptimizeState *state) return 0; } } + + static const int richcompare_table[] = { + [Eq] = Py_EQ, + [NotEq] = Py_NE, + [Gt] = Py_GT, + [Lt] = Py_LT, + [GtE] = Py_GE, + [LtE] = Py_LE, + }; + + if (node->v.Compare.left->kind == Constant_kind) { + PyObject *lhs = node->v.Compare.left->v.Constant.value; + for (Py_ssize_t i = 0; i < asdl_seq_LEN(args); i++) { + expr_ty curr_expr = (expr_ty)asdl_seq_GET(args, i); + if (curr_expr->kind != Constant_kind) { + /* try to fold only if every comparator is constant */ + return 1; + } + int op = asdl_seq_GET(ops, i); + if (op == Is || op == IsNot) { + /* Do not fold "is" and "is not" expressions since this breaks + expected syntax warnings. For example: + >>> 1 is 1 + :1: SyntaxWarning: "is" with 'int' literal. Did you mean "=="? + */ + return 1; + } + PyObject *rhs = curr_expr->v.Constant.value; + int res; + switch (op) { + case Eq: + case NotEq: + case Gt: + case Lt: + case GtE: + case LtE: { + res = PyObject_RichCompareBool(lhs, rhs, richcompare_table[op]); + break; + } + case In: + case NotIn: { + res = PySequence_Contains(rhs, lhs); + if (op == NotIn && res >= 0) { + res = !res; + } + break; + } + default: + Py_UNREACHABLE(); + } + if (res == 0) { + /* shortcut, whole expression is False */ + return make_const(node, Py_False, arena); + } + else if (res < 0) { + return make_const(node, NULL, arena); + } + lhs = rhs; + } + /* whole expression is True */ + return make_const(node, Py_True, arena); + } return 1; }