From a7bac88550a8fde5d3d5b476c0dbeeb6c4e8a2d1 Mon Sep 17 00:00:00 2001 From: Mike Lasby Date: Fri, 9 Aug 2024 10:08:46 -0600 Subject: [PATCH] better str method to account for ablated neurons --- sparsimony/dst/srigl.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sparsimony/dst/srigl.py b/sparsimony/dst/srigl.py index 5244022..98dc415 100644 --- a/sparsimony/dst/srigl.py +++ b/sparsimony/dst/srigl.py @@ -191,7 +191,10 @@ def __str__(self) -> str: for config in self.groups: mask = get_mask(**config) mask_flat = mask.view(mask.shape[0], prod(mask.shape[1:])) - ffi.append(mask_flat.sum(dim=1, dtype=torch.int).unique().item()) + this_ffi = mask_flat.sum(dim=1, dtype=torch.int).unique() + if len(this_ffi) > 1: + this_ffi = this_ffi[this_ffi != 0] + ffi.append(this_ffi.item()) s = super().__str__() s += f"FFI: {ffi}\n" return s