Skip to content

Commit

Permalink
a quick fix on the azure shiny app to get it working again (#262)
Browse files Browse the repository at this point in the history
* a quick fix on the shinyapp to get it working again

* forgot to turn off pre-filtering while testing

* helper function was not uploaded
  • Loading branch information
arik-shurygin authored Oct 8, 2024
1 parent 078a62a commit 96d62f6
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
6 changes: 5 additions & 1 deletion shiny_visualizers/shiny_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from plotly.subplots import make_subplots

from mechanistic_azure.azure_utilities import download_directory_from_azure
from resp_ode.utils import flatten_list_parameters
from resp_ode.utils import drop_keys_with_substring, flatten_list_parameters


class Node:
Expand Down Expand Up @@ -258,6 +258,8 @@ def load_checkpoint_inference_chains(
# any sampled parameters created via numpyro.plate will mess up the data
# flatten plated parameters into separate keys
posteriors: dict[str, list] = flatten_list_parameters(posteriors)
# drop any final_timestep variables if they exist within the posteriors
posteriors = drop_keys_with_substring(posteriors, "final_timestep")
num_sampled_parameters = len(posteriors.keys())
# we want a mostly square subplot, so lets sqrt and take floor/ceil to deal with odd numbers
num_rows = math.isqrt(num_sampled_parameters)
Expand Down Expand Up @@ -331,6 +333,8 @@ def load_checkpoint_inference_correlations(
posteriors = {
key: np.array(matrix).flatten() for key, matrix in posteriors.items()
}
# drop any final_timestep parameters in case they snuck in
posteriors = drop_keys_with_substring(posteriors, "final_timestep")
# Compute the correlation matrix, reverse it so diagonal starts @ top left
correlation_matrix = pd.DataFrame(posteriors).corr()[::-1]

Expand Down
21 changes: 21 additions & 0 deletions src/resp_ode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,27 @@ def flatten_list_parameters(
return return_dict


def drop_keys_with_substring(dct: dict[str], drop_s: str):
"""A simple helper function designed to drop keys from a dictionary if they contain some substring
Parameters
----------
dct : dict[str, Any]
a dictionary with string keys
drop_s : str
keys containing `drop_s` as a substring will be dropped
Returns
-------
dict[str, any]
dct with keys containing drop_s removed, otherwise untouched.
"""
keys_to_drop = [key for key in dct.keys() if drop_s in key]
for key in keys_to_drop:
del dct[key]
return dct


# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
# DEATH CALCULATION CODE
# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
Expand Down

0 comments on commit 96d62f6

Please sign in to comment.