Skip to content

Commit

Permalink
#18 Score instructions lexicographically
Browse files Browse the repository at this point in the history
  • Loading branch information
alexcere committed Dec 12, 2024
1 parent 8a3ac17 commit 3e476a9
Showing 1 changed file with 84 additions and 50 deletions.
134 changes: 84 additions & 50 deletions src/greedy/greedy_new_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,14 +330,30 @@ def first_occurrence(self, var_elem: var_id_T) -> int:
except:
return -1

def last_accessible_occurrence(self, var_elem: var_id_T) -> int:
def last_occurrence(self, var_elem: var_id_T) -> int:
"""
Returns the first position in which the element appears
Returns the last position in which the element appears
"""
try:
return self.stack.index(var_elem)
except:
return -1
for idx, stack_elem in enumerate(self.stack[::-1]):
if stack_elem == var_elem:
return len(self.stack) - 1 - idx
return -1

def last_swap_occurrence(self, var_elem: var_id_T) -> int:
"""
Returns the last accessible position in which the element appears and can be swapped
"""
current_idx = min(STACK_DEPTH, len(self.stack) - 1)

while current_idx > 0:
stack_elem = self.stack[current_idx]

# Find the last occurrence in which the element is not placed in its position
if stack_elem == var_elem and self.idx_wrt_fstack(current_idx) != var_elem:
return current_idx
current_idx -= 1

return -1

def has_computations(self):
"""
Expand Down Expand Up @@ -524,18 +540,7 @@ def greedy(self) -> List[instr_id_T]:
optg.extend(self.solve_permutation(cstate))
self.debug_logger.debug_after_permutation(cstate, optg)

if cstate.debug_mode:
list_strings = [str(t[0]) for t in cstate.trace]
max_list_len = max(len(ls) for ls in list_strings)

# Calculate the maximum length of the string part
max_str_len = max(len(t[1]) for t in cstate.trace)

print(list_strings)

# Print each tuple with aligned formatting
for (lst, string), list_str in zip(cstate.trace, list_strings):
print(f"{string:<{max_str_len}} {list_str:>{max_list_len}}")
self.print_traces(cstate)
return optg

def _available_positions(self, var_elem: var_id_T, cstate: SymbolicState) -> Generator[cstack_pos_T, None, None]:
Expand Down Expand Up @@ -581,16 +586,16 @@ def choose_next_computation(self, cstate: SymbolicState) -> Tuple[Union[instr_id
"""
Returns either a stack element or an instruction that must be computed
"""
candidate = self._score_candidate(cstate)
candidate = self._select_candidate(cstate)
return candidate

def _score_candidate(self, cstate: SymbolicState) -> Tuple[Union[instr_id_T, var_id_T], str]:
def _select_candidate(self, cstate: SymbolicState) -> Tuple[Union[instr_id_T, var_id_T], str]:
"""
Decides which stack variable or instruction must be computed using a scoring system
"""
new_instr, cheap_stack_elems, dup_stack_elems = cstate.candidates()
current_top = cstate.top_stack()
best_candidate_info = False, dict(), False
best_candidate_score = [-1]
candidate = None

# TODO: pass candidates as arguments
Expand All @@ -605,34 +610,22 @@ def _score_candidate(self, cstate: SymbolicState) -> Tuple[Union[instr_id_T, var

# First, we evaluate the remaining instructions
for id_ in new_instr:
top_instr = self._id2instr[id_]
deepest_pos = dict()

# Function invocations might generate multiple values that we should take into account
for out_var in top_instr['outpt_sk']:
# We detect which is the deepest position in which the element can be placed
deepest_occurrence = self._deepest_position(out_var)
if deepest_occurrence is not None:
deepest_pos[out_var] = deepest_occurrence

# We can reuse the topmost element and consume it
uses_top = top_can_be_reused and current_top in self._top_can_be_used[id_]
current_candidate_info = uses_top, deepest_pos, top_instr["id"] in self._instrs_with_deps
score_id = self._score_instr(self._id2instr[id_], cstate, top_can_be_reused)

# To decide whether the current candidate is the best so far, we use the information from deepest_pos
# and reuses_pos
better_candidate = self._le_ranked_options(best_candidate_info, current_candidate_info)
better_candidate = score_id > best_candidate_score
if better_candidate:
candidate = id_
best_candidate_info = current_candidate_info
best_candidate_score = score_id

self.debug_logger.debug_rank_candidates(id_, current_candidate_info, better_candidate)
self.debug_logger.debug_rank_candidates(id_, score_id, better_candidate)

# If the best candidate does not reuse the topmost element, we also try duplicating already existing elements
# or cheap computations
if not best_candidate_info[0]:
if best_candidate_score[0] <= 0:
# Search among the positions not solved that are deepest than the one in the best candidate
deepest_position = max(best_candidate_info[1].values(), default=-1)
deepest_position = best_candidate_score[3] if len(best_candidate_score) == 4 else -1
for position_not_solved in sorted(cstate.not_solved, reverse=True):

if position_not_solved < deepest_position:
Expand Down Expand Up @@ -675,17 +668,43 @@ def _handle_too_deep(self, cstate: SymbolicState) -> Optional[Tuple[Union[instr_

return None

def _le_ranked_options(self, option1: Tuple[bool, Dict[var_id_T, int], bool],
option2: Tuple[bool, Dict[var_id_T, int], bool]) -> bool:
# First we prioritize whether it can reuse the topmost element
if option1[0] != option2[0]:
return option2[0]
opt1_deepest = max(option1[1].values(), default=-1)
opt2_deepest = max(option2[1].values(), default=-1)
if opt1_deepest != opt2_deepest:
return opt1_deepest < opt2_deepest
def _score_instr(self, instr: instr_JSON_T, cstate: SymbolicState, top_can_reused: bool) -> Tuple[int, int, int, int]:
"""
We score the instructions according to the following lexicographic order:
1) Can reuse the topmost element
2) Number of stack elements that can consume by swapping
3) Deepest position that needs to access. If > STACK_DEPTH, then, we assign to -1
4) Deepest position in which one of the produced stack vars can be consumed
"""
can_reuse_topmost = int(top_can_reused and self._top_can_be_used.get(cstate.top_stack(), False))
n_swappable = 0
max_pos = -1

# From the input stack, retrieves how many stack elements can be consumed by swapping
# and the deepest position needed to access
for input_var in instr['inpt_sk']:

return not option1[2]
# Does not need to be duplicated
if cstate.stack_var_copies_needed[input_var] == 0:
swap_position = cstate.last_swap_occurrence(input_var)
if swap_position != -1:
n_swappable += 1
max_pos = max(max_pos, swap_position)
else:
max_pos = max(max_pos, cstate.first_occurrence(input_var))

deepest_to_place = -1
# Function invocations might generate multiple values that we should take into account
for out_var in instr['outpt_sk']:
# We detect which is the deepest position in which the element can be placed
deepest_position = self._deepest_position(out_var)
if deepest_position is not None:
deepest_to_place = max(deepest_position, deepest_to_place)

return [can_reuse_topmost, n_swappable, max_pos, deepest_to_place]

def _score_stack_var(self, instr: instr_JSON_T, cstate: SymbolicState) -> int:
pass

def compute_instr(self, instr: instr_JSON_T, cstate: SymbolicState) -> List[instr_id_T]:
"""
Expand Down Expand Up @@ -779,7 +798,7 @@ def compute_var(self, var_elem: var_id_T, cstate: SymbolicState) -> List[instr_i
if cstate.is_accessible_swap(var_elem) and cstate.stack_var_copies_needed[var_elem] == 0 \
and self.fixed_elements == 0:
# We swap to the deepest accesible copy
idx = cstate.last_accessible_occurrence(var_elem)
idx = cstate.last_swap_occurrence(var_elem)
seq = cstate.swap(idx)
self.debug_logger.debug_message(f"SWAP{idx} {cstate.stack}")

Expand All @@ -805,6 +824,21 @@ def solve_permutation(self, cstate: SymbolicState) -> List[instr_id_T]:
# TODO: complete code
return []

def print_traces(self, cstate: SymbolicState) -> None:
"""
Prints the traces so far from the current state. Debug mode must be activated
"""
if cstate.debug_mode:
list_strings = [str(t[0]) for t in cstate.trace]
max_list_len = max(len(ls) for ls in list_strings)

# Calculate the maximum length of the string part
max_str_len = max(len(t[1]) for t in cstate.trace)

# Print each tuple with aligned formatting
for (lst, string), list_str in zip(cstate.trace, list_strings):
print(f"{string:<{max_str_len}} {list_str:>{max_list_len}}")


class DebugLogger:
"""
Expand Down

0 comments on commit 3e476a9

Please sign in to comment.