diff --git a/tools/bitme.py b/tools/bitme.py index bb12304c..22b0bd2c 100755 --- a/tools/bitme.py +++ b/tools/bitme.py @@ -378,19 +378,25 @@ def __init__(self, nid, sid_line, domain, comment, line_no): if not isinstance(sid_line, Sort): raise model_error("sort", line_no) + def get_domain(self): + # filter out uninitialized states + return [state for state in self.domain if state.init_line is not None] + def get_z3_lambda(self): if self.z3_lambda is None: - if self.domain: - self.z3_lambda = z3.Lambda([state.get_z3() for state in self.domain], self.get_z3()) + domain = self.get_domain() + if domain: + self.z3_lambda = z3.Lambda([state.get_z3() for state in domain], self.get_z3()) else: self.z3_lambda = self.get_z3() return self.z3_lambda def get_bitwuzla_lambda(self, tm): if self.bitwuzla_lambda is None: - if self.domain: + domain = self.get_domain() + if domain: self.bitwuzla_lambda = tm.mk_term(bitwuzla.Kind.LAMBDA, - [*[state.get_bitwuzla(tm) for state in self.domain], self.get_bitwuzla(tm)]) + [*[state.get_bitwuzla(tm) for state in domain], self.get_bitwuzla(tm)]) else: self.bitwuzla_lambda = self.get_bitwuzla(tm) return self.bitwuzla_lambda @@ -566,6 +572,11 @@ def get_z3(self): self.z3 = z3.Const(self.name, self.sid_line.get_z3()) return self.z3 + def get_bitwuzla(self, tm): + if self.bitwuzla is None: + self.bitwuzla = tm.mk_const(self.sid_line.get_bitwuzla(tm), self.name) + return self.bitwuzla + class Input(Variable): keyword = OP_INPUT @@ -586,11 +597,6 @@ def get_values(self, step): def get_z3_name(self, step): return self.get_z3() - def get_bitwuzla(self, tm): - if self.bitwuzla is None: - self.bitwuzla = tm.mk_const(self.sid_line.get_bitwuzla(tm), self.name) - return self.bitwuzla - def get_bitwuzla_name(self, step, tm): return self.get_bitwuzla(tm) @@ -618,9 +624,10 @@ def set_instance(self, instance, step): def get_z3_select(self, step): instance = self.get_instance(step) assert step not in self.cache_z3_instance - if instance.domain: + domain = instance.get_domain() + if domain: self.cache_z3_instance[step] = z3.Select(instance.get_z3_lambda(), - *[state.get_z3_name(step) for state in instance.domain]) + *[state.get_z3_name(step) for state in domain]) else: self.cache_z3_instance[step] = instance.get_z3_lambda() return self.cache_z3_instance[step] @@ -628,18 +635,11 @@ def get_z3_select(self, step): def get_z3_substitute(self, step): instance = self.get_instance(step) assert step not in self.cache_z3_instance - if step <= 0: - self.cache_z3_instance[step] = instance.get_z3() - else: - assert step - 1 in self.cache_z3_instance - self.cache_z3_instance[step] = self.cache_z3_instance[step - 1] - if instance.domain: - if step <= 0: - current_states = [state.get_z3() for state in instance.domain] - else: - # assuming that cached z3 term is a term over states of step - 1 - current_states = [state.get_z3_name(step - 1) for state in instance.domain] - next_states = [state.get_z3_name(step) for state in instance.domain] + self.cache_z3_instance[step] = instance.get_z3() + domain = instance.get_domain() + if domain: + current_states = [state.get_z3() for state in domain] + next_states = [state.get_z3_name(step) for state in domain] renaming = list(zip(current_states, next_states)) self.cache_z3_instance[step] = z3.substitute(self.cache_z3_instance[step], renaming) @@ -656,10 +656,11 @@ def get_z3_instance(self, step): def get_bitwuzla_select(self, step, tm): instance = self.get_instance(step) assert step not in self.cache_bitwuzla_instance - if instance.domain: + domain = instance.get_domain() + if domain: self.cache_bitwuzla_instance[step] = tm.mk_term(bitwuzla.Kind.APPLY, [instance.get_bitwuzla_lambda(tm), - *[state.get_bitwuzla_name(step, tm) for state in instance.domain]]) + *[state.get_bitwuzla_name(step, tm) for state in domain]]) else: self.cache_bitwuzla_instance[step] = instance.get_bitwuzla_lambda(tm) return self.cache_bitwuzla_instance[step] @@ -667,18 +668,11 @@ def get_bitwuzla_select(self, step, tm): def get_bitwuzla_substitute(self, step, tm): instance = self.get_instance(step) assert step not in self.cache_bitwuzla_instance - if step <= 0: - self.cache_bitwuzla_instance[step] = instance.get_bitwuzla(tm) - else: - assert step - 1 in self.cache_bitwuzla_instance - self.cache_bitwuzla_instance[step] = self.cache_bitwuzla_instance[step - 1] - if instance.domain: - if step <= 0: - current_states = [state.get_bitwuzla(tm) for state in instance.domain] - else: - # assuming that cached bitwuzla term is a term over states of step - 1 - current_states = [state.get_bitwuzla_name(step - 1, tm) for state in instance.domain] - next_states = [state.get_bitwuzla_name(step, tm) for state in instance.domain] + self.cache_bitwuzla_instance[step] = instance.get_bitwuzla(tm) + domain = instance.get_domain() + if domain: + current_states = [state.get_bitwuzla(tm) for state in domain] + next_states = [state.get_bitwuzla_name(step, tm) for state in domain] renaming = dict(zip(current_states, next_states)) self.cache_bitwuzla_instance[step] = tm.substitute_term(self.cache_bitwuzla_instance[step], renaming) @@ -763,11 +757,6 @@ def get_z3_name(self, step): def get_z3_instance(self, step): return self.instance.get_z3_instance(step) - def get_bitwuzla(self, tm): - if self.bitwuzla is None: - self.bitwuzla = tm.mk_var(self.sid_line.get_bitwuzla(tm), self.name) - return self.bitwuzla - def get_bitwuzla_name(self, step, tm): if step == -1: step = 0