From 10c3b0c2d2234b4f54b87b1bd1e704bed6f447c7 Mon Sep 17 00:00:00 2001 From: htagourti Date: Tue, 15 Oct 2024 09:22:30 +0000 Subject: [PATCH 01/21] added qdrant init to compute_embeddings --- identification/speaker_identify.py | 39 ++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/identification/speaker_identify.py b/identification/speaker_identify.py index 06f34bb..0c37add 100644 --- a/identification/speaker_identify.py +++ b/identification/speaker_identify.py @@ -17,6 +17,8 @@ import glob import json from tqdm import tqdm +from qdrant_client import QdrantClient +from qdrant_client.http.models import VectorParams, Distance, PointStruct device = os.environ.get("DEVICE_IDENTIFICATION", os.environ.get("DEVICE", None)) if device is None: @@ -114,6 +116,7 @@ def initialize_embeddings( log = None, max_duration = 60 * 3, sample_rate = 16_000, + collection_name="speaker_embeddings", ): """ Pre-compute and store reference speaker embeddings @@ -136,8 +139,25 @@ def initialize_embeddings( ) if log: log.info(f"Speaker identification model loaded in {time.time() - tic:.3f} seconds on {device}") + # Initialize Qdrant client + client = QdrantClient(url="http://localhost:6333") + + # Create collection if not exists + if not client.collection_exists(collection_name=collection_name): + if log: + log.info(f"Creating collection: {collection_name}") + client.create_collection( + collection_name=collection_name, + vectors_config=VectorParams( + size=192, # Adjust according to your embedding size + distance=Distance.COSINE + ), + ) + + os.makedirs(_FOLDER_EMBEDDINGS, exist_ok=True) speakers = list(_get_speaker_names()) + points = [] # List to store points for Qdrant upsert for speaker_name in tqdm(speakers, desc="Compute ref. speaker embeddings"): embedding_file = _get_speaker_embedding_file(speaker_name) if os.path.isfile(embedding_file): @@ -182,8 +202,23 @@ def initialize_embeddings( spk_embed = compute_embedding(audio) # Note: it is important to save the embeddings on the CPU (to be able to load them on the CPU later on) spk_embed = spk_embed.cpu() - with open(embedding_file, "wb") as f: - pkl.dump(spk_embed, f) + # Prepare point for Qdrant + point = PointStruct( + id=speaker_name, # Use a unique identifier for each speaker + vector=spk_embed.numpy().tolist(), # Convert to list for Qdrant + payload={"person": speaker_name} + ) + + points.append(point) # Append point to the list + + # Upsert all points to Qdrant in one go + if points: + operation_info = client.upsert( + collection_name=collection_name, + wait=True, + points=points + ) + if log: log.info(f"Speaker identification initialized with {len(speakers)} speakers") def compute_embedding(audio, min_len = 640): From 96f0831fb00169b38c866da3640b156922c802a5 Mon Sep 17 00:00:00 2001 From: htagourti Date: Tue, 15 Oct 2024 10:10:21 +0000 Subject: [PATCH 02/21] added qdrant search to speaker_identify --- identification/speaker_identify.py | 50 +++++++++++++----------------- 1 file changed, 22 insertions(+), 28 deletions(-) diff --git a/identification/speaker_identify.py b/identification/speaker_identify.py index 0c37add..2f7f324 100644 --- a/identification/speaker_identify.py +++ b/identification/speaker_identify.py @@ -116,7 +116,8 @@ def initialize_embeddings( log = None, max_duration = 60 * 3, sample_rate = 16_000, - collection_name="speaker_embeddings", + qdrant_client = QdrantClient(url="http://localhost:6333"), + qdrant_collection="speaker_embeddings", ): """ Pre-compute and store reference speaker embeddings @@ -140,14 +141,13 @@ def initialize_embeddings( if log: log.info(f"Speaker identification model loaded in {time.time() - tic:.3f} seconds on {device}") # Initialize Qdrant client - client = QdrantClient(url="http://localhost:6333") # Create collection if not exists - if not client.collection_exists(collection_name=collection_name): + if not qdrant_client.collection_exists(collection_name=qdrant_collection): if log: - log.info(f"Creating collection: {collection_name}") - client.create_collection( - collection_name=collection_name, + log.info(f"Creating collection: {qdrant_collection}") + qdrant_client.create_collection( + collection_name=qdrant_collection, vectors_config=VectorParams( size=192, # Adjust according to your embedding size distance=Distance.COSINE @@ -159,17 +159,9 @@ def initialize_embeddings( speakers = list(_get_speaker_names()) points = [] # List to store points for Qdrant upsert for speaker_name in tqdm(speakers, desc="Compute ref. speaker embeddings"): - embedding_file = _get_speaker_embedding_file(speaker_name) - if os.path.isfile(embedding_file): - try: - with open(embedding_file, "rb") as f: - pkl.load(f) - if log: log.info(f"Speaker {speaker_name} embedding already computed") - continue - except Exception as e: - os.remove(embedding_file) audio_files = _get_speaker_sample_files(speaker_name) assert len(audio_files) > 0, f"No audio files found for speaker {speaker_name}" + audio = None max_samples = max_duration * sample_rate for audio_file in audio_files: @@ -213,8 +205,8 @@ def initialize_embeddings( # Upsert all points to Qdrant in one go if points: - operation_info = client.upsert( - collection_name=collection_name, + operation_info = qdrant_client.upsert( + collection_name=qdrant_collection, wait=True, points=points ) @@ -342,6 +334,8 @@ def speaker_identify( limit_duration=3 * 60, log = None, spk_tag = None, + qdrant_client = QdrantClient(url="http://localhost:6333"), + qdrant_collection="speaker_embeddings", ): """ Run speaker identification on given segments of an audio @@ -395,22 +389,22 @@ def speaker_identify( embedding_audio = compute_embedding(audio_selection) - # Loop on the target speakers - for speaker_name in speaker_names: + # Search for similar embeddings in Qdrant + results = qdrant_client.search(qdrant_collection, embedding_audio[0]) + + for result in results: + speaker_name = result.payload["person"] + + # Check if the speaker is in the exclude list if speaker_name in exclude_speakers: continue - - # Get speaker embedding - with open(_get_speaker_embedding_file(speaker_name), "rb") as f: - embedding_speaker = pkl.load(f) - embedding_speaker = embedding_speaker.to(_embedding_model.device) - - # Compute score similarity - score = similarity(embedding_speaker, embedding_audio) - score = score.item() + + # Use the similarity score returned by Qdrant + score = result.score # Directly get the similarity score from the result if score >= min_similarity: votes[speaker_name] += score + score = None if not votes: argmax_speaker = _UNKNOWN From 8c5187b773fcbb2b30a7025546cc0b44aa37802f Mon Sep 17 00:00:00 2001 From: htagourti Date: Tue, 15 Oct 2024 15:36:11 +0000 Subject: [PATCH 03/21] implemented qdrant for embeddings storage and search --- docker-compose.yml | 45 +++++++++---- http_server/ingress.py | 2 + identification/speaker_identify.py | 63 ++++++++++++------- pyannote/diarization/processing/__init__.py | 10 ++- .../processing/speakerdiarization.py | 8 ++- pyannote/requirements.txt | 1 + 6 files changed, 90 insertions(+), 39 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 861d26d..6f606df 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,16 +1,37 @@ -version: '3.7' - services: - my-diarization-service: - image: linto-diarization-simple:latest + qdrant: + image: qdrant/qdrant + container_name: qdrant + ports: + - "6333:6333" # Qdrant default port + volumes: + - ./qdrant_storage:/qdrant/storage:z + + diarization_app: + build: + context : . + dockerfile: pyannote/Dockerfile + container_name: diarization_app + shm_size: '1gb' + ports : + - 8080:80 + environment: + - QDRANT_HOST + - QDRANT_PORT + - QDRANT_COLLECTION_NAME + - SERVICE_MODE volumes: - - /path/to/shared/folder:/opt/audio - env_file: .env + - ./data/speakers_samples:/opt/speaker_samples + depends_on: + - qdrant # Ensure Qdrant starts before the app deploy: - replicas: 1 - networks: - - your-net + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + -networks: - your-net: - external: true +volumes: + qdrant_storage: diff --git a/http_server/ingress.py b/http_server/ingress.py index 731665d..922a08d 100644 --- a/http_server/ingress.py +++ b/http_server/ingress.py @@ -3,6 +3,7 @@ import json import logging from time import time +import os from confparser import createParser from flask import Flask, Response, abort, json, request @@ -11,6 +12,7 @@ from diarization.processing import diarizationworker, USE_GPU + app = Flask("__diarization-serving__") logging.basicConfig( diff --git a/identification/speaker_identify.py b/identification/speaker_identify.py index 2f7f324..c168033 100644 --- a/identification/speaker_identify.py +++ b/identification/speaker_identify.py @@ -40,9 +40,13 @@ _UNKNOWN = "<>" -def initialize_speaker_identification(log): +def initialize_speaker_identification( + qdrant_client = None, + qdrant_collection=None, + log=None): + initialize_db(log) - initialize_embeddings(log) + initialize_embeddings(qdrant_client, qdrant_collection, log) def is_speaker_identification_enabled(): @@ -113,11 +117,11 @@ def convert_wavfile(wavfile, outfile): return outfile def initialize_embeddings( + qdrant_client = None, + qdrant_collection=None, log = None, max_duration = 60 * 3, sample_rate = 16_000, - qdrant_client = QdrantClient(url="http://localhost:6333"), - qdrant_collection="speaker_embeddings", ): """ Pre-compute and store reference speaker embeddings @@ -140,25 +144,27 @@ def initialize_embeddings( ) if log: log.info(f"Speaker identification model loaded in {time.time() - tic:.3f} seconds on {device}") - # Initialize Qdrant client - - # Create collection if not exists - if not qdrant_client.collection_exists(collection_name=qdrant_collection): + # Check if the collection exists + if qdrant_client.collection_exists(collection_name=qdrant_collection): if log: - log.info(f"Creating collection: {qdrant_collection}") - qdrant_client.create_collection( - collection_name=qdrant_collection, - vectors_config=VectorParams( - size=192, # Adjust according to your embedding size - distance=Distance.COSINE - ), - ) + log.info(f"Deleting existing collection: {qdrant_collection}") + qdrant_client.delete_collection(collection_name=qdrant_collection) + # Create collection + if log: + log.info(f"Creating collection: {qdrant_collection}") + qdrant_client.create_collection( + collection_name=qdrant_collection, + vectors_config=VectorParams( + size=192, # Adjust according to your embedding size + distance=Distance.COSINE + ), + ) os.makedirs(_FOLDER_EMBEDDINGS, exist_ok=True) speakers = list(_get_speaker_names()) points = [] # List to store points for Qdrant upsert - for speaker_name in tqdm(speakers, desc="Compute ref. speaker embeddings"): + for _,speaker_name in enumerate(tqdm(speakers, desc="Compute ref. speaker embeddings")): audio_files = _get_speaker_sample_files(speaker_name) assert len(audio_files) > 0, f"No audio files found for speaker {speaker_name}" @@ -196,9 +202,9 @@ def initialize_embeddings( spk_embed = spk_embed.cpu() # Prepare point for Qdrant point = PointStruct( - id=speaker_name, # Use a unique identifier for each speaker - vector=spk_embed.numpy().tolist(), # Convert to list for Qdrant - payload={"person": speaker_name} + id=_+1, + vector=spk_embed.flatten(),#.numpy().tolist(), # Convert to list for Qdrant + payload={"person": speaker_name.strip()} ) points.append(point) # Append point to the list @@ -332,10 +338,10 @@ def speaker_identify( min_similarity=0.5, sample_rate=16_000, limit_duration=3 * 60, + qdrant_client = None, + qdrant_collection=None, log = None, spk_tag = None, - qdrant_client = QdrantClient(url="http://localhost:6333"), - qdrant_collection="speaker_embeddings", ): """ Run speaker identification on given segments of an audio @@ -390,7 +396,7 @@ def speaker_identify( embedding_audio = compute_embedding(audio_selection) # Search for similar embeddings in Qdrant - results = qdrant_client.search(qdrant_collection, embedding_audio[0]) + results = qdrant_client.search(qdrant_collection, embedding_audio.flatten()) for result in results: speaker_name = result.payload["person"] @@ -490,7 +496,14 @@ def check_speaker_specification(speakers_spec, cursor=None): return speaker_names -def speaker_identify_given_diarization(audioFile, diarization, speakers_spec="*", log=None, options={}): +def speaker_identify_given_diarization( + audioFile, + diarization, + speakers_spec="*", + qdrant_client = None, + qdrant_collection=None, + log=None, + options={}): """ Run speaker identification on given diarized audio file @@ -560,6 +573,8 @@ def speech_duration(spk): exclude_speakers=([] if _can_identify_twice_the_same_speaker else already_identified), log=log, spk_tag=spk_tag, + qdrant_client=qdrant_client, + qdrant_collection=qdrant_collection, **options ) if spk_name == _UNKNOWN: diff --git a/pyannote/diarization/processing/__init__.py b/pyannote/diarization/processing/__init__.py index aaec5aa..78879c0 100644 --- a/pyannote/diarization/processing/__init__.py +++ b/pyannote/diarization/processing/__init__.py @@ -1,5 +1,13 @@ import os import torch +from qdrant_client import QdrantClient + +# Initialize Qdrant Client +qdrant_host = os.getenv("QDRANT_HOST", "qdrant") +qdrant_port = os.getenv("QDRANT_PORT", "6333") +qdrant_collection = os.getenv("QDRANT_COLLECTION_NAME", "speaker_embeddings") +qdrant_client = QdrantClient(url=f"http://{qdrant_host}:{qdrant_port}") # Replace with your Qdrant URL + device = os.environ.get("DEVICE") if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" @@ -21,6 +29,6 @@ from .speakerdiarization import SpeakerDiarization -diarizationworker = SpeakerDiarization(device=device, num_threads=NUM_THREADS) +diarizationworker = SpeakerDiarization(device=device, num_threads=NUM_THREADS, qdrant_client=qdrant_client, qdrant_collection=qdrant_collection) __all__ = ["diarizationworker"] diff --git a/pyannote/diarization/processing/speakerdiarization.py b/pyannote/diarization/processing/speakerdiarization.py index 3c3c658..0e7479c 100644 --- a/pyannote/diarization/processing/speakerdiarization.py +++ b/pyannote/diarization/processing/speakerdiarization.py @@ -30,6 +30,8 @@ def __init__( device=None, num_threads=4, tolerated_silence=0, + qdrant_client = None, + qdrant_collection=None, ): """ Speaker Diarization class @@ -53,6 +55,8 @@ def __init__( + (f" ({num_threads} threads)" if device == "cpu" else "") ) self.tolerated_silence = tolerated_silence + self.qdrant_client = qdrant_client + self.qdrant_collection = qdrant_collection home = os.path.expanduser('~') model_configuration = "pyannote/speaker-diarization-3.1" @@ -72,7 +76,7 @@ def __init__( self.num_threads = num_threads self.tempfile = None - initialize_speaker_identification(self.log) + initialize_speaker_identification(self.qdrant_client, self.qdrant_collection,self.log) def run_pyannote(self, audioFile, speaker_count, max_speaker): @@ -219,7 +223,7 @@ def run( result = self.run_pyannote( file_path, speaker_count=speaker_count, max_speaker=max_speaker ) - result = speaker_identify_given_diarization(file_path, result, speaker_names, log=self.log) + result = speaker_identify_given_diarization(file_path, result, speaker_names, log=self.log, qdrant_client=self.qdrant_client, qdrant_collection=self.qdrant_collection) return result except Exception as e: self.log.error(e) diff --git a/pyannote/requirements.txt b/pyannote/requirements.txt index ac1032c..e54cc9b 100644 --- a/pyannote/requirements.txt +++ b/pyannote/requirements.txt @@ -12,3 +12,4 @@ torchaudio==2.2.1 memory-tempfile==2.2.3 # Version 2 of numpy breaks pyannote 3.1.1 (use of np.NaN instead of np.nan) numpy<2 +qdrant-client From de2be332864a4ce14621f208a1edff29edab0eaa Mon Sep 17 00:00:00 2001 From: htagourti Date: Wed, 16 Oct 2024 13:57:23 +0000 Subject: [PATCH 04/21] Replace sqlite db queries with Qdrant --- identification/speaker_identify.py | 271 +++++++----------- pyannote/diarization/processing/__init__.py | 8 +- .../processing/speakerdiarization.py | 6 +- 3 files changed, 109 insertions(+), 176 deletions(-) diff --git a/identification/speaker_identify.py b/identification/speaker_identify.py index c168033..bcd4e9f 100644 --- a/identification/speaker_identify.py +++ b/identification/speaker_identify.py @@ -11,13 +11,10 @@ import subprocess import memory_tempfile import werkzeug -import pickle as pkl -import hashlib -import sqlite3 import glob import json from tqdm import tqdm -from qdrant_client import QdrantClient +from qdrant_client import models from qdrant_client.http.models import VectorParams, Distance, PointStruct device = os.environ.get("DEVICE_IDENTIFICATION", os.environ.get("DEVICE", None)) @@ -34,89 +31,14 @@ # Constants (that could be env variables) _FOLDER_WAV = os.environ.get("SPEAKER_SAMPLES_FOLDER", "/opt/speaker_samples") _FOLDER_INTERNAL = os.environ.get("SPEAKER_PRECOMPUTED_FOLDER", "/opt/speaker_precomputed") -_FOLDER_EMBEDDINGS = f"{_FOLDER_INTERNAL}/embeddings" + _FILE_DATABASE = f"{_FOLDER_INTERNAL}/speakers_database" _UNKNOWN = "<>" -def initialize_speaker_identification( - qdrant_client = None, - qdrant_collection=None, - log=None): - - initialize_db(log) - initialize_embeddings(qdrant_client, qdrant_collection, log) - - -def is_speaker_identification_enabled(): - return os.path.isdir(_FOLDER_WAV) - -# Create / update / check database -def initialize_db(log): - if not is_speaker_identification_enabled(): - if log: log.info(f"Speaker identification is disabled") - return - if log: log.info(f"Speaker identification is enabled") - os.makedirs(os.path.dirname(_FILE_DATABASE), exist_ok=True) - # Create connection - conn = sqlite3.connect(_FILE_DATABASE) - cursor = conn.cursor() - # Creating and inserting into table - cursor.execute("""CREATE TABLE IF NOT EXISTS speaker_names (id integer UNIQUE, name TEXT UNIQUE)""") - all_ids = list(_get_db_speaker_ids(cursor)) - all_names = _get_db_speaker_names(cursor) - assert all_ids == list(range(1, len(all_ids)+1)), f"Speaker ids are not continuous" - assert len(all_names) == len(all_ids), f"Speaker names are not unique" - new_id = len(all_ids) + 1 - for speaker_name in _get_speaker_names(): - if speaker_name not in all_names: - cursor.execute("INSERT OR IGNORE INTO speaker_names (id, name) VALUES (?, ?)", ( - new_id, - speaker_name, - )) - new_id += 1 - conn.commit() - conn.close() - -def check_wav_16khz_mono(wavfile, log=None): - """ - Returns True if a wav file is 16khz and single channel - """ - try: - signal, fs = torchaudio.load(wavfile) - except: - if log: log.info(f"Could not load {wavfile}") - return None - assert len(signal.shape) == 2 - mono = (signal.shape[0] == 1) - freq = (fs == 16000) - if mono and freq: - return signal - reason = "" - if not mono: - reason += " is not mono" - if not freq: - if reason: - reason += " and" - reason += f" is in {freq/1000} kHz" - if log: log.info(f"File {wavfile} {reason}") - - -def convert_wavfile(wavfile, outfile): - """ - Converts file to 16khz single channel mono wav - """ - cmd = "ffmpeg -y -i {} -acodec pcm_s16le -ar 16000 -ac 1 {}".format( - wavfile, outfile - ) - subprocess.Popen(cmd, shell=True, stderr=subprocess.PIPE).wait() - if not os.path.isfile(outfile): - raise RuntimeError(f"Failed to run conversion: {cmd}") - return outfile - -def initialize_embeddings( +def initialize_speaker_identification( qdrant_client = None, qdrant_collection=None, log = None, @@ -131,7 +53,8 @@ def initialize_embeddings( max_duration (int): maximum duration (in seconds) of speech to use for speaker embeddings sample_rate (int): sample rate (of the embedding model) """ - if not is_speaker_identification_enabled(): + if not (is_speaker_identification_enabled() and qdrant_client and qdrant_collection): + if log: log.info(f"Speaker identification is disabled") return global _embedding_model @@ -139,11 +62,12 @@ def initialize_embeddings( tic = time.time() _embedding_model = EncoderClassifier.from_hparams( source="speechbrain/spkrec-ecapa-voxceleb", - # savedir="pretrained_models/spkrec-ecapa-voxceleb", run_opts={"device":device} ) if log: log.info(f"Speaker identification model loaded in {time.time() - tic:.3f} seconds on {device}") - + + if log: log.info(f"Speaker identification is enabled") + # Check if the collection exists if qdrant_client.collection_exists(collection_name=qdrant_collection): if log: @@ -161,7 +85,6 @@ def initialize_embeddings( ), ) - os.makedirs(_FOLDER_EMBEDDINGS, exist_ok=True) speakers = list(_get_speaker_names()) points = [] # List to store points for Qdrant upsert for _,speaker_name in enumerate(tqdm(speakers, desc="Compute ref. speaker embeddings")): @@ -219,6 +142,49 @@ def initialize_embeddings( if log: log.info(f"Speaker identification initialized with {len(speakers)} speakers") + +def is_speaker_identification_enabled(): + return os.path.isdir(_FOLDER_WAV) + + +def check_wav_16khz_mono(wavfile, log=None): + """ + Returns True if a wav file is 16khz and single channel + """ + try: + signal, fs = torchaudio.load(wavfile) + except: + if log: log.info(f"Could not load {wavfile}") + return None + assert len(signal.shape) == 2 + mono = (signal.shape[0] == 1) + freq = (fs == 16000) + if mono and freq: + return signal + + reason = "" + if not mono: + reason += " is not mono" + if not freq: + if reason: + reason += " and" + reason += f" is in {freq/1000} kHz" + if log: log.info(f"File {wavfile} {reason}") + + +def convert_wavfile(wavfile, outfile): + """ + Converts file to 16khz single channel mono wav + """ + cmd = "ffmpeg -y -i {} -acodec pcm_s16le -ar 16000 -ac 1 {}".format( + wavfile, outfile + ) + subprocess.Popen(cmd, shell=True, stderr=subprocess.PIPE).wait() + if not os.path.isfile(outfile): + raise RuntimeError(f"Failed to run conversion: {cmd}") + return outfile + + def compute_embedding(audio, min_len = 640): """ Compute speaker embedding from audio @@ -232,52 +198,46 @@ def compute_embedding(audio, min_len = 640): audio = torch.cat([audio, torch.zeros(audio.shape[0], min_len - audio.shape[-1])], dim=-1) return _embedding_model.encode_batch(audio) -def _get_db_speaker_ids(cursor=None): - return _get_db_possible_values("id", cursor) - -def _get_db_speaker_names(cursor=None): - return _get_db_possible_values("name", cursor) - -def _get_db_possible_values(name, cursor, check_unique=True): - create_connection = (cursor is None) - if create_connection: - conn = sqlite3.connect(_FILE_DATABASE) - cursor = conn.cursor() - cursor.execute(f"SELECT {name} FROM speaker_names") - values = cursor.fetchall() - values = [value[0] for value in values] - if check_unique: - assert len(values) == len(set(values)), f"Values are not unique" - else: - values = list(set(values)) - if create_connection: - conn.close() - return values - -def _get_db_speaker_name(speaker_id, cursor=None): - return _get_db_speaker_attribute(speaker_id, "id", "name", cursor) - -def _get_db_speaker_id(speaker_name, cursor=None): - return _get_db_speaker_attribute(speaker_name, "name", "id", cursor) - -def _get_db_speaker_attribute(value, orig, dest, cursor): - create_connection = (cursor is None) - if create_connection: - conn = sqlite3.connect(_FILE_DATABASE) - cursor = conn.cursor() - item = cursor.execute(f"SELECT {dest} FROM speaker_names WHERE {orig} = '{value}'") - item = item.fetchone() - assert item, f"Speaker {orig} {value} not found" - assert len(item) == 1, f"Speaker {orig} {value} not unique" - value = item[0] - if create_connection: - conn.close() - return value - - -def _get_speaker_embedding_file(speaker_name): - hash = _get_speaker_hash(speaker_name) - return os.path.join(_FOLDER_EMBEDDINGS, hash + '.pkl') + +def _get_db_speaker_names(qdrant_client = None,qdrant_collection=None): + + response = qdrant_client.scroll(collection_name=qdrant_collection,with_payload=True) + return [point.payload.get("person") for point in response[0]] + + +def _get_db_speaker_name(speaker_id, qdrant_client = None,qdrant_collection=None): + + # Retrieve the point from Qdrant + response = qdrant_client.retrieve( + collection_name=qdrant_collection, + ids=[speaker_id], + ) + # Extract the 'person' payload from the response + if response : + return response[0].payload.get('person') + +def _get_db_speaker_id(speaker_name, qdrant_client = None,qdrant_collection=None): + # Filter Qdrant for speaker_name + response = qdrant_client.scroll( + collection_name=qdrant_collection, + scroll_filter = models.Filter( + must=[ + models.FieldCondition( + key="person", + match=models.MatchValue(value=speaker_name), + ) + ]) + ) + # Extract the id + points = response[0] if response else [] + + if len(points) == 0: + raise ValueError(f"Person with name '{speaker_name}' not found in the Qdrant collection.") + if len(points) > 1: + raise ValueError(f"Multiple persons with the name '{speaker_name}' found. Ensure uniqueness.") + return points[0].id + + def _get_speaker_sample_files(speaker_name): if os.path.isdir(os.path.join(_FOLDER_WAV, speaker_name)): @@ -289,37 +249,6 @@ def _get_speaker_sample_files(speaker_name): assert len(audio_files) == 1 return audio_files -_cached_speaker_hashes = {} -def _get_speaker_hash(speaker_name): - """ - Return a hash depending on the speaker audio filenames - """ - if speaker_name in _cached_speaker_hashes: - return _cached_speaker_hashes[speaker_name] - files = _get_speaker_sample_files(speaker_name) - hashes = md5sum_files(files) - hash = md5sum_object(hashes) - prefix = speaker_name.replace(" ", "-").replace("/", "--").lower().strip("_-") + "_" - hash = prefix + hash - _cached_speaker_hashes[speaker_name] = hash - return hash - -def md5sum_files(filenames): - """Compute the md5 hash of a file or a list of files""" - single = False - if not isinstance(filenames, list): - single = True - filenames = [filenames] - p = subprocess.Popen(["md5sum", *filenames], stdout = subprocess.PIPE) - (stdout, stderr) = p.communicate() - assert p.returncode == 0, f"Error running md5sum: {stderr}" - md5_string = stdout.decode("utf-8").strip() - md5_list = [f.split()[0] for f in md5_string.split("\n")] - return md5_list[0] if single else md5_list - - -def md5sum_object(obj): - return hashlib.md5(pkl.dumps(obj)).hexdigest() def _get_speaker_names(): assert os.path.isdir(_FOLDER_WAV) @@ -425,7 +354,11 @@ def speaker_identify( return argmax_speaker, score -def check_speaker_specification(speakers_spec, cursor=None): +def check_speaker_specification( + speakers_spec, + qdrant_client = None, + qdrant_collection=None, + ): """ Check and convert speaker specification to list of speaker names @@ -434,17 +367,17 @@ def check_speaker_specification(speakers_spec, cursor=None): cursor (sqlite3.Cursor): optional database cursor """ - if speakers_spec and not is_speaker_identification_enabled(): + if speakers_spec and not (is_speaker_identification_enabled() and qdrant_client and qdrant_collection): raise RuntimeError("Speaker identification is disabled (no reference speakers)") # Read list / dictionary if isinstance(speakers_spec, str): speakers_spec = speakers_spec.strip() - print("NOCOMMIT", speakers_spec, speakers_spec and (speakers_spec == "*"), _get_db_speaker_names()) + print("NOCOMMIT", speakers_spec, speakers_spec and (speakers_spec == "*"), _get_db_speaker_names(qdrant_client,qdrant_collection)) if speakers_spec: if speakers_spec == "*": # Wildcard: all speakers - speakers_spec = list(_get_db_speaker_names()) + speakers_spec = _get_db_speaker_names(qdrant_client,qdrant_collection) elif speakers_spec[0] in "[{": try: speakers_spec = json.loads(speakers_spec) @@ -470,7 +403,7 @@ def check_speaker_specification(speakers_spec, cursor=None): speaker_names = [] for item in speakers_spec: if isinstance(item, int): - items = [_get_db_speaker_name(item, cursor)] + items = [_get_db_speaker_name(item, qdrant_client, qdrant_collection)] elif isinstance(item, dict): # Should we really keep this format ? @@ -481,7 +414,7 @@ def check_speaker_specification(speakers_spec, cursor=None): elif isinstance(item, str): if all_speaker_names is None: - all_speaker_names = _get_db_speaker_names(cursor) + all_speaker_names = _get_db_speaker_names(qdrant_client, qdrant_collection) if item not in all_speaker_names: raise ValueError(f"Unknown speaker name '{item}'") items = [item] @@ -515,7 +448,7 @@ def speaker_identify_given_diarization( options (dict): optional options (e.g. {"min_similarity": 0.25, "limit_duration": 60}) """ - speaker_names = check_speaker_specification(speakers_spec) + speaker_names = check_speaker_specification(speakers_spec, qdrant_client, qdrant_collection) if not speaker_names: return diarization diff --git a/pyannote/diarization/processing/__init__.py b/pyannote/diarization/processing/__init__.py index 78879c0..15700f5 100644 --- a/pyannote/diarization/processing/__init__.py +++ b/pyannote/diarization/processing/__init__.py @@ -3,10 +3,10 @@ from qdrant_client import QdrantClient # Initialize Qdrant Client -qdrant_host = os.getenv("QDRANT_HOST", "qdrant") -qdrant_port = os.getenv("QDRANT_PORT", "6333") -qdrant_collection = os.getenv("QDRANT_COLLECTION_NAME", "speaker_embeddings") -qdrant_client = QdrantClient(url=f"http://{qdrant_host}:{qdrant_port}") # Replace with your Qdrant URL +qdrant_host = os.getenv("QDRANT_HOST") +qdrant_port = os.getenv("QDRANT_PORT") +qdrant_collection = os.getenv("QDRANT_COLLECTION_NAME") +qdrant_client = QdrantClient(url=f"http://{qdrant_host}:{qdrant_port}") if (qdrant_host and qdrant_port) else None device = os.environ.get("DEVICE") if device is None: diff --git a/pyannote/diarization/processing/speakerdiarization.py b/pyannote/diarization/processing/speakerdiarization.py index 0e7479c..75cb10c 100644 --- a/pyannote/diarization/processing/speakerdiarization.py +++ b/pyannote/diarization/processing/speakerdiarization.py @@ -30,8 +30,8 @@ def __init__( device=None, num_threads=4, tolerated_silence=0, - qdrant_client = None, - qdrant_collection=None, + qdrant_client= None, + qdrant_collection= None, ): """ Speaker Diarization class @@ -202,7 +202,7 @@ def run( speaker_names = None, ): # Early check on speaker names - speaker_names = check_speaker_specification(speaker_names) + speaker_names = check_speaker_specification(speaker_names, self.qdrant_client, self.qdrant_collection) # If we run both speaker diarization and speaker identification, we need to save the file if speaker_names and isinstance(file_path, werkzeug.datastructures.file_storage.FileStorage): From 76f0a49bfb7bd694ab20ba41e352aa3f73d63d2c Mon Sep 17 00:00:00 2001 From: htagourti Date: Thu, 17 Oct 2024 07:28:06 +0000 Subject: [PATCH 05/21] Added SpeakerIdentifier class --- identification/__init__.py | 0 identification/speaker_identification.py | 531 ++++++++++++++++++ identification/speaker_identify.py | 4 - pyannote/diarization/processing/__init__.py | 8 +- .../processing/speakerdiarization.py | 20 +- 5 files changed, 539 insertions(+), 24 deletions(-) create mode 100644 identification/__init__.py create mode 100644 identification/speaker_identification.py diff --git a/identification/__init__.py b/identification/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/identification/speaker_identification.py b/identification/speaker_identification.py new file mode 100644 index 0000000..5f265f1 --- /dev/null +++ b/identification/speaker_identification.py @@ -0,0 +1,531 @@ +import os +import torch +import torchaudio +import time +import glob +import subprocess +import json +import werkzeug +from collections import defaultdict +from tqdm import tqdm +from qdrant_client.http.models import VectorParams, Distance, PointStruct +from qdrant_client import models, QdrantClient +import speechbrain +if speechbrain.__version__ >= "1.0.0": + from speechbrain.inference.speaker import EncoderClassifier +else: + from speechbrain.pretrained import EncoderClassifier + + +class SpeakerIdentifier: + # Define class-level constants + _FOLDER_WAV = os.environ.get("SPEAKER_SAMPLES_FOLDER", "/opt/speaker_samples") + _can_identify_twice_the_same_speaker = os.environ.get("CAN_IDENTIFY_TWICE_THE_SAME_SPEAKER", "1").lower() in ["true", "1", "yes"] + _UNKNOWN = "<>" + + def __init__(self, qdrant_client=None, qdrant_collection=None, device=None, log=None): + self.device = device or self._get_device() + self.qdrant_host = os.getenv("QDRANT_HOST") + self.qdrant_port = os.getenv("QDRANT_PORT") + self.qdrant_client = QdrantClient(url=f"http://{self.qdrant_host}:{self.qdrant_port}") if (self.qdrant_host and self.qdrant_port) else None + self.qdrant_collection = os.getenv("QDRANT_COLLECTION_NAME") + self._embedding_model = None + self.log = log + + def _get_device(): + if torch.cuda.is_available(): + return "cuda" + return "cpu" + + def initialize_speaker_identification( + self, + max_duration=60 * 3, + sample_rate=16_000, + ): + + if not self.is_speaker_identification_enabled() : + if self.log: self.log.info(f"Speaker identification is disabled") + return + + if self._embedding_model is None: + tic = time.time() + self._embedding_model = EncoderClassifier.from_hparams( + source="speechbrain/spkrec-ecapa-voxceleb", + run_opts={"device":self.device} + ) + if self.log: self.log.info(f"Speaker identification model loaded in {time.time() - tic:.3f} seconds on {self.device}") + + if self.log: self.log.info(f"Speaker identification is enabled") + + # Check if the collection exists + if self.qdrant_client.collection_exists(collection_name=self.qdrant_collection): + if self.log: + self.log.info(f"Deleting existing collection: {self.qdrant_collection}") + self.qdrant_client.delete_collection(collection_name=self.qdrant_collection) + + # Create collection + if self.log: + self.log.info(f"Creating collection: {self.qdrant_collection}") + self.qdrant_client.create_collection( + collection_name=self.qdrant_collection, + vectors_config=VectorParams( + size=192, + distance=Distance.COSINE + ), + ) + + speakers = list(self._get_speaker_names()) + points = [] # List to store points for Qdrant upsert + for _,speaker_name in enumerate(tqdm(speakers, desc="Compute ref. speaker embeddings")): + audio_files = self._get_speaker_sample_files(speaker_name) + assert len(audio_files) > 0, f"No audio files found for speaker {speaker_name}" + + audio = None + max_samples = max_duration * sample_rate + for audio_file in audio_files: + clip_audio = self.check_wav_16khz_mono(audio_file) + if clip_audio is not None: + clip_sample_rate = 16000 + else: + if self.log: self.log.info(f"Converting audio file {audio_file} to single channel 16kHz WAV using ffmpeg...") + converted_wavfile = os.path.join( + os.path.dirname(audio_file), "___{}.wav".format(os.path.splitext(os.path.basename(audio_file))[0]) + ) + self.convert_wavfile(audio_file, converted_wavfile) + try: + clip_audio, clip_sample_rate = torchaudio.load(converted_wavfile) + finally: + os.remove(converted_wavfile) + + assert clip_sample_rate == sample_rate, f"Unsupported sample rate {clip_sample_rate} (only {sample_rate} is supported)" + if clip_audio.shape[1] > max_samples: + clip_audio = clip_audio[:, :max_samples] + if audio is None: + audio = clip_audio + else: + audio = torch.cat((audio, clip_audio), 1) + # Update maximum number of remaining samples + max_samples -= clip_audio.shape[1] + if max_samples <= 0: + break + + spk_embed = self.compute_embedding(audio) + # Note: it is important to save the embeddings on the CPU (to be able to load them on the CPU later on) + spk_embed = spk_embed.cpu() + # Prepare point for Qdrant + point = PointStruct( + id=_+1, + vector=spk_embed.flatten(),#.numpy().tolist(), # Convert to list for Qdrant + payload={"person": speaker_name.strip()} + ) + + points.append(point) # Append point to the list + + # Upsert all points to Qdrant in one go + if points: + self.qdrant_client.upsert( + collection_name=self.qdrant_collection, + wait=True, + points=points + ) + + if self.log: self.log.info(f"Speaker identification initialized with {len(speakers)} speakers") + + # Create a method to check if speaker identification is enabled + def is_speaker_identification_enabled(self): + return self.qdrant_client and self.qdrant_collection and os.path.isdir(self._FOLDER_WAV) + + @staticmethod + def convert_wavfile(wavfile, outfile): + """ + Converts file to 16khz single channel mono wav + """ + cmd = "ffmpeg -y -i {} -acodec pcm_s16le -ar 16000 -ac 1 {}".format( + wavfile, outfile + ) + subprocess.Popen(cmd, shell=True, stderr=subprocess.PIPE).wait() + if not os.path.isfile(outfile): + raise RuntimeError(f"Failed to run conversion: {cmd}") + return outfile + + def check_wav_16khz_mono(self,wavfile): + """ + Returns True if a wav file is 16khz and single channel + """ + try: + signal, fs = torchaudio.load(wavfile) + except: + if self.log: self.log.info(f"Could not load {wavfile}") + return None + assert len(signal.shape) == 2 + mono = (signal.shape[0] == 1) + freq = (fs == 16000) + if mono and freq: + return signal + + reason = "" + if not mono: + reason += " is not mono" + if not freq: + if reason: + reason += " and" + reason += f" is in {freq/1000} kHz" + if self.log: self.log.info(f"File {wavfile} {reason}") + + + def compute_embedding(self,audio, min_len = 640): + """ + Compute speaker embedding from audio + + Args: + audio (torch.Tensor): audio waveform + """ + assert self._embedding_model is not None, "Speaker identification model not initialized" + # The following is to avoid a failure on too short audio (less than 640 samples = 40ms at 16kHz) + if audio.shape[-1] < min_len: + audio = torch.cat([audio, torch.zeros(audio.shape[0], min_len - audio.shape[-1])], dim=-1) + return self._embedding_model.encode_batch(audio) + + + def _get_db_speaker_names(self): + + response = self.qdrant_client.scroll(collection_name=self.qdrant_collection,with_payload=True) + return [point.payload.get("person") for point in response[0]] + + + def _get_db_speaker_name(self,speaker_id): + # Retrieve the point from Qdrant + response = self.qdrant_client.retrieve( + collection_name=self.qdrant_collection, + ids=[speaker_id], + ) + # Extract the 'person' payload from the response + if response : + return response[0].payload.get('person') + + + def _get_db_speaker_id(self,speaker_name): + # Get qdrant id corresponding to speaker_name + response = self.qdrant_client.scroll( + collection_name=self.qdrant_collection, + scroll_filter = models.Filter( + must=[ + models.FieldCondition( + key="person", + match=models.MatchValue(value=speaker_name), + ) + ]) + ) + # Extract the id + points = response[0] if response else [] + + if len(points) == 0: + raise ValueError(f"Person with name '{speaker_name}' not found in the Qdrant collection.") + if len(points) > 1: + raise ValueError(f"Multiple persons with the name '{speaker_name}' found. Ensure uniqueness.") + return points[0].id + + def _get_speaker_sample_files(self,speaker_name): + if os.path.isdir(os.path.join(self._FOLDER_WAV, speaker_name)): + audio_files = sorted(glob.glob(os.path.join(self._FOLDER_WAV, speaker_name, '*'))) + else: + prefix = os.path.join(self._FOLDER_WAV, speaker_name) + audio_files = glob.glob(prefix + '.*') + audio_files = [file for file in audio_files if os.path.splitext(file)[0] == prefix] + assert len(audio_files) == 1 + return audio_files + + def _get_speaker_names(self): + assert os.path.isdir(self._FOLDER_WAV) + for root, dirs, files in os.walk(self._FOLDER_WAV): + if root == self._FOLDER_WAV: + for file in files: + yield os.path.splitext(file)[0] + else: + yield os.path.basename(root.rstrip("/")) + + def speaker_identify( + self, + audio, + speaker_names, + segments, + exclude_speakers, + min_similarity=0.5, + sample_rate=16_000, + limit_duration=3 * 60, + spk_tag = None, + ): + """ + Run speaker identification on given segments of an audio + + Args: + audio (torch.Tensor): audio waveform + speaker_names (list): list of reference speaker names + segments (list): list of segments to analyze (tuples of start and end times in seconds) + exclude_speakers (list): list of speaker names to exclude + min_similarity (float): minimum similarity to consider a speaker match + The default value 0.25 was taken from https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/inference/speaker.py#L61 + sample_rate (int): audio sample rate + limit_duration (int): maximum duration (in seconds) of speech to identify a speaker (the first seconds of speech will be used, the other will be ignored) + spk_tag: information for the logger + + Returns: + str: identified speaker name + float: similarity score + """ + tic = time.time() + + assert len(speaker_names) > 0 + + votes = defaultdict(int) + + # Sort segments by duration (longest first) + segments = sorted(segments, key=lambda x: x[1] - x[0], reverse=True) + assert len(segments) + + total_duration = sum([end - start for (start, end) in segments]) + + # Glue all the speaker segments up to a certain length + audio_selection = None + limit_samples = limit_duration * sample_rate + for start, end in segments: + start = int(start * sample_rate) + end = int(end * sample_rate) + if end - start > limit_samples: + end = start + limit_samples + + clip = audio[:, start:end] + if audio_selection is None: + audio_selection = clip + else: + audio_selection = torch.cat((audio_selection, clip), 1) + limit_samples -= (end - start) + if limit_samples <= 0: + break + + embedding_audio = self.compute_embedding(audio_selection) + + # Search for similar embeddings in Qdrant + results = self.qdrant_client.search(self.qdrant_collection, embedding_audio.flatten()) + + for result in results: + speaker_name = result.payload["person"] + + # Check if the speaker is in the exclude list + if speaker_name in exclude_speakers: + continue + + # Use the similarity score returned by Qdrant + score = result.score # Directly get the similarity score from the result + if score >= min_similarity: + votes[speaker_name] += score + + + score = None + if not votes: + argmax_speaker = self._UNKNOWN + else: + argmax_speaker = max(votes, key=votes.get) + score = votes[argmax_speaker] + + if self.log: + self.log.info( + f"Speaker recognition {spk_tag} -> {argmax_speaker} (done in {time.time() - tic:.3f} seconds, on {audio_selection.shape[1] / sample_rate:.3f} seconds of audio out of {total_duration:.3f})" + ) + + return argmax_speaker, score + + def check_speaker_specification( + self, + speakers_spec, + ): + """ + Check and convert speaker specification to list of speaker names + + Args: + speakers_spec (str, list): speaker specification + """ + + if speakers_spec and not self.is_speaker_identification_enabled(): + raise RuntimeError("Speaker identification is disabled (no reference speakers)") + + # Read list / dictionary + if isinstance(speakers_spec, str): + speakers_spec = speakers_spec.strip() + print("NOCOMMIT", speakers_spec, speakers_spec and (speakers_spec == "*"), self._get_db_speaker_names()) + if speakers_spec: + if speakers_spec == "*": + # Wildcard: all speakers + speakers_spec = self._get_db_speaker_names() + elif speakers_spec[0] in "[{": + try: + speakers_spec = json.loads(speakers_spec) + except Exception as err: + if "|" in speakers_spec: + speakers_spec = speakers_spec.split("|") + else: + raise ValueError(f"Unsupported reference speaker specification: {speakers_spec} (except empty string, \"*\", or \"speaker1|speaker2|...|speakerN\", or \"[\"speaker1\", \"speaker2\", ..., \"speakerN\"]\")") from err + if isinstance(speakers_spec, dict): + speakers_spec = [speakers_spec] + else: + speakers_spec = speakers_spec.split("|") + + # Convert to list of speaker names + if not speakers_spec: + return [] + + if not isinstance(speakers_spec, list): + raise ValueError(f"Unsupported reference speaker specification of type {type(speakers_spec)}: {speakers_spec}") + + speakers_spec = [s for s in speakers_spec if s] + all_speaker_names = None + speaker_names = [] + for item in speakers_spec: + if isinstance(item, int): + speaker_names.append(self._get_db_speaker_name(item)) + + elif isinstance(item, dict): + # Should we really keep this format ? + start = item['start'] + end = item['end'] + items=[] + for id in range(start,end+1): + speaker_names.append(self._get_db_speaker_name(id)) + + + elif isinstance(item, str): + if all_speaker_names is None: + all_speaker_names = self._get_db_speaker_names() + if item not in all_speaker_names: + raise ValueError(f"Unknown speaker name '{item}'") + speaker_names.append(item) + + else: + raise ValueError(f"Unsupported reference speaker specification of type {type(item)} (in list): {speakers_spec}") + + return speaker_names + + + def speaker_identify_given_diarization( + self, + audioFile, + diarization, + speakers_spec="*", + options={}): + """ + Run speaker identification on given diarized audio file + + Args: + audioFile (str): path to audio file + diarization (dict): diarization result + speakers_spec (list): list of reference speaker ids or ranges (e.g. [1, 2, {"start": 3, "end": 5}]) + options (dict): optional options (e.g. {"min_similarity": 0.25, "limit_duration": 60}) + """ + + speaker_names = self.check_speaker_specification(speakers_spec) + + if not speaker_names: + return diarization + + if self.log: + full_tic = time.time() + self.log.info(f"Running speaker identification with {len(speaker_names)} reference speakers") + + if isinstance(audioFile, werkzeug.datastructures.file_storage.FileStorage): + tempfile = memory_tempfile.MemoryTempfile(filesystem_types=['tmpfs', 'shm'], fallback=True) + if self.log: + self.log.info(f"Using temporary folder {tempfile.gettempdir()}") + + with tempfile.NamedTemporaryFile(suffix = ".wav") as ntf: + audioFile.save(ntf.name) + return self.speaker_identify_given_diarization(ntf.name, diarization, speaker_names) + + speaker_tags = [] + speaker_segments = {} + common = [] + speaker_map = {} + speaker_surnames = {} + + for segment in diarization["segments"]: + + start = segment["seg_begin"] + end = segment["seg_end"] + speaker = segment["spk_id"] + common.append([start, end, speaker]) + + # find different speakers + if speaker not in speaker_tags: + speaker_tags.append(speaker) + speaker_map[speaker] = speaker + speaker_segments[speaker] = [] + + speaker_segments[speaker].append([start, end]) + + audio, sample_rate = torchaudio.load(audioFile) + # This should be OK, since this is enforced by the diarization API + assert sample_rate == 16_000, f"Unsupported sample rate {sample_rate} (only 16kHz is supported)" + + # Process the speakers with the longest speech turns first + def speech_duration(spk): + return sum([end - start for (start, end) in speaker_segments[spk]]) + already_identified = [] + speaker_id_scores = {} + for spk_tag in sorted(speaker_segments.keys(), key=speech_duration, reverse=True): + spk_segments = speaker_segments[spk_tag] + + spk_name, spk_id_score = self.speaker_identify( + audio, speaker_names, spk_segments, + # TODO : do we really want to avoid that 2 speakers are the same ? + # and if we do, not that it's not invariant to the order in which segments are taken (so we should choose a somewhat optimal order) + exclude_speakers=([] if self._can_identify_twice_the_same_speaker else already_identified), + spk_tag=spk_tag, + **options + ) + if spk_name == self._UNKNOWN: + speaker_map[spk_tag] = spk_tag + else: + already_identified.append(spk_name) + speaker_map[spk_tag] = spk_name + speaker_id_scores[spk_name] = spk_id_score + + result = {} + _segments = [] + _speakers = {} + speaker_surnames = {} + for iseg, segment in enumerate(diarization["segments"]): + start = segment["seg_begin"] + end = segment["seg_end"] + speaker = segment["spk_id"] + + # Convert speaker names to spk1, spk2, etc. + if speaker not in speaker_surnames: + speaker_surnames[speaker] = ( + speaker # "spk"+str(len(speaker_surnames)+1) + ) + speaker = speaker_surnames[speaker] + speaker_name = speaker_map[speaker] + if speaker_name == self._UNKNOWN: + speaker_name = speaker + + segment["spk_id"] = speaker_name + + _segments.append(segment) + + if speaker_name not in _speakers: + _speakers[speaker_name] = {"spk_id": speaker_name} + if speaker_name in speaker_id_scores: + _speakers[speaker_name]["spk_id_score"] = round(speaker_id_scores[speaker_name], 3) + _speakers[speaker_name]["duration"] = round(end - start, 3) + _speakers[speaker_name]["nbr_seg"] = 1 + + else: + _speakers[speaker_name]["duration"] += round(end - start, 3) + _speakers[speaker_name]["nbr_seg"] += 1 + + result["speakers"] = list(_speakers.values()) + result["segments"] = _segments + + if self.log: + self.log.info(f"Speaker identification done in {time.time() - full_tic:.3f} seconds") + + return result \ No newline at end of file diff --git a/identification/speaker_identify.py b/identification/speaker_identify.py index bcd4e9f..cfeaaf6 100644 --- a/identification/speaker_identify.py +++ b/identification/speaker_identify.py @@ -30,10 +30,6 @@ # Constants (that could be env variables) _FOLDER_WAV = os.environ.get("SPEAKER_SAMPLES_FOLDER", "/opt/speaker_samples") -_FOLDER_INTERNAL = os.environ.get("SPEAKER_PRECOMPUTED_FOLDER", "/opt/speaker_precomputed") - -_FILE_DATABASE = f"{_FOLDER_INTERNAL}/speakers_database" - _UNKNOWN = "<>" diff --git a/pyannote/diarization/processing/__init__.py b/pyannote/diarization/processing/__init__.py index 15700f5..64855db 100644 --- a/pyannote/diarization/processing/__init__.py +++ b/pyannote/diarization/processing/__init__.py @@ -1,12 +1,6 @@ import os import torch -from qdrant_client import QdrantClient -# Initialize Qdrant Client -qdrant_host = os.getenv("QDRANT_HOST") -qdrant_port = os.getenv("QDRANT_PORT") -qdrant_collection = os.getenv("QDRANT_COLLECTION_NAME") -qdrant_client = QdrantClient(url=f"http://{qdrant_host}:{qdrant_port}") if (qdrant_host and qdrant_port) else None device = os.environ.get("DEVICE") if device is None: @@ -29,6 +23,6 @@ from .speakerdiarization import SpeakerDiarization -diarizationworker = SpeakerDiarization(device=device, num_threads=NUM_THREADS, qdrant_client=qdrant_client, qdrant_collection=qdrant_collection) +diarizationworker = SpeakerDiarization(device=device, num_threads=NUM_THREADS) __all__ = ["diarizationworker"] diff --git a/pyannote/diarization/processing/speakerdiarization.py b/pyannote/diarization/processing/speakerdiarization.py index 75cb10c..c8e3528 100644 --- a/pyannote/diarization/processing/speakerdiarization.py +++ b/pyannote/diarization/processing/speakerdiarization.py @@ -17,12 +17,8 @@ os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "identification" ) ) -import identification -from identification.speaker_identify import ( - initialize_speaker_identification, - check_speaker_specification, - speaker_identify_given_diarization, -) + +from identification.speaker_identification import SpeakerIdentifier class SpeakerDiarization: def __init__( @@ -30,8 +26,6 @@ def __init__( device=None, num_threads=4, tolerated_silence=0, - qdrant_client= None, - qdrant_collection= None, ): """ Speaker Diarization class @@ -55,8 +49,7 @@ def __init__( + (f" ({num_threads} threads)" if device == "cpu" else "") ) self.tolerated_silence = tolerated_silence - self.qdrant_client = qdrant_client - self.qdrant_collection = qdrant_collection + home = os.path.expanduser('~') model_configuration = "pyannote/speaker-diarization-3.1" @@ -75,8 +68,9 @@ def __init__( self.pipeline = self.pipeline.to(torch.device(device)) self.num_threads = num_threads self.tempfile = None + self.speaker_identifier = SpeakerIdentifier(device=device, log=self.log) - initialize_speaker_identification(self.qdrant_client, self.qdrant_collection,self.log) + self.speaker_identifier.initialize_speaker_identification() def run_pyannote(self, audioFile, speaker_count, max_speaker): @@ -202,7 +196,7 @@ def run( speaker_names = None, ): # Early check on speaker names - speaker_names = check_speaker_specification(speaker_names, self.qdrant_client, self.qdrant_collection) + speaker_names = self.speaker_identifier.check_speaker_specification(speaker_names) # If we run both speaker diarization and speaker identification, we need to save the file if speaker_names and isinstance(file_path, werkzeug.datastructures.file_storage.FileStorage): @@ -223,7 +217,7 @@ def run( result = self.run_pyannote( file_path, speaker_count=speaker_count, max_speaker=max_speaker ) - result = speaker_identify_given_diarization(file_path, result, speaker_names, log=self.log, qdrant_client=self.qdrant_client, qdrant_collection=self.qdrant_collection) + result = self.speaker_identifier.speaker_identify_given_diarization(file_path, result, speaker_names) return result except Exception as e: self.log.error(e) From ebb3dd15485d6968b593ebd56ecfc3ecf7afa9d0 Mon Sep 17 00:00:00 2001 From: htagourti Date: Thu, 17 Oct 2024 09:00:05 +0000 Subject: [PATCH 06/21] Fixed get_db_speaker_names --- identification/speaker_identification.py | 31 +- identification/speaker_identify.py | 556 ----------------------- 2 files changed, 26 insertions(+), 561 deletions(-) delete mode 100644 identification/speaker_identify.py diff --git a/identification/speaker_identification.py b/identification/speaker_identification.py index 5f265f1..cec7a61 100644 --- a/identification/speaker_identification.py +++ b/identification/speaker_identification.py @@ -6,6 +6,7 @@ import subprocess import json import werkzeug +import memory_tempfile from collections import defaultdict from tqdm import tqdm from qdrant_client.http.models import VectorParams, Distance, PointStruct @@ -23,7 +24,7 @@ class SpeakerIdentifier: _can_identify_twice_the_same_speaker = os.environ.get("CAN_IDENTIFY_TWICE_THE_SAME_SPEAKER", "1").lower() in ["true", "1", "yes"] _UNKNOWN = "<>" - def __init__(self, qdrant_client=None, qdrant_collection=None, device=None, log=None): + def __init__(self, device=None, log=None): self.device = device or self._get_device() self.qdrant_host = os.getenv("QDRANT_HOST") self.qdrant_port = os.getenv("QDRANT_PORT") @@ -32,6 +33,7 @@ def __init__(self, qdrant_client=None, qdrant_collection=None, device=None, log= self._embedding_model = None self.log = log + @staticmethod def _get_device(): if torch.cuda.is_available(): return "cuda" @@ -187,10 +189,29 @@ def compute_embedding(self,audio, min_len = 640): return self._embedding_model.encode_batch(audio) - def _get_db_speaker_names(self): + def _get_db_speaker_names(self, batch_size=100): + all_points = [] + offset = None # Start without any offset - response = self.qdrant_client.scroll(collection_name=self.qdrant_collection,with_payload=True) - return [point.payload.get("person") for point in response[0]] + while True: + # Scroll request with batch_size + response, next_offset = self.qdrant_client.scroll( + collection_name=self.qdrant_collection, + offset=offset, # Use offset to get the next batch + limit=batch_size, + with_payload=True, + ) + + all_points.extend(response) # Collect the points + + # Break the loop if no more points are available + if next_offset is None: + break + + # Update the offset for the next iteration + offset = next_offset + + return [point.payload.get("person") for point in all_points] def _get_db_speaker_name(self,speaker_id): @@ -317,7 +338,7 @@ def speaker_identify( # Use the similarity score returned by Qdrant score = result.score # Directly get the similarity score from the result - if score >= min_similarity: + if (score >= min_similarity) and (speaker_name in speaker_names): votes[speaker_name] += score diff --git a/identification/speaker_identify.py b/identification/speaker_identify.py deleted file mode 100644 index cfeaaf6..0000000 --- a/identification/speaker_identify.py +++ /dev/null @@ -1,556 +0,0 @@ -import speechbrain -if speechbrain.__version__ >= "1.0.0": - from speechbrain.inference.speaker import EncoderClassifier -else: - from speechbrain.pretrained import EncoderClassifier -import os -from collections import defaultdict -import torch -import torchaudio -import time -import subprocess -import memory_tempfile -import werkzeug -import glob -import json -from tqdm import tqdm -from qdrant_client import models -from qdrant_client.http.models import VectorParams, Distance, PointStruct - -device = os.environ.get("DEVICE_IDENTIFICATION", os.environ.get("DEVICE", None)) -if device is None: - if torch.cuda.is_available(): - device="cuda" - else: - device="cpu" - -_can_identify_twice_the_same_speaker = os.environ.get("CAN_IDENTIFY_TWICE_THE_SAME_SPEAKER", "1").lower() in ["true", "1", "yes"] - -_embedding_model = None - -# Constants (that could be env variables) -_FOLDER_WAV = os.environ.get("SPEAKER_SAMPLES_FOLDER", "/opt/speaker_samples") -_UNKNOWN = "<>" - - - -def initialize_speaker_identification( - qdrant_client = None, - qdrant_collection=None, - log = None, - max_duration = 60 * 3, - sample_rate = 16_000, - ): - """ - Pre-compute and store reference speaker embeddings - - Args: - log (logging.Logger): optional logger - max_duration (int): maximum duration (in seconds) of speech to use for speaker embeddings - sample_rate (int): sample rate (of the embedding model) - """ - if not (is_speaker_identification_enabled() and qdrant_client and qdrant_collection): - if log: log.info(f"Speaker identification is disabled") - return - - global _embedding_model - if _embedding_model is None: - tic = time.time() - _embedding_model = EncoderClassifier.from_hparams( - source="speechbrain/spkrec-ecapa-voxceleb", - run_opts={"device":device} - ) - if log: log.info(f"Speaker identification model loaded in {time.time() - tic:.3f} seconds on {device}") - - if log: log.info(f"Speaker identification is enabled") - - # Check if the collection exists - if qdrant_client.collection_exists(collection_name=qdrant_collection): - if log: - log.info(f"Deleting existing collection: {qdrant_collection}") - qdrant_client.delete_collection(collection_name=qdrant_collection) - - # Create collection - if log: - log.info(f"Creating collection: {qdrant_collection}") - qdrant_client.create_collection( - collection_name=qdrant_collection, - vectors_config=VectorParams( - size=192, # Adjust according to your embedding size - distance=Distance.COSINE - ), - ) - - speakers = list(_get_speaker_names()) - points = [] # List to store points for Qdrant upsert - for _,speaker_name in enumerate(tqdm(speakers, desc="Compute ref. speaker embeddings")): - audio_files = _get_speaker_sample_files(speaker_name) - assert len(audio_files) > 0, f"No audio files found for speaker {speaker_name}" - - audio = None - max_samples = max_duration * sample_rate - for audio_file in audio_files: - clip_audio = check_wav_16khz_mono(audio_file, log=log) - if clip_audio is not None: - clip_sample_rate = 16000 - else: - if log: log.info(f"Converting audio file {audio_file} to single channel 16kHz WAV using ffmpeg...") - converted_wavfile = os.path.join( - os.path.dirname(audio_file), "___{}.wav".format(os.path.splitext(os.path.basename(audio_file))[0]) - ) - convert_wavfile(audio_file, converted_wavfile) - try: - clip_audio, clip_sample_rate = torchaudio.load(converted_wavfile) - finally: - os.remove(converted_wavfile) - - assert clip_sample_rate == sample_rate, f"Unsupported sample rate {clip_sample_rate} (only {sample_rate} is supported)" - if clip_audio.shape[1] > max_samples: - clip_audio = clip_audio[:, :max_samples] - if audio is None: - audio = clip_audio - else: - audio = torch.cat((audio, clip_audio), 1) - # Update maximum number of remaining samples - max_samples -= clip_audio.shape[1] - if max_samples <= 0: - break - - spk_embed = compute_embedding(audio) - # Note: it is important to save the embeddings on the CPU (to be able to load them on the CPU later on) - spk_embed = spk_embed.cpu() - # Prepare point for Qdrant - point = PointStruct( - id=_+1, - vector=spk_embed.flatten(),#.numpy().tolist(), # Convert to list for Qdrant - payload={"person": speaker_name.strip()} - ) - - points.append(point) # Append point to the list - - # Upsert all points to Qdrant in one go - if points: - operation_info = qdrant_client.upsert( - collection_name=qdrant_collection, - wait=True, - points=points - ) - - if log: log.info(f"Speaker identification initialized with {len(speakers)} speakers") - - -def is_speaker_identification_enabled(): - return os.path.isdir(_FOLDER_WAV) - - -def check_wav_16khz_mono(wavfile, log=None): - """ - Returns True if a wav file is 16khz and single channel - """ - try: - signal, fs = torchaudio.load(wavfile) - except: - if log: log.info(f"Could not load {wavfile}") - return None - assert len(signal.shape) == 2 - mono = (signal.shape[0] == 1) - freq = (fs == 16000) - if mono and freq: - return signal - - reason = "" - if not mono: - reason += " is not mono" - if not freq: - if reason: - reason += " and" - reason += f" is in {freq/1000} kHz" - if log: log.info(f"File {wavfile} {reason}") - - -def convert_wavfile(wavfile, outfile): - """ - Converts file to 16khz single channel mono wav - """ - cmd = "ffmpeg -y -i {} -acodec pcm_s16le -ar 16000 -ac 1 {}".format( - wavfile, outfile - ) - subprocess.Popen(cmd, shell=True, stderr=subprocess.PIPE).wait() - if not os.path.isfile(outfile): - raise RuntimeError(f"Failed to run conversion: {cmd}") - return outfile - - -def compute_embedding(audio, min_len = 640): - """ - Compute speaker embedding from audio - - Args: - audio (torch.Tensor): audio waveform - """ - assert _embedding_model is not None, "Speaker identification model not initialized" - # The following is to avoid a failure on too short audio (less than 640 samples = 40ms at 16kHz) - if audio.shape[-1] < min_len: - audio = torch.cat([audio, torch.zeros(audio.shape[0], min_len - audio.shape[-1])], dim=-1) - return _embedding_model.encode_batch(audio) - - -def _get_db_speaker_names(qdrant_client = None,qdrant_collection=None): - - response = qdrant_client.scroll(collection_name=qdrant_collection,with_payload=True) - return [point.payload.get("person") for point in response[0]] - - -def _get_db_speaker_name(speaker_id, qdrant_client = None,qdrant_collection=None): - - # Retrieve the point from Qdrant - response = qdrant_client.retrieve( - collection_name=qdrant_collection, - ids=[speaker_id], - ) - # Extract the 'person' payload from the response - if response : - return response[0].payload.get('person') - -def _get_db_speaker_id(speaker_name, qdrant_client = None,qdrant_collection=None): - # Filter Qdrant for speaker_name - response = qdrant_client.scroll( - collection_name=qdrant_collection, - scroll_filter = models.Filter( - must=[ - models.FieldCondition( - key="person", - match=models.MatchValue(value=speaker_name), - ) - ]) - ) - # Extract the id - points = response[0] if response else [] - - if len(points) == 0: - raise ValueError(f"Person with name '{speaker_name}' not found in the Qdrant collection.") - if len(points) > 1: - raise ValueError(f"Multiple persons with the name '{speaker_name}' found. Ensure uniqueness.") - return points[0].id - - - -def _get_speaker_sample_files(speaker_name): - if os.path.isdir(os.path.join(_FOLDER_WAV, speaker_name)): - audio_files = sorted(glob.glob(os.path.join(_FOLDER_WAV, speaker_name, '*'))) - else: - prefix = os.path.join(_FOLDER_WAV, speaker_name) - audio_files = glob.glob(prefix + '.*') - audio_files = [file for file in audio_files if os.path.splitext(file)[0] == prefix] - assert len(audio_files) == 1 - return audio_files - - -def _get_speaker_names(): - assert os.path.isdir(_FOLDER_WAV) - for root, dirs, files in os.walk(_FOLDER_WAV): - if root == _FOLDER_WAV: - for file in files: - yield os.path.splitext(file)[0] - else: - yield os.path.basename(root.rstrip("/")) - -def speaker_identify( - audio, - speaker_names, - segments, - exclude_speakers, - min_similarity=0.5, - sample_rate=16_000, - limit_duration=3 * 60, - qdrant_client = None, - qdrant_collection=None, - log = None, - spk_tag = None, - ): - """ - Run speaker identification on given segments of an audio - - Args: - audio (torch.Tensor): audio waveform - speaker_names (list): list of reference speaker names - segments (list): list of segments to analyze (tuples of start and end times in seconds) - exclude_speakers (list): list of speaker names to exclude - min_similarity (float): minimum similarity to consider a speaker match - The default value 0.25 was taken from https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/inference/speaker.py#L61 - sample_rate (int): audio sample rate - limit_duration (int): maximum duration (in seconds) of speech to identify a speaker (the first seconds of speech will be used, the other will be ignored) - log: logger - spk_tag: information for the logger - - Returns: - str: identified speaker name - float: similarity score - """ - tic = time.time() - - similarity = torch.nn.CosineSimilarity(dim=-1, eps=1e-6) - assert len(speaker_names) > 0 - - votes = defaultdict(int) - - # Sort segments by duration (longest first) - segments = sorted(segments, key=lambda x: x[1] - x[0], reverse=True) - assert len(segments) - - total_duration = sum([end - start for (start, end) in segments]) - - # Glue all the speaker segments up to a certain length - audio_selection = None - limit_samples = limit_duration * sample_rate - for start, end in segments: - start = int(start * sample_rate) - end = int(end * sample_rate) - if end - start > limit_samples: - end = start + limit_samples - - clip = audio[:, start:end] - if audio_selection is None: - audio_selection = clip - else: - audio_selection = torch.cat((audio_selection, clip), 1) - limit_samples -= (end - start) - if limit_samples <= 0: - break - - embedding_audio = compute_embedding(audio_selection) - - # Search for similar embeddings in Qdrant - results = qdrant_client.search(qdrant_collection, embedding_audio.flatten()) - - for result in results: - speaker_name = result.payload["person"] - - # Check if the speaker is in the exclude list - if speaker_name in exclude_speakers: - continue - - # Use the similarity score returned by Qdrant - score = result.score # Directly get the similarity score from the result - if score >= min_similarity: - votes[speaker_name] += score - - - score = None - if not votes: - argmax_speaker = _UNKNOWN - else: - argmax_speaker = max(votes, key=votes.get) - score = votes[argmax_speaker] - - if log: - log.info( - f"Speaker recognition {spk_tag} -> {argmax_speaker} (done in {time.time() - tic:.3f} seconds, on {audio_selection.shape[1] / sample_rate:.3f} seconds of audio out of {total_duration:.3f})" - ) - - return argmax_speaker, score - -def check_speaker_specification( - speakers_spec, - qdrant_client = None, - qdrant_collection=None, - ): - """ - Check and convert speaker specification to list of speaker names - - Args: - speakers_spec (str, list): speaker specification - cursor (sqlite3.Cursor): optional database cursor - """ - - if speakers_spec and not (is_speaker_identification_enabled() and qdrant_client and qdrant_collection): - raise RuntimeError("Speaker identification is disabled (no reference speakers)") - - # Read list / dictionary - if isinstance(speakers_spec, str): - speakers_spec = speakers_spec.strip() - print("NOCOMMIT", speakers_spec, speakers_spec and (speakers_spec == "*"), _get_db_speaker_names(qdrant_client,qdrant_collection)) - if speakers_spec: - if speakers_spec == "*": - # Wildcard: all speakers - speakers_spec = _get_db_speaker_names(qdrant_client,qdrant_collection) - elif speakers_spec[0] in "[{": - try: - speakers_spec = json.loads(speakers_spec) - except Exception as err: - if "|" in speakers_spec: - speakers_spec = speakers_spec.split("|") - else: - raise ValueError(f"Unsupported reference speaker specification: {speakers_spec} (except empty string, \"*\", or \"speaker1|speaker2|...|speakerN\", or \"[\"speaker1\", \"speaker2\", ..., \"speakerN\"]\")") from err - if isinstance(speakers_spec, dict): - speakers_spec = [speakers_spec] - else: - speakers_spec = speakers_spec.split("|") - - # Convert to list of speaker names - if not speakers_spec: - return [] - - if not isinstance(speakers_spec, list): - raise ValueError(f"Unsupported reference speaker specification of type {type(speakers_spec)}: {speakers_spec}") - - speakers_spec = [s for s in speakers_spec if s] - all_speaker_names = None - speaker_names = [] - for item in speakers_spec: - if isinstance(item, int): - items = [_get_db_speaker_name(item, qdrant_client, qdrant_collection)] - - elif isinstance(item, dict): - # Should we really keep this format ? - start = item['start'] - end = item['end'] - for id in range(start,end+1): - items.append(_get_db_speaker_id(id)) - - elif isinstance(item, str): - if all_speaker_names is None: - all_speaker_names = _get_db_speaker_names(qdrant_client, qdrant_collection) - if item not in all_speaker_names: - raise ValueError(f"Unknown speaker name '{item}'") - items = [item] - - else: - raise ValueError(f"Unsupported reference speaker specification of type {type(item)} (in list): {speakers_spec}") - - for item in items: - if item not in speaker_names: - speaker_names.append(item) - - return speaker_names - - -def speaker_identify_given_diarization( - audioFile, - diarization, - speakers_spec="*", - qdrant_client = None, - qdrant_collection=None, - log=None, - options={}): - """ - Run speaker identification on given diarized audio file - - Args: - audioFile (str): path to audio file - diarization (dict): diarization result - speakers_spec (list): list of reference speaker ids or ranges (e.g. [1, 2, {"start": 3, "end": 5}]) - log (logging.Logger): optional logger - options (dict): optional options (e.g. {"min_similarity": 0.25, "limit_duration": 60}) - """ - - speaker_names = check_speaker_specification(speakers_spec, qdrant_client, qdrant_collection) - - if not speaker_names: - return diarization - - if log: - full_tic = time.time() - log.info(f"Running speaker identification with {len(speaker_names)} reference speakers") - - if isinstance(audioFile, werkzeug.datastructures.file_storage.FileStorage): - tempfile = memory_tempfile.MemoryTempfile(filesystem_types=['tmpfs', 'shm'], fallback=True) - if log: - log.info(f"Using temporary folder {tempfile.gettempdir()}") - - with tempfile.NamedTemporaryFile(suffix = ".wav") as ntf: - audioFile.save(ntf.name) - return speaker_identify_given_diarization(ntf.name, diarization, speaker_names) - - speaker_tags = [] - speaker_segments = {} - common = [] - speaker_map = {} - speaker_surnames = {} - - for segment in diarization["segments"]: - - start = segment["seg_begin"] - end = segment["seg_end"] - speaker = segment["spk_id"] - common.append([start, end, speaker]) - - # find different speakers - if speaker not in speaker_tags: - speaker_tags.append(speaker) - speaker_map[speaker] = speaker - speaker_segments[speaker] = [] - - speaker_segments[speaker].append([start, end]) - - audio, sample_rate = torchaudio.load(audioFile) - # This should be OK, since this is enforced by the diarization API - assert sample_rate == 16_000, f"Unsupported sample rate {sample_rate} (only 16kHz is supported)" - - # Process the speakers with the longest speech turns first - def speech_duration(spk): - return sum([end - start for (start, end) in speaker_segments[spk]]) - already_identified = [] - speaker_id_scores = {} - for spk_tag in sorted(speaker_segments.keys(), key=speech_duration, reverse=True): - spk_segments = speaker_segments[spk_tag] - - spk_name, spk_id_score = speaker_identify( - audio, speaker_names, spk_segments, - # TODO : do we really want to avoid that 2 speakers are the same ? - # and if we do, not that it's not invariant to the order in which segments are taken (so we should choose a somewhat optimal order) - exclude_speakers=([] if _can_identify_twice_the_same_speaker else already_identified), - log=log, - spk_tag=spk_tag, - qdrant_client=qdrant_client, - qdrant_collection=qdrant_collection, - **options - ) - if spk_name == _UNKNOWN: - speaker_map[spk_tag] = spk_tag - else: - already_identified.append(spk_name) - speaker_map[spk_tag] = spk_name - speaker_id_scores[spk_name] = spk_id_score - - result = {} - _segments = [] - _speakers = {} - speaker_surnames = {} - for iseg, segment in enumerate(diarization["segments"]): - start = segment["seg_begin"] - end = segment["seg_end"] - speaker = segment["spk_id"] - - # Convert speaker names to spk1, spk2, etc. - if speaker not in speaker_surnames: - speaker_surnames[speaker] = ( - speaker # "spk"+str(len(speaker_surnames)+1) - ) - speaker = speaker_surnames[speaker] - speaker_name = speaker_map[speaker] - if speaker_name == _UNKNOWN: - speaker_name = speaker - - segment["spk_id"] = speaker_name - - _segments.append(segment) - - if speaker_name not in _speakers: - _speakers[speaker_name] = {"spk_id": speaker_name} - if speaker_name in speaker_id_scores: - _speakers[speaker_name]["spk_id_score"] = round(speaker_id_scores[speaker_name], 3) - _speakers[speaker_name]["duration"] = round(end - start, 3) - _speakers[speaker_name]["nbr_seg"] = 1 - - else: - _speakers[speaker_name]["duration"] += round(end - start, 3) - _speakers[speaker_name]["nbr_seg"] += 1 - - result["speakers"] = list(_speakers.values()) - result["segments"] = _segments - - if log: - log.info(f"Speaker identification done in {time.time() - full_tic:.3f} seconds") - - return result From 7366db0f5c60384db38e78d3d764c426f0bca415 Mon Sep 17 00:00:00 2001 From: htagourti Date: Wed, 30 Oct 2024 11:21:18 +0000 Subject: [PATCH 07/21] added wait for qdrant to entrypoint --- docker-compose.yml | 6 +++++- docker-entrypoint.sh | 11 +++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/docker-compose.yml b/docker-compose.yml index 6f606df..8596e9e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -20,8 +20,12 @@ services: - QDRANT_PORT - QDRANT_COLLECTION_NAME - SERVICE_MODE + - SERVICE_NAME + - SERVICES_BROKER + - CONCURRENCY volumes: - - ./data/speakers_samples:/opt/speaker_samples + - ./data/speakers_samples:/opt/speaker_samples # Reference Speaker samples + - ./data/test_samples:/opt/audio # Test audio file (Celery task mode) depends_on: - qdrant # Ensure Qdrant starts before the app deploy: diff --git a/docker-entrypoint.sh b/docker-entrypoint.sh index fe73c5c..d8a3ffa 100755 --- a/docker-entrypoint.sh +++ b/docker-entrypoint.sh @@ -22,6 +22,16 @@ check_gpu_availability() { } +# Wait for Qdrant to be available +wait_for_qdrant() { + echo "Waiting for Qdrant to be reachable..." + /usr/src/app/wait-for-it.sh "${QDRANT_HOST}:${QDRANT_PORT}" --timeout=20 --strict -- echo "Qdrant is up" + if [ $? -ne 0 ]; then + echo "ERROR: Qdrant service not reachable at ${QDRANT_HOST}:${QDRANT_PORT}" + exit 1 + fi +} + run_http_server() { echo "HTTP server Mode" python http_server/ingress.py --debug @@ -58,6 +68,7 @@ run_celery_worker() { # Main logic check_gpu_availability +wait_for_qdrant if [ -z "$SERVICE_MODE" ]; then echo "ERROR: Must specify a serving mode: [ http | task ]" From 399c5796bab0fb6038b33f68b26b953ebd469791 Mon Sep 17 00:00:00 2001 From: htagourti Date: Wed, 30 Oct 2024 16:09:09 +0000 Subject: [PATCH 08/21] rename speaker identify --- .../{speaker_identification.py => speaker_identify.py} | 0 pyannote/diarization/processing/speakerdiarization.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename identification/{speaker_identification.py => speaker_identify.py} (100%) diff --git a/identification/speaker_identification.py b/identification/speaker_identify.py similarity index 100% rename from identification/speaker_identification.py rename to identification/speaker_identify.py diff --git a/pyannote/diarization/processing/speakerdiarization.py b/pyannote/diarization/processing/speakerdiarization.py index c8e3528..ae2fbae 100644 --- a/pyannote/diarization/processing/speakerdiarization.py +++ b/pyannote/diarization/processing/speakerdiarization.py @@ -18,7 +18,7 @@ ) ) -from identification.speaker_identification import SpeakerIdentifier +from identification.speaker_identify import SpeakerIdentifier class SpeakerDiarization: def __init__( From 8f96bbc857edfb528ef5944292e74f4f02ac0ce9 Mon Sep 17 00:00:00 2001 From: htagourti Date: Thu, 31 Oct 2024 09:03:08 +0000 Subject: [PATCH 09/21] Renamed speaker index variable for clarity - Raise error if speaker identification is enabled without Qdrant client - Applied refactoring to simple diarization process --- identification/speaker_identify.py | 20 ++++++++++++------- .../processing/speakerdiarization.py | 15 +++++--------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/identification/speaker_identify.py b/identification/speaker_identify.py index cec7a61..ad6629a 100644 --- a/identification/speaker_identify.py +++ b/identification/speaker_identify.py @@ -44,11 +44,18 @@ def initialize_speaker_identification( max_duration=60 * 3, sample_rate=16_000, ): - + # Check if speaker identification is enabled if not self.is_speaker_identification_enabled() : if self.log: self.log.info(f"Speaker identification is disabled") return + # Raise error if Qdrant client is not set + elif self.qdrant_client is None: + raise EnvironmentError( + "Qdrant client is not set. Please ensure that the environment variables 'QDRANT_HOST' " + "and 'QDRANT_PORT' are set to enable speaker identification." + ) + if self._embedding_model is None: tic = time.time() self._embedding_model = EncoderClassifier.from_hparams( @@ -78,7 +85,7 @@ def initialize_speaker_identification( speakers = list(self._get_speaker_names()) points = [] # List to store points for Qdrant upsert - for _,speaker_name in enumerate(tqdm(speakers, desc="Compute ref. speaker embeddings")): + for speaker_idx,speaker_name in enumerate(tqdm(speakers, desc="Compute ref. speaker embeddings")): audio_files = self._get_speaker_sample_files(speaker_name) assert len(audio_files) > 0, f"No audio files found for speaker {speaker_name}" @@ -116,12 +123,12 @@ def initialize_speaker_identification( spk_embed = spk_embed.cpu() # Prepare point for Qdrant point = PointStruct( - id=_+1, - vector=spk_embed.flatten(),#.numpy().tolist(), # Convert to list for Qdrant + id=speaker_idx+1, + vector=spk_embed.flatten(), # Convert to 1D list for Qdrant [[[1, 2, 3, ...]]] -> [1, 2, 3, ...] payload={"person": speaker_name.strip()} ) - points.append(point) # Append point to the list + points.append(point) # Upsert all points to Qdrant in one go if points: @@ -135,7 +142,7 @@ def initialize_speaker_identification( # Create a method to check if speaker identification is enabled def is_speaker_identification_enabled(self): - return self.qdrant_client and self.qdrant_collection and os.path.isdir(self._FOLDER_WAV) + return os.path.isdir(self._FOLDER_WAV) @staticmethod def convert_wavfile(wavfile, outfile): @@ -373,7 +380,6 @@ def check_speaker_specification( # Read list / dictionary if isinstance(speakers_spec, str): speakers_spec = speakers_spec.strip() - print("NOCOMMIT", speakers_spec, speakers_spec and (speakers_spec == "*"), self._get_db_speaker_names()) if speakers_spec: if speakers_spec == "*": # Wildcard: all speakers diff --git a/simple/diarization/processing/speakerdiarization.py b/simple/diarization/processing/speakerdiarization.py index 59e296c..b0caf4f 100644 --- a/simple/diarization/processing/speakerdiarization.py +++ b/simple/diarization/processing/speakerdiarization.py @@ -18,13 +18,7 @@ os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "identification" ) ) -import identification -from identification.speaker_identify import ( - initialize_speaker_identification, - check_speaker_specification, - speaker_identify_given_diarization, -) - +from identification.speaker_identify import SpeakerIdentifier class SpeakerDiarization: def __init__(self, device=None, device_vad=None, device_clustering=None, num_threads=None): @@ -58,8 +52,9 @@ def __init__(self, device=None, device_vad=None, device_clustering=None, num_thr ) self.tempfile = None + self.speaker_identifier = SpeakerIdentifier(device=device, log=self.log) - initialize_speaker_identification(self.log) + self.speaker_identifier.initialize_speaker_identification() def run_simple_diarizer(self, file_path, speaker_count, max_speaker): @@ -199,7 +194,7 @@ def run( speaker_names = None, ): # Early check on speaker names - speaker_names = check_speaker_specification(speaker_names) + speaker_names = self.speaker_identifier.check_speaker_specification(speaker_names) if isinstance(file_path, werkzeug.datastructures.file_storage.FileStorage): if self.tempfile is None: @@ -221,7 +216,7 @@ def run( result = self.run_simple_diarizer( file_path, speaker_count=speaker_count, max_speaker=max_speaker ) - result = speaker_identify_given_diarization(file_path, result, speaker_names, log=self.log) + result = self.speaker_identifier.speaker_identify_given_diarization(file_path, result, speaker_names) return result except Exception as e: self.log.error(e) From ce39ed91fe03ff6c7cf2e0b7cf3e491356690c65 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Thu, 31 Oct 2024 13:20:02 +0100 Subject: [PATCH 10/21] Add release notes for next versions --- pyannote/RELEASE.md | 3 +++ simple/RELEASE.md | 3 +++ 2 files changed, 6 insertions(+) diff --git a/pyannote/RELEASE.md b/pyannote/RELEASE.md index ac1510c..30b85d3 100644 --- a/pyannote/RELEASE.md +++ b/pyannote/RELEASE.md @@ -1,3 +1,6 @@ +# 2.0.1 +- Use Qdrant for efficient speaker identification + # 2.0.0 - Add speaker identification - Add progress bar diff --git a/simple/RELEASE.md b/simple/RELEASE.md index 83ef2c2..29fb167 100644 --- a/simple/RELEASE.md +++ b/simple/RELEASE.md @@ -1,3 +1,6 @@ +# 2.0.1 +- Use Qdrant for efficient speaker identification + # 2.0.0 - Add speaker identification From 1fbd964eea0ebb2ca8c34735c59ccc227bb28d01 Mon Sep 17 00:00:00 2001 From: htagourti Date: Thu, 31 Oct 2024 13:03:22 +0000 Subject: [PATCH 11/21] skip wait for qdrant if venv not set --- docker-entrypoint.sh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docker-entrypoint.sh b/docker-entrypoint.sh index d8a3ffa..af44bca 100755 --- a/docker-entrypoint.sh +++ b/docker-entrypoint.sh @@ -24,6 +24,11 @@ check_gpu_availability() { # Wait for Qdrant to be available wait_for_qdrant() { + # Check if QDRANT_HOST and QDRANT_PORT are set + if [[ -z "${QDRANT_HOST}" || -z "${QDRANT_PORT}" ]]; then + echo "Qdrant environment variables are not set. Skipping wait for Qdrant." + return 0 + fi echo "Waiting for Qdrant to be reachable..." /usr/src/app/wait-for-it.sh "${QDRANT_HOST}:${QDRANT_PORT}" --timeout=20 --strict -- echo "Qdrant is up" if [ $? -ne 0 ]; then From c51e848f491becf6e742e441447eb899cd199097 Mon Sep 17 00:00:00 2001 From: htagourti Date: Thu, 31 Oct 2024 13:12:09 +0000 Subject: [PATCH 12/21] added qdrant to simple requirements --- simple/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/simple/requirements.txt b/simple/requirements.txt index a168010..824acee 100644 --- a/simple/requirements.txt +++ b/simple/requirements.txt @@ -12,4 +12,5 @@ speechbrain==1.0.0 torchaudio==2.2.1 onnxruntime-gpu==1.17.1 scipy==1.8.1 # newer version can provoke segmentation faults -numpy==1.23.5 \ No newline at end of file +numpy==1.23.5 +qdrant-client \ No newline at end of file From fb2471725cf757063e0eec32bc88391f8bc17bc4 Mon Sep 17 00:00:00 2001 From: htagourti Date: Thu, 31 Oct 2024 14:28:55 +0000 Subject: [PATCH 13/21] Updated README and .envdefault files --- .envdefault | 9 +++++++++ README.md | 40 ++++++++++++++++++++++++++++++++++------ pyannote/.envdefault | 8 +++++++- pyannote/README.md | 25 +++++++++++++++++-------- simple/.envdefault | 8 +++++++- simple/README.md | 26 ++++++++++++++++++-------- 6 files changed, 92 insertions(+), 24 deletions(-) create mode 100644 .envdefault diff --git a/.envdefault b/.envdefault new file mode 100644 index 0000000..937a0ea --- /dev/null +++ b/.envdefault @@ -0,0 +1,9 @@ +SERVICE_MODE=http +SERVICE_NAME=diarization +SERVICES_BROKER=redis://172.17.0.1:6379 +BROKER_PASS= +CONCURRENCY=2 +QDRANT_HOST=qdrant +QDRANT_PORT=6333 +QDRANT_COLLECTION_NAME=speaker_embeddings +QDRANT_RECREATE_COLLECTION=true \ No newline at end of file diff --git a/README.md b/README.md index e197d75..c55cef7 100644 --- a/README.md +++ b/README.md @@ -20,23 +20,38 @@ In what follow, you can replace "pyannote" by "simple" or "pybk" to try other me ### HTTP Server -1. If needed, build docker image +1. If you want to use speaker identification, make sure Qdrant is running. You can start Qdrant using the following Docker command: + +```bash +docker run + -p 6333:6333 \ # Qdrant default port + -v ./qdrant_storage:/qdrant/storage:z \ + qdrant/qdrant +``` + +2. If needed, build docker image ```bash docker build . -t linto-diarization-pyannote:latest -f pyannote/Dockerfile ``` -2. Launch docker container (and keep it running) +3. Launch docker container (and keep it running) ```bash docker run -it --rm \ -p 8080:80 \ --shm-size=1gb --tmpfs /run/user/0 \ --env SERVICE_MODE=http \ + --env QDRANT_HOST=localhost \ + --env QDRANT_PORT=6333 \ + --env QDRANT_COLLECTION_NAME=speaker_embeddings \ + --env QDRANT_RECREATE_COLLECTION=true \ + --env SERVICE_MODE=http \ linto-diarization-pyannote:latest ``` +Alternatively, you can use docker-compose. -3. Open the swagger in a browser: [http://localhost:8080/docs](http://localhost:8080/docs) +4. Open the swagger in a browser: [http://localhost:8080/docs](http://localhost:8080/docs) Unfold `/diarization` route and click "Try it out". Then - Choose a file - Specify either `speaker_count` (Fixed number of speaker) or `max_speaker` (Max number of speakers) @@ -52,7 +67,16 @@ In the following we assume we want to test on an audio that is in `$HOME/test.wa docker build . -t linto-diarization-pyannote:latest -f pyannote/Dockerfile ``` -2. Run Redis server +2. If you want to use speaker identification, make sure Qdrant is running. You can start Qdrant using the following Docker command: + +```bash +docker run + -p 6333:6333 \ # Qdrant default port + -v ./qdrant_storage:/qdrant/storage:z \ + qdrant/qdrant +``` + +3. Run Redis server ```bash docker run -it --rm \ @@ -61,7 +85,7 @@ docker run -it --rm \ redis-server /etc/redis-stack.conf --protected-mode no --bind 0.0.0.0 --loglevel debug ``` -3. Launch docker container, attaching the volume where is the audio file on which you will test +4. Launch docker container, attaching the volume where is the audio file on which you will test ```bash docker run -it --rm \ @@ -71,10 +95,14 @@ docker run -it --rm \ --env SERVICES_BROKER=redis://172.17.0.1:6379 \ --env BROKER_PASS= \ --env CONCURRENCY=2 \ + --env QDRANT_HOST=localhost \ + --env QDRANT_PORT=6333 \ + --env QDRANT_COLLECTION_NAME=speaker_embeddings \ + --env QDRANT_RECREATE_COLLECTION=true \ linto-diarization-pyannote:latest ``` -3. Testing with a given audio file can be done using python3 (with packages `celery` and `redis` installed). +5. Testing with a given audio file can be done using python3 (with packages `celery` and `redis` installed). For example with the following command for the file `$HOME/test.wav` with 2 speakers ```bash diff --git a/pyannote/.envdefault b/pyannote/.envdefault index d5483ab..4a3fac7 100644 --- a/pyannote/.envdefault +++ b/pyannote/.envdefault @@ -13,4 +13,10 @@ CONCURRENCY=2 # DEVICE=cpu # Maximum number of threads on CPU -NUM_THREADS=4 \ No newline at end of file +NUM_THREADS=4 + +# Qdrant +QDRANT_HOST=qdrant +QDRANT_PORT=6333 +QDRANT_COLLECTION_NAME=speaker_embeddings +QDRANT_RECREATE_COLLECTION=true \ No newline at end of file diff --git a/pyannote/README.md b/pyannote/README.md index 15b8b4e..bf602a6 100644 --- a/pyannote/README.md +++ b/pyannote/README.md @@ -45,6 +45,13 @@ or ```bash docker pull lintoai/linto-diarization-pyannote ``` +For speaker identification, run qdrant : +```bash +docker run + -p 6333:6333 \ # Qdrant default port + -v ./qdrant_storage:/qdrant/storage:z \ + qdrant/qdrant +``` ### HTTP @@ -64,6 +71,10 @@ An example of .env file is provided in [pyannote/.envdefault](https://github.com | `CUDA_VISIBLE_DEVICES` | GPU device index to use, when running on GPU/CUDA. We also recommend to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` on multi-GPU machines | `0` \| `1` \| `2` \| ... | | `SPEAKER_SAMPLES_FOLDER` | (default: `/opt/speaker_samples`) Folder where to find audio files for target speakers samples | `/path/to/folder` | | `SPEAKER_PRECOMPUTED_FOLDER` | (default: `/opt/speaker_precomputed`) Folder where to store precomputed embeddings of target speakers | `/path/to/folder` | +| `QDRANT_HOST` | Host address of the Qdrant instance | `localhost` | +| `QDRANT_PORT` | Port number for the Qdrant instance | `6333` | +| `QDRANT_COLLECTION` | Name of the collection in Qdrant for storing embeddings | `speaker_embeddings` | +| `QDRANT_RECREATE_COLLECTION` | Recreate collection or use existing one from mounted volume | `true` | **2- Run the container** @@ -85,14 +96,8 @@ Then the parent folder of the samples must be mounted as a volume in the contain ```bash docker run ... -v <>:/opt/speaker_samples ``` - -When speaker identification, you can also mount a volume (empty at the beginning) on **`/opt/speaker_precomputed`** -(or a custom folder set with the `SPEAKER_PRECOMPUTED_FOLDER` environment variable), -where will be stored the precomputed embeddings of the speakers. -This can avoid an initialisation time at each new docker run, if the set of target speakers remains the same or just grows. -```bash -docker run ... -v <>:/opt/speaker_precomputed -``` +When speaker identification, if you want to use an existing collection in the volume mounted to the qdrant docker container, you can specify the environment variable `QDRANT_RECREATE_COLLECTION=false` +This can avoid an initialisation time at each new docker run. You may also want to add ```--gpus all``` to enable GPU capabilitiesn and maybe set `CUDA_VISIBLE_DEVICES` if there are several available GPU cards. @@ -117,6 +122,10 @@ Parameters are the [same as for the HTTP API](#http), with the addition of the f | `SERVICE_NAME` | Service's name | `diarization-ml` | | `LANGUAGE` | Language code as a BCP-47 code | `en-US` or * or languages separated by "\|" | | `MODEL_INFO` | Human readable description of the model | `Multilingual diarization model` | +| `QDRANT_HOST` | Host address of the Qdrant instance | `localhost` | +| `QDRANT_PORT` | Port number for the Qdrant instance | `6333` | +| `QDRANT_COLLECTION` | Name of the collection in Qdrant for storing embeddings | `speaker_embeddings` | +| `QDRANT_RECREATE_COLLECTION` | Recreate collection or use existing one from mounted volume | `true` | **2- Fill the docker-compose.yml** diff --git a/simple/.envdefault b/simple/.envdefault index 762b394..f9ed424 100644 --- a/simple/.envdefault +++ b/simple/.envdefault @@ -14,4 +14,10 @@ CONCURRENCY=2 # DEVICE_CLUSTERING=cpu # Maximum number of threads on CPU -NUM_THREADS=4 \ No newline at end of file +NUM_THREADS=4 + +# Qdrant +QDRANT_HOST=qdrant +QDRANT_PORT=6333 +QDRANT_COLLECTION_NAME=speaker_embeddings +QDRANT_RECREATE_COLLECTION=true \ No newline at end of file diff --git a/simple/README.md b/simple/README.md index 75ffcd8..e87f419 100644 --- a/simple/README.md +++ b/simple/README.md @@ -54,6 +54,13 @@ or ```bash docker pull lintoai/linto-diarization-simple ``` +For speaker identification, run qdrant : +```bash +docker run + -p 6333:6333 \ # Qdrant default port + -v ./qdrant_storage:/qdrant/storage:z \ + qdrant/qdrant +``` ### HTTP @@ -73,7 +80,10 @@ An example of .env file is provided in [simple/.envdefault](https://github.com/l | `CUDA_VISIBLE_DEVICES` | GPU device index to use, when running on GPU/CUDA. We also recommend to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` on multi-GPU machines | `0` \| `1` \| `2` \| ... | | `SPEAKER_SAMPLES_FOLDER` | (default: `/opt/speaker_samples`) Folder where to find audio files for target speakers samples | `/path/to/folder` | | `SPEAKER_PRECOMPUTED_FOLDER` | (default: `/opt/speaker_precomputed`) Folder where to store precomputed embeddings of target speakers | `/path/to/folder` | - +| `QDRANT_HOST` | Host address of the Qdrant instance | `localhost` | +| `QDRANT_PORT` | Port number for the Qdrant instance | `6333` | +| `QDRANT_COLLECTION` | Name of the collection in Qdrant for storing embeddings | `speaker_embeddings` | +| `QDRANT_RECREATE_COLLECTION` | Recreate collection or use existing one from mounted volume | `true` | **2- Run the container** @@ -96,13 +106,9 @@ Then the parent folder of the samples must be mounted as a volume in the contain docker run ... -v <>:/opt/speaker_samples ``` -When speaker identification, you can also mount a volume (empty at the beginning) on **`/opt/speaker_precomputed`** -(or a custom folder set with the `SPEAKER_PRECOMPUTED_FOLDER` environment variable), -where will be stored the precomputed embeddings of the speakers. -This can avoid an initialisation time at each new docker run, if the set of target speakers remains the same or just grows. -```bash -docker run ... -v <>:/opt/speaker_precomputed -``` +When speaker identification, if you want to use an existing collection in the volume mounted to the qdrant docker container, you can specify the environment variable `QDRANT_RECREATE_COLLECTION=false` +This can avoid an initialisation time at each new docker run. + You may also want to add ```--gpus all``` to enable GPU capabilitiesn and maybe set `CUDA_VISIBLE_DEVICES` if there are several available GPU cards. @@ -128,6 +134,10 @@ Parameters are the [same as for the HTTP API](#http), with the addition of the f | `SERVICE_NAME` | Service's name | `diarization-ml` | | `LANGUAGE` | Language code as a BCP-47 code | `en-US` or * or languages separated by "\|" | | `MODEL_INFO` | Human readable description of the model | `Multilingual diarization model` | +| `QDRANT_HOST` | Host address of the Qdrant instance | `localhost` | +| `QDRANT_PORT` | Port number for the Qdrant instance | `6333` | +| `QDRANT_COLLECTION` | Name of the collection in Qdrant for storing embeddings | `speaker_embeddings` | +| `QDRANT_RECREATE_COLLECTION` | Recreate collection or use existing one from mounted volume | `true` | **2- Fill the docker-compose.yml** From 33b1b0bb0f6eb6d2a260fa920f9318085a2e758e Mon Sep 17 00:00:00 2001 From: htagourti Date: Thu, 31 Oct 2024 16:04:01 +0000 Subject: [PATCH 14/21] Added the possibility to use exisitng qdrant collection --- docker-compose.yml | 7 +++---- identification/speaker_identify.py | 16 +++++++++++++--- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 8596e9e..091c537 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -13,12 +13,15 @@ services: dockerfile: pyannote/Dockerfile container_name: diarization_app shm_size: '1gb' + stdin_open: true + tty: true ports : - 8080:80 environment: - QDRANT_HOST - QDRANT_PORT - QDRANT_COLLECTION_NAME + - QDRANT_RECREATE_COLLECTION - SERVICE_MODE - SERVICE_NAME - SERVICES_BROKER @@ -35,7 +38,3 @@ services: - driver: nvidia count: 1 capabilities: [gpu] - - -volumes: - qdrant_storage: diff --git a/identification/speaker_identify.py b/identification/speaker_identify.py index ad6629a..1ecb0b4 100644 --- a/identification/speaker_identify.py +++ b/identification/speaker_identify.py @@ -23,6 +23,8 @@ class SpeakerIdentifier: _FOLDER_WAV = os.environ.get("SPEAKER_SAMPLES_FOLDER", "/opt/speaker_samples") _can_identify_twice_the_same_speaker = os.environ.get("CAN_IDENTIFY_TWICE_THE_SAME_SPEAKER", "1").lower() in ["true", "1", "yes"] _UNKNOWN = "<>" + _RECREATE_COLLECTION = os.getenv("QDRANT_RECREATE_COLLECTION", "False").lower() in ["true", "1", "yes"] + def __init__(self, device=None, log=None): self.device = device or self._get_device() @@ -68,9 +70,17 @@ def initialize_speaker_identification( # Check if the collection exists if self.qdrant_client.collection_exists(collection_name=self.qdrant_collection): - if self.log: - self.log.info(f"Deleting existing collection: {self.qdrant_collection}") - self.qdrant_client.delete_collection(collection_name=self.qdrant_collection) + if self._RECREATE_COLLECTION: + if self.log: + self.log.info(f"Deleting existing collection: {self.qdrant_collection}") + self.qdrant_client.delete_collection(collection_name=self.qdrant_collection) + else: + if self.log: + self.log.info(f"Using existing collection: {self.qdrant_collection}") + speakers = self._get_db_speaker_names() + if self.log: + self.log.info(f"Speaker identification initialized with {len(speakers)} speakers") + return # Create collection if self.log: From 00f14759b76bf74c9b5318c4663b63bbe9c921a8 Mon Sep 17 00:00:00 2001 From: htagourti Date: Thu, 31 Oct 2024 20:44:38 +0000 Subject: [PATCH 15/21] fixed bug on embeddings vector shape --- identification/speaker_identify.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/identification/speaker_identify.py b/identification/speaker_identify.py index 1ecb0b4..46a2ffe 100644 --- a/identification/speaker_identify.py +++ b/identification/speaker_identify.py @@ -134,7 +134,7 @@ def initialize_speaker_identification( # Prepare point for Qdrant point = PointStruct( id=speaker_idx+1, - vector=spk_embed.flatten(), # Convert to 1D list for Qdrant [[[1, 2, 3, ...]]] -> [1, 2, 3, ...] + vector=spk_embed[0].flatten(), # Convert to 1D list for Qdrant [[[1, 2, 3, ...]]] -> [1, 2, 3, ...] payload={"person": speaker_name.strip()} ) @@ -344,7 +344,7 @@ def speaker_identify( embedding_audio = self.compute_embedding(audio_selection) # Search for similar embeddings in Qdrant - results = self.qdrant_client.search(self.qdrant_collection, embedding_audio.flatten()) + results = self.qdrant_client.search(self.qdrant_collection, embedding_audio[0].flatten()) for result in results: speaker_name = result.payload["person"] From 7a56b575ae03dadd1947abc8aae88737f5e513e8 Mon Sep 17 00:00:00 2001 From: htagourti Date: Fri, 15 Nov 2024 15:36:06 +0000 Subject: [PATCH 16/21] upgraded pyannote-audio to fix slow diarization problem --- pyannote/requirements.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyannote/requirements.txt b/pyannote/requirements.txt index e54cc9b..6e8ef5f 100644 --- a/pyannote/requirements.txt +++ b/pyannote/requirements.txt @@ -6,9 +6,9 @@ gunicorn>=20.1.0 gevent pyyaml>=5.4.1 supervisor>=4.2.2 -pyannote.audio==3.1.1 -speechbrain==0.5.16 -torchaudio==2.2.1 +pyannote.audio==3.3.2 +speechbrain==1.0.0 +torchaudio==2.4.1 memory-tempfile==2.2.3 # Version 2 of numpy breaks pyannote 3.1.1 (use of np.NaN instead of np.nan) numpy<2 From 7535427f6e5abe275c6438f478f6f3e1368bc132 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Fri, 15 Nov 2024 16:46:10 +0100 Subject: [PATCH 17/21] Trace modifications in release notes --- pyannote/RELEASE.md | 1 + 1 file changed, 1 insertion(+) diff --git a/pyannote/RELEASE.md b/pyannote/RELEASE.md index 30b85d3..ff7d83d 100644 --- a/pyannote/RELEASE.md +++ b/pyannote/RELEASE.md @@ -1,5 +1,6 @@ # 2.0.1 - Use Qdrant for efficient speaker identification +- Update pyannote to 3.3.2 (and speechbrain 1.0.0) # 2.0.0 - Add speaker identification From d650179ce2cc319da90e8e77dbb1393dfcb3b972 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Mon, 18 Nov 2024 06:57:42 +0100 Subject: [PATCH 18/21] remove useless modif --- http_server/ingress.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/http_server/ingress.py b/http_server/ingress.py index 922a08d..731665d 100644 --- a/http_server/ingress.py +++ b/http_server/ingress.py @@ -3,7 +3,6 @@ import json import logging from time import time -import os from confparser import createParser from flask import Flask, Response, abort, json, request @@ -12,7 +11,6 @@ from diarization.processing import diarizationworker, USE_GPU - app = Flask("__diarization-serving__") logging.basicConfig( From 428132793c1d0252b6469edd34e690f326631a0c Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Mon, 18 Nov 2024 07:09:08 +0100 Subject: [PATCH 19/21] Allow to have max_speaker not specified (as for pyannote) --- simple/RELEASE.md | 1 + simple/diarization/processing/speakerdiarization.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/simple/RELEASE.md b/simple/RELEASE.md index 29fb167..51465eb 100644 --- a/simple/RELEASE.md +++ b/simple/RELEASE.md @@ -1,5 +1,6 @@ # 2.0.1 - Use Qdrant for efficient speaker identification +- Specifying max number of speakers is now optional # 2.0.0 - Add speaker identification diff --git a/simple/diarization/processing/speakerdiarization.py b/simple/diarization/processing/speakerdiarization.py index b0caf4f..158edac 100644 --- a/simple/diarization/processing/speakerdiarization.py +++ b/simple/diarization/processing/speakerdiarization.py @@ -8,6 +8,7 @@ import memory_tempfile import torch import werkzeug +import warnings sys.path.append(os.path.join(os.path.dirname(__file__), "simple_diarizer")) import simple_diarizer @@ -210,7 +211,8 @@ def run( self.log.info(f"Starting diarization on file {file_path}") if speaker_count is None and max_speaker is None: - raise Exception("Either speaker_count or max_speaker must be set") + max_speaker = 50 # default value + warnings.warn(f"No speaker count nor maximum specified, using default value {max_speaker=}") try: result = self.run_simple_diarizer( From d2000b9b791242b9c514d288500df4fec95f8eaf Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Mon, 18 Nov 2024 07:09:33 +0100 Subject: [PATCH 20/21] cosm --- pyannote/README.md | 2 +- simple/README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyannote/README.md b/pyannote/README.md index bf602a6..e6b08fd 100644 --- a/pyannote/README.md +++ b/pyannote/README.md @@ -232,7 +232,7 @@ Diarization worker accepts requests with the following arguments: * `file`: (str) Is the relative path of the file in the shared_folder. * `speaker_count`: (int, default None) Fixed number of speakers. * `max_speaker`: (int, default None) Max number of speaker if speaker_count=None. -* `speaker_names`: (string, optional) List of target speaker names, speaker identification (if speaker samples are provided only). Possible values are +* `speaker_names`: (string, default None) List of target speaker names, speaker identification (if speaker samples are provided only). Possible values are * empty string "": no speaker identification * wild card "`*`": speaker identification for all speakers * list of speaker names in json format (ex: "`["speaker1", ..., "speakerN"]`") or separated by `|` (ex: "`speaker1|...|speakerN`"): speaker identification for the listed speakers only diff --git a/simple/README.md b/simple/README.md index e87f419..0f48ace 100644 --- a/simple/README.md +++ b/simple/README.md @@ -244,7 +244,7 @@ Diarization worker accepts requests with the following arguments: * `file`: (str) Is the relative path of the file in the shared_folder. * `speaker_count`: (int, default None) Fixed number of speakers. * `max_speaker`: (int, default None) Max number of speaker if speaker_count=None. -* `speaker_names`: (string, optional) List of target speaker names, speaker identification (if speaker samples are provided only). Possible values are +* `speaker_names`: (string, default None) List of target speaker names, speaker identification (if speaker samples are provided only). Possible values are * empty string "": no speaker identification * wild card "`*`": speaker identification for all speakers * list of speaker names in json format (ex: "`["speaker1", ..., "speakerN"]`") or separated by `|` (ex: "`speaker1|...|speakerN`"): speaker identification for the listed speakers only From 3b2c1fb80c40700b24700618fc9b56e5b724ebe2 Mon Sep 17 00:00:00 2001 From: htagourti Date: Mon, 18 Nov 2024 09:37:59 +0000 Subject: [PATCH 21/21] README update --- README.md | 77 ++++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 70 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index c55cef7..a2e9503 100644 --- a/README.md +++ b/README.md @@ -20,10 +20,18 @@ In what follow, you can replace "pyannote" by "simple" or "pybk" to try other me ### HTTP Server -1. If you want to use speaker identification, make sure Qdrant is running. You can start Qdrant using the following Docker command: +1. If you want to use speaker identification, make sure Qdrant is running. +First, create a custom bridge network so the diarization container can communicate with qdrant : + +```bash +docker network create diarization_network +``` + You can start Qdrant using the following Docker command: ```bash docker run + --name qdrant \ + --network diarization_network \ -p 6333:6333 \ # Qdrant default port -v ./qdrant_storage:/qdrant/storage:z \ qdrant/qdrant @@ -33,23 +41,78 @@ docker run ```bash docker build . -t linto-diarization-pyannote:latest -f pyannote/Dockerfile -``` +``` 3. Launch docker container (and keep it running) +If you want to enable speaker identification, make sure to mount reference speaker audio samples to `/opt/speaker_samples`. + ```bash docker run -it --rm \ + --name linto-diarization \ + --network diarization_network \ -p 8080:80 \ + -v ./data/speakers_samples:/opt/speaker_samples \ # Reference speaker samples. Enables speaker identification --shm-size=1gb --tmpfs /run/user/0 \ --env SERVICE_MODE=http \ - --env QDRANT_HOST=localhost \ - --env QDRANT_PORT=6333 \ - --env QDRANT_COLLECTION_NAME=speaker_embeddings \ - --env QDRANT_RECREATE_COLLECTION=true \ + --env QDRANT_HOST=qdrant \ # Only specify if enabling speaker identification + --env QDRANT_PORT=6333 \ # Only specify if enabling speaker identification + --env QDRANT_COLLECTION_NAME=speaker_embeddings \ # Only specify if enabling speaker identification + --env QDRANT_RECREATE_COLLECTION=true \ # Only specify if enabling speaker identification --env SERVICE_MODE=http \ linto-diarization-pyannote:latest ``` -Alternatively, you can use docker-compose. + +Alternatively, you can use docker-compose : + +```yaml + +services: + qdrant: + image: qdrant/qdrant + container_name: qdrant + ports: + - "6333:6333" # Qdrant default port + volumes: + - ./qdrant_storage:/qdrant/storage:z + + diarization_app: + build: + context : . + dockerfile: pyannote/Dockerfile + container_name: diarization_app + shm_size: '1gb' + stdin_open: true + tty: true + ports : + - 8080:80 + environment: + - QDRANT_HOST + - QDRANT_PORT + - QDRANT_COLLECTION_NAME + - QDRANT_RECREATE_COLLECTION + - SERVICE_MODE + - SERVICE_NAME + - SERVICES_BROKER + - CONCURRENCY + volumes: + - ./data/speakers_samples:/opt/speaker_samples # Reference Speaker samples : This enables speaker identification + depends_on: + - qdrant # Ensure Qdrant starts before the app + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + +``` + +Run it using this command : +```bash +docker compose up +``` 4. Open the swagger in a browser: [http://localhost:8080/docs](http://localhost:8080/docs) Unfold `/diarization` route and click "Try it out". Then