diff --git a/infer/src/pulse/PulseAbstractValue.ml b/infer/src/pulse/PulseAbstractValue.ml index 8a4f9aac236..bc92a89b1fc 100644 --- a/infer/src/pulse/PulseAbstractValue.ml +++ b/infer/src/pulse/PulseAbstractValue.ml @@ -39,6 +39,14 @@ let pp f v = if is_restricted v then F.fprintf f "a%d" (-v) else F.fprintf f "v% let yojson_of_t l = `String (F.asprintf "%a" pp l) +let compare_unrestricted_first v1 v2 = + if is_restricted v1 then + if is_restricted v2 then (* compare absolute values *) compare v2 v1 + else (* unrestricted [v2] first *) 1 + else if is_restricted v2 then (* unrestricted [v1] first *) -1 + else compare v1 v2 + + module PPKey = struct type nonrec t = t [@@deriving compare] diff --git a/infer/src/pulse/PulseAbstractValue.mli b/infer/src/pulse/PulseAbstractValue.mli index fb8b638dd21..c1919550290 100644 --- a/infer/src/pulse/PulseAbstractValue.mli +++ b/infer/src/pulse/PulseAbstractValue.mli @@ -31,6 +31,9 @@ val is_unrestricted : t -> bool val pp : F.formatter -> t -> unit +val compare_unrestricted_first : t -> t -> int +(** an alternative comparison function that sorts unrestricted variables before restricted variables *) + module Set : PrettyPrintable.PPSet with type elt = t module Map : sig diff --git a/infer/src/pulse/PulseFormula.ml b/infer/src/pulse/PulseFormula.ml index e23035fd679..0e46802c599 100644 --- a/infer/src/pulse/PulseFormula.ml +++ b/infer/src/pulse/PulseFormula.ml @@ -15,8 +15,9 @@ module ValueHistory = PulseValueHistory module Var = struct include PulseAbstractValue - (* This is used by [var_eqs] to prefer restricted variables as representants. *) let is_simpler_than v1 v2 = PulseAbstractValue.compare v1 v2 < 0 + + let is_simpler_or_equal v1 v2 = PulseAbstractValue.compare v1 v2 <= 0 end module Q = QSafeCapped @@ -92,8 +93,8 @@ module LinArith : sig val mult : Q.t -> t -> t val solve_eq : t -> t -> (Var.t * t) option SatUnsat.t - (** [solve_eq l1 l2] is [Sat (Some (x, l))] if [l1=l2 <=> x=l] (with x larger than all variables - of l), [Sat None] if [l1 = l2] is always true, and [Unsat] if it is always false *) + (** [solve_eq l1 l2] is [Sat (Some (x, l))] if [l1=l2 <=> x=l], [Sat None] if [l1 = l2] is always + true, and [Unsat] if it is always false *) val of_q : Q.t -> t @@ -124,8 +125,8 @@ module LinArith : sig val subst_variable : Var.t -> _ subst_target -> t -> t (** same as above for a single variable to substitute (more optimized) *) - val get_largest : t -> Var.t option - (** the largest [v∊l], according to [Var.compare] *) + val get_simplest : t -> Var.t option + (** the smallest [v∊l] according to [Var.is_simpler_than] *) (** {2 Tableau-Specific Operations} *) @@ -133,9 +134,10 @@ module LinArith : sig (** [true] iff all the variables involved in the expression satisfy {!Var.is_restricted} *) val solve_for_unrestricted : Var.t -> t -> (Var.t * t) option - (** [solve_for_unrestricted u l] returns [Some (x, l')] where [x] is the largest unrestricted - variable in [u=l] and [u=l <=> x=l']. If there are no unrestricted variables in [u=l], it - returns [None]. *) + (** if [l] contains at least one unrestricted variable then [solve_for_unrestricted u l] is + [Some (x, l')] where [x] is the smallest unrestricted variable in [l] and [u=l <=> x=l']. If + there are no unrestricted variables in [l] then [solve_for_unrestricted u l] is [None]. + Assumes [u∉l]. *) val pivot : Var.t * Q.t -> t -> t (** [pivot (v, q) l] assumes [v] appears in [l] with coefficient [q] and returns [l'] such that @@ -155,12 +157,28 @@ module LinArith : sig val is_minimized : t -> bool (** [is_minimized l] iff [classify_minimized_maximized l] is either [`Minimized] or [`Constant] *) end = struct + (* define our own var map to get a custom order: we want to place unrestricted variables first in + the map so that [solve_for_unrestricted] can be implemented in terms of [solve_eq] easily *) + module VarMap = struct + include PrettyPrintable.MakePPMap (struct + type t = Var.t + + let pp = Var.pp + + let compare v1 v2 = Var.compare_unrestricted_first v1 v2 + end) + + (* unpleasant that we have to duplicate this definition from [Var.Map] here *) + let yojson_of_t yojson_of_val m = + `List (List.map ~f:(fun (k, v) -> `List [Var.yojson_of_t k; yojson_of_val v]) (bindings m)) + end + (** invariant: the representation is always "canonical": coefficients cannot be [Q.zero] *) - type t = Q.t * Q.t Var.Map.t [@@deriving compare, equal] + type t = Q.t * Q.t VarMap.t [@@deriving compare, equal] - let yojson_of_t (c, vs) = `List [Var.Map.yojson_of_t Q.yojson_of_t vs; Q.yojson_of_t c] + let yojson_of_t (c, vs) = `List [VarMap.yojson_of_t Q.yojson_of_t vs; Q.yojson_of_t c] - let fold (_, vs) ~init ~f = IContainer.fold_of_pervasives_map_fold Var.Map.fold vs ~init ~f + let fold (_, vs) ~init ~f = IContainer.fold_of_pervasives_map_fold VarMap.fold vs ~init ~f type 'term_t subst_target = | QSubst of Q.t @@ -170,7 +188,7 @@ end = struct | NonLinearTermSubst of 'term_t let pp pp_var fmt (c, vs) = - if Var.Map.is_empty vs then Q.pp_print fmt c + if VarMap.is_empty vs then Q.pp_print fmt c else let pp_c fmt c = if not (Q.is_zero c) then @@ -186,7 +204,7 @@ end = struct in let pp_vs fmt vs = Pp.collection ~sep:"@;" - ~fold:(IContainer.fold_of_pervasives_map_fold Var.Map.fold) + ~fold:(IContainer.fold_of_pervasives_map_fold VarMap.fold) (fun fmt (v, q) -> F.fprintf fmt "%a%a" pp_coeff q pp_var v ; is_first := false ) @@ -197,33 +215,33 @@ end = struct let add (c1, vs1) (c2, vs2) = ( Q.add c1 c2 - , Var.Map.union + , VarMap.union (fun _v c1 c2 -> let c = Q.add c1 c2 in if Q.is_zero c then None else Some c ) vs1 vs2 ) - let minus (c, vs) = (Q.neg c, Var.Map.map (fun c -> Q.neg c) vs) + let minus (c, vs) = (Q.neg c, VarMap.map (fun c -> Q.neg c) vs) let subtract l1 l2 = add l1 (minus l2) - let zero = (Q.zero, Var.Map.empty) + let zero = (Q.zero, VarMap.empty) - let is_zero (c, vs) = Q.is_zero c && Var.Map.is_empty vs + let is_zero (c, vs) = Q.is_zero c && VarMap.is_empty vs let mult q ((c, vs) as l) = if Q.is_zero q then (* needed for correctness: coeffs cannot be zero *) zero else if Q.is_one q then (* purely an optimisation *) l - else (Q.mul q c, Var.Map.map (fun c -> Q.mul q c) vs) + else (Q.mul q c, VarMap.map (fun c -> Q.mul q c) vs) let pivot (x, coeff) (c, vs) = let d = Q.neg coeff in let vs' = - Var.Map.fold - (fun v' coeff' vs' -> if Var.equal v' x then vs' else Var.Map.add v' (Q.div coeff' d) vs') - vs Var.Map.empty + VarMap.fold + (fun v' coeff' vs' -> if Var.equal v' x then vs' else VarMap.add v' (Q.div coeff' d) vs') + vs VarMap.empty in (* note: [d≠0] by the invariant of the coefficient map [vs] *) let c' = Q.div c d in @@ -231,7 +249,7 @@ end = struct let solve_eq_zero ((c, vs) as l) = - match Var.Map.max_binding_opt vs with + match VarMap.min_binding_opt vs with | None -> if Q.is_zero c then Sat None else ( @@ -243,15 +261,15 @@ end = struct let solve_eq l1 l2 = solve_eq_zero (subtract l1 l2) - let of_var v = (Q.zero, Var.Map.singleton v Q.one) + let of_var v = (Q.zero, VarMap.singleton v Q.one) - let of_q q = (q, Var.Map.empty) + let of_q q = (q, VarMap.empty) - let get_as_const (c, vs) = if Var.Map.is_empty vs then Some c else None + let get_as_const (c, vs) = if VarMap.is_empty vs then Some c else None let get_as_var (c, vs) = if Q.is_zero c then - match Var.Map.is_singleton_or_more vs with + match VarMap.is_singleton_or_more vs with | Singleton (x, cx) when Q.is_one cx -> Some x | _ -> @@ -262,7 +280,7 @@ end = struct let get_as_variable_difference (c, vs) = if Q.is_zero c then (* the coefficient has to be 0 *) - let vs_seq = Var.Map.to_seq vs in + let vs_seq = VarMap.to_seq vs in let open IOption.Let_syntax in (* check that the expression consists of exactly two variables with coeffs 1 and -1 *) let* (x, cx), vs_seq = Seq.uncons vs_seq in @@ -276,7 +294,7 @@ end = struct let get_constant_part (c, _) = c - let get_coefficient v (_, vs) = Var.Map.find_opt v vs + let get_coefficient v (_, vs) = VarMap.find_opt v vs let of_subst_target v0 = function | QSubst q -> @@ -292,7 +310,7 @@ end = struct let fold_subst_variables ((c, vs_foreign) as l0) ~init ~f = let changed = ref false in let acc_f, l' = - Var.Map.fold + VarMap.fold (fun v_foreign q0 (acc_f, l) -> let acc_f, op = f acc_f v_foreign in ( match op with @@ -304,7 +322,7 @@ end = struct changed := true ) ; (acc_f, add (mult q0 (of_subst_target v_foreign op)) l) ) vs_foreign - (init, (c, Var.Map.empty)) + (init, (c, VarMap.empty)) in let l' = if !changed then l' else l0 in (acc_f, l') @@ -314,38 +332,40 @@ end = struct (* OPTIM: for a single variable we can avoid iterating over the coefficient map *) let subst_variable x subst_target ((c, vs) as l0) = - match Var.Map.find_opt x vs with + match VarMap.find_opt x vs with | None -> l0 | Some q -> - let vs' = Var.Map.remove x vs in + let vs' = VarMap.remove x vs in add (mult q (of_subst_target x subst_target)) (c, vs') - let get_variables (_, vs) = Var.Map.to_seq vs |> Seq.map fst + let get_variables (_, vs) = VarMap.to_seq vs |> Seq.map fst - let get_largest l = Var.Map.max_binding_opt (snd l) |> Option.map ~f:fst + let get_simplest l = VarMap.min_binding_opt (snd l) |> Option.map ~f:fst (** {2 Tableau-Specific Operations} *) - let is_restricted (_q, l) = - (* HACK: since restricted < unrestricted, it's enough to check the maximum *) - match Var.Map.max_binding_opt l with None -> true | Some (x, _) -> Var.is_restricted x + let is_restricted l = + (* HACK: unrestricted variables come first so we first test if there exists any unrestricted + variable in the map by checking its min element *) + not (get_simplest l |> Option.exists ~f:Var.is_unrestricted) let solve_for_unrestricted w l = - match solve_eq (of_var w) l with - | Unsat | Sat None -> - None - | Sat (Some (x, _) as r) when Var.is_unrestricted x -> - r - | Sat (Some _) -> - None + if not (is_restricted l) then ( + match solve_eq l (of_var w) with + | Unsat | Sat None -> + None + | Sat (Some (x, _) as r) -> + assert (Var.is_unrestricted x) ; + r ) + else None let classify_minimized_maximized (_, vs) = let all_pos, all_neg = - Var.Map.fold + VarMap.fold (fun _ coeff (all_pos, all_neg) -> (Q.(coeff >= zero) && all_pos, Q.(coeff <= zero) && all_neg) ) vs (true, true) @@ -2063,13 +2083,6 @@ module Formula = struct (** opaque because we need to normalize variables in the co-domain of term equalities on the fly *) type term_eqs - (* NOTE on variable orders. Both [var_eqs] and [linear_eqs] act as substitutions. We want these - substitutions to leave us with as many restricted variables as possible, so that we derive - more facts for the tableau. As representants in [var_eqs], we prefer restricted variables; - as keys in the [linear_eqs] map, we prefer unrestricted variables. We do so based on the - ASSUMPTION that the order on {!PulseAbstractValue.t} says that restricted variables are - smaller than unrestricted variables: for [var_eqs] we pick the representant to be the minimum, - for [linear_eqs] we pick the key to be the maximum. *) type t = private { var_eqs: var_eqs (** Equality relation between variables. We want to only use canonical representatives @@ -2093,8 +2106,8 @@ module Formula = struct [domain(linear_eqs) ∩ range(linear_eqs) = ∅], when seeing [linear_eqs] as a map [x->l] - 2. for all [x=l ∊ linear_eqs], [x > max({x'|x'∊l})] according to [is_simpler_than] - (in other words: [x] is the most complex variable in [x=l]). *) + 2. for all [x=l ∊ linear_eqs], [x < min({x'|x'∊l})] according to [is_simpler_than] + (in other words: [x] is the simplest variable in [x=l]). *) ; term_eqs: term_eqs (** Equalities of the form [t = x], used to detect when two abstract values are equal to the same term (hence equal). Together with [var_eqs] and [linear_eqs] this gives a @@ -2259,8 +2272,6 @@ module Formula = struct -> atoms_occurrences:AtomMapOccurrences.t -> t (** escape hatch *) - - val check_invariant : t -> unit end = struct type term_eqs = Term.VarMap.t_ [@@deriving compare, equal, yojson_of] @@ -2296,16 +2307,6 @@ module Formula = struct ; atoms_occurrences= Var.Map.empty } - let check_invariant_linear_eq_lhs_max v l = - LinArith.fold l ~init:() ~f:(fun () (w, _) -> - if Var.compare v w <= 0 then - L.die InternalError "linear_eqs: lhs should be strictly bigger that all vars in rhs" ) - - - let check_invariant_linear_eqs linear_eqs = - Var.Map.iter check_invariant_linear_eq_lhs_max linear_eqs - - let get_repr phi x = VarUF.find phi.var_eqs x let get_repr_as_var phi x = (get_repr phi x :> Var.t) @@ -2446,12 +2447,6 @@ module Formula = struct F.pp_close_box fmt () - let check_invariant phi = - if Debug.debug then ( - Debug.p "Checking invariant of %a@\n" (pp_with_pp_var Var.pp) phi ; - check_invariant_linear_eqs phi.linear_eqs ) - - (* {2 mutations} *) let add_const_eq v t phi = @@ -2642,7 +2637,6 @@ module Formula = struct let add_linear_eq v l phi = Debug.p "add_linear_eq %a=%a@\n" Var.pp v (LinArith.pp Var.pp) l ; - if Debug.debug then check_invariant_linear_eq_lhs_max v l ; let phi = match Var.Map.find_opt v phi.linear_eqs with | Some l_old -> @@ -2954,7 +2948,6 @@ module Formula = struct [l1] and [l2] should have already been through {!normalize_linear} (w.r.t. [phi]) *) let rec solve_normalized_lin_eq ~fuel ?(force_no_tableau = false) new_eqs l1 l2 phi = - Unsafe.check_invariant phi ; Debug.p "solve_normalized_lin_eq: %a=%a@\n" (LinArith.pp Var.pp) l1 (LinArith.pp Var.pp) l2 ; LinArith.solve_eq l1 l2 >>= function @@ -3260,15 +3253,15 @@ module Formula = struct Var.pp v (LinArith.pp Var.pp) lv ; let r = let lv' = LinArith.subst_variable x (LinSubst lx) lv in - (* check the invariant that [v] is (strictly) larger than any variable in + (* check the invariant that [v] is (strictly) simpler than any variable in [lv']; because the invariant was true before it's enough to check that it's - larger than any variable in [lx] *) + simpler than any variable in [lx] *) let needs_pivot = - match LinArith.get_largest lx with + match LinArith.get_simplest lx with | None -> false | Some min -> - Var.compare v min <= 0 + Var.is_simpler_or_equal min v in Debug.p "needs_pivot= %b@\n" needs_pivot ; if needs_pivot then diff --git a/infer/src/pulse/unit/PulseFormulaTest.ml b/infer/src/pulse/unit/PulseFormulaTest.ml index 3e61aaef029..9f1d3411a86 100644 --- a/infer/src/pulse/unit/PulseFormulaTest.ml +++ b/infer/src/pulse/unit/PulseFormulaTest.ml @@ -247,7 +247,7 @@ let%test_module "normalization" = test (x < y) ; [%expect {| - conditions: (empty) phi: linear_eqs: y = a1 +x +1 && term_eqs: [a1 +x +1]=y|}] + conditions: (empty) phi: linear_eqs: x = y -a1 -1 && term_eqs: [y -a1 -1]=x|}] let%expect_test _ = @@ -315,8 +315,8 @@ let%test_module "normalization" = [%expect {| conditions: (empty) - phi: linear_eqs: v7 = x +v6 ∧ v8 = x +v6 +1 ∧ v10 = 0 - && term_eqs: 0=v10∧[x +v6]=v7∧[x +v6 +1]=v8∧(z×v8)=v9∧(v×y)=v6∧(v9÷w)=v10 + phi: linear_eqs: x = -v6 +v8 -1 ∧ v7 = v8 -1 ∧ v10 = 0 + && term_eqs: 0=v10∧[-v6 +v8 -1]=x∧[v8 -1]=v7∧(z×v8)=v9∧(v×y)=v6∧(v9÷w)=v10 && intervals: v10=0|}] @@ -327,8 +327,8 @@ let%test_module "normalization" = {| conditions: (empty) phi: var_eqs: v8=v9=v10 - && linear_eqs: y = -1/3·x -1/3 ∧ v6 = -x -1 ∧ v7 = -1 ∧ v8 = 0 - && term_eqs: (-1)=v7∧0=v8∧[-x -1]=v6∧[-1/3·x -1/3]=y∧[12·x +36·y +12]=v8 + && linear_eqs: x = -v6 -1 ∧ y = 1/3·v6 ∧ v7 = -1 ∧ v8 = 0 + && term_eqs: (-1)=v7∧0=v8∧[-v6 -1]=x∧[1/3·v6]=y && intervals: v8=0|}] @@ -339,10 +339,9 @@ let%test_module "normalization" = {| conditions: (empty) phi: var_eqs: v8=v9=v10 - && linear_eqs: y = -1/3·x -1/3 ∧ z = 12 ∧ w = 1 ∧ v = 3 ∧ v6 = -x -1 + && linear_eqs: x = -v6 -1 ∧ y = 1/3·v6 ∧ z = 12 ∧ w = 1 ∧ v = 3 ∧ v7 = -1 ∧ v8 = 0 - && term_eqs: (-1)=v7∧0=v8∧1=w∧3=v∧12=z∧[-x -1]=v6∧[-1/3·x -1/3]=y - ∧[12·x +36·y +12]=v8 + && term_eqs: (-1)=v7∧0=v8∧1=w∧3=v∧12=z∧[-v6 -1]=x∧[1/3·v6]=y && intervals: z=12 ∧ w=1 ∧ v=3 ∧ v8=0|}] @@ -353,10 +352,10 @@ let%test_module "normalization" = {| conditions: (empty) phi: var_eqs: z=v7 - && linear_eqs: x = 2 ∧ y = -42 ∧ w = z -2 ∧ v6 = 4 - && term_eqs: (-42)=y∧2=x∧4=v6∧[z -2]=w + && linear_eqs: x = 2 ∧ y = -42 ∧ z = w +2 ∧ v6 = 4 + && term_eqs: (-42)=y∧2=x∧4=v6∧[w +2]=z && intervals: y=-42 ∧ v6=4 - && atoms: {is_int(z) = 1} + && atoms: {is_int([w +2]) = 1} |}] @@ -420,11 +419,11 @@ let%test_module "variable elimination" = Formula: conditions: (empty) phi: var_eqs: x=v6 ∧ z=w=v7 ∧ v=v8 - && linear_eqs: y = x +1 ∧ z = -1 ∧ v = 0 - && term_eqs: (-1)=z∧0=v∧[x +1]=y + && linear_eqs: x = y -1 ∧ z = -1 ∧ v = 0 + && term_eqs: (-1)=z∧0=v∧[y -1]=x && intervals: v=0 Result: changed - conditions: (empty) phi: term_eqs: (-1)=z∧[x +1]=y|}] + conditions: (empty) phi: term_eqs: (-1)=z∧[y -1]=x|}] let%expect_test _ = @@ -434,11 +433,11 @@ let%test_module "variable elimination" = Formula: conditions: (empty) phi: var_eqs: x=v6 ∧ v=v9 - && linear_eqs: z = x -y ∧ w = -x -y ∧ v = -x -y +1 ∧ v7 = -y ∧ v8 = 0 - && term_eqs: 0=v8∧[-x -y]=w∧[x -y]=z∧[-y]=v7∧[-x -y +1]=v + && linear_eqs: x = -v +v7 +1 ∧ y = -v7 ∧ z = -v +2·v7 +1 ∧ w = v -1 ∧ v8 = 0 + && term_eqs: 0=v8∧[v -1]=w∧[-v7]=y∧[-v +v7 +1]=x∧[-v +2·v7 +1]=z && intervals: v8=0 Result: changed - conditions: (empty) phi: term_eqs: [-x -y]=w∧[x -y]=z∧[-y]=v7∧[-x -y +1]=v|}] + conditions: (empty) phi: term_eqs: [v -1]=w∧[-v7]=y∧[-v +v7 +1]=x∧[-v +2·v7 +1]=z|}] let%expect_test _ = @@ -446,9 +445,9 @@ let%test_module "variable elimination" = [%expect {| Formula: - conditions: (empty) phi: var_eqs: x=w=v6 ∧ y=z && linear_eqs: y = x -4 && term_eqs: [x -4]=y + conditions: (empty) phi: var_eqs: x=w=v6 ∧ y=z && linear_eqs: x = y +4 && term_eqs: [y +4]=x Result: changed - conditions: (empty) phi: term_eqs: [x -4]=y|}] + conditions: (empty) phi: term_eqs: [y +4]=x|}] end ) @@ -481,8 +480,8 @@ let%test_module "non-linear simplifications" = {| conditions: (empty) phi: var_eqs: z=v8 ∧ w=v7 - && linear_eqs: y = 2 ∧ z = 2·x ∧ w = 4·x -3 ∧ v6 = 4·x - && term_eqs: 2=y∧[4·x -3]=w∧[2·x]=z∧[4·x]=v6 + && linear_eqs: x = 1/4·v6 ∧ y = 2 ∧ z = 1/2·v6 ∧ w = v6 -3 + && term_eqs: 2=y∧[v6 -3]=w∧[1/4·v6]=x∧[1/2·v6]=z && intervals: y=2|}] end ) @@ -505,8 +504,8 @@ let%test_module "inequalities" = {| conditions: (empty) phi: var_eqs: a3=z ∧ a2=y ∧ a1=x - && linear_eqs: a2 = a5 +a3 +3 ∧ a1 = -a5 +a4 -a3 -1 ∧ v6 = a4 +2 ∧ v7 = -a5 -3 - && term_eqs: [-a5 -3]=v7∧[-a5 +a4 -a3 -1]=a1∧[a4 +2]=v6∧[a5 +a3 +3]=a2 + && linear_eqs: a2 = a3 +a5 +3 ∧ a1 = -a3 +a4 -a5 -1 ∧ v6 = a4 +2 ∧ v7 = -a5 -3 + && term_eqs: [-a5 -3]=v7∧[-a3 +a4 -a5 -1]=a1∧[a4 +2]=v6∧[a3 +a5 +3]=a2 && intervals: a3≥0 ∧ a2≥0 ∧ a1≥0 ∧ v6≥2 ∧ v7≤-3 |}]