From 43fdb1519782f48006bea50822f069b11d44846e Mon Sep 17 00:00:00 2001 From: jakakokosar Date: Wed, 9 Oct 2019 21:57:47 +0200 Subject: [PATCH] OWVolcanoPlot: general improvements, use of GeneScoring component --- .../tests/widgets/ow_components/__init__.py | 0 .../ow_components/test_gene_scoring.py | 73 ++++++ .../bioinformatics/utils/statistics.py | 16 +- .../bioinformatics/widgets/OWVolcanoPlot.py | 151 ++++++------ .../widgets/ow_components/__init__.py | 3 + .../widgets/ow_components/gene_scoring.py | 230 ++++++++++++++++++ 6 files changed, 398 insertions(+), 75 deletions(-) create mode 100644 orangecontrib/bioinformatics/tests/widgets/ow_components/__init__.py create mode 100644 orangecontrib/bioinformatics/tests/widgets/ow_components/test_gene_scoring.py create mode 100644 orangecontrib/bioinformatics/widgets/ow_components/__init__.py create mode 100644 orangecontrib/bioinformatics/widgets/ow_components/gene_scoring.py diff --git a/orangecontrib/bioinformatics/tests/widgets/ow_components/__init__.py b/orangecontrib/bioinformatics/tests/widgets/ow_components/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orangecontrib/bioinformatics/tests/widgets/ow_components/test_gene_scoring.py b/orangecontrib/bioinformatics/tests/widgets/ow_components/test_gene_scoring.py new file mode 100644 index 00000000..4486f1f4 --- /dev/null +++ b/orangecontrib/bioinformatics/tests/widgets/ow_components/test_gene_scoring.py @@ -0,0 +1,73 @@ +import unittest + +from AnyQt.QtTest import QSignalSpy + +from Orange.data import Table +from Orange.widgets.widget import OWWidget +from Orange.widgets.settings import SettingProvider +from Orange.widgets.tests.base import WidgetTest +from Orange.widgets.tests.utils import simulate + +from orangecontrib.bioinformatics.utils.statistics import score_hypergeometric_test +from orangecontrib.bioinformatics.widgets.ow_components import GeneScoringComponent + + +class MockWidget(OWWidget): + name = "Mock" + scoring_component = SettingProvider(GeneScoringComponent) + + def __init__(self): + self.scoring_component = GeneScoringComponent(self, self.mainArea) + + +class TestGeneScoringComponent(WidgetTest): + def setUp(self): + self.widget = MockWidget() + self.component = self.widget.scoring_component + + def test_scoring_methods_combobox(self): + combo_box_values = [ + self.component.score_method_combo.itemText(i) for i in range(self.component.score_method_combo.count()) + ] + self.assertTrue(len(combo_box_values) > 0) + self.assertEqual([name for name, _ in self.component.score_methods], combo_box_values) + + signals_cb_emits = QSignalSpy(self.component.score_method_changed) + simulate.combobox_run_through_all(self.component.score_method_combo) + + self.assertEqual(self.component.score_method_combo.currentIndex(), self.component.current_method_index) + self.assertEqual(self.component.current_method_index, len(combo_box_values) - 1) + + # number of signals combobox emits should be equal to the length of available scoring methods + self.assertEqual(len(combo_box_values), len(signals_cb_emits)) + + def test_expression_threshold_spinbox(self): + # find index of item in combobox for hypergeometric test + method_index, *_ = [ + index + for index, (name, method) in enumerate(self.component.score_methods) + if method == score_hypergeometric_test + ] + + # check if spinbox appears after hypergeometric test is selected + self.assertTrue(self.component.expression_threshold_box.isHidden()) + simulate.combobox_activate_index(self.component.score_method_combo, method_index) + self.assertFalse(self.component.expression_threshold_box.isHidden()) + + def test_group_values(self): + self.assertIsNone(self.component.data) + self.component.initialize(Table('iris')) + self.assertIsNotNone(self.component.data) + + # we expect only one value 'iris class attribute' + combo_box_value, *_ = [ + self.component.group_combo.itemText(i) for i in range(self.component.group_combo.count()) + ] + self.assertEqual(combo_box_value, 'iris') + + group_values = [self.component.list_widget.item(i).text() for i in range(self.component.list_widget.count())] + self.assertEqual(group_values, ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']) + + +if __name__ == "__main__": + unittest.main() diff --git a/orangecontrib/bioinformatics/utils/statistics.py b/orangecontrib/bioinformatics/utils/statistics.py index 51574ac6..b5704e06 100644 --- a/orangecontrib/bioinformatics/utils/statistics.py +++ b/orangecontrib/bioinformatics/utils/statistics.py @@ -12,8 +12,9 @@ ALTERNATIVES = [ALT_GREATER, ALT_TWO, ALT_LESS] -def score_t_test(a, b, axis=0, alternative=ALT_TWO): - # type: (np.array, np.array, int, str) -> Tuple[Union[float, np.array], Union[float, np.array]] +def score_t_test( + a: np.array, b: np.array, axis: int = 0, **kwargs +) -> Tuple[Union[float, np.array], Union[float, np.array]]: """ Run t-test. Enable setting different alternative hypothesis. Probabilities are exact due to symmetry of the test. @@ -24,7 +25,7 @@ def score_t_test(a, b, axis=0, alternative=ALT_TWO): scipy.stats.ttest_ind """ - # alt = kwargs.get("alternative", ALT_TWO) + alternative = kwargs.get("alternative", ALT_TWO) assert alternative in ALTERNATIVES scores, pvalues = scipy.stats.ttest_ind(a, b, axis=axis) @@ -41,7 +42,7 @@ def score_t_test(a, b, axis=0, alternative=ALT_TWO): return scores, 1.0 - pvalues -def score_mann_whitney(a, b, **kwargs): +def score_mann_whitney(a: np.array, b: np.array, **kwargs) -> Tuple[np.array, np.array]: axis = kwargs.get('axis', 0) a, b = np.asarray(a, dtype=float), np.asarray(b, dtype=float) @@ -71,12 +72,15 @@ def score_mann_whitney(a, b, **kwargs): return np.array(statistics), np.array(p_values) -def score_hypergeometric_test(a, b, threshold=1, **kwargs): +def score_hypergeometric_test(a: np.array, b: np.array, threshold: float = 1.0, **kwargs) -> Tuple[np.array, np.array]: """ Run a hypergeometric test. The probability in a two-sided test is approximated with the symmetric distribution with more extreme of the tails. """ - # type: (np.ndarray, np.ndarray, float) -> np.ndarray + axis = kwargs.get('axis', 0) + + if axis == 1: + a, b = a.T, b.T # Binary expression matrices _a = (a >= threshold).astype(int) diff --git a/orangecontrib/bioinformatics/widgets/OWVolcanoPlot.py b/orangecontrib/bioinformatics/widgets/OWVolcanoPlot.py index a82832af..3f299305 100644 --- a/orangecontrib/bioinformatics/widgets/OWVolcanoPlot.py +++ b/orangecontrib/bioinformatics/widgets/OWVolcanoPlot.py @@ -1,13 +1,16 @@ +from typing import Optional + import numpy as np +from Orange.data import Table from Orange.widgets import gui, settings from Orange.widgets.widget import Msg -from Orange.widgets.settings import SettingProvider +from Orange.widgets.settings import SettingProvider, DomainContextHandler from Orange.widgets.visualize.owscatterplot import OWScatterPlotBase, OWDataProjectionWidget -from orangecontrib.bioinformatics.utils.statistics import score_t_test, score_fold_change -from orangecontrib.bioinformatics.widgets.utils.gui import label_selection -from orangecontrib.bioinformatics.widgets.utils.data import GENE_ID_COLUMN, GENE_AS_ATTRIBUTE_NAME +from orangecontrib.bioinformatics.utils.statistics import score_fold_change +from orangecontrib.bioinformatics.widgets.utils.data import TableAnnotation +from orangecontrib.bioinformatics.widgets.ow_components import GeneScoringComponent class VolcanoGraph(OWScatterPlotBase): @@ -28,117 +31,127 @@ class Warning(OWDataProjectionWidget.Warning): 'Insufficient data to compute statistics.' 'More than one measurement per class should be provided ' ) - gene_enrichment = Msg('{}, {}.') - no_selected_gene_sets = Msg('No gene set selected, select them from Gene Sets box.') - class Error(OWDataProjectionWidget.Error): exclude_error = Msg('Target labels most exclude/include at least one value.') negative_values = Msg('Negative values in the input. The inputs cannot be in ratio scale.') - data_not_annotated = Msg('The input date is not annotated as expexted. Please refer to documentation.') + data_not_annotated = Msg('The input date is not annotated as expected. Please refer to documentation.') gene_column_id_missing = Msg('Can not identify genes column. Please refer to documentation.') - GRAPH_CLASS = VolcanoGraph + settingsHandler = DomainContextHandler() graph = SettingProvider(VolcanoGraph) - embedding_variables_names = ('log2 (ratio)', '-log10 (P_value)') + scoring_component = SettingProvider(GeneScoringComponent) - stored_selections = settings.ContextSetting([]) - current_group_index = settings.ContextSetting(0) + GRAPH_CLASS = VolcanoGraph + embedding_variables_names = ('log2 (ratio)', '-log10 (P_value)') def __init__(self): super().__init__() + self._data: Optional[Table] = None + self.genes_in_columns: Optional[str] = None + self.gene_id_column: Optional[str] = None + self.gene_id_attribute: Optional[str] = None + + self.fold: Optional[np.array] = None + self.log_p_values: Optional[np.array] = None + self.valid_data: Optional[np.array] = None def _add_controls(self): - box = gui.vBox(self.controlArea, "Target Labels") - self.group_selection_widget = label_selection.LabelSelectionWidget() - self.group_selection_widget.groupChanged.connect(self.on_target_values_changed) - self.group_selection_widget.groupSelectionChanged.connect(self.on_target_values_changed) - box.layout().addWidget(self.group_selection_widget) + box = gui.vBox(self.controlArea, True, margin=0) + self.scoring_component = GeneScoringComponent(self, box) + self.scoring_component.group_changed.connect(self.setup_plot) + self.scoring_component.selection_changed.connect(self.setup_plot) + self.scoring_component.score_method_changed.connect(self.setup_plot) + self.scoring_component.expression_threshold_changed.connect(self.setup_plot) super()._add_controls() self.gui.add_widgets([self.gui.ShowGridLines], self._plot_box) - def get_embedding(self): + def _compute(self): self.Error.exclude_error.clear() - group, target_indices = self.group_selection_widget.selected_split() - - if self.data and group is not None and target_indices: - X = self.data.X - I1 = label_selection.group_selection_mask(self.data, group, target_indices) - I2 = ~I1 - - # print(group) - if isinstance(group, label_selection.RowGroup): - X = X.T - - N1, N2 = np.count_nonzero(I1), np.count_nonzero(I2) + if self.data: + x = self.data.X + score_method = self.scoring_component.get_score_method() + i1 = self.scoring_component.get_selection_mask() + i2 = ~i1 - if not N1 or not N2: + n1, n2 = np.count_nonzero(i1), np.count_nonzero(i2) + if not n1 or not n2: self.Error.exclude_error() return - if N1 < 2 and N2 < 2: + if n1 < 2 and n2 < 2: self.Warning.insufficient_data() - X1, X2 = X[:, I1], X[:, I2] - - if np.any(X1 < 0.0) or np.any(X2 < 0): + x1, x2 = x[:, i1], x[:, i2] + if np.any(x1 < 0.0) or np.any(x2 < 0): self.Error.negative_values() - X1 = np.full_like(X1, np.nan) - X2 = np.full_like(X2, np.nan) + x1 = np.full_like(x1, np.nan) + x2 = np.full_like(x2, np.nan) + + with np.errstate(divide='ignore', invalid='ignore'): + self.fold = score_fold_change(x1, x2, axis=1, log=True) + _, p_values = score_method(x1, x2, axis=1, threshold=self.scoring_component.get_expression_threshold()) + self.log_p_values = np.log10(p_values) + + def get_embedding(self): + if self.data is None: + return None + + if self.fold is None or self.log_p_values is None: + return + + self.valid_data = np.isfinite(self.fold) & np.isfinite(self.log_p_values) + return np.array([self.fold, -self.log_p_values]).T + + def send_data(self): + group_sel, data, graph = None, self._get_projection_data(), self.graph + if graph.selection is not None: + group_sel = np.zeros(len(data), dtype=int) + group_sel[self.valid_data] = graph.selection - with np.errstate(divide="ignore", invalid="ignore"): - fold = score_fold_change(X1, X2, axis=1, log=True) - _, p_values = score_t_test(X1, X2, axis=1) - log_p_values = np.log10(p_values) + selected_data = self._get_selected_data(data, graph.get_selection(), group_sel) - self.valid_data = np.isfinite(fold) & np.isfinite(p_values) - return np.array([fold, -log_p_values]).T + if self.genes_in_columns and selected_data: + selected_data = Table.transpose(selected_data, feature_names_column='Feature name') + + self.Outputs.selected_data.send(selected_data) def setup_plot(self): + self._compute() super().setup_plot() for axis, var in (("bottom", 'log2 (ratio)'), ("left", '-log10 (P_value)')): self.graph.set_axis_title(axis, var) - def on_target_values_changed(self, index): - # Save the current selection to persistent settings - self.current_group_index = index - selected_indices = [ind.row() for ind in self.group_selection_widget.currentGroupSelection().indexes()] - - if self.current_group_index != -1 and selected_indices: - self.stored_selections[self.current_group_index] = selected_indices - - self.setup_plot() - def set_data(self, data): self.Warning.clear() self.Error.clear() - super().set_data(data) - self.group_selection_widget.set_data(self, self.data) - if self.data: - if not self.stored_selections: - self.stored_selections = [[0] for _ in self.group_selection_widget.targets] - self.group_selection_widget.set_selection() + if data: + self.genes_in_columns = data.attributes.get(TableAnnotation.gene_as_attr_name, None) + self.gene_id_column = data.attributes.get(TableAnnotation.gene_id_column, None) + self.gene_id_attribute = data.attributes.get(TableAnnotation.gene_id_attribute, None) + + if self.genes_in_columns: + self._data = data + # override default meta_attr_name value to avoid unexpected changes. + data = Table.transpose(data, meta_attr_name='Feature name') + + super().set_data(data) + self.scoring_component.initialize(self.data) def check_data(self): self.clear_messages() - use_attr_names = self.data.attributes.get(GENE_AS_ATTRIBUTE_NAME, None) - gene_id_column = self.data.attributes.get(GENE_ID_COLUMN, None) - if self.data is not None and (len(self.data) == 0 or len(self.data.domain) == 0): self.data = None - if use_attr_names is None: + if self.genes_in_columns is None: # Note: input data is not annotated properly. self.Error.data_not_annotated() self.data = None - if gene_id_column is None: - # Note: Can not identify genes column. - self.Error.gene_column_id_missing() - self.data = None - if __name__ == "__main__": - pass + from Orange.widgets.utils.widgetpreview import WidgetPreview + + WidgetPreview(OWVolcanoPlot).run() diff --git a/orangecontrib/bioinformatics/widgets/ow_components/__init__.py b/orangecontrib/bioinformatics/widgets/ow_components/__init__.py new file mode 100644 index 00000000..5459a5b4 --- /dev/null +++ b/orangecontrib/bioinformatics/widgets/ow_components/__init__.py @@ -0,0 +1,3 @@ +from .gene_scoring import GeneScoringComponent + +__all__ = ('GeneScoringComponent',) diff --git a/orangecontrib/bioinformatics/widgets/ow_components/gene_scoring.py b/orangecontrib/bioinformatics/widgets/ow_components/gene_scoring.py new file mode 100644 index 00000000..f5d345cc --- /dev/null +++ b/orangecontrib/bioinformatics/widgets/ow_components/gene_scoring.py @@ -0,0 +1,230 @@ +from typing import Callable +from itertools import chain +from collections import namedtuple, defaultdict + +import numpy as np + +from AnyQt.QtCore import QObject, QItemSelection, QItemSelectionModel +from AnyQt.QtCore import pyqtSignal as Signal +from AnyQt.QtWidgets import QListView, QListWidget + +from Orange.data import Table, DiscreteVariable +from Orange.widgets.gui import comboBox, widgetBox, doubleSpin +from Orange.widgets.widget import OWComponent +from Orange.widgets.settings import ContextSetting + +from orangecontrib.bioinformatics.utils.statistics import score_t_test, score_mann_whitney, score_hypergeometric_test + +column_group = namedtuple('ColumnGroup', ['name', 'key', 'values']) +row_group = namedtuple('RowGroup', ['name', 'var', 'values']) + + +class GeneScoringComponent(OWComponent, QObject): + + # Current group/root index has changed. + group_changed = Signal(int) + # Selection for the current group/root has changed. + selection_changed = Signal(int) + # Scoring method has changed. + score_method_changed = Signal(int) + # Expression threshold changed + expression_threshold_changed = Signal(float) + + # component settings + current_method_index: int + current_method_index = ContextSetting(0) + current_group_index: int + current_group_index = ContextSetting(0) + stored_selections: dict + stored_selections = ContextSetting({}) + # default threshold defining expressed genes for Hypergeometric Test + expression_threshold_value: int + expression_threshold_value = 1.0 + + score_methods = [ + ('T-test', score_t_test), + ('Mann-Whitney', score_mann_whitney), + ('Hypergeometric Test', score_hypergeometric_test), + ] + + def __init__(self, parent_widget, parent_component): + QObject.__init__(self) + OWComponent.__init__(self, parent_widget) + + self.score_method_combo = comboBox( + parent_component, self, 'current_method_index', label='Method', callback=self.__on_score_method_changed + ) + self.score_method_combo.addItems([name for name, _ in self.score_methods]) + + self.expression_threshold_box = widgetBox(parent_component, 'Expression threshold', margin=0) + self.expression_threshold_box.setFlat(True) + self.expression_threshold_spinbox = doubleSpin( + self.expression_threshold_box, + self, + 'expression_threshold_value', + minv=0, + maxv=1e2, + step=1e-2, + callback=self.__on_expression_threshold_changed, + callbackOnReturn=True, + ) + self.__show_expression_threshold_spinbox() + + self.group_combo = comboBox( + parent_component, self, 'current_group_index', label='Group', callback=self.__on_group_index_changed + ) + + self.list_widget = QListWidget() + self.list_widget.setSelectionMode(QListView.ExtendedSelection) + self.list_widget.selectionModel().selectionChanged.connect(self.__on_selection_changed) + + box = widgetBox(parent_component, 'Values', margin=0) + box.setFlat(True) + box.layout().addWidget(self.list_widget) + + self.groups = {} + self.data = None + + def initialize(self, data): + """ Initialize widget state after receiving new data. + """ + + if data is not None: + self.data = data + + column_groups, row_groups = self.group_candidates(data) + self.groups = {index: value for index, value in enumerate(column_groups + row_groups)} + + self.group_combo.clear() + self.group_combo.addItems([str(x.name) for x in self.groups.values()]) + self.group_combo.setCurrentIndex(self.current_group_index) + + self.__populate_list_widget() + + def get_selection_mask(self): + """ Return the selection masks for the group. + """ + + group, indices = self.selected_split() + + if isinstance(group, column_group): + selected = [group.values[i] for i in indices] + target = {(group.key, value) for value in selected} + _i = [bool(set(var.attributes.items()).intersection(target)) for var in self.data.domain.attributes] + return np.array(_i, dtype=bool) + + elif isinstance(group, row_group): + target = set(indices) + x, _ = self.data.get_column_view(group.var) + _i = np.zeros_like(x, dtype=bool) + for i in target: + _i |= x == i + return _i + else: + raise TypeError("column_group or row_group expected, got {}".format(type(group).__name__)) + + def selected_split(self): + group_index = self.group_combo.currentIndex() + if not (0 <= group_index < len(self.groups)): + return None, [] + + group = self.groups[group_index] + selection = [model_idx.row() for model_idx in self.list_widget.selectedIndexes()] + + return group, selection + + def get_expression_threshold(self) -> float: + return self.expression_threshold_value + + def get_score_method(self) -> Callable: + _, method = self.score_methods[self.current_method_index] + return method + + def __populate_list_widget(self) -> None: + if self.list_widget is not None: + self.list_widget.selectionModel().selectionChanged.disconnect(self.__on_selection_changed) + self.list_widget.clear() + + target = self.groups.get(self.current_group_index, None) + if target is not None: + self.list_widget.addItems(target.values) + + self.list_widget.setSizeAdjustPolicy(QListWidget.AdjustToContents) + + self.__set_selection() + self.list_widget.selectionModel().selectionChanged.connect(self.__on_selection_changed) + + def __show_expression_threshold_spinbox(self): + self.expression_threshold_box.setHidden(self.get_score_method().__name__ != score_hypergeometric_test.__name__) + + def __on_expression_threshold_changed(self): + self.expression_threshold_changed.emit(self.expression_threshold_value) + + def __on_score_method_changed(self) -> None: + self.__show_expression_threshold_spinbox() + self.score_method_changed.emit(self.current_method_index) + + def __on_group_index_changed(self) -> None: + self.__populate_list_widget() + self.group_changed.emit(self.current_group_index) + + def __on_selection_changed(self) -> None: + self.__store_selection() + self.selection_changed.emit(self.current_group_index) + + def __store_selection(self) -> None: + self.stored_selections[self.current_group_index] = tuple( + model_idx.row() for model_idx in self.list_widget.selectedIndexes() + ) + + def __set_selection(self) -> None: + # Restore previous selection for root (if available) + group_index = self.current_group_index + indices = [ + self.list_widget.indexFromItem(self.list_widget.item(index)) + for index in self.stored_selections.get(group_index, []) + ] + selection = QItemSelection() + for ind in indices: + selection.select(ind, ind) + self.list_widget.selectionModel().select(selection, QItemSelectionModel.ClearAndSelect) + + @staticmethod + def group_candidates(data: Table): + items = [attr.attributes.items() for attr in data.domain.attributes] + items = list(chain(*items)) + + targets = defaultdict(set) + for label, value in items: + targets[label].add(value) + + # Need at least 2 distinct values or key + targets = [(key, sorted(vals)) for key, vals in targets.items() if len(vals) >= 2] + + column_groups = [column_group(key, key, values) for key, values in sorted(targets)] + + disc_vars = [ + var + for var in data.domain.class_vars + data.domain.metas + if isinstance(var, DiscreteVariable) and len(var.values) >= 2 + ] + + row_groups = [row_group(var.name, var, var.values) for var in disc_vars] + return column_groups, row_groups + + +if __name__ == "__main__": + from Orange.widgets.utils.widgetpreview import WidgetPreview + from Orange.widgets.widget import OWWidget + from Orange.data import Table + from Orange.widgets.settings import SettingProvider + + class MockWidget(OWWidget): + name = "Mock" + scoring_component = SettingProvider(GeneScoringComponent) + + def __init__(self): + self.scoring_component = GeneScoringComponent(self, self.mainArea) + self.scoring_component.initialize(Table('iris')) + + WidgetPreview(MockWidget).run()