Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do we keep that ? (NME GPU) #41

Open
wants to merge 8 commits into
base: next
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyannote/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ An example of .env file is provided in [pyannote/.envdefault](https://github.com
|:-|:-|:-|
| `SERVING_MODE` | (Required) Specify launch mode | `http` |
| `CONCURRENCY` | Number of worker(s) additional to the main worker | `0` \| `1` \| `2` \| ... |
| `DEVICE` | Device to use for the model (by default, GPU/CUDA is used if it is available, CPU otherwise) | `cpu` \| `cuda` |
| `DEVICE` | Device to use for the model (by default, GPU/CUDA is used if it is available, CPU otherwise) | `cpu` \| `cuda` \| `cuda:1` ... |
| `NUM_THREADS` | Number of threads (maximum) to use for things running on CPU | `1` \| `4` \| ... |
| `CUDA_VISIBLE_DEVICES` | GPU device index to use, when running on GPU/CUDA. We also recommend to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` on multi-GPU machines | `0` \| `1` \| `2` \| ... |

Expand Down
3 changes: 3 additions & 0 deletions simple/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ An example of .env file is provided in [simple/.envdefault](https://github.com/l
| `DEVICE` | Device to use for the embedding model (by default, GPU/CUDA is used if it is available, CPU otherwise) | `cpu` \| `cuda` |
| `DEVICE_CLUSTERING` | Device to use for clustering (by default, GPU/CUDA is used if it is available, CPU otherwise) | `cpu` \| `cuda` |
| `NUM_THREADS` | Number of threads (maximum) to use for things running on CPU | `1` \| `4` \| ... |
| `DEVICE` | Device to use for the embeddings model (by default, GPU/CUDA is used if it is available, CPU otherwise) | `cpu` \| `cuda` \| `cuda:1` ... |
| `DEVICE_VAD` | Device to use for the Voice Activity Detection (by default, CPU) | `cpu` \| `cuda` \| `cuda:1` ... |
| `DEVICE_CLUSTERING` | Device to use for the clustering (by default, CPU) | `cpu` \| `cuda` \| `cuda:1` ... |
| `CUDA_VISIBLE_DEVICES` | GPU device index to use, when running on GPU/CUDA. We also recommend to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` on multi-GPU machines | `0` \| `1` \| `2` \| ... |


Expand Down
3 changes: 3 additions & 0 deletions simple/diarization/processing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
else:
USE_GPU = False

device_vad = os.environ.get("DEVICE_VAD", "cpu")
device_clustering = os.environ.get("DEVICE_CLUSTERING", "cpu")

# Number of CPU threads
NUM_THREADS = os.environ.get("NUM_THREADS", torch.get_num_threads())
NUM_THREADS = int(NUM_THREADS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from scipy.ndimage import gaussian_filter
from sklearn.cluster import AgglomerativeClustering, KMeans, SpectralClustering
from sklearn.metrics import pairwise_distances
from .spectral_clustering import NME_SpectralClustering

from .nmesc_clustering import (
NMESC,
Expand Down Expand Up @@ -95,7 +96,8 @@ def cluster_NME_SC(embeds, n_clusters=None, max_speakers= None, threshold=None,
labels = NME_SpectralClustering(
S,
num_clusters=n_clusters,
max_num_clusters=max_speakers
max_num_clusters=max_speakers,
device=device,
)

"""
Expand Down Expand Up @@ -123,7 +125,16 @@ def cluster_NME_SC(embeds, n_clusters=None, max_speakers= None, threshold=None,

return labels


from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler(feature_range=(0, 1))
def getCosAffinityMatrix(emb):
"""
Calculate cosine similarity values among speaker embeddings.
"""
sim_d = cosine_similarity(emb)
scaler.fit(sim_d)
sim_d = scaler.transform(sim_d)
return sim_d
def diagonal_fill(A):
"""
Sets the diagonal elemnts of the matrix to the max of each row
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def __init__(
)

self.window = window
self.period = period
self.period = period

def setup_VAD(self, device):
self.device_vad = device
use_gpu = device != "cpu"
Expand Down Expand Up @@ -266,8 +266,7 @@ def diarize(

if self.num_threads:
# For VAD / embedding
torch.set_num_threads(self.num_threads)

torch.set_num_threads(self.num_threads)
recname = os.path.splitext(os.path.basename(wav_file))[0]

if check_wav_16khz_mono(wav_file):
Expand All @@ -294,8 +293,7 @@ def diarize(
self.log("Extracting embeddings...")
tic = time.time()
embeds, segments = self.recording_embeds(signal, fs, speech_ts)
self.log(f"Done in {time.time() - tic:.3f} seconds")

self.log(f"Done in {time.time() - tic:.3f} seconds")
[w, k] = embeds.shape
if w >= 2:
self.log("Clustering to {} speakers...".format(num_speakers))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import numpy as np
import scipy
from sklearn.cluster import SpectralClustering
import torch

# NME low-level operations
# These functions are taken from the Kaldi scripts.

# Prepares binarized(0/1) affinity matrix with p_neighbors non-zero elements in each row
def get_kneighbors_conn(X_dist, p_neighbors):
X_dist_out = np.zeros_like(X_dist)
for i, line in enumerate(X_dist):
sorted_idx = np.argsort(line)
sorted_idx = sorted_idx[::-1]
indices = sorted_idx[:p_neighbors]
X_dist_out[indices, i] = 1
return X_dist_out


# Thresolds affinity matrix to leave p maximum non-zero elements in each row
def Threshold(A, p):
N = A.shape[0]
Ap = np.zeros((N, N))
for i in range(N):
thr = sorted(A[i, :], reverse=True)[p]
Ap[i, A[i, :] > thr] = A[i, A[i, :] > thr]
return Ap


# Computes Laplacian of a matrix
def Laplacian(A):
d = np.sum(A, axis=1) - np.diag(A)
D = np.diag(d)
return D - A


# Calculates eigengaps (differences between adjacent eigenvalues sorted in descending order)
def Eigengap(S):
S = sorted(S)
return np.diff(S)

def getLamdaGaplist(lambdas):
lambdas = np.real(lambdas)
return list(lambdas[1:] - lambdas[:-1])

# Computes parameters of normalized eigenmaps for automatic thresholding selection
def ComputeNMEParameters(A, p, max_num_clusters, device):
# p-Neighbour binarization
Ap = get_kneighbors_conn(A, p)
# Symmetrization
Ap = (Ap + np.transpose(Ap)) / 2
# Laplacian matrix computation
Lp = Laplacian(Ap)
# Get max_num_clusters+1 smallest eigenvalues
from torch.linalg import eigh

Lp = torch.from_numpy(Lp).float().to(device)
lambdas, _ = eigh(Lp)
S = lambdas.cpu().numpy()
# Eigengap computation

e = np.sort(S)
g = getLamdaGaplist(e)
k = np.argmax(g[: min(max_num_clusters, len(g))])
arg_sorted_idx = np.argsort(g[: max_num_clusters])[::-1]
max_key = arg_sorted_idx[0]
max_eig_gap = g[max_key] / (max(e) + 1e-10)
r = (p / A.shape[0]) / (max_eig_gap + 1e-10)


return (e, g, k, r)

"""
Performs spectral clustering with Normalized Maximum Eigengap (NME)
Parameters:
A: affinity matrix (matrix of pairwise cosine similarities or PLDA scores between speaker embeddings)
num_clusters: number of clusters to generate (if None, determined automatically)
max_num_clusters: maximum allowed number of clusters to generate
pmax: maximum count for matrix binarization (should be at least 2)
pbest: best count for matrix binarization (if 0, determined automatically)
Returns: cluster assignments for every speaker embedding
"""
def getPvalueList(mat,max_rp_threshold):
"""
Generates a p-value (p_neighbour) list for searching.
"""

max_N = int(mat.shape[0] * max_rp_threshold)

N = min(max_N, 30)
p_value_list = list(np.linspace(1, max_N, N, endpoint=True).astype(int))

if p_value_list ==[] :
p_value_list= range(1,mat.shape[0])
return p_value_list

def NME_SpectralClustering(
A, num_clusters=None, max_num_clusters=None, pbest=0, pmin=3, pmax=20, device=None
):
if max_num_clusters is None:
assert num_clusters is not None, "Cannot have both num_clusters and max_num_clusters be None"
max_num_clusters = num_clusters

if pbest == 0:
# Selecting best number of neighbors for affinity matrix thresolding
rbest = None
kbest = None
p_value_list = getPvalueList(A,0.25)
for p in p_value_list:
e, g, k, r = ComputeNMEParameters(A, p, max_num_clusters,device)
if rbest is None or rbest > r:
rbest = r
pbest = p
kbest = k
num_clusters = num_clusters if num_clusters is not None else (kbest + 1)
return NME_SpectralClustering_sklearn(
A, num_clusters, pbest
)

if num_clusters is None:
e, g, k, r = ComputeNMEParameters(A, pbest, max_num_clusters)
return NME_SpectralClustering_sklearn(A, k + 1, pbest)

return NME_SpectralClustering_sklearn(A, num_clusters, pbest)


"""
Performs spectral clustering with Normalized Maximum Eigengap (NME) with fixed threshold and number of clusters
Parameters:
A: affinity matrix (matrix of pairwise cosine similarities or PLDA scores between speaker embeddings)
OLVec: 0/1 vector denoting which segments are overlap segments
num_clusters: number of clusters to generate
pbest: best count for matrix binarization
Returns: cluster assignments for every speaker embedding
"""


def NME_SpectralClustering_sklearn(A, num_clusters, pbest):

# Ap = Threshold(A, pbest)
Ap = get_kneighbors_conn(A, pbest) # thresholded and binarized
Ap = (Ap + np.transpose(Ap)) / 2


model = SpectralClustering(
n_clusters=num_clusters, affinity="precomputed", random_state=0
)
labels = model.fit_predict(Ap)
return labels