Skip to content

Commit

Permalink
discrete semantics for new tapes
Browse files Browse the repository at this point in the history
  • Loading branch information
markusdemedeiros committed Sep 23, 2024
1 parent 8fef565 commit 5bf2f54
Showing 1 changed file with 92 additions and 46 deletions.
138 changes: 92 additions & 46 deletions theories/meas_lang/lang.v
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
From HB Require Import structures.
From Coq Require Import Logic.ClassicalEpsilon Psatz.
From stdpp Require Import base numbers binders strings gmap.
From mathcomp.analysis Require Import reals measure.
From mathcomp Require Import ssrbool all_algebra eqtype choice boolp classical_sets.
From mathcomp.analysis Require Import reals measure itv.
From mathcomp Require Import ssrbool all_algebra eqtype choice boolp classical_sets fintype.
From iris.algebra Require Export ofe.
From clutch.prelude Require Export stdpp_ext.
From clutch.common Require Export locations.
Expand Down Expand Up @@ -78,8 +78,8 @@ Inductive expr :=
| AllocTape (e : expr)
| Rand (e1 e2 : expr)
(* Real probabilistic choice *)
| AllocUTape (e : expr)
| URand (e1 e2 : expr)
| AllocUTape
| URand (e : expr)
(* No-op operator used for cost *)
| Tick (e : expr)
with val :=
Expand Down Expand Up @@ -157,6 +157,8 @@ Definition tapeUpdateUnsafe {A} (i : nat) (v : option A) (t : tape A) : tape A :

Global Instance tape_insert {A} : Insert nat (option A) (tape A) := tapeUpdateUnsafe.

Program Definition tapeAdvance {A} (t : tape A) : tape A
:= {| tape_position := 1 + tape_position _ t; tape_contents := tape_contents _ t |}.

(* Advance the tape by 1, returning an updated tape and the first sample on the tape. *)
Program Definition tapeNext {A} (t : tape A) (H : isSome (t !! 0)) : A * (tape A)
Expand Down Expand Up @@ -267,11 +269,9 @@ Inductive ectx_item :=
| StoreLCtx (v2 : val)
| StoreRCtx (e1 : expr)
| AllocTapeCtx
| AllocUTapeCtx
| RandLCtx (v2 : val)
| RandRCtx (e1 : expr)
| URandLCtx (v2 : val)
| URandRCtx (e1 : expr)
| URandCtx
| TickCtx.

Definition fill_item (Ki : ectx_item) (e : expr) : expr :=
Expand All @@ -295,11 +295,9 @@ Definition fill_item (Ki : ectx_item) (e : expr) : expr :=
| StoreLCtx v2 => Store e (Val v2)
| StoreRCtx e1 => Store e1 e
| AllocTapeCtx => AllocTape e
| AllocUTapeCtx => AllocUTape e
| RandLCtx v2 => Rand e (Val v2)
| RandRCtx e1 => Rand e1 e
| URandLCtx v2 => URand e (Val v2)
| URandRCtx e1 => URand e1 e
| URandCtx => URand e
| TickCtx => Tick e
end.

Expand Down Expand Up @@ -342,17 +340,12 @@ Definition decomp_item (e : expr) : option (ectx_item * expr) :=
| _ => Some (StoreRCtx e1, e2)
end
| AllocTape e => noval e AllocTapeCtx
| AllocUTape e => noval e AllocUTapeCtx
| Rand e1 e2 =>
match e2 with
| Val v => noval e1 (RandLCtx v)
| _ => Some (RandRCtx e1, e2)
end
| URand e1 e2 =>
match e2 with
| Val v => noval e1 (URandLCtx v)
| _ => Some (URandRCtx e1, e2)
end
| URand e => noval e URandCtx
| Tick e => noval e TickCtx
| _ => None
end.
Expand All @@ -378,9 +371,9 @@ Fixpoint subst (x : string) (v : val) (e : expr) : expr :=
| Load e => Load (subst x v e)
| Store e1 e2 => Store (subst x v e1) (subst x v e2)
| AllocTape e => AllocTape (subst x v e)
| AllocUTape e => AllocUTape (subst x v e)
| AllocUTape => AllocUTape
| Rand e1 e2 => Rand (subst x v e1) (subst x v e2)
| URand e1 e2 => URand (subst x v e1) (subst x v e2)
| URand e => URand (subst x v e)
| Tick e => Tick (subst x v e)
end.

Expand Down Expand Up @@ -634,17 +627,30 @@ Section pointed_instances.
(** state * loc is pointed (automatic) *)
(* Check (<<discr (state * loc)>> : measurableType _). *)

(*
(** [0, 1] is pointed *)
(* FIXME Only used to build a discrete space over [0, 1], which we will delete *)
HB.instance Definition _ (R : realType) := isPointed.Build {i01 R} (0)%:i01.
(* Check (<<discr {i01 R}>> : measurableType _). *)
*)

(* FIXME: Super bad, casuses NFI, but I can't figure out any other way to get HB to *)
(* recognize R as a pointedType for <<discr R>>. This is temporary, and will be deleted *)
(* or fixed when we move to the new sigma algebra. *)
(** R is pointed *)
#[non_forgetful_inheritance]
HB.instance Definition _ (R : realType) := isPointed.Build R (0)%R.
(* Check (<<discr R>> : measurableType _). *)

End pointed_instances.

Definition cfg : Type := expr * state.

Section meas_semantics.
Local Open Scope classical_set_scope.
Context {R : realType}.
Notation giryM := (giryM (R := R)).
Notation giryM := (giryM (R := Real.sort meas_lang.R)).
Local Open Scope expr_scope.


Definition head_stepM_def (c : cfg) : giryM <<discr cfg>> :=
let (e1, σ1) := c in
match e1 with
Expand Down Expand Up @@ -695,47 +701,85 @@ Section meas_semantics.
| Some v => giryM_ret R ((Val $ LitV LitUnit, state_upd_heap <[l:=w]> σ1) : <<discr cfg>>)
| None => giryM_zero
end

(* FIXME: Finish implementation for tapes *)

(*
(* Uniform sampling from [0, 1 , ..., N] *)
| Rand (Val (LitV (LitInt N))) (Val (LitV LitUnit)) =>
giryM_map
(m_discr (fun (n : 'I_(S (Z.to_nat N))) => ((Val $ LitV $ LitInt n, σ1) : <<discr cfg>>)))
(giryM_unif (Z.to_nat N))
| AllocTape (Val (LitV (LitInt z))) =>
let ι := fresh_loc σ1.(tapes) in
giryM_ret R ((Val $ LitV $ LitLbl ι, state_upd_tapes <[ι := (fin (Z.to_nat z; [])) ]> σ1) : <<discr cfg>>)
*)

(*
(* Labelled sampling, conditional on tape contents *)
giryM_ret R ((Val $ LitV $ LitLbl ι, state_upd_tapes <[ι := {| btape_tape := emptyTape ; btape_bound := (Z.to_nat z) |} ]> σ1) : <<discr cfg>>)
(* Rand with a tape *)
| Rand (Val (LitV (LitInt N))) (Val (LitV (LitLbl l))) =>
match σ1.(tapes) !! l with
| Some (M; ns) =>
if bool_decide (M = Z.to_nat N) then
match ns with
| n :: ns =>
(* the tape is non-empty so we consume the first number *)
giryM_ret R ((Val $ LitV $ LitInt $ fin_to_nat n, state_upd_tapes <[l:=(fin(M; ns))]> σ1) : <<discr cfg>>)
| [] =>
| Some btape =>
(* There exists a tape with label l *)
let τ := btape.(btape_tape) in
let M := btape.(btape_bound) in
if (bool_decide (M = Z.to_nat N)) then
(* Tape bounds match *)
match (τ !! 0) with
| Some v =>
(* There is a next value on the tape *)
let σ' := state_upd_tapes <[ l := {| btape_tape := (tapeAdvance τ); btape_bound := M |} ]> σ1 in
(giryM_ret R ((Val $ LitV $ LitInt $ Z.of_nat v, σ') : <<discr cfg>>))
| None =>
(* Next slot on tape is empty *)
giryM_map
(m_discr (fun (n : 'I_(S (Z.to_nat M))) => ((Val $ LitV $ LitInt n, σ1) : <<discr cfg>>)))
(giryM_unif (Z.to_nat _))
(m_discr (fun (v : 'I_(S (Z.to_nat N))) =>
(* Fill the tape head with new sample *)
let τ' := <[ (0 : nat) := Some (v : nat) ]> τ in
(* Advance the tape *)
let σ' := state_upd_tapes <[ l := {| btape_tape := (tapeAdvance τ'); btape_bound := M |} ]> σ1 in
(* Return the new sample and state *)
((Val $ LitV $ LitInt $ Z.of_nat v, σ') : <<discr cfg>>)))
(giryM_unif (Z.to_nat N))
end
else
(* bound did not match the bound of the tape *)
(* Tape bounds do not match *)
(* Do not advance the tape, but still generate a new sample *)
giryM_map
(m_discr (fun (n : 'I_(S (Z.to_nat M))) => ((Val $ LitV $ LitInt n, σ1) : <<discr cfg>>)))
(giryM_unif (Z.to_nat _))
| None => mzero
(m_discr (fun (n : 'I_(S (Z.to_nat N))) => ((Val $ LitV $ LitInt n, σ1) : <<discr cfg>>)))
(giryM_unif (Z.to_nat N))
| None => giryM_zero
end
| AllocUTape =>
let ι := fresh_loc σ1.(utapes) in
giryM_ret R ((Val $ LitV $ LitLbl ι, state_upd_utapes <[ ι := emptyTape ]> σ1) : <<discr cfg>>)
(* Urand with no tape *)
| URand (Val (LitV LitUnit)) =>
giryM_map
(m_discr (fun u => ((Val $ LitV $ LitReal u, σ1) : <<discr cfg>>)))
giryM_zero
(* Urand with a tape *)
| URand (Val (LitV (LitLbl l))) =>
match σ1.(utapes) !! l with
| Some τ =>
(* tape l is allocated *)
match (τ !! 0) with
| Some u =>
(* Head has a sample *)
let σ' := state_upd_utapes <[ l := (tapeAdvance τ) ]> σ1 in
(giryM_ret R ((Val $ LitV $ LitReal u, σ') : <<discr cfg>>))
| None =>
(* Head has no sample *)
giryM_map
(m_discr (fun (u : R) =>
(* Fill tape head with new sample *)
let τ' := <[ (0 : nat) := Some u ]> τ in
(* Advance tape *)
let σ' := state_upd_utapes <[ l := (tapeAdvance τ') ]> σ1 in
(* Return the update value an state *)
((Val $ LitV $ LitReal u, σ') : <<discr cfg>>)))
giryM_zero
end
| None => giryM_zero
end
*)
| Tick (Val (LitV (LitInt n))) => giryM_ret R ((Val $ LitV $ LitUnit, σ1) : <<discr cfg>>)
| _ => giryM_zero
end.


(* head_stepM is a measurable map because it is a function out of a discrete space.
After we add continuous varaibles this argument gets more complex. It is not
Expand Down Expand Up @@ -1004,9 +1048,9 @@ Fixpoint height (e : expr) : nat :=
| Load e => 1 + height e
| Store e1 e2 => 1 + height e1 + height e2
| AllocTape e => 1 + height e
| AllocUTape e => 1 + height e
| AllocUTape => 1
| Rand e1 e2 => 1 + height e1 + height e2
| URand e1 e2 => 1 + height e1 + height e2
| URand e => 1 + height e
| Tick e => 1 + height e
end.

Expand Down Expand Up @@ -1048,6 +1092,8 @@ Definition get_active (σ : state) : list loc := elements (dom σ.(tapes)).


Local Open Scope classical_set_scope.


(*
Program Definition meas_lang_mixin : Type :=
@MeasEctxiLanguageMixin R _ _ _ <<discr expr>> <<discr val>> <<discr state>> ectx_item of_val to_val _ _ _ _.
Expand Down

0 comments on commit 5bf2f54

Please sign in to comment.