diff --git a/theories/meas_lang/lang.v b/theories/meas_lang/lang.v index 2ef91d3f..9d0b4f03 100644 --- a/theories/meas_lang/lang.v +++ b/theories/meas_lang/lang.v @@ -1,8 +1,8 @@ From HB Require Import structures. From Coq Require Import Logic.ClassicalEpsilon Psatz. From stdpp Require Import base numbers binders strings gmap. -From mathcomp.analysis Require Import reals measure. -From mathcomp Require Import ssrbool all_algebra eqtype choice boolp classical_sets. +From mathcomp.analysis Require Import reals measure itv. +From mathcomp Require Import ssrbool all_algebra eqtype choice boolp classical_sets fintype. From iris.algebra Require Export ofe. From clutch.prelude Require Export stdpp_ext. From clutch.common Require Export locations. @@ -78,8 +78,8 @@ Inductive expr := | AllocTape (e : expr) | Rand (e1 e2 : expr) (* Real probabilistic choice *) - | AllocUTape (e : expr) - | URand (e1 e2 : expr) + | AllocUTape + | URand (e : expr) (* No-op operator used for cost *) | Tick (e : expr) with val := @@ -157,6 +157,8 @@ Definition tapeUpdateUnsafe {A} (i : nat) (v : option A) (t : tape A) : tape A : Global Instance tape_insert {A} : Insert nat (option A) (tape A) := tapeUpdateUnsafe. +Program Definition tapeAdvance {A} (t : tape A) : tape A + := {| tape_position := 1 + tape_position _ t; tape_contents := tape_contents _ t |}. (* Advance the tape by 1, returning an updated tape and the first sample on the tape. *) Program Definition tapeNext {A} (t : tape A) (H : isSome (t !! 0)) : A * (tape A) @@ -267,11 +269,9 @@ Inductive ectx_item := | StoreLCtx (v2 : val) | StoreRCtx (e1 : expr) | AllocTapeCtx - | AllocUTapeCtx | RandLCtx (v2 : val) | RandRCtx (e1 : expr) - | URandLCtx (v2 : val) - | URandRCtx (e1 : expr) + | URandCtx | TickCtx. Definition fill_item (Ki : ectx_item) (e : expr) : expr := @@ -295,11 +295,9 @@ Definition fill_item (Ki : ectx_item) (e : expr) : expr := | StoreLCtx v2 => Store e (Val v2) | StoreRCtx e1 => Store e1 e | AllocTapeCtx => AllocTape e - | AllocUTapeCtx => AllocUTape e | RandLCtx v2 => Rand e (Val v2) | RandRCtx e1 => Rand e1 e - | URandLCtx v2 => URand e (Val v2) - | URandRCtx e1 => URand e1 e + | URandCtx => URand e | TickCtx => Tick e end. @@ -342,17 +340,12 @@ Definition decomp_item (e : expr) : option (ectx_item * expr) := | _ => Some (StoreRCtx e1, e2) end | AllocTape e => noval e AllocTapeCtx - | AllocUTape e => noval e AllocUTapeCtx | Rand e1 e2 => match e2 with | Val v => noval e1 (RandLCtx v) | _ => Some (RandRCtx e1, e2) end - | URand e1 e2 => - match e2 with - | Val v => noval e1 (URandLCtx v) - | _ => Some (URandRCtx e1, e2) - end + | URand e => noval e URandCtx | Tick e => noval e TickCtx | _ => None end. @@ -378,9 +371,9 @@ Fixpoint subst (x : string) (v : val) (e : expr) : expr := | Load e => Load (subst x v e) | Store e1 e2 => Store (subst x v e1) (subst x v e2) | AllocTape e => AllocTape (subst x v e) - | AllocUTape e => AllocUTape (subst x v e) + | AllocUTape => AllocUTape | Rand e1 e2 => Rand (subst x v e1) (subst x v e2) - | URand e1 e2 => URand (subst x v e1) (subst x v e2) + | URand e => URand (subst x v e) | Tick e => Tick (subst x v e) end. @@ -634,17 +627,30 @@ Section pointed_instances. (** state * loc is pointed (automatic) *) (* Check (<> : measurableType _). *) + (* + (** [0, 1] is pointed *) + (* FIXME Only used to build a discrete space over [0, 1], which we will delete *) + HB.instance Definition _ (R : realType) := isPointed.Build {i01 R} (0)%:i01. + (* Check (<> : measurableType _). *) + *) + + (* FIXME: Super bad, casuses NFI, but I can't figure out any other way to get HB to *) + (* recognize R as a pointedType for <>. This is temporary, and will be deleted *) + (* or fixed when we move to the new sigma algebra. *) + (** R is pointed *) + #[non_forgetful_inheritance] + HB.instance Definition _ (R : realType) := isPointed.Build R (0)%R. + (* Check (<> : measurableType _). *) + End pointed_instances. Definition cfg : Type := expr * state. Section meas_semantics. Local Open Scope classical_set_scope. - Context {R : realType}. - Notation giryM := (giryM (R := R)). + Notation giryM := (giryM (R := Real.sort meas_lang.R)). Local Open Scope expr_scope. - Definition head_stepM_def (c : cfg) : giryM <> := let (e1, σ1) := c in match e1 with @@ -695,10 +701,6 @@ Section meas_semantics. | Some v => giryM_ret R ((Val $ LitV LitUnit, state_upd_heap <[l:=w]> σ1) : <>) | None => giryM_zero end - - (* FIXME: Finish implementation for tapes *) - - (* (* Uniform sampling from [0, 1 , ..., N] *) | Rand (Val (LitV (LitInt N))) (Val (LitV LitUnit)) => giryM_map @@ -706,36 +708,78 @@ Section meas_semantics. (giryM_unif (Z.to_nat N)) | AllocTape (Val (LitV (LitInt z))) => let ι := fresh_loc σ1.(tapes) in - giryM_ret R ((Val $ LitV $ LitLbl ι, state_upd_tapes <[ι := (fin (Z.to_nat z; [])) ]> σ1) : <>) - *) - - (* - (* Labelled sampling, conditional on tape contents *) + giryM_ret R ((Val $ LitV $ LitLbl ι, state_upd_tapes <[ι := {| btape_tape := emptyTape ; btape_bound := (Z.to_nat z) |} ]> σ1) : <>) + (* Rand with a tape *) | Rand (Val (LitV (LitInt N))) (Val (LitV (LitLbl l))) => match σ1.(tapes) !! l with - | Some (M; ns) => - if bool_decide (M = Z.to_nat N) then - match ns with - | n :: ns => - (* the tape is non-empty so we consume the first number *) - giryM_ret R ((Val $ LitV $ LitInt $ fin_to_nat n, state_upd_tapes <[l:=(fin(M; ns))]> σ1) : <>) - | [] => + | Some btape => + (* There exists a tape with label l *) + let τ := btape.(btape_tape) in + let M := btape.(btape_bound) in + if (bool_decide (M = Z.to_nat N)) then + (* Tape bounds match *) + match (τ !! 0) with + | Some v => + (* There is a next value on the tape *) + let σ' := state_upd_tapes <[ l := {| btape_tape := (tapeAdvance τ); btape_bound := M |} ]> σ1 in + (giryM_ret R ((Val $ LitV $ LitInt $ Z.of_nat v, σ') : <>)) + | None => + (* Next slot on tape is empty *) giryM_map - (m_discr (fun (n : 'I_(S (Z.to_nat M))) => ((Val $ LitV $ LitInt n, σ1) : <>))) - (giryM_unif (Z.to_nat _)) + (m_discr (fun (v : 'I_(S (Z.to_nat N))) => + (* Fill the tape head with new sample *) + let τ' := <[ (0 : nat) := Some (v : nat) ]> τ in + (* Advance the tape *) + let σ' := state_upd_tapes <[ l := {| btape_tape := (tapeAdvance τ'); btape_bound := M |} ]> σ1 in + (* Return the new sample and state *) + ((Val $ LitV $ LitInt $ Z.of_nat v, σ') : <>))) + (giryM_unif (Z.to_nat N)) end else - (* bound did not match the bound of the tape *) + (* Tape bounds do not match *) + (* Do not advance the tape, but still generate a new sample *) giryM_map - (m_discr (fun (n : 'I_(S (Z.to_nat M))) => ((Val $ LitV $ LitInt n, σ1) : <>))) - (giryM_unif (Z.to_nat _)) - | None => mzero + (m_discr (fun (n : 'I_(S (Z.to_nat N))) => ((Val $ LitV $ LitInt n, σ1) : <>))) + (giryM_unif (Z.to_nat N)) + | None => giryM_zero + end + | AllocUTape => + let ι := fresh_loc σ1.(utapes) in + giryM_ret R ((Val $ LitV $ LitLbl ι, state_upd_utapes <[ ι := emptyTape ]> σ1) : <>) + (* Urand with no tape *) + | URand (Val (LitV LitUnit)) => + giryM_map + (m_discr (fun u => ((Val $ LitV $ LitReal u, σ1) : <>))) + giryM_zero + (* Urand with a tape *) + | URand (Val (LitV (LitLbl l))) => + match σ1.(utapes) !! l with + | Some τ => + (* tape l is allocated *) + match (τ !! 0) with + | Some u => + (* Head has a sample *) + let σ' := state_upd_utapes <[ l := (tapeAdvance τ) ]> σ1 in + (giryM_ret R ((Val $ LitV $ LitReal u, σ') : <>)) + | None => + (* Head has no sample *) + giryM_map + (m_discr (fun (u : R) => + (* Fill tape head with new sample *) + let τ' := <[ (0 : nat) := Some u ]> τ in + (* Advance tape *) + let σ' := state_upd_utapes <[ l := (tapeAdvance τ') ]> σ1 in + (* Return the update value an state *) + ((Val $ LitV $ LitReal u, σ') : <>))) + giryM_zero + end + | None => giryM_zero end - *) | Tick (Val (LitV (LitInt n))) => giryM_ret R ((Val $ LitV $ LitUnit, σ1) : <>) | _ => giryM_zero end. + (* head_stepM is a measurable map because it is a function out of a discrete space. After we add continuous varaibles this argument gets more complex. It is not @@ -1004,9 +1048,9 @@ Fixpoint height (e : expr) : nat := | Load e => 1 + height e | Store e1 e2 => 1 + height e1 + height e2 | AllocTape e => 1 + height e - | AllocUTape e => 1 + height e + | AllocUTape => 1 | Rand e1 e2 => 1 + height e1 + height e2 - | URand e1 e2 => 1 + height e1 + height e2 + | URand e => 1 + height e | Tick e => 1 + height e end. @@ -1048,6 +1092,8 @@ Definition get_active (σ : state) : list loc := elements (dom σ.(tapes)). Local Open Scope classical_set_scope. + + (* Program Definition meas_lang_mixin : Type := @MeasEctxiLanguageMixin R _ _ _ <> <> <> ectx_item of_val to_val _ _ _ _.