Skip to content

Commit

Permalink
Slight cleanup in cluster.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jakobnissen committed Nov 1, 2023
1 parent 77c7156 commit 987fdc5
Showing 1 changed file with 27 additions and 44 deletions.
71 changes: 27 additions & 44 deletions vamb/cluster.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
__doc__ = """Iterative medoid clustering.
Usage:
>>> clusters = list(ClusterIterator(matrix))
Implements one core function, cluster, along with the helper
functions write_clusters and read_clusters.
"""
__doc__ = "Iterative medoid clustering"

import random as _random
import numpy as _np
Expand Down Expand Up @@ -37,11 +30,6 @@ class NoThreshold:
pass


class Default:
__slots__ = []
pass


# This is the PDF of normal with µ=0, s=0.01 from -0.075 to 0.075 with intervals
# of DELTA_X, for a total of 31 values. We multiply by _DELTA_X so the density
# of one point sums to approximately one
Expand Down Expand Up @@ -312,6 +300,7 @@ def __iter__(self):
def __next__(self) -> Cluster:
if self.n_remaining_points == 0:
raise StopIteration
assert self.n_remaining_points > 0 # not negative

cluster, _, points = self.find_cluster()
self.n_emitted_clusters += 1
Expand Down Expand Up @@ -389,9 +378,8 @@ def get_next_seed(self) -> int:
self.order[i] = -1
continue

self.order_index = (
i + 1
) # Move to next index for the next time this is called
# Move to next index for the next time this is called
self.order_index = i + 1
return new_index

def update_successes(self, success: bool):
Expand Down Expand Up @@ -469,12 +457,9 @@ def wander_medoid(self, seed) -> tuple[int, _Tensor]:

return (medoid, distances)

def find_threshold(
self, distances: _Tensor
) -> Union[Loner, NoThreshold, Default, float]:
def find_threshold(self, distances: _Tensor) -> Union[Loner, NoThreshold, float]:
# If the point is a loner, immediately return a threshold in where only
# that point is contained.
# TODO: Avoid this dual pass in this critical function for performance? How...?
if _torch.count_nonzero(distances < 0.05) == 1:
return Loner()

Expand All @@ -501,9 +486,6 @@ def find_threshold(
weight=picked_lengths,
)

# When the peak_valley_ratio is too high, we need to return something to not get caught
# in an infinite loop.
must_return_points = self.peak_valley_ratio > 0.55
peak_density = 0.0
peak_over = False
minimum_x = 0.0
Expand Down Expand Up @@ -532,7 +514,7 @@ def find_threshold(
if not peak_over and density > peak_density:
# Do not accept first peak to be after x = 0.1
if x > 0.1:
return Default() if must_return_points else NoThreshold()
return NoThreshold()
peak_density = density

# Peak is over when density drops below 60% of peak density
Expand All @@ -556,11 +538,11 @@ def find_threshold(

# If we have not detected a threshold, we can't return one.
if threshold is None:
return Default() if must_return_points else NoThreshold()
return NoThreshold()
# Else, we check whether the threshold is too high. If not, we return it.
else:
if threshold > 0.2 + self.peak_valley_ratio:
return Default() if must_return_points else NoThreshold()
return NoThreshold()
else:
return threshold

Expand All @@ -582,25 +564,26 @@ def find_cluster(self) -> tuple[Cluster, int, _Tensor]:
)
points = _torch.IntTensor([medoid])
return (cluster, medoid, points)

elif isinstance(threshold, Default):
points = _smaller_indices(
distances, self.kept_mask, _DEFAULT_RADIUS, self.cuda
)
cluster = Cluster(
int(self.indices[medoid].item()), # type: ignore
seed,
self.indices[points].numpy(),
self.peak_valley_ratio,
_DEFAULT_RADIUS,
True,
self.successes,
len(self.attempts),
)
return (cluster, medoid, points)

elif isinstance(threshold, NoThreshold):
self.update_successes(False)
# When the peak_valley_ratio is too high, we need to return something to not get caught
# in an infinite loop.
if self.peak_valley_ratio > 0.55:
points = _smaller_indices(
distances, self.kept_mask, _DEFAULT_RADIUS, self.cuda
)
cluster = Cluster(
int(self.indices[medoid].item()), # type: ignore
seed,
self.indices[points].numpy(),
self.peak_valley_ratio,
_DEFAULT_RADIUS,
True,
self.successes,
len(self.attempts),
)
return (cluster, medoid, points)
else:
self.update_successes(False)

elif isinstance(threshold, float):
points = _smaller_indices(
Expand Down

0 comments on commit 987fdc5

Please sign in to comment.