From cb4237bdef5259914725518e191116a7f5d94f9e Mon Sep 17 00:00:00 2001 From: "Chris Hawblitzel (Microsoft)" Date: Wed, 14 Jun 2023 15:53:33 -0700 Subject: [PATCH] Replace height intrinsic with is_smaller_than and make height a partial order (#570) * Replace height intrinsic with is_smaller_than, and in the SMT encoding, encode height as a Z3 partial order rather than an int to allow decreases on infinite maps. * Update .github/workflows/get-z3.sh to Z3 4.12.2 * Allow decreases on FnSpec fields * Add is_smaller_than_lexicographic and some comments * Add decreases_to macro * Split is_smaller_than handling into separate function * Restrict usage of decreases through FnSpec and Map, erring on the side of caution, inspired by https://github.com/FStarLang/FStar/pull/2954 --- .github/workflows/get-z3.sh | 2 +- source/air/src/ast.rs | 23 +- source/air/src/closure.rs | 2 +- source/air/src/parser.rs | 32 +- source/air/src/printer.rs | 16 +- source/air/src/tests.rs | 30 ++ source/air/src/typecheck.rs | 8 +- source/builtin/src/lib.rs | 34 +- source/pervasive/map.rs | 35 ++ source/pervasive/seq.rs | 10 + .../example/summer_school/chapter-1-22.rs | 3 +- source/rust_verify/src/consts.rs | 2 +- source/rust_verify/src/erase.rs | 4 +- source/rust_verify/src/lifetime_generate.rs | 24 +- source/rust_verify/src/rust_to_vir_expr.rs | 88 ++++- source/rust_verify_test/tests/recursion.rs | 338 +++++++++++++++++- source/tools/get-z3.ps1 | 2 +- source/tools/get-z3.sh | 2 +- source/vir/src/ast.rs | 13 +- source/vir/src/ast_visitor.rs | 5 +- source/vir/src/datatype_to_air.rs | 130 +++++-- source/vir/src/def.rs | 5 + source/vir/src/interpreter.rs | 14 +- source/vir/src/modes.rs | 14 +- source/vir/src/poly.rs | 22 +- source/vir/src/prelude.rs | 66 +++- source/vir/src/recursion.rs | 44 ++- source/vir/src/split_expression.rs | 1 - source/vir/src/sst.rs | 1 + source/vir/src/sst_to_air.rs | 63 +++- source/vir/src/sst_util.rs | 6 +- source/vir/src/sst_visitor.rs | 1 - source/vir/src/triggers.rs | 31 +- source/vir/src/triggers_auto.rs | 3 +- 34 files changed, 932 insertions(+), 142 deletions(-) diff --git a/.github/workflows/get-z3.sh b/.github/workflows/get-z3.sh index 6c732c7e56..7e1ab73179 100755 --- a/.github/workflows/get-z3.sh +++ b/.github/workflows/get-z3.sh @@ -1,5 +1,5 @@ #! /bin/bash -z3_version="4.10.1" +z3_version="4.12.2" filename=z3-$z3_version-x64-glibc-2.31 wget https://github.com/Z3Prover/z3/releases/download/z3-$z3_version/$filename.zip diff --git a/source/air/src/ast.rs b/source/air/src/ast.rs index 2aefc17d4c..c1be708d6c 100644 --- a/source/air/src/ast.rs +++ b/source/air/src/ast.rs @@ -49,6 +49,22 @@ pub enum UnaryOp { BitExtract(u32, u32), } +/// These are Z3 special relations x <= y that are documented at +/// https://microsoft.github.io/z3guide/docs/theories/Special%20Relations/ +#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub enum Relation { + /// reflexive, transitive, antisymmetric + PartialOrder, + /// reflexive, transitive, antisymmetric, and for all x, y. (x <= y or y <= x) + LinearOrder, + /// reflexive, transitive, antisymmetric, and for all x, y, z. (y <= x and z <= x) ==> (y <= z or z <= y) + TreeOrder, + /// reflexive, transitive, antisymmetric, and: + /// - for all x, y, z. (x <= y and x <= z) ==> (y <= z or z <= y) + /// - for all x, y, z. (y <= x and z <= x) ==> (y <= z or z <= y) + PiecewiseLinearOrder, +} + #[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)] pub enum BinaryOp { Implies, @@ -59,7 +75,12 @@ pub enum BinaryOp { Gt, EuclideanDiv, EuclideanMod, - + /// Z3 special relations (see Relation above) + /// The u64 is the Z3 unique name ("index") for each relation that the user wants + /// ("To create a different relation that is also a partial order use a different index, + /// such as (_ partial-order 1)", according to + /// https://microsoft.github.io/z3guide/docs/theories/Special%20Relations/ .) + Relation(Relation, u64), BitXor, BitAnd, BitOr, diff --git a/source/air/src/closure.rs b/source/air/src/closure.rs index 4346e36d32..2b62db8864 100644 --- a/source/air/src/closure.rs +++ b/source/air/src/closure.rs @@ -455,7 +455,7 @@ fn simplify_expr(ctxt: &mut Context, state: &mut State, expr: &Expr) -> (Typ, Ex ExprX::Binary(op, e1, e2) => { let (es, ts) = simplify_exprs_ref(ctxt, state, &vec![e1, e2]); let typ = match op { - BinaryOp::Implies | BinaryOp::Eq => Arc::new(TypX::Bool), + BinaryOp::Implies | BinaryOp::Eq | BinaryOp::Relation(..) => Arc::new(TypX::Bool), BinaryOp::Le | BinaryOp::Ge | BinaryOp::Lt | BinaryOp::Gt => Arc::new(TypX::Bool), BinaryOp::EuclideanDiv | BinaryOp::EuclideanMod => Arc::new(TypX::Int), BinaryOp::BitUGt | BinaryOp::BitULt | BinaryOp::BitUGe | BinaryOp::BitULe => { diff --git a/source/air/src/parser.rs b/source/air/src/parser.rs index cc71ac1152..1c727d8973 100644 --- a/source/air/src/parser.rs +++ b/source/air/src/parser.rs @@ -1,7 +1,7 @@ use crate::ast::{ BinaryOp, BindX, Binder, BinderX, Binders, Command, CommandX, Commands, Constant, Decl, DeclX, - Decls, Expr, ExprX, Exprs, MultiOp, Qid, Quant, QueryX, Span, Stmt, StmtX, Stmts, Trigger, - Triggers, Typ, TypX, UnaryOp, + Decls, Expr, ExprX, Exprs, MultiOp, Qid, Quant, QueryX, Relation, Span, Stmt, StmtX, Stmts, + Trigger, Triggers, Typ, TypX, UnaryOp, }; use crate::def::mk_skolem_id; use crate::messages::{error_from_labels, error_from_spans, MessageLabel, MessageLabels}; @@ -45,6 +45,27 @@ fn underscore_atom_atom_expr(s1: &str, s2: &str) -> Option { None } +fn relation_binary_op(n1: &Node, n2: &Node) -> Option { + match (n1, n2) { + (Node::Atom(s1), Node::Atom(s2)) => { + if let Ok(n) = s2.parse::() { + match s1.as_str() { + "partial-order" => Some(BinaryOp::Relation(Relation::PartialOrder, n)), + "linear-order" => Some(BinaryOp::Relation(Relation::LinearOrder, n)), + "tree-order" => Some(BinaryOp::Relation(Relation::TreeOrder, n)), + "piecewise-linear-order" => { + Some(BinaryOp::Relation(Relation::PiecewiseLinearOrder, n)) + } + _ => None, + } + } else { + None + } + } + _ => None, + } +} + fn map_nodes_to_vec(nodes: &[Node], f: &F) -> Result>, String> where F: Fn(&Node) -> Result, @@ -229,6 +250,13 @@ impl Parser { Node::Atom(s) if s.to_string() == "bvlshr" => Some(BinaryOp::LShr), Node::Atom(s) if s.to_string() == "bvshl" => Some(BinaryOp::Shl), Node::Atom(s) if s.to_string() == "concat" => Some(BinaryOp::BitConcat), + Node::List(nodes) + if nodes.len() == 3 + && nodes[0] == Node::Atom("_".to_string()) + && relation_binary_op(&nodes[1], &nodes[2]).is_some() => + { + relation_binary_op(&nodes[1], &nodes[2]) + } _ => None, }; let lop = match &nodes[0] { diff --git a/source/air/src/printer.rs b/source/air/src/printer.rs index a72f76626f..a57bc39558 100644 --- a/source/air/src/printer.rs +++ b/source/air/src/printer.rs @@ -158,6 +158,18 @@ impl Printer { _ => Node::List(vec![str_to_node(sop), self.expr_to_node(expr)]), } } + ExprX::Binary(BinaryOp::Relation(relation, n), lhs, rhs) => { + use crate::ast::Relation; + let s = match relation { + Relation::PartialOrder => "partial-order", + Relation::LinearOrder => "linear-order", + Relation::TreeOrder => "tree-order", + Relation::PiecewiseLinearOrder => "piecewise-linear-order", + }; + let op = + Node::List(vec![str_to_node("_"), str_to_node(s), Node::Atom(n.to_string())]); + Node::List(vec![op, self.expr_to_node(lhs), self.expr_to_node(rhs)]) + } ExprX::Binary(op, lhs, rhs) => { let sop = match op { BinaryOp::Implies => "=>", @@ -168,7 +180,7 @@ impl Printer { BinaryOp::Gt => ">", BinaryOp::EuclideanDiv => "div", BinaryOp::EuclideanMod => "mod", - + BinaryOp::Relation(..) => unreachable!(), BinaryOp::BitXor => "bvxor", BinaryOp::BitAnd => "bvand", BinaryOp::BitOr => "bvor", @@ -440,7 +452,7 @@ impl NodeWriter { { brk = true; } - Node::Atom(a) if a == ":pattern" => { + Node::Atom(a) if a == ":pattern" || a == ":qid" || a == ":skolemid" => { was_pattern = true; } _ => {} diff --git a/source/air/src/tests.rs b/source/air/src/tests.rs index 73e2347f45..d1c923c0a0 100644 --- a/source/air/src/tests.rs +++ b/source/air/src/tests.rs @@ -1600,3 +1600,33 @@ fn no_choose5() { ) ) } + +#[test] +fn yes_partial_order() { + yes!( + (declare-sort X 0) + (declare-const c1 X) + (declare-const c2 X) + (declare-const c3 X) + (check-valid + (axiom ((_ partial-order 77) c1 c2)) + (axiom ((_ partial-order 77) c2 c3)) + (assert ((_ partial-order 77) c1 c3)) + ) + ) +} + +#[test] +fn no_partial_order() { + no!( + (declare-sort X 0) + (declare-const c1 X) + (declare-const c2 X) + (declare-const c3 X) + (check-valid + (axiom ((_ partial-order 77) c1 c2)) + (axiom ((_ partial-order 76) c2 c3)) + (assert ((_ partial-order 77) c1 c3)) + ) + ) +} diff --git a/source/air/src/typecheck.rs b/source/air/src/typecheck.rs index 5f7b31c460..f23008093b 100644 --- a/source/air/src/typecheck.rs +++ b/source/air/src/typecheck.rs @@ -225,14 +225,18 @@ fn check_expr(typing: &mut Typing, expr: &Expr) -> Result { ExprX::Binary(BinaryOp::Implies, e1, e2) => { check_exprs(typing, "=>", &[bt(), bt()], &bt(), &[e1.clone(), e2.clone()]) } - ExprX::Binary(BinaryOp::Eq, e1, e2) => { + ExprX::Binary(op @ (BinaryOp::Eq | BinaryOp::Relation(..)), e1, e2) => { let t1 = check_expr(typing, e1)?; let t2 = check_expr(typing, e2)?; if typ_eq(&t1, &t2) { Ok(bt()) } else { Err(format!( - "in equality, left expression has type {} and right expression has different type {}", + "in {}, left expression has type {} and right expression has different type {}", + match op { + BinaryOp::Eq => "equality", + _ => "relation", + }, typ_name(&t1), typ_name(&t2) )) diff --git a/source/builtin/src/lib.rs b/source/builtin/src/lib.rs index d769ad8afc..a383076dc1 100644 --- a/source/builtin/src/lib.rs +++ b/source/builtin/src/lib.rs @@ -1055,6 +1055,38 @@ pub fn arch_word_bits() -> nat { unimplemented!(); } -pub fn height(_a: A) -> nat { +pub fn is_smaller_than(_: A, _: B) -> bool { unimplemented!(); } + +pub fn is_smaller_than_lexicographic(_: A, _: B) -> bool { + unimplemented!(); +} + +pub fn is_smaller_than_recursive_function_field(_: A, _: B) -> bool { + unimplemented!(); +} + +#[macro_export] +macro_rules! decreases_to_internal { + ($($x:expr),* $(,)? => $($y:expr),* $(,)?) => { + $crate::is_smaller_than_lexicographic(($($y,)*), ($($x,)*)) + } +} + +/// decreases_to!(b => a) means that height(a) < height(b), so that b can decrease to a +/// in decreases clauses. +/// decreases_to!(b1, ..., bn => a1, ..., am) can compare lexicographically ordered values, +/// which can be useful when making assertions about decreases clauses. +/// Notes: +/// - decreases_to! desugars to a call to is_smaller_than_lexicographic. +/// - you can write #[trigger](decreases_to!(b => a)) to trigger on height(a). +/// (in the SMT encoding, height is a function call and is a useful trigger, +/// while is_smaller_than/is_smaller_than_lexicographic is not a function call +/// and is not a useful trigger.) +#[macro_export] +macro_rules! decreases_to { + ($($x:tt)*) => { + ::builtin_macros::verus_proof_macro_exprs!($crate::decreases_to_internal!($($x)*)) + }; +} diff --git a/source/pervasive/map.rs b/source/pervasive/map.rs index 7c86b26b21..115eff9baf 100644 --- a/source/pervasive/map.rs +++ b/source/pervasive/map.rs @@ -300,6 +300,41 @@ impl Map { // Trusted axioms +/* REVIEW: this is simpler than the two separate axioms below -- would this be ok? +#[verifier(external_body)] +#[verifier(broadcast_forall)] +pub proof fn axiom_map_index_decreases(m: Map, key: K) + requires + m.dom().contains(key), + ensures + #[trigger](decreases_to!(m => m[key])), +{ +} +*/ + +#[verifier(external_body)] +#[verifier(broadcast_forall)] +pub proof fn axiom_map_index_decreases_finite(m: Map, key: K) + requires + m.dom().finite(), + m.dom().contains(key), + ensures + #[trigger](decreases_to!(m => m[key])), +{ +} + +// REVIEW: this is currently a special case that is hard-wired into the verifier +// It implements a version of https://github.com/FStarLang/FStar/pull/2954 . +#[verifier(external_body)] +#[verifier(broadcast_forall)] +pub proof fn axiom_map_index_decreases_infinite(m: Map, key: K) + requires + m.dom().contains(key), + ensures + #[trigger] is_smaller_than_recursive_function_field(m[key], m), +{ +} + #[verifier(external_body)] #[verifier(broadcast_forall)] pub proof fn axiom_map_empty() diff --git a/source/pervasive/seq.rs b/source/pervasive/seq.rs index 853ff1aac4..8e607e721c 100644 --- a/source/pervasive/seq.rs +++ b/source/pervasive/seq.rs @@ -173,6 +173,16 @@ impl Seq { // Trusted axioms +#[verifier(external_body)] +#[verifier(broadcast_forall)] +pub proof fn axiom_seq_index_decreases(s: Seq, i: int) + requires + 0 <= i < s.len(), + ensures + #[trigger](decreases_to!(s => s[i])), +{ +} + #[verifier(external_body)] #[verifier(broadcast_forall)] pub proof fn axiom_seq_empty() diff --git a/source/rust_verify/example/summer_school/chapter-1-22.rs b/source/rust_verify/example/summer_school/chapter-1-22.rs index 2644c58e77..66738adffa 100644 --- a/source/rust_verify/example/summer_school/chapter-1-22.rs +++ b/source/rust_verify/example/summer_school/chapter-1-22.rs @@ -136,9 +136,8 @@ fn check_is_sorted_tree(tree: &Tree) -> (ret: TreeSortedness) proof { sorted_tree_means_sorted_sequence(**left); + sorted_tree_means_sorted_sequence(**right); } - // sorted_tree_means_sorted_sequence(**right); // TODO: why is only one of these calls - // necessary? // assert(equal(tree@, left@.add(seq![*value as int]).add(right@))); // assert(tree@.len() > 0); diff --git a/source/rust_verify/src/consts.rs b/source/rust_verify/src/consts.rs index 49ef2c4c0b..1a3a7f7a7b 100644 --- a/source/rust_verify/src/consts.rs +++ b/source/rust_verify/src/consts.rs @@ -1 +1 @@ -pub const EXPECTED_SOLVER_VERSION: &str = "4.10.1"; +pub const EXPECTED_SOLVER_VERSION: &str = "4.12.2"; diff --git a/source/rust_verify/src/erase.rs b/source/rust_verify/src/erase.rs index 0fbfd76260..9568a347a7 100644 --- a/source/rust_verify/src/erase.rs +++ b/source/rust_verify/src/erase.rs @@ -6,7 +6,7 @@ use rustc_span::SpanData; use vir::ast::{AutospecUsage, Fun, Krate, Mode, Path, Pattern}; use vir::modes::ErasureModes; -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, PartialEq, Eq, Debug)] pub enum CompilableOperator { IntIntrinsic, Implies, @@ -25,7 +25,7 @@ pub enum CompilableOperator { } /// Information about each call in the AST (each ExprKind::Call). -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub enum ResolvedCall { /// The call is to a spec or proof function, and should be erased Spec, diff --git a/source/rust_verify/src/lifetime_generate.rs b/source/rust_verify/src/lifetime_generate.rs index aa2a536d1e..34d6684912 100644 --- a/source/rust_verify/src/lifetime_generate.rs +++ b/source/rust_verify/src/lifetime_generate.rs @@ -2004,7 +2004,13 @@ pub(crate) fn gen_check_tracked_lifetimes<'tcx>( ctxt.ignored_functions.insert(*id); } for (hir_id, span, call) in &erasure_hints.resolved_calls { - ctxt.calls.insert(*hir_id, call.clone()).map(|_| panic!("{:?}", span)); + if ctxt.calls.contains_key(hir_id) { + if &ctxt.calls[hir_id] != call { + panic!("inconsistent resolved_calls: {:?}", span); + } + } else { + ctxt.calls.insert(*hir_id, call.clone()); + } } for (span, mode) in &erasure_hints.erasure_modes.condition_modes { if crate::spans::from_raw_span(&span.raw_span).is_none() { @@ -2015,7 +2021,13 @@ pub(crate) fn gen_check_tracked_lifetimes<'tcx>( panic!("missing id_to_hir"); } for hir_id in &id_to_hir[&span.id] { - ctxt.condition_modes.insert(*hir_id, *mode).map(|_| panic!("{:?}", span)); + if ctxt.condition_modes.contains_key(hir_id) { + if &ctxt.condition_modes[hir_id] != mode { + panic!("inconsistent condition_modes: {:?}", span); + } + } else { + ctxt.condition_modes.insert(*hir_id, *mode); + } } } for (span, mode) in &erasure_hints.erasure_modes.var_modes { @@ -2027,7 +2039,13 @@ pub(crate) fn gen_check_tracked_lifetimes<'tcx>( panic!("missing id_to_hir"); } for hir_id in &id_to_hir[&span.id] { - ctxt.var_modes.insert(*hir_id, *mode).map(|v| panic!("{:?} {:?}", span, v)); + if ctxt.var_modes.contains_key(hir_id) { + if &ctxt.var_modes[hir_id] != mode { + panic!("inconsistent var_modes: {:?}", span); + } + } else { + ctxt.var_modes.insert(*hir_id, *mode); + } } } for (hir_id, mode) in &erasure_hints.direct_var_modes { diff --git a/source/rust_verify/src/rust_to_vir_expr.rs b/source/rust_verify/src/rust_to_vir_expr.rs index 3ab898a3c5..455f6c47ba 100644 --- a/source/rust_verify/src/rust_to_vir_expr.rs +++ b/source/rust_verify/src/rust_to_vir_expr.rs @@ -381,6 +381,66 @@ fn check_lit_int( } } +fn mk_is_smaller_than<'tcx>( + bctx: &BodyCtxt<'tcx>, + span: Span, + args0: Vec<&'tcx Expr>, + args1: Vec<&'tcx Expr>, + recursive_function_field: bool, +) -> Result { + // convert is_smaller_than((x0, y0, z0), (x1, y1, z1)) into + // x0 < x1 || (x0 == x1 && (y0 < y1 || (y0 == y1 && z0 < z1))) + // see also check_decrease in recursion.rs + let tbool = Arc::new(TypX::Bool); + let tint = Arc::new(TypX::Int(IntRange::Int)); + let when_equalx = ExprX::Const(Constant::Bool(args1.len() < args0.len())); + let when_equal = bctx.spanned_typed_new(span, &tbool, when_equalx); + let mut dec_exp: vir::ast::Expr = when_equal; + for (i, (exp0, exp1)) in args0.iter().zip(args1.iter()).rev().enumerate() { + let mk_bop = |op: BinaryOp, e1: vir::ast::Expr, e2: vir::ast::Expr| { + bctx.spanned_typed_new(span, &tbool, ExprX::Binary(op, e1, e2)) + }; + let mk_cmp = |lt: bool| -> Result { + let e0 = expr_to_vir(bctx, exp0, ExprModifier::REGULAR)?; + let e1 = expr_to_vir(bctx, exp1, ExprModifier::REGULAR)?; + if vir::recursion::height_is_int(&e0.typ) { + if lt { + // 0 <= x < y + let zerox = ExprX::Const(vir::ast_util::const_int_from_u128(0)); + let zero = bctx.spanned_typed_new(span, &tint, zerox); + let op0 = BinaryOp::Inequality(InequalityOp::Le); + let cmp0 = mk_bop(op0, zero, e0); + let op1 = BinaryOp::Inequality(InequalityOp::Lt); + let e0 = expr_to_vir(bctx, exp0, ExprModifier::REGULAR)?; + let cmp1 = mk_bop(op1, e0, e1); + Ok(mk_bop(BinaryOp::And, cmp0, cmp1)) + } else { + Ok(mk_bop(BinaryOp::Eq(Mode::Spec), e0, e1)) + } + } else { + let cmp = BinaryOp::HeightCompare { strictly_lt: lt, recursive_function_field }; + Ok(mk_bop(cmp, e0, e1)) + } + }; + if i == 0 { + // i == 0 means last shared exp0/exp1, which we visit first + if args1.len() < args0.len() { + // if z0 == z1, we can ignore the extra args0: + // z0 < z1 || z0 == z1 + dec_exp = mk_bop(BinaryOp::Or, mk_cmp(true)?, mk_cmp(false)?); + } else { + // z0 < z1 + dec_exp = mk_cmp(true)?; + } + } else { + // x0 < x1 || (x0 == x1 && dec_exp) + let and = mk_bop(BinaryOp::And, mk_cmp(false)?, dec_exp); + dec_exp = mk_bop(BinaryOp::Or, mk_cmp(true)?, and); + } + } + return Ok(dec_exp); +} + pub(crate) fn expr_to_vir_inner<'tcx>( bctx: &BodyCtxt<'tcx>, expr: &Expr<'tcx>, @@ -574,7 +634,9 @@ fn fn_call_to_vir<'tcx>( let is_signed_max = f_name == "builtin::signed_max"; let is_unsigned_max = f_name == "builtin::unsigned_max"; let is_arch_word_bits = f_name == "builtin::arch_word_bits"; - let is_height = f_name == "builtin::height"; + let is_smaller_than = f_name == "builtin::is_smaller_than"; + let is_smaller_than_lex = f_name == "builtin::is_smaller_than_lexicographic"; + let is_smaller_than_rec_fun = f_name == "builtin::is_smaller_than_recursive_function_field"; let is_reveal_strlit = tcx.is_diagnostic_item(Symbol::intern("builtin::reveal_strlit"), f); let is_strslice_len = tcx.is_diagnostic_item(Symbol::intern("builtin::strslice_len"), f); @@ -708,7 +770,9 @@ fn fn_call_to_vir<'tcx>( || is_strslice_is_ascii || is_closure_to_fn_spec || is_arch_word_bits - || is_height; + || is_smaller_than + || is_smaller_than_lex + || is_smaller_than_rec_fun; let is_spec_allow_proof_args_pre = is_spec_op || is_builtin_add || is_builtin_sub @@ -1160,10 +1224,22 @@ fn fn_call_to_vir<'tcx>( return mk_expr(ExprX::UnaryOpr(UnaryOpr::IntegerTypeBound(kind, Mode::Spec), arg)); } - if is_height { - assert!(args.len() == 1); - let arg = expr_to_vir(bctx, &args[0], ExprModifier::REGULAR)?; - return mk_expr(ExprX::UnaryOpr(UnaryOpr::Height, arg)); + if is_smaller_than || is_smaller_than_lex || is_smaller_than_rec_fun { + assert!(args.len() == 2); + let (args0, args1) = if is_smaller_than_lex { + match (&args[0].kind, &args[1].kind) { + (ExprKind::Tup(_), ExprKind::Tup(_)) => { + (extract_tuple(args[0]), extract_tuple(args[1])) + } + _ => unsupported_err!( + expr.span, + "is_smaller_than_lexicographic requires tuple arguments" + ), + } + } else { + (vec![args[0]], vec![args[1]]) + }; + return mk_is_smaller_than(bctx, expr.span, args0, args1, is_smaller_than_rec_fun); } if is_smartptr_new { diff --git a/source/rust_verify_test/tests/recursion.rs b/source/rust_verify_test/tests/recursion.rs index bcad7c72ed..b1a78422fe 100644 --- a/source/rust_verify_test/tests/recursion.rs +++ b/source/rust_verify_test/tests/recursion.rs @@ -382,7 +382,9 @@ test_verify_one_file! { decreases i { if 0 < i { + assert(decreases_to!(i => i - 1)); dec1((i - 1) as nat); + assert(decreases_to!(i => i, 100 * i)); dec2(i, 100 * i); } } @@ -391,10 +393,13 @@ test_verify_one_file! { decreases j, k { if 0 < k { + assert(decreases_to!(j, k => j, k - 1)); dec2(j, (k - 1) as nat); } if 0 < j { + assert(decreases_to!(j, k => j - 1, 100 * j + k)); dec2((j - 1) as nat, 100 * j + k); + assert(decreases_to!(j, k => j - 1)); dec1((j - 1) as nat); } } @@ -426,6 +431,33 @@ test_verify_one_file! { } => Err(err) => assert_fails(err, 2) } +test_verify_one_file! { + #[test] multidecrease1_fail1_assert verus_code! { + proof fn dec1(i: nat) + decreases i + { + if 0 < i { + let tmp = decreases_to!(i => i); + assert(tmp); // FAILS + dec2(i, 100 * i); + } + } + + proof fn dec2(j: nat, k: nat) + decreases j, k + { + if 0 < k { + let tmp = decreases_to!(j, k => j, k); + assert(tmp); // FAILS + } + if 0 < j { + dec2((j - 1) as nat, 100 * j + k); + dec1((j - 1) as nat); + } + } + } => Err(err) => assert_fails(err, 2) +} + test_verify_one_file! { #[test] multidecrease1_fail2 verus_code! { proof fn dec1(i: nat) @@ -451,6 +483,33 @@ test_verify_one_file! { } => Err(err) => assert_fails(err, 2) } +test_verify_one_file! { + #[test] multidecrease1_fail2_assert verus_code! { + proof fn dec1(i: nat) + decreases i + { + if 0 < i { + dec1((i - 1) as nat); + let tmp = decreases_to!(i => i + 1, 100 * i); + assert(tmp); // FAILS + } + } + + proof fn dec2(j: nat, k: nat) + decreases j, k + { + if 0 < k { + dec2(j, (k - 1) as nat); + } + if 0 < j { + let tmp = decreases_to!(j, k => j, 100 * j + k); + assert(tmp); // FAILS + dec1((j - 1) as nat); + } + } + } => Err(err) => assert_fails(err, 2) +} + test_verify_one_file! { #[test] multidecrease1_fail3 verus_code! { proof fn dec1(i: nat) @@ -476,6 +535,32 @@ test_verify_one_file! { } => Err(err) => assert_one_fails(err) } +test_verify_one_file! { + #[test] multidecrease1_fail3_assert verus_code! { + proof fn dec1(i: nat) + decreases i + { + if 0 < i { + dec1((i - 1) as nat); + dec2(i, 100 * i); + } + } + + proof fn dec2(j: nat, k: nat) + decreases j, k + { + if 0 < k { + dec2(j, (k - 1) as nat); + } + if 0 < j { + dec2((j - 1) as nat, 100 * j + k); + let tmp = decreases_to!(j, k => j); + assert(tmp); // FAILS + } + } + } => Err(err) => assert_one_fails(err) +} + test_verify_one_file! { #[test] multidecrease1_fail4 verus_code! { proof fn dec1(i: nat) { @@ -1283,6 +1368,239 @@ test_verify_one_file! { } => Ok(()) } +test_verify_one_file! { + #[test] decrease_through_seq verus_code! { + use vstd::prelude::*; + + struct S { + x: Seq>, + } + + spec fn f(s: S) -> int + decreases s + { + if s.x.len() > 0 { + f(*s.x[0]) + } else { + 0 + } + } + + proof fn p(s: S) + decreases s + { + if s.x.len() > 0 { + p(*s.x[0]); + assert(false); // FAILS + } + } + + proof fn q(s: S) + decreases s + { + q(*s.x[0]); // FAILS + } + } => Err(e) => assert_fails(e, 2) +} + +test_verify_one_file! { + #[test] decrease_through_map verus_code! { + use vstd::prelude::*; + + struct S { + x: Map>, + } + + spec fn f(s: S) -> int + decreases s + { + if s.x.dom().contains(3) { + f(*s.x[3]) + } else { + 0 + } + } + + proof fn p(s: S) + decreases s + { + if s.x.dom().contains(3) { + p(*s.x[3]); + assert(false); // FAILS + } + } + + proof fn q(s: S) + decreases s + { + q(*s.x[3]); // FAILS + } + } => Err(e) => assert_fails(e, 2) +} + +test_verify_one_file! { + #[test] decrease_through_my_map verus_code! { + // Err on the side of caution; see https://github.com/FStarLang/FStar/pull/2954 + use vstd::prelude::*; + + #[verifier::reject_recursive_types(A)] + #[verifier::accept_recursive_types(B)] + struct MyMap(Map); + struct S { + x: MyMap>, + } + + spec fn f(s: S) -> int + decreases s + { + if s.x.0.dom().contains(3) { f(*s.x.0[3]) } else { 0 } // FAILS + } + } => Err(e) => assert_one_fails(e) +} + +test_verify_one_file! { + #[test] decrease_through_function verus_code! { + enum E { + Nil, + F(FnSpec(int) -> E), + } + + proof fn p(e: E) + decreases e + { + if let E::F(f) = e { + p(f(0)); + } + } + } => Ok(()) +} + +test_verify_one_file! { + #[test] decrease_through_function_fails verus_code! { + enum E { + Nil, + F(FnSpec(int) -> E), + } + + proof fn p(e: E) + decreases e + { + if let E::F(f) = e { + p(f(0)); + assert(false); // FAILS + } + } + } => Err(e) => assert_one_fails(e) +} + +test_verify_one_file! { + #[test] decrease_through_function_bad verus_code! { + struct S { + x: FnSpec(int) -> S, + } + + proof fn p(s: S) + ensures false + decreases s + { + p((s.x)(0)); + } + } => Err(e) => assert_vir_error_msg(e, "datatype must have at least one non-recursive variant") +} + +test_verify_one_file! { + #[test] decrease_through_my_fun verus_code! { + // Err on the side of caution; see https://github.com/FStarLang/FStar/pull/2954 + use vstd::prelude::*; + + #[verifier::reject_recursive_types(A)] + struct MyFun(FnSpec(A) -> B); + enum E { + Nil, + F(MyFun), + } + + proof fn p(e: E) + decreases e + { + if let E::F(f) = e { + p(f.0(0)); // FAILS + } + } + } => Err(e) => assert_one_fails(e) +} + +test_verify_one_file! { + #[test] decrease_through_abstract_type verus_code! { + mod m1 { + use builtin::*; + pub struct S(A, B); + impl S { + pub closed spec fn get0(self) -> A { self.0 } + pub closed spec fn get1(self) -> B { self.1 } + } + // TODO: broadcast_forall + pub proof fn lemma_height_s(s: S) + ensures + decreases_to!(s => s.get0()), + decreases_to!(s => s.get1()), + { + } + } + + mod m2 { + use builtin::*; + use crate::m1::*; + enum Q { + Nil, + Cons(S>), + } + proof fn test(q: Q) + decreases q, + { + if let Q::Cons(s) = q { + lemma_height_s(s); + test(*s.get1()); + } + } + } + + mod m3 { + use builtin::*; + use crate::m1::*; + enum Q { + Nil, + Cons(S>), + } + proof fn test(q: Q) + decreases q, + { + if let Q::Cons(s) = q { + test(*s.get1()); // FAILS + } + } + } + + mod m4 { + use builtin::*; + use crate::m1::*; + enum Q { + Nil, + Cons(S>), + } + proof fn test(q: Q) + decreases q, + { + if let Q::Cons(s) = q { + lemma_height_s(s); + test(*s.get1()); + assert(false); // FAILS + } + } + } + } => Err(e) => assert_fails(e, 2) +} + test_verify_one_file! { #[test] height_intrinsic verus_code! { #[is_variant] @@ -1297,25 +1615,25 @@ test_verify_one_file! { assert(l == *x.get_Node_0()); assert(r == *x.get_Node_1()); - assert(height(x) > height(l)); - assert(height(x) > height(r)); - assert(height(x) > height(x.get_Node_0())); - - assert(height(l) >= 0); + assert(decreases_to!(x => l)); + assert(decreases_to!(x => r)); + assert(decreases_to!(x => x.get_Node_0())); } proof fn testing_fail(l: Tree, r: Tree) { - assert(height(l) > height(r)); // FAILS + let tmp = decreases_to!(l => r); + assert(tmp); // FAILS } proof fn testing_fail2(x: Tree) { - assert(height(x.get_Node_0()) < height(x)); // FAILS + let tmp = decreases_to!(x => x.get_Node_0()); + assert(tmp); // FAILS } proof fn testing3(x: Tree) requires x.is_Node(), { - assert(height(x.get_Node_0()) < height(x)); + assert(decreases_to!(x => x.get_Node_0())); } } => Err(e) => assert_fails(e, 2) } @@ -1329,9 +1647,9 @@ test_verify_one_file! { } fn test(tree: Tree) { - let x = height(tree); + let x = decreases_to!(tree => tree); } - } => Err(err) => assert_vir_error_msg(err, "cannot test 'height' in exec mode") + } => Err(err) => assert_vir_error_msg(err, "expression has mode spec, expected mode exec") } test_verify_one_file! { diff --git a/source/tools/get-z3.ps1 b/source/tools/get-z3.ps1 index e78c940382..93bdbf8f01 100644 --- a/source/tools/get-z3.ps1 +++ b/source/tools/get-z3.ps1 @@ -1,4 +1,4 @@ -$z3_version = "4.10.1" +$z3_version = "4.12.2" $filename = "z3-$z3_version-x64-win" $download_url = "https://github.com/Z3Prover/z3/releases/download/z3-$z3_version/$filename.zip" diff --git a/source/tools/get-z3.sh b/source/tools/get-z3.sh index 7d585d0a68..1ce33126cc 100755 --- a/source/tools/get-z3.sh +++ b/source/tools/get-z3.sh @@ -1,6 +1,6 @@ #! /bin/bash -z3_version="4.10.1" +z3_version="4.12.2" if [ `uname` == "Darwin" ]; then if [[ $(uname -m) == 'arm64' ]]; then diff --git a/source/vir/src/ast.rs b/source/vir/src/ast.rs index 2d234be7e6..ac90ce01b7 100644 --- a/source/vir/src/ast.rs +++ b/source/vir/src/ast.rs @@ -210,6 +210,12 @@ pub enum UnaryOp { /// Internal consistency check to make sure finalize_exp gets called /// (appears only briefly in SST before finalize_exp is called) MustBeFinalized, + /// We don't give users direct access to the "height" function and Height types. + /// However, it's useful to be able to trigger on the "height" function + /// when using HeightCompare. We manage this by having triggers.rs convert + /// HeightCompare triggers into HeightTrigger, which is eventually translated + /// into direct calls to the "height" function in the triggers. + HeightTrigger, /// Used only for handling builtin::strslice_len StrLen, /// Used only for handling builtin::strslice_is_ascii @@ -258,9 +264,6 @@ pub enum UnaryOpr { /// to hold the result. /// Mode is the minimum allowed mode (e.g., Spec for spec-only, Exec if allowed in exec). IntegerTypeBound(IntegerTypeBoundKind, Mode), - /// Height of a data structure for the purpose of decreases-checking. - /// Maps to the built-in intrinsic. - Height, /// Custom diagnostic message CustomErr(Arc), } @@ -318,12 +321,14 @@ pub enum BinaryOp { Xor, /// boolean implies (short-circuiting: right side is evaluated only if left side is true) Implies, + /// the is_smaller_than builtin, used for decreases (true for <, false for ==) + HeightCompare { strictly_lt: bool, recursive_function_field: bool }, /// SMT equality for any type -- two expressions are exactly the same value /// Some types support compilable equality (Mode == Exec); others only support spec equality (Mode == Spec) Eq(Mode), /// not Eq Ne, - /// + /// arithmetic inequality Inequality(InequalityOp), /// IntRange operations that may require overflow or divide-by-zero checks /// (None for InferMode means always mode Spec) diff --git a/source/vir/src/ast_visitor.rs b/source/vir/src/ast_visitor.rs index 0f56ffda0a..7d06d8baeb 100644 --- a/source/vir/src/ast_visitor.rs +++ b/source/vir/src/ast_visitor.rs @@ -37,9 +37,9 @@ where VisitorControlFlow::Recurse => { match &**typ { TypX::Bool + | TypX::Int(_) | TypX::StrSlice | TypX::Char - | TypX::Int(_) | TypX::TypParam(_) | TypX::TypeId | TypX::ConstInt(_) @@ -90,9 +90,9 @@ where { match &**typ { TypX::Bool + | TypX::Int(_) | TypX::StrSlice | TypX::Char - | TypX::Int(_) | TypX::TypParam(_) | TypX::TypeId | TypX::ConstInt(_) @@ -633,7 +633,6 @@ where UnaryOpr::TupleField { .. } => op.clone(), UnaryOpr::Field { .. } => op.clone(), UnaryOpr::IntegerTypeBound(_kind, _) => op.clone(), - UnaryOpr::Height => op.clone(), UnaryOpr::CustomErr(_) => op.clone(), }; let expr1 = map_expr_visitor_env(e1, map, env, fe, fs, ft)?; diff --git a/source/vir/src/datatype_to_air.rs b/source/vir/src/datatype_to_air.rs index 123dfd7fe5..dbb70ad345 100644 --- a/source/vir/src/datatype_to_air.rs +++ b/source/vir/src/datatype_to_air.rs @@ -254,13 +254,31 @@ fn datatype_or_fun_to_air_commands( // trigger on apply(x, args), has_type_f params.push(x_param(&datatyp)); pre.insert(0, has_box.clone()); - let trigs = vec![app, has_box.clone()]; + let trigs = vec![app.clone(), has_box.clone()]; let name = format!("{}_{}", path_as_rust_name(dpath), QID_APPLY); - let bind = func_bind_trig(ctx, name, tparams, &Arc::new(params), &trigs, false, false); + let aparams = Arc::new(params.clone()); + let bind = func_bind_trig(ctx, name, tparams, &aparams, &trigs, false, false); let imply = mk_implies(&mk_and(&pre), &has_app); let forall = mk_bind_expr(&bind, &imply); let axiom = Arc::new(DeclX::Axiom(forall)); axiom_commands.push(Arc::new(CommandX::Global(axiom))); + + // Lambda height axiom: + // forall typ1 ... typn, tret, arg1: Poly ... argn: Poly, x: Fun. + // has_type_f && has_type1 && ... && has_typen ==> + // height_lt(height(apply(x, args)), height(box(mk_fun(x)))) + // trigger on height(apply(x, args)), has_type_f + let height_app = str_apply(crate::def::HEIGHT, &vec![app]); + let from_rec_fun = str_apply(crate::def::HEIGHT_REC_FUN, &vec![box_mk_fun]); + let height_fun = str_apply(crate::def::HEIGHT, &vec![from_rec_fun]); + let height_lt = str_apply(crate::def::HEIGHT_LT, &vec![height_app.clone(), height_fun]); + let trigs = vec![height_app, has_box.clone()]; + let name = format!("{}_{}", path_as_rust_name(dpath), crate::def::QID_HEIGHT_APPLY); + let bind = func_bind_trig(ctx, name, tparams, &aparams, &trigs, false, false); + let imply = mk_implies(&mk_and(&pre), &height_lt); + let forall = mk_bind_expr(&bind, &imply); + let axiom = Arc::new(DeclX::Axiom(forall)); + axiom_commands.push(Arc::new(CommandX::Global(axiom))); } // constructor and field axioms @@ -370,33 +388,91 @@ fn datatype_or_fun_to_air_commands( for variant in variants.iter() { for field in variant.a.iter() { let typ = &field.a.0; - match &*crate::ast_util::undecorate_typ(typ) { - TypX::Datatype(path, _) if ctx.datatype_is_transparent[path] => { - let node = crate::prelude::datatype_height_axiom( - &dpath, - Some(&path), - &is_variant_ident(dpath, &*variant.name), - &variant_field_ident(&dpath, &variant.name, &field.name), - ); - let axiom = air::parser::Parser::new() - .node_to_command(&node) - .expect("internal error: malformed datatype axiom"); - axiom_commands.push(axiom); - } - TypX::TypParam(_) => { - let node = crate::prelude::datatype_height_axiom( - &dpath, - None, - &is_variant_ident(dpath, &*variant.name), - &variant_field_ident(&dpath, &variant.name, &field.name), - ); - let axiom = air::parser::Parser::new() - .node_to_command(&node) - .expect("internal error: malformed datatype axiom"); - axiom_commands.push(axiom); + let mut recursion_or_tparam = |t: &Typ| match &**t { + TypX::Datatype(path, _) + if ctx.global.datatype_graph.in_same_scc(path, dpath) => + { + Err(()) } - _ => {} + TypX::TypParam(_) => Err(()), + _ => Ok(()), + }; + let has_recursion_or_tparam = + crate::ast_visitor::typ_visitor_check(typ, &mut recursion_or_tparam).is_err(); + if !has_recursion_or_tparam { + continue; } + let typ = crate::ast_util::undecorate_typ(typ); + let field_box_path = match &*typ { + TypX::Lambda(typs, _) => Some(prefix_lambda_type(typs.len())), + TypX::Datatype(..) => crate::sst_to_air::datatype_box_prefix(ctx, &typ), + TypX::Boxed(_) => None, + TypX::TypParam(_) => None, + _ => continue, + }; + let unboxed = if let TypX::Boxed(t) = &*typ { t } else { &*typ }; + let fun_or_map_ret = { + match unboxed { + TypX::Lambda(_, ret) => Some(ret), + TypX::Datatype(d, targs) + if crate::ast_util::path_as_vstd_name(d) + == Some("map::Map".to_string()) + && targs.len() == 2 => + { + // HACK special case for the infinite map::Map type, + // which is like a FnSpec type + Some(&targs[1]) + } + _ => None, + } + }; + let recursive_function_field = if let Some(ret) = fun_or_map_ret { + // REVIEW: this is inspired by https://github.com/FStarLang/FStar/pull/2954 , + // which restricts decreases on FnSpec applications or Map lookups + // to the case where the FnSpec or Map is a field of a recursive datatype + // and the application or lookup returns a value of the recursive datatype. + // It's not clear that we need this restriction, since we don't have F*'s + // universes, but let's err on the side of cautious for now. + // We define recursive_function_field to be true when all of these hold: + // 1) the field is a FnSpec or Map type + // 2) the only use of type parameters in the FnSpec/Map return type + // is to instantiate the datatype with exactly its original parameters + // For example, recursive_function_field is true for field f here: + // struct S { a: A, b: B, f: FnSpec(int) -> Option> } + // but is false for field f here: + // struct S { a: A, b: B, f: FnSpec(int) -> Option<(A, B)> } + // because A and B appear in the return type, but not as part of S + // This suppresses decreases for a wrapper around a FnSpec or infinite Map: + // struct MyFun(FnSpec(A) -> B); + // TODO: allow recursive_function_field across mutually recursive datatypes + // that have type parameters (e.g. by inlining the recursive types). + let our_typ = Arc::new(TypX::Datatype(dpath.clone(), typ_args.clone())); + use crate::visitor::VisitorControlFlow; + let mut visitor = |t: &Typ| -> VisitorControlFlow<()> { + if crate::ast_util::types_equal(t, &our_typ) { + VisitorControlFlow::Return + } else if let TypX::TypParam(_) = &**t { + VisitorControlFlow::Stop(()) + } else { + VisitorControlFlow::Recurse + } + }; + let visit = crate::ast_visitor::typ_visitor_dfs(ret, &mut visitor); + visit == VisitorControlFlow::Recurse + } else { + false + }; + let nodes = crate::prelude::datatype_height_axioms( + &dpath, + &field_box_path, + &is_variant_ident(dpath, &*variant.name), + &variant_field_ident(&dpath, &variant.name, &field.name), + recursive_function_field, + ); + let axioms = air::parser::Parser::new() + .nodes_to_commands(&nodes) + .expect("internal error: malformed datatype axiom"); + axiom_commands.extend(axioms.iter().cloned()); } } } diff --git a/source/vir/src/def.rs b/source/vir/src/def.rs index dfa1b71217..9ef1ee0bd3 100644 --- a/source/vir/src/def.rs +++ b/source/vir/src/def.rs @@ -109,6 +109,7 @@ pub const EUC_MOD: &str = "EucMod"; pub const SNAPSHOT_CALL: &str = "CALL"; pub const SNAPSHOT_PRE: &str = "PRE"; pub const SNAPSHOT_ASSIGN: &str = "ASSIGN"; +pub const T_HEIGHT: &str = "Height"; pub const POLY: &str = "Poly"; pub const BOX_INT: &str = "I"; pub const BOX_BOOL: &str = "B"; @@ -141,7 +142,10 @@ pub const MK_FUN: &str = "mk_fun"; pub const CONST_INT: &str = "const_int"; pub const DUMMY_PARAM: &str = "no%param"; pub const CHECK_DECREASE_INT: &str = "check_decrease_int"; +pub const CHECK_DECREASE_HEIGHT: &str = "check_decrease_height"; pub const HEIGHT: &str = "height"; +pub const HEIGHT_LT: &str = "height_lt"; +pub const HEIGHT_REC_FUN: &str = "fun_from_recursive_field"; pub const CLOSURE_REQ: &str = "closure_req"; pub const CLOSURE_ENS: &str = "closure_ens"; pub const EXT_EQ: &str = "ext_eq"; @@ -160,6 +164,7 @@ pub const QID_CONSTRUCTOR_INNER: &str = "constructor_inner"; pub const QID_CONSTRUCTOR: &str = "constructor"; pub const QID_EXT_EQUAL: &str = "ext_equal"; pub const QID_APPLY: &str = "apply"; +pub const QID_HEIGHT_APPLY: &str = "height_apply"; pub const QID_ACCESSOR: &str = "accessor"; pub const QID_INVARIANT: &str = "invariant"; pub const QID_HAS_TYPE_ALWAYS: &str = "has_type_always"; diff --git a/source/vir/src/interpreter.rs b/source/vir/src/interpreter.rs index c9633e6b21..47f0d2aa59 100644 --- a/source/vir/src/interpreter.rs +++ b/source/vir/src/interpreter.rs @@ -924,6 +924,7 @@ fn eval_expr_internal(ctx: &Ctx, state: &mut State, exp: &Exp) -> Result bool_new(!b), BitNot | Clip { .. } + | HeightTrigger | Trigger(_) | CoerceMode { .. } | StrLen @@ -1029,9 +1030,13 @@ fn eval_expr_internal(ctx: &Ctx, state: &mut State, exp: &Exp) -> Result { panic!("Found MustBeFinalized op {:?} after calling finalize_exp", exp) } - Not | Trigger(_) | CoerceMode { .. } | StrLen | StrIsAscii | CharToInt => { - ok - } + Not + | HeightTrigger + | Trigger(_) + | CoerceMode { .. } + | StrLen + | StrIsAscii + | CharToInt => ok, } } // !(!(e_inner)) == e_inner @@ -1098,7 +1103,6 @@ fn eval_expr_internal(ctx: &Ctx, state: &mut State, exp: &Exp) -> Result ok, } } - Height => ok, CustomErr(_) => Ok(e), } } @@ -1341,7 +1345,7 @@ fn eval_expr_internal(ctx: &Ctx, state: &mut State, exp: &Exp) -> Result ok_e2(e2.clone()), + HeightCompare { .. } | StrGetChar => ok_e2(e2.clone()), } } BinaryOpr(op, e1, e2) => { diff --git a/source/vir/src/modes.rs b/source/vir/src/modes.rs index 289d49e690..5f6d794ebc 100644 --- a/source/vir/src/modes.rs +++ b/source/vir/src/modes.rs @@ -679,6 +679,9 @@ fn check_expr_handle_mut_arg( Ok(*to_mode) } } + ExprX::Unary(UnaryOp::HeightTrigger, _) => { + panic!("direct access to 'height' is not allowed") + } ExprX::Unary(_, e1) => check_expr(typing, outer_mode, erasure_mode, e1), ExprX::UnaryOpr(UnaryOpr::Box(_), e1) => check_expr(typing, outer_mode, erasure_mode, e1), ExprX::UnaryOpr(UnaryOpr::Unbox(_), e1) => check_expr(typing, outer_mode, erasure_mode, e1), @@ -716,13 +719,6 @@ fn check_expr_handle_mut_arg( let mode = check_expr(typing, joined_mode, erasure_mode, e1)?; Ok(mode_join(*min_mode, mode)) } - ExprX::UnaryOpr(UnaryOpr::Height, e1) => { - if typing.check_ghost_blocks && typing.block_ghostness == Ghost::Exec { - return error(&expr.span, "cannot test 'height' in exec mode"); - } - check_expr_has_mode(typing, Mode::Spec, e1, Mode::Spec)?; - Ok(Mode::Spec) - } ExprX::UnaryOpr(UnaryOpr::CustomErr(_), e1) => { check_expr_has_mode(typing, Mode::Spec, e1, Mode::Spec)?; Ok(Mode::Spec) @@ -733,6 +729,7 @@ fn check_expr_handle_mut_arg( ExprX::Binary(op, e1, e2) => { let op_mode = match op { BinaryOp::Eq(mode) => *mode, + BinaryOp::HeightCompare { .. } => Mode::Spec, _ => Mode::Exec, }; match op { @@ -744,7 +741,8 @@ fn check_expr_handle_mut_arg( } let outer_mode = match op { // because Implies isn't compiled, make it spec-only - BinaryOp::Implies => mode_join(outer_mode, Mode::Spec), + BinaryOp::Implies => Mode::Spec, + BinaryOp::HeightCompare { .. } => Mode::Spec, _ => outer_mode, }; let mode1 = check_expr(typing, outer_mode, erasure_mode, e1)?; diff --git a/source/vir/src/poly.rs b/source/vir/src/poly.rs index fc71cccec6..b890dcf1da 100644 --- a/source/vir/src/poly.rs +++ b/source/vir/src/poly.rs @@ -369,6 +369,7 @@ fn poly_expr(ctx: &Ctx, state: &mut State, expr: &Expr) -> Expr { let e1 = coerce_expr_to_native(ctx, &e1); mk_expr(ExprX::Unary(*op, e1)) } + UnaryOp::HeightTrigger => panic!("direct access to 'height' is not allowed"), UnaryOp::Trigger(_) | UnaryOp::CoerceMode { .. } => { mk_expr_typ(&e1.typ, ExprX::Unary(*op, e1.clone())) } @@ -388,10 +389,6 @@ fn poly_expr(ctx: &Ctx, state: &mut State, expr: &Expr) -> Expr { let e1 = coerce_expr_to_native(ctx, &e1); mk_expr(ExprX::UnaryOpr(op.clone(), e1)) } - UnaryOpr::Height => { - let e1 = coerce_expr_to_poly(ctx, &e1); - mk_expr(ExprX::UnaryOpr(op.clone(), e1)) - } UnaryOpr::CustomErr(_) => { let exprx = ExprX::UnaryOpr(op.clone(), e1.clone()); SpannedTyped::new(&e1.span, &e1.typ, exprx) @@ -423,17 +420,22 @@ fn poly_expr(ctx: &Ctx, state: &mut State, expr: &Expr) -> Expr { let e1 = poly_expr(ctx, state, e1); let e2 = poly_expr(ctx, state, e2); use BinaryOp::*; - let native = match op { - And | Or | Xor | Implies | Inequality(_) => true, - Arith(..) => true, - Eq(_) | Ne => false, - Bitwise(..) => true, - StrGetChar { .. } => true, + let (native, poly) = match op { + And | Or | Xor | Implies | Inequality(_) => (true, false), + HeightCompare { .. } => (false, true), + Arith(..) => (true, false), + Eq(_) | Ne => (false, false), + Bitwise(..) => (true, false), + StrGetChar { .. } => (true, false), }; if native { let e1 = coerce_expr_to_native(ctx, &e1); let e2 = coerce_expr_to_native(ctx, &e2); mk_expr(ExprX::Binary(*op, e1, e2)) + } else if poly { + let e1 = coerce_expr_to_poly(ctx, &e1); + let e2 = coerce_expr_to_poly(ctx, &e2); + mk_expr(ExprX::Binary(*op, e1, e2)) } else { let (e1, e2) = coerce_exprs_to_agree(ctx, &e1, &e2); mk_expr(ExprX::Binary(*op, e1, e2)) diff --git a/source/vir/src/prelude.rs b/source/vir/src/prelude.rs index edc4f704f9..928bc3938d 100644 --- a/source/vir/src/prelude.rs +++ b/source/vir/src/prelude.rs @@ -63,11 +63,17 @@ pub(crate) fn prelude_nodes(config: PreludeConfig) -> Vec { #[allow(non_snake_case)] let EucMod = str_to_node(EUC_MOD); let check_decrease_int = str_to_node(CHECK_DECREASE_INT); + let check_decrease_height = str_to_node(CHECK_DECREASE_HEIGHT); let height = str_to_node(HEIGHT); + let height_le = nodes!(_ partial-order 0); + let height_lt = str_to_node(HEIGHT_LT); + let height_rec_fun = str_to_node(HEIGHT_REC_FUN); let closure_req = str_to_node(CLOSURE_REQ); let closure_ens = str_to_node(CLOSURE_ENS); #[allow(non_snake_case)] let Poly = str_to_node(POLY); + #[allow(non_snake_case)] + let Height = str_to_node(T_HEIGHT); let box_int = str_to_node(BOX_INT); let box_bool = str_to_node(BOX_BOOL); let unbox_int = str_to_node(UNBOX_INT); @@ -151,11 +157,11 @@ pub(crate) fn prelude_nodes(config: PreludeConfig) -> Vec { (declare-fun [strslice_len] ([strslice]) Int) (declare-fun [strslice_get_char] ([strslice] Int) [Char]) (declare-fun [new_strlit] (Int) [strslice]) - (declare-fun [from_strlit] ([strslice]) Int) // Polymorphism (declare-sort [Poly] 0) + (declare-sort [Height] 0) (declare-fun [box_int] (Int) [Poly]) (declare-fun [box_bool] (Bool) [Poly]) (declare-fun [unbox_int] ([Poly]) Int) @@ -533,8 +539,16 @@ pub(crate) fn prelude_nodes(config: PreludeConfig) -> Vec { :skolemid skolem_prelude_to_unicode_bounds ))) - // Decreases + (declare-fun [height] ([Poly]) [Height]) + (declare-fun [height_lt] ([Height] [Height]) Bool) + (axiom (forall ((x [Height]) (y [Height])) (! + (= ([height_lt] x y) (and ([height_le] x y) (not (= x y)))) + :pattern (([height_lt] x y)) + :qid prelude_height_lt + :skolemid skolem_prelude_height_lt + ))) + (declare-fun [height_rec_fun] ([Poly]) [Poly]) (declare-fun [check_decrease_int] (Int Int Bool) Bool) (axiom (forall ((cur Int) (prev Int) (otherwise Bool)) (! (= ([check_decrease_int] cur prev otherwise) @@ -544,15 +558,20 @@ pub(crate) fn prelude_nodes(config: PreludeConfig) -> Vec { ) ) :pattern (([check_decrease_int] cur prev otherwise)) - :qid prelude_check_decreases - :skolemid skolem_prelude_check_decreases + :qid prelude_check_decrease_int + :skolemid skolem_prelude_check_decrease_int ))) - (declare-fun [height] ([Poly]) Int) - (axiom (forall ((x [Poly])) (! - (<= 0 ([height] x)) - :pattern (([height] x)) - :qid prelude_height - :skolemid skolem_prelude_height + (declare-fun [check_decrease_height] ([Poly] [Poly] Bool) Bool) + (axiom (forall ((cur [Poly]) (prev [Poly]) (otherwise Bool)) (! + (= ([check_decrease_height] cur prev otherwise) + (or + ([height_lt] ([height] cur) ([height] prev)) + (and (= ([height] cur) ([height] prev)) otherwise) + ) + ) + :pattern (([check_decrease_height] cur prev otherwise)) + :qid prelude_check_decrease_height + :skolemid skolem_prelude_check_decrease_height ))) // uninterpreted integer versions for bitvector Ops. first argument is bit-width @@ -601,28 +620,33 @@ pub(crate) fn prelude_nodes(config: PreludeConfig) -> Vec { pub(crate) fn datatype_height_axiom( typ_name1: &Path, - typ_name2: Option<&Path>, + typ_name2: &Option, is_variant_ident: &Ident, field: &Ident, + recursive_function_field: bool, ) -> Node { let height = str_to_node(HEIGHT); + let height_lt = str_to_node(HEIGHT_LT); + let height_rec_fun = str_to_node(HEIGHT_REC_FUN); let field = str_to_node(field.as_str()); let is_variant = str_to_node(is_variant_ident.as_str()); let typ1 = str_to_node(path_to_air_ident(typ_name1).as_str()); let box_t1 = str_to_node(prefix_box(typ_name1).as_str()); let field_of_x = match typ_name2 { Some(typ2) => { - let box_t2 = str_to_node(prefix_box(typ2).as_str()); + let box_t2 = str_to_node(prefix_box(&typ2).as_str()); node!(([box_t2] ([field] x))) } // for a field with generic type, [field]'s return type is already "Poly" None => node!(([field] x)), }; + let field_of_x = + if recursive_function_field { node!(([height_rec_fun][field_of_x])) } else { field_of_x }; node!( (axiom (forall ((x [typ1])) (! (=> ([is_variant] x) - (< + ([height_lt] ([height] [field_of_x]) ([height] ([box_t1] x)) ) @@ -633,3 +657,19 @@ pub(crate) fn datatype_height_axiom( ))) ) } + +pub(crate) fn datatype_height_axioms( + typ_name1: &Path, + typ_name2: &Option, + is_variant_ident: &Ident, + field: &Ident, + recursive_function_field: bool, +) -> Vec { + let axiom1 = datatype_height_axiom(typ_name1, typ_name2, is_variant_ident, field, false); + if recursive_function_field { + let axiom2 = datatype_height_axiom(typ_name1, typ_name2, is_variant_ident, field, true); + vec![axiom1, axiom2] + } else { + vec![axiom1] + } +} diff --git a/source/vir/src/recursion.rs b/source/vir/src/recursion.rs index 79055cec9a..ce46d58591 100644 --- a/source/vir/src/recursion.rs +++ b/source/vir/src/recursion.rs @@ -71,19 +71,37 @@ fn is_recursive_call(ctxt: &Ctxt, target: &Fun, resolved_method: &Option<(Fun, T } } -fn height_of_exp(ctxt: &Ctxt, exp: &Exp) -> Exp { +pub fn height_is_int(typ: &Typ) -> bool { + match &*crate::ast_util::undecorate_typ(typ) { + TypX::Int(_) => true, + _ => false, + } +} + +fn height_typ(ctxt: &Ctxt, exp: &Exp) -> Typ { + if height_is_int(&exp.typ) { + Arc::new(TypX::Int(IntRange::Int)) + } else { + if crate::poly::typ_is_poly(ctxt.ctx, &exp.typ) { + exp.typ.clone() + } else { + Arc::new(TypX::Boxed(exp.typ.clone())) + } + } +} + +fn exp_for_decrease(ctxt: &Ctxt, exp: &Exp) -> Exp { match &*crate::ast_util::undecorate_typ(&exp.typ) { TypX::Int(_) => exp.clone(), TypX::Datatype(..) => { - let arg = if crate::poly::typ_is_poly(ctxt.ctx, &exp.typ) { + if crate::poly::typ_is_poly(ctxt.ctx, &exp.typ) { exp.clone() } else { let op = UnaryOpr::Box(exp.typ.clone()); let argx = ExpX::UnaryOpr(op, exp.clone()); - SpannedTyped::new(&exp.span, &exp.typ, argx) - }; - let call = ExpX::UnaryOpr(UnaryOpr::Height, arg); - SpannedTyped::new(&exp.span, &Arc::new(TypX::Int(IntRange::Int)), call) + let typ = Arc::new(TypX::Boxed(exp.typ.clone())); + SpannedTyped::new(&exp.span, &typ, argx) + } } // TODO: non-panic error message _ => panic!("internal error: unsupported type for decreases {:?}", exp.typ), @@ -102,11 +120,15 @@ fn check_decrease(ctxt: &Ctxt, span: &Span, exps: &Vec) -> Exp { for (i, exp) in (0..ctxt.num_decreases).zip(exps.iter()).rev() { let decreases_at_entryx = ExpX::Var(unique_local(&decrease_at_entry(i))); let decreases_at_entry = - SpannedTyped::new(&exp.span, &Arc::new(TypX::Int(IntRange::Int)), decreases_at_entryx); + SpannedTyped::new(&exp.span, &height_typ(ctxt, exp), decreases_at_entryx); // 0 <= decreases_exp < decreases_at_entry - let args = vec![height_of_exp(ctxt, exp), decreases_at_entry, dec_exp]; + let args = vec![exp_for_decrease(ctxt, exp), decreases_at_entry, dec_exp]; let call = ExpX::Call( - CallFun::InternalFun(InternalFun::CheckDecreaseInt), + if height_is_int(&exp.typ) { + CallFun::InternalFun(InternalFun::CheckDecreaseInt) + } else { + CallFun::InternalFun(InternalFun::CheckDecreaseHeight) + }, Arc::new(vec![]), Arc::new(args), ); @@ -337,7 +359,7 @@ fn mk_decreases_at_entry(ctxt: &Ctxt, span: &Span, exps: &Vec) -> (Vec = Vec::new(); let mut stm_assigns: Vec = Vec::new(); for (i, exp) in exps.iter().enumerate() { - let typ = Arc::new(TypX::Int(IntRange::Int)); + let typ = height_typ(ctxt, exp); let decl = Arc::new(LocalDeclX { ident: unique_local(&decrease_at_entry(i)), typ: typ.clone(), @@ -351,7 +373,7 @@ fn mk_decreases_at_entry(ctxt: &Ctxt, span: &Span, exps: &Vec) -> (Vec Trace } UnaryOpr::HasType(_) | UnaryOpr::IntegerTypeBound(..) - | UnaryOpr::Height | UnaryOpr::IsVariant { .. } | UnaryOpr::TupleField { .. } | UnaryOpr::Field(_) => { diff --git a/source/vir/src/sst.rs b/source/vir/src/sst.rs index 62afb06eb5..4d4942f713 100644 --- a/source/vir/src/sst.rs +++ b/source/vir/src/sst.rs @@ -45,6 +45,7 @@ pub enum InternalFun { ClosureReq, ClosureEns, CheckDecreaseInt, + CheckDecreaseHeight, } #[derive(Debug, Clone, Hash)] diff --git a/source/vir/src/sst_to_air.rs b/source/vir/src/sst_to_air.rs index 560ac43e3a..6659603414 100644 --- a/source/vir/src/sst_to_air.rs +++ b/source/vir/src/sst_to_air.rs @@ -314,6 +314,23 @@ pub(crate) fn typ_invariant(ctx: &Ctx, typ: &Typ, expr: &Expr) -> Option { } } +pub(crate) fn datatype_box_prefix(ctx: &Ctx, typ: &Typ) -> Option { + match &**typ { + TypX::Datatype(path, _) => { + if ctx.datatype_is_transparent[path] { + Some(path.clone()) + } else { + if let Some(monotyp) = typ_as_mono(typ) { + Some(crate::sst_to_air::monotyp_to_path(&monotyp)) + } else { + None + } + } + } + _ => None, + } +} + fn try_box(ctx: &Ctx, expr: Expr, typ: &Typ) -> Option { let f_name = match &**typ { TypX::Bool => Some(str_ident(crate::def::BOX_BOOL)), @@ -321,16 +338,11 @@ fn try_box(ctx: &Ctx, expr: Expr, typ: &Typ) -> Option { TypX::Tuple(_) => None, TypX::Lambda(typs, _) => Some(prefix_box(&prefix_lambda_type(typs.len()))), TypX::AnonymousClosure(..) => unimplemented!(), - TypX::Datatype(path, _) => { - if ctx.datatype_is_transparent[path] { - Some(prefix_box(&path)) + TypX::Datatype(..) => { + if let Some(prefix) = datatype_box_prefix(ctx, typ) { + Some(prefix_box(&prefix)) } else { - if let Some(monotyp) = typ_as_mono(typ) { - let dpath = crate::sst_to_air::monotyp_to_path(&monotyp); - Some(prefix_box(&dpath)) - } else { - panic!("abstract datatype should be boxed") - } + panic!("abstract datatype should be boxed") } } TypX::Decorate(_, t) => return try_box(ctx, expr, t), @@ -718,6 +730,9 @@ pub(crate) fn exp_to_expr(ctx: &Ctx, exp: &Exp, expr_ctxt: &ExprCtxt) -> Result< InternalFun::ClosureReq => str_ident(crate::def::CLOSURE_REQ), InternalFun::ClosureEns => str_ident(crate::def::CLOSURE_ENS), InternalFun::CheckDecreaseInt => str_ident(crate::def::CHECK_DECREASE_INT), + InternalFun::CheckDecreaseHeight => { + str_ident(crate::def::CHECK_DECREASE_HEIGHT) + } }, Arc::new(exprs), )) @@ -788,6 +803,7 @@ pub(crate) fn exp_to_expr(ctx: &Ctx, exp: &Exp, expr_ctxt: &ExprCtxt) -> Result< return Ok(bv_e); } } + UnaryOp::HeightTrigger => panic!("internal error: unexpected HeightTrigger"), UnaryOp::Trigger(_) => exp_to_expr(ctx, arg, expr_ctxt)?, UnaryOp::CoerceMode { .. } => { panic!("internal error: TupleField should have been removed before here") @@ -821,6 +837,9 @@ pub(crate) fn exp_to_expr(ctx: &Ctx, exp: &Exp, expr_ctxt: &ExprCtxt) -> Result< ); clip_bitwise_result(bit_expr, exp)? } + UnaryOp::HeightTrigger => { + str_apply(crate::def::HEIGHT, &vec![exp_to_expr(ctx, exp, expr_ctxt)?]) + } UnaryOp::Trigger(_) => exp_to_expr(ctx, exp, expr_ctxt)?, UnaryOp::Clip { range: IntRange::Int, .. } => exp_to_expr(ctx, exp, expr_ctxt)?, UnaryOp::Clip { range, .. } => { @@ -890,10 +909,6 @@ pub(crate) fn exp_to_expr(ctx: &Ctx, exp: &Exp, expr_ctxt: &ExprCtxt) -> Result< let name = Arc::new(ARCH_SIZE.to_string()); Arc::new(ExprX::Var(name)) } - UnaryOpr::Height => Arc::new(ExprX::Apply( - str_ident(crate::def::HEIGHT), - Arc::new(vec![exp_to_expr(ctx, exp, expr_ctxt)?]), - )), UnaryOpr::Field(FieldOpr { datatype, variant, field, get_variant: _ }) => { let expr = exp_to_expr(ctx, exp, expr_ctxt)?; Arc::new(ExprX::Apply( @@ -915,6 +930,12 @@ pub(crate) fn exp_to_expr(ctx: &Ctx, exp: &Exp, expr_ctxt: &ExprCtxt) -> Result< format!("error: cannot use bit-vector arithmetic on type {:?}", exp.typ), ); } + if let BinaryOp::HeightCompare { .. } = op { + return error( + &exp.span, + format!("error: cannot use bit-vector arithmetic on is_smaller_than"), + ); + } // disallow signed integer from bitvec reasoning. However, allow that for shift // TODO: sanity check for shift let _ = match op { @@ -950,6 +971,7 @@ pub(crate) fn exp_to_expr(ctx: &Ctx, exp: &Exp, expr_ctxt: &ExprCtxt) -> Result< _ => (), }; let bop = match op { + BinaryOp::HeightCompare { .. } => unreachable!(), BinaryOp::Eq(_) => air::ast::BinaryOp::Eq, BinaryOp::Ne => unreachable!(), BinaryOp::Arith(ArithOp::Add, _) => air::ast::BinaryOp::BitAdd, @@ -998,6 +1020,20 @@ pub(crate) fn exp_to_expr(ctx: &Ctx, exp: &Exp, expr_ctxt: &ExprCtxt) -> Result< BinaryOp::Implies => { return Ok(mk_implies(&lh, &rh)); } + BinaryOp::HeightCompare { strictly_lt, recursive_function_field } => { + let lhh = str_apply(crate::def::HEIGHT, &vec![lh]); + let rh = if *recursive_function_field { + str_apply(crate::def::HEIGHT_REC_FUN, &vec![rh]) + } else { + rh + }; + let rhh = str_apply(crate::def::HEIGHT, &vec![rh]); + if *strictly_lt { + return Ok(str_apply(crate::def::HEIGHT_LT, &vec![lhh, rhh])); + } else { + return Ok(mk_eq(&lhh, &rhh)); + } + } BinaryOp::Arith(ArithOp::Add, _) => { ExprX::Multi(MultiOp::Add, Arc::new(vec![lh, rh])) } @@ -1067,6 +1103,7 @@ pub(crate) fn exp_to_expr(ctx: &Ctx, exp: &Exp, expr_ctxt: &ExprCtxt) -> Result< BinaryOp::Or => unreachable!(), BinaryOp::Xor => unreachable!(), BinaryOp::Implies => unreachable!(), + BinaryOp::HeightCompare { .. } => unreachable!(), BinaryOp::Eq(_) => air::ast::BinaryOp::Eq, BinaryOp::Ne => unreachable!(), BinaryOp::Inequality(InequalityOp::Le) => air::ast::BinaryOp::Le, diff --git a/source/vir/src/sst_util.rs b/source/vir/src/sst_util.rs index b2eb80d83b..2090af5d83 100644 --- a/source/vir/src/sst_util.rs +++ b/source/vir/src/sst_util.rs @@ -247,6 +247,7 @@ impl BinaryOp { Or => (6, 6, 7), Xor => (22, 22, 23), // Rust doesn't have a logical XOR, so this is consistent with BitXor Implies => (3, 4, 3), + HeightCompare { .. } => (90, 5, 5), Eq(_) | Ne => (10, 11, 11), Inequality(_) => (10, 10, 10), Arith(o, _) => match o { @@ -291,6 +292,7 @@ impl ExpX { Unary(op, exp) => match op { UnaryOp::Not | UnaryOp::BitNot => (format!("!{}", exp.x.to_string_prec(99)), 90), UnaryOp::Clip { .. } => (format!("clip({})", exp), 99), + UnaryOp::HeightTrigger => (format!("height_trigger({})", exp), 99), UnaryOp::StrLen => (format!("{}.len()", exp.x.to_string_prec(99)), 90), UnaryOp::StrIsAscii => (format!("{}.is_ascii()", exp.x.to_string_prec(99)), 90), UnaryOp::CharToInt => (format!("{} as char", exp.x.to_string_prec(99)), 90), @@ -303,7 +305,6 @@ impl ExpX { match op { Box(_) => (format!("box({})", exp), 99), Unbox(_) => (format!("unbox({})", exp), 99), - Height => (format!("height({})", exp), 99), HasType(t) => (format!("{}.has_type({:?})", exp, t), 99), IntegerTypeBound(kind, mode) => (format!("{:?}.{:?}({})", kind, mode, exp), 99), IsVariant { datatype: _, variant } => { @@ -327,6 +328,7 @@ impl ExpX { Or => "||", Xor => "^", Implies => "==>", + HeightCompare { .. } => "", Eq(_) => "==", Ne => "!=", Inequality(o) => match o { @@ -353,6 +355,8 @@ impl ExpX { }; if let BinaryOp::StrGetChar = op { (format!("{}.get_char({})", left, e2), prec_exp) + } else if let HeightCompare { .. } = op { + (format!("height_compare({left}, {right})"), prec_exp) } else { (format!("{} {} {}", left, op_str, right), prec_exp) } diff --git a/source/vir/src/sst_visitor.rs b/source/vir/src/sst_visitor.rs index 52810e4493..464f361c5b 100644 --- a/source/vir/src/sst_visitor.rs +++ b/source/vir/src/sst_visitor.rs @@ -516,7 +516,6 @@ where UnaryOpr::TupleField { .. } => op.clone(), UnaryOpr::Field { .. } => op.clone(), UnaryOpr::IntegerTypeBound(_, _) => op.clone(), - UnaryOpr::Height => op.clone(), UnaryOpr::CustomErr(_msg) => op.clone(), }; ok_exp(ExpX::UnaryOpr(op.clone(), fe(env, e1)?)) diff --git a/source/vir/src/triggers.rs b/source/vir/src/triggers.rs index 1c8be1df5e..7d4c436d15 100644 --- a/source/vir/src/triggers.rs +++ b/source/vir/src/triggers.rs @@ -1,6 +1,6 @@ use crate::ast::{ - BinaryOp, Ident, IntegerTypeBoundKind, TriggerAnnotation, Typ, TypX, UnaryOp, UnaryOpr, VarAt, - VirErr, + BinaryOp, Ident, IntegerTypeBoundKind, SpannedTyped, TriggerAnnotation, Typ, TypX, UnaryOp, + UnaryOpr, VarAt, VirErr, }; use crate::ast_util::error; use crate::context::Ctx; @@ -25,10 +25,20 @@ struct State { coverage: HashMap, HashSet>, } -fn remove_boxing(exp: &Exp) -> Exp { +fn preprocess_exp(exp: &Exp) -> Exp { match &exp.x { ExpX::UnaryOpr(UnaryOpr::Box(_), e) | ExpX::UnaryOpr(UnaryOpr::Unbox(_), e) => { - remove_boxing(e) + preprocess_exp(e) + } + ExpX::Binary(BinaryOp::HeightCompare { .. }, e1, _) => { + // We don't let users use the "height" function or Height type directly. + // However, when using HeightCompare, it's useful to trigger on "height", + // and it's not useful to trigger on HeightCompare, + // which is essentially a "<" operator on heights. + // Therefore, we replace HeightCompare triggers with height triggers. + // (Or rather, HeightCompare is the interface by which users write height triggers.) + let typ = Arc::new(TypX::Bool); // arbitrary type for trigger + SpannedTyped::new(&exp.span, &typ, ExpX::Unary(UnaryOp::HeightTrigger, e1.clone())) } _ => exp.clone(), } @@ -46,7 +56,7 @@ fn check_trigger_expr( | ExpX::CallLambda(..) | ExpX::UnaryOpr(UnaryOpr::Field { .. }, _) | ExpX::UnaryOpr(UnaryOpr::IsVariant { .. }, _) - | ExpX::Unary(UnaryOp::Trigger(_), _) => {} + | ExpX::Unary(UnaryOp::Trigger(_) | UnaryOp::HeightTrigger, _) => {} // allow triggers for bitvector operators ExpX::Binary(BinaryOp::Bitwise(_, _), _, _) | ExpX::Unary(UnaryOp::BitNot, _) => {} ExpX::BinaryOpr(crate::ast::BinaryOpr::ExtEq(..), _, _) => {} @@ -113,6 +123,7 @@ fn check_trigger_expr( UnaryOp::Trigger(_) | UnaryOp::Clip { .. } | UnaryOp::BitNot + | UnaryOp::HeightTrigger | UnaryOp::StrLen | UnaryOp::StrIsAscii | UnaryOp::CharToInt => Ok(()), @@ -126,7 +137,6 @@ fn check_trigger_expr( | UnaryOpr::IsVariant { .. } | UnaryOpr::TupleField { .. } | UnaryOpr::Field { .. } - | UnaryOpr::Height | UnaryOpr::CustomErr(_) | UnaryOpr::IntegerTypeBound( IntegerTypeBoundKind::SignedMin | IntegerTypeBoundKind::ArchWordBits, @@ -145,6 +155,10 @@ fn check_trigger_expr( And | Or | Xor | Implies | Eq(_) | Ne => { error(&exp.span, "triggers cannot contain boolean operators") } + HeightCompare { .. } => error( + &exp.span, + "triggers cannot contain interior is_smaller_than expressions", + ), Inequality(_) => Ok(()), Arith(..) => error( &exp.span, @@ -193,6 +207,7 @@ fn check_trigger_expr( | UnaryOp::CoerceMode { .. } | UnaryOp::CharToInt => true, UnaryOp::MustBeFinalized => true, + UnaryOp::HeightTrigger => false, UnaryOp::Not | UnaryOp::BitNot | UnaryOp::StrLen | UnaryOp::StrIsAscii => false, }, ExpX::Binary(op, _, _) => { @@ -230,7 +245,7 @@ fn get_manual_triggers(state: &mut State, exp: &Exp) -> Result<(), VirErr> { } ExpX::Unary(UnaryOp::Trigger(TriggerAnnotation::Trigger(group)), e1) => { let mut free_vars: HashSet = HashSet::new(); - let e1 = remove_boxing(&e1); + let e1 = preprocess_exp(&e1); check_trigger_expr(state, &e1, &mut free_vars, &lets)?; for x in &free_vars { if map.get(x) == Some(&true) && !state.trigger_vars.contains(x) { @@ -256,7 +271,7 @@ fn get_manual_triggers(state: &mut State, exp: &Exp) -> Result<(), VirErr> { for (n, trigger) in triggers.iter().enumerate() { let group = Some(n as u64); let mut coverage: HashSet = HashSet::new(); - let es: Vec = trigger.iter().map(remove_boxing).collect(); + let es: Vec = trigger.iter().map(preprocess_exp).collect(); for e in &es { let mut free_vars: HashSet = HashSet::new(); check_trigger_expr(state, e, &mut free_vars, &lets)?; diff --git a/source/vir/src/triggers_auto.rs b/source/vir/src/triggers_auto.rs index c578968380..88edd84390 100644 --- a/source/vir/src/triggers_auto.rs +++ b/source/vir/src/triggers_auto.rs @@ -316,6 +316,7 @@ fn gather_terms(ctxt: &mut Ctxt, ctx: &Ctx, exp: &Exp, depth: u64) -> (bool, Ter | UnaryOp::CoerceMode { .. } | UnaryOp::MustBeFinalized | UnaryOp::CharToInt => 0, + UnaryOp::HeightTrigger => 1, UnaryOp::Trigger(_) | UnaryOp::Clip { .. } | UnaryOp::BitNot => 1, UnaryOp::StrIsAscii | UnaryOp::StrLen => fail_on_strop(), }; @@ -330,7 +331,6 @@ fn gather_terms(ctxt: &mut Ctxt, ctx: &Ctx, exp: &Exp, depth: u64) -> (bool, Ter } ExpX::UnaryOpr(UnaryOpr::Box(_), e1) => gather_terms(ctxt, ctx, e1, depth), ExpX::UnaryOpr(UnaryOpr::Unbox(_), e1) => gather_terms(ctxt, ctx, e1, depth), - ExpX::UnaryOpr(UnaryOpr::Height, e1) => gather_terms(ctxt, ctx, e1, depth), ExpX::UnaryOpr(UnaryOpr::CustomErr(_), e1) => gather_terms(ctxt, ctx, e1, depth), ExpX::UnaryOpr(UnaryOpr::HasType(_), _) => { (false, Arc::new(TermX::App(ctxt.other(), Arc::new(vec![])))) @@ -362,6 +362,7 @@ fn gather_terms(ctxt: &mut Ctxt, ctx: &Ctx, exp: &Exp, depth: u64) -> (bool, Ter use BinaryOp::*; let depth = match op { And | Or | Xor | Implies | Eq(_) => 0, + HeightCompare { .. } => 1, Ne | Inequality(_) | Arith(..) => 1, Bitwise(..) => 1, StrGetChar => fail_on_strop(),