Skip to content

Commit

Permalink
First attempt at domain propagation with unary operators
Browse files Browse the repository at this point in the history
  • Loading branch information
ckirsch committed Jan 13, 2025
1 parent b1962b7 commit e2a642a
Showing 1 changed file with 55 additions and 30 deletions.
85 changes: 55 additions & 30 deletions tools/bitme.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,19 +353,30 @@ def get_bitwuzla(self, tm):

class Values:
def __init__(self, sid_line):
self.sid_line = sid_line
self.number_of_values = 0
self.values = {}
self.number_of_constraints = 0
self.constraints = {}

def OR(arg1, arg2):
assert arg1 != Constant.false or arg2 != Constant.false
if arg1 == Constant.true or arg2 == Constant.true:
return Constant.true
else:
return Logical(next_nid(), OP_OR, Bool.boolean, arg1, arg2, arg1.comment, arg1.line_no)

def get_value(self):
assert self.number_of_values == 1
return list(self.values)[0]

def set_value(self, constraint, value):
assert self.sid_line == value.sid_line
assert constraint != Constant.false
if value not in self.values:
self.number_of_values += 1
self.values[value] = constraint
if constraint not in self.constraints:
self.number_of_constraints += 1
self.constraints[constraint] = value
self.values[value] = constraint
else:
self.values[value] = Values.OR(constraint, self.values[value])
return self

class Expression(Line):
def __init__(self, nid, sid_line, domain, comment, line_no):
Expand All @@ -382,6 +393,10 @@ def get_domain(self):
# filter out uninitialized states
return [state for state in self.domain if state.init_line is not None]

def get_value(self):
# TODO: remove when done with domain propagation
return self

def get_z3_lambda(self):
if self.z3_lambda is None:
domain = self.get_domain()
Expand Down Expand Up @@ -432,7 +447,9 @@ def get_mapped_array_expression_for(self, index):
return self

def get_values(self, step):
return self
if 0 not in self.cache_values:
self.cache_values[0] = Values(self.sid_line).set_value(Constant.true, self)
return self.cache_values[0]

def get_z3(self):
if self.z3 is None:
Expand Down Expand Up @@ -619,7 +636,7 @@ def get_instance(self, step):
def set_instance(self, instance, step):
self.cache_instance[step] = instance
if Instance.PROPAGATE:
self.cache_instance[step] = self.cache_instance[step].get_values(step)
self.cache_instance[step] = self.cache_instance[step].get_values(step).get_value()

def get_z3_select(self, step):
instance = self.get_instance(step)
Expand Down Expand Up @@ -805,7 +822,7 @@ def copy(self, arg1_line):

def get_values(self, step):
if step not in self.cache_values:
arg1_value = self.arg1_line.get_values(step)
arg1_value = self.arg1_line.get_values(step).get_value()
if isinstance(arg1_value, Constant):
if self.op == 'sext':
self.cache_values[step] = type(arg1_value)(next_nid(), self.sid_line, arg1_value.signed_value, self.comment, self.line_no)
Expand Down Expand Up @@ -861,7 +878,7 @@ def copy(self, arg1_line):

def get_values(self, step):
if step not in self.cache_values:
arg1_value = self.arg1_line.get_values(step)
arg1_value = self.arg1_line.get_values(step).get_value()
if isinstance(arg1_value, Constant):
self.cache_values[step] = type(arg1_value)(next_nid(), self.sid_line,
(arg1_value.value & 2**(self.u + 1) - 1) >> self.l, self.comment, self.line_no)
Expand Down Expand Up @@ -910,30 +927,38 @@ def get_mapped_array_expression_for(self, index):
arg1_line = self.arg1_line.get_mapped_array_expression_for(None)
return self.copy(arg1_line)

def get_unaries(self, values, op):
results = Values(self.sid_line)
for value in values.values:
constraint = values.values[value]
if op == (lambda x: not x):
if value == Constant.false:
results.set_value(constraint, Constant.true)
else:
assert value == Constant.true
results.set_value(constraint, Constant.false)
else:
results.set_value(constraint,
type(value)(next_nid(), self.sid_line,
op(value.value) % 2**self.sid_line.size, self.comment, self.line_no))
return results

def get_values(self, step):
if step not in self.cache_values:
arg1_value = self.arg1_line.get_values(step)
if isinstance(arg1_value, Constant):
value = arg1_value.value
if isinstance(arg1_value, Values):
if self.op == 'not':
if isinstance(self.sid_line, Bool):
if arg1_value == Constant.false:
self.cache_values[step] = Constant.true
else:
assert arg1_value == Constant.true
self.cache_values[step] = Constant.false
return self.cache_values[step]
assert arg1_value.number_of_values <= 2
self.cache_values[step] = self.get_unaries(arg1_value, lambda x: not x)
else:
value = ~value
self.cache_values[step] = self.get_unaries(arg1_value, lambda x: ~x)
elif self.op == 'inc':
value = value + 1
self.cache_values[step] = self.get_unaries(arg1_value, lambda x: x + 1)
elif self.op == 'dec':
value = value - 1
self.cache_values[step] = self.get_unaries(arg1_value, lambda x: x - 1)
elif self.op == 'neg':
value = -value

self.cache_values[step] = type(arg1_value)(next_nid(), self.sid_line,
value % 2**self.sid_line.size, self.comment, self.line_no)
self.cache_values[step] = self.get_unaries(arg1_value, lambda x: -x)
else:
self.cache_values[step] = self.copy(arg1_value)
return self.cache_values[step]
Expand Down Expand Up @@ -999,8 +1024,8 @@ def get_mapped_array_expression_for(self, index):

def get_values(self, step):
if step not in self.cache_values:
arg1_value = self.arg1_line.get_values(step)
arg2_value = self.arg2_line.get_values(step)
arg1_value = self.arg1_line.get_values(step).get_value()
arg2_value = self.arg2_line.get_values(step).get_value()
self.cache_values[step] = self.copy(arg1_value, arg2_value)
return self.cache_values[step]

Expand Down Expand Up @@ -1335,9 +1360,9 @@ def __str__(self):

def get_values(self, step):
if step not in self.cache_values:
arg1_value = self.arg1_line.get_values(step)
arg2_value = self.arg2_line.get_values(step)
arg3_value = self.arg3_line.get_values(step)
arg1_value = self.arg1_line.get_values(step).get_value()
arg2_value = self.arg2_line.get_values(step).get_value()
arg3_value = self.arg3_line.get_values(step).get_value()
self.cache_values[step] = self.copy(arg1_value, arg2_value, arg3_value)
return self.cache_values[step]

Expand Down

0 comments on commit e2a642a

Please sign in to comment.