diff --git a/.env_default_http b/.env_default_http
new file mode 100644
index 0000000..6aed7a9
--- /dev/null
+++ b/.env_default_http
@@ -0,0 +1,8 @@
+# SERVING PARAMETERS
+SERVICE_MODE=http
+
+# SERVICE DISCOVERY
+SERVICE_NAME=MY_PUNCTUATION_SERVICE
+
+# CONCURRENCY
+CONCURRENCY=2
\ No newline at end of file
diff --git a/.env_default_task b/.env_default_task
new file mode 100644
index 0000000..9669c52
--- /dev/null
+++ b/.env_default_task
@@ -0,0 +1,15 @@
+# SERVING PARAMETERS
+SERVICE_MODE=task
+
+# SERVICE PARAMETERS
+SERVICES_BROKER=redis://192.168.0.1:6379
+BROKER_PASS=password
+
+# SERVICE DISCOVERY
+SERVICE_NAME=my-diarization-service
+LANGUAGE=en-US/fr-FR/*
+QUEUE_NAME=(Optionnal)
+MODEL_INFO=This model does something
+
+# CONCURRENCY
+CONCURRENCY=2
\ No newline at end of file
diff --git a/Dockerfile b/Dockerfile
index acbaa60..edda097 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,5 +1,5 @@
-FROM python:3.9
-LABEL maintainer="irebai@linagora.com, rbaraglia@linagora.com, wghezaiel@linagora.com"
+FROM python:3.10
+LABEL maintainer="rbaraglia@linagora.com, wghezaiel@linagora.com"
RUN apt-get update &&\
apt-get install -y \
@@ -31,6 +31,10 @@ COPY document /usr/src/app/document
COPY pyBK/diarizationFunctions.py pyBK/diarizationFunctions.py
COPY docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./
+# Grep CURRENT VERSION
+COPY RELEASE.md ./
+RUN export VERSION=$(awk -v RS='' '/#/ {print; exit}' RELEASE.md | head -1 | sed 's/#//' | sed 's/ //')
+
ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/diarization"
# Limits on OPENBLAS number of thread prevent SEGFAULT on machine with a large number of cpus
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..a9911a2
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,13 @@
+.DEFAULT_GOAL := help
+
+target_dirs := http_server pyBK diarization celery_app
+
+help:
+ @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
+
+style: ## update code style.
+ black ${target_dirs}
+ isort ${target_dirs}
+
+lint: ## run pylint linter.
+ pylint ${target_dirs}
\ No newline at end of file
diff --git a/README.md b/README.md
index 0660928..9ccda73 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,21 @@
# LINTO-PLATFORM-DIARIZATION
-LinTO-platform-diarization is the speaker diarization service within the [LinTO stack](https://github.com/linto-ai/linto-platform-stack).
+LinTO-platform-diarization is the [LinTO](https://linto.ai/) service for speaker diarization.
-LinTO-platform-diarization can either be used as a standalone diarization service or deployed within a micro-services infrastructure using a message broker connector.
+LinTO-platform-diarization can either be used as a standalone diarization service or deployed as a micro-services.
+
+* [Prerequisites](#pre-requisites)
+* [Deploy](#deploy)
+ * [HTTP](#http)
+ * [MicroService](#micro-service)
+* [Usage](#usages)
+ * [HTTP API](#http-api)
+ * [/healthcheck](#healthcheck)
+ * [/diarization](#diarization)
+ * [/docs](#docs)
+ * [Using celery](#using-celery)
+
+* [License](#license)
+***
## Pre-requisites
@@ -9,11 +23,11 @@ LinTO-platform-diarization can either be used as a standalone diarization servic
The transcription service requires docker up and running.
### (micro-service) Service broker and shared folder
-The diarization only entry point in job mode are tasks posted on a message broker. Supported message broker are RabbitMQ, Redis, Amazon SQS.
-On addition, as to prevent large audio from transiting through the message broker, lp-diarization use a shared storage folder.
+The diarization only entry point in job mode are tasks posted on a Redis message broker.
+Futhermore, to prevent large audio from transiting through the message broker, diarization uses a shared storage folder mounted on /opt/audio.
-## Deploy linto-platform-diarization
-linto-platform-stt can be deployed three ways:
+## Deploy
+linto-platform-diarization can be deployed:
* As a standalone diarization service through an HTTP API.
* As a micro-service connected to a message broker.
@@ -22,17 +36,31 @@ linto-platform-stt can be deployed three ways:
```bash
git clone https://github.com/linto-ai/linto-platform-diarization.git
cd linto-platform-diarization
-git submodule init
-git submodule update
docker build . -t linto-platform-diarization:latest
```
-### HTTP API
+### HTTP
+
+**1- Fill the .env**
+```bash
+cp .env_default_http .env
+```
+
+Fill the .env with your values.
+
+**Parameters:**
+| Variables | Description | Example |
+|:-|:-|:-|
+| SERVING_MODE | Specify launch mode | http |
+| CONCURRENCY | Number of HTTP worker* | 1+ |
+
+**2- Run the container**
```bash
docker run --rm \
+-v SHARED_FOLDER:/opt/audio \
-p HOST_SERVING_PORT:80 \
---env SERVICE_MODE=http \
+--env-file .env \
linto-platform-diarization:latest
```
@@ -42,37 +70,88 @@ This will run a container providing an http API binded on the host HOST_SERVING_
| Variables | Description | Example |
|:-|:-|:-|
| HOST_SERVING_PORT | Host serving port | 80 |
-| CONCURRENCY | Number of HTTP worker* | 1+ |
> *diarization uses all CPU available, adding workers will share the available CPU thus decreasing processing speed for concurrent requests
-### Micro-service within LinTO-Platform stack
->LinTO-platform-diarization can be deployed within the linto-platform-stack through the use of linto-platform-services-manager. Used this way, the container spawn celery worker waiting for diarization task on a message broker.
->LinTO-platform-diarization in task mode is not intended to be launch manually.
->However, if you intent to connect it to your custom message's broker here are the parameters:
+### Using celery
+>LinTO-platform-diarization can be deployed as a micro-service using celery. Used this way, the container spawn celery worker waiting for diarization task on a message broker.
-You need a message broker up and running at MY_SERVICE_BROKER.
+You need a message broker up and running at SERVICES_BROKER.
+**1- Fill the .env**
```bash
-docker run --rm \
--v AM_PATH:/opt/models/AM \
--v LM_PATH:/opt/models/LM \
--v SHARED_AUDIO_FOLDER:/opt/audio \
---env SERVICES_BROKER=MY_SERVICE_BROKER \
---env BROKER_PASS=MY_BROKER_PASS \
---env SERVICE_MODE=task \
---env CONCURRENCY=1 \
-linto-platform-diarization:latest
+cp .env_default_task .env
```
+Fill the .env with your values.
+
**Parameters:**
| Variables | Description | Example |
|:-|:-|:-|
+| SERVING_MODE | Specify launch mode | task |
| SERVICES_BROKER | Service broker uri | redis://my_redis_broker:6379 |
| BROKER_PASS | Service broker password (Leave empty if there is no password) | my_password |
-| CONCURRENCY | Number of celery worker* | 1+ |
+| QUEUE_NAME | (Optionnal) overide the generated queue's name (See Queue name bellow) | my_queue |
+| SERVICE_NAME | Service's name | diarization-ml |
+| LANGUAGE | Language code as a BCP-47 code | en-US or * or languages separated by "\|" |
+| MODEL_INFO | Human readable description of the model | Multilingual diarization model |
+| CONCURRENCY | Number of worker (1 worker = 1 cpu) | >1 |
+
+**2- Fill the docker-compose.yml**
+
+`#docker-compose.yml`
+```yaml
+version: '3.7'
+
+services:
+ punctuation-service:
+ image: linto-platform-diarization:latest
+ volumes:
+ - /path/to/shared/folder:/opt/audio
+ env_file: .env
+ deploy:
+ replicas: 1
+ networks:
+ - your-net
+
+networks:
+ your-net:
+ external: true
+```
+
+**3- Run with docker compose**
+
+```bash
+docker stack deploy --resolve-image always --compose-file docker-compose.yml your_stack
+```
+
+**Queue name:**
+
+By default the service queue name is generated using SERVICE_NAME and LANGUAGE: `diarization_{LANGUAGE}_{SERVICE_NAME}`.
+
+The queue name can be overided using the QUEUE_NAME env variable.
+
+**Service discovery:**
+
+As a micro-service, the instance will register itself in the service registry for discovery. The service information are stored as a JSON object in redis's db0 under the id `service:{HOST_NAME}`.
+
+The following information are registered:
+
+```json
+{
+ "service_name": $SERVICE_NAME,
+ "host_name": $HOST_NAME,
+ "service_type": "diarization",
+ "service_language": $LANGUAGE,
+ "queue_name": $QUEUE_NAME,
+ "version": "1.2.0", # This repository's version
+ "info": "Multilingual diarization model",
+ "last_alive": 65478213,
+ "concurrency": 1
+}
+```
+
-> *diarization uses all CPU available, adding workers will share the available CPU thus decreasing processing speed for concurrent requests
## Usages
@@ -92,9 +171,9 @@ Diarization API
* Method: POST
* Response content: application/json
-* File: An Wave file
-* spk_number: (integer - optional) Number of speakers. If empty, diarization will guess.
-* max_speaker: (interger - optional) Max number of speakers if spk_number is empty.
+* File: A Wave file
+* spk_number: (integer - optional) Number of speakers. If empty, diarization will clusterize automatically.
+* max_speaker: (integer - optional) Max number of speakers if spk_number is unknown.
Return a json object when using structured as followed:
```json
@@ -116,7 +195,7 @@ The /docs route offers a OpenAPI/swagger interface.
### Through the message broker
STT-Worker accepts requests with the following arguments:
-```file_path: str, with_metadata: bool```
+```file_path: str, speaker_count: int (None), max_speaker: int (None)```
* file_path: (str) Is the location of the file within the shared_folder. /.../SHARED_FOLDER/{file_path}
* speaker_count: (int default None) Fixed number of speakers.
diff --git a/RELEASE.md b/RELEASE.md
index 7d67ad5..45f944f 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -1,3 +1,21 @@
+# 1.1.2
+- Added service registration.
+- Updated healthcheck to add heartbeat.
+- Added possibility to overide generated queue name.
+# 1.1.1
+- Fixed: silences (and short occurrences <1 sec between silences) occurring inside a speaker turn were postponed at the end of the speaker turn (and could be arbitrarily assigned to next speaker)
+- Fixed: make diarization deterministic (random seed is fixed)
+- Tune length of short occurrences to consider as silences (0.3 sec)
+
+# 1.1.0
+- Changed: loading audio file by AudioSegment toolbox.
+- Changed: mfcc are extracted by python_speech_features toolbox.
+- Fixed windowRate =< maximumKBMWindowRate.
+- Likelihood table is only calculated for the top five gaussian, computation time is reduced.
+- Similarity matrix is calculated by Binary keys and cumulative vectors
+- Removed: unused AHC.
+- Code formated to pep8
+
# 1.0.3
- Fixed: diarization failing on short audio when n_speaker > 1
- Fixed (TBT): diarization returning segfault on machine with a lot of CPU
diff --git a/celery_app/celeryapp.py b/celery_app/celeryapp.py
index ed2edfa..2a75dd0 100644
--- a/celery_app/celeryapp.py
+++ b/celery_app/celeryapp.py
@@ -1,24 +1,24 @@
import os
+
from celery import Celery
from diarization import logger
-celery = Celery(__name__, include=['celery_app.tasks'])
+celery = Celery(__name__, include=["celery_app.tasks"])
service_name = os.environ.get("SERVICE_NAME")
broker_url = os.environ.get("SERVICES_BROKER")
if os.environ.get("BROKER_PASS", False):
- components = broker_url.split('//')
+ components = broker_url.split("//")
broker_url = f'{components[0]}//:{os.environ.get("BROKER_PASS")}@{components[1]}'
celery.conf.broker_url = "{}/0".format(broker_url)
celery.conf.result_backend = "{}/1".format(broker_url)
-celery.conf.update(
- result_expires=3600,
- task_acks_late=True,
- task_track_started=True)
+celery.conf.update(result_expires=3600, task_acks_late=True, task_track_started=True)
# Queues
celery.conf.update(
- {'task_routes': {
- 'diarization_task': {'queue': 'diarization'}, }
- }
+ {
+ "task_routes": {
+ "diarization_task": {"queue": "diarization"},
+ }
+ }
)
diff --git a/celery_app/register.py b/celery_app/register.py
new file mode 100644
index 0000000..9d2b732
--- /dev/null
+++ b/celery_app/register.py
@@ -0,0 +1,90 @@
+"""The register Module allow registering and unregistering operations within the service stack for service discovery purposes"""
+import os
+import sys
+import uuid
+from socket import gethostname
+from time import time
+from xmlrpc.client import ResponseError
+
+import redis
+from redis.commands.json.path import Path
+from redis.commands.search.field import NumericField, TextField
+from redis.commands.search.indexDefinition import IndexDefinition, IndexType
+
+SERVICE_DISCOVERY_DB = 0
+SERVICE_TYPE = "diarization"
+
+service_name = os.environ.get("SERVICE_NAME", SERVICE_TYPE)
+service_lang = os.environ.get("LANGUAGE", "?")
+host_name = gethostname()
+
+
+def register(is_heartbeat: bool = False) -> bool:
+ """Registers the service and act as heartbeat.
+
+ Returns:
+ bool: registering status
+ """
+ host, port = os.environ.get("SERVICES_BROKER").split("//")[1].split(":")
+ password = os.environ.get("BROKER_PASS", None)
+ r = redis.Redis(
+ host=host, port=int(port), db=SERVICE_DISCOVERY_DB, password=password
+ )
+
+ res = r.json().set(f"service:{host_name}", Path.root_path(), service_info())
+ if is_heartbeat:
+ return res
+ else:
+ print(f"Service registered as service:{host_name}")
+ schema = (
+ TextField("$.service_name", as_name="service_name"),
+ TextField("$.service_type", as_name="service_type"),
+ TextField("$.service_language", as_name="service_language"),
+ TextField("$.queue_name", as_name="queue_name"),
+ TextField("$.version", as_name="version"),
+ TextField("$.info", as_name="info"),
+ NumericField("$.last_alive", as_name="last_alive"),
+ NumericField("$.concurrency", as_name="concurrency"),
+ )
+ try:
+ r.ft().create_index(
+ schema,
+ definition=IndexDefinition(prefix=["service:"], index_type=IndexType.JSON),
+ )
+ except Exception as error:
+ pass
+ return res
+
+
+def unregister() -> None:
+ """Un-register the service"""
+ try:
+ host, port = os.environ.get("SERVICES_BROKER").split("//")[1].split(":")
+ r = redis.Redis(
+ host=host, port=int(port), db=SERVICE_DISCOVERY_DB, password="password"
+ )
+ r.json().delete(f"service:{host_name}")
+ except Exception as error:
+ print(f"Failed to unregister: {repr(error)}")
+
+
+def queue() -> str:
+ return os.environ.get("QUEUE_NAME", f"{SERVICE_TYPE}_{service_lang}_{service_name}")
+
+
+def service_info() -> dict:
+ return {
+ "service_name": service_name,
+ "host_name": host_name,
+ "service_type": SERVICE_TYPE,
+ "service_language": service_lang,
+ "queue_name": queue(),
+ "version": "1.1.2",
+ "info": os.environ.get("MODEL_INFO", "unknown"),
+ "last_alive": int(time()),
+ "concurrency": int(os.environ.get("CONCURRENCY")),
+ }
+
+
+if __name__ == "__main__":
+ sys.exit(register())
diff --git a/celery_app/tasks.py b/celery_app/tasks.py
index bb245cd..b3ab130 100644
--- a/celery_app/tasks.py
+++ b/celery_app/tasks.py
@@ -1,12 +1,15 @@
-import os
import json
+import os
+
from celery_app.celeryapp import celery
from diarization.processing.speakerdiarization import SpeakerDiarization
@celery.task(name="diarization_task")
-def diarization_task(file_name: str, speaker_count: int = None, max_speaker: int = None):
- """ transcribe_task do a synchronous call to the transcribe worker API """
+def diarization_task(
+ file_name: str, speaker_count: int = None, max_speaker: int = None
+):
+ """transcribe_task do a synchronous call to the transcribe worker API"""
if not os.path.isfile(os.path.join("/opt/audio", file_name)):
raise Exception("Could not find ressource {}".format(file_name))
@@ -20,8 +23,11 @@ def diarization_task(file_name: str, speaker_count: int = None, max_speaker: int
# Processing
try:
diarizationworker = SpeakerDiarization()
- result = diarizationworker.run(os.path.join(
- "/opt/audio", file_name), number_speaker=speaker_count, max_speaker=max_speaker)
+ result = diarizationworker.run(
+ os.path.join("/opt/audio", file_name),
+ number_speaker=speaker_count,
+ max_speaker=max_speaker,
+ )
response = diarizationworker.format_response(result)
except Exception as e:
raise Exception("Diarization has failed : {}".format(e))
diff --git a/diarization/__init__.py b/diarization/__init__.py
index 33b20af..82ff433 100644
--- a/diarization/__init__.py
+++ b/diarization/__init__.py
@@ -1,5 +1,7 @@
-import os
import logging
-logging.basicConfig(format='%(asctime)s %(name)s %(levelname)s: %(message)s', datefmt='%d/%m/%Y %H:%M:%S')
-logger = logging.getLogger("__diarization-serving__")
\ No newline at end of file
+logging.basicConfig(
+ format="%(asctime)s %(name)s %(levelname)s: %(message)s",
+ datefmt="%d/%m/%Y %H:%M:%S",
+)
+logger = logging.getLogger("__diarization-serving__")
diff --git a/diarization/processing/speakerdiarization.py b/diarization/processing/speakerdiarization.py
index d4e7750..3c24d58 100644
--- a/diarization/processing/speakerdiarization.py
+++ b/diarization/processing/speakerdiarization.py
@@ -1,21 +1,24 @@
#!/usr/bin/env python3
+import logging
import os
import time
-import logging
import uuid
-import numpy as np
+
import librosa
-import webrtcvad
+import numpy as np
import pyBK.diarizationFunctions as pybk
-#from spafe.features.mfcc import mfcc, imfcc
+
+# from spafe.features.mfcc import mfcc, imfcc
from pydub import AudioSegment
from python_speech_features import mfcc
+import pyBK.diarizationFunctions as pybk
+
class SpeakerDiarization:
def __init__(self):
- self.log = logging.getLogger('__speaker-diarization__' + __name__)
+ self.log = logging.getLogger("__speaker-diarization__" + __name__)
if os.environ.get("DEBUG", False) in ["1", 1, "true", "True"]:
self.log.setLevel(logging.DEBUG)
@@ -48,21 +51,23 @@ def __init__(self):
# BINARY_KEY
self.topGaussiansPerFrame = 5 # Number of top selected components per frame
- self.bitsPerSegmentFactor = 0.2 # Percentage of bits set to 1 in the binary keys
+ self.bitsPerSegmentFactor = (
+ 0.2 # Percentage of bits set to 1 in the binary keys
+ )
# CLUSTERING
self.N_init = 25 # Number of initial clusters
# Linkage criterion used if linkage==1 ('average', 'single', 'complete')
- self.linkageCriterion = 'average'
+ self.linkageCriterion = "average"
# Similarity metric: 'cosine' for cumulative vectors, and 'jaccard' for binary keys
- self.metric = 'cosine'
+ self.metric = "cosine"
# CLUSTERING_SELECTION
# Distance metric used in the selection of the output clustering solution ('jaccard','cosine')
- self.metric_clusteringSelection = 'cosine'
+ self.metric_clusteringSelection = "cosine"
# Method employed for number of clusters selection. Can be either 'elbow' for an elbow criterion based on within-class sum of squares (WCSS) or 'spectral' for spectral clustering
- self.bestClusteringCriterion = 'spectral'
+ self.bestClusteringCriterion = "spectral"
self.sigma = 1 # Spectral clustering parameters, employed if bestClusteringCriterion == spectral
self.percentile = 80
self.maxNrSpeakers = 20 # If known, max nr of speakers in a sesssion in the database. This is to limit the effect of changes in very small meaningless eigenvalues values generating huge eigengaps
@@ -70,52 +75,57 @@ def __init__(self):
# RESEGMENTATION
self.resegmentation = 1 # Set to 1 to perform re-segmentation
self.modelSize = 16 # Number of GMM components
- self.modelSize = 16 # Number of GMM components
self.nbIter = 5 # Number of expectation-maximization (EM) iterations
self.smoothWin = 100 # Size of the likelihood smoothing window in nb of frames
+ # Pseudo-randomness
+ self.seed = 0
+
+ # Short segments to ignore
+ self.min_duration = 0.3
+
def compute_feat_Librosa(self, audioFile):
try:
if type(audioFile) is not str:
filename = str(uuid.uuid4())
- file_path = "/tmp/"+filename
+ file_path = "/tmp/" + filename
audioFile.save(file_path)
else:
file_path = audioFile
self.sr = 16000
- y = AudioSegment.from_wav(file_path)
- self.data = np.array(y.get_array_of_samples())
+ audio = AudioSegment.from_wav(file_path)
+ audio = audio.set_frame_rate(self.sr)
+ audio = audio.set_channels(1)
+ self.data = np.array(audio.get_array_of_samples())
if type(audioFile) is not str:
os.remove(file_path)
frame_length_inSample = self.frame_length_s * self.sr
hop = int(self.frame_shift_s * self.sr)
- NFFT = int(2**np.ceil(np.log2(frame_length_inSample)))
-
+ NFFT = int(2 ** np.ceil(np.log2(frame_length_inSample)))
+
framelength_in_samples = self.frame_length_s * self.sr
n_fft = int(2 ** np.ceil(np.log2(framelength_in_samples)))
-
+
additional_kwargs = {}
if self.sr >= 16000:
- additional_kwargs.update({"lowfreq": 20, "highfreq": 7600})
+ additional_kwargs.update({"lowfreq": 20, "highfreq": 7600})
-
mfcc_coef = mfcc(
- signal=self.data,
- samplerate=self.sr,
- numcep=30,
- nfilt=30,
- nfft=n_fft,
- winlen=0.03,
- winstep=0.01,
- **additional_kwargs,
- )
-
+ signal=self.data,
+ samplerate=self.sr,
+ numcep=30,
+ nfilt=30,
+ nfft=n_fft,
+ winlen=0.03,
+ winstep=0.01,
+ **additional_kwargs,
+ )
+
except Exception as e:
self.log.error(e)
- raise ValueError(
- "Speaker diarization failed when extracting features!!!")
+ raise ValueError("Speaker diarization failed when extracting features!!!")
else:
return mfcc_coef
@@ -126,51 +136,60 @@ def computeVAD_WEBRTC(self, data, sr, nFeatures):
sr = 16000
va_framed = pybk.py_webrtcvad(
- data, fs=sr, fs_vad=sr, hoplength=30, vad_mode=0)
+ data, fs=sr, fs_vad=sr, hoplength=30, vad_mode=0
+ )
segments = pybk.get_py_webrtcvad_segments(va_framed, sr)
maskSAD = np.zeros([1, nFeatures])
for seg in segments:
- start = int(np.round(seg[0]/self.frame_shift_s))
- end = int(np.round(seg[1]/self.frame_shift_s))
+ start = int(np.round(seg[0] / self.frame_shift_s))
+ end = int(np.round(seg[1] / self.frame_shift_s))
maskSAD[0][start:end] = 1
except Exception as e:
self.log.error(e)
raise ValueError(
- "Speaker diarization failed while voice activity detection!!!")
+ "Speaker diarization failed while voice activity detection!!!"
+ )
else:
return maskSAD
def getSegments(self, frameshift, finalSegmentTable, finalClusteringTable, dur):
- numberOfSpeechFeatures = finalSegmentTable[-1, 2].astype(int)+1
+ numberOfSpeechFeatures = finalSegmentTable[-1, 2].astype(int) + 1
solutionVector = np.zeros([1, numberOfSpeechFeatures])
- for i in np.arange(np.size(finalSegmentTable, 0)):
- solutionVector[0, np.arange(
- finalSegmentTable[i, 1], finalSegmentTable[i, 2]+1).astype(int)] = finalClusteringTable[i]
+ for i in range(np.size(finalSegmentTable, 0)):
+ solutionVector[
+ 0,
+ np.arange(finalSegmentTable[i, 1], finalSegmentTable[i, 2] + 1).astype(
+ int
+ ),
+ ] = finalClusteringTable[i]
seg = np.empty([0, 3])
solutionDiff = np.diff(solutionVector)[0]
first = 0
- for i in np.arange(0, np.size(solutionDiff, 0)):
+ for i in range(0, np.size(solutionDiff, 0)):
if solutionDiff[i]:
- last = i+1
- seg1 = (first)*frameshift
- seg2 = (last-first)*frameshift
- seg3 = solutionVector[0, last-1]
- if seg.shape[0] != 0 and seg3 == seg[-1][2]:
- seg[-1][1] += seg2
- elif seg3 and seg2 > 1: # and seg2 > 0.1
- seg = np.vstack((seg, [seg1, seg2, seg3]))
- first = i+1
+ last = i + 1
+ start = (first) * frameshift
+ duration = (last - first) * frameshift
+ spklabel = solutionVector[0, last - 1]
+ silence = not spklabel or duration <= self.min_duration
+ if seg.shape[0] != 0 and (spklabel == seg[-1][2] or silence):
+ seg[-1][1] += duration
+ elif not silence:
+ seg = np.vstack((seg, [start, duration, spklabel]))
+ else: # First silence
+ continue
+ first = i + 1
last = np.size(solutionVector, 1)
- seg1 = (first-1)*frameshift
- seg2 = (last-first+1)*frameshift
- seg3 = solutionVector[0, last-1]
- if seg3 == seg[-1][2]:
- seg[-1][1] += seg2
- elif seg3 and seg2 > 1: # and seg2 > 0.1
- seg = np.vstack((seg, [seg1, seg2, seg3]))
- seg = np.vstack((seg, [dur, -1, -1]))
- seg[0][0] = 0.0
+ start = (first - 1) * frameshift
+ duration = (last - first + 1) * frameshift
+ spklabel = solutionVector[0, last - 1]
+ silence = not spklabel or duration <= self.min_duration
+ if spklabel == seg[-1][2] or silence:
+ seg[-1][1] += duration
+ else:
+ seg = np.vstack((seg, [start, duration, spklabel]))
+ seg = np.vstack((seg, [dur, -1, -1])) # Why?
return seg
def format_response(self, segments: list) -> dict:
@@ -216,52 +235,56 @@ def format_response(self, segments: list) -> dict:
# Remove the last line of the segments.
# It indicates the end of the file and segments.
- if segments[len(segments)-1][2] == -1:
- segments = segments[:len(segments)-1]
+ if segments[len(segments) - 1][2] == -1:
+ segments = segments[: len(segments) - 1]
for seg in segments:
segment = {}
- segment['seg_id'] = seg_id
+ segment["seg_id"] = seg_id
# Ensure speaker id continuity and numbers speaker by order of appearance.
if seg[2] not in spk_i_dict.keys():
spk_i_dict[seg[2]] = spk_i
spk_i += 1
- segment['spk_id'] = 'spk'+str(spk_i_dict[seg[2]])
- segment['seg_begin'] = float("{:.2f}".format(seg[0]))
- segment['seg_end'] = float("{:.2f}".format(seg[0] + seg[1]))
+ segment["spk_id"] = "spk" + str(spk_i_dict[seg[2]])
+ segment["seg_begin"] = float("{:.2f}".format(seg[0]))
+ segment["seg_end"] = float("{:.2f}".format(seg[0] + seg[1]))
- if segment['spk_id'] not in _speakers:
- _speakers[segment['spk_id']] = {}
- _speakers[segment['spk_id']]['spk_id'] = segment['spk_id']
- _speakers[segment['spk_id']]['duration'] = float(
- "{:.2f}".format(seg[1]))
- _speakers[segment['spk_id']]['nbr_seg'] = 1
+ if segment["spk_id"] not in _speakers:
+ _speakers[segment["spk_id"]] = {}
+ _speakers[segment["spk_id"]]["spk_id"] = segment["spk_id"]
+ _speakers[segment["spk_id"]]["duration"] = float(
+ "{:.2f}".format(seg[1])
+ )
+ _speakers[segment["spk_id"]]["nbr_seg"] = 1
else:
- _speakers[segment['spk_id']]['duration'] += seg[1]
- _speakers[segment['spk_id']]['nbr_seg'] += 1
- _speakers[segment['spk_id']]['duration'] = float(
- "{:.2f}".format(_speakers[segment['spk_id']]['duration']))
+ _speakers[segment["spk_id"]]["duration"] += seg[1]
+ _speakers[segment["spk_id"]]["nbr_seg"] += 1
+ _speakers[segment["spk_id"]]["duration"] = float(
+ "{:.2f}".format(_speakers[segment["spk_id"]]["duration"])
+ )
_segments.append(segment)
seg_id += 1
- json['speakers'] = list(_speakers.values())
- json['segments'] = _segments
+ json["speakers"] = list(_speakers.values())
+ json["segments"] = _segments
return json
def run(self, audioFile, number_speaker: int = None, max_speaker: int = None):
self.log.debug(f"Starting diarization on file {audioFile}")
try:
start_time = time.time()
- self.log.debug("Extracting features ... (t={:.2f}s)".format(
- time.time() - start_time))
+ self.log.debug(
+ "Extracting features ... (t={:.2f}s)".format(time.time() - start_time)
+ )
feats = self.compute_feat_Librosa(audioFile)
nFeatures = feats.shape[0]
duration = nFeatures * self.frame_shift_s
- self.log.debug("Computing SAD Mask ... (t={:.2f}s)".format(
- time.time() - start_time))
+ self.log.debug(
+ "Computing SAD Mask ... (t={:.2f}s)".format(time.time() - start_time)
+ )
maskSAD = self.computeVAD_WEBRTC(self.data, self.sr, nFeatures)
maskUEM = np.ones([1, nFeatures])
@@ -271,118 +294,137 @@ def run(self, audioFile, number_speaker: int = None, max_speaker: int = None):
speechMapping = np.zeros(nFeatures)
# you need to start the mapping from 1 and end it in the actual number of features independently of the indexing style
# so that we don't lose features on the way
- speechMapping[np.nonzero(mask)] = np.arange(1, nSpeechFeatures+1)
+ speechMapping[np.nonzero(mask)] = np.arange(1, nSpeechFeatures + 1)
data = feats[np.where(mask == 1)]
del feats
- self.log.debug("Computing segment table ... (t={:.2f}s)".format(
- time.time() - start_time))
+ self.log.debug(
+ "Computing segment table ... (t={:.2f}s)".format(
+ time.time() - start_time
+ )
+ )
segmentTable = pybk.getSegmentTable(
- mask, speechMapping, self.seg_length, self.seg_increment, self.seg_rate)
+ mask, speechMapping, self.seg_length, self.seg_increment, self.seg_rate
+ )
numberOfSegments = np.size(segmentTable, 0)
self.log.debug(f"Number of segment: {numberOfSegments}")
if numberOfSegments == 1:
self.log.debug(f"Single segment: returning")
- return [[0, duration, 1],
- [duration, -1, -1]]
+ return [[0, duration, 1], [duration, -1, -1]]
# create the KBM
# set the window rate in order to obtain "minimumNumberOfInitialGaussians" gaussians
- windowRate = np.floor((nSpeechFeatures-self.windowLength)/self.minimumNumberOfInitialGaussians)
+ windowRate = np.floor(
+ (nSpeechFeatures - self.windowLength)
+ / self.minimumNumberOfInitialGaussians
+ )
if windowRate > self.maximumKBMWindowRate:
windowRate = self.maximumKBMWindowRate
elif windowRate == 0:
windowRate = 1
-
-
- poolSize = np.floor((nSpeechFeatures-self.windowLength)/windowRate)
+ poolSize = np.floor((nSpeechFeatures - self.windowLength) / windowRate)
if self.useRelativeKBMsize:
- kbmSize = int(np.floor(poolSize*self.relKBMsize))
+ kbmSize = int(np.floor(poolSize * self.relKBMsize))
else:
kbmSize = int(self.kbmSize)
# Training pool of',int(poolSize),'gaussians with a rate of',int(windowRate),'frames'
- self.log.debug("Training KBM ... (t={:.2f}s)".format(
- time.time() - start_time))
- kbm, gmPool = pybk.trainKBM(
- data, self.windowLength, windowRate, kbmSize)
+ self.log.debug(
+ "Training KBM ... (t={:.2f}s)".format(time.time() - start_time)
+ )
+ kbm, gmPool = pybk.trainKBM(data, self.windowLength, windowRate, kbmSize)
#'Selected',kbmSize,'gaussians from the pool'
Vg = pybk.getVgMatrix(data, gmPool, kbm, self.topGaussiansPerFrame)
#'Computing binary keys for all segments... '
- self.log.debug("Computing binary keys ... (t={:.2f}s)".format(
- time.time() - start_time))
- segmentBKTable, segmentCVTable = pybk.getSegmentBKs(segmentTable,
- kbmSize,
- Vg,
- self.bitsPerSegmentFactor,
- speechMapping)
+ self.log.debug(
+ "Computing binary keys ... (t={:.2f}s)".format(time.time() - start_time)
+ )
+ segmentBKTable, segmentCVTable = pybk.getSegmentBKs(
+ segmentTable, kbmSize, Vg, self.bitsPerSegmentFactor, speechMapping
+ )
#'Performing initial clustering... '
- self.log.debug("Performing initial clustering ... (t={:.2f}s)".format(
- time.time() - start_time))
- initialClustering = np.digitize(np.arange(numberOfSegments),
- np.arange(0, numberOfSegments, numberOfSegments/self.N_init))
-
- #'Performing agglomerative clustering... '
- self.log.debug("Performing agglomerative clustering ... (t={:.2f}s)".format(
- time.time() - start_time))
- finalClusteringTable, k = pybk.performClusteringLinkage(segmentBKTable,
- segmentCVTable,
- self.N_init,
- self.linkageCriterion,
- self.metric)
+ self.log.debug(
+ "Performing initial clustering ... (t={:.2f}s)".format(
+ time.time() - start_time
+ )
+ )
#'Selecting best clustering...'
# self.bestClusteringCriterion == 'spectral':
- self.log.debug("Selecting best clustering ... (t={:.2f}s)".format(
- time.time() - start_time))
- bestClusteringID = pybk.getSpectralClustering(self.metric_clusteringSelection,
- self.N_init,
- segmentBKTable,
- segmentCVTable,
- number_speaker,
- k,
- self.sigma,
- self.percentile,
- max_speaker if max_speaker is not None else self.maxNrSpeakers)+1
+ self.log.debug(
+ "Selecting best clustering ... (t={:.2f}s)".format(
+ time.time() - start_time
+ )
+ )
+ bestClusteringID = (
+ pybk.getSpectralClustering(
+ self.metric_clusteringSelection,
+ self.N_init,
+ segmentBKTable,
+ segmentCVTable,
+ number_speaker,
+ self.sigma,
+ self.percentile,
+ max_speaker if max_speaker is not None else self.maxNrSpeakers,
+ random_state=self.seed,
+ )
+ + 1
+ )
if self.resegmentation and np.size(np.unique(bestClusteringID), 0) > 1:
- self.log.debug("Performing resegmentation ... (t={:.2f}s)".format(
- time.time() - start_time))
- finalClusteringTableResegmentation, finalSegmentTable = pybk.performResegmentation(data,
- speechMapping,
- mask,
- bestClusteringID,
- segmentTable,
- self.modelSize,
- self.nbIter,
- self.smoothWin,
- nSpeechFeatures)
- self.log.debug("Get segments ... (t={:.2f}s)".format(
- time.time() - start_time))
- segments = self.getSegments(self.frame_shift_s,
- finalSegmentTable,
- np.squeeze(
- finalClusteringTableResegmentation),
- duration)
+ self.log.debug(
+ "Performing resegmentation ... (t={:.2f}s)".format(
+ time.time() - start_time
+ )
+ )
+ (
+ finalClusteringTableResegmentation,
+ finalSegmentTable,
+ ) = pybk.performResegmentation(
+ data,
+ speechMapping,
+ mask,
+ bestClusteringID,
+ segmentTable,
+ self.modelSize,
+ self.nbIter,
+ self.smoothWin,
+ nSpeechFeatures,
+ )
+ self.log.debug(
+ "Get segments ... (t={:.2f}s)".format(time.time() - start_time)
+ )
+ segments = self.getSegments(
+ self.frame_shift_s,
+ finalSegmentTable,
+ np.squeeze(finalClusteringTableResegmentation),
+ duration,
+ )
else:
- return [[0, duration, 1],
- [duration, -1, -1]]
+ return [[0, duration, 1], [duration, -1, -1]]
- self.log.info("Speaker Diarization took %d[s] with a speed %0.2f[xRT]" %
- (int(time.time() - start_time), float(int(time.time() - start_time)/duration)))
+ self.log.info(
+ "Speaker Diarization took %d[s] with a speed %0.2f[xRT]"
+ % (
+ int(time.time() - start_time),
+ float(int(time.time() - start_time) / duration),
+ )
+ )
except ValueError as v:
self.log.error(v)
raise ValueError(
- 'Speaker diarization failed during processing the speech signal')
+ "Speaker diarization failed during processing the speech signal"
+ )
except Exception as e:
self.log.error(e)
raise Exception(
- 'Speaker diarization failed during processing the speech signal')
+ "Speaker diarization failed during processing the speech signal"
+ )
else:
self.log.debug(self.format_response(segments))
return segments
diff --git a/docker-compose.yml b/docker-compose.yml
new file mode 100644
index 0000000..477128a
--- /dev/null
+++ b/docker-compose.yml
@@ -0,0 +1,16 @@
+version: '3.7'
+
+services:
+ my-diarization-service:
+ image: linto-platform-diarization:latest
+ volumes:
+ - /path/to/shared/folder:/opt/audio
+ env_file: .env
+ deploy:
+ replicas: 1
+ networks:
+ - your-net
+
+networks:
+ your-net:
+ external: true
diff --git a/docker-entrypoint.sh b/docker-entrypoint.sh
index b67a196..0c5fd94 100755
--- a/docker-entrypoint.sh
+++ b/docker-entrypoint.sh
@@ -1,6 +1,4 @@
#!/bin/bash
-set -ea
-
echo "RUNNING Diarization"
# Launch parameters, environement variables and dependencies check
@@ -18,11 +16,26 @@ else
if [[ -z "$SERVICES_BROKER" ]]
then
echo "ERROR: SERVICES_BROKER variable not specified, cannot start celery worker."
- return -1
+ exit -1
fi
+ echo "Running celery worker"
/usr/src/app/wait-for-it.sh $(echo $SERVICES_BROKER | cut -d'/' -f 3) --timeout=20 --strict -- echo " $SERVICES_BROKER (Service Broker) is up"
- echo "RUNNING STT CELERY WORKER"
- celery --app=celery_app.celeryapp worker -Ofair -n diarization_worker@%h --queues=diarization -c $CONCURRENCY
+ # MICRO SERVICE
+ ## QUEUE NAME
+ QUEUE=$(python -c "from celery_app.register import queue; exit(queue())" 2>&1)
+ echo "Service set to $QUEUE"
+
+ ## REGISTRATION
+ python -c "from celery_app.register import register; register()"
+ echo "Service registered"
+
+ ## WORKER
+ celery --app=celery_app.celeryapp worker -Ofair -n diarization_worker@%h --queues=$QUEUE -c $CONCURRENCY
+
+ ## UNREGISTERING
+ python -c "from celery_app.register import unregister; unregister()"
+ echo "Service unregistered"
+
else
echo "ERROR: Wrong serving command: $1"
exit -1
diff --git a/healthcheck.sh b/healthcheck.sh
index 693877e..68d6b0c 100755
--- a/healthcheck.sh
+++ b/healthcheck.sh
@@ -6,5 +6,9 @@ if [ "$SERVICE_MODE" = "http" ]
then
curl --fail http://localhost:80/healthcheck || exit 1
else
+ # Update last alive
+ python -c "from celery_app.register import register; register()"
+
+ # Ping worker
celery --app=celery_app.celeryapp inspect ping -d diarization_worker@$HOSTNAME || exit 1
fi
diff --git a/http_server/confparser.py b/http_server/confparser.py
index 49513db..aa3e2f3 100644
--- a/http_server/confparser.py
+++ b/http_server/confparser.py
@@ -1,5 +1,5 @@
-import os
import argparse
+import os
__all__ = ["createParser"]
@@ -9,44 +9,39 @@ def createParser() -> argparse.ArgumentParser:
# SERVICE
parser.add_argument(
- '--service_name',
+ "--service_name",
type=str,
- help='Service Name',
- default=os.environ.get('SERVICE_NAME', 'diarization'))
+ help="Service Name",
+ default=os.environ.get("SERVICE_NAME", "diarization"),
+ )
# GUNICORN
+ parser.add_argument("--service_port", type=int, help="Service port", default=80)
parser.add_argument(
- '--service_port',
- type=int,
- help='Service port',
- default=80)
- parser.add_argument(
- '--workers',
+ "--workers",
type=int,
help="Number of Gunicorn workers (default=CONCURRENCY + 1)",
- default=int(os.environ.get('CONCURRENCY', 1)) + 1)
+ default=int(os.environ.get("CONCURRENCY", 1)) + 1,
+ )
# SWAGGER
parser.add_argument(
- '--swagger_url',
- type=str,
- help='Swagger interface url',
- default='/docs')
+ "--swagger_url", type=str, help="Swagger interface url", default="/docs"
+ )
parser.add_argument(
- '--swagger_prefix',
+ "--swagger_prefix",
type=str,
- help='Swagger prefix',
- default=os.environ.get('SWAGGER_PREFIX', ''))
+ help="Swagger prefix",
+ default=os.environ.get("SWAGGER_PREFIX", ""),
+ )
parser.add_argument(
- '--swagger_path',
+ "--swagger_path",
type=str,
- help='Swagger file path',
- default=os.environ.get('SWAGGER_PATH', '/usr/src/app/document/swagger.yml'))
+ help="Swagger file path",
+ default=os.environ.get("SWAGGER_PATH", "/usr/src/app/document/swagger.yml"),
+ )
# MISC
- parser.add_argument(
- '--debug',
- action='store_true',
- help='Display debug logs')
+ parser.add_argument("--debug", action="store_true", help="Display debug logs")
return parser
diff --git a/http_server/ingress.py b/http_server/ingress.py
index d5a6112..def9693 100644
--- a/http_server/ingress.py
+++ b/http_server/ingress.py
@@ -1,14 +1,13 @@
#!/usr/bin/env python3
+import json
+import logging
import os
from time import time
-import logging
-import json
-from flask import Flask, request, abort, Response, json
-
-from serving import GunicornServing
from confparser import createParser
+from flask import Flask, Response, abort, json, request
+from serving import GunicornServing
from swagger import setupSwaggerUI
from diarization.processing.speakerdiarization import SpeakerDiarization
@@ -16,32 +15,34 @@
app = Flask("__diarization-serving__")
logging.basicConfig(
- format='%(asctime)s %(name)s %(levelname)s: %(message)s', datefmt='%d/%m/%Y %H:%M:%S')
+ format="%(asctime)s %(name)s %(levelname)s: %(message)s",
+ datefmt="%d/%m/%Y %H:%M:%S",
+)
logger = logging.getLogger("__diarization-serving__")
-@app.route('/healthcheck', methods=['GET'])
+@app.route("/healthcheck", methods=["GET"])
def healthcheck():
return json.dumps({"healthcheck": "OK"}), 200
-@app.route("/oas_docs", methods=['GET'])
+@app.route("/oas_docs", methods=["GET"])
def oas_docs():
return "Not Implemented", 501
-@app.route('/diarization', methods=['POST'])
+@app.route("/diarization", methods=["POST"])
def transcribe():
try:
- logger.info('Diarization request received')
+ logger.info("Diarization request received")
# get response content type
- logger.debug(request.headers.get('accept').lower())
- if not request.headers.get('accept').lower() == 'application/json':
- raise ValueError('Not accepted header')
+ logger.debug(request.headers.get("accept").lower())
+ if not request.headers.get("accept").lower() == "application/json":
+ raise ValueError("Not accepted header")
# get input file
- if 'file' in request.files.keys():
+ if "file" in request.files.keys():
spk_number = request.form.get("spk_number", None)
if spk_number is not None:
spk_number = int(spk_number)
@@ -50,20 +51,21 @@ def transcribe():
max_spk_number = int(max_spk_number)
start_t = time()
else:
- raise ValueError('No audio file was uploaded')
+ raise ValueError("No audio file was uploaded")
except ValueError as error:
return str(error), 400
except Exception as e:
logger.error(e)
- return 'Server Error: {}'.format(str(e)), 500
+ return "Server Error: {}".format(str(e)), 500
# Diarization
try:
diarizationworker = SpeakerDiarization()
result = diarizationworker.run(
- request.files['file'], number_speaker=spk_number, max_speaker=max_spk_number)
+ request.files["file"], number_speaker=spk_number, max_speaker=max_spk_number
+ )
except Exception as e:
- return 'Diarization has failed: {}'.format(str(e)), 500
+ return "Diarization has failed: {}".format(str(e)), 500
response = diarizationworker.format_response(result)
logger.debug("Diarization complete (t={}s)".format(time() - start_t))
@@ -74,21 +76,21 @@ def transcribe():
# Rejected request handlers
@app.errorhandler(405)
def method_not_allowed(error):
- return 'The method is not allowed for the requested URL', 405
+ return "The method is not allowed for the requested URL", 405
@app.errorhandler(404)
def page_not_found(error):
- return 'The requested URL was not found', 404
+ return "The requested URL was not found", 404
@app.errorhandler(500)
def server_error(error):
logger.error(error)
- return 'Server Error', 500
+ return "Server Error", 500
-if __name__ == '__main__':
+if __name__ == "__main__":
logger.info("Startup...")
parser = createParser()
@@ -102,9 +104,14 @@ def server_error(error):
except Exception as e:
logger.warning("Could not setup swagger: {}".format(str(e)))
- serving = GunicornServing(app, {'bind': '{}:{}'.format("0.0.0.0", args.service_port),
- 'workers': args.workers,
- 'timeout': 3600})
+ serving = GunicornServing(
+ app,
+ {
+ "bind": "{}:{}".format("0.0.0.0", args.service_port),
+ "workers": args.workers,
+ "timeout": 3600,
+ },
+ )
logger.info(args)
try:
serving.run()
diff --git a/http_server/serving.py b/http_server/serving.py
index 3e3eead..d2dd7e8 100644
--- a/http_server/serving.py
+++ b/http_server/serving.py
@@ -2,15 +2,17 @@
class GunicornServing(gunicorn.app.base.BaseApplication):
-
def __init__(self, app, options=None):
self.options = options or {}
self.application = app
super().__init__()
def load_config(self):
- config = {key: value for key, value in self.options.items()
- if key in self.cfg.settings and value is not None}
+ config = {
+ key: value
+ for key, value in self.options.items()
+ if key in self.cfg.settings and value is not None
+ }
for key, value in config.items():
self.cfg.set(key.lower(), value)
diff --git a/http_server/swagger.py b/http_server/swagger.py
index 843205f..32d7432 100644
--- a/http_server/swagger.py
+++ b/http_server/swagger.py
@@ -1,17 +1,17 @@
import yaml
from flask_swagger_ui import get_swaggerui_blueprint
+
def setupSwaggerUI(app, args):
- '''Setup Swagger UI within the app'''
- swagger_yml = yaml.load(
- open(args.swagger_path, 'r'), Loader=yaml.Loader)
+ """Setup Swagger UI within the app"""
+ swagger_yml = yaml.load(open(args.swagger_path, "r"), Loader=yaml.Loader)
swaggerui = get_swaggerui_blueprint(
# Swagger UI static files will be mapped to '{SWAGGER_URL}/dist/'
args.swagger_prefix + args.swagger_url,
args.swagger_path,
config={ # Swagger UI config overrides
- 'app_name': "LinTO Platform Diarization",
- 'spec': swagger_yml
- }
+ "app_name": "LinTO Platform Diarization",
+ "spec": swagger_yml,
+ },
)
- app.register_blueprint(swaggerui, url_prefix=args.swagger_url)
\ No newline at end of file
+ app.register_blueprint(swaggerui, url_prefix=args.swagger_url)
diff --git a/pyBK/diarizationFunctions.py b/pyBK/diarizationFunctions.py
index 6cb4d54..db3d280 100644
--- a/pyBK/diarizationFunctions.py
+++ b/pyBK/diarizationFunctions.py
@@ -3,37 +3,45 @@
# http://www.eurecom.fr/en/people/patino-jose
# Contact: patino[at]eurecom[dot]fr, josempatinovillar[at]gmail[dot]com
+import numpy as np
+import scipy
+import scipy.sparse as sparse
+import sklearn
+from scipy import sparse
+from scipy.linalg import eigh
from scipy.ndimage import gaussian_filter
-from sklearn.neighbors import kneighbors_graph
-from scipy.sparse.csgraph import laplacian as csgraph_laplacian
+from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
+from scipy.sparse.csgraph import laplacian as csgraph_laplacian
from scipy.sparse.linalg import eigsh, lobpcg
-from scipy.sparse import csr_matrix
-from scipy.linalg import eigh
-from scipy import sparse
-import scipy.sparse as sparse
-import scipy
-import sklearn
-from sklearn.preprocessing import MinMaxScaler
-from sklearn.base import BaseEstimator, ClusterMixin
-from sklearn.cluster import KMeans
-from sklearn.utils import check_random_state, check_array, check_symmetric
-from sklearn.utils.validation import check_array
-from sklearn.utils.extmath import _deterministic_vector_sign_flip
-from sklearn.utils import check_random_state
-import numpy as np
from scipy.spatial.distance import cdist
from scipy.stats import multivariate_normal
from sklearn import mixture
+from sklearn.base import BaseEstimator, ClusterMixin
+from sklearn.cluster import KMeans
+from sklearn.neighbors import kneighbors_graph
+from sklearn.preprocessing import MinMaxScaler
+from sklearn.utils import check_array, check_random_state, check_symmetric
+from sklearn.utils.extmath import _deterministic_vector_sign_flip
+from sklearn.utils.validation import check_array
-__all__ = ["py_webrtcvad", "getSegmentTable", "trainKBM", "getVgMatrix", "getSegmentBKs",
- "performClusteringLinkage", "getSpectralClustering", "performResegmentation"]
+__all__ = [
+ "py_webrtcvad",
+ "getSegmentTable",
+ "trainKBM",
+ "getVgMatrix",
+ "getSegmentBKs",
+ "performClusteringLinkage",
+ "getSpectralClustering",
+ "performResegmentation",
+]
def py_webrtcvad(data, fs, fs_vad, hoplength=30, vad_mode=0):
import webrtcvad
from librosa.core import resample
from librosa.util import frame
+
""" Voice activity detection.
This was implementioned for easier use of py-webrtcvad.
Thanks to: https://github.com/wiseman/py-webrtcvad.git
@@ -68,32 +76,33 @@ def py_webrtcvad(data, fs, fs_vad, hoplength=30, vad_mode=0):
# check argument
if fs_vad not in [8000, 16000, 32000, 48000]:
- raise ValueError('fs_vad must be 8000, 16000, 32000 or 48000.')
+ raise ValueError("fs_vad must be 8000, 16000, 32000 or 48000.")
if hoplength not in [10, 20, 30]:
- raise ValueError('hoplength must be 10, 20, or 30.')
+ raise ValueError("hoplength must be 10, 20, or 30.")
if vad_mode not in [0, 1, 2, 3]:
- raise ValueError('vad_mode must be 0, 1, 2 or 3.')
+ raise ValueError("vad_mode must be 0, 1, 2 or 3.")
# check data
- if data.dtype.kind == 'i':
- if data.max() > 2**15 - 1 or data.min() < -2**15:
+ if data.dtype.kind == "i":
+ if data.max() > 2**15 - 1 or data.min() < -(2**15):
raise ValueError(
- 'when data type is int, data must be -32768 < data < 32767.')
- data = data.astype('f')
+ "when data type is int, data must be -32768 < data < 32767."
+ )
+ data = data.astype("f")
- elif data.dtype.kind == 'f':
+ elif data.dtype.kind == "f":
if np.abs(data).max() >= 1:
data = data / np.abs(data).max() * 0.9
- print('Warning: input data was rescaled.')
- data = (data * 2**15).astype('f')
+ print("Warning: input data was rescaled.")
+ data = (data * 2**15).astype("f")
else:
- raise ValueError('data dtype must be int or float.')
+ raise ValueError("data dtype must be int or float.")
data = data.squeeze()
if not data.ndim == 1:
- raise ValueError('data must be mono (1 ch).')
+ raise ValueError("data must be mono (1 ch).")
# resampling
if fs != fs_vad:
@@ -101,12 +110,12 @@ def py_webrtcvad(data, fs, fs_vad, hoplength=30, vad_mode=0):
else:
resampled = data
- resampled = resampled.astype('int16')
+ resampled = resampled.astype("int16")
hop = fs_vad * hoplength // 1000
framelen = resampled.size // hop + 1
padlen = framelen * hop - resampled.size
- paded = np.lib.pad(resampled, (0, padlen), 'constant', constant_values=0)
+ paded = np.lib.pad(resampled, (0, padlen), "constant", constant_values=0)
framed = frame(paded, frame_length=hop, hop_length=hop).T
vad = webrtcvad.Vad()
@@ -117,7 +126,7 @@ def py_webrtcvad(data, fs, fs_vad, hoplength=30, vad_mode=0):
va_framed = np.zeros([len(valist), hop_origin])
va_framed[valist] = 1
- return va_framed.reshape(-1)[:data.size]
+ return va_framed.reshape(-1)[: data.size]
def get_py_webrtcvad_segments(vad_info, fs):
@@ -149,21 +158,22 @@ def getSegmentTable(mask, speechMapping, wLength, wIncr, wShift):
for i in range(nSegs):
begs = np.arange(segBeg[i], segEnd[i], wShift)
bbegs = np.maximum(segBeg[i], begs - wIncr)
- ends = np.minimum(begs + wLength-1, segEnd[i])
+ ends = np.minimum(begs + wLength - 1, segEnd[i])
eends = np.minimum(ends + wIncr, segEnd[i])
segmentTable = np.vstack(
- (segmentTable, np.vstack((bbegs, begs, ends, eends)).T))
+ (segmentTable, np.vstack((bbegs, begs, ends, eends)).T)
+ )
return segmentTable
def unravelMask(mask):
- changePoints = np.diff(1*mask)
+ changePoints = np.diff(1 * mask)
segBeg = np.where(changePoints == 1)[0] + 1
segEnd = np.where(changePoints == -1)[0]
if mask[0] == 1:
segBeg = np.insert(segBeg, 0, 0)
if mask[-1] == 1:
- segEnd = np.append(segEnd, np.size(mask)-1)
+ segEnd = np.append(segEnd, np.size(mask) - 1)
nSegs = np.size(segBeg)
return changePoints, segBeg, segEnd, nSegs
@@ -221,9 +231,9 @@ def trainKBM(data, windowLength, windowRate, kbmSize):
def getVgMatrix(data, gmPool, kbm, topGaussiansPerFrame):
-
+
logLikelihoodTable = getLikelihoodTable(data, gmPool, kbm)
-
+
# The original code was:
# Vg = np.argsort(-logLikelihoodTable)[:, 0:topGaussiansPerFrame]
# return Vg
@@ -232,7 +242,7 @@ def getVgMatrix(data, gmPool, kbm, topGaussiansPerFrame):
partition_args = np.argpartition(-logLikelihoodTable, 5, axis=1)[:, :5]
partition = np.take_along_axis(-logLikelihoodTable, partition_args, axis=1)
vg = np.take_along_axis(partition_args, np.argsort(partition), axis=1)
-
+
return vg
@@ -265,22 +275,27 @@ def getSegmentBKs(segmentTable, kbmSize, Vg, bitsPerSegmentFactor, speechMapping
# BITSPERSEGMENTFACTOR = proportion of bits that will be set to 1 in the binary keys
# Output:
# SEGMENTBKTABLE = NxKBMSIZE matrix containing N binary keys for each N segments in SEGMENTTABLE
- # SEGMENTCVTABLE = NxKBMSIZE matrix containing N cumulative vectors for each N segments in SEGMENTTABLE
-
- numberOfSegments = np.size(segmentTable,0)
- segmentBKTable = np.zeros([numberOfSegments,kbmSize])
- segmentCVTable = np.zeros([numberOfSegments,kbmSize])
+ # SEGMENTCVTABLE = NxKBMSIZE matrix containing N cumulative vectors for each N segments in SEGMENTTABLE
+
+ numberOfSegments = np.size(segmentTable, 0)
+ segmentBKTable = np.zeros([numberOfSegments, kbmSize])
+ segmentCVTable = np.zeros([numberOfSegments, kbmSize])
for i in range(numberOfSegments):
- # Conform the segment according to the segmentTable matrix
- beginningIndex = int(segmentTable[i,0])
- endIndex = int(segmentTable[i,3])
+ # Conform the segment according to the segmentTable matrix
+ beginningIndex = int(segmentTable[i, 0])
+ endIndex = int(segmentTable[i, 3])
# Store indices of features of the segment
# speechMapping is substracted one because 1-indexing is used for this variable
- A = np.arange(speechMapping[beginningIndex]-1,speechMapping[endIndex],dtype=int)
- segmentBKTable[i], segmentCVTable[i] = binarizeFeatures(kbmSize, Vg[A,:], bitsPerSegmentFactor)
- #print('done')
+ A = np.arange(
+ speechMapping[beginningIndex] - 1, speechMapping[endIndex], dtype=int
+ )
+ segmentBKTable[i], segmentCVTable[i] = binarizeFeatures(
+ kbmSize, Vg[A, :], bitsPerSegmentFactor
+ )
+ # print('done')
return segmentBKTable, segmentCVTable
+
def binarizeFeatures(binaryKeySize, topComponentIndicesMatrix, bitsPerSegmentFactor):
# BINARIZEMATRIX Extracts a binary key and a cumulative vector from the the
# rows of VG specified by vector A
@@ -309,23 +324,6 @@ def binarizeFeatures(binaryKeySize, topComponentIndicesMatrix, bitsPerSegmentFac
return binaryKey, v_f
-def performClusteringLinkage(segmentBKTable, segmentCVTable, N_init, linkageCriterion, linkageMetric):
- from scipy.cluster.hierarchy import linkage
- from scipy import cluster
- if linkageMetric == 'jaccard':
- observations = segmentBKTable
- elif linkageMetric == 'cosine':
- observations = segmentCVTable
- else:
- observations = segmentCVTable
- clusteringTable = np.zeros([np.size(segmentCVTable, 0), N_init])
- Z = linkage(observations, method=linkageCriterion, metric=linkageMetric)
- for i in np.arange(N_init):
- clusteringTable[:, i] = cluster.hierarchy.cut_tree(Z, N_init-i).T+1
- k = N_init
- return clusteringTable, k
-
-
def get_sim_mat(X):
"""Returns the similarity matrix based on cosine similarities.
Arguments
@@ -417,11 +415,11 @@ def _set_diag(laplacian, value, norm_laplacian):
# We need all entries in the diagonal to values
if not sparse.isspmatrix(laplacian):
if norm_laplacian:
- laplacian.flat[::n_nodes + 1] = value
+ laplacian.flat[:: n_nodes + 1] = value
else:
laplacian = laplacian.tocoo()
if norm_laplacian:
- diag_idx = (laplacian.row == laplacian.col)
+ diag_idx = laplacian.row == laplacian.col
laplacian.data[diag_idx] = value
n_diags = np.unique(laplacian.row - laplacian.col).size
if n_diags <= 7:
@@ -432,25 +430,36 @@ def _set_diag(laplacian, value, norm_laplacian):
return laplacian
-def spectral_clustering(affinity, n_clusters=8, n_components=None,
- eigen_solver=None, random_state=None, n_init=10,
- eigen_tol=0.0, assign_labels='kmeans'):
- if assign_labels not in ('kmeans', 'discretize'):
- raise ValueError("The 'assign_labels' parameter should be "
- "'kmeans' or 'discretize', but '%s' was given"
- % assign_labels)
+def spectral_clustering(
+ affinity,
+ n_clusters=8,
+ n_components=None,
+ eigen_solver=None,
+ random_state=None,
+ n_init=10,
+ eigen_tol=0.0,
+ assign_labels="kmeans",
+):
+ if assign_labels not in ("kmeans", "discretize"):
+ raise ValueError(
+ "The 'assign_labels' parameter should be "
+ "'kmeans' or 'discretize', but '%s' was given" % assign_labels
+ )
random_state = check_random_state(random_state)
n_components = n_clusters if n_components is None else n_components
- maps = spectral_embedding(affinity, n_components=n_components,
- eigen_solver=eigen_solver,
- random_state=random_state,
- eigen_tol=eigen_tol, drop_first=False)
-
- if assign_labels == 'kmeans':
- kmeans = KMeans(n_clusters, random_state=random_state,
- n_init=n_init).fit(maps)
+ maps = spectral_embedding(
+ affinity,
+ n_components=n_components,
+ eigen_solver=eigen_solver,
+ random_state=random_state,
+ eigen_tol=eigen_tol,
+ drop_first=False,
+ )
+
+ if assign_labels == "kmeans":
+ kmeans = KMeans(n_clusters, random_state=random_state, n_init=n_init).fit(maps)
labels = kmeans.labels_
else:
labels = discretize(maps, random_state=random_state)
@@ -458,31 +467,43 @@ def spectral_clustering(affinity, n_clusters=8, n_components=None,
return labels
-def spectral_embedding(adjacency, n_components=20, eigen_solver=None,
- random_state=None, eigen_tol=0.0,
- norm_laplacian=True, drop_first=True):
+def spectral_embedding(
+ adjacency,
+ n_components=20,
+ eigen_solver=None,
+ random_state=None,
+ eigen_tol=0.0,
+ norm_laplacian=True,
+ drop_first=True,
+):
adjacency = check_symmetric(adjacency)
- eigen_solver = 'arpack'
+ eigen_solver = "arpack"
norm_laplacian = True
random_state = check_random_state(random_state)
n_nodes = adjacency.shape[0]
if not _graph_is_connected(adjacency):
- warnings.warn("Graph is not fully connected, spectral embedding"
- " may not work as expected.")
- laplacian, dd = csgraph_laplacian(adjacency, normed=norm_laplacian,
- return_diag=True)
- if (eigen_solver == 'arpack' or eigen_solver != 'lobpcg' and
- (not sparse.isspmatrix(laplacian) or n_nodes < 5 * n_components)):
+ warnings.warn(
+ "Graph is not fully connected, spectral embedding"
+ " may not work as expected."
+ )
+ laplacian, dd = csgraph_laplacian(
+ adjacency, normed=norm_laplacian, return_diag=True
+ )
+ if (
+ eigen_solver == "arpack"
+ or eigen_solver != "lobpcg"
+ and (not sparse.isspmatrix(laplacian) or n_nodes < 5 * n_components)
+ ):
# print("[INFILE] eigen_solver : ", eigen_solver, "norm_laplacian:", norm_laplacian)
laplacian = _set_diag(laplacian, 1, norm_laplacian)
try:
laplacian *= -1
v0 = random_state.uniform(-1, 1, laplacian.shape[0])
- lambdas, diffusion_map = eigsh(laplacian, k=n_components,
- sigma=1.0, which='LM',
- tol=eigen_tol, v0=v0)
+ lambdas, diffusion_map = eigsh(
+ laplacian, k=n_components, sigma=1.0, which="LM", tol=eigen_tol, v0=v0
+ )
embedding = diffusion_map.T[n_components::-1]
if norm_laplacian:
embedding = embedding / dd
@@ -518,8 +539,7 @@ def compute_sorted_eigenvectors(A):
EPS = 1e-10
-def compute_number_of_clusters(
- eigenvalues, max_clusters=None, stop_eigenvalue=1e-2):
+def compute_number_of_clusters(eigenvalues, max_clusters=None, stop_eigenvalue=1e-2):
"""
Compute number of clusters using EigenGap principle.
@@ -566,7 +586,7 @@ def row_threshold_mult(A, p=0.95, mult=0.01):
"""
For each row multiply elements smaller than the row's p'th percentile by mult
"""
- percentiles = np.percentile(A, p*100, axis=1)
+ percentiles = np.percentile(A, p * 100, axis=1)
mask = A < percentiles[:, np.newaxis]
A = (mask * mult * A) + (~mask * A)
@@ -578,21 +598,16 @@ def row_max_norm(A):
Row-wise max normalization: S_{ij} = Y_{ij} / max_k(Y_{ik})
"""
maxes = np.amax(A, axis=1)
- return A/maxes
+ return A / maxes
def sim_enhancement(A):
- func_order = [
- gaussian_blur,
- diagonal_fill,
- row_threshold_mult,
-
- row_max_norm
- ]
+ func_order = [gaussian_blur, diagonal_fill, row_threshold_mult, row_max_norm]
for f in func_order:
A = f(A)
return A
+
def binaryKeySimilarity_cdist(clusteringMetric, bkT1, cvT1, bkT2, cvT2):
if clusteringMetric == "cosine":
S = 1 - cdist(cvT1, cvT2, metric=clusteringMetric)
@@ -601,48 +616,46 @@ def binaryKeySimilarity_cdist(clusteringMetric, bkT1, cvT1, bkT2, cvT2):
else:
logging.info("Clustering metric must be cosine or jaccard")
return S
-
-def getSpectralClustering(bestClusteringMetric, N_init, bkT, cvT, number_speaker, n, sigma, percentile, maxNrSpeakers):
+
+
+def getSpectralClustering(
+ bestClusteringMetric,
+ N_init,
+ bkT,
+ cvT,
+ number_speaker,
+ sigma,
+ percentile,
+ maxNrSpeakers,
+ random_state = None,
+):
if number_speaker is None:
# Compute affinity matrix.
- simMatrix = binaryKeySimilarity_cdist(bestClusteringMetric,bkT,cvT,bkT,cvT)
+ simMatrix = binaryKeySimilarity_cdist(bestClusteringMetric, bkT, cvT, bkT, cvT)
# Laplacian calculation
affinity = sim_enhancement(simMatrix)
(eigenvalues, eigenvectors) = compute_sorted_eigenvectors(affinity)
# Get number of clusters.
- k = compute_number_of_clusters(eigenvalues, 15, 1e-1)
- # Get spectral embeddings.
- spectral_embeddings = eigenvectors[:, :k]
-
- # Run K-Means++ on spectral embeddings.
- # Note: The correct way should be using a K-Means implementation
- # that supports customized distance measure such as cosine distance.
- # This implemention from scikit-learn does NOT, which is inconsistent
- # with the paper.
-
- bestClusteringID = spectral_clustering(affinity,
- n_clusters=k,
- eigen_solver=None,
- random_state=None,
- n_init=25,
- eigen_tol=0.0,
- assign_labels='kmeans')
+ number_speaker = compute_number_of_clusters(eigenvalues, maxNrSpeakers, 1e-2)
else:
# Compute affinity matrix.
- simMatrix = binaryKeySimilarity_cdist(bestClusteringMetric,bkT,cvT,bkT,cvT)
+ simMatrix = binaryKeySimilarity_cdist(bestClusteringMetric, bkT, cvT, bkT, cvT)
# Laplacian calculation
affinity = sim_enhancement(simMatrix)
- bestClusteringID = spectral_clustering(affinity,
- n_clusters=number_speaker,
- eigen_solver=None,
- random_state=None,
- n_init=25,
- eigen_tol=0.0,
- assign_labels='kmeans')
+
+ bestClusteringID = spectral_clustering(
+ affinity,
+ n_clusters=number_speaker,
+ eigen_solver=None,
+ random_state=random_state,
+ n_init=25,
+ eigen_tol=0.0,
+ assign_labels="kmeans",
+ )
return bestClusteringID
@@ -652,15 +665,26 @@ def smooth(a, WSZ):
# WSZ: smoothing window size needs, which must be odd number,
# as in the original MATLAB implementation
# From https://stackoverflow.com/a/40443565
- out0 = np.convolve(a, np.ones(WSZ, dtype=int), 'valid')/WSZ
- r = np.arange(1, WSZ-1, 2)
- start = np.cumsum(a[:WSZ-1])[::2]/r
- stop = (np.cumsum(a[:-WSZ:-1])[::2]/r)[::-1]
+ out0 = np.convolve(a, np.ones(WSZ, dtype=int), "valid") / WSZ
+ r = np.arange(1, WSZ - 1, 2)
+ start = np.cumsum(a[: WSZ - 1])[::2] / r
+ stop = (np.cumsum(a[:-WSZ:-1])[::2] / r)[::-1]
return np.concatenate((start, out0, stop))
-def performResegmentation(data, speechMapping, mask, finalClusteringTable, segmentTable, modelSize, nbIter, smoothWin, numberOfSpeechFeatures):
+def performResegmentation(
+ data,
+ speechMapping,
+ mask,
+ finalClusteringTable,
+ segmentTable,
+ modelSize,
+ nbIter,
+ smoothWin,
+ numberOfSpeechFeatures,
+):
from sklearn import mixture
+
np.random.seed(0)
changePoints, segBeg, segEnd, nSegs = unravelMask(mask)
@@ -671,39 +695,53 @@ def performResegmentation(data, speechMapping, mask, finalClusteringTable, segme
speakerFeaturesIndxs = []
idxs = np.where(finalClusteringTable == spkID)[0]
for l in np.arange(np.size(idxs, 0)):
- speakerFeaturesIndxs = np.append(speakerFeaturesIndxs, np.arange(
- int(segmentTable[idxs][:][l, 1]), int(segmentTable[idxs][:][l, 2])+1))
+ speakerFeaturesIndxs = np.append(
+ speakerFeaturesIndxs,
+ np.arange(
+ int(segmentTable[idxs][:][l, 1]),
+ int(segmentTable[idxs][:][l, 2]) + 1,
+ ),
+ )
formattedData = np.vstack(
- (np.tile(spkID, (1, np.size(speakerFeaturesIndxs, 0))), speakerFeaturesIndxs))
+ (
+ np.tile(spkID, (1, np.size(speakerFeaturesIndxs, 0))),
+ speakerFeaturesIndxs,
+ )
+ )
trainingData = np.hstack((trainingData, formattedData))
llkMatrix = np.zeros([np.size(speakerIDs, 0), numberOfSpeechFeatures])
for i in np.arange(np.size(speakerIDs, 0)):
spkIdxs = np.where(trainingData[0, :] == speakerIDs[i])[0]
- spkIdxs = speechMapping[trainingData[1,
- spkIdxs].astype(int)].astype(int)-1
+ spkIdxs = speechMapping[trainingData[1, spkIdxs].astype(int)].astype(int) - 1
msize = np.minimum(modelSize, np.size(spkIdxs, 0))
- w_init = np.ones([msize])/msize
- m_init = data[spkIdxs[np.random.randint(
- np.size(spkIdxs, 0), size=(1, msize))[0]], :]
+ w_init = np.ones([msize]) / msize
+ m_init = data[
+ spkIdxs[np.random.randint(np.size(spkIdxs, 0), size=(1, msize))[0]], :
+ ]
gmm = mixture.GaussianMixture(
- n_components=msize, covariance_type='diag', weights_init=w_init, means_init=m_init, verbose=0)
+ n_components=msize,
+ covariance_type="diag",
+ weights_init=w_init,
+ means_init=m_init,
+ verbose=0,
+ )
gmm.fit(data[spkIdxs, :])
llkSpk = gmm.score_samples(data)
llkSpkSmoothed = np.zeros([1, numberOfSpeechFeatures])
for jx in np.arange(nSegs):
sectionIdx = np.arange(
- speechMapping[segBeg[jx]]-1, speechMapping[segEnd[jx]]).astype(int)
+ speechMapping[segBeg[jx]] - 1, speechMapping[segEnd[jx]]
+ ).astype(int)
sectionWin = np.minimum(smoothWin, np.size(sectionIdx))
if sectionWin % 2 == 0:
sectionWin = sectionWin - 1
if sectionWin >= 2:
- llkSpkSmoothed[0, sectionIdx] = smooth(
- llkSpk[sectionIdx], sectionWin)
+ llkSpkSmoothed[0, sectionIdx] = smooth(llkSpk[sectionIdx], sectionWin)
else:
llkSpkSmoothed[0, sectionIdx] = llkSpk[sectionIdx]
llkMatrix[i, :] = llkSpkSmoothed[0].T
- segOut = np.argmax(llkMatrix, axis=0)+1
+ segOut = np.argmax(llkMatrix, axis=0) + 1
segChangePoints = np.diff(segOut)
changes = np.where(segChangePoints != 0)[0]
relSegEnds = speechMapping[segEnd]
@@ -716,15 +754,30 @@ def performResegmentation(data, speechMapping, mask, finalClusteringTable, segme
finalClusteringTableResegmentation = np.empty([0, 1])
for i in np.arange(np.size(changes, 0)):
- addedRow = np.hstack((np.tile(np.where(speechMapping == np.maximum(currentPoint, 1))[0], (1, 2)), np.tile(
- np.where(speechMapping == np.maximum(1, changes[i].astype(int)))[0], (1, 2))))
+ addedRow = np.hstack(
+ (
+ np.tile(
+ np.where(speechMapping == np.maximum(currentPoint, 1))[0], (1, 2)
+ ),
+ np.tile(
+ np.where(speechMapping == np.maximum(1, changes[i].astype(int)))[0],
+ (1, 2),
+ ),
+ )
+ )
finalSegmentTable = np.vstack((finalSegmentTable, addedRow[0]))
finalClusteringTableResegmentation = np.vstack(
- (finalClusteringTableResegmentation, segOut[(changes[i]).astype(int)]))
- currentPoint = changes[i]+1
- addedRow = np.hstack((np.tile(np.where(speechMapping == currentPoint)[0], (1, 2)), np.tile(
- np.where(speechMapping == numberOfSpeechFeatures)[0], (1, 2))))
+ (finalClusteringTableResegmentation, segOut[(changes[i]).astype(int)])
+ )
+ currentPoint = changes[i] + 1
+ addedRow = np.hstack(
+ (
+ np.tile(np.where(speechMapping == currentPoint)[0], (1, 2)),
+ np.tile(np.where(speechMapping == numberOfSpeechFeatures)[0], (1, 2)),
+ )
+ )
finalSegmentTable = np.vstack((finalSegmentTable, addedRow[0]))
finalClusteringTableResegmentation = np.vstack(
- (finalClusteringTableResegmentation, segOut[(changes[i]+1).astype(int)]))
+ (finalClusteringTableResegmentation, segOut[(changes[i] + 1).astype(int)])
+ )
return finalClusteringTableResegmentation, finalSegmentTable
diff --git a/requirements.txt b/requirements.txt
index f701936..4e9f7c5 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -14,3 +14,4 @@ sklearn
spafe
pydub
python_speech_features
+redis