From db8be568ccfd843f0a665527e45f1939cd06aead Mon Sep 17 00:00:00 2001 From: Markus de Medeiros Date: Tue, 24 Sep 2024 11:05:55 -0400 Subject: [PATCH] copy remaining prob_lang files --- theories/meas_lang/class_instances.v | 160 +++ theories/meas_lang/ctx_subst.v | 48 + theories/meas_lang/erasure.v | 691 +++++++++++ theories/meas_lang/exec_lang.v | 66 + theories/meas_lang/metatheory.v | 1724 ++++++++++++++++++++++++++ theories/meas_lang/tactics.v | 87 ++ theories/meas_lang/wp_tactics.v | 756 +++++++++++ 7 files changed, 3532 insertions(+) create mode 100644 theories/meas_lang/class_instances.v create mode 100644 theories/meas_lang/ctx_subst.v create mode 100644 theories/meas_lang/erasure.v create mode 100644 theories/meas_lang/exec_lang.v create mode 100644 theories/meas_lang/metatheory.v create mode 100644 theories/meas_lang/tactics.v create mode 100644 theories/meas_lang/wp_tactics.v diff --git a/theories/meas_lang/class_instances.v b/theories/meas_lang/class_instances.v new file mode 100644 index 00000000..fa92b141 --- /dev/null +++ b/theories/meas_lang/class_instances.v @@ -0,0 +1,160 @@ +From Coq Require Import Reals Psatz. +From clutch.common Require Export language. +From clutch.meas_lang Require Export lang tactics notation. +From iris.prelude Require Import options. + +(* +Global Instance into_val_val v : IntoVal (Val v) v. +Proof. done. Qed. +Global Instance as_val_val v : AsVal (Val v). +Proof. by eexists. Qed. + +(** * Instances of the [Atomic] class *) +Section atomic. + Local Ltac solve_atomic := + apply strongly_atomic_atomic, ectx_language_atomic; + [intros ????; simpl; by inv_head_step + |apply ectxi_language_sub_redexes_are_values; intros [] **; naive_solver]. + + Global Instance rec_atomic s f x e : Atomic s (Rec f x e). + Proof. solve_atomic. Qed. + Global Instance injl_atomic s v : Atomic s (InjL (Val v)). + Proof. solve_atomic. Qed. + Global Instance injr_atomic s v : Atomic s (InjR (Val v)). + Proof. solve_atomic. Qed. + (** The instance below is a more general version of [Skip] *) + Global Instance beta_atomic s f x v1 v2 : Atomic s (App (RecV f x (Val v1)) (Val v2)). + Proof. destruct f,x; solve_atomic. Qed. + + Global Instance unop_atomic s op v : Atomic s (UnOp op (Val v)). + Proof. solve_atomic. Qed. + Global Instance binop_atomic s op v1 v2 : Atomic s (BinOp op (Val v1) (Val v2)). + Proof. solve_atomic. Qed. + Global Instance if_true_atomic s v1 e2 : + Atomic s (If (Val $ LitV $ LitBool true) (Val v1) e2). + Proof. solve_atomic. Qed. + Global Instance if_false_atomic s e1 v2 : + Atomic s (If (Val $ LitV $ LitBool false) e1 (Val v2)). + Proof. solve_atomic. Qed. + + Global Instance fst_atomic s v : Atomic s (Fst (Val v)). + Proof. solve_atomic. Qed. + Global Instance snd_atomic s v : Atomic s (Snd (Val v)). + Proof. solve_atomic. Qed. + + Global Instance alloc_atomic s v : Atomic s (Alloc (Val v)). + Proof. solve_atomic. Qed. + Global Instance load_atomic s v : Atomic s (Load (Val v)). + Proof. solve_atomic. Qed. + Global Instance store_atomic s v1 v2 : Atomic s (Store (Val v1) (Val v2)). + Proof. solve_atomic. Qed. + + Global Instance rand_atomic s z l : Atomic s (Rand (Val (LitV (LitInt z))) (Val (LitV (LitLbl l)))). + Proof. solve_atomic. Qed. + Global Instance rand_atomic_int s z : Atomic s (Rand (Val (LitV (LitInt z))) (Val (LitV LitUnit))). + Proof. solve_atomic. Qed. + Global Instance alloc_tape_atomic s z : Atomic s (AllocTape (Val (LitV (LitInt z)))). + Proof. solve_atomic. Qed. + + Global Instance tick_atomic s z : Atomic s (Tick (Val (LitV (LitInt z)))). + Proof. solve_atomic. Qed. +End atomic. + +(** * Instances of the [PureExec] class *) +(** The behavior of the various [wp_] tactics with regard to lambda differs in +the following way: + +- [wp_pures] does *not* reduce lambdas/recs that are hidden behind a definition. +- [wp_rec] and [wp_lam] reduce lambdas/recs that are hidden behind a definition. + +To realize this behavior, we define the class [AsRecV v f x erec], which takes a +value [v] as its input, and turns it into a [RecV f x erec] via the instance +[AsRecV_recv : AsRecV (RecV f x e) f x e]. We register this instance via +[Hint Extern] so that it is only used if [v] is syntactically a lambda/rec, and +not if [v] contains a lambda/rec that is hidden behind a definition. + +To make sure that [wp_rec] and [wp_lam] do reduce lambdas/recs that are hidden +behind a definition, we activate [AsRecV_recv] by hand in these tactics. *) +Class AsRecV (v : val) (f x : binder) (erec : expr) := + as_recv : v = RecV f x erec. +Global Hint Mode AsRecV ! - - - : typeclass_instances. +Definition AsRecV_recv f x e : AsRecV (RecV f x e) f x e := eq_refl. +Global Hint Extern 0 (AsRecV (RecV _ _ _) _ _ _) => + apply AsRecV_recv : typeclass_instances. + +Section pure_exec. + Local Ltac solve_exec_safe := intros; subst; eexists; eapply head_step_support_equiv_rel; eauto with head_step. + Local Ltac solve_exec_puredet := + intros; simpl; + (repeat case_match); simplify_eq; + rewrite dret_1_1 //. + Local Ltac solve_pure_exec := + subst; intros ?; apply nsteps_once, pure_head_step_pure_step; + constructor; [solve_exec_safe | solve_exec_puredet]. + + Global Instance pure_recc f x (erec : expr) : + PureExec True 1 (Rec f x erec) (Val $ RecV f x erec). + Proof. + solve_pure_exec. + Qed. + + Global Instance pure_pairc (v1 v2 : val) : + PureExec True 1 (Pair (Val v1) (Val v2)) (Val $ PairV v1 v2). + Proof. solve_pure_exec. Qed. + Global Instance pure_injlc (v : val) : + PureExec True 1 (InjL $ Val v) (Val $ InjLV v). + Proof. solve_pure_exec. Qed. + Global Instance pure_injrc (v : val) : + PureExec True 1 (InjR $ Val v) (Val $ InjRV v). + Proof. solve_pure_exec. Qed. + + Global Instance pure_beta f x (erec : expr) (v1 v2 : val) `{!AsRecV v1 f x erec} : + PureExec True 1 (App (Val v1) (Val v2)) (subst' x v2 (subst' f v1 erec)). + Proof. unfold AsRecV in *. subst. solve_pure_exec. Qed. + + Global Instance pure_unop op v v' : + PureExec (un_op_eval op v = Some v') 1 (UnOp op (Val v)) (Val v'). + Proof. solve_pure_exec. Qed. + + Global Instance pure_binop op v1 v2 v' : + PureExec (bin_op_eval op v1 v2 = Some v') 1 (BinOp op (Val v1) (Val v2)) (Val v') | 10. + Proof. solve_pure_exec. Qed. + + (* Lower-cost instance for [EqOp]. *) + Global Instance pure_eqop v1 v2 : + PureExec (vals_compare_safe v1 v2) 1 + (BinOp EqOp (Val v1) (Val v2)) + (Val $ LitV $ LitBool $ bool_decide (v1 = v2)) | 1. + Proof. + intros Hcompare. + cut (bin_op_eval EqOp v1 v2 = Some $ LitV $ LitBool $ bool_decide (v1 = v2)). + { intros. revert Hcompare. solve_pure_exec. } + rewrite /bin_op_eval /= decide_True //. + Qed. + + Global Instance pure_if_true e1 e2 : + PureExec True 1 (If (Val $ LitV $ LitBool true) e1 e2) e1. + Proof. solve_pure_exec. Qed. + Global Instance pure_if_false e1 e2 : + PureExec True 1 (If (Val $ LitV $ LitBool false) e1 e2) e2. + Proof. solve_pure_exec. Qed. + + Global Instance pure_fst v1 v2 : + PureExec True 1 (Fst (Val $ PairV v1 v2)) (Val v1). + Proof. solve_pure_exec. Qed. + Global Instance pure_snd v1 v2 : + PureExec True 1 (Snd (Val $ PairV v1 v2)) (Val v2). + Proof. solve_pure_exec. Qed. + + Global Instance pure_case_inl v e1 e2 : + PureExec True 1 (Case (Val $ InjLV v) e1 e2) (App e1 (Val v)). + Proof. solve_pure_exec. Qed. + Global Instance pure_case_inr v e1 e2 : + PureExec True 1 (Case (Val $ InjRV v) e1 e2) (App e2 (Val v)). + Proof. solve_pure_exec. Qed. + + Global Instance pure_tick (z : Z) : + PureExec True 1 (Tick #z) #(). + Proof. solve_pure_exec. Qed. +End pure_exec. +*) diff --git a/theories/meas_lang/ctx_subst.v b/theories/meas_lang/ctx_subst.v new file mode 100644 index 00000000..319242e0 --- /dev/null +++ b/theories/meas_lang/ctx_subst.v @@ -0,0 +1,48 @@ +From stdpp Require Import base stringmap fin_sets fin_map_dom. +From clutch.meas_lang Require Export lang metatheory ectx_language ectxi_language. + +(* +(** Substitution in the contexts *) +Definition subst_map_ctx_item (es : stringmap val) (K : ectx_item) := + match K with + | AppLCtx v2 => AppLCtx v2 + | AppRCtx e1 => AppRCtx (subst_map es e1) + | UnOpCtx op => UnOpCtx op + | BinOpLCtx op v2 => BinOpLCtx op v2 + | BinOpRCtx op e1 => BinOpRCtx op (subst_map es e1) + | IfCtx e1 e2 => IfCtx (subst_map es e1) (subst_map es e2) + | PairLCtx v2 => PairLCtx v2 + | PairRCtx e1 => PairRCtx (subst_map es e1) + | FstCtx => FstCtx + | SndCtx => SndCtx + | InjLCtx => InjLCtx + | InjRCtx => InjRCtx + | CaseCtx e1 e2 => CaseCtx (subst_map es e1) (subst_map es e2) + | AllocNLCtx v2 => AllocNLCtx v2 + | AllocNRCtx e1 => AllocNRCtx (subst_map es e1) + | LoadCtx => LoadCtx + | StoreLCtx v2 => StoreLCtx v2 + | StoreRCtx e1 => StoreRCtx (subst_map es e1) + | AllocTapeCtx => AllocTapeCtx + | RandLCtx v2 => RandLCtx v2 + | RandRCtx e1 => RandRCtx (subst_map es e1) + | TickCtx => TickCtx + end. + +Definition subst_map_ctx (es : stringmap val) (K : list ectx_item) := + map (subst_map_ctx_item es) K. + +Lemma subst_map_fill_item (vs : stringmap val) (Ki : ectx_item) (e : expr) : + subst_map vs (fill_item Ki e) = + fill_item (subst_map_ctx_item vs Ki) (subst_map vs e). +Proof. induction Ki; simpl; eauto with f_equal. Qed. + +Lemma subst_map_fill (vs : stringmap val) (K : list ectx_item) (e : expr) : + subst_map vs (fill K e) = fill (subst_map_ctx vs K) (subst_map vs e). +Proof. + generalize dependent e. generalize dependent vs. + induction K as [|Ki K]; eauto. + intros es e. simpl. + by rewrite IHK subst_map_fill_item. +Qed. +*) diff --git a/theories/meas_lang/erasure.v b/theories/meas_lang/erasure.v new file mode 100644 index 00000000..f019dc62 --- /dev/null +++ b/theories/meas_lang/erasure.v @@ -0,0 +1,691 @@ +From Coq Require Import Reals Psatz. +From Coquelicot Require Import Rcomplements Rbar Lim_seq. +From stdpp Require Import fin_maps fin_map_dom. +From clutch.prelude Require Import stdpp_ext. +(* From clutch.common Require Import exec language ectx_language erasable. *) +From clutch.meas_lang Require Import notation lang metatheory. +(* From clutch.prob Require Import couplings couplings_app markov. *) + +Set Default Proof Using "Type*". +Local Open Scope R. + +(* +Section erasure_helpers. + + Variable (m : nat). + Hypothesis IH : + ∀ (e1 : expr) (σ1 : state) α N zs, + tapes σ1 !! α = Some (N; zs) → + Rcoupl + (dmap (λ x, x.1) (pexec m (e1, σ1))) + (dmap (λ x, x.1) (dunifP N ≫= (λ z, pexec m (e1, state_upd_tapes <[α:= (N; zs ++ [z])]> σ1)))) eq. + + Local Lemma ind_case_det e σ α N zs K : + tapes σ !! α = Some (N; zs) → + is_det_head_step e σ = true → + Rcoupl + (dmap (fill_lift K) (head_step e σ) ≫= λ ρ, dmap (λ x, x.1) (pexec m ρ)) + (dunifP N ≫= (λ z, dmap + (fill_lift K) + (head_step e (state_upd_tapes <[α:= (N; zs ++ [z]) ]> σ)) ≫= λ ρ, dmap (λ x, x.1) (pexec m ρ))) + (=). + Proof using m IH. + intros Hα (e2 & (σ2 & Hdet))%is_det_head_step_true%det_step_pred_ex_rel. + erewrite 1!det_head_step_singleton; [|done..]. + setoid_rewrite (det_head_step_singleton ); eauto; last first. + - eapply det_head_step_upd_tapes; eauto. + - erewrite det_step_eq_tapes in Hα; [|done]. + rewrite !dmap_dret. + setoid_rewrite (dmap_dret (fill_lift K)). + rewrite !dret_id_left. + erewrite (distr_ext (dunifP _ ≫= _) _); last first. + { intros. apply dbind_pmf_ext; [|done..]. intros. + rewrite dret_id_left. done. } + rewrite -dmap_dbind. apply IH. done. + Qed. + + Local Lemma ind_case_dzero e σ α N zs K : + tapes σ !! α = Some (N; zs) → + head_step e σ = dzero → + Rcoupl + (dmap (fill_lift K) (head_step e σ) ≫= λ ρ, dmap (λ x, x.1) (pexec m ρ)) + (dunifP N ≫= (λ z, + dmap (fill_lift K) (head_step e (state_upd_tapes <[α:=(N; zs ++ [z])]> σ)) ≫= + λ ρ, dmap (λ x, x.1) (pexec m ρ))) eq. + Proof using m IH. + intros Hα Hz. + rewrite Hz. + setoid_rewrite head_step_dzero_upd_tapes; [|by eapply elem_of_dom_2|done]. + rewrite dmap_dzero dbind_dzero dzero_dbind. + apply Rcoupl_dzero_dzero. + Qed. + + Local Lemma ind_case_alloc σ α N ns (z : Z) K : + tapes σ !! α = Some (N; ns) → + Rcoupl + (dmap (fill_lift K) (head_step (alloc #z) σ) ≫= λ ρ, dmap (λ x, x.1) (pexec m ρ)) + (dunifP N ≫= + (λ n, dmap (fill_lift K) (head_step (alloc #z) (state_upd_tapes <[α:= (N; ns ++ [n])]> σ)) ≫= λ ρ, dmap (λ x, x.1) (pexec m ρ))) + eq. + Proof using m IH. + intros Hα. + rewrite dmap_dret dret_id_left -/exec. + setoid_rewrite (dmap_dret (fill_lift K)). + erewrite (distr_ext (dunifP N ≫= _)); last first. + { intros. apply dbind_pmf_ext; [|done..]. + intros. rewrite dret_id_left. done. } + rewrite -dmap_dbind. + (* TODO: fix slightly ugly hack ... *) + revert IH; intro IHm. + apply lookup_total_correct in Hα as Hαtot. + pose proof (elem_fresh_ne _ _ _ Hα) as Hne. + erewrite dbind_ext_right; last first. + { intros n. + rewrite -(fresh_loc_upd_some _ _ (N; ns)); [|done]. + rewrite (fresh_loc_upd_swap σ α (N; ns) (_; [])) //. } + apply IHm. + by apply fresh_loc_lookup. + Qed. + + Local Lemma ind_case_rand_some σ α α' K N M (z : Z) n ns ns' : + N = Z.to_nat z → + tapes σ !! α = Some (M; ns') → + tapes σ !! α' = Some (N; n :: ns) → + Rcoupl + (dmap (fill_lift K) (head_step (rand(#lbl:α') #z) σ) ≫= λ ρ, dmap (λ x, x.1) (pexec m ρ)) + (dunifP M ≫= + (λ n, dmap (fill_lift K) + (head_step (rand(#lbl:α') #z) (state_upd_tapes <[α:= (M; ns' ++ [n])]> σ)) + ≫= λ ρ, dmap (λ x, x.1) (pexec m ρ))) + (=). + Proof using m IH. + intros Hz Hα Hα'. + apply lookup_total_correct in Hα as Hαtot. + apply lookup_total_correct in Hα' as Hα'tot. + destruct (decide (α = α')) as [-> | Hαneql]. + - simplify_eq. rewrite /head_step Hα. + setoid_rewrite lookup_insert. + rewrite bool_decide_eq_true_2 //. + rewrite dmap_dret dret_id_left -/exec. + erewrite dbind_ext_right; last first. + { intros. + rewrite -app_comm_cons. + rewrite upd_tape_twice dmap_dret dret_id_left -/exec //. } + assert (Haux : ∀ n, + state_upd_tapes <[α':=(Z.to_nat z; ns ++ [n])]> σ = + state_upd_tapes <[α':=(Z.to_nat z; ns ++ [n])]> (state_upd_tapes <[α':=(Z.to_nat z; ns)]> σ)). + { intros. rewrite /state_upd_tapes. f_equal. rewrite insert_insert //. } + erewrite dbind_ext_right; [| intros; rewrite Haux; done]. + rewrite -dmap_dbind. + apply IH. + apply lookup_insert. + - rewrite /head_step Hα'. + rewrite bool_decide_eq_true_2 //. + setoid_rewrite lookup_insert_ne; [|done]. + rewrite Hα' bool_decide_eq_true_2 //. + rewrite !dmap_dret !dret_id_left -/exec. + erewrite dbind_ext_right; last first. + { intros. + rewrite upd_diff_tape_comm; [|done]. + rewrite dmap_dret dret_id_left -/exec //. } + rewrite -dmap_dbind. + eapply IH. + rewrite lookup_insert_ne //. + Qed. + + Local Lemma ind_case_rand_empty σ α α' K (N M : nat) z ns : + M = Z.to_nat z → + tapes σ !! α = Some (N; ns) → + tapes σ !! α' = Some (M; []) → + Rcoupl + (dmap (fill_lift K) (head_step (rand(#lbl:α') #z) σ) ≫= λ ρ, dmap (λ x, x.1) (pexec m ρ)) + (dunifP N ≫= + (λ n, dmap (fill_lift K) + (head_step (rand(#lbl:α') #z) (state_upd_tapes <[α := (N; ns ++ [n])]> σ)) + ≫= λ ρ, dmap (λ x, x.1) (pexec m ρ))) + eq. + Proof using m IH. + intros Hz Hα Hα'. + destruct (decide (α = α')) as [-> | Hαneql]. + + simplify_eq. rewrite /head_step Hα. + rewrite bool_decide_eq_true_2 //. + rewrite {1 2}/dmap. + rewrite -!dbind_assoc -/exec. + eapply (Rcoupl_dbind _ _ _ _ (=)); [ |apply Rcoupl_eq]. + intros ? b ->. + do 2 rewrite dret_id_left. + rewrite lookup_insert. + rewrite bool_decide_eq_true_2 //. + rewrite dmap_dret dret_id_left -/exec. + rewrite upd_tape_twice. + rewrite /state_upd_tapes insert_id //. + destruct σ; simpl. + apply Rcoupl_eq. + + rewrite /head_step /=. + setoid_rewrite lookup_insert_ne; [|done]. + rewrite Hα'. + rewrite bool_decide_eq_true_2 //. + rewrite {1 2}/dmap. + erewrite (dbind_ext_right (dunifP N)); last first. + { intro. + rewrite {1 2}/dmap. + do 2 rewrite -dbind_assoc -/exec. + done. } + rewrite -!dbind_assoc -/exec. + rewrite dbind_comm. + eapply Rcoupl_dbind; [|apply Rcoupl_eq]. + intros; simplify_eq. + do 2 rewrite dret_id_left /=. + erewrite (distr_ext (dunifP N≫=_)); last first. + { intros. apply dbind_pmf_ext; [|done..]. + intros. rewrite !dret_id_left. done. + } + rewrite dbind_assoc. + by apply IH. + Qed. + + Local Lemma ind_case_rand_some_neq σ α α' K N M ns ns' z : + N ≠ Z.to_nat z → + tapes σ !! α = Some (M; ns') → + tapes σ !! α' = Some (N; ns) → + Rcoupl + (dmap (fill_lift K) (head_step (rand(#lbl:α') #z) σ) ≫= λ ρ, dmap (λ x, x.1) (pexec m ρ)) + (dunifP M ≫= + (λ n, dmap (fill_lift K) + (head_step (rand(#lbl:α') #z) (state_upd_tapes <[α:= (M; ns' ++ [n]) : tape]> σ)) + ≫= λ ρ, dmap (λ x, x.1) (pexec m ρ))) + (=). + Proof using m IH. + intros Hz Hα Hα'. + rewrite /head_step Hα'. + rewrite bool_decide_eq_false_2 //. + destruct (decide (α = α')) as [-> | Heq]. + - simplify_eq. + setoid_rewrite lookup_insert. + rewrite bool_decide_eq_false_2 //. + rewrite /dmap /=. + rewrite -!dbind_assoc -/exec. + erewrite (dbind_ext_right (dunifP M)); last first. + { intros. rewrite -!dbind_assoc -/exec //. } + rewrite dbind_comm. + eapply Rcoupl_dbind; [|apply Rcoupl_eq]. + intros; simplify_eq. + rewrite 2!dret_id_left. + erewrite (distr_ext (dunifP M ≫=_ )); last first. + { intros. apply dbind_pmf_ext; [|done..]. + intros. rewrite !dret_id_left; done. + } + rewrite -dmap_dbind. + by apply IH. + - setoid_rewrite lookup_insert_ne; [|done]. + rewrite Hα' bool_decide_eq_false_2 //. + rewrite /dmap. + rewrite -!dbind_assoc -/exec. + erewrite (dbind_ext_right (dunifP M)); last first. + { intros. rewrite -!dbind_assoc -/exec //. } + rewrite dbind_comm. + eapply Rcoupl_dbind; [|apply Rcoupl_eq]. + intros; simplify_eq. + rewrite 2!dret_id_left. + erewrite (distr_ext (dunifP M ≫=_ )); last first. + { intros. apply dbind_pmf_ext; [|done..]. + intros. rewrite !dret_id_left; done. + } + rewrite -dmap_dbind. + by apply IH. + Qed. + + Local Lemma ind_case_rand σ α K (M N : nat) z ns : + N = Z.to_nat z → + tapes σ !! α = Some (M; ns) → + Rcoupl + (dmap (fill_lift K) (head_step (rand #z) σ) ≫= λ ρ, dmap (λ x, x.1) (pexec m ρ)) + (dunifP M ≫= + (λ n, + dmap (fill_lift K) + (head_step (rand #z) (state_upd_tapes <[α := (M; ns ++ [n]) : tape]> σ)) + ≫= λ ρ, dmap (λ x, x.1) (pexec m ρ))) + eq. + Proof using m IH. + intros Hz Hα. + rewrite /head_step. + rewrite {1 2}/dmap. + erewrite (dbind_ext_right (dunifP M)); last first. + { intro. + rewrite {1 2}/dmap. + do 2 rewrite -dbind_assoc //. } + rewrite -/exec /=. + rewrite -!dbind_assoc -/exec. + erewrite (dbind_ext_right (dunifP M)); last first. + { intros n. rewrite -!dbind_assoc. done. } + rewrite dbind_comm. + eapply Rcoupl_dbind; [|apply Rcoupl_eq]. + intros; simplify_eq. + do 2 rewrite dret_id_left. + erewrite (distr_ext (dunifP M ≫=_ )); last first. + { intros. apply dbind_pmf_ext; [|done..]. + intros. rewrite !dret_id_left; done. + } + rewrite -dmap_dbind. + apply IH; auto. + Qed. + +End erasure_helpers. + + +Lemma prim_coupl_upd_tapes_dom m e1 σ1 α N ns : + σ1.(tapes) !! α = Some (N; ns) → + Rcoupl + (dmap (λ x, x.1) (pexec m (e1, σ1))) + (dunifP N ≫= + (λ n, dmap (λ x, x.1) (pexec m (e1, state_upd_tapes <[α := (N; ns ++ [n])]> σ1)))) + (=). +Proof. + rewrite -dmap_dbind. + revert e1 σ1 α N ns; induction m; intros e1 σ1 α N ns Hα. + - rewrite /pexec /=. + rewrite dmap_dret. + rewrite dmap_dbind. + erewrite (distr_ext (dunifP N≫=_)); last first. + { intros. apply dbind_pmf_ext; [|done..]. + intros. rewrite dmap_dret. done. + } + rewrite (dret_const (dunifP N)); [apply Rcoupl_eq | apply dunif_mass; lia]. + - rewrite pexec_Sn /step_or_final /=. + destruct (to_val e1) eqn:He1. + + rewrite dret_id_left. + rewrite -/(pexec m (e1, σ1)). + rewrite pexec_is_final; last by rewrite /is_final. + rewrite dmap_dret. simpl. + rewrite dmap_dbind. + erewrite (distr_ext (dunifP N ≫=_)); last first. + { intros. apply dbind_pmf_ext; [|done..]. + intros. rewrite pexec_is_final; last by rewrite /is_final. + rewrite dmap_dret. simpl. done. + } + rewrite dret_const; [|solve_distr_mass]. + apply Rcoupl_eq. + + rewrite !dmap_dbind. + erewrite (distr_ext (dunifP N ≫= _)); last first. + { intros. apply dbind_pmf_ext; [|done..]. + intros. setoid_rewrite pexec_Sn. + rewrite /step_or_final/=He1/prim_step/=. + rewrite dmap_dbind. + done. + } + rewrite /prim_step/=. + destruct (decomp e1) as [K ered] eqn:Hdecomp_e1. + rewrite Hdecomp_e1. + destruct (det_or_prob_or_dzero ered σ1) as [ HD | [HP | HZ]]. + * eapply ind_case_det; [done|done|by apply is_det_head_step_true]. + * inversion HP; simplify_eq. + -- by eapply ind_case_alloc. + -- by eapply ind_case_rand_some. + -- by eapply ind_case_rand_empty. + -- by eapply ind_case_rand_some_neq. + -- by eapply ind_case_rand. + * by eapply ind_case_dzero. +Qed. + +Lemma pexec_coupl_step_pexec m e1 σ1 α bs : + σ1.(tapes) !! α = Some bs → + Rcoupl + (dmap (λ ρ, ρ.1) (pexec m (e1, σ1))) + (dmap (λ ρ, ρ.1) (state_step σ1 α ≫= (λ σ2, pexec m (e1, σ2)))) + eq. +Proof. + intros. + destruct bs. + eapply Rcoupl_eq_trans; first eapply prim_coupl_upd_tapes_dom; try done. + rewrite <-dmap_dbind. + apply Rcoupl_dmap. + erewrite state_step_unfold; last done. + rewrite /dmap. + rewrite -dbind_assoc. + eapply Rcoupl_dbind; last apply Rcoupl_eq. + intros ??->. + rewrite dret_id_left. + eapply Rcoupl_mono; first apply Rcoupl_eq. + intros. naive_solver. +Qed. + +Lemma prim_coupl_step_prim m e1 σ1 α bs : + σ1.(tapes) !! α = Some bs → + Rcoupl + (exec m (e1, σ1)) + (state_step σ1 α ≫= (λ σ2, exec m (e1, σ2))) + eq. +Proof. + intros Hα. + epose proof pexec_coupl_step_pexec _ _ _ _ _ Hα as H. + setoid_rewrite exec_pexec_relate. + simpl. + erewrite (distr_ext _ (dmap (λ ρ, ρ.1) (pexec m (e1, σ1)) ≫= + λ e, match to_val e with | Some b => dret b | None => dzero end)); last first. + { intros. rewrite /dmap. + rewrite -dbind_assoc. simpl. + apply dbind_pmf_ext; try done. + intros. rewrite dret_id_left. done. + } + erewrite (distr_ext (state_step _ _ ≫= _) _). + - eapply Rcoupl_dbind; last exact. + intros. subst. apply Rcoupl_eq. + - intros. rewrite /dmap. + rewrite -!dbind_assoc. simpl. + apply dbind_pmf_ext; try done. + intros. apply dbind_pmf_ext; try done. + intros. + rewrite dret_id_left. done. +Qed. + +Lemma state_step_erasable σ1 α bs : + σ1.(tapes) !! α = Some bs → + erasable (state_step σ1 α) σ1. +Proof. + intros. rewrite /erasable. + intros. + symmetry. + apply Rcoupl_eq_elim. + by eapply prim_coupl_step_prim. +Qed. + +Lemma iterM_state_step_erasable σ1 α bs n: + σ1.(tapes) !! α = Some bs → + erasable (iterM n (λ σ, state_step σ α) σ1) σ1. +Proof. + revert σ1 bs. + induction n; intros σ1 bs H. + - simpl. apply dret_erasable. + - simpl. apply erasable_dbind; first by eapply state_step_erasable. + intros ? H0. + destruct bs. + erewrite state_step_unfold in H0; last done. + rewrite dmap_pos in H0. destruct H0 as (?&->&K). + eapply IHn. simpl. apply lookup_insert. +Qed. + +Lemma limprim_coupl_step_limprim_aux e1 σ1 α bs v: + σ1.(tapes) !! α = Some bs → + (lim_exec (e1, σ1)) v = + (state_step σ1 α ≫= (λ σ2, lim_exec (e1, σ2))) v. +Proof. + intro Hsome. + rewrite lim_exec_unfold/=. + rewrite {2}/pmf/=/dbind_pmf. + setoid_rewrite lim_exec_unfold. + simpl in *. + assert + (SeriesC (λ a: state, state_step σ1 α a * Sup_seq (λ n : nat, exec n (e1, a) v)) = + SeriesC (λ a: state, Sup_seq (λ n : nat, state_step σ1 α a * exec n (e1, a) v))) as Haux. + { apply SeriesC_ext; intro v'. + apply eq_rbar_finite. + rewrite rmult_finite. + rewrite (rbar_finite_real_eq (Sup_seq (λ n : nat, exec n (e1, v') v))); auto. + - rewrite <- (Sup_seq_scal_l (state_step σ1 α v') (λ n : nat, exec n (e1, v') v)); auto. + - apply (Rbar_le_sandwich 0 1). + + apply (Sup_seq_minor_le _ _ 0%nat); simpl; auto. + + apply upper_bound_ge_sup; intro; simpl; auto. + } + rewrite Haux. + rewrite (MCT_seriesC _ (λ n, exec n (e1,σ1) v) (lim_exec (e1,σ1) v)); auto. + - real_solver. + - intros. apply Rmult_le_compat; auto; [done|apply exec_mono]. + - intro. exists (state_step σ1 α a)=>?. real_solver. + - intro n. + rewrite (Rcoupl_eq_elim _ _ (prim_coupl_step_prim n e1 σ1 α bs Hsome)); auto. + rewrite {3}/pmf/=/dbind_pmf. + apply SeriesC_correct; auto. + apply (ex_seriesC_le _ (state_step σ1 α)); auto. + real_solver. + - rewrite lim_exec_unfold. + rewrite rbar_finite_real_eq; [apply Sup_seq_correct |]. + rewrite mon_sup_succ. + + apply (Rbar_le_sandwich 0 1); auto. + * apply (Sup_seq_minor_le _ _ 0%nat); simpl; auto. + * apply upper_bound_ge_sup; intro; simpl; auto. + + intros. eapply exec_mono. +Qed. + +Lemma limprim_coupl_step_limprim e1 σ1 α bs : + σ1.(tapes) !! α = Some bs → + Rcoupl + (lim_exec (e1, σ1)) + (state_step σ1 α ≫= (λ σ2, lim_exec (e1, σ2))) + eq. +Proof. + intro Hsome. + erewrite (distr_ext (lim_exec (e1, σ1))); last first. + - intro a. + apply (limprim_coupl_step_limprim_aux _ _ _ _ _ Hsome). + - apply Rcoupl_eq. +Qed. + +Lemma lim_exec_eq_erasure αs e σ : + αs ⊆ get_active σ → + lim_exec (e, σ) = foldlM state_step σ αs ≫= (λ σ', lim_exec (e, σ')). +Proof. + induction αs as [|α αs IH] in σ |-*. + { rewrite /= dret_id_left //. } + intros Hα. + eapply Rcoupl_eq_elim. + assert (lim_exec (e, σ) = state_step σ α ≫= (λ σ2, lim_exec (e, σ2))) as ->. + { apply distr_ext => v. + assert (α ∈ get_active σ) as Hel; [apply Hα; left|]. + rewrite /get_active in Hel. + apply elem_of_elements, elem_of_dom in Hel as [? ?]. + by eapply limprim_coupl_step_limprim_aux. } + rewrite foldlM_cons -dbind_assoc. + eapply Rcoupl_dbind; [|eapply Rcoupl_pos_R, Rcoupl_eq]. + intros ?? (-> & Hs%state_step_support_equiv_rel & _). + inversion_clear Hs. + rewrite IH; [eapply Rcoupl_eq|]. + intros α' ?. rewrite /get_active /=. + apply elem_of_elements. + apply elem_of_dom. + destruct (decide (α = α')); subst. + + eexists. rewrite lookup_insert //. + + rewrite lookup_insert_ne //. + apply elem_of_dom. eapply elem_of_elements, Hα. by right. +Qed. + +Lemma refRcoupl_erasure e1 σ1 e1' σ1' α α' R Φ m bs bs': + σ1.(tapes) !! α = Some bs → + σ1'.(tapes) !! α' = Some bs' → + Rcoupl (state_step σ1 α) (state_step σ1' α') R → + (∀ σ2 σ2', R σ2 σ2' → + refRcoupl (exec m (e1, σ2)) + (lim_exec (e1', σ2')) Φ ) → + refRcoupl (exec m (e1, σ1)) + (lim_exec (e1', σ1')) Φ. +Proof. + intros Hα Hα' HR Hcont. + eapply refRcoupl_eq_refRcoupl_unfoldl ; + [eapply prim_coupl_step_prim; eauto |]. + eapply refRcoupl_eq_refRcoupl_unfoldr; + [| eapply Rcoupl_eq_sym, limprim_coupl_step_limprim; eauto]. + apply (refRcoupl_dbind _ _ _ _ R); auto. + by eapply Rcoupl_refRcoupl. +Qed. + +Lemma ARcoupl_erasure e1 σ1 e1' σ1' α α' R Φ ε ε' m bs bs': + 0 <= ε -> + 0 <= ε' -> + σ1.(tapes) !! α = Some bs → + σ1'.(tapes) !! α' = Some bs' → + ARcoupl (state_step σ1 α) (state_step σ1' α') R ε → + (∀ σ2 σ2', R σ2 σ2' → + ARcoupl (exec m (e1, σ2)) + (lim_exec (e1', σ2')) Φ ε' ) → + ARcoupl (exec m (e1, σ1)) + (lim_exec (e1', σ1')) Φ (ε + ε'). +Proof. + intros Hε Hε' Hα Hα' HR Hcont. + rewrite -(Rplus_0_l (ε + ε')). + eapply ARcoupl_eq_trans_l; try lra. + - eapply ARcoupl_from_eq_Rcoupl; try lra; eauto. + eapply prim_coupl_step_prim; eauto. + - rewrite -(Rplus_0_r (ε + ε')). + eapply ARcoupl_eq_trans_r; auto; try lra; last first. + + eapply ARcoupl_from_eq_Rcoupl; try lra; eauto. + eapply Rcoupl_eq_sym, limprim_coupl_step_limprim; eauto. + + apply (ARcoupl_dbind _ _ _ _ R); auto. +Qed. + +Lemma refRcoupl_erasure_r (e1 : expr) σ1 e1' σ1' α' R Φ m bs': + to_val e1 = None → + σ1'.(tapes) !! α' = Some bs' → + Rcoupl (prim_step e1 σ1) (state_step σ1' α') R → + (∀ e2 σ2 σ2', R (e2, σ2) σ2' → refRcoupl (exec m (e2, σ2)) (lim_exec (e1', σ2')) Φ ) → + refRcoupl (exec (S m) (e1, σ1)) (lim_exec (e1', σ1')) Φ. +Proof. + intros He1 Hα' HR Hcont. + rewrite exec_Sn_not_final; [|eauto]. + eapply (refRcoupl_eq_refRcoupl_unfoldr _ (state_step σ1' α' ≫= (λ σ2', lim_exec (e1', σ2')))). + - eapply refRcoupl_dbind; [|by apply Rcoupl_refRcoupl]. + intros [] ??. by apply Hcont. + - apply Rcoupl_eq_sym. by eapply limprim_coupl_step_limprim. +Qed. + + +Lemma ARcoupl_erasure_r (e1 : expr) σ1 e1' σ1' α' R Φ ε ε' m bs': + 0 <= ε -> + 0 <= ε' -> + to_val e1 = None → + σ1'.(tapes) !! α' = Some bs' → + ARcoupl (prim_step e1 σ1) (state_step σ1' α') R ε → + (∀ e2 σ2 σ2', R (e2, σ2) σ2' → ARcoupl (exec m (e2, σ2)) (lim_exec (e1', σ2')) Φ ε' ) → + ARcoupl (exec (S m) (e1, σ1)) (lim_exec (e1', σ1')) Φ (ε + ε'). +Proof. + intros Hε Hε' He1 Hα' HR Hcont. + rewrite exec_Sn_not_final; [|eauto]. + rewrite -(Rplus_0_r (ε + ε')). + eapply (ARcoupl_eq_trans_r _ (state_step σ1' α' ≫= (λ σ2', lim_exec (e1', σ2')))); try lra. + - eapply ARcoupl_dbind; try lra; auto; [| apply HR]. + intros [] ??. by apply Hcont. + - eapply ARcoupl_from_eq_Rcoupl; [lra | ]. + apply Rcoupl_eq_sym. by eapply limprim_coupl_step_limprim. +Qed. + +Lemma refRcoupl_erasure_l (e1 e1' : expr) σ1 σ1' α R Φ m bs : + σ1.(tapes) !! α = Some bs → + Rcoupl (state_step σ1 α) (prim_step e1' σ1') R → + (∀ σ2 e2' σ2', R σ2 (e2', σ2') → refRcoupl (exec m (e1, σ2)) (lim_exec (e2', σ2')) Φ ) → + refRcoupl (exec m (e1, σ1)) (lim_exec (e1', σ1')) Φ. +Proof. + intros Hα HR Hcont. + assert (to_val e1' = None). + { apply Rcoupl_pos_R, Rcoupl_inhabited_l in HR as (?&?&?&?&?); [eauto using val_stuck|]. + rewrite state_step_mass; [lra|]. apply elem_of_dom. eauto. } + eapply (refRcoupl_eq_refRcoupl_unfoldl _ (state_step σ1 α ≫= (λ σ2, exec m (e1, σ2)))). + - by eapply prim_coupl_step_prim. + - rewrite lim_exec_step. + rewrite step_or_final_no_final; [|eauto]. + eapply refRcoupl_dbind; [|by apply Rcoupl_refRcoupl]. + intros ? [] ?. by apply Hcont. +Qed. + +Lemma ARcoupl_erasure_l (e1 e1' : expr) σ1 σ1' α R Φ ε ε' m bs : + 0 <= ε -> + 0 <= ε' -> + σ1.(tapes) !! α = Some bs → + ARcoupl (state_step σ1 α) (prim_step e1' σ1') R ε → + (∀ σ2 e2' σ2', R σ2 (e2', σ2') → ARcoupl (exec m (e1, σ2)) (lim_exec (e2', σ2')) Φ ε') → + ARcoupl (exec m (e1, σ1)) (lim_exec (e1', σ1')) Φ (ε + ε'). +Proof. + intros Hε Hε' Hα HR Hcont. + destruct (to_val e1') eqn:Hval. + - assert (prim_step e1' σ1' = dzero) as Hz. + { by eapply (is_final_dzero (e1', σ1')), to_final_Some_2. } + rewrite Hz in HR. + rewrite -(Rplus_0_l (ε + ε')). + eapply (ARcoupl_eq_trans_l _ (state_step σ1 α ≫= (λ σ2, exec m (e1, σ2)))); [lra| lra | | ]. + + apply ARcoupl_from_eq_Rcoupl; [lra |]. + by eapply prim_coupl_step_prim. + + rewrite lim_exec_step. + rewrite step_or_final_is_final; [|eauto]. + eapply ARcoupl_dbind; [lra|lra| | ]; last first. + * rewrite -(Rplus_0_r ε). + eapply ARcoupl_eq_trans_r; [lra|lra| | apply ARcoupl_dzero; lra ]. + eauto. + * intros ? [] ?. by apply Hcont. + - rewrite -(Rplus_0_l (ε + ε')). + eapply (ARcoupl_eq_trans_l _ (state_step σ1 α ≫= (λ σ2, exec m (e1, σ2)))); [lra| lra | | ]. + + apply ARcoupl_from_eq_Rcoupl; [lra |]. + by eapply prim_coupl_step_prim. + + rewrite lim_exec_step. + rewrite step_or_final_no_final; [|eauto]. + eapply ARcoupl_dbind; [lra|lra| | apply HR]. + intros ? [] ?. by apply Hcont. +Qed. + + +Lemma refRcoupl_erasure_erasable (e1 e1' : expr) σ1 σ1' μ1 μ2 R Φ n : + Rcoupl (μ1) (μ2) R -> + erasable μ1 σ1-> + erasable μ2 σ1'-> + (∀ σ2 σ2' : language.state prob_lang, R σ2 σ2' → refRcoupl (exec (S n) (e1, σ2)) (lim_exec (e1', σ2')) Φ) -> + refRcoupl (exec (S n) (e1, σ1)) (lim_exec (e1', σ1')) Φ. +Proof. + rewrite {1}/erasable. + intros Hcoupl Hμ1 Hμ2 Hcont. + rewrite -Hμ1. + erewrite <-erasable_lim_exec; last exact Hμ2. + eapply refRcoupl_dbind; try done. + by apply Rcoupl_refRcoupl. +Qed. + +Lemma ARcoupl_erasure_erasable (e1 e1' : expr) ε ε1 ε2 σ1 σ1' μ1 μ2 R Φ n : + 0 <= ε1 -> + 0 <= ε2 -> + ε1 + ε2 <= ε -> + ARcoupl (μ1) (μ2) R ε1-> + erasable μ1 σ1-> + erasable μ2 σ1'-> + (∀ σ2 σ2' : language.state prob_lang, R σ2 σ2' → ARcoupl (exec n (e1, σ2)) (lim_exec (e1', σ2')) Φ ε2) -> + ARcoupl (exec n (e1, σ1)) (lim_exec (e1', σ1')) Φ ε. +Proof. + rewrite {1}/erasable. + intros H1 H2 Hineq Hcoupl Hμ1 Hμ2 Hcont. + rewrite -Hμ1. + erewrite <-erasable_lim_exec; last exact. + eapply ARcoupl_mon_grading; first exact. + eapply ARcoupl_dbind; try done. +Qed. + +Lemma ARcoupl_erasure_erasable_exp_rhs ε1 μ1 μ1' (E2 : _ → R) R Φ (e1 e1' : expr) σ1 σ1' ε r n m : + 0 <= ε1 → + ARcoupl μ1 (σ2' ← μ1'; pexec m (e1', σ2')) R ε1 → + ε1 + Expval (σ2' ← μ1'; pexec m (e1', σ2')) E2 <= ε → + (∀ ρ, (0 <= E2 ρ <= r)%R) → + erasable μ1 σ1 → + erasable μ1' σ1' → + (∀ σ2 e2' σ2', R σ2 (e2', σ2') → + ARcoupl (exec n (e1, σ2)) (lim_exec (e2', σ2')) Φ (E2 (e2', σ2'))) → + ARcoupl (exec n (e1, σ1)) (lim_exec (e1', σ1')) Φ ε. +Proof. + intros H1 Hcoupl Hineq Hbound Hμ1 Hμ2 Hcont. + rewrite -Hμ1. + rewrite -(erasable_pexec_lim_exec μ1' m) //. + eapply ARcoupl_mon_grading; [done|]. + eapply (ARcoupl_dbind_adv_rhs' E2); [done|eauto|done| |done]. + intros ? [] ?. + by eapply Hcont. +Qed. + +Lemma ARcoupl_erasure_erasable_exp_lhs ε1 μ1' (E2 : _ → R) R Φ (e1 e1' : expr) σ1 σ1' ε r n m : + 0 <= ε1 → + ARcoupl (prim_step e1 σ1) (μ1' ≫= λ σ2', pexec m (e1', σ2')) R ε1 → + ε1 + Expval (prim_step e1 σ1) E2 <= ε → + (∀ ρ, (0 <= E2 ρ <= r)%R) → + erasable μ1' σ1' → + (∀ e2 σ2 e2' σ2', R (e2, σ2) (e2', σ2') → + ARcoupl (exec n (e2, σ2)) (lim_exec (e2', σ2')) Φ (E2 (e2, σ2))) → + ARcoupl (prim_step e1 σ1 ≫= exec n) (lim_exec (e1', σ1')) Φ ε. +Proof. + intros Hε Hcoupl Hle Hb Hμ1' Hcont. + rewrite -(erasable_pexec_lim_exec μ1' m) //. + eapply ARcoupl_mon_grading; [done|]. + eapply (ARcoupl_dbind_adv_lhs' E2); [done|eauto|done| |done]. + intros [] [] ?. by eapply Hcont. +Qed. +*) diff --git a/theories/meas_lang/exec_lang.v b/theories/meas_lang/exec_lang.v new file mode 100644 index 00000000..170043f9 --- /dev/null +++ b/theories/meas_lang/exec_lang.v @@ -0,0 +1,66 @@ +(* TODO move into metatheory.v ? *) + +From Coq Require Export Reals Psatz. +From clutch.meas_lang Require Import lang. + +(* +Lemma exec_det_step_ctx K `{!LanguageCtx K} n ρ (e1 e2 : expr) σ1 σ2 : + prim_step e1 σ1 (e2, σ2) = 1%R → + pexec n ρ (K e1, σ1) = 1%R → + pexec (S n) ρ (K e2, σ2) = 1%R. +Proof. + intros. eapply pexec_det_step; [|done]. + rewrite -fill_step_prob //. + eapply (val_stuck _ σ1 (e2, σ2)). + rewrite H. lra. +Qed. + +Lemma exec_PureExec_ctx K `{!LanguageCtx K} (P : Prop) m n ρ (e e' : expr) σ : + P → + PureExec P n e e' → + pexec m ρ (K e, σ) = 1 → + pexec (m + n) ρ (K e', σ) = 1. +Proof. + move=> HP /(_ HP). + destruct ρ as [e0 σ0]. + revert e e' m. induction n=> e e' m. + { rewrite -plus_n_O. by inversion 1. } + intros (e'' & Hsteps & Hpstep)%nsteps_inv_r Hdet. + specialize (IHn _ _ m Hsteps Hdet). + rewrite -plus_n_Sm. + eapply exec_det_step_ctx; [done| |done]. + apply Hpstep. +Qed. + +Lemma stepN_det_step_ctx K `{!LanguageCtx K} n ρ (e1 e2 : expr) σ1 σ2 : + prim_step e1 σ1 (e2, σ2) = 1%R → + stepN n ρ (K e1, σ1) = 1%R → + stepN (S n) ρ (K e2, σ2) = 1%R. +Proof. + intros. + rewrite -Nat.add_1_r. + erewrite (stepN_det_trans n 1); [done|done|]. + rewrite stepN_Sn /=. + rewrite dret_id_right. + rewrite -fill_step_prob //. + eapply (val_stuck _ σ1 (e2, σ2)). + rewrite H. lra. +Qed. + +Lemma stepN_PureExec_ctx K `{!LanguageCtx K} (P : Prop) m n ρ (e e' : expr) σ : + P → + PureExec P n e e' → + stepN m ρ (K e, σ) = 1 → + stepN (m + n) ρ (K e', σ) = 1. +Proof. + move=> HP /(_ HP). + destruct ρ as [e0 σ0]. + revert e e' m. induction n=> e e' m. + { rewrite -plus_n_O. by inversion 1. } + intros (e'' & Hsteps & Hpstep)%nsteps_inv_r Hdet. + specialize (IHn _ _ m Hsteps Hdet). + rewrite -plus_n_Sm. + eapply stepN_det_step_ctx; [done| |done]. + apply Hpstep. +Qed. +*) diff --git a/theories/meas_lang/metatheory.v b/theories/meas_lang/metatheory.v new file mode 100644 index 00000000..640e004a --- /dev/null +++ b/theories/meas_lang/metatheory.v @@ -0,0 +1,1724 @@ +From Coq Require Import Reals Psatz. +From stdpp Require Import functions gmap stringmap fin_sets. +From clutch.prelude Require Import stdpp_ext NNRbar fin uniform_list. +(* From clutch.prob Require Import distribution couplings couplings_app. *) +From clutch.meas_lang Require Import ectx_language. +From clutch.prob_lang Require Import tactics notation lang. +(* From clutch.prob Require Import distribution couplings. *) +From iris.prelude Require Import options. +Set Default Proof Using "Type*". +(* This file contains some metatheory about the [meas_lang] language *) + +(* +(* Adding a binder to a set of identifiers. *) +Local Definition set_binder_insert (x : binder) (X : stringset) : stringset := + match x with + | BAnon => X + | BNamed f => {[f]} ∪ X + end. + +(* Check if expression [e] is closed w.r.t. the set [X] of variable names, + and that all the values in [e] are closed *) +Fixpoint is_closed_expr (X : stringset) (e : expr) : bool := + match e with + | Val v => is_closed_val v + | Var x => bool_decide (x ∈ X) + | Rec f x e => is_closed_expr (set_binder_insert f (set_binder_insert x X)) e + | UnOp _ e | Fst e | Snd e | InjL e | InjR e | Load e => + is_closed_expr X e + | App e1 e2 | BinOp _ e1 e2 | Pair e1 e2 | AllocN e1 e2 | Store e1 e2 | Rand e1 e2 => + is_closed_expr X e1 && is_closed_expr X e2 + | If e0 e1 e2 | Case e0 e1 e2 => + is_closed_expr X e0 && is_closed_expr X e1 && is_closed_expr X e2 + | AllocTape e => is_closed_expr X e + | Tick e => is_closed_expr X e + end +with is_closed_val (v : val) : bool := + match v with + | LitV _ => true + | RecV f x e => is_closed_expr (set_binder_insert f (set_binder_insert x ∅)) e + | PairV v1 v2 => is_closed_val v1 && is_closed_val v2 + | InjLV v | InjRV v => is_closed_val v + end. + +(** Parallel substitution *) +Fixpoint subst_map (vs : gmap string val) (e : expr) : expr := + match e with + | Val _ => e + | Var y => if vs !! y is Some v then Val v else Var y + | Rec f y e => Rec f y (subst_map (binder_delete y (binder_delete f vs)) e) + | App e1 e2 => App (subst_map vs e1) (subst_map vs e2) + | UnOp op e => UnOp op (subst_map vs e) + | BinOp op e1 e2 => BinOp op (subst_map vs e1) (subst_map vs e2) + | If e0 e1 e2 => If (subst_map vs e0) (subst_map vs e1) (subst_map vs e2) + | Pair e1 e2 => Pair (subst_map vs e1) (subst_map vs e2) + | Fst e => Fst (subst_map vs e) + | Snd e => Snd (subst_map vs e) + | InjL e => InjL (subst_map vs e) + | InjR e => InjR (subst_map vs e) + | Case e0 e1 e2 => Case (subst_map vs e0) (subst_map vs e1) (subst_map vs e2) + | AllocN e1 e2 => AllocN (subst_map vs e1) (subst_map vs e2) + | Load e => Load (subst_map vs e) + | Store e1 e2 => Store (subst_map vs e1) (subst_map vs e2) + | AllocTape e => AllocTape (subst_map vs e) + | Rand e1 e2 => Rand (subst_map vs e1) (subst_map vs e2) + | Tick e => Tick (subst_map vs e) + end. + +(* Properties *) +Local Instance set_unfold_elem_of_insert_binder x y X Q : + SetUnfoldElemOf y X Q → + SetUnfoldElemOf y (set_binder_insert x X) (Q ∨ BNamed y = x). +Proof. destruct 1; constructor; destruct x; set_solver. Qed. + +Lemma is_closed_weaken X Y e : is_closed_expr X e → X ⊆ Y → is_closed_expr Y e. +Proof. revert X Y; induction e; naive_solver (eauto; set_solver). Qed. + +Lemma is_closed_weaken_empty X e : is_closed_expr ∅ e → is_closed_expr X e. +Proof. intros. by apply is_closed_weaken with ∅, empty_subseteq. Qed. + +Lemma is_closed_subst X e y v : + is_closed_val v → + is_closed_expr ({[y]} ∪ X) e → + is_closed_expr X (subst y v e). +Proof. + intros Hv. revert X. + induction e=> X /= ?; destruct_and?; split_and?; simplify_option_eq; + try match goal with + | H : ¬(_ ∧ _) |- _ => apply not_and_l in H as [?%dec_stable|?%dec_stable] + end; eauto using is_closed_weaken with set_solver. +Qed. +Lemma is_closed_subst' X e x v : + is_closed_val v → + is_closed_expr (set_binder_insert x X) e → + is_closed_expr X (subst' x v e). +Proof. destruct x; eauto using is_closed_subst. Qed. + +Lemma subst_is_closed X e x es : is_closed_expr X e → x ∉ X → subst x es e = e. +Proof. + revert X. induction e=> X /=; + rewrite ?bool_decide_spec ?andb_True=> ??; + repeat case_decide; simplify_eq/=; f_equal; intuition eauto with set_solver. +Qed. + +Lemma subst_is_closed_empty e x v : is_closed_expr ∅ e → subst x v e = e. +Proof. intros. apply subst_is_closed with (∅:stringset); set_solver. Qed. + +Lemma subst_subst e x v v' : + subst x v (subst x v' e) = subst x v' e. +Proof. + intros. induction e; simpl; try (f_equal; by auto); + simplify_option_eq; auto using subst_is_closed_empty with f_equal. +Qed. +Lemma subst_subst' e x v v' : + subst' x v (subst' x v' e) = subst' x v' e. +Proof. destruct x; simpl; auto using subst_subst. Qed. + +Lemma subst_subst_ne e x y v v' : + x ≠ y → subst x v (subst y v' e) = subst y v' (subst x v e). +Proof. + intros. induction e; simpl; try (f_equal; by auto); + simplify_option_eq; auto using eq_sym, subst_is_closed_empty with f_equal. +Qed. +Lemma subst_subst_ne' e x y v v' : + x ≠ y → subst' x v (subst' y v' e) = subst' y v' (subst' x v e). +Proof. destruct x, y; simpl; auto using subst_subst_ne with congruence. Qed. + +Lemma subst_rec' f y e x v : + x = f ∨ x = y ∨ x = BAnon → + subst' x v (Rec f y e) = Rec f y e. +Proof. intros. destruct x; simplify_option_eq; naive_solver. Qed. +Lemma subst_rec_ne' f y e x v : + (x ≠ f ∨ f = BAnon) → (x ≠ y ∨ y = BAnon) → + subst' x v (Rec f y e) = Rec f y (subst' x v e). +Proof. intros. destruct x; simplify_option_eq; naive_solver. Qed. + +Lemma bin_op_eval_closed op v1 v2 v' : + is_closed_val v1 → is_closed_val v2 → bin_op_eval op v1 v2 = Some v' → + is_closed_val v'. +Proof. + rewrite /bin_op_eval /bin_op_eval_bool /bin_op_eval_int /bin_op_eval_loc; + repeat case_match; by naive_solver. +Qed. + +Lemma heap_closed_alloc σ l n w : + (0 < n)%Z → + is_closed_val w → + map_Forall (λ _ v, is_closed_val v) (heap σ) → + (∀ i : Z, (0 ≤ i)%Z → (i < n)%Z → heap σ !! (l +ₗ i) = None) → + map_Forall (λ _ v, is_closed_val v) + (heap_array l (replicate (Z.to_nat n) w) ∪ heap σ). +Proof. + intros Hn Hw Hσ Hl. + eapply (map_Forall_ind + (λ k v, ((heap_array l (replicate (Z.to_nat n) w) ∪ heap σ) + !! k = Some v))). + - apply map_Forall_empty. + - intros m i x Hi Hix Hkwm Hm. + apply map_Forall_insert_2; auto. + apply lookup_union_Some in Hix; last first. + { eapply heap_array_map_disjoint; + rewrite replicate_length Z2Nat.id; auto with lia. } + destruct Hix as [(?&?&?&[-> Hlt%inj_lt]%lookup_replicate_1)%heap_array_lookup| + [j Hj]%elem_of_map_to_list%elem_of_list_lookup_1]. + + simplify_eq/=. rewrite !Z2Nat.id in Hlt; eauto with lia. + + apply map_Forall_to_list in Hσ. + by eapply Forall_lookup in Hσ; eauto; simpl in *. + - apply map_Forall_to_list, Forall_forall. + intros [? ?]; apply elem_of_map_to_list. +Qed. + +Lemma subst_map_empty e : subst_map ∅ e = e. +Proof. + assert (∀ x, binder_delete x (∅:gmap _ val) = ∅) as Hdel. + { intros [|x]; by rewrite /= ?delete_empty. } + induction e; simplify_map_eq; rewrite ?Hdel; auto with f_equal. +Qed. +Lemma subst_map_insert x v vs e : + subst_map (<[x:=v]>vs) e = subst x v (subst_map (delete x vs) e). +Proof. + revert vs. induction e=> vs; simplify_map_eq; auto with f_equal. + - match goal with + | |- context [ <[?x:=_]> _ !! ?y ] => + destruct (decide (x = y)); simplify_map_eq=> // + end. by case (vs !! _); simplify_option_eq. + - destruct (decide _) as [[??]|[<-%dec_stable|[<-%dec_stable ?]]%not_and_l_alt]. + + rewrite !binder_delete_insert // !binder_delete_delete; eauto with f_equal. + + by rewrite /= delete_insert_delete delete_idemp. + + by rewrite /= binder_delete_insert // delete_insert_delete + !binder_delete_delete delete_idemp. +Qed. +Lemma subst_map_singleton x v e : + subst_map {[x:=v]} e = subst x v e. +Proof. by rewrite subst_map_insert delete_empty subst_map_empty. Qed. + +Lemma subst_map_binder_insert b v vs e : + subst_map (binder_insert b v vs) e = + subst' b v (subst_map (binder_delete b vs) e). +Proof. destruct b; rewrite ?subst_map_insert //. Qed. +Lemma subst_map_binder_insert_empty b v e : + subst_map (binder_insert b v ∅) e = subst' b v e. +Proof. by rewrite subst_map_binder_insert binder_delete_empty subst_map_empty. Qed. + +Lemma subst_map_binder_insert_2 b1 v1 b2 v2 vs e : + subst_map (binder_insert b1 v1 (binder_insert b2 v2 vs)) e = + subst' b2 v2 (subst' b1 v1 (subst_map (binder_delete b2 (binder_delete b1 vs)) e)). +Proof. + destruct b1 as [|s1], b2 as [|s2]=> /=; auto using subst_map_insert. + rewrite subst_map_insert. destruct (decide (s1 = s2)) as [->|]. + - by rewrite delete_idemp subst_subst delete_insert_delete. + - by rewrite delete_insert_ne // subst_map_insert subst_subst_ne. +Qed. +Lemma subst_map_binder_insert_2_empty b1 v1 b2 v2 e : + subst_map (binder_insert b1 v1 (binder_insert b2 v2 ∅)) e = + subst' b2 v2 (subst' b1 v1 e). +Proof. + by rewrite subst_map_binder_insert_2 !binder_delete_empty subst_map_empty. +Qed. + +Lemma subst_map_is_closed X e vs : + is_closed_expr X e → + (∀ x, x ∈ X → vs !! x = None) → + subst_map vs e = e. +Proof. + revert X vs. assert (∀ x x1 x2 X (vs : gmap string val), + (∀ x, x ∈ X → vs !! x = None) → + x ∈ set_binder_insert x2 (set_binder_insert x1 X) → + binder_delete x1 (binder_delete x2 vs) !! x = None). + { intros x x1 x2 X vs ??. rewrite !lookup_binder_delete_None. set_solver. } + induction e=> X vs /= ? HX; repeat case_match; naive_solver eauto with f_equal. +Qed. + +Lemma subst_map_is_closed_empty e vs : is_closed_expr ∅ e → subst_map vs e = e. +Proof. intros. apply subst_map_is_closed with (∅ : stringset); set_solver. Qed. + +Local Open Scope R. + +Lemma ARcoupl_state_step_dunifP σ α N ns: + tapes σ !! α = Some (N; ns) -> + ARcoupl (state_step σ α) (dunifP N) + ( + λ σ' n, σ' = state_upd_tapes <[α := (N; ns ++ [n])]> σ + ) 0. +Proof. + intros H. + erewrite state_step_unfold; last done. + rewrite -{2}(dmap_id (dunifP N)). + apply ARcoupl_map; first lra. + apply ARcoupl_refRcoupl. + eapply refRcoupl_mono; last apply refRcoupl_eq_refl. + intros ??->. done. +Qed. + +(** * rand(N) ~ rand(N) coupling *) +Lemma Rcoupl_rand_rand N f `{Bij (fin (S N)) (fin (S N)) f} z σ1 σ1' : + N = Z.to_nat z → + Rcoupl + (prim_step (rand #z) σ1) + (prim_step (rand #z) σ1') + (λ ρ2 ρ2', ∃ (n : fin (S N)), + ρ2 = (Val #n, σ1) ∧ ρ2' = (Val #(f n), σ1')). +Proof. + intros Hz. + rewrite head_prim_step_eq /=. + rewrite head_prim_step_eq /=. + rewrite /dmap -Hz. + eapply Rcoupl_dbind; [|by eapply Rcoupl_dunif]. + intros n ? ->. + apply Rcoupl_dret. + eauto. +Qed. + +(** * rand(N, α1) ~ rand(N, α2) coupling, "wrong" N *) +Lemma Rcoupl_rand_lbl_rand_lbl_wrong N M f `{Bij (fin (S N)) (fin (S N)) f} α1 α2 z σ1 σ2 xs ys : + σ1.(tapes) !! α1 = Some (M; xs) → + σ2.(tapes) !! α2 = Some (M; ys) → + N ≠ M → + N = Z.to_nat z → + Rcoupl + (prim_step (rand(#lbl:α1) #z) σ1) + (prim_step (rand(#lbl:α2) #z) σ2) + (λ ρ2 ρ2', ∃ (n : fin (S N)), + ρ2 = (Val #n, σ1) ∧ ρ2' = (Val #(f n), σ2)). +Proof. + intros Hσ1 Hσ2 Hneq Hz. + rewrite ?head_prim_step_eq /=. + rewrite /dmap -Hz Hσ1 Hσ2. + rewrite bool_decide_eq_false_2 //. + eapply Rcoupl_dbind; [|by eapply Rcoupl_dunif]. + intros n ? ->. + apply Rcoupl_dret. + eauto. +Qed. + +(** * rand(N,α) ~ rand(N) coupling, "wrong" N *) +Lemma Rcoupl_rand_lbl_rand_wrong N M f `{Bij (fin (S N)) (fin (S N)) f} α1 z σ1 σ2 xs : + σ1.(tapes) !! α1 = Some (M; xs) → + N ≠ M → + N = Z.to_nat z → + Rcoupl + (prim_step (rand(#lbl:α1) #z) σ1) + (prim_step (rand #z) σ2) + (λ ρ2 ρ2', ∃ (n : fin (S N)), + ρ2 = (Val #n, σ1) ∧ ρ2' = (Val #(f n), σ2)). +Proof. + intros Hσ1 Hneq Hz. + rewrite ?head_prim_step_eq /=. + rewrite /dmap -Hz Hσ1. + rewrite bool_decide_eq_false_2 //. + eapply Rcoupl_dbind; [|by eapply Rcoupl_dunif]. + intros n ? ->. + apply Rcoupl_dret. + eauto. +Qed. + +(** * rand(N) ~ rand(N, α) coupling, "wrong" N *) +Lemma Rcoupl_rand_rand_lbl_wrong N M f `{Bij (fin (S N)) (fin (S N)) f} α2 z σ1 σ2 ys : + σ2.(tapes) !! α2 = Some (M; ys) → + N ≠ M → + N = Z.to_nat z → + Rcoupl + (prim_step (rand #z) σ1) + (prim_step (rand(#lbl:α2) #z) σ2) + (λ ρ2 ρ2', ∃ (n : fin (S N)), + ρ2 = (Val #n, σ1) ∧ ρ2' = (Val #(f n), σ2)). +Proof. + intros Hσ2 Hneq Hz. + rewrite ?head_prim_step_eq /=. + rewrite /dmap -Hz Hσ2. + rewrite bool_decide_eq_false_2 //. + eapply Rcoupl_dbind; [|by eapply Rcoupl_dunif]. + intros n ? ->. + apply Rcoupl_dret. + eauto. +Qed. + +(** * state_step(α, N) ~ state_step(α', N) coupling *) +Lemma Rcoupl_state_state N f `{Bij (fin (S N)) (fin (S N)) f} σ1 σ2 α1 α2 xs ys : + σ1.(tapes) !! α1 = Some (N; xs) → + σ2.(tapes) !! α2 = Some (N; ys) → + Rcoupl + (state_step σ1 α1) + (state_step σ2 α2) + (λ σ1' σ2', ∃ (n : fin (S N)), + σ1' = state_upd_tapes <[α1 := (N; xs ++ [n])]> σ1 ∧ + σ2' = state_upd_tapes <[α2 := (N; ys ++ [f n])]> σ2). +Proof. + intros Hα1 Hα2. + rewrite /state_step. + do 2 (rewrite bool_decide_eq_true_2; [|by eapply elem_of_dom_2]). + rewrite (lookup_total_correct _ _ _ Hα1). + rewrite (lookup_total_correct _ _ _ Hα2). + eapply Rcoupl_dbind; [|by apply Rcoupl_dunif]. + intros n ? ->. + apply Rcoupl_dret. eauto. +Qed. + +(** * Generalized state_step(α) ~ state_step(α') coupling *) +Lemma Rcoupl_state_step_gen (m1 m2 : nat) (R : fin (S m1) -> fin (S m2) -> Prop) σ1 σ2 α1 α2 xs ys : + σ1.(tapes) !! α1 = Some (m1; xs) → + σ2.(tapes) !! α2 = Some (m2; ys) → + Rcoupl (dunif (S m1)) (dunif (S m2)) R → + Rcoupl + (state_step σ1 α1) + (state_step σ2 α2) + (λ σ1' σ2', ∃ (n1 : fin (S m1)) (n2 : fin (S m2)), + R n1 n2 ∧ + σ1' = state_upd_tapes <[α1 := (m1; xs ++ [n1])]> σ1 ∧ + σ2' = state_upd_tapes <[α2 := (m2; ys ++ [n2])]> σ2). +Proof. + intros Hα1 Hα2 Hcoupl. + apply Rcoupl_pos_R in Hcoupl. + rewrite /state_step. + pose proof (elem_of_dom_2 _ _ _ Hα1) as Hdom1. + pose proof (elem_of_dom_2 _ _ _ Hα2) as Hdom2. + rewrite bool_decide_eq_true_2; auto. + rewrite bool_decide_eq_true_2; auto. + rewrite (lookup_total_correct _ _ _ Hα1). + rewrite (lookup_total_correct _ _ _ Hα2). + rewrite /dmap. + eapply Rcoupl_dbind; [ | apply Hcoupl ]; simpl. + intros a b (Hab & HposA & HposB). + rewrite /pmf/dunif/= in HposA. + rewrite /pmf/dunif/= in HposB. + apply Rcoupl_dret. + exists a. exists b. split; try split; auto. +Qed. + +(** * rand(unit, N) ~ state_step(α', N) coupling *) +Lemma Rcoupl_rand_state N f `{Bij (fin (S N)) (fin (S N)) f} z σ1 σ1' α' xs: + N = Z.to_nat z → + σ1'.(tapes) !! α' = Some (N; xs) → + Rcoupl + (prim_step (rand #z) σ1) + (state_step σ1' α') + (λ ρ2 σ2', ∃ (n : fin (S N)), + ρ2 = (Val #n, σ1) ∧ σ2' = state_upd_tapes <[α' := (N; xs ++ [f n])]> σ1'). +Proof. + intros Hz Hα'. + rewrite head_prim_step_eq /=. + rewrite /state_step. + rewrite bool_decide_eq_true_2; [|by eapply elem_of_dom_2] . + rewrite -Hz. + rewrite (lookup_total_correct _ _ _ Hα'). + eapply Rcoupl_dbind; [|by eapply Rcoupl_dunif]. + intros n ? ->. + apply Rcoupl_dret. eauto. +Qed. + +(** * state_step(α, N) ~ rand(unit, N) coupling *) +Lemma Rcoupl_state_rand N f `{Bij (fin (S N)) (fin (S N)) f} z σ1 σ1' α xs : + N = Z.to_nat z → + σ1.(tapes) !! α = Some (N; xs) → + Rcoupl + (state_step σ1 α) + (prim_step (rand #z) σ1') + (λ σ2 ρ2' , ∃ (n : fin (S N)), + σ2 = state_upd_tapes <[α := (N; xs ++ [n])]> σ1 ∧ ρ2' = (Val #(f n), σ1') ). +Proof. + intros Hz Hα. + rewrite head_prim_step_eq /=. + rewrite /state_step. + rewrite bool_decide_eq_true_2; [ |by eapply elem_of_dom_2] . + rewrite -Hz. + rewrite (lookup_total_correct _ _ _ Hα). + eapply Rcoupl_dbind; [ |by eapply Rcoupl_dunif]. + intros n ? ->. + apply Rcoupl_dret. eauto. +Qed. + +Lemma Rcoupl_rand_r `{Countable A} N z (a : A) σ1' : + N = Z.to_nat z → + Rcoupl + (dret a) + (prim_step (rand #z) σ1') + (λ a' ρ2', ∃ (n : fin (S N)), a' = a ∧ ρ2' = (Val #n, σ1')). +Proof. + intros ?. + assert (head_reducible (rand #z) σ1') as hr by solve_red. + rewrite head_prim_step_eq //. + eapply Rcoupl_mono. + - apply Rcoupl_pos_R, Rcoupl_trivial. + all : auto using dret_mass, head_step_mass. + - intros ? [] (_ & hh%dret_pos & ?). + inv_head_step; eauto. +Qed. + +(** * e1 ~ rand(α', N) coupling for α' ↪ₛ (N, []) *) +Lemma Rcoupl_rand_empty_r `{Countable A} N z (a : A) σ1' α' : + N = Z.to_nat z → + tapes σ1' !! α' = Some (N; []) → + Rcoupl + (dret a) + (prim_step (rand(#lbl:α') #z) σ1') + (λ a' ρ2', ∃ (n : fin (S N)), a' = a ∧ ρ2' = (Val #n, σ1')). +Proof. + intros ??. + assert (head_reducible (rand(#lbl:α') #z) σ1') as hr by solve_red. + rewrite head_prim_step_eq //. + eapply Rcoupl_mono. + - apply Rcoupl_pos_R, Rcoupl_trivial. + all : auto using dret_mass, head_step_mass. + - intros ? [] (_ & hh%dret_pos & ?). + inv_head_step; eauto. +Qed. + +Lemma Rcoupl_rand_wrong_r `{Countable A} N M z (a : A) ns σ1' α' : + N = Z.to_nat z → + N ≠ M → + tapes σ1' !! α' = Some (M; ns) → + Rcoupl + (dret a) + (prim_step (rand(#lbl:α') #z) σ1') + (λ a' ρ2', ∃ (n : fin (S N)), a' = a ∧ ρ2' = (Val #n, σ1')). +Proof. + intros ???. + assert (head_reducible (rand(#lbl:α') #z) σ1') as hr by solve_red. + rewrite head_prim_step_eq //. + eapply Rcoupl_mono. + - apply Rcoupl_pos_R, Rcoupl_trivial. + all : auto using dret_mass, head_step_mass. + - intros ? [] (_ & hh%dret_pos & ?). + inv_head_step; eauto. +Qed. + +Lemma S_INR_le_compat (N M : nat) : + (N <= M)%R -> + (0 < S N <= S M)%R. +Proof. + split; [| do 2 rewrite S_INR; lra ]. + rewrite S_INR. + apply Rplus_le_lt_0_compat; [ apply pos_INR | lra]. +Qed. + +(** * Approximate rand(N) ~ rand(M) coupling, N <= M *) +Lemma ARcoupl_rand_rand (N M : nat) z w σ1 σ1' (ε : nonnegreal) : + (N ≤ M)%nat → + (((S M - S N) / S M) = ε)%R → + N = Z.to_nat z → + M = Z.to_nat w → + ARcoupl + (prim_step (rand #z) σ1) + (prim_step (rand #w) σ1') + (λ ρ2 ρ2', ∃ (n : fin (S N)) (m : fin (S M)), + (fin_to_nat n = m) ∧ + ρ2 = (Val #n, σ1) ∧ ρ2' = (Val #m, σ1')) + ε. +Proof. + intros NMpos NMε Hz Hw. + rewrite ?head_prim_step_eq /=. + rewrite /dmap -Hz -Hw. + replace ε with (nnreal_plus ε nnreal_zero); last first. + { apply nnreal_ext; simpl; lra. } + eapply ARcoupl_dbind. + 1,2: apply cond_nonneg. + 2 : { + rewrite -NMε. + eapply ARcoupl_dunif_leq. + split; real_solver. + } + intros n m Hnm. + apply ARcoupl_dret; [done|]. + exists n . exists m. + by rewrite Hnm //. +Qed. + +(** * Approximate rand(N) ~ rand(M) coupling, N <= M, along an injection *) +Lemma ARcoupl_rand_rand_inj (N M : nat) f `{Inj (fin (S N)) (fin (S M)) (=) (=) f} z w σ1 σ1' (ε : nonnegreal) : + (N <= M)%nat → + ((S M - S N) / S M = ε)%R → + N = Z.to_nat z → + M = Z.to_nat w → + ARcoupl + (prim_step (rand #z) σ1) + (prim_step (rand #w) σ1') + (λ ρ2 ρ2', ∃ (n : fin (S N)), + ρ2 = (Val #n, σ1) ∧ ρ2' = (Val #(f n), σ1')) + ε. +Proof. + intros NMpos NMε Hz Hw. + rewrite ?head_prim_step_eq /=. + rewrite /dmap -Hz -Hw. + replace ε with (nnreal_plus ε nnreal_zero); last first. + { apply nnreal_ext; simpl; lra. } + eapply ARcoupl_dbind. + 1,2: apply cond_nonneg. + 2 : { + rewrite -NMε. + eapply ARcoupl_dunif_leq_inj; eauto. + apply S_INR_le_compat. real_solver. + } + intros n m Hnm. + apply ARcoupl_dret; [done|]. + exists n . + by rewrite Hnm //. +Qed. + +(** * Approximate rand(N) ~ rand(M) coupling, M <= N *) +Lemma ARcoupl_rand_rand_rev (N M : nat) z w σ1 σ1' (ε : nonnegreal) : + (M <= N)%nat → + (((S N - S M) / S N) = ε)%R → + N = Z.to_nat z → + M = Z.to_nat w → + ARcoupl + (prim_step (rand #z) σ1) + (prim_step (rand #w) σ1') + (λ ρ2 ρ2', ∃ (n : fin (S N)) (m : fin (S M)), + (fin_to_nat n = m) ∧ + ρ2 = (Val #n, σ1) ∧ ρ2' = (Val #m, σ1')) + ε. +Proof. + intros NMpos NMε Hz Hw. + rewrite ?head_prim_step_eq /=. + rewrite /dmap -Hz -Hw. + replace ε with (nnreal_plus ε nnreal_zero); last first. + { apply nnreal_ext; simpl; lra. } + eapply ARcoupl_dbind. + 1,2: apply cond_nonneg. + 2 : { + rewrite -NMε. + eapply ARcoupl_dunif_leq_rev, S_INR_le_compat. + real_solver. + } + intros n m Hnm. + apply ARcoupl_dret; [done|]. + exists n . exists m. + by rewrite Hnm //. +Qed. + + +(** * Approximate rand(N) ~ rand(M) coupling, M <= N, along an injection *) +Lemma ARcoupl_rand_rand_rev_inj (N M : nat) f `{Inj (fin (S M)) (fin (S N)) (=) (=) f} z w σ1 σ1' (ε : nonnegreal) : + (M <= N)%nat → + (((S N - S M) / S N) = ε)%R → + N = Z.to_nat z → + M = Z.to_nat w → + ARcoupl + (prim_step (rand #z) σ1) + (prim_step (rand #w) σ1') + (λ ρ2 ρ2', ∃ (m : fin (S M)), + ρ2 = (Val #(f m), σ1) ∧ ρ2' = (Val #m, σ1')) + ε. +Proof. + intros NMpos NMε Hz Hw. + rewrite ?head_prim_step_eq /=. + rewrite /dmap -Hz -Hw. + replace ε with (nnreal_plus ε nnreal_zero); last first. + { apply nnreal_ext; simpl; lra. } + eapply ARcoupl_dbind. + 1,2: apply cond_nonneg. + 2 : { + rewrite -NMε. + eapply ARcoupl_dunif_leq_rev_inj, S_INR_le_compat; [done|]. + real_solver. + } + intros n m Hnm. + apply ARcoupl_dret; [done|]. + exists m. + by rewrite Hnm //. +Qed. + + +(** * Approximate state_step(α, N) ~ state_step(α', N) coupling *) +Lemma ARcoupl_state_state (N M : nat) σ1 σ2 α1 α2 xs ys (ε : nonnegreal) : + (N <= M)%nat → + (((S M - S N) / S M) = ε)%R → + σ1.(tapes) !! α1 = Some (N; xs) → + σ2.(tapes) !! α2 = Some (M; ys) → + ARcoupl + (state_step σ1 α1) + (state_step σ2 α2) + (λ σ1' σ2', ∃ (n : fin (S N)) (m : fin (S M)), + (fin_to_nat n = m) ∧ + σ1' = state_upd_tapes <[α1 := (N; xs ++ [n])]> σ1 ∧ + σ2' = state_upd_tapes <[α2 := (M; ys ++ [m])]> σ2) + ε. +Proof. + intros NMpos NMε Hα1 Hα2. + rewrite /state_step. + do 2 (rewrite bool_decide_eq_true_2; [|by eapply elem_of_dom_2]). + rewrite (lookup_total_correct _ _ _ Hα1). + rewrite (lookup_total_correct _ _ _ Hα2). + replace ε with (nnreal_plus ε nnreal_zero); last first. + { apply nnreal_ext; simpl; lra. } + unshelve eapply ARcoupl_dbind. + { exact (λ (n : fin (S N)) (m : fin (S M)), fin_to_nat n = m). } + { destruct ε ; done. } { simpl ; lra. } + 2: { rewrite -NMε. apply ARcoupl_dunif_leq, S_INR_le_compat. real_solver. } + intros n m nm. + apply ARcoupl_dret; [done|]. + simpl in nm. eauto. +Qed. + +Lemma ARcoupl_state_state_rev (N M : nat) σ1 σ2 α1 α2 xs ys (ε : nonnegreal) : + (M <= N)%nat → + (((S N - S M) / S N) = ε)%R → + σ1.(tapes) !! α1 = Some (N; xs) → + σ2.(tapes) !! α2 = Some (M; ys) → + ARcoupl + (state_step σ1 α1) + (state_step σ2 α2) + (λ σ1' σ2', ∃ (n : fin (S N)) (m : fin (S M)), + (fin_to_nat n = m) ∧ + σ1' = state_upd_tapes <[α1 := (N; xs ++ [n])]> σ1 ∧ + σ2' = state_upd_tapes <[α2 := (M; ys ++ [m])]> σ2) + ε. +Proof. + intros NMpos NMε Hα1 Hα2. + rewrite /state_step. + do 2 (rewrite bool_decide_eq_true_2; [|by eapply elem_of_dom_2]). + rewrite (lookup_total_correct _ _ _ Hα1). + rewrite (lookup_total_correct _ _ _ Hα2). + replace ε with (nnreal_plus ε nnreal_zero); last first. + { apply nnreal_ext; simpl; lra. } + unshelve eapply ARcoupl_dbind. + { exact (λ (n : fin (S N)) (m : fin (S M)), fin_to_nat n = m). } + { destruct ε ; done. } { simpl ; lra. } + 2: { rewrite -NMε. apply ARcoupl_dunif_leq_rev, S_INR_le_compat. real_solver. } + intros n m nm. + apply ARcoupl_dret; [done|]. + simpl in nm. eauto. +Qed. + +Lemma ARcoupl_rand_no_coll_l `{Countable A} N (x : fin (S N)) z (σ : state) (a : A) (ε : nonnegreal) : + (1 / S N = ε)%R → + N = Z.to_nat z → + ARcoupl + (prim_step (rand #z) σ) + (dret a) + (λ ρ a', ∃ n : fin (S N), + ρ = (Val #n, σ) ∧ (n ≠ x) ∧ a' = a) + ε. +Proof. + intros Nε Nz. + rewrite head_prim_step_eq /=. + rewrite -Nz. + rewrite -(dmap_dret (λ x, x) _) /dmap. + replace ε with (ε + nnreal_zero)%NNR by (apply nnreal_ext ; simpl ; lra). + eapply ARcoupl_dbind ; [destruct ε ; done | simpl ; lra |..]; last first. + { rewrite -Nε. apply (ARcoupl_dunif_no_coll_l _ _ x). real_solver. } + move => n ? [xn ->]. apply ARcoupl_dret; [done|]. + exists n. auto. +Qed. + +Lemma ARcoupl_rand_no_coll_r `{Countable A} N (x : fin (S N)) z (σₛ : state) (a : A) (ε : nonnegreal) : + (1 / S N = ε)%R → + N = Z.to_nat z → + ARcoupl + (dret a) + (prim_step (rand #z) σₛ) + (λ a' ρₛ, ∃ n : fin (S N), + a' = a ∧ ρₛ = (Val #n, σₛ) ∧ (n ≠ x)) + ε. +Proof. + intros Nε Nz. + rewrite head_prim_step_eq /=. + rewrite -Nz. + rewrite -(dmap_dret (λ x, x) _). + rewrite /dmap. + replace ε with (nnreal_plus ε nnreal_zero) by (apply nnreal_ext ; simpl ; lra). + eapply ARcoupl_dbind ; [destruct ε ; done | simpl ; lra |..]. + 2: rewrite -Nε; apply (ARcoupl_dunif_no_coll_r _ _ x); real_solver. + move => ? n [-> xn]. apply ARcoupl_dret; [done|]. + exists n. auto. +Qed. + +(** * a coupling between rand n and rand n avoiding results from a list *) +Lemma ARcoupl_rand_rand_avoid_list (N : nat) z σ1 σ1' (ε : nonnegreal) l: + NoDup l -> + (length l / S N = ε)%R → + N = Z.to_nat z → + ARcoupl + (prim_step (rand #z) σ1) + (prim_step (rand #z) σ1') + (λ ρ2 ρ2', ∃ (n : fin (S N)), + (n∉l)/\ + ρ2 = (Val #n, σ1) ∧ ρ2' = (Val #n, σ1')) + ε. +Proof. + intros Hl Hε Hz. + rewrite !head_prim_step_eq /=. + rewrite /dmap -Hz. + replace ε with (nnreal_plus ε nnreal_zero); last first. + { apply nnreal_ext; simpl; lra. } + eapply ARcoupl_dbind. + 1,2: apply cond_nonneg. + 2 : { + rewrite -Hε. + by apply ARcoupl_dunif_avoid. + } + simpl. + intros n m [Hnm ->]. + apply ARcoupl_dret; [done|]. + naive_solver. +Qed. + +(** * state_step ~ fair_coin *) +Lemma state_step_fair_coin_coupl σ α bs : + σ.(tapes) !! α = Some ((1%nat; bs) : tape) → + Rcoupl + (state_step σ α) + fair_coin + (λ σ' b, σ' = state_upd_tapes (<[α := (1%nat; bs ++ [bool_to_fin b])]>) σ). +Proof. + intros Hα. + exists (dmap (λ b, (state_upd_tapes (<[α := (1%nat; bs ++ [bool_to_fin b]) : tape]>) σ, b)) fair_coin). + repeat split. + - rewrite /lmarg dmap_comp /state_step. + rewrite bool_decide_eq_true_2; [|by eapply elem_of_dom_2]. + rewrite lookup_total_alt Hα /=. + eapply distr_ext=> σ'. + rewrite /dmap /= /pmf /= /dbind_pmf. + rewrite SeriesC_bool SeriesC_fin2 /=. + rewrite {1 3 5 7}/pmf /=. + destruct (decide (state_upd_tapes <[α:=(1%nat; bs ++ [1%fin])]> σ = σ')); subst. + + rewrite {1 2}dret_1_1 // dret_0; [lra|]. + intros [= H%(insert_inv (tapes σ))]. simplify_eq. + + destruct (decide (state_upd_tapes <[α:=(1%nat; bs ++ [0%fin])]> σ = σ')); subst. + * rewrite {1 2}dret_0 // dret_1_1 //. lra. + * rewrite !dret_0 //. lra. + - rewrite /rmarg dmap_comp. + assert ((snd ∘ (λ b : bool, _)) = Datatypes.id) as -> by f_equal. + rewrite dmap_id //. + - by intros [σ' b] (b' & [=-> ->] & ?)%dmap_pos=>/=. +Qed. + +(** * state_step ≫= state_step ~ dprod fair_coin fair_coin *) +Lemma state_steps_fair_coins_coupl (σ : state) (α1 α2 : loc) (bs1 bs2 : list (fin 2)): + α1 ≠ α2 → + σ.(tapes) !! α1 = Some ((1%nat; bs1) : tape) → + σ.(tapes) !! α2 = Some ((1%nat; bs2) : tape) → + Rcoupl + (state_step σ α1 ≫= (λ σ', state_step σ' α2)) + (dprod fair_coin fair_coin) + (λ σ' '(b1, b2), + σ' = (state_upd_tapes (<[α1 := (1%nat; bs1 ++ [bool_to_fin b1])]>) + (state_upd_tapes (<[α2 := (1%nat; bs2 ++ [bool_to_fin b2])]>) σ))). +Proof. + intros Hneq Hα1 Hα2. + rewrite /dprod. + rewrite -(dret_id_right (state_step _ _ ≫= _)) -dbind_assoc. + eapply Rcoupl_dbind; [|by eapply state_step_fair_coin_coupl]. + intros σ' b1 ->. + eapply Rcoupl_dbind; [|eapply state_step_fair_coin_coupl]; last first. + { rewrite lookup_insert_ne //. } + intros σ' b2 ->. + eapply Rcoupl_dret. + rewrite /state_upd_tapes insert_commute //. +Qed. + +Lemma Rcoupl_state_1_3 σ σₛ α1 α2 αₛ (xs ys:list(fin (2))) (zs:list(fin (4))): + α1 ≠ α2 -> + σ.(tapes) !! α1 = Some (1%nat; xs) -> + σ.(tapes) !! α2 = Some (1%nat; ys) -> + σₛ.(tapes) !! αₛ = Some (3%nat; zs) -> + Rcoupl + (state_step σ α1 ≫= (λ σ1', state_step σ1' α2)) + (state_step σₛ αₛ) + (λ σ1' σ2', ∃ (x y:fin 2) (z:fin 4), + σ1' = state_upd_tapes <[α2 := (1%nat; ys ++ [y])]> (state_upd_tapes <[α1 := (1%nat; xs ++ [x])]> σ) ∧ + σ2' = state_upd_tapes <[αₛ := (3%nat; zs ++ [z])]> σₛ /\ + (2*fin_to_nat x + fin_to_nat y = fin_to_nat z)%nat + ). +Proof. + intros Hneq H1 H2 H3. + rewrite /state_step. + do 2 (rewrite bool_decide_eq_true_2; [|by eapply elem_of_dom_2]). + rewrite (lookup_total_correct _ _ _ H1). + rewrite (lookup_total_correct _ _ _ H3). + erewrite (dbind_eq _ (λ σ, dmap + (λ n : fin 2, + state_upd_tapes <[α2:=(1%nat; ys ++ [n])]> σ) + (dunifP 1))); last first. + - done. + - intros [??] H. + rewrite dmap_pos in H. destruct H as (?&->&H). + rewrite bool_decide_eq_true_2; last first. + { eapply elem_of_dom_2. by rewrite /state_upd_tapes/=lookup_insert_ne. } + rewrite lookup_total_insert_ne; last done. + rewrite (lookup_total_correct _ _ _ H2). + done. + - pose (witness:=dmap (λ n: fin 4, ( match fin_to_nat n with + | 0%nat =>state_upd_tapes <[α2:=(1%nat; ys ++ [0%fin])]> + (state_upd_tapes <[α1:=(1%nat; xs ++ [0%fin])]> σ) + | 1%nat =>state_upd_tapes <[α2:=(1%nat; ys ++ [1%fin])]> + (state_upd_tapes <[α1:=(1%nat; xs ++ [0%fin])]> σ) + | 2%nat =>state_upd_tapes <[α2:=(1%nat; ys ++ [0%fin])]> + (state_upd_tapes <[α1:=(1%nat; xs ++ [1%fin])]> σ) + | 3%nat => state_upd_tapes <[α2:=(1%nat; ys ++ [1%fin])]> + (state_upd_tapes <[α1:=(1%nat; xs ++ [1%fin])]> σ) + | _ => σ + end + ,state_upd_tapes <[αₛ:=(3%nat; zs ++ [n])]> σₛ) + )(dunifP 3)). + exists witness. + split; last first. + + intros [??]. + rewrite /witness dmap_pos. + intros [?[??]]. + repeat (inv_fin x => x); simpl in *; simplify_eq => _; naive_solver. + + rewrite /witness. split. + -- rewrite /lmarg dmap_comp. + erewrite dmap_eq; last first. + ** done. + ** intros ??. simpl. done. + ** apply distr_ext. intros s. + (** prove left marginal of witness is correct *) + rewrite {1}/dmap{1}/dbind/dbind_pmf{1}/pmf. + etrans; last first. + { (** simplify the RHS *) + rewrite /dmap/dbind/dbind_pmf/pmf/=. + erewrite (SeriesC_ext _ (λ a, + if (bool_decide (a ∈ [state_upd_tapes <[α1:=(1%nat; xs ++ [0%fin])]> σ; state_upd_tapes <[α1:=(1%nat; xs ++ [1%fin])]> σ])) + then + SeriesC (λ a0 : fin 2, / (1 + 1) * dret_pmf (state_upd_tapes <[α1:=(1%nat; xs ++ [a0])]> σ) a) * + SeriesC (λ a0 : fin 2, / (1 + 1) * dret_pmf (state_upd_tapes <[α2:=(1%nat; ys ++ [a0])]> a) s) + else 0)); first rewrite SeriesC_list/=. + - by rewrite !SeriesC_finite_foldr/dret_pmf/=. + - repeat constructor; last (set_unfold; naive_solver). + rewrite elem_of_list_singleton. move /state_upd_tapes_same'. done. + - intros [??]. + case_bool_decide; first done. + apply Rmult_eq_0_compat_r. + set_unfold. + rewrite SeriesC_finite_foldr/dret_pmf/=. + repeat case_bool_decide; try lra; naive_solver. + } + pose proof state_upd_tapes_same' as K1. + pose proof state_upd_tapes_neq' as K2. + case_bool_decide; last done. + rewrite (bool_decide_eq_false_2 (state_upd_tapes <[α1:=(1%nat; xs ++ [0%fin])]> σ = + state_upd_tapes <[α1:=(1%nat; xs ++ [1%fin])]> σ)); last first. + { apply K2. done. } + rewrite (bool_decide_eq_false_2 (state_upd_tapes <[α1:=(1%nat; xs ++ [1%fin])]> σ = + state_upd_tapes <[α1:=(1%nat; xs ++ [0%fin])]> σ)); last first. + { apply K2. done. } + rewrite (bool_decide_eq_true_2 (state_upd_tapes <[α1:=(1%nat; xs ++ [1%fin])]> σ = + state_upd_tapes <[α1:=(1%nat; xs ++ [1%fin])]> σ)); last done. + rewrite !Rmult_0_r. + rewrite SeriesC_finite_foldr/dunifP /dunif/pmf /=/dret_pmf. + case_bool_decide. + { repeat rewrite bool_decide_eq_false_2. + - lra. + - subst. intro K. simplify_eq. rewrite map_eq_iff in K. + specialize (K α2). rewrite !lookup_insert in K. simplify_eq. + - subst. intro K. simplify_eq. rewrite map_eq_iff in K. + specialize (K α1). rewrite lookup_insert_ne in K; last done. + rewrite (lookup_insert_ne (<[_:=_]> _ )) in K; last done. + rewrite !lookup_insert in K. simplify_eq. + - subst. intro K. simplify_eq. rewrite map_eq_iff in K. + specialize (K α2). rewrite !lookup_insert in K. simplify_eq. + } + case_bool_decide. + { repeat rewrite bool_decide_eq_false_2. + - lra. + - subst. intro K. simplify_eq. rewrite map_eq_iff in K. + specialize (K α1). rewrite lookup_insert_ne in K; last done. + rewrite (lookup_insert_ne (<[_:=_]> _ )) in K; last done. + rewrite !lookup_insert in K. simplify_eq. + - subst. intro K. simplify_eq. rewrite map_eq_iff in K. + specialize (K α2). rewrite !lookup_insert in K. simplify_eq. + } + case_bool_decide. + { repeat rewrite bool_decide_eq_false_2. + - lra. + - subst. intro K. simplify_eq. rewrite map_eq_iff in K. + specialize (K α2). rewrite !lookup_insert in K. simplify_eq. + } + lra. + -- rewrite /rmarg dmap_comp. + f_equal. +Qed. + +Lemma iterM_state_step_unfold σ (N p:nat) α xs : + σ.(tapes) !! α = Some (N%nat; xs) -> + (iterM p (λ σ1', state_step σ1' α) σ) = + dmap (λ v, state_upd_tapes <[α := (N%nat; xs ++ v)]> σ) + (dunifv N p). +Proof. + revert σ N α xs. + induction p as [|p' IH]. + { (* base case *) + intros. simpl. + apply distr_ext. + intros. rewrite /dret/dret_pmf{1}/pmf/=. + rewrite dmap_unfold_pmf. + + (** Why doesnt this work?? *) + (* rewrite (@SeriesC_subset _ _ _ (λ x, x= (nil:list (fin (S N)))) _ _). *) + (* - rewrite (SeriesC_singleton_dependent nil). *) + + erewrite (SeriesC_ext ). + - erewrite (SeriesC_singleton_dependent [] (λ a0:list (fin (S N)), dunifv N 0 a0 * (if bool_decide (a = state_upd_tapes <[α:=(N; xs ++ a0)]> σ) then 1 else 0))). + rewrite dunifv_pmf. simpl. + case_bool_decide. + + rewrite bool_decide_eq_true_2; [lra|]. rewrite app_nil_r. + rewrite state_upd_tapes_no_change; done. + + rewrite bool_decide_eq_false_2; [lra|]. rewrite app_nil_r. + rewrite state_upd_tapes_no_change; done. + - intros. simpl. + symmetry. + case_bool_decide; first done. + rewrite dunifv_pmf. + rewrite bool_decide_eq_false_2; first lra. + intros ?%nil_length_inv. done. + } + (* inductive case *) + intros σ N α xs Ht. + replace (S p') with (p'+1)%nat; last lia. + rewrite iterM_plus; simpl. + erewrite IH; last done. + erewrite dbind_ext_right; last first. + { intros. apply dret_id_right. } + apply distr_ext. intros σ'. rewrite dmap_unfold_pmf. + replace (p'+1)%nat with (S p') by lia. + assert (Decision (∃ v: list (fin (S N)), length v = S p' /\ σ' = state_upd_tapes <[α:=(N; xs ++ v)]> σ)) as Hdec. + { (* can be improved *) + apply make_decision. } + destruct (decide (∃ v: list (fin (S N)), length v = S p' /\ σ' = state_upd_tapes <[α:=(N; xs ++ v)]> σ)) as [K|K]. + - (* σ' is reachable *) + destruct K as [v [Hlen ->]]. + rewrite (SeriesC_subset (λ a, a = v)); last first. + { intros a Ha. + rewrite bool_decide_eq_false_2; first lra. + move => /state_upd_tapes_same. intros L. simplify_eq. + } + rewrite SeriesC_singleton_dependent. rewrite bool_decide_eq_true_2; last done. + rewrite dunifv_pmf. rewrite Rmult_1_r. + remember (/ (S N ^ S p')%nat) as val eqn:Hval. + rewrite /dbind/dbind_pmf{1}/pmf/=. + rewrite (SeriesC_subset (λ a, a = state_upd_tapes <[α := (N; xs ++ take p' v)]> σ)). + + rewrite SeriesC_singleton_dependent. erewrite state_step_unfold; last first. + { simpl. rewrite lookup_insert. done. } + rewrite !dmap_unfold_pmf. + rewrite (SeriesC_subset (λ a, a = take p' v)); last first. + { intros. rewrite bool_decide_eq_false_2; first lra. + move => /state_upd_tapes_same. intros L. simplify_eq. + } + rewrite SeriesC_singleton_dependent. + rewrite bool_decide_eq_true_2; last done. + rewrite dunifv_pmf. + rewrite bool_decide_eq_true_2; last first. + { rewrite firstn_length_le; lia. } + assert (is_Some (last v)) as [x Hsome]. + { rewrite last_is_Some. intros ?. subst. done. } + rewrite (SeriesC_subset (λ a, a = x)); last first. + { intros a H'. rewrite bool_decide_eq_false_2; first lra. + rewrite state_upd_tapes_twice. move => /state_upd_tapes_same. + rewrite <-app_assoc. intros K. + simplify_eq. apply H'. + rewrite -K in Hsome. + rewrite last_snoc in Hsome. by simplify_eq. + } + rewrite SeriesC_singleton_dependent. + rewrite /dunifP dunif_pmf. rewrite bool_decide_eq_true_2; last rewrite state_upd_tapes_twice. + * rewrite bool_decide_eq_true_2; last done. + rewrite Hval. + cut ( / INR (S N ^ p') * (/ INR (S N)) = / INR (S N ^ S p')); first lra. + rewrite -Rinv_mult. + f_equal. rewrite -mult_INR. f_equal. simpl. lia. + * rewrite -app_assoc. repeat f_equal. + rewrite <-(firstn_all v) at 1. + rewrite Hlen. erewrite take_S_r; first done. + rewrite -Hsome last_lookup. + f_equal. rewrite Hlen. done. + + (* prove that σ' is not an intermediate step*) + intros σ'. + intros Hσ. + assert (dmap (λ v0, state_upd_tapes <[α:=(N; xs ++ v0)]> σ) (dunifv N p') σ' * state_step σ' α (state_upd_tapes <[α:=(N; xs ++ v)]> σ) >= 0) as [H|H]; last done. + { apply Rle_ge. apply Rmult_le_pos; auto. } + exfalso. + apply Rmult_pos_cases in H as [[H1 H2]|[? H]]; last first. + { pose proof pmf_pos (state_step σ' α) (state_upd_tapes <[α:=(N; xs ++ v)]> σ). lra. } + rewrite dmap_pos in H1. + destruct H1 as [v' [-> H1]]. + apply Hσ. repeat f_equal. + erewrite state_step_unfold in H2; last first. + { simpl. apply lookup_insert. } + apply dmap_pos in H2. + destruct H2 as [a [H2?]]. + rewrite state_upd_tapes_twice in H2. + apply state_upd_tapes_same in H2. rewrite -app_assoc in H2. simplify_eq. + rewrite take_app_length'; first done. + rewrite app_length in Hlen. simpl in *; lia. + - (* σ' is not reachable, i.e. both sides are zero *) + rewrite SeriesC_0; last first. + { intros x. + assert (0<=dunifv N (S p') x) as [H|<-] by auto; last lra. + apply Rlt_gt in H. + rewrite -dunifv_pos in H. + rewrite bool_decide_eq_false_2; [lra|naive_solver]. + } + rewrite /dbind/dbind_pmf{1}/pmf/=. + apply SeriesC_0. + intros σ''. + rewrite /dmap/dbind/dbind_pmf{1}/pmf/=. + setoid_rewrite dunifv_pmf. + assert (SeriesC + (λ a : list (fin (S N)), + (if bool_decide (length a = p') then / (S N ^ p')%nat else 0) * + dret (state_upd_tapes <[α:=(N; xs ++ a)]> σ) σ'') * state_step σ'' α σ' >= 0) as [H|H]; last done. + { apply Rle_ge. apply Rmult_le_pos; auto. apply SeriesC_ge_0'. + intros. apply Rmult_le_pos; auto. case_bool_decide; try lra. + rewrite -Rdiv_1_l. apply Rcomplements.Rdiv_le_0_compat; first lra. + apply lt_0_INR. + epose proof Nat.pow_le_mono_r (S N) 0 p' _ _ as H0; simpl in H0; lia. + Unshelve. + all: lia. + } + apply Rmult_pos_cases in H as [[H1 H2]|[? H]]; last first. + { pose proof pmf_pos (state_step σ'' α) σ'. lra. } + epose proof SeriesC_gtz_ex _ _ H1. simpl in *. + destruct H as [v H]. + apply Rmult_pos_cases in H as [[? H]|[]]; last first. + { epose proof pmf_pos (dret (state_upd_tapes <[α:=(N; xs ++ v)]> σ)) σ''. lra. } + apply dret_pos in H; subst. + case_bool_decide; last lra. + erewrite state_step_unfold in H2; last first. + { simpl. rewrite lookup_insert. done. } + exfalso. + apply K. rewrite dmap_pos in H2. destruct H2 as [x[-> H2]]. subst. + setoid_rewrite state_upd_tapes_twice. + rewrite -app_assoc. + exists (v++[x]); rewrite app_length; simpl; split; first lia. done. + Unshelve. + simpl. + intros. case_bool_decide; last real_solver. + apply Rmult_le_pos; last auto. + rewrite -Rdiv_1_l. apply Rcomplements.Rdiv_le_0_compat; first lra. + apply lt_0_INR. + epose proof Nat.pow_le_mono_r (S N) 0 p' _ _ as H0; simpl in H0; lia. + Unshelve. + all: lia. +Qed. + +Lemma Rcoupl_state_state_exp N p M σ σₛ α αₛ xs zs + (f:(list (fin (S N))) -> fin (S M)) + (Hinj: forall l1 l2, length l1 = p -> length l2 = p -> f l1 = f l2 -> l1 = l2): + (S N ^ p = S M)%nat-> + σ.(tapes) !! α = Some (N%nat; xs) -> + σₛ.(tapes) !! αₛ = Some (M%nat; zs) -> + Rcoupl + (iterM p (λ σ1', state_step σ1' α) σ) + (state_step σₛ αₛ) + (λ σ1' σ2', ∃ (xs':list(fin (S N))) (z:fin (S M)), + length xs' = p /\ + σ1' = state_upd_tapes <[α := (N%nat; xs ++ xs')]> σ ∧ + σ2' = state_upd_tapes <[αₛ := (M%nat; zs ++ [z])]> σₛ /\ + f xs' = z + ). +Proof. + intros H Hσ Hσₛ. + erewrite state_step_unfold; last done. + erewrite iterM_state_step_unfold; last done. + apply Rcoupl_dmap. + exists (dmap (λ v, (v, f v)) (dunifv N p)). + split. + - split; apply distr_ext. + + intros v. rewrite lmarg_pmf. + rewrite (SeriesC_ext _ + (λ b : fin (S M), if bool_decide (b=f v) then dmap (λ v0, (v0, f v0)) (dunifv N p) (v, b) else 0)). + * rewrite SeriesC_singleton_dependent. rewrite dmap_unfold_pmf. + rewrite (SeriesC_ext _ + (λ a, if bool_decide (a = v) then dunifv N p a * (if bool_decide ((v, f v) = (a, f a)) then 1 else 0) else 0)). + { rewrite SeriesC_singleton_dependent. rewrite bool_decide_eq_true_2; first lra. + done. } + intros. + case_bool_decide; simplify_eq. + -- rewrite bool_decide_eq_true_2; done. + -- rewrite bool_decide_eq_false_2; first lra. + intros ->. done. + * intros. case_bool_decide; first done. + rewrite dmap_unfold_pmf. + setoid_rewrite bool_decide_eq_false_2. + -- rewrite SeriesC_scal_r; lra. + -- intros ?. simplify_eq. + + intros a. + rewrite rmarg_pmf. + assert (∃ x, length x = p /\ f x = a) as [x [H1 H2]]. + { + assert (Surj eq (λ x:vec(fin(S N)) p, f (vec_to_list x)) ) as K. + - apply finite_inj_surj; last first. + + rewrite vec_card !fin_card. + done. + + intros v1 v2 Hf. + apply vec_to_list_inj2. + apply Hinj; last done. + * by rewrite vec_to_list_length. + * by rewrite vec_to_list_length. + - pose proof K a as [v K']. + subst. + exists (vec_to_list v). split; last done. + apply vec_to_list_length. + } + rewrite (SeriesC_subset (λ x', x' = x)). + * rewrite SeriesC_singleton_dependent. rewrite dmap_unfold_pmf. + rewrite (SeriesC_subset (λ x', x' = x)). + -- rewrite SeriesC_singleton_dependent. rewrite bool_decide_eq_true_2; last by subst. + rewrite dunifv_pmf /dunifP dunif_pmf. + rewrite bool_decide_eq_true_2; last done. rewrite H. lra. + -- intros. subst. rewrite bool_decide_eq_false_2; first lra. + naive_solver. + * intros ? H0. subst. rewrite dmap_unfold_pmf. + apply SeriesC_0. intros x0. + assert (0<=dunifv N (length x) x0) as [H1|<-] by auto; last lra. + apply Rlt_gt in H1. rewrite <-dunifv_pos in H1. + rewrite bool_decide_eq_false_2; first lra. + intros ?. simplify_eq. + apply H0. by apply Hinj. + - intros []. rewrite dmap_pos. + intros [?[? Hpos]]. simplify_eq. + rewrite -dunifv_pos in Hpos. + naive_solver. +Qed. + +Lemma Rcoupl_fragmented_rand_rand_inj (N M: nat) (f: fin (S M) -> fin (S N)) (Hinj: Inj (=) (=) f) σ σₛ ms ns α αₛ: + (M<=N)%nat → + σ.(tapes) !! α = Some (N%nat; ns) → + σₛ.(tapes) !! αₛ = Some (M%nat; ms) → + Rcoupl + (state_step σ α) + (dunifP N≫= λ x, if bool_decide (∃ m, f m = x) then state_step σₛ αₛ else dret σₛ) + (λ σ1' σ2', ∃ (n : fin (S N)), + if bool_decide (∃ m, f m = n) + then ∃ (m : fin (S M)), + σ1' = state_upd_tapes <[α := (N; ns ++ [n])]> σ ∧ + σ2' = state_upd_tapes <[αₛ := (M; ms ++ [m])]> σₛ /\ + f m = n + else + σ1' = state_upd_tapes <[α := (N; ns ++ [n])]> σ ∧ + σ2' = σₛ + ). +Proof. + intros Hineq Hσ Hσₛ. (* rewrite <-(dret_id_right (state_step _ _)). *) + replace (0)%NNR with (0+0)%NNR; last first. + { apply nnreal_ext. simpl. lra. } + erewrite (distr_ext (dunifP _ ≫= _) + (MkDistr (dunifP N ≫= (λ x : fin (S N), + match ClassicalEpsilon.excluded_middle_informative + (∃ m, f m = x) + with + | left Hproof => + dret (state_upd_tapes <[αₛ:=(M; ms ++ [epsilon Hproof])]> σₛ) + | _ => + dret σₛ + end)) _ _ _) ); last first. + { intros σ'. simpl. rewrite /pmf/=. + rewrite /dbind_pmf. rewrite /dunifP. setoid_rewrite dunif_pmf. + rewrite !SeriesC_scal_l. apply Rmult_eq_compat_l. + erewrite (SeriesC_ext _ + (λ x : fin (S N), (if bool_decide (∃ m : fin (S M), f m = x) then state_step σₛ αₛ σ' else 0) + + (if bool_decide (∃ m : fin (S M), f m = x) then 0 else dret σₛ σ') + )); last first. + { intros. case_bool_decide; lra. } + trans (SeriesC + (λ x : fin (S N), + match ClassicalEpsilon.excluded_middle_informative + (∃ m, f m = x) with + | left Hproof => dret (state_upd_tapes <[αₛ:=(M; ms ++ [epsilon Hproof])]> σₛ) σ' + | right _ => 0 + end + + match ClassicalEpsilon.excluded_middle_informative + (∃ m, f m = x) with + | left Hproof => 0 + | right _ => dret σₛ σ' + end + ) + ); last first. + { apply SeriesC_ext. intros. case_match; lra. } + rewrite !SeriesC_plus; last first. + all: try apply ex_seriesC_finite. + etrans; first eapply Rplus_eq_compat_l; last apply Rplus_eq_compat_r. + { apply SeriesC_ext. intros. case_bool_decide as H; case_match; done. } + destruct (ExcludedMiddle (∃ x, σ' = (state_upd_tapes <[αₛ:=(M; ms ++ [x])]> σₛ))) as [H|H]. + + destruct H as [n ->]. + trans 1. + * rewrite /state_step. + rewrite bool_decide_eq_true_2; last first. + { rewrite elem_of_dom. rewrite Hσₛ. done. } + setoid_rewrite (lookup_total_correct (tapes σₛ) αₛ (M; ms)); last done. + rewrite /dmap/dbind/dbind_pmf{1}/pmf/=. + rewrite /dunifP. setoid_rewrite dunif_pmf. + setoid_rewrite SeriesC_scal_l. + rewrite (SeriesC_ext _ (λ x : fin (S N), + if bool_decide (∃ m : fin (S M), f m = x) + then + / S M + else 0)). + -- erewrite (SeriesC_ext _ (λ x : fin (S N), / S M * if bool_decide (x∈f<$> enum (fin (S M))) then 1 else 0)). + { rewrite SeriesC_scal_l. rewrite SeriesC_list_1. + - rewrite fmap_length. rewrite length_enum_fin. rewrite Rinv_l; first lra. + replace 0 with (INR 0) by done. + move => /INR_eq. lia. + - apply NoDup_fmap_2; try done. + apply NoDup_enum. + } + intros n'. + case_bool_decide as H. + ++ rewrite bool_decide_eq_true_2; first lra. + destruct H as [?<-]. + apply elem_of_list_fmap_1. + apply elem_of_enum. + ++ rewrite bool_decide_eq_false_2; first lra. + intros H0. apply H. + apply elem_of_list_fmap_2 in H0 as [?[->?]]. + naive_solver. + -- intros. + erewrite (SeriesC_ext _ (λ x, if (bool_decide (x=n)) then 1 else 0)). + ++ rewrite SeriesC_singleton. case_bool_decide as H1; lra. + ++ intros m. case_bool_decide; subst. + ** by apply dret_1. + ** apply dret_0. intro H1. apply H. apply state_upd_tapes_same in H1. + simplify_eq. + * symmetry. + rewrite (SeriesC_ext _ (λ x, if bool_decide (x = f n) then 1 else 0)). + { apply SeriesC_singleton. } + intros n'. + case_match eqn:Heqn. + { destruct e as [m <-] eqn:He. + case_bool_decide as Heqn'. + - apply Hinj in Heqn' as ->. + apply dret_1. + repeat f_equal. + pose proof epsilon_correct (λ m : fin (S M), f m = f n) as H. simpl in H. + apply Hinj. rewrite H. done. + - apply dret_0. + move => /state_upd_tapes_same. intros eq. simplify_eq. + apply Heqn'. pose proof epsilon_correct (λ m0 : fin (S M), f m0 = f m) as H. + by rewrite H. + } + rewrite bool_decide_eq_false_2; first done. + intros ->. naive_solver. + + trans 0. + * apply SeriesC_0. + intros. case_bool_decide; last done. + rewrite /state_step. + rewrite bool_decide_eq_true_2; last first. + { rewrite elem_of_dom. rewrite Hσₛ. done. } + setoid_rewrite (lookup_total_correct (tapes σₛ) αₛ (M; ms)); last done. + rewrite /dmap/dbind/dbind_pmf{1}/pmf/=. + rewrite /dunifP. setoid_rewrite dunif_pmf. + apply SeriesC_0. + intros m. apply Rmult_eq_0_compat_l. + apply dret_0. + intros ->. apply H. + exists m. done. + * symmetry. + apply SeriesC_0. + intros. case_match; last done. + apply dret_0. + intros ->. apply H. + naive_solver. + } + erewrite state_step_unfold; last done. + rewrite /dmap. + eapply Rcoupl_dbind; last apply Rcoupl_eq. + intros ??->. + case_match eqn:Heqn. + - destruct e as [m He]. + replace (epsilon _) with m; last first. + { pose proof epsilon_correct (λ m0 : fin (S M), f m0 = b) as H. + simpl in H. apply Hinj. rewrite H. done. + } + apply Rcoupl_dret. + exists b. + rewrite bool_decide_eq_true_2; last naive_solver. + naive_solver. + - apply Rcoupl_dret. + exists b. rewrite bool_decide_eq_false_2; naive_solver. +Qed. + + +(** Some useful lemmas to reason about language properties *) + +Inductive det_head_step_rel : expr → state → expr → state → Prop := +| RecDS f x e σ : + det_head_step_rel (Rec f x e) σ (Val $ RecV f x e) σ +| PairDS v1 v2 σ : + det_head_step_rel (Pair (Val v1) (Val v2)) σ (Val $ PairV v1 v2) σ +| InjLDS v σ : + det_head_step_rel (InjL $ Val v) σ (Val $ InjLV v) σ +| InjRDS v σ : + det_head_step_rel (InjR $ Val v) σ (Val $ InjRV v) σ +| BetaDS f x e1 v2 e' σ : + e' = subst' x v2 (subst' f (RecV f x e1) e1) → + det_head_step_rel (App (Val $ RecV f x e1) (Val v2)) σ e' σ +| UnOpDS op v v' σ : + un_op_eval op v = Some v' → + det_head_step_rel (UnOp op (Val v)) σ (Val v') σ +| BinOpDS op v1 v2 v' σ : + bin_op_eval op v1 v2 = Some v' → + det_head_step_rel (BinOp op (Val v1) (Val v2)) σ (Val v') σ +| IfTrueDS e1 e2 σ : + det_head_step_rel (If (Val $ LitV $ LitBool true) e1 e2) σ e1 σ +| IfFalseDS e1 e2 σ : + det_head_step_rel (If (Val $ LitV $ LitBool false) e1 e2) σ e2 σ +| FstDS v1 v2 σ : + det_head_step_rel (Fst (Val $ PairV v1 v2)) σ (Val v1) σ +| SndDS v1 v2 σ : + det_head_step_rel (Snd (Val $ PairV v1 v2)) σ (Val v2) σ +| CaseLDS v e1 e2 σ : + det_head_step_rel (Case (Val $ InjLV v) e1 e2) σ (App e1 (Val v)) σ +| CaseRDS v e1 e2 σ : + det_head_step_rel (Case (Val $ InjRV v) e1 e2) σ (App e2 (Val v)) σ +| AllocNDS z N v σ l : + l = fresh_loc σ.(heap) → + N = Z.to_nat z → + (0 < N)%nat -> + det_head_step_rel (AllocN (Val (LitV (LitInt z))) (Val v)) σ + (Val $ LitV $ LitLoc l) (state_upd_heap_N l N v σ) +| LoadDS l v σ : + σ.(heap) !! l = Some v → + det_head_step_rel (Load (Val $ LitV $ LitLoc l)) σ (of_val v) σ +| StoreDS l v w σ : + σ.(heap) !! l = Some v → + det_head_step_rel (Store (Val $ LitV $ LitLoc l) (Val w)) σ + (Val $ LitV LitUnit) (state_upd_heap <[l:=w]> σ) +| TickDS z σ : + det_head_step_rel (Tick (Val $ LitV $ LitInt z)) σ (Val $ LitV $ LitUnit) σ. + +Inductive det_head_step_pred : expr → state → Prop := +| RecDSP f x e σ : + det_head_step_pred (Rec f x e) σ +| PairDSP v1 v2 σ : + det_head_step_pred (Pair (Val v1) (Val v2)) σ +| InjLDSP v σ : + det_head_step_pred (InjL $ Val v) σ +| InjRDSP v σ : + det_head_step_pred (InjR $ Val v) σ +| BetaDSP f x e1 v2 σ : + det_head_step_pred (App (Val $ RecV f x e1) (Val v2)) σ +| UnOpDSP op v σ v' : + un_op_eval op v = Some v' → + det_head_step_pred (UnOp op (Val v)) σ +| BinOpDSP op v1 v2 σ v' : + bin_op_eval op v1 v2 = Some v' → + det_head_step_pred (BinOp op (Val v1) (Val v2)) σ +| IfTrueDSP e1 e2 σ : + det_head_step_pred (If (Val $ LitV $ LitBool true) e1 e2) σ +| IfFalseDSP e1 e2 σ : + det_head_step_pred (If (Val $ LitV $ LitBool false) e1 e2) σ +| FstDSP v1 v2 σ : + det_head_step_pred (Fst (Val $ PairV v1 v2)) σ +| SndDSP v1 v2 σ : + det_head_step_pred (Snd (Val $ PairV v1 v2)) σ +| CaseLDSP v e1 e2 σ : + det_head_step_pred (Case (Val $ InjLV v) e1 e2) σ +| CaseRDSP v e1 e2 σ : + det_head_step_pred (Case (Val $ InjRV v) e1 e2) σ +| AllocNDSP z N v σ l : + l = fresh_loc σ.(heap) → + N = Z.to_nat z → + (0 < N)%nat -> + det_head_step_pred (AllocN (Val (LitV (LitInt z))) (Val v)) σ +| LoadDSP l v σ : + σ.(heap) !! l = Some v → + det_head_step_pred (Load (Val $ LitV $ LitLoc l)) σ +| StoreDSP l v w σ : + σ.(heap) !! l = Some v → + det_head_step_pred (Store (Val $ LitV $ LitLoc l) (Val w)) σ +| TickDSP z σ : + det_head_step_pred (Tick (Val $ LitV $ LitInt z)) σ. + +Definition is_det_head_step (e1 : expr) (σ1 : state) : bool := + match e1 with + | Rec f x e => true + | Pair (Val v1) (Val v2) => true + | InjL (Val v) => true + | InjR (Val v) => true + | App (Val (RecV f x e1)) (Val v2) => true + | UnOp op (Val v) => bool_decide(is_Some(un_op_eval op v)) + | BinOp op (Val v1) (Val v2) => bool_decide (is_Some(bin_op_eval op v1 v2)) + | If (Val (LitV (LitBool true))) e1 e2 => true + | If (Val (LitV (LitBool false))) e1 e2 => true + | Fst (Val (PairV v1 v2)) => true + | Snd (Val (PairV v1 v2)) => true + | Case (Val (InjLV v)) e1 e2 => true + | Case (Val (InjRV v)) e1 e2 => true + | AllocN (Val (LitV (LitInt z))) (Val v) => bool_decide (0 < Z.to_nat z)%nat + | Load (Val (LitV (LitLoc l))) => + bool_decide (is_Some (σ1.(heap) !! l)) + | Store (Val (LitV (LitLoc l))) (Val w) => + bool_decide (is_Some (σ1.(heap) !! l)) + | Tick (Val (LitV (LitInt z))) => true + | _ => false + end. + +Lemma det_step_eq_tapes e1 σ1 e2 σ2 : + det_head_step_rel e1 σ1 e2 σ2 → σ1.(tapes) = σ2.(tapes). +Proof. inversion 1; auto. Qed. + +Inductive prob_head_step_pred : expr -> state -> Prop := +| AllocTapePSP σ N z : + N = Z.to_nat z → + prob_head_step_pred (alloc #z) σ +| RandTapePSP α σ N n ns z : + N = Z.to_nat z → + σ.(tapes) !! α = Some ((N; n :: ns) : tape) → + prob_head_step_pred (rand(#lbl:α) #z) σ +| RandEmptyPSP N α σ z : + N = Z.to_nat z → + σ.(tapes) !! α = Some ((N; []) : tape) → + prob_head_step_pred (rand(#lbl:α) #z) σ +| RandTapeOtherPSP N M α σ ns z : + N ≠ M → + M = Z.to_nat z → + σ.(tapes) !! α = Some ((N; ns) : tape) → + prob_head_step_pred (rand(#lbl:α) #z) σ +| RandNoTapePSP (N : nat) σ z : + N = Z.to_nat z → + prob_head_step_pred (rand #z) σ. + +Definition head_step_pred e1 σ1 := + det_head_step_pred e1 σ1 ∨ prob_head_step_pred e1 σ1. + +Lemma det_step_is_unique e1 σ1 e2 σ2 e3 σ3 : + det_head_step_rel e1 σ1 e2 σ2 → + det_head_step_rel e1 σ1 e3 σ3 → + e2 = e3 ∧ σ2 = σ3. +Proof. + intros H1 H2. + inversion H1; inversion H2; simplify_eq; auto. +Qed. + +Lemma det_step_pred_ex_rel e1 σ1 : + det_head_step_pred e1 σ1 ↔ ∃ e2 σ2, det_head_step_rel e1 σ1 e2 σ2. +Proof. + split. + - intro H; inversion H; simplify_eq; eexists; eexists; econstructor; eauto. + - intros (e2 & (σ2 & H)); inversion H ; econstructor; eauto. +Qed. + +Local Ltac solve_step_det := + rewrite /pmf /=; + repeat (rewrite bool_decide_eq_true_2 // || case_match); + try (lra || lia || done). + +Local Ltac inv_det_head_step := + repeat + match goal with + | H : to_val _ = Some _ |- _ => apply of_to_val in H + | H : is_det_head_step _ _ = true |- _ => + rewrite /is_det_head_step in H; + repeat (case_match in H; simplify_eq) + | H : is_Some _ |- _ => destruct H + | H : bool_decide _ = true |- _ => rewrite bool_decide_eq_true in H; destruct_and? + | _ => progress simplify_map_eq/= + end. + +Lemma is_det_head_step_true e1 σ1 : + is_det_head_step e1 σ1 = true ↔ det_head_step_pred e1 σ1. +Proof. + split; intro H. + - destruct e1; inv_det_head_step; by econstructor. + - inversion H; solve_step_det. +Qed. + +Lemma det_head_step_singleton e1 σ1 e2 σ2 : + det_head_step_rel e1 σ1 e2 σ2 → head_step e1 σ1 = dret (e2, σ2). +Proof. + intros Hdet. + apply pmf_1_eq_dret. + inversion Hdet; simplify_eq/=; try case_match; + simplify_option_eq; rewrite ?dret_1_1 //. +Qed. + +Lemma val_not_head_step e1 σ1 : + is_Some (to_val e1) → ¬ head_step_pred e1 σ1. +Proof. + intros [] [Hs | Hs]; inversion Hs; simplify_eq. +Qed. + +Lemma head_step_pred_ex_rel e1 σ1 : + head_step_pred e1 σ1 ↔ ∃ e2 σ2, head_step_rel e1 σ1 e2 σ2. +Proof. + split. + - intros [Hdet | Hdet]; + inversion Hdet; simplify_eq; do 2 eexists; try (by econstructor). + Unshelve. all : apply 0%fin. + - intros (?&?& H). inversion H; simplify_eq; + (try by (left; econstructor)); + (try by (right; econstructor)). + right. by eapply RandTapeOtherPSP; [|done|done]. +Qed. + +Lemma not_head_step_pred_dzero e1 σ1: + ¬ head_step_pred e1 σ1 ↔ head_step e1 σ1 = dzero. +Proof. + split. + - intro Hnstep. + apply dzero_ext. + intros (e2 & σ2). + destruct (Rlt_le_dec 0 (head_step e1 σ1 (e2, σ2))) as [H1%Rgt_lt | H2]; last first. + { pose proof (pmf_pos (head_step e1 σ1) (e2, σ2)). destruct H2; lra. } + apply head_step_support_equiv_rel in H1. + assert (∃ e2 σ2, head_step_rel e1 σ1 e2 σ2) as Hex; eauto. + by apply head_step_pred_ex_rel in Hex. + - intros Hhead (e2 & σ2 & Hstep)%head_step_pred_ex_rel. + apply head_step_support_equiv_rel in Hstep. + assert (head_step e1 σ1 (e2, σ2) = 0); [|lra]. + rewrite Hhead //. +Qed. + +Lemma det_or_prob_or_dzero e1 σ1 : + det_head_step_pred e1 σ1 + ∨ prob_head_step_pred e1 σ1 + ∨ head_step e1 σ1 = dzero. +Proof. + destruct (Rlt_le_dec 0 (SeriesC (head_step e1 σ1))) as [H1%Rlt_gt | [HZ | HZ]]. + - pose proof (SeriesC_gtz_ex (head_step e1 σ1) (pmf_pos (head_step e1 σ1)) H1) as [[e2 σ2] Hρ]. + pose proof (head_step_support_equiv_rel e1 e2 σ1 σ2) as [H3 H4]. + specialize (H3 Hρ). + assert (head_step_pred e1 σ1) as []; [|auto|auto]. + apply head_step_pred_ex_rel; eauto. + - by pose proof (pmf_SeriesC_ge_0 (head_step e1 σ1)) + as ?%Rle_not_lt. + - apply SeriesC_zero_dzero in HZ. eauto. +Qed. + +Lemma head_step_dzero_upd_tapes α e σ N zs z : + α ∈ dom σ.(tapes) → + head_step e σ = dzero → + head_step e (state_upd_tapes <[α:=(N; zs ++ [z]) : tape]> σ) = dzero. +Proof. + intros Hdom Hz. + destruct e; simpl in *; + repeat case_match; done || inv_dzero; simplify_map_eq. + (* TODO: [simplify_map_eq] should solve this? *) + - destruct (decide (α = l1)). + + simplify_eq. + by apply not_elem_of_dom_2 in H5. + + rewrite lookup_insert_ne // in H6. + rewrite H5 in H6. done. + - destruct (decide (α = l1)). + + simplify_eq. + by apply not_elem_of_dom_2 in H5. + + rewrite lookup_insert_ne // in H6. + rewrite H5 in H6. done. + - destruct (decide (α = l1)). + + simplify_eq. + by apply not_elem_of_dom_2 in H5. + + rewrite lookup_insert_ne // in H6. + rewrite H5 in H6. done. +Qed. + +Lemma det_head_step_upd_tapes N e1 σ1 e2 σ2 α z zs : + det_head_step_rel e1 σ1 e2 σ2 → + tapes σ1 !! α = Some (N; zs) → + det_head_step_rel + e1 (state_upd_tapes <[α := (N; zs ++ [z])]> σ1) + e2 (state_upd_tapes <[α := (N; zs ++ [z])]> σ2). +Proof. + inversion 1; try econstructor; eauto. + (* Unsolved case *) + intros. rewrite state_upd_tapes_heap. econstructor; eauto. +Qed. + +Lemma upd_tape_some σ α N n ns : + tapes σ !! α = Some (N; ns) → + tapes (state_upd_tapes <[α:= (N; ns ++ [n])]> σ) !! α = Some (N; ns ++ [n]). +Proof. + intros H. rewrite /state_upd_tapes /=. rewrite lookup_insert //. +Qed. + +Lemma upd_tape_some_trivial σ α bs: + tapes σ !! α = Some bs → + state_upd_tapes <[α:=tapes σ !!! α]> σ = σ. +Proof. + destruct σ. simpl. + intros H. + rewrite (lookup_total_correct _ _ _ H). + f_equal. + by apply insert_id. +Qed. + +Lemma upd_diff_tape_comm σ α β bs bs': + α ≠ β → + state_upd_tapes <[β:= bs]> (state_upd_tapes <[α := bs']> σ) = + state_upd_tapes <[α:= bs']> (state_upd_tapes <[β := bs]> σ). +Proof. + intros. rewrite /state_upd_tapes /=. rewrite insert_commute //. +Qed. + +Lemma upd_diff_tape_tot σ α β bs: + α ≠ β → + tapes σ !!! α = tapes (state_upd_tapes <[β:=bs]> σ) !!! α. +Proof. symmetry ; by rewrite lookup_total_insert_ne. Qed. + +Lemma upd_tape_twice σ β bs bs' : + state_upd_tapes <[β:= bs]> (state_upd_tapes <[β:= bs']> σ) = state_upd_tapes <[β:= bs]> σ. +Proof. rewrite /state_upd_tapes insert_insert //. Qed. + +Lemma fresh_loc_upd_some σ α bs bs' : + (tapes σ) !! α = Some bs → + fresh_loc (tapes σ) = (fresh_loc (<[α:= bs']> (tapes σ))). +Proof. + intros Hα. + apply fresh_loc_eq_dom. + by rewrite dom_insert_lookup_L. +Qed. + +Lemma elem_fresh_ne {V} (ls : gmap loc V) k v : + ls !! k = Some v → fresh_loc ls ≠ k. +Proof. + intros; assert (is_Some (ls !! k)) as Hk by auto. + pose proof (fresh_loc_is_fresh ls). + rewrite -elem_of_dom in Hk. + set_solver. +Qed. + +Lemma fresh_loc_upd_swap σ α bs bs' bs'' : + (tapes σ) !! α = Some bs → + state_upd_tapes <[fresh_loc (tapes σ):=bs']> (state_upd_tapes <[α:=bs'']> σ) + = state_upd_tapes <[α:=bs'']> (state_upd_tapes <[fresh_loc (tapes σ):=bs']> σ). +Proof. + intros H. + apply elem_fresh_ne in H. + unfold state_upd_tapes. + by rewrite insert_commute. +Qed. + +Lemma fresh_loc_lookup σ α bs bs' : + (tapes σ) !! α = Some bs → + (tapes (state_upd_tapes <[fresh_loc (tapes σ):=bs']> σ)) !! α = Some bs. +Proof. + intros H. + pose proof (elem_fresh_ne _ _ _ H). + by rewrite lookup_insert_ne. +Qed. + +Lemma prim_step_empty_tape σ α (z:Z) K N : + (tapes σ) !! α = Some (N; []) -> prim_step (fill K (rand(#lbl:α) #z)) σ = prim_step (fill K (rand #z)) σ. +Proof. + intros H. + rewrite !fill_dmap; [|done|done]. + rewrite /dmap. + f_equal. + simpl. apply distr_ext; intros [e s]. + erewrite !head_prim_step_eq; simpl; last first. + (** type classes dont work? *) + { destruct (decide (Z.to_nat z=N)) as [<-|?] eqn:Heqn. + all: eexists (_, σ); eapply head_step_support_equiv_rel; + eapply head_step_support_eq; simpl; last first. + - rewrite H. rewrite bool_decide_eq_true_2; last lia. + eapply dmap_unif_nonzero; last done. + intros ???. simplify_eq. done. + - apply Rinv_pos. pose proof pos_INR_S (Z.to_nat z). lra. + - rewrite H. case_bool_decide as H0; first lia. + eapply dmap_unif_nonzero; last done. + intros ???. by simplify_eq. + - apply Rinv_pos. pose proof pos_INR_S (Z.to_nat z). lra. + } + { eexists (_, σ); eapply head_step_support_equiv_rel; + eapply head_step_support_eq; simpl; last first. + - eapply dmap_unif_nonzero; last done. + intros ???. simplify_eq. done. + - apply Rinv_pos. pose proof pos_INR_S (Z.to_nat z). lra. + } rewrite H. + case_bool_decide; last done. + subst. done. + Unshelve. + all: exact (0%fin). +Qed. + +*) diff --git a/theories/meas_lang/tactics.v b/theories/meas_lang/tactics.v new file mode 100644 index 00000000..dea96e3d --- /dev/null +++ b/theories/meas_lang/tactics.v @@ -0,0 +1,87 @@ +From Coq Require Import Reals Psatz. +From stdpp Require Import fin_maps. +From iris.proofmode Require Import environments proofmode. +From clutch.meas_lang Require Import lang ectx_language. +From iris.prelude Require Import options. +Import meas_lang. + +(* +(** The tactic [reshape_expr e tac] decomposes the expression [e] into an +evaluation context [K] and a subexpression [e']. It calls the tactic [tac K e'] +for each possible decomposition until [tac] succeeds. *) +Ltac reshape_expr e tac := + let rec go K e := + match e with + | _ => tac K e + | App ?e (Val ?v) => go (AppLCtx v :: K) e + | App ?e1 ?e2 => go (AppRCtx e1 :: K) e2 + | UnOp ?op ?e => go (UnOpCtx op :: K) e + | BinOp ?op ?e (Val ?v) => go (BinOpLCtx op v :: K) e + | BinOp ?op ?e1 ?e2 => go (BinOpRCtx op e1 :: K) e2 + | If ?e0 ?e1 ?e2 => go (IfCtx e1 e2 :: K) e0 + | Pair ?e (Val ?v) => go (PairLCtx v :: K) e + | Pair ?e1 ?e2 => go (PairRCtx e1 :: K) e2 + | Fst ?e => go (FstCtx :: K) e + | Snd ?e => go (SndCtx :: K) e + | InjL ?e => go (InjLCtx :: K) e + | InjR ?e => go (InjRCtx :: K) e + | Case ?e0 ?e1 ?e2 => go (CaseCtx e1 e2 :: K) e0 + | AllocN ?e (Val ?v) => go (AllocNLCtx v :: K) e + | AllocN ?e1 ?e2 => go (AllocNRCtx e1 :: K) e2 + | Load ?e => go (LoadCtx :: K) e + | Store ?e (Val ?v) => go (StoreLCtx v :: K) e + | Store ?e1 ?e2 => go (StoreRCtx e1 :: K) e2 + | AllocTape ?e => go (AllocTapeCtx :: K) e + | Rand ?e (Val ?v) => go (RandLCtx v :: K) e + | Rand ?e1 ?e2 => go (RandRCtx e1 :: K) e2 + | Tick ?e => go (TickCtx :: K) e + end in go (@nil ectx_item) e. + +Local Open Scope R. + +Lemma head_step_support_eq e1 e2 σ1 σ2 r : + r > 0 → head_step e1 σ1 (e2, σ2) = r → head_step_rel e1 σ1 e2 σ2. +Proof. intros ? <-. by eapply head_step_support_equiv_rel. Qed. + +Lemma head_step_support_eq_1 e1 e2 σ1 σ2 : + head_step e1 σ1 (e2, σ2) = 1 → head_step_rel e1 σ1 e2 σ2. +Proof. eapply head_step_support_eq; lra. Qed. + +(** The tactic [inv_head_step] performs inversion on hypotheses of the shape + [head_step]. The tactic will discharge head-reductions starting from values, + and simplifies hypothesis related to conversions from and to values, and + finite map operations. This tactic is slightly ad-hoc and tuned for proving + our lifting lemmas. *) + +Global Hint Extern 0 (head_reducible _ _) => + eexists (_, _); eapply head_step_support_equiv_rel : head_step. +Global Hint Extern 1 (head_step _ _ _ > 0) => + eapply head_step_support_equiv_rel; econstructor : head_step. + +Global Hint Extern 2 (head_reducible _ _) => + by eauto with head_step : typeclass_instances. + +Ltac solve_step := + simpl; + match goal with + | |- (prim_step _ _).(pmf) _ = 1%R => + rewrite head_prim_step_eq /= ; + simplify_map_eq ; solve_distr + | |- (head_step _ _).(pmf) _ = 1%R => simplify_map_eq; solve_distr + | |- (head_step _ _).(pmf) _ > 0%R => eauto with head_step + end. + +Ltac solve_red := + match goal with + | |- (environments.envs_entails _ ( ⌜ _ ⌝ ∗ _)) => + iSplitR ; [ by (iPureIntro ; solve_red) | ] + | |- (environments.envs_entails _ ( _ ∗ ⌜ _ ⌝)) => + iSplitL ; [ by (iPureIntro ; solve_red) | ] + | |- reducible ((fill _ _), _) => + apply reducible_fill ; solve_red + | |- reducible _ => + apply head_prim_reducible ; solve_red + | |- (head_reducible _ _) => + by eauto with head_step + end. +*) diff --git a/theories/meas_lang/wp_tactics.v b/theories/meas_lang/wp_tactics.v new file mode 100644 index 00000000..271ab4e8 --- /dev/null +++ b/theories/meas_lang/wp_tactics.v @@ -0,0 +1,756 @@ +From iris.bi Require Export bi updates. +From iris.base_logic.lib Require Import fancy_updates. +From iris.proofmode Require Import coq_tactics reduction spec_patterns. +From iris.proofmode Require Export tactics. + +(* +(*From clutch.bi Require Import weakestpre.*) +From clutch.prob_lang Require Import lang tactics notation class_instances. +Set Default Proof Using "Type*". + +(** A basic set of requirements for a weakest precondition *) +Class GwpTacticsBase (Σ : gFunctors) (A : Type) `{!invGS_gen hlc Σ} (gwp : A → coPset → expr → (val → iProp Σ) → iProp Σ) := { + wptac_wp_value E Φ v a : Φ v ⊢ gwp a E (of_val v) Φ; + wptac_wp_fupd E Φ e a : gwp a E e (λ v, |={E}=> Φ v) ⊢ gwp a E e Φ; + }. + +Class GwpTacticsBind (Σ : gFunctors) (A : Type) `{!invGS_gen hlc Σ} (gwp : A → coPset → expr → (val → iProp Σ) → iProp Σ) := { + wptac_wp_bind K `{!LanguageCtx K} E e Φ a : + gwp a E e (λ v, gwp a E (K (of_val v)) Φ ) ⊢ gwp a E (K e) Φ +}. + +Class GwpTacticsPure Σ A (laters : bool) (gwp : A → coPset → expr → (val → iProp Σ) → iProp Σ) := { + wptac_wp_pure_step E e1 e2 φ n Φ a : + PureExec φ n e1 e2 → + φ → + ▷^(if laters then n else 0) (gwp a E e2 Φ) ⊢ (gwp a E e1 Φ); +}. + +(** Heap *) +Class GwpTacticsHeap Σ A (laters : bool) (gwp : A → coPset → expr → (val → iProp Σ) → iProp Σ):= { + wptac_mapsto : loc → dfrac → val → iProp Σ; + wptac_mapsto_array : loc → dfrac → (list val) → iProp Σ; + + wptac_wp_alloc E v a Φ : + True -∗ + (▷?laters (∀ l, (wptac_mapsto l (DfracOwn 1) v) -∗ Φ (LitV (LitLoc l))%V)) -∗ + gwp a E (Alloc (Val v)) Φ; + + wptac_wp_allocN E v n a Φ : + (0 < n)%Z → + True -∗ + (▷?laters (∀ l, (wptac_mapsto_array l (DfracOwn 1) (replicate (Z.to_nat n) v)) -∗ + Φ (LitV (LitLoc l))%V)) -∗ + gwp a E (AllocN (Val $ LitV $ LitInt $ n) (Val v)) Φ; + + wptac_wp_load E v l dq a Φ : + (▷ wptac_mapsto l dq v) -∗ + (▷?laters ((wptac_mapsto l dq v) -∗ Φ v%V)) -∗ + gwp a E (Load (Val $ LitV $ LitLoc l)) Φ; + + wptac_wp_store E v v' l a Φ : + (▷ wptac_mapsto l (DfracOwn 1) v') -∗ + (▷?laters ((wptac_mapsto l (DfracOwn 1) v) -∗ Φ (LitV (LitUnit))%V)) -∗ + gwp a E (Store (Val $ LitV $ LitLoc l) (Val v)) Φ; + }. + + +(** Tapes *) +Class GwpTacticsTapes Σ A (laters : bool) (gwp : A → coPset → expr → (val → iProp Σ) → iProp Σ):= { + wptac_mapsto_tape : loc → dfrac → nat -> (list nat) → iProp Σ; + + wptac_wp_alloctape E (N : nat) (z : Z) a Φ : + TCEq N (Z.to_nat z) -> + True -∗ + (▷?laters (∀ l, (wptac_mapsto_tape l (DfracOwn 1) N nil) -∗ Φ (LitV (LitLbl l))%V)) -∗ + gwp a E (AllocTape (Val $ LitV $ LitInt $ z)) Φ; + + wptac_wp_rand_tape E N (n : nat) (z : Z) ns l dq a Φ : + TCEq N (Z.to_nat z) -> + (▷ wptac_mapsto_tape l dq N (n::ns)) -∗ + (▷?laters ((wptac_mapsto_tape l dq N ns) -∗ ⌜ n ≤ N ⌝ -∗ Φ (LitV $ LitInt $ n)%V)) -∗ + gwp a E (Rand (LitV (LitInt z)) (LitV (LitLbl l))) Φ; +}. + +Section wp_tactics. + Context `{GwpTacticsBase Σ A hlc gwp}. + + Local Notation "'WP' e @ s ; E {{ Φ } }" := (gwp s E e%E Φ) + (at level 20, e, Φ at level 200, only parsing) : bi_scope. + Local Notation "'WP' e @ s ; E {{ v , Q } }" := (gwp s E e%E (λ v, Q)) + (at level 20, e, Q at level 200, + format "'[hv' 'WP' e '/' @ '[' s ; '/' E ']' '/' {{ '[' v , '/' Q ']' } } ']'") : bi_scope. + + Lemma tac_wp_expr_eval Δ a E Φ e e' : + (∀ (e'':=e'), e = e'') → + envs_entails Δ (WP e' @ a; E {{ Φ }}) → envs_entails Δ (WP e @ a; E {{ Φ }}). + Proof. by intros ->. Qed. + + Lemma tac_wp_pure_later laters `{!GwpTacticsPure Σ A laters gwp} Δ Δ' E K e1 e2 φ n Φ a : + PureExec φ n e1 e2 → + φ → + MaybeIntoLaterNEnvs (if laters then n else 0) Δ Δ' → + envs_entails Δ' (WP (fill K e2) @ a; E {{ Φ }}) → + envs_entails Δ (WP (fill K e1) @ a; E {{ Φ }}). + Proof. + rewrite envs_entails_unseal=> ??? HΔ'. rewrite into_laterN_env_sound /=. + (* We want [pure_exec_fill] to be available to TC search locally. *) + pose proof @pure_exec_fill. + rewrite HΔ' -wptac_wp_pure_step //. + Qed. + + Lemma tac_wp_value_nofupd Δ E Φ v a : + envs_entails Δ (Φ v) → envs_entails Δ (WP (of_val v) @ a; E {{ Φ }}). + Proof. rewrite envs_entails_unseal=> ->. apply wptac_wp_value. Qed. + + Lemma tac_wp_value' Δ E Φ v a : + envs_entails Δ (|={E}=> Φ v) → envs_entails Δ (WP (of_val v) @ a; E {{ Φ }}). + Proof. rewrite envs_entails_unseal=> ->. by rewrite -wptac_wp_fupd -wptac_wp_value. Qed. + +End wp_tactics. + +Section wp_bind_tactics. + Context `{GwpTacticsBind Σ A hlc gwp}. + + Local Notation "'WP' e @ s ; E {{ Φ } }" := (gwp s E e%E Φ) + (at level 20, e, Φ at level 200, only parsing) : bi_scope. + Local Notation "'WP' e @ s ; E {{ v , Q } }" := (gwp s E e%E (λ v, Q)) + (at level 20, e, Q at level 200, + format "'[hv' 'WP' e '/' @ '[' s ; '/' E ']' '/' {{ '[' v , '/' Q ']' } } ']'") : bi_scope. + + Lemma tac_wp_bind K Δ E Φ e f a : + f = (λ e, fill K e) → (* as an eta expanded hypothesis so that we can `simpl` it *) + envs_entails Δ (WP e @ a; E {{ v, WP f (Val v) @ a; E {{ Φ }} }})%I → + envs_entails Δ (WP fill K e @ a; E {{ Φ }}). + Proof. rewrite envs_entails_unseal=> -> ->. by apply: wptac_wp_bind. Qed. +End wp_bind_tactics. + +(* TODO: find a better way so that we do not need to have a case for both [wp] and [twp]... *) +Tactic Notation "wp_expr_eval" tactic3(t) := + iStartProof; + lazymatch goal with + | |- envs_entails _ (wp ?s ?E ?e ?Q) => + notypeclasses refine (tac_wp_expr_eval _ _ _ _ e _ _ _ ); + [apply _|let x := fresh in intros x; simpl; unfold x; notypeclasses refine eq_refl|] + | |- envs_entails _ (twp ?s ?E ?e ?Q) => + notypeclasses refine (tac_wp_expr_eval _ _ _ _ e _ _ _ ); + [apply _|let x := fresh in intros x; simpl; unfold x; notypeclasses refine eq_refl|] + | _ => fail "wp_expr_eval: not a 'wp'" + end. +Ltac wp_expr_simpl := wp_expr_eval simpl. + +(** Simplify the goal if it is [wp] of a value. + If the postcondition already allows a fupd, do not add a second one. + But otherwise, *do* add a fupd. This ensures that all the lemmas applied + here are bidirectional, so we never will make a goal unprovable. *) +Ltac wp_value_head := + lazymatch goal with + | |- envs_entails _ (wp ?s ?E (of_val _) (λ _, fupd ?E _ _)) => + eapply tac_wp_value_nofupd + | |- envs_entails _ (wp ?s ?E (of_val _) (λ _, wp _ ?E _ _)) => + eapply tac_wp_value_nofupd + | |- envs_entails _ (wp ?s ?E (of_val _) _) => + eapply tac_wp_value' + | |- envs_entails _ (twp ?s ?E (of_val _) (λ _, fupd ?E _ _)) => + eapply tac_wp_value_nofupd + | |- envs_entails _ (twp ?s ?E (of_val _) (λ _, twp _ ?E _ _)) => + eapply tac_wp_value_nofupd + | |- envs_entails _ (twp ?s ?E (of_val _) _) => + eapply tac_wp_value' + end. + +Ltac wp_finish := + (* simplify occurences of [wptac_mapsto] projections *) + rewrite ?[wptac_mapsto _ _ _]/=; + (* simplify occurences of [wptac_mapsto_tape] projections *) + rewrite ?[wptac_mapsto_tape _ _ _]/=; + (* simplify occurences of subst/fill *) + wp_expr_simpl; + (* in case we have reached a value, get rid of the wp *) + try wp_value_head; + (* prettify ▷s caused by [MaybeIntoLaterNEnvs] and λs caused by wp_value *) + pm_prettify. + +Ltac solve_vals_compare_safe := + (* The first branch is for when we have [vals_compare_safe] in the context. + The other two branches are for when either one of the branches reduces to + [True] or we have it in the context. *) + fast_done || (left; fast_done) || (right; fast_done). + +(** The argument [efoc] can be used to specify the construct that should be +reduced. For example, you can write [wp_pure (EIf _ _ _)], which will search +for an [EIf _ _ _] in the expression, and reduce it. + +The use of [open_constr] in this tactic is essential. It will convert all holes +(i.e. [_]s) into evars, that later get unified when an occurences is found +(see [unify e' efoc] in the code below). *) +Tactic Notation "wp_pure" open_constr(efoc) := + iStartProof; + lazymatch goal with + | |- envs_entails _ (wp ?s ?E ?e ?Q) => + let e := eval simpl in e in + reshape_expr e ltac:(fun K e' => + unify e' efoc; + eapply (tac_wp_pure_later _ _ _ _ K e'); + [tc_solve (* PureExec *) + |try solve_vals_compare_safe (* The pure condition for PureExec -- handles trivial goals, including [vals_compare_safe] *) + |tc_solve (* IntoLaters *) + |wp_finish (* new goal *) + ]) + || fail "wp_pure: cannot find" efoc "in" e "or" efoc "is not a redex" + | |- envs_entails _ (twp ?s ?E ?e ?Q) => + let e := eval simpl in e in + reshape_expr e ltac:(fun K e' => + unify e' efoc; + eapply (tac_wp_pure_later _ _ _ _ K e'); + [tc_solve (* PureExec *) + |try solve_vals_compare_safe (* The pure condition for PureExec -- handles trivial goals, including [vals_compare_safe] *) + |tc_solve (* IntoLaters *) + |wp_finish (* new goal *) + ]) + || fail "Hello! wp_pure: cannot find" efoc "in" e "or" efoc "is not a redex" + | _ => fail "wp_pure: not a 'wp'" + end. + +Tactic Notation "wp_pure" := + wp_pure _. + +Ltac wp_pures := + iStartProof; + first [ (* The `;[]` makes sure that no side-condition magically spawns. *) + progress repeat (wp_pure _; []) + | wp_finish (* In case wp_pure never ran, make sure we do the usual cleanup. *) + ]. + + +(** Unlike [wp_pures], the tactics [wp_rec] and [wp_lam] should also reduce + lambdas/recs that are hidden behind a definition, i.e. they should use + [AsRecV_recv] as a proper instance instead of a [Hint Extern]. + + We achieve this by putting [AsRecV_recv] in the current environment so that it + can be used as an instance by the typeclass resolution system. We then perform + the reduction, and finally we clear this new hypothesis. *) +Tactic Notation "wp_rec" := + let H := fresh in + assert (H := AsRecV_recv); + wp_pure (App _ _); + clear H. + +Tactic Notation "wp_if" := wp_pure (If _ _ _). +Tactic Notation "wp_if_true" := wp_pure (If (LitV (LitBool true)) _ _). +Tactic Notation "wp_if_false" := wp_pure (If (LitV (LitBool false)) _ _). +Tactic Notation "wp_unop" := wp_pure (UnOp _ _). +Tactic Notation "wp_binop" := wp_pure (BinOp _ _ _). +Tactic Notation "wp_op" := wp_unop || wp_binop. +Tactic Notation "wp_lam" := wp_rec. +Tactic Notation "wp_let" := wp_pure (Rec BAnon (BNamed _) _); wp_lam. +Tactic Notation "wp_seq" := wp_pure (Rec BAnon BAnon _); wp_lam. +Tactic Notation "wp_proj" := wp_pure (Fst _) || wp_pure (Snd _). +Tactic Notation "wp_case" := wp_pure (Case _ _ _). +Tactic Notation "wp_match" := wp_case; wp_pure (Rec _ _ _); wp_lam. +Tactic Notation "wp_inj" := wp_pure (InjL _) || wp_pure (InjR _). +Tactic Notation "wp_pair" := wp_pure (Pair _ _). +Tactic Notation "wp_closure" := wp_pure (Rec _ _ _). + +Ltac wp_bind_core K := + lazymatch eval hnf in K with + | [] => idtac + | _ => eapply (tac_wp_bind K); [simpl; reflexivity|reduction.pm_prettify] + end. + +Tactic Notation "wp_bind" open_constr(efoc) := + iStartProof; + lazymatch goal with + | |- envs_entails _ (wp ?s ?E ?e ?Q) => + first [ reshape_expr e ltac:(fun K e' => unify e' efoc; wp_bind_core K) + | fail 1 "wp_bind: cannot find" efoc "in" e ] + | |- envs_entails _ (twp ?s ?E ?e ?Q) => + first [ reshape_expr e ltac:(fun K e' => unify e' efoc; wp_bind_core K) + | fail 1 "wp_bind: cannot find" efoc "in" e ] + | _ => fail "wp_bind: not a 'wp'" + end. + +(** The tactic [wp_apply_core lem tac_suc tac_fail] evaluates [lem] to a + hypothesis [H] that can be applied, and then runs [wp_bind_core K; tac_suc H] + for every possible evaluation context [K]. + + - The tactic [tac_suc] should do [iApplyHyp H] to actually apply the hypothesis, + but can perform other operations in addition (see [wp_apply] and [awp_apply] + below). + - The tactic [tac_fail cont] is called when [tac_suc H] fails for all evaluation + contexts [K], and can perform further operations before invoking [cont] to + try again. + + TC resolution of [lem] premises happens *after* [tac_suc H] got executed. *) + +Ltac wp_apply_core lem tac_suc tac_fail := first + [iPoseProofCore lem as false (fun H => + lazymatch goal with + | |- envs_entails _ (wp ?s ?E ?e ?Q) => + reshape_expr e ltac:(fun K e' => wp_bind_core K; tac_suc H) + | |- envs_entails _ (twp ?s ?E ?e ?Q) => + reshape_expr e ltac:(fun K e' => wp_bind_core K; tac_suc H) + | _ => fail 1 "wp_apply: not a 'wp'" + end) + |tac_fail ltac:(fun _ => wp_apply_core lem tac_suc tac_fail) + |let P := type of lem in + fail "wp_apply: cannot apply" lem ":" P ]. + +Tactic Notation "wp_apply" open_constr(lem) := + wp_apply_core lem ltac:(fun H => iApplyHyp H; try iNext; try wp_expr_simpl) + ltac:(fun cont => fail). +Tactic Notation "wp_smart_apply" open_constr(lem) := + wp_apply_core lem ltac:(fun H => iApplyHyp H; try iNext; try wp_expr_simpl) + ltac:(fun cont => wp_pure _; []; cont ()). + +(** Better tactics :) *) +Tactic Notation "wp_apply" open_constr(lem) "as" constr(pat) := + wp_apply lem; last iIntros pat. +Tactic Notation "wp_apply" open_constr(lem) "as" "(" simple_intropattern(x1) ")" + constr(pat) := + wp_apply lem; last iIntros ( x1 ) pat. +Tactic Notation "wp_apply" open_constr(lem) "as" "(" simple_intropattern(x1) + simple_intropattern(x2) ")" constr(pat) := + wp_apply lem; last iIntros ( x1 x2 ) pat. +Tactic Notation "wp_apply" open_constr(lem) "as" "(" simple_intropattern(x1) + simple_intropattern(x2) simple_intropattern(x3) ")" constr(pat) := + wp_apply lem; last iIntros ( x1 x2 x3 ) pat. +Tactic Notation "wp_apply" open_constr(lem) "as" "(" simple_intropattern(x1) + simple_intropattern(x2) simple_intropattern(x3) simple_intropattern(x4) ")" + constr(pat) := + wp_apply lem; last iIntros ( x1 x2 x3 x4 ) pat. +Tactic Notation "wp_apply" open_constr(lem) "as" "(" simple_intropattern(x1) + simple_intropattern(x2) simple_intropattern(x3) simple_intropattern(x4) + simple_intropattern(x5) ")" constr(pat) := + wp_apply lem; last iIntros ( x1 x2 x3 x4 x5 ) pat. +Tactic Notation "wp_apply" open_constr(lem) "as" "(" simple_intropattern(x1) + simple_intropattern(x2) simple_intropattern(x3) simple_intropattern(x4) + simple_intropattern(x5) simple_intropattern(x6) ")" constr(pat) := + wp_apply lem; last iIntros ( x1 x2 x3 x4 x5 x6 ) pat. +Tactic Notation "wp_apply" open_constr(lem) "as" "(" simple_intropattern(x1) + simple_intropattern(x2) simple_intropattern(x3) simple_intropattern(x4) + simple_intropattern(x5) simple_intropattern(x6) simple_intropattern(x7) ")" + constr(pat) := + wp_apply lem; last iIntros ( x1 x2 x3 x4 x5 x6 x7 ) pat. +Tactic Notation "wp_apply" open_constr(lem) "as" "(" simple_intropattern(x1) + simple_intropattern(x2) simple_intropattern(x3) simple_intropattern(x4) + simple_intropattern(x5) simple_intropattern(x6) simple_intropattern(x7) + simple_intropattern(x8) ")" constr(pat) := + wp_apply lem; last iIntros ( x1 x2 x3 x4 x5 x6 x7 x8 ) pat. +Tactic Notation "wp_apply" open_constr(lem) "as" "(" simple_intropattern(x1) + simple_intropattern(x2) simple_intropattern(x3) simple_intropattern(x4) + simple_intropattern(x5) simple_intropattern(x6) simple_intropattern(x7) + simple_intropattern(x8) simple_intropattern(x9) ")" constr(pat) := + wp_apply lem; last iIntros ( x1 x2 x3 x4 x5 x6 x7 x8 x9 ) pat. +Tactic Notation "wp_apply" open_constr(lem) "as" "(" simple_intropattern(x1) + simple_intropattern(x2) simple_intropattern(x3) simple_intropattern(x4) + simple_intropattern(x5) simple_intropattern(x6) simple_intropattern(x7) + simple_intropattern(x8) simple_intropattern(x9) simple_intropattern(x10) ")" + constr(pat) := + wp_apply lem; last iIntros ( x1 x2 x3 x4 x5 x6 x7 x8 x9 x10 ) pat. + + +Tactic Notation "wp_smart_apply" open_constr(lem) "as" constr(pat) := + wp_smart_apply lem; last iIntros pat. +Tactic Notation "wp_smart_apply" open_constr(lem) "as" "(" simple_intropattern(x1) ")" + constr(pat) := + wp_smart_apply lem; last iIntros ( x1 ) pat. +Tactic Notation "wp_smart_apply" open_constr(lem) "as" "(" simple_intropattern(x1) + simple_intropattern(x2) ")" constr(pat) := + wp_smart_apply lem; last iIntros ( x1 x2 ) pat. +Tactic Notation "wp_smart_apply" open_constr(lem) "as" "(" simple_intropattern(x1) + simple_intropattern(x2) simple_intropattern(x3) ")" constr(pat) := + wp_smart_apply lem; last iIntros ( x1 x2 x3 ) pat. +Tactic Notation "wp_smart_apply" open_constr(lem) "as" "(" simple_intropattern(x1) + simple_intropattern(x2) simple_intropattern(x3) simple_intropattern(x4) ")" + constr(pat) := + wp_smart_apply lem; last iIntros ( x1 x2 x3 x4 ) pat. +Tactic Notation "wp_smart_apply" open_constr(lem) "as" "(" simple_intropattern(x1) + simple_intropattern(x2) simple_intropattern(x3) simple_intropattern(x4) + simple_intropattern(x5) ")" constr(pat) := + wp_smart_apply lem; last iIntros ( x1 x2 x3 x4 x5 ) pat. +Tactic Notation "wp_smart_apply" open_constr(lem) "as" "(" simple_intropattern(x1) + simple_intropattern(x2) simple_intropattern(x3) simple_intropattern(x4) + simple_intropattern(x5) simple_intropattern(x6) ")" constr(pat) := + wp_smart_apply lem; last iIntros ( x1 x2 x3 x4 x5 x6 ) pat. +Tactic Notation "wp_smart_apply" open_constr(lem) "as" "(" simple_intropattern(x1) + simple_intropattern(x2) simple_intropattern(x3) simple_intropattern(x4) + simple_intropattern(x5) simple_intropattern(x6) simple_intropattern(x7) ")" + constr(pat) := + wp_smart_apply lem; last iIntros ( x1 x2 x3 x4 x5 x6 x7 ) pat. +Tactic Notation "wp_smart_apply" open_constr(lem) "as" "(" simple_intropattern(x1) + simple_intropattern(x2) simple_intropattern(x3) simple_intropattern(x4) + simple_intropattern(x5) simple_intropattern(x6) simple_intropattern(x7) + simple_intropattern(x8) ")" constr(pat) := + wp_smart_apply lem; last iIntros ( x1 x2 x3 x4 x5 x6 x7 x8 ) pat. +Tactic Notation "wp_smart_apply" open_constr(lem) "as" "(" simple_intropattern(x1) + simple_intropattern(x2) simple_intropattern(x3) simple_intropattern(x4) + simple_intropattern(x5) simple_intropattern(x6) simple_intropattern(x7) + simple_intropattern(x8) simple_intropattern(x9) ")" constr(pat) := + wp_smart_apply lem; last iIntros ( x1 x2 x3 x4 x5 x6 x7 x8 x9 ) pat. +Tactic Notation "wp_smart_apply" open_constr(lem) "as" "(" simple_intropattern(x1) + simple_intropattern(x2) simple_intropattern(x3) simple_intropattern(x4) + simple_intropattern(x5) simple_intropattern(x6) simple_intropattern(x7) + simple_intropattern(x8) simple_intropattern(x9) simple_intropattern(x10) ")" + constr(pat) := + wp_smart_apply lem; last iIntros ( x1 x2 x3 x4 x5 x6 x7 x8 x9 x10 ) pat. + +Section heap_tactics. + Context `{GwpTacticsBase Σ A hlc gwp, GwpTacticsBind Σ A hlc gwp, !GwpTacticsHeap Σ A laters gwp}. + + Local Notation "'WP' e @ s ; E {{ Φ } }" := (gwp s E e%E Φ) + (at level 20, e, Φ at level 200, only parsing) : bi_scope. + + (** Notations with binder. *) + Local Notation "'WP' e @ s ; E {{ v , Q } }" := (gwp s E e%E (λ v, Q)) + (at level 20, e, Q at level 200, + format "'[hv' 'WP' e '/' @ '[' s ; '/' E ']' '/' {{ '[' v , '/' Q ']' } } ']'") : bi_scope. + + Lemma tac_wp_alloc Δ Δ' E j K v Φ a : + MaybeIntoLaterNEnvs (if laters then 1 else 0) Δ Δ' → + (∀ (l : loc), + match envs_app false (Esnoc Enil j (wptac_mapsto l (DfracOwn 1) v)) Δ' with + | Some Δ'' => + envs_entails Δ'' (WP fill K (Val $ LitV $ LitLoc l) @ a ; E {{ Φ }}) + | None => False + end) → + envs_entails Δ (WP fill K (Alloc (Val v)) @ a; E {{ Φ }}). + Proof. + rewrite envs_entails_unseal=> ? HΔ. + rewrite -wptac_wp_bind. + eapply bi.wand_apply. + { apply bi.wand_entails, wptac_wp_alloc. } + rewrite left_id into_laterN_env_sound. + apply bi.laterN_mono, bi.forall_intro=> l. + specialize (HΔ l). + destruct (envs_app _ _ _) as [Δ''|] eqn:HΔ'; [| contradiction]. + rewrite envs_app_sound //; simpl. + apply bi.wand_intro_l. + rewrite right_id. + rewrite bi.wand_elim_r //. + Qed. + + Lemma tac_wp_allocN Δ Δ' E j K v n Φ a : + (0 < n)%Z → + MaybeIntoLaterNEnvs (if laters then 1 else 0) Δ Δ' → + (∀ l, + match envs_app false (Esnoc Enil j (wptac_mapsto_array l (DfracOwn 1) (replicate (Z.to_nat n) v))) Δ' with + | Some Δ'' => + envs_entails Δ'' (WP fill K (Val $ LitV $ LitLoc l) @ a; E {{ Φ }}) + | None => False + end) → + envs_entails Δ (WP fill K (AllocN (Val $ LitV $ LitInt n) (Val v)) @ a; E {{ Φ }}). + Proof. + rewrite envs_entails_unseal=> ? ? HΔ. + rewrite -wptac_wp_bind. + eapply bi.wand_apply. + { by apply bi.wand_entails, wptac_wp_allocN. } + rewrite left_id into_laterN_env_sound. + apply bi.laterN_mono, bi.forall_intro=> l. + specialize (HΔ l). + destruct (envs_app _ _ _) as [Δ''|] eqn:HΔ'; [ | contradiction ]. + rewrite envs_app_sound //; simpl. + apply bi.wand_intro_l. + rewrite right_id. + rewrite bi.wand_elim_r //. + Qed. + + Lemma tac_wp_load Δ Δ' E i K b l dq v Φ a : + MaybeIntoLaterNEnvs (if laters then 1 else 0) Δ Δ' → + envs_lookup i Δ' = Some (b, wptac_mapsto l dq v) → + envs_entails Δ' (WP fill K (Val v) @ a; E {{ Φ }}) → + envs_entails Δ (WP fill K (Load (Val $ LitV $ LitLoc l)) @ a; E {{ Φ }}). + Proof. + rewrite envs_entails_unseal=> ?? Hi. + rewrite -wptac_wp_bind. + eapply bi.wand_apply. + { apply bi.wand_entails, wptac_wp_load. } + rewrite into_laterN_env_sound. + destruct laters. + - rewrite -bi.later_sep. + rewrite envs_lookup_split //; simpl. + apply bi.later_mono. + destruct b; simpl. + * iIntros "[#$ He]". iIntros "_". iApply Hi. iApply "He". iFrame "#". + * by apply bi.sep_mono_r, bi.wand_mono. + - rewrite envs_lookup_split //; simpl. + destruct b; simpl. + * iIntros "[#$ He]". iIntros "_". iApply Hi. iApply "He". iFrame "#". + * iIntros "[$ He] H". iApply Hi. by iApply "He". + Qed. + + Lemma tac_wp_store Δ Δ' E i K l v v' Φ a : + MaybeIntoLaterNEnvs (if laters then 1 else 0) Δ Δ' → + envs_lookup i Δ' = Some (false, wptac_mapsto l (DfracOwn 1) v)%I → + match envs_simple_replace i false (Esnoc Enil i (wptac_mapsto l (DfracOwn 1) v')) Δ' with + | Some Δ'' => envs_entails Δ'' (WP fill K (Val $ LitV LitUnit) @ a; E {{ Φ }}) + | None => False + end → + envs_entails Δ (WP fill K (Store (Val $ LitV $ LitLoc l) (Val v')) @ a; E {{ Φ }}). + Proof. + rewrite envs_entails_unseal=> ?? Hcnt. + destruct (envs_simple_replace _ _ _) as [Δ''|] eqn:HΔ''; [ | contradiction ]. + rewrite -wptac_wp_bind. eapply bi.wand_apply. + { eapply bi.wand_entails, wptac_wp_store. } + rewrite into_laterN_env_sound. + destruct laters. + - rewrite -bi.later_sep envs_simple_replace_sound //; simpl. + rewrite right_id. by apply bi.later_mono, bi.sep_mono_r, bi.wand_mono. + - rewrite envs_simple_replace_sound //; simpl. + rewrite right_id. + iIntros "[$ He] H". iApply Hcnt. by iApply "He". + Qed. + +End heap_tactics. + +Tactic Notation "wp_alloc" ident(l) "as" constr(H) := + let Htmp := iFresh in + let finish _ := + first [intros l | fail 1 "wp_alloc:" l "not fresh"]; + pm_reduce; + lazymatch goal with + | |- False => fail 1 "wp_alloc:" H "not fresh" + | _ => iDestructHyp Htmp as H; wp_finish + end in + wp_pures; + (** The code first tries to use allocation lemma for a single reference, + ie, [tac_wp_alloc] (respectively, [tac_twp_alloc]). + If that fails, it tries to use the lemma [tac_wp_allocN] + (respectively, [tac_twp_allocN]) for allocating an array. + Notice that we could have used the array allocation lemma also for single + references. However, that would produce the resource l ↦∗ [v] instead of + l ↦ v for single references. These are logically equivalent assertions + but are not equal. *) +lazymatch goal with + | |- envs_entails _ (wp ?s ?E ?e ?Q) => + let process_single _ := + first + [reshape_expr e ltac:(fun K e' => eapply (tac_wp_alloc _ _ _ Htmp K)) + |fail 1 "wp_alloc: cannot find 'Alloc' in" e]; + [tc_solve + |finish ()] + in + let process_array _ := + first + [reshape_expr e ltac:(fun K e' => eapply (tac_wp_allocN _ _ _ Htmp K)) + |fail 1 "wp_alloc: cannot find 'Alloc' in" e]; + [idtac| tc_solve + |finish ()] + in (process_single ()) || (process_array ()) +| |- envs_entails _ (twp ?s ?E ?e ?Q) => + let process_single _ := + first + [reshape_expr e ltac:(fun K e' => eapply (tac_wp_alloc _ _ _ Htmp K)) + |fail 1 "wp_alloc: cannot find 'Alloc' in" e]; + [tc_solve + |finish ()] + in + let process_array _ := + first + [reshape_expr e ltac:(fun K e' => eapply (tac_wp_allocN _ _ _ Htmp K)) + |fail 1 "wp_alloc: cannot find 'Alloc' in" e]; + [idtac| tc_solve + |finish ()] + in (process_single ()) || (process_array ()) + | _ => fail "wp_alloc: not a 'wp'" + end. + + +Tactic Notation "wp_alloc" ident(l) := + wp_alloc l as "?". + +Tactic Notation "wp_load" := + let solve_wptac_mapsto _ := + let l := match goal with |- _ = Some (_, (wptac_mapsto ?l _ _)%I) => l end in + iAssumptionCore || fail "wp_load: cannot find" l "↦ ?" in + wp_pures; + lazymatch goal with + | |- envs_entails _ (wp ?s ?E ?e ?Q) => + first + [reshape_expr e ltac:(fun K e' => eapply (tac_wp_load _ _ _ _ K)) + |fail 1 "wp_load: cannot find 'Load' in" e]; + [tc_solve + |solve_wptac_mapsto () + |wp_finish] + | |- envs_entails _ (twp ?s ?E ?e ?Q) => + first + [reshape_expr e ltac:(fun K e' => eapply (tac_wp_load _ _ _ _ K)) + |fail 1 "wp_load: cannot find 'Load' in" e]; + [tc_solve + |solve_wptac_mapsto () + |wp_finish] + | _ => fail "wp_load: not a 'wp'" + end. + +Tactic Notation "wp_store" := + let solve_wptac_mapsto _ := + let l := match goal with |- _ = Some (_, (wptac_mapsto ?l _ _)%I) => l end in + iAssumptionCore || fail "wp_store: cannot find" l "↦ ?" in + wp_pures; + lazymatch goal with + | |- envs_entails _ (wp ?s ?E ?e ?Q) => + first + [reshape_expr e ltac:(fun K e' => eapply (tac_wp_store _ _ _ _ K)) + |fail 1 "wp_store: cannot find 'Store' in" e]; + [tc_solve + |solve_wptac_mapsto () + |pm_reduce; first [wp_seq|wp_finish]] + | |- envs_entails _ (twp ?s ?E ?e ?Q) => + first + [reshape_expr e ltac:(fun K e' => eapply (tac_wp_store _ _ _ _ K)) + |fail 1 "wp_store: cannot find 'Store' in" e]; + [tc_solve + |solve_wptac_mapsto () + |pm_reduce; first [wp_seq|wp_finish]] + | _ => fail "wp_store: not a 'wp'" + end. + + +Section tape_tactics. + Context `{GwpTacticsBase Σ A hlc gwp, GwpTacticsBind Σ A hlc gwp, !GwpTacticsTapes Σ A laters gwp}. + + Local Notation "'WP' e @ s ; E {{ Φ } }" := (gwp s E e%E Φ) + (at level 20, e, Φ at level 200, only parsing) : bi_scope. + + (** Notations with binder. *) + Local Notation "'WP' e @ s ; E {{ v , Q } }" := (gwp s E e%E (λ v, Q)) + (at level 20, e, Q at level 200, + format "'[hv' 'WP' e '/' @ '[' s ; '/' E ']' '/' {{ '[' v , '/' Q ']' } } ']'") : bi_scope. + + Lemma tac_wp_alloctape Δ Δ' E j K N z Φ a : + TCEq N (Z.to_nat z) -> + MaybeIntoLaterNEnvs (if laters then 1 else 0) Δ Δ' → + (∀ l, + match envs_app false (Esnoc Enil j (wptac_mapsto_tape l (DfracOwn 1) N nil)) Δ' with + | Some Δ'' => + envs_entails Δ'' (WP fill K (Val $ LitV $ LitLbl l) @ a ; E {{ Φ }}) + | None => False + end) → + envs_entails Δ (WP fill K (AllocTape (Val $ LitV $ LitInt z)) @ a; E {{ Φ }}). + Proof. + rewrite envs_entails_unseal=> ? ? HΔ. + rewrite -wptac_wp_bind. + eapply bi.wand_apply. + { by apply bi.wand_entails, wptac_wp_alloctape. } + rewrite left_id into_laterN_env_sound. + apply bi.laterN_mono, bi.forall_intro=> l. + specialize (HΔ l). + destruct (envs_app _ _ _) as [Δ''|] eqn:HΔ'; [| contradiction]. + rewrite envs_app_sound //; simpl. + apply bi.wand_intro_l. + rewrite right_id. + rewrite bi.wand_elim_r //. + Qed. + + + Lemma tac_wp_rand_tape Δ1 Δ2 E i j K l N z n ns Φ a : + TCEq N (Z.to_nat z) -> + MaybeIntoLaterNEnvs (if laters then 1 else 0) Δ1 Δ2 → + envs_lookup i Δ2 = Some (false, wptac_mapsto_tape l (DfracOwn 1) N (n::ns)) -> + (match envs_simple_replace i false (Esnoc Enil i (wptac_mapsto_tape l (DfracOwn 1) N ns)) Δ2 with + | Some Δ3 => + (match envs_app false (Esnoc Enil j (⌜n ≤ N⌝%I)) Δ3 with + | Some Δ4 => envs_entails Δ4 (WP fill K (Val $ LitV $ LitInt n) @ a; E {{ Φ }}) + | None => False + end) + | None => False + end) → + envs_entails Δ1 (gwp a E (fill K (Rand (LitV (LitInt z)) (LitV (LitLbl l)))) Φ ). + Proof. + rewrite envs_entails_unseal=> ?? Hi HΔ. + destruct (envs_simple_replace _ _ _ _) as [Δ3|] eqn:HΔ3; last done. + rewrite -wptac_wp_bind. + eapply bi.wand_apply. + { by apply bi.wand_entails, wptac_wp_rand_tape. } + rewrite into_laterN_env_sound. + destruct laters. + - rewrite -bi.later_sep. + apply bi.later_mono. + rewrite (envs_simple_replace_sound Δ2 Δ3 i) /= //; simpl. + iIntros "[$ He]". + iIntros "Htp ?". + destruct (envs_app _ _ _) as [Δ4 |] eqn: HΔ4; [|contradiction]. + rewrite envs_app_sound //; simpl. + iApply HΔ. + iApply ("He" with "[$Htp]"). + iFrame. + - simpl. + rewrite (envs_simple_replace_sound Δ2 Δ3 i) /= //; simpl. + iIntros "[$ He]". + iIntros "Htp ?". + destruct (envs_app _ _ _) as [Δ4 |] eqn: HΔ4; [|contradiction]. + rewrite envs_app_sound //; simpl. + iApply HΔ. + iApply ("He" with "[$Htp]"). + iFrame. + Qed. + +End tape_tactics. + + +Tactic Notation "wp_alloctape" ident(l) "as" constr(H) := + let Htmp := iFresh in + let finish _ := + first [intros l | fail 1 "wp_alloctape:" l "not fresh"]; + pm_reduce; + lazymatch goal with + | |- False => fail 1 "wp_alloc:" H "not fresh" + | _ => iDestructHyp Htmp as H; wp_finish + end in + wp_pures; +lazymatch goal with + | |- envs_entails _ (wp ?s ?E ?e ?Q) => + first + [reshape_expr e ltac:(fun K e' => eapply (tac_wp_alloctape _ _ _ Htmp K)) + |fail 1 "wp_alloc: cannot find 'AllocTape' in" e]; + [tc_solve | tc_solve + |finish ()] +| |- envs_entails _ (twp ?s ?E ?e ?Q) => + first + [reshape_expr e ltac:(fun K e' => eapply (tac_wp_alloctape _ _ _ Htmp K)) + |fail 1 "wp_alloc: cannot find 'AllocTape' in" e]; + [tc_solve | tc_solve + |finish ()] + | _ => fail "wp_alloc: not a 'wp'" + end. + + +Tactic Notation "wp_alloctape" ident(l) := + wp_alloctape l as "?". + +Tactic Notation "wp_randtape" "as" constr(H) := + let Htmp := iFresh in + let solve_wptac_mapsto_tape _ := + let l := match goal with |- _ = Some (_, (wptac_mapsto_tape ?l _ _ (_ :: _))%I) => l end in + iAssumptionCore || fail "wp_load: cannot find" l "↪N ?" in + let finish _ := + pm_reduce; + lazymatch goal with + | |- False => fail 1 "wp_alloc:" H "not fresh" + | _ => iDestructHyp Htmp as H; wp_finish + end in + wp_pures; + lazymatch goal with + | |- envs_entails _ (wp ?s ?E ?e ?Q) => + first + [reshape_expr e ltac:(fun K e' => eapply (tac_wp_rand_tape _ _ _ _ Htmp K)) + |fail 1 "wp_load: cannot find 'Rand' in" e]; + [ (* Delay resolution of TCEq *) + | tc_solve + | solve_wptac_mapsto_tape () + |]; + [try tc_solve | finish ()] + | |- envs_entails _ (twp ?s ?E ?e ?Q) => + first + [reshape_expr e ltac:(fun K e' => eapply (tac_wp_rand_tape _ _ _ _ Htmp K)) + |fail 1 "wp_load: cannot find 'Rand' in" e]; + [try tc_solve + |tc_solve + |try (solve_wptac_mapsto_tape ()) + |finish ()] + | _ => fail "wp_load: not a 'wp'" + end. + +Tactic Notation "wp_randtape" := + wp_randtape as "%". +*)