diff --git a/navis/morpho/subset.py b/navis/morpho/subset.py index a003dba0..a8af6aee 100644 --- a/navis/morpho/subset.py +++ b/navis/morpho/subset.py @@ -16,25 +16,25 @@ import numpy as np import networkx as nx -from typing import Union, Sequence +from typing import Union, Sequence, Callable from .. import utils, config, core, graph # Set up logging logger = config.get_logger(__name__) -__all__ = sorted(['subset_neuron']) +__all__ = sorted(["subset_neuron"]) +@utils.map_neuronlist(desc="Subsetting", allow_parallel=True) @utils.lock_neuron -def subset_neuron(x: Union['core.TreeNeuron', 'core.MeshNeuron'], - subset: Union[Sequence[Union[int, str]], - nx.DiGraph, - pd.DataFrame], - inplace: bool = False, - keep_disc_cn: bool = False, - prevent_fragments: bool = False - ) -> 'core.NeuronObject': +def subset_neuron( + x: Union["core.TreeNeuron", "core.MeshNeuron"], + subset: Union[Sequence[Union[int, str]], nx.DiGraph, pd.DataFrame, Callable], + inplace: bool = False, + keep_disc_cn: bool = False, + prevent_fragments: bool = False, +) -> "core.NeuronObject": """Subset a neuron to a given set of nodes/vertices. Note that for ``MeshNeurons`` it is not guaranteed that all vertices in @@ -43,20 +43,26 @@ def subset_neuron(x: Union['core.TreeNeuron', 'core.MeshNeuron'], Parameters ---------- - x : TreeNeuron | MeshNeuron | Dotprops - Neuron to subset. - subset : list-like | set | NetworkX.Graph | pandas.DataFrame - For TreeNeurons: - - node IDs to subset the neuron to - - a boolean mask - - DataFrame with ``node_id`` column - For MeshNeurons: - - vertex indices - - a boolean mask - For Dotprops: - - point indices - - a boolean mask - + x : TreeNeuron | MeshNeuron | Dotprops | NeuronList + Neuron to subset. When passing a NeuronList, it's advised + to use a function for `subset` (see below). + subset : list-like | set | NetworkX.Graph | pandas.DataFrame | Callable + Subset of the neuron to keep. Depending on the neuron: + For TreeNeurons: + - node IDs + - a boolean mask matching the number of nodes + - DataFrame with ``node_id`` column + For MeshNeurons: + - vertex indices + - a boolean mask matching either the number of + vertices or faces + For Dotprops: + - point indices + - a boolean mask matching the number of points + Alternatively, you can pass a function that accepts + a neuron and returns a suitable `subset` as described + above. This is useful e.g. when wanting to subset a + list of neurons. keep_disc_cn : bool, optional If False, will remove disconnected connectors that have "lost" their parent node/vertex. @@ -69,7 +75,7 @@ def subset_neuron(x: Union['core.TreeNeuron', 'core.MeshNeuron'], Returns ------- - TreeNeuron | MeshNeuron + TreeNeuron | MeshNeuron | Dotprops | NeuronList Examples -------- @@ -85,7 +91,17 @@ def subset_neuron(x: Union['core.TreeNeuron', 'core.MeshNeuron'], >>> # Flatten segments into list of nodes >>> nodes_to_keep = [n for s in short_segs for n in s] >>> # Subset neuron - >>> n_short = navis.subset_neuron(n, nodes_to_keep) + >>> n_short = navis.subset_neuron(n, subset=nodes_to_keep) + + Subset multiple neurons using a callable + + >>> import navis + >>> nl = navis.example_neurons(2) + >>> # Subset neurons to all leaf nodes + >>> nl_end = navis.subset_neuron( + ... nl, + ... subset=lambda x: x.leafs.node_id + ... ) See Also -------- @@ -98,35 +114,43 @@ def subset_neuron(x: Union['core.TreeNeuron', 'core.MeshNeuron'], if isinstance(x, core.NeuronList) and len(x) == 1: x = x[0] - utils.eval_param(x, name='x', - allowed_types=(core.TreeNeuron, core.MeshNeuron, core.Dotprops)) + utils.eval_param( + x, name="x", allowed_types=(core.TreeNeuron, core.MeshNeuron, core.Dotprops) + ) + + if callable(subset): + subset = subset(x) # Make a copy of the neuron if not inplace: x = x.copy() # We have to run this in a separate function so that the lock is applied # to the copy - subset_neuron(x, - subset=subset, - inplace=True, - keep_disc_cn=keep_disc_cn, - prevent_fragments=prevent_fragments) + subset_neuron( + x, + subset=subset, + inplace=True, + keep_disc_cn=keep_disc_cn, + prevent_fragments=prevent_fragments, + ) return x if isinstance(x, core.TreeNeuron): - x = _subset_treeneuron(x, - subset=subset, - keep_disc_cn=keep_disc_cn, - prevent_fragments=prevent_fragments) + x = _subset_treeneuron( + x, + subset=subset, + keep_disc_cn=keep_disc_cn, + prevent_fragments=prevent_fragments, + ) elif isinstance(x, core.MeshNeuron): - x = _subset_meshneuron(x, - subset=subset, - keep_disc_cn=keep_disc_cn, - prevent_fragments=prevent_fragments) + x = _subset_meshneuron( + x, + subset=subset, + keep_disc_cn=keep_disc_cn, + prevent_fragments=prevent_fragments, + ) elif isinstance(x, core.Dotprops): - x = _subset_dotprops(x, - subset=subset, - keep_disc_cn=keep_disc_cn) + x = _subset_dotprops(x, subset=subset, keep_disc_cn=keep_disc_cn) return x @@ -134,23 +158,25 @@ def subset_neuron(x: Union['core.TreeNeuron', 'core.MeshNeuron'], def _subset_dotprops(x, subset, keep_disc_cn): """Subset Dotprops.""" if not utils.is_iterable(subset): - raise TypeError('Can only subset Dotprops to list, set or ' - f'numpy.ndarray, not "{type(subset)}"') + raise TypeError( + "Can only subset Dotprops to list, set or " + f'numpy.ndarray, not "{type(subset)}"' + ) subset = utils.make_iterable(subset) # Convert indices to mask if subset.dtype == bool: - if subset.shape != (x.points.shape[0], ): - raise ValueError('Boolean mask must be of same length as points.') + if subset.shape != (x.points.shape[0],): + raise ValueError("Boolean mask must be of same length as points.") mask = subset else: mask = np.isin(np.arange(0, len(x.points)), subset) # Filter connectors if not keep_disc_cn and x.has_connectors: - if 'point' not in x.connectors.columns: - x.connectors['point'] = x.snap(x.connectors[['x', 'y', 'z']].values)[0] + if "point" not in x.connectors.columns: + x.connectors["point"] = x.snap(x.connectors[["x", "y", "z"]].values)[0] if subset.dtype == bool: subset = np.arange(0, len(x.points))[subset] @@ -161,7 +187,7 @@ def _subset_dotprops(x, subset, keep_disc_cn): # Make old -> new indices map new_ix = dict(zip(subset, np.arange(0, len(subset)))) - x.connectors['point'] = x.connectors.point.map(new_ix) + x.connectors["point"] = x.connectors.point.map(new_ix) # Mask vectors # This will also trigger re-calculation which is necessary for two reasons: @@ -187,16 +213,24 @@ def _subset_dotprops(x, subset, keep_disc_cn): def _subset_meshneuron(x, subset, keep_disc_cn, prevent_fragments): """Subset MeshNeuron.""" if not utils.is_iterable(subset): - raise TypeError('Can only subset MeshNeuron to list, set or ' - f'numpy.ndarray, not "{type(subset)}"') + raise TypeError( + "Can only subset MeshNeuron to list, set or " + f'numpy.ndarray, not "{type(subset)}"' + ) subset = utils.make_iterable(subset) - # Convert mask to indices + # Convert mask to vertex indices if subset.dtype == bool: - if subset.shape != (x.vertices.shape[0], ): - raise ValueError('Boolean mask must be of same length as vertices.') - subset = np.arange(0, len(x.vertices))[subset] + if subset.shape[0] == x.vertices.shape[0]: + subset = np.arange(len(x.vertices))[subset] + elif subset.shape[0] == x.faces.shape[0]: + # Translate face mask to vertex indices + subset = np.unique(x.faces[subset]) + else: + raise ValueError( + "Boolean mask must be of same length as vertices or faces." + ) if prevent_fragments: # Generate skeleton @@ -209,9 +243,11 @@ def _subset_meshneuron(x, subset, keep_disc_cn, prevent_fragments): subset = np.arange(0, len(x.vertices))[np.isin(sk.vertex_map, subset)] # Filter connectors + # (connectors are associated with vertices, not faces which is why + # our `subset` is always a list of vertex indices) if not keep_disc_cn and x.has_connectors: - if 'vertex_id' not in x.connectors.columns: - x.connectors['vertex'] = x.snap(x.connectors[['x', 'y', 'z']].values)[0] + if "vertex_id" not in x.connectors.columns: + x.connectors["vertex"] = x.snap(x.connectors[["x", "y", "z"]].values)[0] x._connectors = x.connectors[x.connectors.vertex.isin(subset)].copy() x._connectors.reset_index(inplace=True, drop=True) @@ -219,18 +255,10 @@ def _subset_meshneuron(x, subset, keep_disc_cn, prevent_fragments): # Make old -> new indices map new_ix = dict(zip(subset, np.arange(0, len(subset)))) - x.connectors['vertex'] = x.connectors.vertex.map(new_ix) + x.connectors["vertex"] = x.connectors.vertex.map(new_ix) - # Subset the mesh (by faces) - # Build the mask bit by bit to be more efficient: - subset_faces = np.full(len(x.faces), True) - for i in range(3): - subset_faces[subset_faces] = np.isin(x.faces[subset_faces, i], subset) - subset_faces = np.where(subset_faces)[0] - - if len(subset_faces): - submesh = x.trimesh.submesh([subset_faces], append=True) - x.vertices, x.faces = submesh.vertices, submesh.faces + if len(subset): + x.vertices, x.faces = submesh(x, vertex_index=subset) else: x.vertices, x.faces = np.empty((0, 3)), np.empty((0, 3)) @@ -247,8 +275,10 @@ def _subset_treeneuron(x, subset, keep_disc_cn, prevent_fragments): # This forces subset into numpy array (important for e.g. sets) subset = utils.make_iterable(subset) else: - raise TypeError('Can only subset to list, set, numpy.ndarray or' - f'networkx.Graph, not "{type(subset)}"') + raise TypeError( + "Can only subset to list, set, numpy.ndarray or" + f'networkx.Graph, not "{type(subset)}"' + ) if prevent_fragments: subset, new_root = graph.connected_subgraph(x, subset) @@ -268,41 +298,54 @@ def _subset_treeneuron(x, subset, keep_disc_cn, prevent_fragments): # Make sure that there are root nodes # This is the fastest "pandorable" way: instead of overwriting the column, # concatenate a new column to this DataFrame - x._nodes = pd.concat([x.nodes.drop('parent_id', inplace=False, axis=1), # type: ignore # no stubs for concat - x.nodes[['parent_id']].where(x.nodes.parent_id.isin(x.nodes.node_id.values), - other=-1, inplace=False)], - axis=1) + x._nodes = pd.concat( + [ + x.nodes.drop("parent_id", inplace=False, axis=1), # type: ignore # no stubs for concat + x.nodes[["parent_id"]].where( + x.nodes.parent_id.isin(x.nodes.node_id.values), other=-1, inplace=False + ), + ], + axis=1, + ) # Make sure any new roots or leafs are properly typed # We won't produce new slabs but roots and leaves might change - x.nodes.loc[x.nodes.parent_id < 0, 'type'] = 'root' - x.nodes.loc[(~x.nodes.node_id.isin(x.nodes.parent_id.values) - & (x.nodes.parent_id >= 0)), 'type'] = 'end' + x.nodes.loc[x.nodes.parent_id < 0, "type"] = "root" + x.nodes.loc[ + (~x.nodes.node_id.isin(x.nodes.parent_id.values) & (x.nodes.parent_id >= 0)), + "type", + ] = "end" # Filter connectors if not keep_disc_cn and x.has_connectors: x._connectors = x.connectors[x.connectors.node_id.isin(x.nodes.node_id)] x._connectors.reset_index(inplace=True, drop=True) - if getattr(x, 'tags', None) is not None: + if getattr(x, "tags", None) is not None: # Filter tags - x.tags = {t: [tn for tn in x.tags[t] if tn in x.nodes.node_id.values] for t in x.tags} # type: ignore # TreeNeuron has no tags + x.tags = { + t: [tn for tn in x.tags[t] if tn in x.nodes.node_id.values] for t in x.tags + } # type: ignore # TreeNeuron has no tags # Remove empty tags x.tags = {t: x.tags[t] for t in x.tags if x.tags[t]} # type: ignore # TreeNeuron has no tags # Fix graph representations (avoids having to recompute them) - if '_graph_nx' in x.__dict__: + if "_graph_nx" in x.__dict__: x._graph_nx = x.graph.subgraph(x.nodes.node_id.values) - if '_igraph' in x.__dict__: + if "_igraph" in x.__dict__: if x.igraph and config.use_igraph: - id2ix = {n: ix for ix, n in zip(x.igraph.vs.indices, - x.igraph.vs.get_attribute_values('node_id'))} + id2ix = { + n: ix + for ix, n in zip( + x.igraph.vs.indices, x.igraph.vs.get_attribute_values("node_id") + ) + } indices = [id2ix[n] for n in x.nodes.node_id.values] vs = x.igraph.vs[indices] x._igraph = x.igraph.subgraph(vs) - if hasattr(x, '_soma') and x._soma is not None: + if hasattr(x, "_soma") and x._soma is not None: # Check if soma is still in the neuron if x._soma not in x.nodes.node_id.values: x._soma = None @@ -314,3 +357,86 @@ def _subset_treeneuron(x, subset, keep_disc_cn, prevent_fragments): x.reroot(new_root, inplace=True) return x + + +def submesh(mesh, *, faces_index=None, vertex_index=None): + """Re-imlementation of trimesh.submesh that is faster for our use case. + + Notably we: + - ignore normals (possibly needed) and visuals (definitely not needed) + - allow only one set of faces to be passed + - return vertices and faces instead of a new mesh + - make as few copies as possible + - allow passing vertex indices instead of faces + + This function is 5-10x faster than trimesh.submesh for our use case. + Note that the speed of this function was never the bottleneck though, + it's about the memory footprint. + See https://github.com/navis-org/navis/issues/154. + + Parameters + ---------- + mesh : trimesh.Trimesh + Mesh to submesh. + faces_index : array-like + Indices of faces to keep. + vertex_index : array-like + Indices of vertices to keep. + + Returns + ------- + vertices : np.ndarray + Vertices of submesh. + faces : np.ndarray + Faces of submesh. + + """ + if faces_index is None and vertex_index is None: + raise ValueError("Either `faces_index` or `vertex_index` must be provided.") + elif faces_index is not None and vertex_index is not None: + raise ValueError("Only one of `faces_index` or `vertex_index` can be provided.") + + # First check if we can return either an empty mesh or the original mesh right away + if faces_index is not None: + if len(faces_index) == 0: + return np.array([]), np.array([]) + elif len(faces_index) == len(mesh.faces): + if len(np.unique(faces_index)) == len(mesh.faces): + return mesh.vertices.copy(), mesh.faces.copy() + else: + if len(vertex_index) == 0: + return np.array([]), np.array([]) + elif len(vertex_index) == len(mesh.vertices): + if len(np.unique(vertex_index)) == len(mesh.vertices): + return mesh.vertices.copy(), mesh.faces.copy() + + # Use a view of the original data + original_faces = mesh.faces.view(np.ndarray) + original_vertices = mesh.vertices.view(np.ndarray) + + # If we're starting with vertices, find faces that contain at least one of our vertices + # This way we will also make sure to drop unreferenced vertices + if vertex_index is not None: + faces_index = np.arange(len(original_faces))[ + np.isin(original_faces, vertex_index).all(axis=1) + ] + + # Get unique vertices in the to-be-kept faces + faces = original_faces[faces_index] + unique = np.unique(faces.reshape(-1)) + + # Generate a mask for the vertices + # (using int32 here since we're unlikey to have more than 2B vertices) + mask = np.arange(len(original_vertices), dtype=np.int32) + + # Remap the vertices to the new indices + mask[unique] = np.arange(len(unique)) + + # Grab the vertices in the order they are referenced + vertices = original_vertices[unique].copy() + + # Remap the faces to the new vertex indices + # (making a copy to allow `mask` to be garbage collected) + faces = mask[faces].copy() + + return vertices, faces