-
Notifications
You must be signed in to change notification settings - Fork 0
/
monad_convert.ML
234 lines (202 loc) · 8.2 KB
/
monad_convert.ML
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
(*
* Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
*
* SPDX-License-Identifier: BSD-2-Clause
*)
(*
* Code to manage converting between L2_monad and other monad types.
*
* TypeStrengthen provides a higher level interface for converting entire programs.
*)
structure Monad_Convert = struct
(* Utilities. *)
fun intersperse _ [] = []
| intersperse _ [x] = [x]
| intersperse a (x::xs) = x :: a :: intersperse a xs
fun theE NONE exc = raise exc
| theE (SOME x) _ = x
fun oneE [] exc = raise exc
| oneE (x::_) _ = x
(* From Find_Theorems *)
fun apply_dummies tm =
let
val (xs, _) = Term.strip_abs tm;
val tm' = Term.betapplys (tm, map (Term.dummy_pattern o #2) xs);
in #1 (Term.replace_dummy_patterns tm' 1) end;
fun parse_pattern ctxt nm =
let
val consts = Proof_Context.consts_of ctxt;
val nm' =
(case Syntax.parse_term ctxt nm of
Const (c, _) => c
| _ => Consts.intern consts nm);
in
(case try (Consts.the_abbreviation consts) nm' of
SOME (_, rhs) => apply_dummies (Proof_Context.expand_abbrevs ctxt rhs)
| NONE => Proof_Context.read_term_pattern ctxt nm)
end;
(* Breadth-first term search *)
fun term_search_bf cont pred prune = let
fun fresh_var vars v = if member (op =) vars v then fresh_var vars (v ^ "'") else v
fun search ((vars, term), queue) =
if pred term then cont (vars, term) (fn () => walk queue) else
if prune term then walk queue else
case term of
t as Abs (v, typ, _) =>
let val v' = fresh_var vars v in
walk (Queue.enqueue
((v'::vars), betapply (t, Free (v', typ))) queue)
end
| f $ x => walk (Queue.enqueue (vars, x) (Queue.enqueue (vars, f) queue))
| _ => walk queue
and walk queue = if Queue.is_empty queue then () else search (Queue.dequeue queue)
in
(fn term => search (([], term), Queue.empty))
end
fun term_search_bf_first pred prune term = let
val r = Unsynchronized.ref NONE
val _ = term_search_bf (fn result => K (r := SOME result)) pred prune term
in !r end
(* From Pure/Tools/find_theorems.ML, because Florian made it private *)
fun matches_subterm ctxt (pat, obj) =
let
val thy = Proof_Context.theory_of ctxt;
fun matches obj ctxt' = Pattern.matches thy (pat, obj) orelse
(case obj of
Abs (_, _, _) =>
let val ((_, t'), ctxt'') = Variable.dest_abs obj ctxt'
in matches t' ctxt'' end
| t $ u => matches t ctxt' orelse matches u ctxt'
| _ => false);
in matches obj ctxt end;
fun grep_term ctxt pattern =
let
val thy = Proof_Context.theory_of ctxt
in
term_search_bf_first
(fn term => Pattern.matches thy (pattern, term))
(fn term => not (matches_subterm ctxt (pattern, term)))
end
(* Check whether the term is in L2_monad notation. *)
val term_is_L2 = Monad_Types.check_lifting_head
[@{term "L2_unknown"}, @{term "L2_seq"}, @{term "L2_modify"},
@{term "L2_gets"}, @{term "L2_condition"}, @{term "L2_catch"}, @{term "L2_while"},
@{term "L2_throw"}, @{term "L2_spec"}, @{term "L2_guard"}, @{term "L2_fail"},
@{term "L2_recguard"}, @{term "L2_call"}]
(*
* Perform monad conversion on a term, taking into account any extra
* simplifying facts. Only a successful conversion is returned.
*
* For this conversion to be useful on recursive programs, it needs
* to be given a fact representing the inductive assumption.
*)
fun monad_rewrite (lthy : Proof.context) (mt : Monad_Types.monad_type)
(more_facts : thm list) (forward : bool)
(term : term) : thm option =
let
val lthy = Utils.set_hidden_ctxt lthy
val rules = Monad_Types.monad_type_rules mt
val rules' = if forward then #lift_rules rules else #unlift_rules rules
val cterm = Thm.cterm_of lthy term
(* Just apply the simplifier and hope that it works. *)
val thm = Simplifier.rewrite (
put_simpset HOL_ss lthy addsimps rules' addsimps more_facts) cterm
val rhs = Utils.rhs_of (term_of_thm thm)
val good_rewrite = if forward then #valid_term mt else term_is_L2
in
if good_rewrite lthy rhs
then SOME thm else NONE
end
(*
* Apply polish to a theorem of the form:
*
* <LHS> == <lift> $ <some term to polish>
*
* Return the new theorem.
*)
local
val case_prod_eta_contract_thm =
@{lemma "(%x. (case_prod s) x) == (case_prod s)" by simp}
in
fun polish ctxt (mt : Monad_Types.monad_type) do_opt thm =
let
(* Apply any polishing rules. *)
val ctxt = Utils.set_hidden_ctxt ctxt
val simps = if do_opt then Utils.get_rules ctxt @{named_theorems polish} else []
(* Simplify using polish rules. *)
val rules = Monad_Types.monad_type_rules mt
val simp_conv = Simplifier.rewrite (put_simpset HOL_ss ctxt addsimps (#polish_rules rules @ simps))
(* eta-contract "case_prod clauses, so that they render as:
* "%(a, b). P a b" instead of "case x of (a, b) => P a b". *)
val case_prod_conv =
Conv.bottom_conv (
K (Conv.try_conv (Conv.rewr_conv case_prod_eta_contract_thm))) ctxt
val thm_p =
Conv.fconv_rule (Conv.arg_conv (Utils.rhs_conv (
simp_conv then_conv case_prod_conv))) thm
in
thm_p
end
end
(*
* Wrap a tactic that doesn't handle invalid subgoal numbers to return
* "Seq.empty" when appropriate.
*)
fun handle_invalid_subgoals (tac : int -> tactic) n =
fn thm =>
if Logic.count_prems (term_of_thm thm) < n then
no_tac thm
else
tac n thm
(*
* monad_convert tactic.
*)
fun monad_convert_tac (ctxt : Proof.context) (monad_name : string) (do_opt : bool)
(pattern_str : string) (n : int) : tactic =
fn state => let
val all_rules = Monad_Types.TSRules.get (Context.Proof ctxt)
(* Figure out which monad to lift into. *)
val target_rule = theE (Symtab.lookup all_rules monad_name)
(ERROR ("monad_convert: could not find monad type " ^ quote monad_name))
(* Search the subgoal for the supplied pattern. *)
val pattern = parse_pattern ctxt pattern_str
val subgoal = Logic.get_goal (term_of_thm state) n
val (m_vars, m_term) = theE (grep_term ctxt pattern subgoal)
(TERM ("monad_convert: failed to match pattern", [pattern]))
(* Find a lifting rule whose output matches m_term.
* This saves us from having to try every unlift rule. *)
val orig_lift_rule = oneE (filter (fn mt => #valid_term mt ctxt m_term)
(all_rules |> Symtab.dest |> map snd))
(TERM ("monad_convert: could not determine monad type", [m_term]))
(* Unlift back to L2_monad. *)
val unlift_thm = theE (monad_rewrite ctxt orig_lift_rule [] false m_term)
(TERM ("monad_convert: could not unlift term (rule: " ^
#name orig_lift_rule ^ ")", [m_term]))
val unlift_term = Utils.rhs_of (term_of_thm unlift_thm)
(* Lift to target monad. *)
val relift_thm = theE (monad_rewrite ctxt target_rule [] true unlift_term)
(TERM ("monad_convert: could not lift to " ^ #name target_rule,
[m_term, unlift_term]))
(* Polish result. *)
val relift_thm' = polish ctxt target_rule do_opt relift_thm
val translate_thm = Thm.transitive unlift_thm relift_thm'
(* Make variables schematic *)
val translate_thm' = Goal.prove ctxt (sort_distinct string_ord m_vars) []
(term_of_thm translate_thm)
(K (resolve_tac ctxt [translate_thm] 1))
val result = EqSubst.eqsubst_tac ctxt [0] [translate_thm'] n state
in
case Seq.pull result of
NONE => raise TERM ("monad_convert: failed to apply conversion",
[term_of_thm translate_thm', subgoal])
| SOME (x, xs) => Seq.cons x xs
end
val _ = Context.>> (Context.map_theory
(Method.setup (Binding.name "monad_convert")
(* Based on subgoal_tac parser *)
(Args.goal_spec -- Scan.lift (Parse.name -- Parse.embedded_inner_syntax) >>
(fn (quant, (monad_name, term_str)) => fn ctxt =>
SIMPLE_METHOD'' quant (handle_invalid_subgoals (
monad_convert_tac ctxt monad_name true term_str))))
"autocorres monad conversion"))
end