Skip to content

Commit

Permalink
Fixing and simplifying term substitution, keeping uninitialized state…
Browse files Browse the repository at this point in the history
… as free variables
  • Loading branch information
ckirsch committed Jan 13, 2025
1 parent 9564197 commit fc79211
Showing 1 changed file with 31 additions and 42 deletions.
73 changes: 31 additions & 42 deletions tools/bitme.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,19 +378,25 @@ def __init__(self, nid, sid_line, domain, comment, line_no):
if not isinstance(sid_line, Sort):
raise model_error("sort", line_no)

def get_domain(self):
# filter out uninitialized states
return [state for state in self.domain if state.init_line is not None]

def get_z3_lambda(self):
if self.z3_lambda is None:
if self.domain:
self.z3_lambda = z3.Lambda([state.get_z3() for state in self.domain], self.get_z3())
domain = self.get_domain()
if domain:
self.z3_lambda = z3.Lambda([state.get_z3() for state in domain], self.get_z3())
else:
self.z3_lambda = self.get_z3()
return self.z3_lambda

def get_bitwuzla_lambda(self, tm):
if self.bitwuzla_lambda is None:
if self.domain:
domain = self.get_domain()
if domain:
self.bitwuzla_lambda = tm.mk_term(bitwuzla.Kind.LAMBDA,
[*[state.get_bitwuzla(tm) for state in self.domain], self.get_bitwuzla(tm)])
[*[state.get_bitwuzla(tm) for state in domain], self.get_bitwuzla(tm)])
else:
self.bitwuzla_lambda = self.get_bitwuzla(tm)
return self.bitwuzla_lambda
Expand Down Expand Up @@ -566,6 +572,11 @@ def get_z3(self):
self.z3 = z3.Const(self.name, self.sid_line.get_z3())
return self.z3

def get_bitwuzla(self, tm):
if self.bitwuzla is None:
self.bitwuzla = tm.mk_const(self.sid_line.get_bitwuzla(tm), self.name)
return self.bitwuzla

class Input(Variable):
keyword = OP_INPUT

Expand All @@ -586,11 +597,6 @@ def get_values(self, step):
def get_z3_name(self, step):
return self.get_z3()

def get_bitwuzla(self, tm):
if self.bitwuzla is None:
self.bitwuzla = tm.mk_const(self.sid_line.get_bitwuzla(tm), self.name)
return self.bitwuzla

def get_bitwuzla_name(self, step, tm):
return self.get_bitwuzla(tm)

Expand Down Expand Up @@ -618,28 +624,22 @@ def set_instance(self, instance, step):
def get_z3_select(self, step):
instance = self.get_instance(step)
assert step not in self.cache_z3_instance
if instance.domain:
domain = instance.get_domain()
if domain:
self.cache_z3_instance[step] = z3.Select(instance.get_z3_lambda(),
*[state.get_z3_name(step) for state in instance.domain])
*[state.get_z3_name(step) for state in domain])
else:
self.cache_z3_instance[step] = instance.get_z3_lambda()
return self.cache_z3_instance[step]

def get_z3_substitute(self, step):
instance = self.get_instance(step)
assert step not in self.cache_z3_instance
if step <= 0:
self.cache_z3_instance[step] = instance.get_z3()
else:
assert step - 1 in self.cache_z3_instance
self.cache_z3_instance[step] = self.cache_z3_instance[step - 1]
if instance.domain:
if step <= 0:
current_states = [state.get_z3() for state in instance.domain]
else:
# assuming that cached z3 term is a term over states of step - 1
current_states = [state.get_z3_name(step - 1) for state in instance.domain]
next_states = [state.get_z3_name(step) for state in instance.domain]
self.cache_z3_instance[step] = instance.get_z3()
domain = instance.get_domain()
if domain:
current_states = [state.get_z3() for state in domain]
next_states = [state.get_z3_name(step) for state in domain]
renaming = list(zip(current_states, next_states))

self.cache_z3_instance[step] = z3.substitute(self.cache_z3_instance[step], renaming)
Expand All @@ -656,29 +656,23 @@ def get_z3_instance(self, step):
def get_bitwuzla_select(self, step, tm):
instance = self.get_instance(step)
assert step not in self.cache_bitwuzla_instance
if instance.domain:
domain = instance.get_domain()
if domain:
self.cache_bitwuzla_instance[step] = tm.mk_term(bitwuzla.Kind.APPLY,
[instance.get_bitwuzla_lambda(tm),
*[state.get_bitwuzla_name(step, tm) for state in instance.domain]])
*[state.get_bitwuzla_name(step, tm) for state in domain]])
else:
self.cache_bitwuzla_instance[step] = instance.get_bitwuzla_lambda(tm)
return self.cache_bitwuzla_instance[step]

def get_bitwuzla_substitute(self, step, tm):
instance = self.get_instance(step)
assert step not in self.cache_bitwuzla_instance
if step <= 0:
self.cache_bitwuzla_instance[step] = instance.get_bitwuzla(tm)
else:
assert step - 1 in self.cache_bitwuzla_instance
self.cache_bitwuzla_instance[step] = self.cache_bitwuzla_instance[step - 1]
if instance.domain:
if step <= 0:
current_states = [state.get_bitwuzla(tm) for state in instance.domain]
else:
# assuming that cached bitwuzla term is a term over states of step - 1
current_states = [state.get_bitwuzla_name(step - 1, tm) for state in instance.domain]
next_states = [state.get_bitwuzla_name(step, tm) for state in instance.domain]
self.cache_bitwuzla_instance[step] = instance.get_bitwuzla(tm)
domain = instance.get_domain()
if domain:
current_states = [state.get_bitwuzla(tm) for state in domain]
next_states = [state.get_bitwuzla_name(step, tm) for state in domain]
renaming = dict(zip(current_states, next_states))

self.cache_bitwuzla_instance[step] = tm.substitute_term(self.cache_bitwuzla_instance[step], renaming)
Expand Down Expand Up @@ -763,11 +757,6 @@ def get_z3_name(self, step):
def get_z3_instance(self, step):
return self.instance.get_z3_instance(step)

def get_bitwuzla(self, tm):
if self.bitwuzla is None:
self.bitwuzla = tm.mk_var(self.sid_line.get_bitwuzla(tm), self.name)
return self.bitwuzla

def get_bitwuzla_name(self, step, tm):
if step == -1:
step = 0
Expand Down

0 comments on commit fc79211

Please sign in to comment.