From e5cb94163931ef25a2307962a59f71b59dff4f56 Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 5 Dec 2024 13:41:37 -0500 Subject: [PATCH] - Change: Optimized the correlation heatmap plot. --- Docs/ChangeLog.md | 5 ++ metax/gui/main_gui.py | 23 +++-- metax/gui/metax_gui/main_window.ui | 90 ++++++++++++++++--- metax/gui/metax_gui/ui_main_window.py | 76 ++++++++++++---- .../analyzer_utils/basic_stats.py | 8 +- metax/taxafunc_ploter/basic_plot.py | 75 ++++++++++------ metax/taxafunc_ploter/get_distinct_colors.py | 21 +++-- metax/utils/version.py | 2 +- pyproject.toml | 2 +- 9 files changed, 229 insertions(+), 73 deletions(-) diff --git a/Docs/ChangeLog.md b/Docs/ChangeLog.md index 452106c..f43b4e7 100644 --- a/Docs/ChangeLog.md +++ b/Docs/ChangeLog.md @@ -1,3 +1,8 @@ +# Version: 1.119.6 +## Date: 2024-12-5 +### Changes: +- Change: Optimized the correlation heatmap plot. + # Version: 1.119.5 ## Date: 2024-12-3 ### Changes: diff --git a/metax/gui/main_gui.py b/metax/gui/main_gui.py index 0c473fc..46ca921 100644 --- a/metax/gui/main_gui.py +++ b/metax/gui/main_gui.py @@ -365,6 +365,7 @@ def __init__(self, MainWindow): self.comboBox_method_of_protein_inference.currentIndexChanged.connect(self.update_method_of_protein_inference) self.comboBox_3dbar_sub_meta.currentIndexChanged.connect(self.change_event_comboBox_3dbar_sub_meta) self.comboBox_tflink_sub_meta.currentIndexChanged.connect(self.change_event_comboBox_tflink_sub_meta) + self.comboBox_sub_meta_pca.currentIndexChanged.connect(self.change_event_comboBox_sub_meta_pca) ## Basic Stat self.pushButton_plot_pca_sns.clicked.connect(lambda: self.plot_basic_info_sns('pca')) @@ -889,6 +890,13 @@ def change_event_comboBox_3dbar_sub_meta(self): # self.comboBox_3dbar_sub_meta.setEnabled(False) # else: # self.comboBox_3dbar_sub_meta.setEnabled(True) + + def change_event_comboBox_sub_meta_pca(self): + if self.comboBox_sub_meta_pca.currentText() != 'None': + self.checkBox_corr_plot_samples.setEnabled(False) + else: + self.checkBox_corr_plot_samples.setEnabled(True) + def change_event_comboBox_tflink_sub_meta(self): # when the sub_meta comboBox is not None, the mean plot is not available if self.comboBox_tflink_sub_meta.currentText() != 'None': @@ -4509,6 +4517,8 @@ def get_title_by_table_name(self, table_name): cluster = self.checkBox_corr_cluster.isChecked() show_all_labels = (self.checkBox_corr_show_all_labels_x.isChecked(), self.checkBox_corr_show_all_labels_y.isChecked()) cmap = self.comboBox_basic_corr_cmap.currentText() + corr_method = self.comboBox_basic_corr_method.currentText() + plot_mean = False if self.checkBox_corr_plot_samples.isChecked() else True # checek if the dataframe has at least 2 rows and 2 columns if df.shape[0] < 2 or df.shape[1] < 2: QMessageBox.warning(self.MainWindow, 'Warning', 'The number of rows or columns is less than 2, correlation cannot be plotted!') @@ -4520,7 +4530,8 @@ def get_title_by_table_name(self, table_name): BasicPlot(self.tfa, **self.heatmap_params_dict).plot_corr_sns(df=df, title_name=title_name, cluster= cluster, width=width, height=height, font_size=font_size, show_all_labels=show_all_labels, theme=theme, cmap=cmap, - rename_sample = rename_sample) + rename_sample = rename_sample, corr_method=corr_method, + plot_mean = plot_mean, sub_meta = sub_meta) elif method == 'alpha_div': self.show_message('Alpha diversity is running, please wait...') @@ -5355,9 +5366,11 @@ def plot_co_expr(self, plot_type = 'network'): self.show_message('Co-expression heatmap is plotting...\n\n It may take a long time! Please wait...') try: print(f'Calculate correlation with {corr_method} method...') - df = self.tfa.BasicStats.get_correlation(df_type = df_type, sample_list = sample_list, focus_list = focus_list, plot_list_only = plot_list_only, rename_taxa = rename_taxa, method=corr_method) + df = self.tfa.BasicStats.get_correlation(df_type = df_type, sample_list = sample_list, + focus_list = focus_list, plot_list_only = plot_list_only, + rename_taxa = rename_taxa, method=corr_method) # save df to table_dict - self.update_table_dict(f'expression correlation heatmap({df_type})', df) + self.update_table_dict(f'{corr_method} correlation heatmap({df_type})', df) show_all_labels = ( self.checkBox_corr_hetatmap_show_all_labels_x.isChecked(), @@ -5365,12 +5378,12 @@ def plot_co_expr(self, plot_type = 'network'): ) cmap = self.comboBox_corr_hetatmap_cmap.currentText() BasicPlot(self.tfa, **self.heatmap_params_dict).plot_items_corr_heatmap(df=df, - title_name=f'Expression Correlation Heatmap({df_type})', + title_name=f'{corr_method.capitalize()} Correlation of {df_type}', cluster=True, cmap=cmap, width=width, height=height, font_size=font_size, - show_all_labels=show_all_labels + show_all_labels=show_all_labels, ) except Exception: diff --git a/metax/gui/metax_gui/main_window.ui b/metax/gui/metax_gui/main_window.ui index 5ddd1fe..100fcdc 100644 --- a/metax/gui/metax_gui/main_window.ui +++ b/metax/gui/metax_gui/main_window.ui @@ -46,7 +46,7 @@ Qt::LeftToRight - 5 + 3 false @@ -1447,7 +1447,7 @@ QTabWidget::Triangular - 1 + 0 @@ -1830,7 +1830,7 @@ 0 0 - 799 + 885 239 @@ -2406,6 +2406,12 @@ + + + 0 + 0 + + Show All Labels @@ -2413,6 +2419,12 @@ + + + 0 + 0 + + X @@ -2420,11 +2432,33 @@ + + + 0 + 0 + + Y + + + + + 0 + 0 + + + + Theme + + + + + + @@ -2471,6 +2505,13 @@ + + + + Plot Samples + + + @@ -2488,20 +2529,42 @@ - + - + 0 0 - Theme + Method - + + + + 0 + 0 + + + + + pearson + + + + + spearman + + + + + kendall + + + @@ -6532,7 +6595,7 @@ QTabWidget::Triangular - 1 + 0 @@ -6840,6 +6903,11 @@ spearman + + + kendall + + @@ -7131,7 +7199,7 @@ false - Plot Co-Expression Heatmap + Plot Correlation Heatmap @@ -8350,8 +8418,8 @@ 0 0 - 885 - 225 + 775 + 102 diff --git a/metax/gui/metax_gui/ui_main_window.py b/metax/gui/metax_gui/ui_main_window.py index 9104649..c770287 100644 --- a/metax/gui/metax_gui/ui_main_window.py +++ b/metax/gui/metax_gui/ui_main_window.py @@ -904,7 +904,7 @@ def setupUi(self, metaX_main): self.scrollArea.setWidgetResizable(True) self.scrollArea.setObjectName("scrollArea") self.scrollAreaWidgetContents = QtWidgets.QWidget() - self.scrollAreaWidgetContents.setGeometry(QtCore.QRect(0, 0, 799, 239)) + self.scrollAreaWidgetContents.setGeometry(QtCore.QRect(0, 0, 885, 239)) self.scrollAreaWidgetContents.setObjectName("scrollAreaWidgetContents") self.gridLayout_34 = QtWidgets.QGridLayout(self.scrollAreaWidgetContents) self.gridLayout_34.setObjectName("gridLayout_34") @@ -1190,14 +1190,40 @@ def setupUi(self, metaX_main): self.horizontalLayout_5 = QtWidgets.QHBoxLayout() self.horizontalLayout_5.setObjectName("horizontalLayout_5") self.label_129 = QtWidgets.QLabel(self.scrollAreaWidgetContents) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Maximum, QtWidgets.QSizePolicy.Preferred) + sizePolicy.setHorizontalStretch(0) + sizePolicy.setVerticalStretch(0) + sizePolicy.setHeightForWidth(self.label_129.sizePolicy().hasHeightForWidth()) + self.label_129.setSizePolicy(sizePolicy) self.label_129.setObjectName("label_129") self.horizontalLayout_5.addWidget(self.label_129) self.checkBox_corr_show_all_labels_x = QtWidgets.QCheckBox(self.scrollAreaWidgetContents) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Fixed, QtWidgets.QSizePolicy.Fixed) + sizePolicy.setHorizontalStretch(0) + sizePolicy.setVerticalStretch(0) + sizePolicy.setHeightForWidth(self.checkBox_corr_show_all_labels_x.sizePolicy().hasHeightForWidth()) + self.checkBox_corr_show_all_labels_x.setSizePolicy(sizePolicy) self.checkBox_corr_show_all_labels_x.setObjectName("checkBox_corr_show_all_labels_x") self.horizontalLayout_5.addWidget(self.checkBox_corr_show_all_labels_x) self.checkBox_corr_show_all_labels_y = QtWidgets.QCheckBox(self.scrollAreaWidgetContents) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed) + sizePolicy.setHorizontalStretch(0) + sizePolicy.setVerticalStretch(0) + sizePolicy.setHeightForWidth(self.checkBox_corr_show_all_labels_y.sizePolicy().hasHeightForWidth()) + self.checkBox_corr_show_all_labels_y.setSizePolicy(sizePolicy) self.checkBox_corr_show_all_labels_y.setObjectName("checkBox_corr_show_all_labels_y") self.horizontalLayout_5.addWidget(self.checkBox_corr_show_all_labels_y) + self.label_192 = QtWidgets.QLabel(self.scrollAreaWidgetContents) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Maximum, QtWidgets.QSizePolicy.Preferred) + sizePolicy.setHorizontalStretch(0) + sizePolicy.setVerticalStretch(0) + sizePolicy.setHeightForWidth(self.label_192.sizePolicy().hasHeightForWidth()) + self.label_192.setSizePolicy(sizePolicy) + self.label_192.setObjectName("label_192") + self.horizontalLayout_5.addWidget(self.label_192) + self.comboBox_basic_corr_cmap = QtWidgets.QComboBox(self.scrollAreaWidgetContents) + self.comboBox_basic_corr_cmap.setObjectName("comboBox_basic_corr_cmap") + self.horizontalLayout_5.addWidget(self.comboBox_basic_corr_cmap) self.gridLayout_34.addLayout(self.horizontalLayout_5, 4, 2, 1, 1) self.label_168 = QtWidgets.QLabel(self.scrollAreaWidgetContents) sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Preferred) @@ -1222,6 +1248,9 @@ def setupUi(self, metaX_main): self.gridLayout_34.addLayout(self.horizontalLayout_20, 8, 1, 1, 1) self.horizontalLayout_3 = QtWidgets.QHBoxLayout() self.horizontalLayout_3.setObjectName("horizontalLayout_3") + self.checkBox_corr_plot_samples = QtWidgets.QCheckBox(self.scrollAreaWidgetContents) + self.checkBox_corr_plot_samples.setObjectName("checkBox_corr_plot_samples") + self.horizontalLayout_3.addWidget(self.checkBox_corr_plot_samples) self.checkBox_corr_cluster = QtWidgets.QCheckBox(self.scrollAreaWidgetContents) sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Fixed) sizePolicy.setHorizontalStretch(0) @@ -1231,17 +1260,25 @@ def setupUi(self, metaX_main): self.checkBox_corr_cluster.setChecked(True) self.checkBox_corr_cluster.setObjectName("checkBox_corr_cluster") self.horizontalLayout_3.addWidget(self.checkBox_corr_cluster) - self.label_192 = QtWidgets.QLabel(self.scrollAreaWidgetContents) - sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Maximum, QtWidgets.QSizePolicy.Preferred) + self.label_98 = QtWidgets.QLabel(self.scrollAreaWidgetContents) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Fixed, QtWidgets.QSizePolicy.Preferred) sizePolicy.setHorizontalStretch(0) sizePolicy.setVerticalStretch(0) - sizePolicy.setHeightForWidth(self.label_192.sizePolicy().hasHeightForWidth()) - self.label_192.setSizePolicy(sizePolicy) - self.label_192.setObjectName("label_192") - self.horizontalLayout_3.addWidget(self.label_192) - self.comboBox_basic_corr_cmap = QtWidgets.QComboBox(self.scrollAreaWidgetContents) - self.comboBox_basic_corr_cmap.setObjectName("comboBox_basic_corr_cmap") - self.horizontalLayout_3.addWidget(self.comboBox_basic_corr_cmap) + sizePolicy.setHeightForWidth(self.label_98.sizePolicy().hasHeightForWidth()) + self.label_98.setSizePolicy(sizePolicy) + self.label_98.setObjectName("label_98") + self.horizontalLayout_3.addWidget(self.label_98) + self.comboBox_basic_corr_method = QtWidgets.QComboBox(self.scrollAreaWidgetContents) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed) + sizePolicy.setHorizontalStretch(0) + sizePolicy.setVerticalStretch(0) + sizePolicy.setHeightForWidth(self.comboBox_basic_corr_method.sizePolicy().hasHeightForWidth()) + self.comboBox_basic_corr_method.setSizePolicy(sizePolicy) + self.comboBox_basic_corr_method.setObjectName("comboBox_basic_corr_method") + self.comboBox_basic_corr_method.addItem("") + self.comboBox_basic_corr_method.addItem("") + self.comboBox_basic_corr_method.addItem("") + self.horizontalLayout_3.addWidget(self.comboBox_basic_corr_method) self.gridLayout_34.addLayout(self.horizontalLayout_3, 4, 1, 1, 1) self.horizontalLayout_105 = QtWidgets.QHBoxLayout() self.horizontalLayout_105.setObjectName("horizontalLayout_105") @@ -3538,6 +3575,7 @@ def setupUi(self, metaX_main): self.comboBox_co_expr_corr_method.setObjectName("comboBox_co_expr_corr_method") self.comboBox_co_expr_corr_method.addItem("") self.comboBox_co_expr_corr_method.addItem("") + self.comboBox_co_expr_corr_method.addItem("") self.horizontalLayout_54.addWidget(self.comboBox_co_expr_corr_method) self.gridLayout_58.addLayout(self.horizontalLayout_54, 0, 1, 1, 1) self.horizontalLayout_30 = QtWidgets.QHBoxLayout() @@ -4349,7 +4387,7 @@ def setupUi(self, metaX_main): self.scrollArea_6.setWidgetResizable(True) self.scrollArea_6.setObjectName("scrollArea_6") self.scrollAreaWidgetContents_7 = QtWidgets.QWidget() - self.scrollAreaWidgetContents_7.setGeometry(QtCore.QRect(0, 0, 885, 225)) + self.scrollAreaWidgetContents_7.setGeometry(QtCore.QRect(0, 0, 775, 102)) self.scrollAreaWidgetContents_7.setObjectName("scrollAreaWidgetContents_7") self.gridLayout_69 = QtWidgets.QGridLayout(self.scrollAreaWidgetContents_7) self.gridLayout_69.setObjectName("gridLayout_69") @@ -5711,11 +5749,11 @@ def setupUi(self, metaX_main): self.retranslateUi(metaX_main) self.stackedWidget.setCurrentIndex(0) - self.tabWidget_TaxaFuncAnalyzer.setCurrentIndex(5) + self.tabWidget_TaxaFuncAnalyzer.setCurrentIndex(3) self.toolBox_2.setCurrentIndex(0) - self.tabWidget_4.setCurrentIndex(1) + self.tabWidget_4.setCurrentIndex(0) self.tabWidget_3.setCurrentIndex(3) - self.tabWidget.setCurrentIndex(1) + self.tabWidget.setCurrentIndex(0) self.tabWidget_2.setCurrentIndex(1) self.tabWidget_6.setCurrentIndex(0) self.toolBox_metalab_res_anno.setCurrentIndex(0) @@ -5960,10 +5998,15 @@ def retranslateUi(self, metaX_main): self.label_129.setText(_translate("metaX_main", "Show All Labels")) self.checkBox_corr_show_all_labels_x.setText(_translate("metaX_main", "X")) self.checkBox_corr_show_all_labels_y.setText(_translate("metaX_main", "Y")) + self.label_192.setText(_translate("metaX_main", "Theme")) self.label_168.setText(_translate("metaX_main", "Correlation Heatmap")) self.checkBox_sunburst_show_all_lables.setText(_translate("metaX_main", "Show All Lables for Sunburst")) + self.checkBox_corr_plot_samples.setText(_translate("metaX_main", "Plot Samples")) self.checkBox_corr_cluster.setText(_translate("metaX_main", "Cluster")) - self.label_192.setText(_translate("metaX_main", "Theme")) + self.label_98.setText(_translate("metaX_main", "Method")) + self.comboBox_basic_corr_method.setItemText(0, _translate("metaX_main", "pearson")) + self.comboBox_basic_corr_method.setItemText(1, _translate("metaX_main", "spearman")) + self.comboBox_basic_corr_method.setItemText(2, _translate("metaX_main", "kendall")) self.label_207.setText(_translate("metaX_main", "UpSet")) self.checkBox_basic_plot_upset_show_percentage.setText(_translate("metaX_main", "Show Percentages")) self.label_206.setText(_translate("metaX_main", "Min Subset Size")) @@ -6210,6 +6253,7 @@ def retranslateUi(self, metaX_main): self.label_65.setText(_translate("metaX_main", "Method of Correlation")) self.comboBox_co_expr_corr_method.setItemText(0, _translate("metaX_main", "pearson")) self.comboBox_co_expr_corr_method.setItemText(1, _translate("metaX_main", "spearman")) + self.comboBox_co_expr_corr_method.setItemText(2, _translate("metaX_main", "kendall")) self.label_162.setText(_translate("metaX_main", "Font Size")) self.checkBox_co_expr_show_label.setText(_translate("metaX_main", "Show Labels")) self.label_191.setText(_translate("metaX_main", "Theme")) @@ -6223,7 +6267,7 @@ def retranslateUi(self, metaX_main): self.pushButton_co_expr_add_top.setToolTip(_translate("metaX_main", "Add conditionally filtered items to the drawing box")) self.pushButton_co_expr_add_top.setText(_translate("metaX_main", "Add Top to List")) self.pushButton_co_expr_plot.setText(_translate("metaX_main", "Plot Co-Expression Network")) - self.pushButton_co_expr_heatmap_plot.setText(_translate("metaX_main", "Plot Co-Expression Heatmap")) + self.pushButton_co_expr_heatmap_plot.setText(_translate("metaX_main", "Plot Correlation Heatmap")) self.label_73.setText(_translate("metaX_main", "Select Top")) self.label_74.setText(_translate("metaX_main", "Sort by")) self.comboBox_co_expr_top_by.setItemText(0, _translate("metaX_main", "Total Intensity")) diff --git a/metax/taxafunc_analyzer/analyzer_utils/basic_stats.py b/metax/taxafunc_analyzer/analyzer_utils/basic_stats.py index 9dda064..7bbb551 100644 --- a/metax/taxafunc_analyzer/analyzer_utils/basic_stats.py +++ b/metax/taxafunc_analyzer/analyzer_utils/basic_stats.py @@ -166,7 +166,7 @@ def get_correlation(self, df_type: str, `df_type`: str: 'taxa', 'func', 'taxa_func', 'func_taxa', 'custom' `sample_list`: a list of samples to calculate correlation `plot_list_only`: bool: if True, only return the list of samples that can be plotted - `method`: str: 'pearson', 'spearman' + `method`: str: 'pearson', 'spearman', 'kendall' ''' df = self.tfa.get_df(df_type) df = self.tfa.replace_if_two_index(df) @@ -204,7 +204,7 @@ def get_combined_sub_meta_df( tuple[pd.DataFrame, Dict[str, str]]: A tuple containing the combined DataFrame and a dictionary with the sample names as keys and the group names as values. """ if sub_meta != 'None': - + plot_mean = False # if sub_meta is not None, the mean will be calculated based on the meta and sub_meta automatically sample_groups = {sample: self.tfa.get_group_of_a_sample(sample, self.tfa.meta_name) for sample in df.columns} sub_groups = {sample: self.tfa.get_group_of_a_sample(sample, sub_meta) for sample in df.columns} @@ -218,8 +218,8 @@ def get_combined_sub_meta_df( else: grouped_data = df.T.groupby([sample_groups, sub_groups]).mean().T - # group_list is the sub-meta group - group_list = [i[1] for i in grouped_data.columns] if not plot_mean else grouped_data.columns.tolist() + # group_list is the meta group:i[0], set i[1] if want to show sub_meta group + group_list = [i[0] for i in grouped_data.columns] # Convert multi-index to single index grouped_data.columns = ['_'.join(col).strip() for col in grouped_data.columns.values] diff --git a/metax/taxafunc_ploter/basic_plot.py b/metax/taxafunc_ploter/basic_plot.py index 8bc4648..d5b44a7 100644 --- a/metax/taxafunc_ploter/basic_plot.py +++ b/metax/taxafunc_ploter/basic_plot.py @@ -27,7 +27,7 @@ def __init__(self, tfobj, sns.set_theme() - def plot_taxa_stats_pie(self, theme:str = 'Auto', res_type = 'pic', font_size = 12): + def plot_taxa_stats_pie(self, theme:str = 'Auto', res_type = 'pic', font_size = 10, width = None, height = None): df = self.tfa.BasicStats.get_stats_peptide_num_in_taxa() # if 'not_found' is 0, then remove it @@ -51,17 +51,20 @@ def plot_taxa_stats_pie(self, theme:str = 'Auto', res_type = 'pic', font_size = # set color palette colors = sns.color_palette("deep") - # set figure size base on font size - if font_size <= 10: - fig_size = (8, 6) - elif font_size <= 12: - fig_size = (10, 8) - elif font_size <= 14: - fig_size = (12, 10) - elif font_size <= 16: - fig_size = (14, 12) + # set figure size base on font size if not specified + if width is None and height is None: + if font_size <= 10: + fig_size = (8, 6) + elif font_size <= 12: + fig_size = (10, 8) + elif font_size <= 14: + fig_size = (12, 10) + elif font_size <= 16: + fig_size = (14, 12) + else: + fig_size = (16, 14) else: - fig_size = (16, 14) + fig_size = (width, height) fig = plt.figure(figsize=fig_size) if res_type == 'show' else plt.figure() @@ -91,7 +94,7 @@ def plot_taxa_stats_pie(self, theme:str = 'Auto', res_type = 'pic', font_size = return fig # input: self.get_stats_taxa_level() - def plot_taxa_number(self, peptide_num = 1, theme:str = 'Auto', res_type = 'pic', font_size = 10): + def plot_taxa_number(self, peptide_num = 1, theme:str = 'Auto', res_type = 'pic', font_size = 10, width = None, height = None): df = self.tfa.BasicStats.get_stats_taxa_level(peptide_num) # if genome in taxa_level and count of species == count of genome, then remove genome, and rename species to species (genome) @@ -106,7 +109,12 @@ def plot_taxa_number(self, peptide_num = 1, theme:str = 'Auto', res_type = 'pic' else: custom_params = {"axes.spines.right": False, "axes.spines.top": False} sns.set_theme(style="ticks", rc=custom_params) - plt.figure(figsize=(10, 8)) if res_type == 'show' else plt.figure() + # plt.figure(figsize=(10, 8)) if res_type == 'show' else plt.figure() + if width is None and height is None: + plt.figure(figsize=(10, 8)) if res_type == 'show' else plt.figure() + else: + plt.figure(figsize=(width, height)) if res_type == 'show' else plt.figure() + ax = sns.barplot(data=df, x='taxa_level', y='count',dodge=False, hue='taxa_level') for i in ax.containers: # set the label of the bar, and fontsize @@ -129,7 +137,7 @@ def plot_taxa_number(self, peptide_num = 1, theme:str = 'Auto', res_type = 'pic' return ax # input: self.get_stats_func_prop() - def plot_prop_stats(self, func_name = 'eggNOG_OGs', theme:str = 'Auto', res_type = 'pic', font_size = 10): + def plot_prop_stats(self, func_name = 'eggNOG_OGs', theme:str = 'Auto', res_type = 'pic', font_size = 10, width = None, height = None): df = self.tfa.BasicStats.get_stats_func_prop(func_name) # #dodge=False to make the bar wider # plt.figure(figsize=(8, 6)) @@ -138,8 +146,10 @@ def plot_prop_stats(self, func_name = 'eggNOG_OGs', theme:str = 'Auto', res_type else: custom_params = {"axes.spines.right": False, "axes.spines.top": False} sns.set_theme(style="ticks", rc=custom_params) - - plt.figure(figsize=(8, 6)) if res_type == 'show' else plt.figure() + if width is None and height is None: + plt.figure(figsize=(8, 6)) if res_type == 'show' else plt.figure() + else: + plt.figure(figsize=(width, height)) if res_type == 'show' else plt.figure() ax = sns.barplot(data=df, x='prop', y='n', hue='label', dodge=False, palette='tab10_r') for i in ax.containers: @@ -382,18 +392,24 @@ def plot_corr_sns( cmap: str = "Auto", rename_sample: bool = False, corr_method: str = "pearson", + sub_meta: str | None = 'None', + plot_mean: bool = False, ): - dft= df.copy() - if rename_sample: - dft, group_list = self.tfa.add_group_name_for_sample(dft) - else: - group_list = [self.tfa.get_group_of_a_sample(i) for i in dft.columns] + dft, group_dict = self.tfa.BasicStats.prepare_dataframe_for_heatmap(df = df, sub_meta = sub_meta, + rename_sample = rename_sample, + plot_mean = plot_mean) + + # if rename_sample: + # _, group_list = self.tfa.add_group_name_for_sample(dft) + # else: + # group_list = [self.tfa.get_group_of_a_sample(i) for i in dft.columns] if cmap == 'Auto': cmap = 'RdYlBu_r' else: cmap = cmap - + + group_list = [group_dict[i] for i in dft.columns] color_list = self.assign_colors(group_list) # check if the correlation method is valid @@ -413,9 +429,12 @@ def plot_corr_sns( 'row_cluster':True if cluster else False, 'method':self.linkage_method, 'metric':self.distance_metric, - "linecolor":(0/255, 0/255, 0/255, 0.01), "dendrogram_ratio":(.1, .2),"col_colors":color_list, + "linecolor":(0/255, 0/255, 0/255, 0.01), "dendrogram_ratio":(.1, .2), + "col_colors": color_list, "figsize":(width, height), "xticklabels":True if show_all_labels[0] else "auto", - "yticklabels":True if show_all_labels[1] else 'auto'} + "yticklabels":True if show_all_labels[1] else 'auto', + "center":0, "vmin":-1, "vmax":1 + } fig = sns.clustermap(corr, **sns_params) ax = fig.ax_heatmap @@ -433,7 +452,7 @@ def plot_corr_sns( va = self.get_y_labels_va() ) - fig.ax_col_dendrogram.set_title(f'Correlation of {title_name}', fontsize=font_size+2, fontweight='bold') + fig.ax_col_dendrogram.set_title(f'{corr_method.capitalize()} Correlation of {title_name}', fontsize=font_size+2, fontweight='bold') cbar = fig.ax_heatmap.collections[0].colorbar cbar.set_label('correlation', @@ -711,12 +730,14 @@ def plot_items_corr_heatmap( 'metric':self.distance_metric, "linecolor":(0/255, 0/255, 0/255, 0.01), "dendrogram_ratio":(.1, .2), "figsize":(width, height), "xticklabels":True if show_all_labels[0] else "auto", - "yticklabels":True if show_all_labels[1] else 'auto'} + "yticklabels":True if show_all_labels[1] else 'auto', + "center":0, "vmin":-1, "vmax":1 + } fig = sns.clustermap(corr, **sns_params) ax = fig.ax_heatmap - fig.ax_col_dendrogram.set_title(f'Correlation of {title_name}', fontsize=font_size+2, fontweight='bold') + fig.ax_col_dendrogram.set_title(f'{title_name}', fontsize=font_size+2, fontweight='bold') fig.ax_heatmap.set_xticklabels( fig.ax_heatmap.get_xmajorticklabels(), fontsize=font_size, diff --git a/metax/taxafunc_ploter/get_distinct_colors.py b/metax/taxafunc_ploter/get_distinct_colors.py index 65fdd7e..fe01c6b 100644 --- a/metax/taxafunc_ploter/get_distinct_colors.py +++ b/metax/taxafunc_ploter/get_distinct_colors.py @@ -11,15 +11,20 @@ def __init__(self): pass def assign_colors(self, groups_list: list) -> list: - ''' - Assign colors of the number of unique groups in the list + """ + Assign colors to the number of unique groups in the list - return a list of colors in hex format - ''' - colors = self.get_distinct_colors(len(set(groups_list))) - result = [] - for group in groups_list: - index = sorted(set(groups_list)).index(group) - result.append(colors[index]) + """ + # Get unique sorted groups and their indices + unique_groups = sorted(set(groups_list)) + group_to_index = {group: idx for idx, group in enumerate(unique_groups)} + + # Generate distinct colors based on the number of unique groups + colors = self.get_distinct_colors(len(unique_groups)) + print(f'Assigned colors for {len(unique_groups)} groups') + + # Map groups to colors using the precomputed index + result = [colors[group_to_index[group]] for group in groups_list] return result def adjust_color(self, color, sat_factor=0.7, light_factor=1.2): diff --git a/metax/utils/version.py b/metax/utils/version.py index 7bdb37c..b6393cc 100644 --- a/metax/utils/version.py +++ b/metax/utils/version.py @@ -1,2 +1,2 @@ -__version__ = '1.119.5' +__version__ = '1.119.6' API_version = '4' \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 2f31c1e..9c53d4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "MetaXTools" -version = "1.119.5" +version = "1.119.6" description = "MetaXTools is a novel tool for linking peptide sequences with taxonomic and functional information in Metaproteomics." readme = "README_PyPi.md" license = { text = "NorthOmics" }