-
Notifications
You must be signed in to change notification settings - Fork 1
/
ghost_code.py
491 lines (380 loc) · 21.1 KB
/
ghost_code.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
from __future__ import annotations
from dataclasses import dataclass
import dataclasses
from enum import Enum, unique
from typing import Callable, Iterable, Mapping, NamedTuple, Sequence, Tuple, TypeAlias, Set
from typing_extensions import assert_never
import abc_cfg
import source
import nip
import ghost_data
import syntax
from provenance import *
@dataclass(frozen=True)
class NodePostConditionProofObligation(source.NodeCond[source.VarNameKind]):
pass
@dataclass(frozen=True)
class NodePreconditionAssumption(source.NodeAssume[source.VarNameKind]):
pass
@dataclass(frozen=True)
class NodeLoopInvariantAssumption(source.NodeAssume[source.VarNameKind]):
pass
@dataclass(frozen=True)
class NodeLoopInvariantProofObligation(source.NodeCond[source.VarNameKind]):
pass
@dataclass(frozen=True)
class NodePrecondObligationFnCall(source.NodeAssert[source.VarNameKind]):
pass
@dataclass(frozen=True)
class NodeAssumePostCondFnCall(source.NodeAssume[source.VarNameKind]):
pass
@dataclass(frozen=True)
class NodeLoopPreIterationAssumption(source.NodeAssume[source.VarNameKind]):
pass
# should we be using node assert here?
@dataclass(frozen=True)
class NodeLoopPostIterationProofObligation(source.NodeAssert[source.VarNameKind]):
pass
NodeGhostCode = (NodePostConditionProofObligation[source.VarNameKind]
| NodePreconditionAssumption[source.VarNameKind]
| NodeLoopInvariantAssumption[source.VarNameKind]
| NodeLoopInvariantProofObligation[source.VarNameKind]
| NodePrecondObligationFnCall[source.VarNameKind]
| NodeAssumePostCondFnCall[source.VarNameKind])
class GenericFunction(nip.GenericFunction[source.VarNameKind, source.VarNameKind2]):
"""
Function pre conditions, post condition, and loop invariants inserted in
the CFG
"""
Function = GenericFunction[source.ProgVarName |
nip.GuardVarName, source.ProgVarName | nip.GuardVarName]
class Insertion(NamedTuple):
after: source.NodeName
before: source.NodeName
node_name: source.NodeName
mk_node: Callable[[source.NodeName],
source.Node[source.ProgVarName | nip.GuardVarName]]
Edge: TypeAlias = tuple[source.NodeName, source.NodeName]
def apply_insertions(func: nip.Function, insertions: Sequence[Insertion]) -> Mapping[source.NodeName, source.Node[source.ProgVarName | nip.GuardVarName]]:
# edge -> list of insertion to apply on that edge, _in order_ (first is
# inserted first, etc)
edge_insertions: dict[Edge, list[Insertion]] = {}
for insertion in insertions:
edge = (insertion.after, insertion.before)
if edge not in edge_insertions:
edge_insertions[edge] = []
edge_insertions[edge].append(insertion)
new_nodes: dict[source.NodeName,
source.Node[source.ProgVarName | nip.GuardVarName]] = {}
for after_name, node in func.nodes.items():
# construct (and add to the node map) all the insertion nodes
# - first one isn't connected
# - last one jumps to after_name's successor
# case split on the type of after_name to connect the first one
# - dataclass._replace(...)
new_nodes[after_name] = node
for before_name in func.cfg.all_succs[after_name]:
edge = (after_name, before_name)
if edge not in edge_insertions:
continue
assert len(edge_insertions[edge]) > 0
for i, insertion in enumerate(edge_insertions[edge]):
assert insertion.node_name not in new_nodes, f"trying to insert a new node, but the name ({insertion.node_name}) is already taken"
if i == len(edge_insertions[edge]) - 1:
insertion_succ = before_name
else:
insertion_succ = edge_insertions[edge][i+1].node_name
new_nodes[insertion.node_name] = insertion.mk_node(
insertion_succ)
first_inserted_node_name: source.NodeName = edge_insertions[edge][0].node_name
# connect node_name to the first insertion node
if isinstance(node, source.NodeBasic | source.NodeCall | source.NodeEmpty | source.NodeAssume | source.NodeAssert):
new_nodes[after_name] = dataclasses.replace(new_nodes[after_name],
succ=first_inserted_node_name)
elif isinstance(node, source.NodeCond):
# notice how we replace from new_nodes[after_name], not node
# this important, because it can be updated multiple times
# (consider what happens when inserting on both the left and
# right branch of conditional node)
if node.succ_then == before_name:
new_nodes[after_name] = dataclasses.replace(new_nodes[after_name],
succ_then=first_inserted_node_name)
elif node.succ_else == before_name:
new_nodes[after_name] = dataclasses.replace(new_nodes[after_name],
succ_else=first_inserted_node_name)
else:
assert_never(node)
return new_nodes
class GhostVarName(source.ProgVarName):
pass
# a variable that ends with /subject-arg
class SubjectArgVarName(GhostVarName):
pass
# a variable that ends with /subject-arg
class CallArgVarName(GhostVarName):
pass
def subject_arg_var_name(arg: source.ExprVarT[source.ProgVarName | nip.GuardVarName]) -> source.ExprVarT[GhostVarName]:
assert arg.name.endswith('/arg'), f"{arg.name!r}"
return source.ExprVar(arg.typ, SubjectArgVarName(arg.name[:-len('/arg')] + "/subject-arg"))
def call_arg_var_name(arg: source.ExprVarT[source.ProgVarName | nip.GuardVarName]) -> source.ExprVarT[GhostVarName]:
assert arg.name.endswith('/arg'), f"{arg.name!r}"
return source.ExprVar(arg.typ, CallArgVarName(arg.name[:-len('/arg')] + "/call-arg"))
@unique
class Mode(Enum):
subject = "subject"
call = "call"
NUM_GHOST_VARIABLES_CPARSER_FUNCTION_CALL = 4 # mem, htd, pms, ghost assertions
# new_variables mutates
def sprinkle_subject_pre_and_post_conditions(func: nip.Function, new_variables: Set[source.ExprVarT[GhostVarName]]) -> Iterable[Insertion]:
"""
We assume the precondition holds, stash the initial values with the
suffix /subject-arg of all the arguments, and then assert that the post
condition holds at the bottom of the function, replacing the variables
with suffix /arg (refering to their old value) with the
suffix /subject-arg(ie. the stashed variables).
"""
entry_node = func.nodes[func.cfg.entry]
assert isinstance(entry_node, source.NodeEmpty)
stash_updates: Tuple[source.Update[source.ProgVarName |
nip.GuardVarName], ...] = ()
for param in func.signature.parameters:
var = source.ExprVar(param.typ, SubjectArgVarName(
param.name + '/subject-arg'))
# track all newly introduced variables
new_variables.add(var)
stash_updates = stash_updates + (source.Update(var, param),)
# a1/subject-arg = a1; a2/subject-arg = a2, ... (for all arguments)
yield Insertion(after=func.cfg.entry,
before=entry_node.succ,
node_name=source.NodeName('stash_initial_args'),
mk_node=lambda succ: source.NodeBasic(Provenance.CALL_STASH_INITIAL_ARGS, stash_updates, succ))
def f(var: source.ExprVarT[source.ProgVarName | nip.GuardVarName]) -> source.ExprVarT[source.ProgVarName | nip.GuardVarName]:
# this will change with the new way of writting specs
if not var.name.endswith('/arg'):
assert False, f"unknown variable {var.name}"
return subject_arg_var_name(var)
precondition = source.convert_expr_vars(f, func.ghost.precondition)
yield Insertion(after=func.cfg.entry,
before=entry_node.succ,
node_name=source.NodeName('pre_condition'),
mk_node=lambda succ: NodePreconditionAssumption(Provenance.PRE_COND, precondition, succ))
def g(var: source.ExprVarT[source.ProgVarName | nip.GuardVarName]) -> source.ExprVarT[source.ProgVarName | nip.GuardVarName]:
# this will be cleaned up when we implement the new way of writting specs
if isinstance(var.name, source.CRetSpecialVar):
assert 0 <= var.name.field_num and var.name.field_num <= len(
func.signature.returns) - NUM_GHOST_VARIABLES_CPARSER_FUNCTION_CALL
return func.signature.returns[var.name.field_num]
elif var.name.endswith("/arg"):
return subject_arg_var_name(var)
return var
converted_post_condition = source.convert_expr_vars(
g, func.ghost.postcondition)
assert len(func.cfg.all_preds[source.NodeNameRet]) == 1, ("not to worry, just need to handle the case "
"where the Ret node has multiple predecessors")
pred = func.cfg.all_preds[source.NodeNameRet][0]
yield Insertion(after=pred,
before=source.NodeNameRet,
node_name=source.NodeName('post_condition'),
mk_node=lambda succ: NodePostConditionProofObligation(Provenance.POST_COND, converted_post_condition, succ, source.NodeNameErr))
def sprinkle_loop_invariant(func: nip.Function, lh: source.LoopHeaderName) -> Iterable[Insertion]:
# TODO
# ----
#
# to generate more readable SMT, we should put the loop invariant into an
# SMT function. It would be safe to also provide a proof that this
# function only needs to have for parameter the loop targets.
#
# proof sketch: suppose the loop invariant depends on a variable which
# isn't a loop target. By definition of loop targets, it is never on the
# lhs of an assignment within the loop, thus it's value is constant, and
# hence doesn't need to be a parameter. By exhaustion of cases, the
# invariant's parameters only need to be the loop targets.
#
# If a variable isn't a loop target, the incarnation number to use is the
# one that occurs in the loop header's DSA context (ie. the only incarnation
# for that variable throughout the loop)
#
# UPDATE: this will be fixed when we switch to the new way of writing
# specs
# ALL predecessors, including predecessors that follow a back edge
def f(var: source.ExprVarT[source.ProgVarName | nip.GuardVarName]) -> source.ExprVarT[source.ProgVarName | nip.GuardVarName]:
if var.name.endswith('/arg'):
return subject_arg_var_name(var)
return var
inv = source.convert_expr_vars(f, func.ghost.loop_invariants[lh])
for i, pred in enumerate(func.cfg.all_preds[lh], start=1):
yield Insertion(after=pred,
before=lh,
node_name=source.NodeName(f'loop_{lh}_latch_{i}'),
mk_node=lambda succ: NodeLoopInvariantProofObligation(
Provenance.LOOP_INV_OBLIGATION,
inv,
succ,
source.NodeNameErr
))
for i, nsucc in enumerate(func.cfg.all_succs[lh], start=1):
yield Insertion(after=lh,
before=nsucc,
node_name=source.NodeName(f'loop_{lh}_inv_asm_{i}'),
mk_node=lambda succ: NodeLoopInvariantAssumption(
Provenance.LOOP_INV_ASSUME,
inv,
succ))
def sprinkle_loop_invariants(func: nip.Function) -> Iterable[Insertion]:
for loop_header in func.loops:
yield from sprinkle_loop_invariant(func, loop_header)
def sprinkle_function_call_pre_and_post_condition(func: nip.Function,
node_name: source.NodeName,
node: source.NodeCall[source.ProgVarName | nip.GuardVarName],
new_variables: Set[source.ExprVarT[GhostVarName]],
signatures: Mapping[str, TemporaryFunctionSignature]) -> Iterable[Insertion]:
# the parameters are the "variable" in a method definition
# the arguments are the values you pass at function call
# (you define the parameters, you make the arguments)
params = signatures[node.fname].parameters
assert len(node.args) == len(params)
_call_stash_updates: list[source.Update[source.ProgVarName |
nip.GuardVarName]] = []
for param, arg in zip(params, node.args):
exprVar = source.ExprVar(
param.typ, CallArgVarName(param.name + "/call-arg"))
new_variables.add(exprVar)
_call_stash_updates.append(source.Update(exprVar, arg))
call_stash_updates = tuple(_call_stash_updates)
def f(var: source.ExprVarT[source.ProgVarName | nip.GuardVarName]) -> source.ExprVarT[source.ProgVarName | nip.GuardVarName]:
# this will change with the new way of writing specs
if not var.name.endswith('/arg'):
assert False, f"unknown variable {var.name}"
return call_arg_var_name(var)
for i, pred in enumerate(func.cfg.all_preds[node_name], start=1):
yield Insertion(after=pred,
before=node_name,
node_name=source.NodeName(
f'call_stash_{node_name}_pred_{i}'),
mk_node=lambda succ: source.NodeBasic(Provenance.CALL_STASH, call_stash_updates, succ))
precond = source.convert_expr_vars(
f, signatures[node.fname].precondition)
yield Insertion(after=pred,
before=node_name,
node_name=source.NodeName(
f'call_pre_{node_name}_pred_{i}'),
mk_node=lambda succ: NodePrecondObligationFnCall(Provenance.PRE_COND_FN_OBLIGATION, precond, succ))
rets = node.rets # pyright isn't smart enough
def g(var: source.ExprVarT[source.ProgVarName | nip.GuardVarName]) -> source.ExprVarT[source.ProgVarName | nip.GuardVarName]:
if isinstance(var.name, source.CRetSpecialVar):
assert 0 <= var.name.field_num and var.name.field_num <= len(
rets) - NUM_GHOST_VARIABLES_CPARSER_FUNCTION_CALL
return rets[var.name.field_num]
elif var.name.endswith("/arg"):
return call_arg_var_name(var)
return var
postcond = source.convert_expr_vars(
g, signatures[node.fname].postcondition)
yield Insertion(after=node_name,
before=node.succ,
node_name=source.NodeName(f'call_post_{node_name}'),
mk_node=lambda succ: NodeAssumePostCondFnCall(
Provenance.POST_COND_FN_ASSUME,
postcond,
succ))
def sprinkle_function_call_pre_and_post_conditions(func: nip.Function,
new_variables: Set[source.ExprVarT[GhostVarName]],
signatures: Mapping[str, TemporaryFunctionSignature]) -> Iterable[Insertion]:
for node_name in func.traverse_topologically(skip_err_and_ret=True):
node = func.nodes[node_name]
if isinstance(node, source.NodeCall):
yield from sprinkle_function_call_pre_and_post_condition(func, node_name, node, new_variables, signatures)
@dataclass(frozen=True, slots=True)
class TemporaryFunctionSignature:
""" These function signatures should be all loaded at the start,
and then passed through each functions.
At load time, we should make sure that the precondition and the post
condition only talk about the variables they are allowed to talk
about.
TODO: make this temporary function signature the new
FunctionSignature
"""
parameters: Tuple[source.ExprVarT[source.ProgVarName], ...]
returns: Tuple[source.ExprVarT[source.ProgVarName], ...]
precondition: source.ExprT[source.ProgVarName | nip.GuardVarName]
postcondition: source.ExprT[source.ProgVarName | nip.GuardVarName]
def sprinkle_pre_and_post_loop_iterations(func: nip.Function) -> Iterable[Insertion]:
def mk_pre_insertion(after: source.NodeName,
before: source.NodeName,
node_name: source.NodeName,
expr: source.ExprT[source.ProgVarName | nip.GuardVarName]) -> Insertion:
return Insertion(after=after,
before=before,
node_name=source.NodeName(
node_name + "_pre_iter_asm"),
mk_node=lambda succ: NodeLoopPreIterationAssumption(origin=Provenance.LOOP_ITER_PRE,
succ=succ,
expr=expr))
def mk_post_insertion(after: source.NodeName,
before: source.NodeName,
node_name: source.NodeName,
expr: source.ExprT[source.ProgVarName | nip.GuardVarName]) -> Insertion:
return Insertion(after=after,
before=before,
node_name=source.NodeName(
node_name + "_post_iter_proof"),
mk_node=lambda succ: NodeLoopPostIterationProofObligation(origin=Provenance.LOOP_ITER_POST,
succ=succ,
expr=expr))
for node_name in func.traverse_topologically(skip_err_and_ret=True):
if loop_header := func.is_loop_header(node_name):
pre_iter = func.ghost.loop_iterations[loop_header].pre_iter
if pre_iter == source.expr_true:
continue
succs = [succ for succ in func.cfg.all_succs[node_name]
if succ in func.loops[loop_header].nodes]
assert len(
succs) > 0, "couldn't find a successor in the loop when inserting pre loop iteration assumption"
succ_in_the_loop = succs[0]
yield mk_pre_insertion(after=node_name,
before=succ_in_the_loop,
node_name=source.NodeName(
node_name + "_pre_iter_asm"),
expr=pre_iter)
if loop_header := func.is_loop_latch(node_name):
post_iter = func.ghost.loop_iterations[loop_header].post_iter
if post_iter == source.expr_true:
continue
yield mk_post_insertion(after=node_name,
before=loop_header,
node_name=source.NodeName(
node_name + "_post_iter_proof"),
expr=post_iter)
def sprinkle_ghost_code(filename: str, func: nip.Function, unsafe_ctx: Mapping[str, syntax.Function]) -> Function:
ctx: dict[str, TemporaryFunctionSignature] = {}
for fname, syn_func in unsafe_ctx.items():
sig = source.convert_function_metadata(syn_func)
ghost = ghost_data.get(filename, fname)
precondition: source.ExprT[source.ProgVarName | nip.GuardVarName]
postcondition: source.ExprT[source.ProgVarName | nip.GuardVarName]
if ghost is None:
precondition = source.expr_true
postcondition = source.expr_true
else:
precondition = ghost.precondition
postcondition = ghost.postcondition
ctx[fname] = TemporaryFunctionSignature(parameters=sig.parameters,
returns=sig.returns,
precondition=precondition,
postcondition=postcondition)
new_variables: Set[source.ExprVarT[GhostVarName]] = set([])
insertions: list[Insertion] = []
insertions.extend(
sprinkle_subject_pre_and_post_conditions(func, new_variables))
insertions.extend(
sprinkle_function_call_pre_and_post_conditions(func, new_variables, ctx))
insertions.extend(sprinkle_loop_invariants(func))
insertions.extend(sprinkle_pre_and_post_loop_iterations(func))
new_nodes = apply_insertions(func, insertions)
all_succs = abc_cfg.compute_all_successors_from_nodes(new_nodes)
cfg = abc_cfg.compute_cfg_from_all_succs(all_succs, func.cfg.entry)
loops = abc_cfg.compute_loops(
new_nodes, cfg)
assert loops.keys() == func.loops.keys(
), "more work required: loop headers changed during conversion, need to keep ghost's loop invariant in sync"
return Function(name=func.name, variables=func.variables | new_variables, nodes=new_nodes, cfg=cfg, loops=loops, ghost=func.ghost, signature=func.signature)