diff --git a/.env_default_http b/.env_default_http new file mode 100644 index 0000000..6aed7a9 --- /dev/null +++ b/.env_default_http @@ -0,0 +1,8 @@ +# SERVING PARAMETERS +SERVICE_MODE=http + +# SERVICE DISCOVERY +SERVICE_NAME=MY_PUNCTUATION_SERVICE + +# CONCURRENCY +CONCURRENCY=2 \ No newline at end of file diff --git a/.env_default_task b/.env_default_task new file mode 100644 index 0000000..9669c52 --- /dev/null +++ b/.env_default_task @@ -0,0 +1,15 @@ +# SERVING PARAMETERS +SERVICE_MODE=task + +# SERVICE PARAMETERS +SERVICES_BROKER=redis://192.168.0.1:6379 +BROKER_PASS=password + +# SERVICE DISCOVERY +SERVICE_NAME=my-diarization-service +LANGUAGE=en-US/fr-FR/* +QUEUE_NAME=(Optionnal) +MODEL_INFO=This model does something + +# CONCURRENCY +CONCURRENCY=2 \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index acbaa60..edda097 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ -FROM python:3.9 -LABEL maintainer="irebai@linagora.com, rbaraglia@linagora.com, wghezaiel@linagora.com" +FROM python:3.10 +LABEL maintainer="rbaraglia@linagora.com, wghezaiel@linagora.com" RUN apt-get update &&\ apt-get install -y \ @@ -31,6 +31,10 @@ COPY document /usr/src/app/document COPY pyBK/diarizationFunctions.py pyBK/diarizationFunctions.py COPY docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./ +# Grep CURRENT VERSION +COPY RELEASE.md ./ +RUN export VERSION=$(awk -v RS='' '/#/ {print; exit}' RELEASE.md | head -1 | sed 's/#//' | sed 's/ //') + ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/diarization" # Limits on OPENBLAS number of thread prevent SEGFAULT on machine with a large number of cpus diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..a9911a2 --- /dev/null +++ b/Makefile @@ -0,0 +1,13 @@ +.DEFAULT_GOAL := help + +target_dirs := http_server pyBK diarization celery_app + +help: + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +style: ## update code style. + black ${target_dirs} + isort ${target_dirs} + +lint: ## run pylint linter. + pylint ${target_dirs} \ No newline at end of file diff --git a/README.md b/README.md index 0660928..9ccda73 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,21 @@ # LINTO-PLATFORM-DIARIZATION -LinTO-platform-diarization is the speaker diarization service within the [LinTO stack](https://github.com/linto-ai/linto-platform-stack). +LinTO-platform-diarization is the [LinTO](https://linto.ai/) service for speaker diarization. -LinTO-platform-diarization can either be used as a standalone diarization service or deployed within a micro-services infrastructure using a message broker connector. +LinTO-platform-diarization can either be used as a standalone diarization service or deployed as a micro-services. + +* [Prerequisites](#pre-requisites) +* [Deploy](#deploy) + * [HTTP](#http) + * [MicroService](#micro-service) +* [Usage](#usages) + * [HTTP API](#http-api) + * [/healthcheck](#healthcheck) + * [/diarization](#diarization) + * [/docs](#docs) + * [Using celery](#using-celery) + +* [License](#license) +*** ## Pre-requisites @@ -9,11 +23,11 @@ LinTO-platform-diarization can either be used as a standalone diarization servic The transcription service requires docker up and running. ### (micro-service) Service broker and shared folder -The diarization only entry point in job mode are tasks posted on a message broker. Supported message broker are RabbitMQ, Redis, Amazon SQS. -On addition, as to prevent large audio from transiting through the message broker, lp-diarization use a shared storage folder. +The diarization only entry point in job mode are tasks posted on a Redis message broker. +Futhermore, to prevent large audio from transiting through the message broker, diarization uses a shared storage folder mounted on /opt/audio. -## Deploy linto-platform-diarization -linto-platform-stt can be deployed three ways: +## Deploy +linto-platform-diarization can be deployed: * As a standalone diarization service through an HTTP API. * As a micro-service connected to a message broker. @@ -22,17 +36,31 @@ linto-platform-stt can be deployed three ways: ```bash git clone https://github.com/linto-ai/linto-platform-diarization.git cd linto-platform-diarization -git submodule init -git submodule update docker build . -t linto-platform-diarization:latest ``` -### HTTP API +### HTTP + +**1- Fill the .env** +```bash +cp .env_default_http .env +``` + +Fill the .env with your values. + +**Parameters:** +| Variables | Description | Example | +|:-|:-|:-| +| SERVING_MODE | Specify launch mode | http | +| CONCURRENCY | Number of HTTP worker* | 1+ | + +**2- Run the container** ```bash docker run --rm \ +-v SHARED_FOLDER:/opt/audio \ -p HOST_SERVING_PORT:80 \ ---env SERVICE_MODE=http \ +--env-file .env \ linto-platform-diarization:latest ``` @@ -42,37 +70,88 @@ This will run a container providing an http API binded on the host HOST_SERVING_ | Variables | Description | Example | |:-|:-|:-| | HOST_SERVING_PORT | Host serving port | 80 | -| CONCURRENCY | Number of HTTP worker* | 1+ | > *diarization uses all CPU available, adding workers will share the available CPU thus decreasing processing speed for concurrent requests -### Micro-service within LinTO-Platform stack ->LinTO-platform-diarization can be deployed within the linto-platform-stack through the use of linto-platform-services-manager. Used this way, the container spawn celery worker waiting for diarization task on a message broker. ->LinTO-platform-diarization in task mode is not intended to be launch manually. ->However, if you intent to connect it to your custom message's broker here are the parameters: +### Using celery +>LinTO-platform-diarization can be deployed as a micro-service using celery. Used this way, the container spawn celery worker waiting for diarization task on a message broker. -You need a message broker up and running at MY_SERVICE_BROKER. +You need a message broker up and running at SERVICES_BROKER. +**1- Fill the .env** ```bash -docker run --rm \ --v AM_PATH:/opt/models/AM \ --v LM_PATH:/opt/models/LM \ --v SHARED_AUDIO_FOLDER:/opt/audio \ ---env SERVICES_BROKER=MY_SERVICE_BROKER \ ---env BROKER_PASS=MY_BROKER_PASS \ ---env SERVICE_MODE=task \ ---env CONCURRENCY=1 \ -linto-platform-diarization:latest +cp .env_default_task .env ``` +Fill the .env with your values. + **Parameters:** | Variables | Description | Example | |:-|:-|:-| +| SERVING_MODE | Specify launch mode | task | | SERVICES_BROKER | Service broker uri | redis://my_redis_broker:6379 | | BROKER_PASS | Service broker password (Leave empty if there is no password) | my_password | -| CONCURRENCY | Number of celery worker* | 1+ | +| QUEUE_NAME | (Optionnal) overide the generated queue's name (See Queue name bellow) | my_queue | +| SERVICE_NAME | Service's name | diarization-ml | +| LANGUAGE | Language code as a BCP-47 code | en-US or * or languages separated by "\|" | +| MODEL_INFO | Human readable description of the model | Multilingual diarization model | +| CONCURRENCY | Number of worker (1 worker = 1 cpu) | >1 | + +**2- Fill the docker-compose.yml** + +`#docker-compose.yml` +```yaml +version: '3.7' + +services: + punctuation-service: + image: linto-platform-diarization:latest + volumes: + - /path/to/shared/folder:/opt/audio + env_file: .env + deploy: + replicas: 1 + networks: + - your-net + +networks: + your-net: + external: true +``` + +**3- Run with docker compose** + +```bash +docker stack deploy --resolve-image always --compose-file docker-compose.yml your_stack +``` + +**Queue name:** + +By default the service queue name is generated using SERVICE_NAME and LANGUAGE: `diarization_{LANGUAGE}_{SERVICE_NAME}`. + +The queue name can be overided using the QUEUE_NAME env variable. + +**Service discovery:** + +As a micro-service, the instance will register itself in the service registry for discovery. The service information are stored as a JSON object in redis's db0 under the id `service:{HOST_NAME}`. + +The following information are registered: + +```json +{ + "service_name": $SERVICE_NAME, + "host_name": $HOST_NAME, + "service_type": "diarization", + "service_language": $LANGUAGE, + "queue_name": $QUEUE_NAME, + "version": "1.2.0", # This repository's version + "info": "Multilingual diarization model", + "last_alive": 65478213, + "concurrency": 1 +} +``` + -> *diarization uses all CPU available, adding workers will share the available CPU thus decreasing processing speed for concurrent requests ## Usages @@ -92,9 +171,9 @@ Diarization API * Method: POST * Response content: application/json -* File: An Wave file -* spk_number: (integer - optional) Number of speakers. If empty, diarization will guess. -* max_speaker: (interger - optional) Max number of speakers if spk_number is empty. +* File: A Wave file +* spk_number: (integer - optional) Number of speakers. If empty, diarization will clusterize automatically. +* max_speaker: (integer - optional) Max number of speakers if spk_number is unknown. Return a json object when using structured as followed: ```json @@ -116,7 +195,7 @@ The /docs route offers a OpenAPI/swagger interface. ### Through the message broker STT-Worker accepts requests with the following arguments: -```file_path: str, with_metadata: bool``` +```file_path: str, speaker_count: int (None), max_speaker: int (None)``` * file_path: (str) Is the location of the file within the shared_folder. /.../SHARED_FOLDER/{file_path} * speaker_count: (int default None) Fixed number of speakers. diff --git a/RELEASE.md b/RELEASE.md index 7d67ad5..45f944f 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,21 @@ +# 1.1.2 +- Added service registration. +- Updated healthcheck to add heartbeat. +- Added possibility to overide generated queue name. +# 1.1.1 +- Fixed: silences (and short occurrences <1 sec between silences) occurring inside a speaker turn were postponed at the end of the speaker turn (and could be arbitrarily assigned to next speaker) +- Fixed: make diarization deterministic (random seed is fixed) +- Tune length of short occurrences to consider as silences (0.3 sec) + +# 1.1.0 +- Changed: loading audio file by AudioSegment toolbox. +- Changed: mfcc are extracted by python_speech_features toolbox. +- Fixed windowRate =< maximumKBMWindowRate. +- Likelihood table is only calculated for the top five gaussian, computation time is reduced. +- Similarity matrix is calculated by Binary keys and cumulative vectors +- Removed: unused AHC. +- Code formated to pep8 + # 1.0.3 - Fixed: diarization failing on short audio when n_speaker > 1 - Fixed (TBT): diarization returning segfault on machine with a lot of CPU diff --git a/celery_app/celeryapp.py b/celery_app/celeryapp.py index ed2edfa..2a75dd0 100644 --- a/celery_app/celeryapp.py +++ b/celery_app/celeryapp.py @@ -1,24 +1,24 @@ import os + from celery import Celery from diarization import logger -celery = Celery(__name__, include=['celery_app.tasks']) +celery = Celery(__name__, include=["celery_app.tasks"]) service_name = os.environ.get("SERVICE_NAME") broker_url = os.environ.get("SERVICES_BROKER") if os.environ.get("BROKER_PASS", False): - components = broker_url.split('//') + components = broker_url.split("//") broker_url = f'{components[0]}//:{os.environ.get("BROKER_PASS")}@{components[1]}' celery.conf.broker_url = "{}/0".format(broker_url) celery.conf.result_backend = "{}/1".format(broker_url) -celery.conf.update( - result_expires=3600, - task_acks_late=True, - task_track_started=True) +celery.conf.update(result_expires=3600, task_acks_late=True, task_track_started=True) # Queues celery.conf.update( - {'task_routes': { - 'diarization_task': {'queue': 'diarization'}, } - } + { + "task_routes": { + "diarization_task": {"queue": "diarization"}, + } + } ) diff --git a/celery_app/register.py b/celery_app/register.py new file mode 100644 index 0000000..9d2b732 --- /dev/null +++ b/celery_app/register.py @@ -0,0 +1,90 @@ +"""The register Module allow registering and unregistering operations within the service stack for service discovery purposes""" +import os +import sys +import uuid +from socket import gethostname +from time import time +from xmlrpc.client import ResponseError + +import redis +from redis.commands.json.path import Path +from redis.commands.search.field import NumericField, TextField +from redis.commands.search.indexDefinition import IndexDefinition, IndexType + +SERVICE_DISCOVERY_DB = 0 +SERVICE_TYPE = "diarization" + +service_name = os.environ.get("SERVICE_NAME", SERVICE_TYPE) +service_lang = os.environ.get("LANGUAGE", "?") +host_name = gethostname() + + +def register(is_heartbeat: bool = False) -> bool: + """Registers the service and act as heartbeat. + + Returns: + bool: registering status + """ + host, port = os.environ.get("SERVICES_BROKER").split("//")[1].split(":") + password = os.environ.get("BROKER_PASS", None) + r = redis.Redis( + host=host, port=int(port), db=SERVICE_DISCOVERY_DB, password=password + ) + + res = r.json().set(f"service:{host_name}", Path.root_path(), service_info()) + if is_heartbeat: + return res + else: + print(f"Service registered as service:{host_name}") + schema = ( + TextField("$.service_name", as_name="service_name"), + TextField("$.service_type", as_name="service_type"), + TextField("$.service_language", as_name="service_language"), + TextField("$.queue_name", as_name="queue_name"), + TextField("$.version", as_name="version"), + TextField("$.info", as_name="info"), + NumericField("$.last_alive", as_name="last_alive"), + NumericField("$.concurrency", as_name="concurrency"), + ) + try: + r.ft().create_index( + schema, + definition=IndexDefinition(prefix=["service:"], index_type=IndexType.JSON), + ) + except Exception as error: + pass + return res + + +def unregister() -> None: + """Un-register the service""" + try: + host, port = os.environ.get("SERVICES_BROKER").split("//")[1].split(":") + r = redis.Redis( + host=host, port=int(port), db=SERVICE_DISCOVERY_DB, password="password" + ) + r.json().delete(f"service:{host_name}") + except Exception as error: + print(f"Failed to unregister: {repr(error)}") + + +def queue() -> str: + return os.environ.get("QUEUE_NAME", f"{SERVICE_TYPE}_{service_lang}_{service_name}") + + +def service_info() -> dict: + return { + "service_name": service_name, + "host_name": host_name, + "service_type": SERVICE_TYPE, + "service_language": service_lang, + "queue_name": queue(), + "version": "1.1.2", + "info": os.environ.get("MODEL_INFO", "unknown"), + "last_alive": int(time()), + "concurrency": int(os.environ.get("CONCURRENCY")), + } + + +if __name__ == "__main__": + sys.exit(register()) diff --git a/celery_app/tasks.py b/celery_app/tasks.py index bb245cd..b3ab130 100644 --- a/celery_app/tasks.py +++ b/celery_app/tasks.py @@ -1,12 +1,15 @@ -import os import json +import os + from celery_app.celeryapp import celery from diarization.processing.speakerdiarization import SpeakerDiarization @celery.task(name="diarization_task") -def diarization_task(file_name: str, speaker_count: int = None, max_speaker: int = None): - """ transcribe_task do a synchronous call to the transcribe worker API """ +def diarization_task( + file_name: str, speaker_count: int = None, max_speaker: int = None +): + """transcribe_task do a synchronous call to the transcribe worker API""" if not os.path.isfile(os.path.join("/opt/audio", file_name)): raise Exception("Could not find ressource {}".format(file_name)) @@ -20,8 +23,11 @@ def diarization_task(file_name: str, speaker_count: int = None, max_speaker: int # Processing try: diarizationworker = SpeakerDiarization() - result = diarizationworker.run(os.path.join( - "/opt/audio", file_name), number_speaker=speaker_count, max_speaker=max_speaker) + result = diarizationworker.run( + os.path.join("/opt/audio", file_name), + number_speaker=speaker_count, + max_speaker=max_speaker, + ) response = diarizationworker.format_response(result) except Exception as e: raise Exception("Diarization has failed : {}".format(e)) diff --git a/diarization/__init__.py b/diarization/__init__.py index 33b20af..82ff433 100644 --- a/diarization/__init__.py +++ b/diarization/__init__.py @@ -1,5 +1,7 @@ -import os import logging -logging.basicConfig(format='%(asctime)s %(name)s %(levelname)s: %(message)s', datefmt='%d/%m/%Y %H:%M:%S') -logger = logging.getLogger("__diarization-serving__") \ No newline at end of file +logging.basicConfig( + format="%(asctime)s %(name)s %(levelname)s: %(message)s", + datefmt="%d/%m/%Y %H:%M:%S", +) +logger = logging.getLogger("__diarization-serving__") diff --git a/diarization/processing/speakerdiarization.py b/diarization/processing/speakerdiarization.py index d4e7750..3c24d58 100644 --- a/diarization/processing/speakerdiarization.py +++ b/diarization/processing/speakerdiarization.py @@ -1,21 +1,24 @@ #!/usr/bin/env python3 +import logging import os import time -import logging import uuid -import numpy as np + import librosa -import webrtcvad +import numpy as np import pyBK.diarizationFunctions as pybk -#from spafe.features.mfcc import mfcc, imfcc + +# from spafe.features.mfcc import mfcc, imfcc from pydub import AudioSegment from python_speech_features import mfcc +import pyBK.diarizationFunctions as pybk + class SpeakerDiarization: def __init__(self): - self.log = logging.getLogger('__speaker-diarization__' + __name__) + self.log = logging.getLogger("__speaker-diarization__" + __name__) if os.environ.get("DEBUG", False) in ["1", 1, "true", "True"]: self.log.setLevel(logging.DEBUG) @@ -48,21 +51,23 @@ def __init__(self): # BINARY_KEY self.topGaussiansPerFrame = 5 # Number of top selected components per frame - self.bitsPerSegmentFactor = 0.2 # Percentage of bits set to 1 in the binary keys + self.bitsPerSegmentFactor = ( + 0.2 # Percentage of bits set to 1 in the binary keys + ) # CLUSTERING self.N_init = 25 # Number of initial clusters # Linkage criterion used if linkage==1 ('average', 'single', 'complete') - self.linkageCriterion = 'average' + self.linkageCriterion = "average" # Similarity metric: 'cosine' for cumulative vectors, and 'jaccard' for binary keys - self.metric = 'cosine' + self.metric = "cosine" # CLUSTERING_SELECTION # Distance metric used in the selection of the output clustering solution ('jaccard','cosine') - self.metric_clusteringSelection = 'cosine' + self.metric_clusteringSelection = "cosine" # Method employed for number of clusters selection. Can be either 'elbow' for an elbow criterion based on within-class sum of squares (WCSS) or 'spectral' for spectral clustering - self.bestClusteringCriterion = 'spectral' + self.bestClusteringCriterion = "spectral" self.sigma = 1 # Spectral clustering parameters, employed if bestClusteringCriterion == spectral self.percentile = 80 self.maxNrSpeakers = 20 # If known, max nr of speakers in a sesssion in the database. This is to limit the effect of changes in very small meaningless eigenvalues values generating huge eigengaps @@ -70,52 +75,57 @@ def __init__(self): # RESEGMENTATION self.resegmentation = 1 # Set to 1 to perform re-segmentation self.modelSize = 16 # Number of GMM components - self.modelSize = 16 # Number of GMM components self.nbIter = 5 # Number of expectation-maximization (EM) iterations self.smoothWin = 100 # Size of the likelihood smoothing window in nb of frames + # Pseudo-randomness + self.seed = 0 + + # Short segments to ignore + self.min_duration = 0.3 + def compute_feat_Librosa(self, audioFile): try: if type(audioFile) is not str: filename = str(uuid.uuid4()) - file_path = "/tmp/"+filename + file_path = "/tmp/" + filename audioFile.save(file_path) else: file_path = audioFile self.sr = 16000 - y = AudioSegment.from_wav(file_path) - self.data = np.array(y.get_array_of_samples()) + audio = AudioSegment.from_wav(file_path) + audio = audio.set_frame_rate(self.sr) + audio = audio.set_channels(1) + self.data = np.array(audio.get_array_of_samples()) if type(audioFile) is not str: os.remove(file_path) frame_length_inSample = self.frame_length_s * self.sr hop = int(self.frame_shift_s * self.sr) - NFFT = int(2**np.ceil(np.log2(frame_length_inSample))) - + NFFT = int(2 ** np.ceil(np.log2(frame_length_inSample))) + framelength_in_samples = self.frame_length_s * self.sr n_fft = int(2 ** np.ceil(np.log2(framelength_in_samples))) - + additional_kwargs = {} if self.sr >= 16000: - additional_kwargs.update({"lowfreq": 20, "highfreq": 7600}) + additional_kwargs.update({"lowfreq": 20, "highfreq": 7600}) - mfcc_coef = mfcc( - signal=self.data, - samplerate=self.sr, - numcep=30, - nfilt=30, - nfft=n_fft, - winlen=0.03, - winstep=0.01, - **additional_kwargs, - ) - + signal=self.data, + samplerate=self.sr, + numcep=30, + nfilt=30, + nfft=n_fft, + winlen=0.03, + winstep=0.01, + **additional_kwargs, + ) + except Exception as e: self.log.error(e) - raise ValueError( - "Speaker diarization failed when extracting features!!!") + raise ValueError("Speaker diarization failed when extracting features!!!") else: return mfcc_coef @@ -126,51 +136,60 @@ def computeVAD_WEBRTC(self, data, sr, nFeatures): sr = 16000 va_framed = pybk.py_webrtcvad( - data, fs=sr, fs_vad=sr, hoplength=30, vad_mode=0) + data, fs=sr, fs_vad=sr, hoplength=30, vad_mode=0 + ) segments = pybk.get_py_webrtcvad_segments(va_framed, sr) maskSAD = np.zeros([1, nFeatures]) for seg in segments: - start = int(np.round(seg[0]/self.frame_shift_s)) - end = int(np.round(seg[1]/self.frame_shift_s)) + start = int(np.round(seg[0] / self.frame_shift_s)) + end = int(np.round(seg[1] / self.frame_shift_s)) maskSAD[0][start:end] = 1 except Exception as e: self.log.error(e) raise ValueError( - "Speaker diarization failed while voice activity detection!!!") + "Speaker diarization failed while voice activity detection!!!" + ) else: return maskSAD def getSegments(self, frameshift, finalSegmentTable, finalClusteringTable, dur): - numberOfSpeechFeatures = finalSegmentTable[-1, 2].astype(int)+1 + numberOfSpeechFeatures = finalSegmentTable[-1, 2].astype(int) + 1 solutionVector = np.zeros([1, numberOfSpeechFeatures]) - for i in np.arange(np.size(finalSegmentTable, 0)): - solutionVector[0, np.arange( - finalSegmentTable[i, 1], finalSegmentTable[i, 2]+1).astype(int)] = finalClusteringTable[i] + for i in range(np.size(finalSegmentTable, 0)): + solutionVector[ + 0, + np.arange(finalSegmentTable[i, 1], finalSegmentTable[i, 2] + 1).astype( + int + ), + ] = finalClusteringTable[i] seg = np.empty([0, 3]) solutionDiff = np.diff(solutionVector)[0] first = 0 - for i in np.arange(0, np.size(solutionDiff, 0)): + for i in range(0, np.size(solutionDiff, 0)): if solutionDiff[i]: - last = i+1 - seg1 = (first)*frameshift - seg2 = (last-first)*frameshift - seg3 = solutionVector[0, last-1] - if seg.shape[0] != 0 and seg3 == seg[-1][2]: - seg[-1][1] += seg2 - elif seg3 and seg2 > 1: # and seg2 > 0.1 - seg = np.vstack((seg, [seg1, seg2, seg3])) - first = i+1 + last = i + 1 + start = (first) * frameshift + duration = (last - first) * frameshift + spklabel = solutionVector[0, last - 1] + silence = not spklabel or duration <= self.min_duration + if seg.shape[0] != 0 and (spklabel == seg[-1][2] or silence): + seg[-1][1] += duration + elif not silence: + seg = np.vstack((seg, [start, duration, spklabel])) + else: # First silence + continue + first = i + 1 last = np.size(solutionVector, 1) - seg1 = (first-1)*frameshift - seg2 = (last-first+1)*frameshift - seg3 = solutionVector[0, last-1] - if seg3 == seg[-1][2]: - seg[-1][1] += seg2 - elif seg3 and seg2 > 1: # and seg2 > 0.1 - seg = np.vstack((seg, [seg1, seg2, seg3])) - seg = np.vstack((seg, [dur, -1, -1])) - seg[0][0] = 0.0 + start = (first - 1) * frameshift + duration = (last - first + 1) * frameshift + spklabel = solutionVector[0, last - 1] + silence = not spklabel or duration <= self.min_duration + if spklabel == seg[-1][2] or silence: + seg[-1][1] += duration + else: + seg = np.vstack((seg, [start, duration, spklabel])) + seg = np.vstack((seg, [dur, -1, -1])) # Why? return seg def format_response(self, segments: list) -> dict: @@ -216,52 +235,56 @@ def format_response(self, segments: list) -> dict: # Remove the last line of the segments. # It indicates the end of the file and segments. - if segments[len(segments)-1][2] == -1: - segments = segments[:len(segments)-1] + if segments[len(segments) - 1][2] == -1: + segments = segments[: len(segments) - 1] for seg in segments: segment = {} - segment['seg_id'] = seg_id + segment["seg_id"] = seg_id # Ensure speaker id continuity and numbers speaker by order of appearance. if seg[2] not in spk_i_dict.keys(): spk_i_dict[seg[2]] = spk_i spk_i += 1 - segment['spk_id'] = 'spk'+str(spk_i_dict[seg[2]]) - segment['seg_begin'] = float("{:.2f}".format(seg[0])) - segment['seg_end'] = float("{:.2f}".format(seg[0] + seg[1])) + segment["spk_id"] = "spk" + str(spk_i_dict[seg[2]]) + segment["seg_begin"] = float("{:.2f}".format(seg[0])) + segment["seg_end"] = float("{:.2f}".format(seg[0] + seg[1])) - if segment['spk_id'] not in _speakers: - _speakers[segment['spk_id']] = {} - _speakers[segment['spk_id']]['spk_id'] = segment['spk_id'] - _speakers[segment['spk_id']]['duration'] = float( - "{:.2f}".format(seg[1])) - _speakers[segment['spk_id']]['nbr_seg'] = 1 + if segment["spk_id"] not in _speakers: + _speakers[segment["spk_id"]] = {} + _speakers[segment["spk_id"]]["spk_id"] = segment["spk_id"] + _speakers[segment["spk_id"]]["duration"] = float( + "{:.2f}".format(seg[1]) + ) + _speakers[segment["spk_id"]]["nbr_seg"] = 1 else: - _speakers[segment['spk_id']]['duration'] += seg[1] - _speakers[segment['spk_id']]['nbr_seg'] += 1 - _speakers[segment['spk_id']]['duration'] = float( - "{:.2f}".format(_speakers[segment['spk_id']]['duration'])) + _speakers[segment["spk_id"]]["duration"] += seg[1] + _speakers[segment["spk_id"]]["nbr_seg"] += 1 + _speakers[segment["spk_id"]]["duration"] = float( + "{:.2f}".format(_speakers[segment["spk_id"]]["duration"]) + ) _segments.append(segment) seg_id += 1 - json['speakers'] = list(_speakers.values()) - json['segments'] = _segments + json["speakers"] = list(_speakers.values()) + json["segments"] = _segments return json def run(self, audioFile, number_speaker: int = None, max_speaker: int = None): self.log.debug(f"Starting diarization on file {audioFile}") try: start_time = time.time() - self.log.debug("Extracting features ... (t={:.2f}s)".format( - time.time() - start_time)) + self.log.debug( + "Extracting features ... (t={:.2f}s)".format(time.time() - start_time) + ) feats = self.compute_feat_Librosa(audioFile) nFeatures = feats.shape[0] duration = nFeatures * self.frame_shift_s - self.log.debug("Computing SAD Mask ... (t={:.2f}s)".format( - time.time() - start_time)) + self.log.debug( + "Computing SAD Mask ... (t={:.2f}s)".format(time.time() - start_time) + ) maskSAD = self.computeVAD_WEBRTC(self.data, self.sr, nFeatures) maskUEM = np.ones([1, nFeatures]) @@ -271,118 +294,137 @@ def run(self, audioFile, number_speaker: int = None, max_speaker: int = None): speechMapping = np.zeros(nFeatures) # you need to start the mapping from 1 and end it in the actual number of features independently of the indexing style # so that we don't lose features on the way - speechMapping[np.nonzero(mask)] = np.arange(1, nSpeechFeatures+1) + speechMapping[np.nonzero(mask)] = np.arange(1, nSpeechFeatures + 1) data = feats[np.where(mask == 1)] del feats - self.log.debug("Computing segment table ... (t={:.2f}s)".format( - time.time() - start_time)) + self.log.debug( + "Computing segment table ... (t={:.2f}s)".format( + time.time() - start_time + ) + ) segmentTable = pybk.getSegmentTable( - mask, speechMapping, self.seg_length, self.seg_increment, self.seg_rate) + mask, speechMapping, self.seg_length, self.seg_increment, self.seg_rate + ) numberOfSegments = np.size(segmentTable, 0) self.log.debug(f"Number of segment: {numberOfSegments}") if numberOfSegments == 1: self.log.debug(f"Single segment: returning") - return [[0, duration, 1], - [duration, -1, -1]] + return [[0, duration, 1], [duration, -1, -1]] # create the KBM # set the window rate in order to obtain "minimumNumberOfInitialGaussians" gaussians - windowRate = np.floor((nSpeechFeatures-self.windowLength)/self.minimumNumberOfInitialGaussians) + windowRate = np.floor( + (nSpeechFeatures - self.windowLength) + / self.minimumNumberOfInitialGaussians + ) if windowRate > self.maximumKBMWindowRate: windowRate = self.maximumKBMWindowRate elif windowRate == 0: windowRate = 1 - - - poolSize = np.floor((nSpeechFeatures-self.windowLength)/windowRate) + poolSize = np.floor((nSpeechFeatures - self.windowLength) / windowRate) if self.useRelativeKBMsize: - kbmSize = int(np.floor(poolSize*self.relKBMsize)) + kbmSize = int(np.floor(poolSize * self.relKBMsize)) else: kbmSize = int(self.kbmSize) # Training pool of',int(poolSize),'gaussians with a rate of',int(windowRate),'frames' - self.log.debug("Training KBM ... (t={:.2f}s)".format( - time.time() - start_time)) - kbm, gmPool = pybk.trainKBM( - data, self.windowLength, windowRate, kbmSize) + self.log.debug( + "Training KBM ... (t={:.2f}s)".format(time.time() - start_time) + ) + kbm, gmPool = pybk.trainKBM(data, self.windowLength, windowRate, kbmSize) #'Selected',kbmSize,'gaussians from the pool' Vg = pybk.getVgMatrix(data, gmPool, kbm, self.topGaussiansPerFrame) #'Computing binary keys for all segments... ' - self.log.debug("Computing binary keys ... (t={:.2f}s)".format( - time.time() - start_time)) - segmentBKTable, segmentCVTable = pybk.getSegmentBKs(segmentTable, - kbmSize, - Vg, - self.bitsPerSegmentFactor, - speechMapping) + self.log.debug( + "Computing binary keys ... (t={:.2f}s)".format(time.time() - start_time) + ) + segmentBKTable, segmentCVTable = pybk.getSegmentBKs( + segmentTable, kbmSize, Vg, self.bitsPerSegmentFactor, speechMapping + ) #'Performing initial clustering... ' - self.log.debug("Performing initial clustering ... (t={:.2f}s)".format( - time.time() - start_time)) - initialClustering = np.digitize(np.arange(numberOfSegments), - np.arange(0, numberOfSegments, numberOfSegments/self.N_init)) - - #'Performing agglomerative clustering... ' - self.log.debug("Performing agglomerative clustering ... (t={:.2f}s)".format( - time.time() - start_time)) - finalClusteringTable, k = pybk.performClusteringLinkage(segmentBKTable, - segmentCVTable, - self.N_init, - self.linkageCriterion, - self.metric) + self.log.debug( + "Performing initial clustering ... (t={:.2f}s)".format( + time.time() - start_time + ) + ) #'Selecting best clustering...' # self.bestClusteringCriterion == 'spectral': - self.log.debug("Selecting best clustering ... (t={:.2f}s)".format( - time.time() - start_time)) - bestClusteringID = pybk.getSpectralClustering(self.metric_clusteringSelection, - self.N_init, - segmentBKTable, - segmentCVTable, - number_speaker, - k, - self.sigma, - self.percentile, - max_speaker if max_speaker is not None else self.maxNrSpeakers)+1 + self.log.debug( + "Selecting best clustering ... (t={:.2f}s)".format( + time.time() - start_time + ) + ) + bestClusteringID = ( + pybk.getSpectralClustering( + self.metric_clusteringSelection, + self.N_init, + segmentBKTable, + segmentCVTable, + number_speaker, + self.sigma, + self.percentile, + max_speaker if max_speaker is not None else self.maxNrSpeakers, + random_state=self.seed, + ) + + 1 + ) if self.resegmentation and np.size(np.unique(bestClusteringID), 0) > 1: - self.log.debug("Performing resegmentation ... (t={:.2f}s)".format( - time.time() - start_time)) - finalClusteringTableResegmentation, finalSegmentTable = pybk.performResegmentation(data, - speechMapping, - mask, - bestClusteringID, - segmentTable, - self.modelSize, - self.nbIter, - self.smoothWin, - nSpeechFeatures) - self.log.debug("Get segments ... (t={:.2f}s)".format( - time.time() - start_time)) - segments = self.getSegments(self.frame_shift_s, - finalSegmentTable, - np.squeeze( - finalClusteringTableResegmentation), - duration) + self.log.debug( + "Performing resegmentation ... (t={:.2f}s)".format( + time.time() - start_time + ) + ) + ( + finalClusteringTableResegmentation, + finalSegmentTable, + ) = pybk.performResegmentation( + data, + speechMapping, + mask, + bestClusteringID, + segmentTable, + self.modelSize, + self.nbIter, + self.smoothWin, + nSpeechFeatures, + ) + self.log.debug( + "Get segments ... (t={:.2f}s)".format(time.time() - start_time) + ) + segments = self.getSegments( + self.frame_shift_s, + finalSegmentTable, + np.squeeze(finalClusteringTableResegmentation), + duration, + ) else: - return [[0, duration, 1], - [duration, -1, -1]] + return [[0, duration, 1], [duration, -1, -1]] - self.log.info("Speaker Diarization took %d[s] with a speed %0.2f[xRT]" % - (int(time.time() - start_time), float(int(time.time() - start_time)/duration))) + self.log.info( + "Speaker Diarization took %d[s] with a speed %0.2f[xRT]" + % ( + int(time.time() - start_time), + float(int(time.time() - start_time) / duration), + ) + ) except ValueError as v: self.log.error(v) raise ValueError( - 'Speaker diarization failed during processing the speech signal') + "Speaker diarization failed during processing the speech signal" + ) except Exception as e: self.log.error(e) raise Exception( - 'Speaker diarization failed during processing the speech signal') + "Speaker diarization failed during processing the speech signal" + ) else: self.log.debug(self.format_response(segments)) return segments diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..477128a --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,16 @@ +version: '3.7' + +services: + my-diarization-service: + image: linto-platform-diarization:latest + volumes: + - /path/to/shared/folder:/opt/audio + env_file: .env + deploy: + replicas: 1 + networks: + - your-net + +networks: + your-net: + external: true diff --git a/docker-entrypoint.sh b/docker-entrypoint.sh index b67a196..0c5fd94 100755 --- a/docker-entrypoint.sh +++ b/docker-entrypoint.sh @@ -1,6 +1,4 @@ #!/bin/bash -set -ea - echo "RUNNING Diarization" # Launch parameters, environement variables and dependencies check @@ -18,11 +16,26 @@ else if [[ -z "$SERVICES_BROKER" ]] then echo "ERROR: SERVICES_BROKER variable not specified, cannot start celery worker." - return -1 + exit -1 fi + echo "Running celery worker" /usr/src/app/wait-for-it.sh $(echo $SERVICES_BROKER | cut -d'/' -f 3) --timeout=20 --strict -- echo " $SERVICES_BROKER (Service Broker) is up" - echo "RUNNING STT CELERY WORKER" - celery --app=celery_app.celeryapp worker -Ofair -n diarization_worker@%h --queues=diarization -c $CONCURRENCY + # MICRO SERVICE + ## QUEUE NAME + QUEUE=$(python -c "from celery_app.register import queue; exit(queue())" 2>&1) + echo "Service set to $QUEUE" + + ## REGISTRATION + python -c "from celery_app.register import register; register()" + echo "Service registered" + + ## WORKER + celery --app=celery_app.celeryapp worker -Ofair -n diarization_worker@%h --queues=$QUEUE -c $CONCURRENCY + + ## UNREGISTERING + python -c "from celery_app.register import unregister; unregister()" + echo "Service unregistered" + else echo "ERROR: Wrong serving command: $1" exit -1 diff --git a/healthcheck.sh b/healthcheck.sh index 693877e..68d6b0c 100755 --- a/healthcheck.sh +++ b/healthcheck.sh @@ -6,5 +6,9 @@ if [ "$SERVICE_MODE" = "http" ] then curl --fail http://localhost:80/healthcheck || exit 1 else + # Update last alive + python -c "from celery_app.register import register; register()" + + # Ping worker celery --app=celery_app.celeryapp inspect ping -d diarization_worker@$HOSTNAME || exit 1 fi diff --git a/http_server/confparser.py b/http_server/confparser.py index 49513db..aa3e2f3 100644 --- a/http_server/confparser.py +++ b/http_server/confparser.py @@ -1,5 +1,5 @@ -import os import argparse +import os __all__ = ["createParser"] @@ -9,44 +9,39 @@ def createParser() -> argparse.ArgumentParser: # SERVICE parser.add_argument( - '--service_name', + "--service_name", type=str, - help='Service Name', - default=os.environ.get('SERVICE_NAME', 'diarization')) + help="Service Name", + default=os.environ.get("SERVICE_NAME", "diarization"), + ) # GUNICORN + parser.add_argument("--service_port", type=int, help="Service port", default=80) parser.add_argument( - '--service_port', - type=int, - help='Service port', - default=80) - parser.add_argument( - '--workers', + "--workers", type=int, help="Number of Gunicorn workers (default=CONCURRENCY + 1)", - default=int(os.environ.get('CONCURRENCY', 1)) + 1) + default=int(os.environ.get("CONCURRENCY", 1)) + 1, + ) # SWAGGER parser.add_argument( - '--swagger_url', - type=str, - help='Swagger interface url', - default='/docs') + "--swagger_url", type=str, help="Swagger interface url", default="/docs" + ) parser.add_argument( - '--swagger_prefix', + "--swagger_prefix", type=str, - help='Swagger prefix', - default=os.environ.get('SWAGGER_PREFIX', '')) + help="Swagger prefix", + default=os.environ.get("SWAGGER_PREFIX", ""), + ) parser.add_argument( - '--swagger_path', + "--swagger_path", type=str, - help='Swagger file path', - default=os.environ.get('SWAGGER_PATH', '/usr/src/app/document/swagger.yml')) + help="Swagger file path", + default=os.environ.get("SWAGGER_PATH", "/usr/src/app/document/swagger.yml"), + ) # MISC - parser.add_argument( - '--debug', - action='store_true', - help='Display debug logs') + parser.add_argument("--debug", action="store_true", help="Display debug logs") return parser diff --git a/http_server/ingress.py b/http_server/ingress.py index d5a6112..def9693 100644 --- a/http_server/ingress.py +++ b/http_server/ingress.py @@ -1,14 +1,13 @@ #!/usr/bin/env python3 +import json +import logging import os from time import time -import logging -import json -from flask import Flask, request, abort, Response, json - -from serving import GunicornServing from confparser import createParser +from flask import Flask, Response, abort, json, request +from serving import GunicornServing from swagger import setupSwaggerUI from diarization.processing.speakerdiarization import SpeakerDiarization @@ -16,32 +15,34 @@ app = Flask("__diarization-serving__") logging.basicConfig( - format='%(asctime)s %(name)s %(levelname)s: %(message)s', datefmt='%d/%m/%Y %H:%M:%S') + format="%(asctime)s %(name)s %(levelname)s: %(message)s", + datefmt="%d/%m/%Y %H:%M:%S", +) logger = logging.getLogger("__diarization-serving__") -@app.route('/healthcheck', methods=['GET']) +@app.route("/healthcheck", methods=["GET"]) def healthcheck(): return json.dumps({"healthcheck": "OK"}), 200 -@app.route("/oas_docs", methods=['GET']) +@app.route("/oas_docs", methods=["GET"]) def oas_docs(): return "Not Implemented", 501 -@app.route('/diarization', methods=['POST']) +@app.route("/diarization", methods=["POST"]) def transcribe(): try: - logger.info('Diarization request received') + logger.info("Diarization request received") # get response content type - logger.debug(request.headers.get('accept').lower()) - if not request.headers.get('accept').lower() == 'application/json': - raise ValueError('Not accepted header') + logger.debug(request.headers.get("accept").lower()) + if not request.headers.get("accept").lower() == "application/json": + raise ValueError("Not accepted header") # get input file - if 'file' in request.files.keys(): + if "file" in request.files.keys(): spk_number = request.form.get("spk_number", None) if spk_number is not None: spk_number = int(spk_number) @@ -50,20 +51,21 @@ def transcribe(): max_spk_number = int(max_spk_number) start_t = time() else: - raise ValueError('No audio file was uploaded') + raise ValueError("No audio file was uploaded") except ValueError as error: return str(error), 400 except Exception as e: logger.error(e) - return 'Server Error: {}'.format(str(e)), 500 + return "Server Error: {}".format(str(e)), 500 # Diarization try: diarizationworker = SpeakerDiarization() result = diarizationworker.run( - request.files['file'], number_speaker=spk_number, max_speaker=max_spk_number) + request.files["file"], number_speaker=spk_number, max_speaker=max_spk_number + ) except Exception as e: - return 'Diarization has failed: {}'.format(str(e)), 500 + return "Diarization has failed: {}".format(str(e)), 500 response = diarizationworker.format_response(result) logger.debug("Diarization complete (t={}s)".format(time() - start_t)) @@ -74,21 +76,21 @@ def transcribe(): # Rejected request handlers @app.errorhandler(405) def method_not_allowed(error): - return 'The method is not allowed for the requested URL', 405 + return "The method is not allowed for the requested URL", 405 @app.errorhandler(404) def page_not_found(error): - return 'The requested URL was not found', 404 + return "The requested URL was not found", 404 @app.errorhandler(500) def server_error(error): logger.error(error) - return 'Server Error', 500 + return "Server Error", 500 -if __name__ == '__main__': +if __name__ == "__main__": logger.info("Startup...") parser = createParser() @@ -102,9 +104,14 @@ def server_error(error): except Exception as e: logger.warning("Could not setup swagger: {}".format(str(e))) - serving = GunicornServing(app, {'bind': '{}:{}'.format("0.0.0.0", args.service_port), - 'workers': args.workers, - 'timeout': 3600}) + serving = GunicornServing( + app, + { + "bind": "{}:{}".format("0.0.0.0", args.service_port), + "workers": args.workers, + "timeout": 3600, + }, + ) logger.info(args) try: serving.run() diff --git a/http_server/serving.py b/http_server/serving.py index 3e3eead..d2dd7e8 100644 --- a/http_server/serving.py +++ b/http_server/serving.py @@ -2,15 +2,17 @@ class GunicornServing(gunicorn.app.base.BaseApplication): - def __init__(self, app, options=None): self.options = options or {} self.application = app super().__init__() def load_config(self): - config = {key: value for key, value in self.options.items() - if key in self.cfg.settings and value is not None} + config = { + key: value + for key, value in self.options.items() + if key in self.cfg.settings and value is not None + } for key, value in config.items(): self.cfg.set(key.lower(), value) diff --git a/http_server/swagger.py b/http_server/swagger.py index 843205f..32d7432 100644 --- a/http_server/swagger.py +++ b/http_server/swagger.py @@ -1,17 +1,17 @@ import yaml from flask_swagger_ui import get_swaggerui_blueprint + def setupSwaggerUI(app, args): - '''Setup Swagger UI within the app''' - swagger_yml = yaml.load( - open(args.swagger_path, 'r'), Loader=yaml.Loader) + """Setup Swagger UI within the app""" + swagger_yml = yaml.load(open(args.swagger_path, "r"), Loader=yaml.Loader) swaggerui = get_swaggerui_blueprint( # Swagger UI static files will be mapped to '{SWAGGER_URL}/dist/' args.swagger_prefix + args.swagger_url, args.swagger_path, config={ # Swagger UI config overrides - 'app_name': "LinTO Platform Diarization", - 'spec': swagger_yml - } + "app_name": "LinTO Platform Diarization", + "spec": swagger_yml, + }, ) - app.register_blueprint(swaggerui, url_prefix=args.swagger_url) \ No newline at end of file + app.register_blueprint(swaggerui, url_prefix=args.swagger_url) diff --git a/pyBK/diarizationFunctions.py b/pyBK/diarizationFunctions.py index 6cb4d54..db3d280 100644 --- a/pyBK/diarizationFunctions.py +++ b/pyBK/diarizationFunctions.py @@ -3,37 +3,45 @@ # http://www.eurecom.fr/en/people/patino-jose # Contact: patino[at]eurecom[dot]fr, josempatinovillar[at]gmail[dot]com +import numpy as np +import scipy +import scipy.sparse as sparse +import sklearn +from scipy import sparse +from scipy.linalg import eigh from scipy.ndimage import gaussian_filter -from sklearn.neighbors import kneighbors_graph -from scipy.sparse.csgraph import laplacian as csgraph_laplacian +from scipy.sparse import csr_matrix from scipy.sparse.csgraph import connected_components +from scipy.sparse.csgraph import laplacian as csgraph_laplacian from scipy.sparse.linalg import eigsh, lobpcg -from scipy.sparse import csr_matrix -from scipy.linalg import eigh -from scipy import sparse -import scipy.sparse as sparse -import scipy -import sklearn -from sklearn.preprocessing import MinMaxScaler -from sklearn.base import BaseEstimator, ClusterMixin -from sklearn.cluster import KMeans -from sklearn.utils import check_random_state, check_array, check_symmetric -from sklearn.utils.validation import check_array -from sklearn.utils.extmath import _deterministic_vector_sign_flip -from sklearn.utils import check_random_state -import numpy as np from scipy.spatial.distance import cdist from scipy.stats import multivariate_normal from sklearn import mixture +from sklearn.base import BaseEstimator, ClusterMixin +from sklearn.cluster import KMeans +from sklearn.neighbors import kneighbors_graph +from sklearn.preprocessing import MinMaxScaler +from sklearn.utils import check_array, check_random_state, check_symmetric +from sklearn.utils.extmath import _deterministic_vector_sign_flip +from sklearn.utils.validation import check_array -__all__ = ["py_webrtcvad", "getSegmentTable", "trainKBM", "getVgMatrix", "getSegmentBKs", - "performClusteringLinkage", "getSpectralClustering", "performResegmentation"] +__all__ = [ + "py_webrtcvad", + "getSegmentTable", + "trainKBM", + "getVgMatrix", + "getSegmentBKs", + "performClusteringLinkage", + "getSpectralClustering", + "performResegmentation", +] def py_webrtcvad(data, fs, fs_vad, hoplength=30, vad_mode=0): import webrtcvad from librosa.core import resample from librosa.util import frame + """ Voice activity detection. This was implementioned for easier use of py-webrtcvad. Thanks to: https://github.com/wiseman/py-webrtcvad.git @@ -68,32 +76,33 @@ def py_webrtcvad(data, fs, fs_vad, hoplength=30, vad_mode=0): # check argument if fs_vad not in [8000, 16000, 32000, 48000]: - raise ValueError('fs_vad must be 8000, 16000, 32000 or 48000.') + raise ValueError("fs_vad must be 8000, 16000, 32000 or 48000.") if hoplength not in [10, 20, 30]: - raise ValueError('hoplength must be 10, 20, or 30.') + raise ValueError("hoplength must be 10, 20, or 30.") if vad_mode not in [0, 1, 2, 3]: - raise ValueError('vad_mode must be 0, 1, 2 or 3.') + raise ValueError("vad_mode must be 0, 1, 2 or 3.") # check data - if data.dtype.kind == 'i': - if data.max() > 2**15 - 1 or data.min() < -2**15: + if data.dtype.kind == "i": + if data.max() > 2**15 - 1 or data.min() < -(2**15): raise ValueError( - 'when data type is int, data must be -32768 < data < 32767.') - data = data.astype('f') + "when data type is int, data must be -32768 < data < 32767." + ) + data = data.astype("f") - elif data.dtype.kind == 'f': + elif data.dtype.kind == "f": if np.abs(data).max() >= 1: data = data / np.abs(data).max() * 0.9 - print('Warning: input data was rescaled.') - data = (data * 2**15).astype('f') + print("Warning: input data was rescaled.") + data = (data * 2**15).astype("f") else: - raise ValueError('data dtype must be int or float.') + raise ValueError("data dtype must be int or float.") data = data.squeeze() if not data.ndim == 1: - raise ValueError('data must be mono (1 ch).') + raise ValueError("data must be mono (1 ch).") # resampling if fs != fs_vad: @@ -101,12 +110,12 @@ def py_webrtcvad(data, fs, fs_vad, hoplength=30, vad_mode=0): else: resampled = data - resampled = resampled.astype('int16') + resampled = resampled.astype("int16") hop = fs_vad * hoplength // 1000 framelen = resampled.size // hop + 1 padlen = framelen * hop - resampled.size - paded = np.lib.pad(resampled, (0, padlen), 'constant', constant_values=0) + paded = np.lib.pad(resampled, (0, padlen), "constant", constant_values=0) framed = frame(paded, frame_length=hop, hop_length=hop).T vad = webrtcvad.Vad() @@ -117,7 +126,7 @@ def py_webrtcvad(data, fs, fs_vad, hoplength=30, vad_mode=0): va_framed = np.zeros([len(valist), hop_origin]) va_framed[valist] = 1 - return va_framed.reshape(-1)[:data.size] + return va_framed.reshape(-1)[: data.size] def get_py_webrtcvad_segments(vad_info, fs): @@ -149,21 +158,22 @@ def getSegmentTable(mask, speechMapping, wLength, wIncr, wShift): for i in range(nSegs): begs = np.arange(segBeg[i], segEnd[i], wShift) bbegs = np.maximum(segBeg[i], begs - wIncr) - ends = np.minimum(begs + wLength-1, segEnd[i]) + ends = np.minimum(begs + wLength - 1, segEnd[i]) eends = np.minimum(ends + wIncr, segEnd[i]) segmentTable = np.vstack( - (segmentTable, np.vstack((bbegs, begs, ends, eends)).T)) + (segmentTable, np.vstack((bbegs, begs, ends, eends)).T) + ) return segmentTable def unravelMask(mask): - changePoints = np.diff(1*mask) + changePoints = np.diff(1 * mask) segBeg = np.where(changePoints == 1)[0] + 1 segEnd = np.where(changePoints == -1)[0] if mask[0] == 1: segBeg = np.insert(segBeg, 0, 0) if mask[-1] == 1: - segEnd = np.append(segEnd, np.size(mask)-1) + segEnd = np.append(segEnd, np.size(mask) - 1) nSegs = np.size(segBeg) return changePoints, segBeg, segEnd, nSegs @@ -221,9 +231,9 @@ def trainKBM(data, windowLength, windowRate, kbmSize): def getVgMatrix(data, gmPool, kbm, topGaussiansPerFrame): - + logLikelihoodTable = getLikelihoodTable(data, gmPool, kbm) - + # The original code was: # Vg = np.argsort(-logLikelihoodTable)[:, 0:topGaussiansPerFrame] # return Vg @@ -232,7 +242,7 @@ def getVgMatrix(data, gmPool, kbm, topGaussiansPerFrame): partition_args = np.argpartition(-logLikelihoodTable, 5, axis=1)[:, :5] partition = np.take_along_axis(-logLikelihoodTable, partition_args, axis=1) vg = np.take_along_axis(partition_args, np.argsort(partition), axis=1) - + return vg @@ -265,22 +275,27 @@ def getSegmentBKs(segmentTable, kbmSize, Vg, bitsPerSegmentFactor, speechMapping # BITSPERSEGMENTFACTOR = proportion of bits that will be set to 1 in the binary keys # Output: # SEGMENTBKTABLE = NxKBMSIZE matrix containing N binary keys for each N segments in SEGMENTTABLE - # SEGMENTCVTABLE = NxKBMSIZE matrix containing N cumulative vectors for each N segments in SEGMENTTABLE - - numberOfSegments = np.size(segmentTable,0) - segmentBKTable = np.zeros([numberOfSegments,kbmSize]) - segmentCVTable = np.zeros([numberOfSegments,kbmSize]) + # SEGMENTCVTABLE = NxKBMSIZE matrix containing N cumulative vectors for each N segments in SEGMENTTABLE + + numberOfSegments = np.size(segmentTable, 0) + segmentBKTable = np.zeros([numberOfSegments, kbmSize]) + segmentCVTable = np.zeros([numberOfSegments, kbmSize]) for i in range(numberOfSegments): - # Conform the segment according to the segmentTable matrix - beginningIndex = int(segmentTable[i,0]) - endIndex = int(segmentTable[i,3]) + # Conform the segment according to the segmentTable matrix + beginningIndex = int(segmentTable[i, 0]) + endIndex = int(segmentTable[i, 3]) # Store indices of features of the segment # speechMapping is substracted one because 1-indexing is used for this variable - A = np.arange(speechMapping[beginningIndex]-1,speechMapping[endIndex],dtype=int) - segmentBKTable[i], segmentCVTable[i] = binarizeFeatures(kbmSize, Vg[A,:], bitsPerSegmentFactor) - #print('done') + A = np.arange( + speechMapping[beginningIndex] - 1, speechMapping[endIndex], dtype=int + ) + segmentBKTable[i], segmentCVTable[i] = binarizeFeatures( + kbmSize, Vg[A, :], bitsPerSegmentFactor + ) + # print('done') return segmentBKTable, segmentCVTable + def binarizeFeatures(binaryKeySize, topComponentIndicesMatrix, bitsPerSegmentFactor): # BINARIZEMATRIX Extracts a binary key and a cumulative vector from the the # rows of VG specified by vector A @@ -309,23 +324,6 @@ def binarizeFeatures(binaryKeySize, topComponentIndicesMatrix, bitsPerSegmentFac return binaryKey, v_f -def performClusteringLinkage(segmentBKTable, segmentCVTable, N_init, linkageCriterion, linkageMetric): - from scipy.cluster.hierarchy import linkage - from scipy import cluster - if linkageMetric == 'jaccard': - observations = segmentBKTable - elif linkageMetric == 'cosine': - observations = segmentCVTable - else: - observations = segmentCVTable - clusteringTable = np.zeros([np.size(segmentCVTable, 0), N_init]) - Z = linkage(observations, method=linkageCriterion, metric=linkageMetric) - for i in np.arange(N_init): - clusteringTable[:, i] = cluster.hierarchy.cut_tree(Z, N_init-i).T+1 - k = N_init - return clusteringTable, k - - def get_sim_mat(X): """Returns the similarity matrix based on cosine similarities. Arguments @@ -417,11 +415,11 @@ def _set_diag(laplacian, value, norm_laplacian): # We need all entries in the diagonal to values if not sparse.isspmatrix(laplacian): if norm_laplacian: - laplacian.flat[::n_nodes + 1] = value + laplacian.flat[:: n_nodes + 1] = value else: laplacian = laplacian.tocoo() if norm_laplacian: - diag_idx = (laplacian.row == laplacian.col) + diag_idx = laplacian.row == laplacian.col laplacian.data[diag_idx] = value n_diags = np.unique(laplacian.row - laplacian.col).size if n_diags <= 7: @@ -432,25 +430,36 @@ def _set_diag(laplacian, value, norm_laplacian): return laplacian -def spectral_clustering(affinity, n_clusters=8, n_components=None, - eigen_solver=None, random_state=None, n_init=10, - eigen_tol=0.0, assign_labels='kmeans'): - if assign_labels not in ('kmeans', 'discretize'): - raise ValueError("The 'assign_labels' parameter should be " - "'kmeans' or 'discretize', but '%s' was given" - % assign_labels) +def spectral_clustering( + affinity, + n_clusters=8, + n_components=None, + eigen_solver=None, + random_state=None, + n_init=10, + eigen_tol=0.0, + assign_labels="kmeans", +): + if assign_labels not in ("kmeans", "discretize"): + raise ValueError( + "The 'assign_labels' parameter should be " + "'kmeans' or 'discretize', but '%s' was given" % assign_labels + ) random_state = check_random_state(random_state) n_components = n_clusters if n_components is None else n_components - maps = spectral_embedding(affinity, n_components=n_components, - eigen_solver=eigen_solver, - random_state=random_state, - eigen_tol=eigen_tol, drop_first=False) - - if assign_labels == 'kmeans': - kmeans = KMeans(n_clusters, random_state=random_state, - n_init=n_init).fit(maps) + maps = spectral_embedding( + affinity, + n_components=n_components, + eigen_solver=eigen_solver, + random_state=random_state, + eigen_tol=eigen_tol, + drop_first=False, + ) + + if assign_labels == "kmeans": + kmeans = KMeans(n_clusters, random_state=random_state, n_init=n_init).fit(maps) labels = kmeans.labels_ else: labels = discretize(maps, random_state=random_state) @@ -458,31 +467,43 @@ def spectral_clustering(affinity, n_clusters=8, n_components=None, return labels -def spectral_embedding(adjacency, n_components=20, eigen_solver=None, - random_state=None, eigen_tol=0.0, - norm_laplacian=True, drop_first=True): +def spectral_embedding( + adjacency, + n_components=20, + eigen_solver=None, + random_state=None, + eigen_tol=0.0, + norm_laplacian=True, + drop_first=True, +): adjacency = check_symmetric(adjacency) - eigen_solver = 'arpack' + eigen_solver = "arpack" norm_laplacian = True random_state = check_random_state(random_state) n_nodes = adjacency.shape[0] if not _graph_is_connected(adjacency): - warnings.warn("Graph is not fully connected, spectral embedding" - " may not work as expected.") - laplacian, dd = csgraph_laplacian(adjacency, normed=norm_laplacian, - return_diag=True) - if (eigen_solver == 'arpack' or eigen_solver != 'lobpcg' and - (not sparse.isspmatrix(laplacian) or n_nodes < 5 * n_components)): + warnings.warn( + "Graph is not fully connected, spectral embedding" + " may not work as expected." + ) + laplacian, dd = csgraph_laplacian( + adjacency, normed=norm_laplacian, return_diag=True + ) + if ( + eigen_solver == "arpack" + or eigen_solver != "lobpcg" + and (not sparse.isspmatrix(laplacian) or n_nodes < 5 * n_components) + ): # print("[INFILE] eigen_solver : ", eigen_solver, "norm_laplacian:", norm_laplacian) laplacian = _set_diag(laplacian, 1, norm_laplacian) try: laplacian *= -1 v0 = random_state.uniform(-1, 1, laplacian.shape[0]) - lambdas, diffusion_map = eigsh(laplacian, k=n_components, - sigma=1.0, which='LM', - tol=eigen_tol, v0=v0) + lambdas, diffusion_map = eigsh( + laplacian, k=n_components, sigma=1.0, which="LM", tol=eigen_tol, v0=v0 + ) embedding = diffusion_map.T[n_components::-1] if norm_laplacian: embedding = embedding / dd @@ -518,8 +539,7 @@ def compute_sorted_eigenvectors(A): EPS = 1e-10 -def compute_number_of_clusters( - eigenvalues, max_clusters=None, stop_eigenvalue=1e-2): +def compute_number_of_clusters(eigenvalues, max_clusters=None, stop_eigenvalue=1e-2): """ Compute number of clusters using EigenGap principle. @@ -566,7 +586,7 @@ def row_threshold_mult(A, p=0.95, mult=0.01): """ For each row multiply elements smaller than the row's p'th percentile by mult """ - percentiles = np.percentile(A, p*100, axis=1) + percentiles = np.percentile(A, p * 100, axis=1) mask = A < percentiles[:, np.newaxis] A = (mask * mult * A) + (~mask * A) @@ -578,21 +598,16 @@ def row_max_norm(A): Row-wise max normalization: S_{ij} = Y_{ij} / max_k(Y_{ik}) """ maxes = np.amax(A, axis=1) - return A/maxes + return A / maxes def sim_enhancement(A): - func_order = [ - gaussian_blur, - diagonal_fill, - row_threshold_mult, - - row_max_norm - ] + func_order = [gaussian_blur, diagonal_fill, row_threshold_mult, row_max_norm] for f in func_order: A = f(A) return A + def binaryKeySimilarity_cdist(clusteringMetric, bkT1, cvT1, bkT2, cvT2): if clusteringMetric == "cosine": S = 1 - cdist(cvT1, cvT2, metric=clusteringMetric) @@ -601,48 +616,46 @@ def binaryKeySimilarity_cdist(clusteringMetric, bkT1, cvT1, bkT2, cvT2): else: logging.info("Clustering metric must be cosine or jaccard") return S - -def getSpectralClustering(bestClusteringMetric, N_init, bkT, cvT, number_speaker, n, sigma, percentile, maxNrSpeakers): + + +def getSpectralClustering( + bestClusteringMetric, + N_init, + bkT, + cvT, + number_speaker, + sigma, + percentile, + maxNrSpeakers, + random_state = None, +): if number_speaker is None: # Compute affinity matrix. - simMatrix = binaryKeySimilarity_cdist(bestClusteringMetric,bkT,cvT,bkT,cvT) + simMatrix = binaryKeySimilarity_cdist(bestClusteringMetric, bkT, cvT, bkT, cvT) # Laplacian calculation affinity = sim_enhancement(simMatrix) (eigenvalues, eigenvectors) = compute_sorted_eigenvectors(affinity) # Get number of clusters. - k = compute_number_of_clusters(eigenvalues, 15, 1e-1) - # Get spectral embeddings. - spectral_embeddings = eigenvectors[:, :k] - - # Run K-Means++ on spectral embeddings. - # Note: The correct way should be using a K-Means implementation - # that supports customized distance measure such as cosine distance. - # This implemention from scikit-learn does NOT, which is inconsistent - # with the paper. - - bestClusteringID = spectral_clustering(affinity, - n_clusters=k, - eigen_solver=None, - random_state=None, - n_init=25, - eigen_tol=0.0, - assign_labels='kmeans') + number_speaker = compute_number_of_clusters(eigenvalues, maxNrSpeakers, 1e-2) else: # Compute affinity matrix. - simMatrix = binaryKeySimilarity_cdist(bestClusteringMetric,bkT,cvT,bkT,cvT) + simMatrix = binaryKeySimilarity_cdist(bestClusteringMetric, bkT, cvT, bkT, cvT) # Laplacian calculation affinity = sim_enhancement(simMatrix) - bestClusteringID = spectral_clustering(affinity, - n_clusters=number_speaker, - eigen_solver=None, - random_state=None, - n_init=25, - eigen_tol=0.0, - assign_labels='kmeans') + + bestClusteringID = spectral_clustering( + affinity, + n_clusters=number_speaker, + eigen_solver=None, + random_state=random_state, + n_init=25, + eigen_tol=0.0, + assign_labels="kmeans", + ) return bestClusteringID @@ -652,15 +665,26 @@ def smooth(a, WSZ): # WSZ: smoothing window size needs, which must be odd number, # as in the original MATLAB implementation # From https://stackoverflow.com/a/40443565 - out0 = np.convolve(a, np.ones(WSZ, dtype=int), 'valid')/WSZ - r = np.arange(1, WSZ-1, 2) - start = np.cumsum(a[:WSZ-1])[::2]/r - stop = (np.cumsum(a[:-WSZ:-1])[::2]/r)[::-1] + out0 = np.convolve(a, np.ones(WSZ, dtype=int), "valid") / WSZ + r = np.arange(1, WSZ - 1, 2) + start = np.cumsum(a[: WSZ - 1])[::2] / r + stop = (np.cumsum(a[:-WSZ:-1])[::2] / r)[::-1] return np.concatenate((start, out0, stop)) -def performResegmentation(data, speechMapping, mask, finalClusteringTable, segmentTable, modelSize, nbIter, smoothWin, numberOfSpeechFeatures): +def performResegmentation( + data, + speechMapping, + mask, + finalClusteringTable, + segmentTable, + modelSize, + nbIter, + smoothWin, + numberOfSpeechFeatures, +): from sklearn import mixture + np.random.seed(0) changePoints, segBeg, segEnd, nSegs = unravelMask(mask) @@ -671,39 +695,53 @@ def performResegmentation(data, speechMapping, mask, finalClusteringTable, segme speakerFeaturesIndxs = [] idxs = np.where(finalClusteringTable == spkID)[0] for l in np.arange(np.size(idxs, 0)): - speakerFeaturesIndxs = np.append(speakerFeaturesIndxs, np.arange( - int(segmentTable[idxs][:][l, 1]), int(segmentTable[idxs][:][l, 2])+1)) + speakerFeaturesIndxs = np.append( + speakerFeaturesIndxs, + np.arange( + int(segmentTable[idxs][:][l, 1]), + int(segmentTable[idxs][:][l, 2]) + 1, + ), + ) formattedData = np.vstack( - (np.tile(spkID, (1, np.size(speakerFeaturesIndxs, 0))), speakerFeaturesIndxs)) + ( + np.tile(spkID, (1, np.size(speakerFeaturesIndxs, 0))), + speakerFeaturesIndxs, + ) + ) trainingData = np.hstack((trainingData, formattedData)) llkMatrix = np.zeros([np.size(speakerIDs, 0), numberOfSpeechFeatures]) for i in np.arange(np.size(speakerIDs, 0)): spkIdxs = np.where(trainingData[0, :] == speakerIDs[i])[0] - spkIdxs = speechMapping[trainingData[1, - spkIdxs].astype(int)].astype(int)-1 + spkIdxs = speechMapping[trainingData[1, spkIdxs].astype(int)].astype(int) - 1 msize = np.minimum(modelSize, np.size(spkIdxs, 0)) - w_init = np.ones([msize])/msize - m_init = data[spkIdxs[np.random.randint( - np.size(spkIdxs, 0), size=(1, msize))[0]], :] + w_init = np.ones([msize]) / msize + m_init = data[ + spkIdxs[np.random.randint(np.size(spkIdxs, 0), size=(1, msize))[0]], : + ] gmm = mixture.GaussianMixture( - n_components=msize, covariance_type='diag', weights_init=w_init, means_init=m_init, verbose=0) + n_components=msize, + covariance_type="diag", + weights_init=w_init, + means_init=m_init, + verbose=0, + ) gmm.fit(data[spkIdxs, :]) llkSpk = gmm.score_samples(data) llkSpkSmoothed = np.zeros([1, numberOfSpeechFeatures]) for jx in np.arange(nSegs): sectionIdx = np.arange( - speechMapping[segBeg[jx]]-1, speechMapping[segEnd[jx]]).astype(int) + speechMapping[segBeg[jx]] - 1, speechMapping[segEnd[jx]] + ).astype(int) sectionWin = np.minimum(smoothWin, np.size(sectionIdx)) if sectionWin % 2 == 0: sectionWin = sectionWin - 1 if sectionWin >= 2: - llkSpkSmoothed[0, sectionIdx] = smooth( - llkSpk[sectionIdx], sectionWin) + llkSpkSmoothed[0, sectionIdx] = smooth(llkSpk[sectionIdx], sectionWin) else: llkSpkSmoothed[0, sectionIdx] = llkSpk[sectionIdx] llkMatrix[i, :] = llkSpkSmoothed[0].T - segOut = np.argmax(llkMatrix, axis=0)+1 + segOut = np.argmax(llkMatrix, axis=0) + 1 segChangePoints = np.diff(segOut) changes = np.where(segChangePoints != 0)[0] relSegEnds = speechMapping[segEnd] @@ -716,15 +754,30 @@ def performResegmentation(data, speechMapping, mask, finalClusteringTable, segme finalClusteringTableResegmentation = np.empty([0, 1]) for i in np.arange(np.size(changes, 0)): - addedRow = np.hstack((np.tile(np.where(speechMapping == np.maximum(currentPoint, 1))[0], (1, 2)), np.tile( - np.where(speechMapping == np.maximum(1, changes[i].astype(int)))[0], (1, 2)))) + addedRow = np.hstack( + ( + np.tile( + np.where(speechMapping == np.maximum(currentPoint, 1))[0], (1, 2) + ), + np.tile( + np.where(speechMapping == np.maximum(1, changes[i].astype(int)))[0], + (1, 2), + ), + ) + ) finalSegmentTable = np.vstack((finalSegmentTable, addedRow[0])) finalClusteringTableResegmentation = np.vstack( - (finalClusteringTableResegmentation, segOut[(changes[i]).astype(int)])) - currentPoint = changes[i]+1 - addedRow = np.hstack((np.tile(np.where(speechMapping == currentPoint)[0], (1, 2)), np.tile( - np.where(speechMapping == numberOfSpeechFeatures)[0], (1, 2)))) + (finalClusteringTableResegmentation, segOut[(changes[i]).astype(int)]) + ) + currentPoint = changes[i] + 1 + addedRow = np.hstack( + ( + np.tile(np.where(speechMapping == currentPoint)[0], (1, 2)), + np.tile(np.where(speechMapping == numberOfSpeechFeatures)[0], (1, 2)), + ) + ) finalSegmentTable = np.vstack((finalSegmentTable, addedRow[0])) finalClusteringTableResegmentation = np.vstack( - (finalClusteringTableResegmentation, segOut[(changes[i]+1).astype(int)])) + (finalClusteringTableResegmentation, segOut[(changes[i] + 1).astype(int)]) + ) return finalClusteringTableResegmentation, finalSegmentTable diff --git a/requirements.txt b/requirements.txt index f701936..4e9f7c5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ sklearn spafe pydub python_speech_features +redis