From 3e476a936261afd9077de519f36d66bfabd1b0e1 Mon Sep 17 00:00:00 2001 From: alexcere <48130030+alexcere@users.noreply.github.com> Date: Thu, 12 Dec 2024 11:30:29 +0100 Subject: [PATCH] #18 Score instructions lexicographically --- src/greedy/greedy_new_version.py | 134 +++++++++++++++++++------------ 1 file changed, 84 insertions(+), 50 deletions(-) diff --git a/src/greedy/greedy_new_version.py b/src/greedy/greedy_new_version.py index 9c4e4790..419a7a2f 100644 --- a/src/greedy/greedy_new_version.py +++ b/src/greedy/greedy_new_version.py @@ -330,14 +330,30 @@ def first_occurrence(self, var_elem: var_id_T) -> int: except: return -1 - def last_accessible_occurrence(self, var_elem: var_id_T) -> int: + def last_occurrence(self, var_elem: var_id_T) -> int: """ - Returns the first position in which the element appears + Returns the last position in which the element appears """ - try: - return self.stack.index(var_elem) - except: - return -1 + for idx, stack_elem in enumerate(self.stack[::-1]): + if stack_elem == var_elem: + return len(self.stack) - 1 - idx + return -1 + + def last_swap_occurrence(self, var_elem: var_id_T) -> int: + """ + Returns the last accessible position in which the element appears and can be swapped + """ + current_idx = min(STACK_DEPTH, len(self.stack) - 1) + + while current_idx > 0: + stack_elem = self.stack[current_idx] + + # Find the last occurrence in which the element is not placed in its position + if stack_elem == var_elem and self.idx_wrt_fstack(current_idx) != var_elem: + return current_idx + current_idx -= 1 + + return -1 def has_computations(self): """ @@ -524,18 +540,7 @@ def greedy(self) -> List[instr_id_T]: optg.extend(self.solve_permutation(cstate)) self.debug_logger.debug_after_permutation(cstate, optg) - if cstate.debug_mode: - list_strings = [str(t[0]) for t in cstate.trace] - max_list_len = max(len(ls) for ls in list_strings) - - # Calculate the maximum length of the string part - max_str_len = max(len(t[1]) for t in cstate.trace) - - print(list_strings) - - # Print each tuple with aligned formatting - for (lst, string), list_str in zip(cstate.trace, list_strings): - print(f"{string:<{max_str_len}} {list_str:>{max_list_len}}") + self.print_traces(cstate) return optg def _available_positions(self, var_elem: var_id_T, cstate: SymbolicState) -> Generator[cstack_pos_T, None, None]: @@ -581,16 +586,16 @@ def choose_next_computation(self, cstate: SymbolicState) -> Tuple[Union[instr_id """ Returns either a stack element or an instruction that must be computed """ - candidate = self._score_candidate(cstate) + candidate = self._select_candidate(cstate) return candidate - def _score_candidate(self, cstate: SymbolicState) -> Tuple[Union[instr_id_T, var_id_T], str]: + def _select_candidate(self, cstate: SymbolicState) -> Tuple[Union[instr_id_T, var_id_T], str]: """ Decides which stack variable or instruction must be computed using a scoring system """ new_instr, cheap_stack_elems, dup_stack_elems = cstate.candidates() current_top = cstate.top_stack() - best_candidate_info = False, dict(), False + best_candidate_score = [-1] candidate = None # TODO: pass candidates as arguments @@ -605,34 +610,22 @@ def _score_candidate(self, cstate: SymbolicState) -> Tuple[Union[instr_id_T, var # First, we evaluate the remaining instructions for id_ in new_instr: - top_instr = self._id2instr[id_] - deepest_pos = dict() - - # Function invocations might generate multiple values that we should take into account - for out_var in top_instr['outpt_sk']: - # We detect which is the deepest position in which the element can be placed - deepest_occurrence = self._deepest_position(out_var) - if deepest_occurrence is not None: - deepest_pos[out_var] = deepest_occurrence - - # We can reuse the topmost element and consume it - uses_top = top_can_be_reused and current_top in self._top_can_be_used[id_] - current_candidate_info = uses_top, deepest_pos, top_instr["id"] in self._instrs_with_deps + score_id = self._score_instr(self._id2instr[id_], cstate, top_can_be_reused) # To decide whether the current candidate is the best so far, we use the information from deepest_pos # and reuses_pos - better_candidate = self._le_ranked_options(best_candidate_info, current_candidate_info) + better_candidate = score_id > best_candidate_score if better_candidate: candidate = id_ - best_candidate_info = current_candidate_info + best_candidate_score = score_id - self.debug_logger.debug_rank_candidates(id_, current_candidate_info, better_candidate) + self.debug_logger.debug_rank_candidates(id_, score_id, better_candidate) # If the best candidate does not reuse the topmost element, we also try duplicating already existing elements # or cheap computations - if not best_candidate_info[0]: + if best_candidate_score[0] <= 0: # Search among the positions not solved that are deepest than the one in the best candidate - deepest_position = max(best_candidate_info[1].values(), default=-1) + deepest_position = best_candidate_score[3] if len(best_candidate_score) == 4 else -1 for position_not_solved in sorted(cstate.not_solved, reverse=True): if position_not_solved < deepest_position: @@ -675,17 +668,43 @@ def _handle_too_deep(self, cstate: SymbolicState) -> Optional[Tuple[Union[instr_ return None - def _le_ranked_options(self, option1: Tuple[bool, Dict[var_id_T, int], bool], - option2: Tuple[bool, Dict[var_id_T, int], bool]) -> bool: - # First we prioritize whether it can reuse the topmost element - if option1[0] != option2[0]: - return option2[0] - opt1_deepest = max(option1[1].values(), default=-1) - opt2_deepest = max(option2[1].values(), default=-1) - if opt1_deepest != opt2_deepest: - return opt1_deepest < opt2_deepest + def _score_instr(self, instr: instr_JSON_T, cstate: SymbolicState, top_can_reused: bool) -> Tuple[int, int, int, int]: + """ + We score the instructions according to the following lexicographic order: + 1) Can reuse the topmost element + 2) Number of stack elements that can consume by swapping + 3) Deepest position that needs to access. If > STACK_DEPTH, then, we assign to -1 + 4) Deepest position in which one of the produced stack vars can be consumed + """ + can_reuse_topmost = int(top_can_reused and self._top_can_be_used.get(cstate.top_stack(), False)) + n_swappable = 0 + max_pos = -1 + + # From the input stack, retrieves how many stack elements can be consumed by swapping + # and the deepest position needed to access + for input_var in instr['inpt_sk']: - return not option1[2] + # Does not need to be duplicated + if cstate.stack_var_copies_needed[input_var] == 0: + swap_position = cstate.last_swap_occurrence(input_var) + if swap_position != -1: + n_swappable += 1 + max_pos = max(max_pos, swap_position) + else: + max_pos = max(max_pos, cstate.first_occurrence(input_var)) + + deepest_to_place = -1 + # Function invocations might generate multiple values that we should take into account + for out_var in instr['outpt_sk']: + # We detect which is the deepest position in which the element can be placed + deepest_position = self._deepest_position(out_var) + if deepest_position is not None: + deepest_to_place = max(deepest_position, deepest_to_place) + + return [can_reuse_topmost, n_swappable, max_pos, deepest_to_place] + + def _score_stack_var(self, instr: instr_JSON_T, cstate: SymbolicState) -> int: + pass def compute_instr(self, instr: instr_JSON_T, cstate: SymbolicState) -> List[instr_id_T]: """ @@ -779,7 +798,7 @@ def compute_var(self, var_elem: var_id_T, cstate: SymbolicState) -> List[instr_i if cstate.is_accessible_swap(var_elem) and cstate.stack_var_copies_needed[var_elem] == 0 \ and self.fixed_elements == 0: # We swap to the deepest accesible copy - idx = cstate.last_accessible_occurrence(var_elem) + idx = cstate.last_swap_occurrence(var_elem) seq = cstate.swap(idx) self.debug_logger.debug_message(f"SWAP{idx} {cstate.stack}") @@ -805,6 +824,21 @@ def solve_permutation(self, cstate: SymbolicState) -> List[instr_id_T]: # TODO: complete code return [] + def print_traces(self, cstate: SymbolicState) -> None: + """ + Prints the traces so far from the current state. Debug mode must be activated + """ + if cstate.debug_mode: + list_strings = [str(t[0]) for t in cstate.trace] + max_list_len = max(len(ls) for ls in list_strings) + + # Calculate the maximum length of the string part + max_str_len = max(len(t[1]) for t in cstate.trace) + + # Print each tuple with aligned formatting + for (lst, string), list_str in zip(cstate.trace, list_strings): + print(f"{string:<{max_str_len}} {list_str:>{max_list_len}}") + class DebugLogger: """