Skip to content

Commit

Permalink
Cleaned initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
Ignacia Echeverria authored and Ignacia Echeverria committed Jun 21, 2024
1 parent 038ef11 commit 64562c4
Showing 1 changed file with 55 additions and 75 deletions.
130 changes: 55 additions & 75 deletions pyext/src/analysis_trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,31 @@
import multiprocessing
from equilibration import detectEquilibration
import hdbscan

import IMP
import IMP.rmf
import RMF

import matplotlib as mpl
import matplotlib.pylab as plt
import matplotlib.gridspec as gridspec

mpl.rcParams.update({"font.size": 10})

restraints = [
("ConnectivityRestraint", "CR","Connectivity_restraint"),
("ExcludedVolumeSphere", "EV","ExcludedVolume_restraint"),
("GaussianEMRestraint", "EM3D","EM_restraint"),
("DistanceRestraint_Score", "DR","Distance_restraint"),
("ResidueBindingRestraint_score", "BR","ResidueBinding_restraint"),
("OccamsPositionalRestraint_Score","OccPos","OccamsPos_restraint"),
("pEMapRestraint_Score","pEMap","pEMAP_restraint"),
("DOPE_Restraint_score","DOPE","DOPE_restraint"),
("MembraneExclusionRestraint","MEX","MembraneExclusion_restraint"),
("MembraneSurfaceLocation","MLS","MembraneSurfaceLocation_restraint"),
("CrossLinkingMassSpectrometryRestraint_Data_Score","XL","XLs_restraint"),
("CrossLinkingMassSpectrometryRestraint_Data_Score","atomic_XL","atomic_XLs_restraint"),
("OccamsRestraint_Score","Occ","Occams_restraint")
]

def generate_n_distinct_colors(n):
# Generate a palette of N distinct colors using matplotlib's HSV colormap
colors = plt.cm.hsv(np.linspace(0, 1, n))
Expand Down Expand Up @@ -114,10 +128,6 @@ def __init__(
)
self.plot_fmt = plot_fmt





# report if equilibration detection has been requested
if not self.detect_equilibration:
print("Not running equilibration check.")
Expand All @@ -130,29 +140,14 @@ def __init__(
"discarding them" % self.burn_in_frac
)

# Define with restraints to analyze
self.Connectivity_restraint = False
self.Excluded_volume_restraint = False
self.XLs_restraint = False
self.XLs_restraint_nuisances = True
self.Multiple_XLs_restraints = False
self.atomic_XLs_restraint = False
self.atomic_XLs_restraint_nuisances = True
self.Multiple_atomic_XLs_restraints = False

self.EM_restraint = False
self.Distance_restraint = False
self.Binding_restraint = False
self.Occams_restraint = False
self.Occams_restraint_nuisances = False
self.Occams_positional_restraint = False
self.Occams_positional_nuisances = False
self.pEMAP_restraint = False
self.pEMAP_restraint_new = False
self.DOPE_restraint = False
self.MembraneExclusion_restraint = False
self.MembraneSurfaceLocation_restraint = False
# Initialize all restraints to False
# By default, add restraints of same type
for (handle, name, flag) in restraints:
setattr(self, flag, False)
setattr(self, f'sum_{flag}', True)
self.score_only_restraint = False
self.sum_score_only_restraint = {}
self.Multiple_psi_values = False

# Other handles
self.restraint_handles = [
Expand All @@ -164,26 +159,17 @@ def __init__(
self.info_handles = []
self.all_score_fields = []

# By default, add restraints of same type
self.sum_Connectivity_restraint = True
self.sum_Excluded_volume_restraint = True
self.sum_Binding_restraint = True
self.sum_Distance_restraint = True
self.sum_XLs_restraint = True
self.sum_Occams_positional_restraint = True
self.sum_atomic_XLs_restraint = True
self.sum_DOPE_restraint = True
self.sum_MembraneExclusion_restraint = True
self.sum_MembraneSurfaceLocation_restraint = True
self.sum_EM_restraint = False
self.sum_score_only_restraint = {}

self.Multiple_psi_values = False


# Separate trajectories into two halves
self.dir_halfA = np.sort(self.out_dirs)[::2]
self.dir_halfB = np.sort(self.out_dirs)[1::2]

def set_analyze_restraint(self, handle, name, flag=None):
self.restraint_handles.append(handle)
self.restraint_names[handle] = name
if flag:
setattr(self, flag, False)

def set_analyze_XLs_restraint(
self,
get_nuisances=True,
Expand Down Expand Up @@ -251,10 +237,10 @@ def set_analyze_Distance_restraint(self):
self.restraint_names["DistanceRestraint_Score"] = "DR"
self.Distance_restraint = True

def set_analyze_Binding_restraint(self):
def set_analyze_ResidueBinding_restraint(self):
self.restraint_handles.append("ResidueBindingRestraint_score")
self.restraint_names["ResidueBindingRestraint_score"] = "BR"
self.Binding_restraint = True
self.ResidueBinding_restraint = True

def set_analyze_Occams_restraint(self):
self.restraint_handles.append("OccamsRestraint_Score")
Expand All @@ -279,16 +265,11 @@ def set_analyze_Occams_positional_restraint(self):
# self.Occams_positional_nuisances = True

def set_analyze_pEMAP_restraint(self):
self.restraint_handles.append("SimplifiedPEMAP_data_Score")
self.restraint_names["SimplifiedPEMAP_data_Score"] = "pEMap"
# self.pEMAP_restraint = True

def set_analyze_pEMAP_restraint_new(self):
self.restraint_handles.append("pEMapRestraint_Score")
self.restraint_names["pEMapRestraint_Score"] = "pEMap"
self.info_handles.append("pEMapRestraint_satisfaction")
self.info_handles.append("pEMapRestraint_sigma")
self.pEMAP_restraint_new = True
self.pEMAP_restraint = True

def set_analyze_DOPE_restraint(self):
self.restraint_handles.append("DOPE_Restraint_score")
Expand All @@ -298,12 +279,12 @@ def set_analyze_DOPE_restraint(self):
def set_analyze_MembraneExclusion_restraint(self):
self.restraint_handles.append("MembraneExclusionRestraint")
self.restraint_names["MembraneExclusionRestraint"] = "MEX"
# self.MembraneExclusion_restraint = True
self.MembraneExclusion_restraint = True

def set_analyze_MembraneSurfaceLocation_restraint(self):
self.restraint_handles.append("MembraneSurfaceLocation")
self.restraint_names["MembraneSurfaceLocation"] = "MSL"
# self.MembraneSurfaceLocation_restraint = True
self.MembraneSurfaceLocation_restraint = True

def set_analyze_score_only_restraint(self, handle, short_name,
do_sum=True):
Expand Down Expand Up @@ -479,7 +460,6 @@ def read_stats_detailed(self, traj, stat_files):
S_dist = []
P_info = []


for sf in stat_files:
# Read header
stat2_dict = self.get_keys(sf)
Expand Down Expand Up @@ -658,15 +638,15 @@ def read_traj_info(self, out_dirs_sel):
if "CR_" in v and "sum" not in v
]
if (
self.Excluded_volume_restraint
and not self.sum_Excluded_volume_restraint
self.ExcludedVolume_restraint
and not self.sum_ExcludedVolume_restraint
):
sel_entries += [
v
for v in S_tot_scores.columns.values
if "EV_" in v and "sum" not in v
]
if self.Binding_restraint and not self.sum_Binding_restraint:
if self.ResidueBinding_restraint and not self.sum_ResdiueBinding_restraint:
sel_entries += [
v
for v in S_tot_scores.columns.values
Expand Down Expand Up @@ -735,7 +715,7 @@ def read_traj_info(self, out_dirs_sel):
ts_eq, burn_in, file_out
)

if self.pEMAP_restraint_new:
if self.pEMAP_restraint:
file_out_pemap = "plot_pEMAP_%s.%s" % (traj_number,
self.plot_fmt)
self.plot_pEMAP_satisfaction(S_info, file_out_pemap)
Expand Down Expand Up @@ -1262,7 +1242,7 @@ def do_extract_models(self, gsms_info, filename, gsms_dir):

# Setup a list of processes that we want to run
processes = [
mp.Process(
multiprocessing.Process(
target=self.extract_models,
args=(df_array[x], filename, gsms_dir)
)
Expand Down Expand Up @@ -1914,7 +1894,7 @@ def plot_XLs_satisfaction(self, S_info, ts_max, file_out):
XLs_percent.sort()
XLs_nuis.sort()

fig, ax = pl.subplots(figsize=(10.0, 4.0), nrows=1, ncols=3)
fig, ax = plt.subplots(figsize=(10.0, 4.0), nrows=1, ncols=3)
axes = ax.flatten()
for i, c in enumerate(XLs_percent):
label = c
Expand Down Expand Up @@ -1948,9 +1928,9 @@ def plot_XLs_satisfaction(self, S_info, ts_max, file_out):
handles, labels = ax[1].get_legend_handles_labels()
ax[2].legend(handles[::-1], labels[::-1])

pl.tight_layout(pad=1.0, w_pad=1.0, h_pad=1.5)
plt.tight_layout(pad=1.0, w_pad=1.0, h_pad=1.5)
fig.savefig(os.path.join(self.analysis_dir, file_out))
pl.close()
plt.close()

def boxplot_XLs_distances(self, cluster=0, type_XLs=None, cutoff=30.0):
if type_XLs:
Expand Down Expand Up @@ -2028,7 +2008,7 @@ def boxplot_XLs_distances(self, cluster=0, type_XLs=None, cutoff=30.0):
n_frac = int(math.ceil(n_xls / float(n_plots)))

# Generate plot
fig, ax = pl.subplots(figsize=(12, 6.0 * n_plots), nrows=n_plots,
fig, ax = plt.subplots(figsize=(12, 6.0 * n_plots), nrows=n_plots,
ncols=1)
if n_plots == 1:
ax = [ax]
Expand Down Expand Up @@ -2062,9 +2042,9 @@ def boxplot_XLs_distances(self, cluster=0, type_XLs=None, cutoff=30.0):
ax[i].set_ylabel("XLs distances (A)")
ax[i].set_title("XLs distance distributions")

pl.tight_layout()
plt.tight_layout()
fig.savefig(os.path.join(self.analysis_dir, file_out))
pl.close()
plt.close()

# Plot histogram of best cluster distances
if type_XLs:
Expand All @@ -2083,19 +2063,19 @@ def boxplot_XLs_distances(self, cluster=0, type_XLs=None, cutoff=30.0):
self.plot_XLs_satisfaction_histogram(min_all, cutoff, file_out_hist)

def plot_XLs_satisfaction_histogram(self, min_all, cutoff, file_out_hist):
fig, ax = pl.subplots(figsize=(5, 5), nrows=1, ncols=1)
fig, ax = plt.subplots(figsize=(5, 5), nrows=1, ncols=1)

ax.hist(min_all, 20, color="b", alpha=0.5)
ax.axvline(x=cutoff, color="orange", alpha=0.7, lw=3)
ax.set_xlabel("Distance (A)", fontsize=12)
ax.set_ylabel("Number of XLs")
ax.set_title("XLs satisfaction")
pl.tight_layout()
plt.tight_layout()
fig.savefig(os.path.join(self.analysis_dir, file_out_hist))
pl.close()
plt.close()

def plot_pEMAP_satisfaction(self, S_info, file_out):
fig, ax = pl.subplots(figsize=(4.0, 4.0), nrows=1, ncols=1)
fig, ax = plt.subplots(figsize=(4.0, 4.0), nrows=1, ncols=1)

pemap_satif = [
v for v in S_info.columns.values
Expand All @@ -2111,17 +2091,17 @@ def plot_pEMAP_satisfaction(self, S_info, file_out):
ax.set_xlabel("Step", fontsize=12)
ax.set_ylabel("Percent satisfied", fontsize=12)

pl.tight_layout(pad=1.2, w_pad=1.5, h_pad=2.5)
plt.tight_layout(pad=1.2, w_pad=1.5, h_pad=2.5)
fig.savefig(os.path.join(self.analysis_dir, file_out))
pl.close()
plt.close()

def plot_Occams_satisfaction(self, Occams_info, file_out):
"""Plot percent of restraint satisfied and the distribution of
the nuisances"""

c = ["gold", "red", "blue", "green"]

fig, ax = pl.subplots(figsize=(12.0, 4.0), nrows=1, ncols=3)
fig, ax = plt.subplots(figsize=(12.0, 4.0), nrows=1, ncols=3)
# Occams satisfaction
occams_satif = [
k for k in Occams_info.columns.values
Expand Down Expand Up @@ -2184,9 +2164,9 @@ def plot_Occams_satisfaction(self, Occams_info, file_out):
handles, labels = ax[2].get_legend_handles_labels()
ax[2].legend(handles[::-1], labels[::-1])

pl.tight_layout(pad=1.2, w_pad=1.5, h_pad=2.5)
plt.tight_layout(pad=1.2, w_pad=1.5, h_pad=2.5)
fig.savefig(os.path.join(self.analysis_dir, file_out))
pl.close()
plt.close()

def substrings(self, s):
for i in range(len(s)):
Expand Down

0 comments on commit 64562c4

Please sign in to comment.