Skip to content

Commit

Permalink
change evaluation contexts to be measurable functions
Browse files Browse the repository at this point in the history
  • Loading branch information
markusdemedeiros committed Sep 19, 2024
1 parent 1fe1a3f commit 4c82f68
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 18 deletions.
26 changes: 21 additions & 5 deletions theories/meas_lang/erasable.v
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
From Coq Require Import Reals Psatz.
From clutch.common Require Import language.
From clutch.prob Require Export couplings distribution markov.
From HB Require Import structures.
From Coq Require Import Logic.ClassicalEpsilon Psatz.
From stdpp Require Import base numbers binders strings gmap.
From mathcomp Require Import ssrbool all_algebra eqtype choice boolp classical_sets.
From iris.prelude Require Import options.
From iris.algebra Require Import ofe.
From clutch.bi Require Import weakestpre.
From mathcomp.analysis Require Import reals measure ereal.
From clutch.prob.monad Require Import laws.
From clutch.meas_lang Require Import language.

(*
Section erasable.
Context {Λ : language}.
Context {R : realType}.
Notation giryM := (giryM (R := R)).
Context {Λ : meas_language}.

(*
Definition meas_erasable (f : measurable_map (state Λ) (giryM (state Λ))) : Prop :=
forall e m,
*)


(*
Definition erasable (μ : distr (state Λ)) σ:=
∀ e m, μ ≫= (λ σ', exec m (e, σ')) = exec m (e, σ).
Expand Down
48 changes: 35 additions & 13 deletions theories/meas_lang/language.v
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
From HB Require Import structures.
From Coq Require Import Logic.ClassicalEpsilon Psatz.
From Coq Require Import Logic.ClassicalEpsilon Psatz Logic.FunctionalExtensionality.
From stdpp Require Import base numbers binders strings gmap.
From mathcomp Require Import ssrbool all_algebra eqtype choice boolp classical_sets.
From iris.prelude Require Import options.
Expand All @@ -25,7 +25,7 @@ Section language_mixin.
Context (to_val : expr → option val).

Definition dead_cfg (s : giryM (expr * state)%type) : Prop
:= s = giryM_zero.
:= measure_eq s giryM_zero.

Definition live_cfg (s : giryM (expr * state)%type) : Prop
:= (@giryM_eval _ _ _ _ (@measurableT _ _) s = 1)%E.
Expand Down Expand Up @@ -64,6 +64,16 @@ Structure meas_language := Language {
language_mixin : MeasLanguageMixin R of_val to_val prim_step
}.



(** Register MCA products into measurableMap hierarchy *)

HB.instance Definition _ {d1 d2 } {T1 : measurableType d1} {T2 : measurableType d2} :=
isMeasurableMap.Build _ _ (T1 * T2)%type T1 fst measurable_fst.

HB.instance Definition _ {d1 d2 } {T1 : measurableType d1} {T2 : measurableType d2} :=
isMeasurableMap.Build _ _ (T1 * T2)%type T2 snd measurable_snd.

Bind Scope expr_scope with expr.
Bind Scope val_scope with val.

Expand All @@ -78,28 +88,38 @@ Canonical Structure exprO Λ := leibnizO (expr Λ).

Definition cfg (Λ : meas_language) := (expr Λ * state Λ)%type.

Definition fill_lift {Λ} (K : measurable_map (expr Λ) (expr Λ)) : (expr Λ * state Λ) → (expr Λ * state Λ) :=
λ c, (K (fst c), (snd c)).

Local Lemma fill_lift_measurable {Λ} (K : measurable_map (expr Λ) (expr Λ)) :
@measurable_fun _ _ (expr Λ * state Λ)%type (expr Λ * state Λ)%type setT (fill_lift K).
Proof.
apply measurable_fun_prod.
{ simpl.
have -> : (λ x : expr Λ * state Λ, K x.1) = m_cmp K fst.
{ apply functional_extensionality.
intro x.
by rewrite m_cmp_eval/=. }
eapply measurable_mapP. }
{ eapply measurable_mapP. }
Qed.

HB.instance Definition _ {Λ} (K : measurable_map (expr Λ) (expr Λ)) :=
isMeasurableMap.Build _ _ (expr Λ * state Λ)%type (expr Λ * state Λ)%type (fill_lift K) (fill_lift_measurable K).

Definition fill_lift {Λ} (K : expr Λ → expr Λ) : (expr Λ * state Λ) → (expr Λ * state Λ) :=
λ '(e, σ), (K e, σ).

Global Instance inj_fill_lift {Λ : meas_language} (K : expr Λexpr Λ) :
Global Instance inj_fill_lift {Λ : meas_language} (K : measurable_map (expr Λ) (expr Λ)) :
Inj (=) (=) K →
Inj (=) (=) (fill_lift K).
Proof. by intros ? [] [] [=->%(inj _) ->]. Qed.

Class MeasLanguageCtx {Λ : meas_language} (K : expr Λ → expr Λ) := {

(** To specify that fill_lift is measurable, give a different measurable_function,
and prove that it is measurable. *)
meas_fill_lift_K : measurable_map (expr Λ * state Λ)%type (expr Λ * state Λ)%type;
meas_fill_lift_spec : forall (ρ : (expr Λ * state Λ)%type), meas_fill_lift_K ρ = fill_lift K ρ;

Class MeasLanguageCtx {Λ : meas_language} (K : measurable_map (expr Λ) (expr Λ)) := {
fill_not_val e :
to_val e = None → to_val (K e) = None;
fill_inj : Inj (=) (=) K;
fill_dmap e1 σ1 :
to_val e1 = None →
prim_step ((K e1), σ1) = giryM_map meas_fill_lift_K (prim_step (e1, σ1))
prim_step ((K e1), σ1) = giryM_map (fill_lift K) (prim_step (e1, σ1))
}.

#[global] Existing Instance fill_inj.
Expand Down Expand Up @@ -172,6 +192,7 @@ Section language.
apply: Hs.
rewrite /dead_cfg; rewrite /dead_cfg in Hs'.

(*
apply giryM_ext.
intro S.

Expand All @@ -182,6 +203,7 @@ Section language.

rewrite giryM_zero_eval in X'.
rewrite giryM_zero_eval.
*)
(* rewrite /pushforward in X'. *)
Admitted.
(*
Expand Down

0 comments on commit 4c82f68

Please sign in to comment.