diff --git a/nx_cugraph/classes/digraph.py b/nx_cugraph/classes/digraph.py index 178bf44f1..a73ca8d0b 100644 --- a/nx_cugraph/classes/digraph.py +++ b/nx_cugraph/classes/digraph.py @@ -34,6 +34,7 @@ __all__ = ["CudaDiGraph", "DiGraph"] networkx_api = nxcg.utils.decorators.networkx_class(nx.DiGraph) +gpu_cpu_api = nxcg.utils.decorators._gpu_cpu_api(nx.DiGraph, __name__) class DiGraph(nx.DiGraph, Graph): @@ -105,6 +106,43 @@ def to_cudagraph_class(cls) -> type[CudaDiGraph]: def to_networkx_class(cls) -> type[nx.DiGraph]: return nx.DiGraph + ########################## + # Networkx graph methods # + ########################## + + # Dispatch to nx.DiGraph or CudaDiGraph + __contains__ = gpu_cpu_api("__contains__") + __len__ = gpu_cpu_api("__len__") + __iter__ = gpu_cpu_api("__iter__") + + @networkx_api + def clear(self) -> None: + cudagraph = self._cudagraph if self._is_on_gpu else None + if self._is_on_cpu: + super().clear() + if cudagraph is not None: + cudagraph.clear() + self._set_cudagraph(cudagraph, clear_cpu=False) + + @networkx_api + def clear_edges(self) -> None: + cudagraph = self._cudagraph if self._is_on_gpu else None + if self._is_on_cpu: + super().clear_edges() + if cudagraph is not None: + cudagraph.clear_edges() + self._set_cudagraph(cudagraph, clear_cpu=False) + + get_edge_data = gpu_cpu_api("get_edge_data", edge_data=True) + has_edge = gpu_cpu_api("has_edge") + neighbors = gpu_cpu_api("neighbors") + has_node = gpu_cpu_api("has_node") + nbunch_iter = gpu_cpu_api("nbunch_iter") + number_of_edges = Graph.number_of_edges + number_of_nodes = gpu_cpu_api("number_of_nodes") + order = gpu_cpu_api("order") + successors = gpu_cpu_api("successors") + class CudaDiGraph(CudaGraph): ################# @@ -244,6 +282,7 @@ def to_undirected(self, reciprocal=False, as_view=False): rv.graph.update(deepcopy(self.graph)) return rv + successors = CudaGraph.neighbors # Alias # Many more methods to implement... ################### diff --git a/nx_cugraph/classes/graph.py b/nx_cugraph/classes/graph.py index 11c6215d8..018636228 100644 --- a/nx_cugraph/classes/graph.py +++ b/nx_cugraph/classes/graph.py @@ -48,6 +48,7 @@ __all__ = ["CudaGraph", "Graph"] networkx_api = nxcg.utils.decorators.networkx_class(nx.Graph) +gpu_cpu_api = nxcg.utils.decorators._gpu_cpu_api(nx.Graph, __name__) # The "everything" cache key is an internal implementation detail of NetworkX # that may change between releases. @@ -63,6 +64,8 @@ True, # Include all node values # `.graph` attributes are always included now ) +_EDGE_KEY_INDEX = 0 +_NODE_KEY_INDEX = 1 # Use to indicate when a full conversion to GPU failed so we don't try again. _CANT_CONVERT_TO_GPU = "_CANT_CONVERT_TO_GPU" @@ -210,8 +213,7 @@ def _cudagraph(self): cache[_CACHE_KEY] = Gcg return Gcg - @_cudagraph.setter - def _cudagraph(self, val, *, clear_cpu=True): + def _set_cudagraph(self, val, *, clear_cpu=True): """Set the full ``CudaGraph`` for this graph, or remove from device if None.""" if (cache := getattr(self, "__networkx_cache__", None)) is None: # Should we warn? @@ -228,6 +230,32 @@ def _cudagraph(self, val, *, clear_cpu=True): for key in self._nx_attrs: self.__dict__[key] = None + def _get_cudagraph(self, *, edge_data=False, node_data=False): + """Get a valid cached ``CudaGraph``, optionally with edge or node data. + + Returns None if no valid graph is found. + + Parameters + ---------- + edge_data : bool, default False + Whether to return a CudaGraph with edge data. + node_data : bool, default False + Whether to return a CudaGraph with node data. + """ + nx_cache = getattr(self, "__networkx_cache__", None) + if nx_cache is None or _CANT_CONVERT_TO_GPU in nx_cache: + return None + cache = nx_cache.get("backends", {}).get("cugraph", {}) + if _CACHE_KEY in cache: + # Always return the canonical CudaGraph if it exists + return cache[_CACHE_KEY] + for key, val in cache.items(): + if (key[_EDGE_KEY_INDEX] is True or edge_data is False) and ( + key[_NODE_KEY_INDEX] is True or node_data is False + ): + return val + return None + @nx.Graph.name.setter def name(self, s): # Don't clear the cache when setting the name, since `.graph` is shared. @@ -509,6 +537,54 @@ def from_dcsc( **attr, ) + ########################## + # Networkx graph methods # + ########################## + + # Dispatch to nx.Graph or CudaGraph + __contains__ = gpu_cpu_api("__contains__") + __len__ = gpu_cpu_api("__len__") + __iter__ = gpu_cpu_api("__iter__") + + @networkx_api + def clear(self) -> None: + cudagraph = self._cudagraph if self._is_on_gpu else None + if self._is_on_cpu: + super().clear() + if cudagraph is not None: + cudagraph.clear() + self._set_cudagraph(cudagraph, clear_cpu=False) + + @networkx_api + def clear_edges(self) -> None: + cudagraph = self._cudagraph if self._is_on_gpu else None + if self._is_on_cpu: + super().clear_edges() + if cudagraph is not None: + cudagraph.clear_edges() + self._set_cudagraph(cudagraph, clear_cpu=False) + + get_edge_data = gpu_cpu_api("get_edge_data", edge_data=True) + has_edge = gpu_cpu_api("has_edge") + neighbors = gpu_cpu_api("neighbors") + has_node = gpu_cpu_api("has_node") + nbunch_iter = gpu_cpu_api("nbunch_iter") + + @networkx_api + def number_of_edges( + self, u: NodeKey | None = None, v: NodeKey | None = None + ) -> int: + if u is not None or v is not None: + # NotImplemented by CudaGraph + nx_class = self.to_networkx_class() + return nx_class.number_of_edges(self, u, v) + return self._number_of_edges(u, v) + + _number_of_edges = gpu_cpu_api("number_of_edges") + number_of_nodes = gpu_cpu_api("number_of_nodes") + order = gpu_cpu_api("order") + # Future work: implement more graph methods, and handle e.g. `copy` + class CudaGraph: # Tell networkx to dispatch calls with this object to nx-cugraph @@ -803,7 +879,7 @@ def to_undirected(self, as_view: bool = False) -> CudaGraph: def _to_compat_graph(self) -> Graph: rv = self._to_compat_graph_class()() - rv._cudagraph = self + rv._set_cudagraph(self) return rv # Not implemented... diff --git a/nx_cugraph/classes/multidigraph.py b/nx_cugraph/classes/multidigraph.py index 5a6595567..0671d21a8 100644 --- a/nx_cugraph/classes/multidigraph.py +++ b/nx_cugraph/classes/multidigraph.py @@ -23,6 +23,7 @@ __all__ = ["CudaMultiDiGraph", "MultiDiGraph"] networkx_api = nxcg.utils.decorators.networkx_class(nx.MultiDiGraph) +gpu_cpu_api = nxcg.utils.decorators._gpu_cpu_api(nx.MultiDiGraph, __name__) class MultiDiGraph(nx.MultiDiGraph, MultiGraph, DiGraph): @@ -50,6 +51,26 @@ def to_cudagraph_class(cls) -> type[CudaMultiDiGraph]: def to_networkx_class(cls) -> type[nx.MultiDiGraph]: return nx.MultiDiGraph + ########################## + # Networkx graph methods # + ########################## + + # Dispatch to nx.MultiDiGraph or CudaMultiDiGraph + __contains__ = gpu_cpu_api("__contains__") + __len__ = gpu_cpu_api("__len__") + __iter__ = gpu_cpu_api("__iter__") + clear = DiGraph.clear + clear_edges = DiGraph.clear_edges + get_edge_data = gpu_cpu_api("get_edge_data", edge_data=True) + has_edge = gpu_cpu_api("has_edge") + neighbors = gpu_cpu_api("neighbors") + has_node = gpu_cpu_api("has_node") + nbunch_iter = gpu_cpu_api("nbunch_iter") + number_of_edges = MultiGraph.number_of_edges + number_of_nodes = gpu_cpu_api("number_of_nodes") + order = gpu_cpu_api("order") + successors = gpu_cpu_api("successors") + class CudaMultiDiGraph(CudaMultiGraph, CudaDiGraph): is_directed = classmethod(MultiDiGraph.is_directed.__func__) diff --git a/nx_cugraph/classes/multigraph.py b/nx_cugraph/classes/multigraph.py index c8c8f1dfb..3f0204f69 100644 --- a/nx_cugraph/classes/multigraph.py +++ b/nx_cugraph/classes/multigraph.py @@ -37,6 +37,7 @@ __all__ = ["MultiGraph", "CudaMultiGraph"] networkx_api = nxcg.utils.decorators.networkx_class(nx.MultiGraph) +gpu_cpu_api = nxcg.utils.decorators._gpu_cpu_api(nx.MultiGraph, __name__) class MultiGraph(nx.MultiGraph, Graph): @@ -277,6 +278,36 @@ def from_dcsc( **attr, ) + ########################## + # Networkx graph methods # + ########################## + + # Dispatch to nx.MultiGraph or CudaMultiGraph + __contains__ = gpu_cpu_api("__contains__") + __len__ = gpu_cpu_api("__len__") + __iter__ = gpu_cpu_api("__iter__") + clear = Graph.clear + clear_edges = Graph.clear_edges + get_edge_data = gpu_cpu_api("get_edge_data", edge_data=True) + has_edge = gpu_cpu_api("has_edge") + neighbors = gpu_cpu_api("neighbors") + has_node = gpu_cpu_api("has_node") + nbunch_iter = gpu_cpu_api("nbunch_iter") + + @networkx_api + def number_of_edges( + self, u: NodeKey | None = None, v: NodeKey | None = None + ) -> int: + if u is not None or v is not None: + # NotImplemented by CudaGraph + nx_class = self.to_networkx_class() + return nx_class.number_of_edges(self, u, v) + return self._number_of_edges(u, v) + + _number_of_edges = gpu_cpu_api("number_of_edges") + number_of_nodes = gpu_cpu_api("number_of_nodes") + order = gpu_cpu_api("order") + class CudaMultiGraph(CudaGraph): # networkx properties @@ -390,14 +421,13 @@ def get_edge_data( mask = (self.src_indices == u) & (self.dst_indices == v) if not mask.any(): return default - if self.edge_keys is None: + if self.edge_keys is None and key is not None: if self.edge_indices is None: self._calculate_edge_indices() - if key is not None: - try: - mask = mask & (self.edge_indices == key) - except TypeError: - return default + try: + mask = mask & (self.edge_indices == key) + except TypeError: + return default indices = cp.nonzero(mask)[0] if indices.size == 0: return default diff --git a/nx_cugraph/tests/test_graph_methods.py b/nx_cugraph/tests/test_graph_methods.py index 40a361b10..abc71ed77 100644 --- a/nx_cugraph/tests/test_graph_methods.py +++ b/nx_cugraph/tests/test_graph_methods.py @@ -17,6 +17,8 @@ from .testing_utils import assert_graphs_equal +CREATE_USING = [nxcg.Graph, nxcg.DiGraph, nxcg.MultiGraph, nxcg.MultiDiGraph] + def _create_Gs(): rv = [] @@ -65,3 +67,51 @@ def test_multidigraph_to_undirected(): Gcg = nxcg.CudaMultiDiGraph(Gnx) with pytest.raises(NotImplementedError): Gcg.to_undirected() + + +@pytest.mark.parametrize("create_using", CREATE_USING) +@pytest.mark.parametrize( + "method", + [ + ("__iter__", ()), + ("__len__", ()), + ("clear", ()), + ("clear_edges", ()), + ("nbunch_iter", ()), + ("number_of_edges", ()), + ("number_of_nodes", ()), + ("order", ()), + ("__contains__", (0,)), + ("neighbors", (0,)), + ("has_node", (0,)), + ("successors", (0,)), + ("get_edge_data", (0, 1)), + ("has_edge", (0, 1)), + ("nbunch_iter", ([0, 1],)), + ], +) +@pytest.mark.parametrize("where", ["gpu", "cpu"]) +def test_method_does_not_convert_to_cpu_or_gpu(create_using, method, where): + attr, args = method + if attr == "successors" and not create_using.is_directed(): + return + G = nxcg.complete_graph(3, create_using=create_using) + is_on_gpu = where == "gpu" + is_on_cpu = where == "cpu" + if is_on_cpu: + G.add_edge(10, 20) + assert G._is_on_gpu == is_on_gpu + assert G._is_on_cpu == is_on_cpu + getattr(G, attr)(*args) + assert G._is_on_gpu == is_on_gpu + assert G._is_on_cpu == is_on_cpu + # Also usable from the class and dispatches correctly + func = getattr(create_using, attr) + func(G, *args) + assert G._is_on_gpu == is_on_gpu + assert G._is_on_cpu == is_on_cpu + # Basic "looks like networkx" checks + nx_class = create_using.to_networkx_class() + nx_func = getattr(nx_class, attr) + assert func.__name__ == nx_func.__name__ + assert func.__module__.startswith("nx_cugraph") diff --git a/nx_cugraph/utils/decorators.py b/nx_cugraph/utils/decorators.py index 15cb010e9..26fce7350 100644 --- a/nx_cugraph/utils/decorators.py +++ b/nx_cugraph/utils/decorators.py @@ -167,3 +167,31 @@ def _default_should_run(*args, **kwargs): def _restore_networkx_dispatched(name): return getattr(BackendInterface, name) + + +def _gpu_cpu_api(nx_class, module_name): + def _gpu_cpu_graph_method(attr, *, edge_data=False, node_data=False): + """Dispatch property to NetworkX or CudaGraph based on cache. + + For example, this will use any cached CudaGraph for ``len(G)``, which + prevents creating NetworkX data structures. + """ + nx_func = getattr(nx_class, attr) + + def inner(self, *args, **kwargs): + cuda_graph = self._get_cudagraph(edge_data=edge_data, node_data=node_data) + if cuda_graph is None: + return nx_func(self, *args, **kwargs) + return getattr(cuda_graph, attr)(*args, **kwargs) + + inner.__name__ = nx_func.__name__ + inner.__doc__ = nx_func.__doc__ + inner.__qualname__ = nx_func.__qualname__ + inner.__defaults__ = nx_func.__defaults__ + inner.__kwdefaults__ = nx_func.__kwdefaults__ + inner.__module__ = module_name + inner.__dict__.update(nx_func.__dict__) + inner.__wrapped__ = nx_func + return inner + + return _gpu_cpu_graph_method