From 2e1d4c3f20dbebcdd203a0a70bacea7d3dced825 Mon Sep 17 00:00:00 2001 From: specialized boy <41695878+HarshaSatyavardhan@users.noreply.github.com> Date: Mon, 9 Jan 2023 04:43:19 +0530 Subject: [PATCH] numpy style docstrings --- dodiscover/replearning/gin.py | 73 +++++++++++++++++++++++++++++------ 1 file changed, 61 insertions(+), 12 deletions(-) diff --git a/dodiscover/replearning/gin.py b/dodiscover/replearning/gin.py index 3609a6f9..2110d312 100644 --- a/dodiscover/replearning/gin.py +++ b/dodiscover/replearning/gin.py @@ -3,19 +3,30 @@ TODO: Need type hints """ from causallearn.search.HiddenCausal.GIN.GIN import GIN as GIN_ - from pywhy_graphs import CPDAG class GIN: + """Wrapper for GIN in the causal-learn package. + + Parameters + ---------- + indep_test_method : str + The method to use for testing independence, by default "kci" + alpha : float + The significance level for independence tests, by default 0.05 + + Attributes + ---------- + graph_ : CPDAG + The estimated causal graph. + causal_learn_graph : CausalGraph + The causal graph object from causal-learn. + causal_ordering : list of str + The causal ordering of the variables. """ - Dodiscover wrapper for GIN in causal-learn package - """ - def __init__(self, indep_test_method="kci", alpha=0.05): - """ - Using default parameters from GIN. - TODO: Add full set of GIN parameters with default - TODO: Add a base class - """ + def __init__(self, indep_test_method: str="kci", alpha: float=0.05): + """Initialize GIN object with specified parameters.""" + self.graph_ = None # Should be in a base class # GIN default parameters. @@ -27,8 +38,33 @@ def __init__(self, indep_test_method="kci", alpha=0.05): self.causal_ordering = None def _causal_learn_to_pdag(self, cl_graph): - """""" + """Convert a causal-learn graph to a CPDAG object. + + Parameters + ---------- + cl_graph : CausalGraph + The causal-learn graph to be converted. + + Returns + ------- + pdag : CPDAG + The equivalent CPDAG object. + """ def _extract_edgelists(adj_mat, names): + """Extracts directed and undirected edges from an adjacency matrix. + + Parameters: + - adj_mat: numpy array + The adjacency matrix of the graph. + - names: list of str + The names of the nodes in the graph. + + Returns: + - directed_edges: list of tuples + The directed edges of the graph. + - undirected_edges: list of sets + The undirected edges of the graph. + """ directed_edges = [] undirected_edges = [] for i, row in enumerate(adj_mat): @@ -51,8 +87,21 @@ def _extract_edgelists(adj_mat, names): return pdag - def fit(self, data, context): + def fit(self, data: 'DataFrame', context: 'DataFrame'): """Fit to data. + + Parameters + ---------- + data : DataFrame + The data to fit to. + context : DataFrame + The context variables to use as constraints. + + Returns + ------- + self : GIN + The fitted GIN object. + TODO: How to apply context constraints? Need to create issue""" causal_learn_graph, ordering = GIN_( data.to_numpy(), @@ -61,4 +110,4 @@ def fit(self, data, context): ) self.causal_learn_graph = causal_learn_graph self.causal_ordering = ordering - self.graph_ = self._causal_learn_to_pdag(causal_learn_graph) + self.graph_ = self._causal_learn_to_pdag(causal_learn_graph) \ No newline at end of file