Skip to content

Commit

Permalink
#18 Replace solved by not solved
Browse files Browse the repository at this point in the history
  • Loading branch information
alexcere committed Dec 23, 2024
1 parent 4c29c4f commit 95c11ec
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions src/greedy/greedy_new_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,11 @@ def __init__(self, initial_stack: List[var_id_T], dependency_graph: nx.DiGraph,
self.stack_var_copies_needed: Dict[var_id_T, int] = stack_var_copies_needed.copy()

# Solved: positions in the final stack that contain elements in the correct position
self.solved: Set[fstack_pos_T] = {len(final_stack) - i - 1
for i, (ini_var, fin_var) in enumerate(zip(reversed(initial_stack), reversed(final_stack)))
solved: Set[fstack_pos_T] = {len(final_stack) - i - 1
for i, (ini_var, fin_var) in enumerate(zip(reversed(initial_stack),
reversed(final_stack)))
if ini_var == fin_var}
self.not_solved = set(range(len(self.final_stack))).difference(solved)

# Number of times each variable is computed in the stack
self.n_computed: Counter = Counter(initial_stack)
Expand All @@ -123,16 +125,21 @@ def __init__(self, initial_stack: List[var_id_T], dependency_graph: nx.DiGraph,
if self.debug_mode:
self.trace: List[Tuple[List[var_id_T], instr_id_T]] = [(self.stack.copy(), "Initial")]

@property
def not_solved(self):
return set(range(len(self.final_stack))).difference(self.solved)

def _remove_solved(self, idx: fstack_pos_T):
"""
Removes an element from the solved set, even if it is not there
Annotates that the idx is no longer solved
"""
try:
self.not_solved.add(idx)
except KeyError:
pass

def _add_solved(self, idx: fstack_pos_T):
"""
Annotates that the idx is solved
"""
try:
self.solved.remove(idx)
self.not_solved.remove(idx)
except KeyError:
pass

Expand All @@ -144,7 +151,7 @@ def _check_idx_solved_cstack(self, idx: cstack_pos_T):
self._remove_solved(fstack_pos)
var_elem = self.stack[idx]
if 0 <= fstack_pos < len(self.final_stack) and self.final_stack[fstack_pos] == var_elem:
self.solved.add(fstack_pos)
self._add_solved(fstack_pos)

def idx_wrt_fstack(self, idx: cstack_pos_T) -> fstack_pos_T:
"""
Expand Down Expand Up @@ -202,7 +209,7 @@ def dup(self, x: int) -> List[instr_id_T]:
fstack_idx = self.idx_wrt_fstack(0)

if fstack_idx >= 0 and self.final_stack[fstack_idx] == new_topmost:
self.solved.add(fstack_idx)
self._add_solved(fstack_idx)

# N computed: add one to the element we have computed
self.n_computed[new_topmost] += 1
Expand Down Expand Up @@ -272,7 +279,7 @@ def insert_element(self, instr: instr_JSON_T, output_var: var_id_T) -> None:
fstack_idx = self.idx_wrt_fstack(0)

if fstack_idx >= 0 and self.final_stack[fstack_idx] == output_var:
self.solved.add(fstack_idx)
self._add_solved(fstack_idx)

# N computed: add one to the element we have inserted
self.n_computed[output_var] += 1
Expand Down Expand Up @@ -324,7 +331,7 @@ def from_memory(self, var_elem: var_id_T) -> List[instr_id_T]:
fstack_idx = self.idx_wrt_fstack(0)

if fstack_idx >= 0 and self.final_stack[fstack_idx] == var_elem:
self.solved.add(fstack_idx)
self._add_solved(fstack_idx)

if self.debug_mode:
self.trace.append((self.stack.copy(), f"MEM({var_elem})"))
Expand Down Expand Up @@ -596,7 +603,7 @@ def _is_condensed(self, node):
"""
Whether to consider the instruction associated to the node
"""
return "STORE" in node or self._id2instr[node]["outpt_sk"] > 1 or \
return "STORE" in node or len(self._id2instr[node]["outpt_sk"]) > 1 or \
any(self._stack_var_copies_needed[out_stack] > 1 or out_stack in self._final_stack
for out_stack in self._id2instr[node]["outpt_sk"])

Expand Down

0 comments on commit 95c11ec

Please sign in to comment.