Skip to content

Commit

Permalink
Back out "Use only one order for abstract values"
Browse files Browse the repository at this point in the history
Reviewed By: rgrig

Differential Revision: D64102186

fbshipit-source-id: 70d4d8c77ab0afe041f91100dfd0f21d789a1155
  • Loading branch information
ngorogiannis authored and facebook-github-bot committed Oct 9, 2024
1 parent 8f937cc commit 06b81b1
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 101 deletions.
8 changes: 8 additions & 0 deletions infer/src/pulse/PulseAbstractValue.ml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ let pp f v = if is_restricted v then F.fprintf f "a%d" (-v) else F.fprintf f "v%

let yojson_of_t l = `String (F.asprintf "%a" pp l)

let compare_unrestricted_first v1 v2 =
if is_restricted v1 then
if is_restricted v2 then (* compare absolute values *) compare v2 v1
else (* unrestricted [v2] first *) 1
else if is_restricted v2 then (* unrestricted [v1] first *) -1
else compare v1 v2


module PPKey = struct
type nonrec t = t [@@deriving compare]

Expand Down
3 changes: 3 additions & 0 deletions infer/src/pulse/PulseAbstractValue.mli
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ val is_unrestricted : t -> bool

val pp : F.formatter -> t -> unit

val compare_unrestricted_first : t -> t -> int
(** an alternative comparison function that sorts unrestricted variables before restricted variables *)

module Set : PrettyPrintable.PPSet with type elt = t

module Map : sig
Expand Down
149 changes: 71 additions & 78 deletions infer/src/pulse/PulseFormula.ml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ module ValueHistory = PulseValueHistory
module Var = struct
include PulseAbstractValue

(* This is used by [var_eqs] to prefer restricted variables as representants. *)
let is_simpler_than v1 v2 = PulseAbstractValue.compare v1 v2 < 0

let is_simpler_or_equal v1 v2 = PulseAbstractValue.compare v1 v2 <= 0
end

module Q = QSafeCapped
Expand Down Expand Up @@ -92,8 +93,8 @@ module LinArith : sig
val mult : Q.t -> t -> t

val solve_eq : t -> t -> (Var.t * t) option SatUnsat.t
(** [solve_eq l1 l2] is [Sat (Some (x, l))] if [l1=l2 <=> x=l] (with x larger than all variables
of l), [Sat None] if [l1 = l2] is always true, and [Unsat] if it is always false *)
(** [solve_eq l1 l2] is [Sat (Some (x, l))] if [l1=l2 <=> x=l], [Sat None] if [l1 = l2] is always
true, and [Unsat] if it is always false *)

val of_q : Q.t -> t

Expand Down Expand Up @@ -124,18 +125,19 @@ module LinArith : sig
val subst_variable : Var.t -> _ subst_target -> t -> t
(** same as above for a single variable to substitute (more optimized) *)

val get_largest : t -> Var.t option
(** the largest [v∊l], according to [Var.compare] *)
val get_simplest : t -> Var.t option
(** the smallest [v∊l] according to [Var.is_simpler_than] *)

(** {2 Tableau-Specific Operations} *)

val is_restricted : t -> bool
(** [true] iff all the variables involved in the expression satisfy {!Var.is_restricted} *)

val solve_for_unrestricted : Var.t -> t -> (Var.t * t) option
(** [solve_for_unrestricted u l] returns [Some (x, l')] where [x] is the largest unrestricted
variable in [u=l] and [u=l <=> x=l']. If there are no unrestricted variables in [u=l], it
returns [None]. *)
(** if [l] contains at least one unrestricted variable then [solve_for_unrestricted u l] is
[Some (x, l')] where [x] is the smallest unrestricted variable in [l] and [u=l <=> x=l']. If
there are no unrestricted variables in [l] then [solve_for_unrestricted u l] is [None].
Assumes [u∉l]. *)

val pivot : Var.t * Q.t -> t -> t
(** [pivot (v, q) l] assumes [v] appears in [l] with coefficient [q] and returns [l'] such that
Expand All @@ -155,12 +157,28 @@ module LinArith : sig
val is_minimized : t -> bool
(** [is_minimized l] iff [classify_minimized_maximized l] is either [`Minimized] or [`Constant] *)
end = struct
(* define our own var map to get a custom order: we want to place unrestricted variables first in
the map so that [solve_for_unrestricted] can be implemented in terms of [solve_eq] easily *)
module VarMap = struct
include PrettyPrintable.MakePPMap (struct
type t = Var.t

let pp = Var.pp

let compare v1 v2 = Var.compare_unrestricted_first v1 v2
end)

(* unpleasant that we have to duplicate this definition from [Var.Map] here *)
let yojson_of_t yojson_of_val m =
`List (List.map ~f:(fun (k, v) -> `List [Var.yojson_of_t k; yojson_of_val v]) (bindings m))
end

(** invariant: the representation is always "canonical": coefficients cannot be [Q.zero] *)
type t = Q.t * Q.t Var.Map.t [@@deriving compare, equal]
type t = Q.t * Q.t VarMap.t [@@deriving compare, equal]

let yojson_of_t (c, vs) = `List [Var.Map.yojson_of_t Q.yojson_of_t vs; Q.yojson_of_t c]
let yojson_of_t (c, vs) = `List [VarMap.yojson_of_t Q.yojson_of_t vs; Q.yojson_of_t c]

let fold (_, vs) ~init ~f = IContainer.fold_of_pervasives_map_fold Var.Map.fold vs ~init ~f
let fold (_, vs) ~init ~f = IContainer.fold_of_pervasives_map_fold VarMap.fold vs ~init ~f

type 'term_t subst_target =
| QSubst of Q.t
Expand All @@ -170,7 +188,7 @@ end = struct
| NonLinearTermSubst of 'term_t

let pp pp_var fmt (c, vs) =
if Var.Map.is_empty vs then Q.pp_print fmt c
if VarMap.is_empty vs then Q.pp_print fmt c
else
let pp_c fmt c =
if not (Q.is_zero c) then
Expand All @@ -186,7 +204,7 @@ end = struct
in
let pp_vs fmt vs =
Pp.collection ~sep:"@;"
~fold:(IContainer.fold_of_pervasives_map_fold Var.Map.fold)
~fold:(IContainer.fold_of_pervasives_map_fold VarMap.fold)
(fun fmt (v, q) ->
F.fprintf fmt "%a%a" pp_coeff q pp_var v ;
is_first := false )
Expand All @@ -197,41 +215,41 @@ end = struct

let add (c1, vs1) (c2, vs2) =
( Q.add c1 c2
, Var.Map.union
, VarMap.union
(fun _v c1 c2 ->
let c = Q.add c1 c2 in
if Q.is_zero c then None else Some c )
vs1 vs2 )


let minus (c, vs) = (Q.neg c, Var.Map.map (fun c -> Q.neg c) vs)
let minus (c, vs) = (Q.neg c, VarMap.map (fun c -> Q.neg c) vs)

let subtract l1 l2 = add l1 (minus l2)

let zero = (Q.zero, Var.Map.empty)
let zero = (Q.zero, VarMap.empty)

let is_zero (c, vs) = Q.is_zero c && Var.Map.is_empty vs
let is_zero (c, vs) = Q.is_zero c && VarMap.is_empty vs

let mult q ((c, vs) as l) =
if Q.is_zero q then (* needed for correctness: coeffs cannot be zero *) zero
else if Q.is_one q then (* purely an optimisation *) l
else (Q.mul q c, Var.Map.map (fun c -> Q.mul q c) vs)
else (Q.mul q c, VarMap.map (fun c -> Q.mul q c) vs)


let pivot (x, coeff) (c, vs) =
let d = Q.neg coeff in
let vs' =
Var.Map.fold
(fun v' coeff' vs' -> if Var.equal v' x then vs' else Var.Map.add v' (Q.div coeff' d) vs')
vs Var.Map.empty
VarMap.fold
(fun v' coeff' vs' -> if Var.equal v' x then vs' else VarMap.add v' (Q.div coeff' d) vs')
vs VarMap.empty
in
(* note: [d≠0] by the invariant of the coefficient map [vs] *)
let c' = Q.div c d in
(c', vs')


let solve_eq_zero ((c, vs) as l) =
match Var.Map.max_binding_opt vs with
match VarMap.min_binding_opt vs with
| None ->
if Q.is_zero c then Sat None
else (
Expand All @@ -243,15 +261,15 @@ end = struct

let solve_eq l1 l2 = solve_eq_zero (subtract l1 l2)

let of_var v = (Q.zero, Var.Map.singleton v Q.one)
let of_var v = (Q.zero, VarMap.singleton v Q.one)

let of_q q = (q, Var.Map.empty)
let of_q q = (q, VarMap.empty)

let get_as_const (c, vs) = if Var.Map.is_empty vs then Some c else None
let get_as_const (c, vs) = if VarMap.is_empty vs then Some c else None

let get_as_var (c, vs) =
if Q.is_zero c then
match Var.Map.is_singleton_or_more vs with
match VarMap.is_singleton_or_more vs with
| Singleton (x, cx) when Q.is_one cx ->
Some x
| _ ->
Expand All @@ -262,7 +280,7 @@ end = struct
let get_as_variable_difference (c, vs) =
if Q.is_zero c then
(* the coefficient has to be 0 *)
let vs_seq = Var.Map.to_seq vs in
let vs_seq = VarMap.to_seq vs in
let open IOption.Let_syntax in
(* check that the expression consists of exactly two variables with coeffs 1 and -1 *)
let* (x, cx), vs_seq = Seq.uncons vs_seq in
Expand All @@ -276,7 +294,7 @@ end = struct

let get_constant_part (c, _) = c

let get_coefficient v (_, vs) = Var.Map.find_opt v vs
let get_coefficient v (_, vs) = VarMap.find_opt v vs

let of_subst_target v0 = function
| QSubst q ->
Expand All @@ -292,7 +310,7 @@ end = struct
let fold_subst_variables ((c, vs_foreign) as l0) ~init ~f =
let changed = ref false in
let acc_f, l' =
Var.Map.fold
VarMap.fold
(fun v_foreign q0 (acc_f, l) ->
let acc_f, op = f acc_f v_foreign in
( match op with
Expand All @@ -304,7 +322,7 @@ end = struct
changed := true ) ;
(acc_f, add (mult q0 (of_subst_target v_foreign op)) l) )
vs_foreign
(init, (c, Var.Map.empty))
(init, (c, VarMap.empty))
in
let l' = if !changed then l' else l0 in
(acc_f, l')
Expand All @@ -314,38 +332,40 @@ end = struct

(* OPTIM: for a single variable we can avoid iterating over the coefficient map *)
let subst_variable x subst_target ((c, vs) as l0) =
match Var.Map.find_opt x vs with
match VarMap.find_opt x vs with
| None ->
l0
| Some q ->
let vs' = Var.Map.remove x vs in
let vs' = VarMap.remove x vs in
add (mult q (of_subst_target x subst_target)) (c, vs')


let get_variables (_, vs) = Var.Map.to_seq vs |> Seq.map fst
let get_variables (_, vs) = VarMap.to_seq vs |> Seq.map fst

let get_largest l = Var.Map.max_binding_opt (snd l) |> Option.map ~f:fst
let get_simplest l = VarMap.min_binding_opt (snd l) |> Option.map ~f:fst

(** {2 Tableau-Specific Operations} *)

let is_restricted (_q, l) =
(* HACK: since restricted < unrestricted, it's enough to check the maximum *)
match Var.Map.max_binding_opt l with None -> true | Some (x, _) -> Var.is_restricted x
let is_restricted l =
(* HACK: unrestricted variables come first so we first test if there exists any unrestricted
variable in the map by checking its min element *)
not (get_simplest l |> Option.exists ~f:Var.is_unrestricted)


let solve_for_unrestricted w l =
match solve_eq (of_var w) l with
| Unsat | Sat None ->
None
| Sat (Some (x, _) as r) when Var.is_unrestricted x ->
r
| Sat (Some _) ->
None
if not (is_restricted l) then (
match solve_eq l (of_var w) with
| Unsat | Sat None ->
None
| Sat (Some (x, _) as r) ->
assert (Var.is_unrestricted x) ;
r )
else None


let classify_minimized_maximized (_, vs) =
let all_pos, all_neg =
Var.Map.fold
VarMap.fold
(fun _ coeff (all_pos, all_neg) ->
(Q.(coeff >= zero) && all_pos, Q.(coeff <= zero) && all_neg) )
vs (true, true)
Expand Down Expand Up @@ -2063,13 +2083,6 @@ module Formula = struct
(** opaque because we need to normalize variables in the co-domain of term equalities on the fly *)
type term_eqs
(* NOTE on variable orders. Both [var_eqs] and [linear_eqs] act as substitutions. We want these
substitutions to leave us with as many restricted variables as possible, so that we derive
more facts for the tableau. As representants in [var_eqs], we prefer restricted variables;
as keys in the [linear_eqs] map, we prefer unrestricted variables. We do so based on the
ASSUMPTION that the order on {!PulseAbstractValue.t} says that restricted variables are
smaller than unrestricted variables: for [var_eqs] we pick the representant to be the minimum,
for [linear_eqs] we pick the key to be the maximum. *)
type t = private
{ var_eqs: var_eqs
(** Equality relation between variables. We want to only use canonical representatives
Expand All @@ -2093,8 +2106,8 @@ module Formula = struct
[domain(linear_eqs) ∩ range(linear_eqs) = ∅], when seeing [linear_eqs] as a map
[x->l]
2. for all [x=l ∊ linear_eqs], [x > max({x'|x'∊l})] according to [is_simpler_than]
(in other words: [x] is the most complex variable in [x=l]). *)
2. for all [x=l ∊ linear_eqs], [x < min({x'|x'∊l})] according to [is_simpler_than]
(in other words: [x] is the simplest variable in [x=l]). *)
; term_eqs: term_eqs
(** Equalities of the form [t = x], used to detect when two abstract values are equal to
the same term (hence equal). Together with [var_eqs] and [linear_eqs] this gives a
Expand Down Expand Up @@ -2259,8 +2272,6 @@ module Formula = struct
-> atoms_occurrences:AtomMapOccurrences.t
-> t
(** escape hatch *)
val check_invariant : t -> unit
end = struct
type term_eqs = Term.VarMap.t_ [@@deriving compare, equal, yojson_of]
Expand Down Expand Up @@ -2296,16 +2307,6 @@ module Formula = struct
; atoms_occurrences= Var.Map.empty }
let check_invariant_linear_eq_lhs_max v l =
LinArith.fold l ~init:() ~f:(fun () (w, _) ->
if Var.compare v w <= 0 then
L.die InternalError "linear_eqs: lhs should be strictly bigger that all vars in rhs" )
let check_invariant_linear_eqs linear_eqs =
Var.Map.iter check_invariant_linear_eq_lhs_max linear_eqs
let get_repr phi x = VarUF.find phi.var_eqs x
let get_repr_as_var phi x = (get_repr phi x :> Var.t)
Expand Down Expand Up @@ -2446,12 +2447,6 @@ module Formula = struct
F.pp_close_box fmt ()
let check_invariant phi =
if Debug.debug then (
Debug.p "Checking invariant of %a@\n" (pp_with_pp_var Var.pp) phi ;
check_invariant_linear_eqs phi.linear_eqs )
(* {2 mutations} *)
let add_const_eq v t phi =
Expand Down Expand Up @@ -2642,7 +2637,6 @@ module Formula = struct
let add_linear_eq v l phi =
Debug.p "add_linear_eq %a=%a@\n" Var.pp v (LinArith.pp Var.pp) l ;
if Debug.debug then check_invariant_linear_eq_lhs_max v l ;
let phi =
match Var.Map.find_opt v phi.linear_eqs with
| Some l_old ->
Expand Down Expand Up @@ -2954,7 +2948,6 @@ module Formula = struct
[l1] and [l2] should have already been through {!normalize_linear} (w.r.t. [phi]) *)
let rec solve_normalized_lin_eq ~fuel ?(force_no_tableau = false) new_eqs l1 l2 phi =
Unsafe.check_invariant phi ;
Debug.p "solve_normalized_lin_eq: %a=%a@\n" (LinArith.pp Var.pp) l1 (LinArith.pp Var.pp) l2 ;
LinArith.solve_eq l1 l2
>>= function
Expand Down Expand Up @@ -3260,15 +3253,15 @@ module Formula = struct
Var.pp v (LinArith.pp Var.pp) lv ;
let r =
let lv' = LinArith.subst_variable x (LinSubst lx) lv in
(* check the invariant that [v] is (strictly) larger than any variable in
(* check the invariant that [v] is (strictly) simpler than any variable in
[lv']; because the invariant was true before it's enough to check that it's
larger than any variable in [lx] *)
simpler than any variable in [lx] *)
let needs_pivot =
match LinArith.get_largest lx with
match LinArith.get_simplest lx with
| None ->
false
| Some min ->
Var.compare v min <= 0
Var.is_simpler_or_equal min v
in
Debug.p "needs_pivot= %b@\n" needs_pivot ;
if needs_pivot then
Expand Down
Loading

0 comments on commit 06b81b1

Please sign in to comment.