From 5db6c78027034f95e082268db5752035119465ca Mon Sep 17 00:00:00 2001 From: Audran Bert Date: Tue, 26 Mar 2024 15:04:06 +0100 Subject: [PATCH 01/50] add whisper streaming support through websocket --- whisper/README.md | 8 + whisper/docker-entrypoint.sh | 5 +- whisper/stt/processing/__init__.py | 2 + whisper/stt/processing/streaming.py | 445 ++++++++++++++++++++++++++++ 4 files changed, 459 insertions(+), 1 deletion(-) create mode 100644 whisper/stt/processing/streaming.py diff --git a/whisper/README.md b/whisper/README.md index 41dc46a..6160c81 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -252,6 +252,14 @@ You may also want to add specific options: | `` | Path to the Whisper model on the host machine mounted to /opt/model.pt | /my/path/to/models/medium.pt | | `` | (Optional) Path to a folder to a custom wav2vec alignment model | /my/path/to/models/wav2vec | +### Websocket Server +Websocket server's mode deploy a streaming transcription service only. + +The SERVICE_MODE value in the .env should be set to ```websocket```. + +Usage is the same as the [http streaming API](#/streaming) + +The code is from [this repository](https://github.com/linto-ai/whisper_streaming) which is a fork of [ufal/whisper_streaming](https://github.com/ufal/whisper_streaming) with some modifications. They published a paper : ["Turning Whisper into Real-Time Transcription System" by Dominik Macháček, Raj Dabre, Ondřej Bojar](https://arxiv.org/abs/2307.14743). We strongly encourage you to take a look at their work. ## Usages ### HTTP API diff --git a/whisper/docker-entrypoint.sh b/whisper/docker-entrypoint.sh index 97a3804..71ca438 100755 --- a/whisper/docker-entrypoint.sh +++ b/whisper/docker-entrypoint.sh @@ -41,7 +41,10 @@ else /usr/src/app/wait-for-it.sh $(echo $SERVICES_BROKER | cut -d'/' -f 3) --timeout=20 --strict -- echo " $SERVICES_BROKER (Service Broker) is up" || exit 1 echo "RUNNING STT CELERY WORKER" celery --app=celery_app.celeryapp worker $OPT -Ofair --queues=${SERVICE_NAME} -c ${CONCURRENCY} -n ${SERVICE_NAME}_worker@%h - + elif [ "$SERVICE_MODE" == "websocket" ] + then + echo "Running Websocket server on port ${STREAMING_PORT:=80}" + python3 websocket/websocketserver.py else echo "ERROR: Wrong serving command: $SERVICE_MODE" exit -1 diff --git a/whisper/stt/processing/__init__.py b/whisper/stt/processing/__init__.py index b0e7f6d..3af614e 100644 --- a/whisper/stt/processing/__init__.py +++ b/whisper/stt/processing/__init__.py @@ -58,6 +58,8 @@ def __call__(self, *args, **kwargs): ) try: model = LazyLoadedModel(model_type, device=device) + if os.environ.get("ENABLE_STREAMING", False) in [True, "true", 1]: + model.check_loaded() # model = load_whisper_model(model_type, device=device) except Exception as err: raise Exception("Failed to load transcription model: {}".format(str(err))) from err diff --git a/whisper/stt/processing/streaming.py b/whisper/stt/processing/streaming.py new file mode 100644 index 0000000..d90def3 --- /dev/null +++ b/whisper/stt/processing/streaming.py @@ -0,0 +1,445 @@ +import json +import sys +import string +import numpy as np +import torch +from .text_normalize import normalize_text, remove_emoji, remove_punctuation +from .decoding import decode_ct2 +from stt import logger, USE_CTRANSLATE2 +from websockets.legacy.server import WebSocketServerProtocol +import whisper_timestamped + +# CITATION (Please look at their github repository): +# Code from https://github.com/linto-ai/whisper_streaming which is a fork of https://github.com/ufal/whisper_streaming with some modifications +# They published a paper : "Turning Whisper into Real-Time Transcription System" by Dominik Macháček, Raj Dabre, Ondřej Bojar +# https://arxiv.org/abs/2307.14743 +# + +def bytes_to_array(bytes): + return np.frombuffer(bytes, dtype=np.int16).astype(np.float32) / 32768 + +def processor_output_to_text(o): + if o[0] is None: + return "" + return o[2] + +def whisper_to_json(o): + result = dict() + result["text"] = processor_output_to_text(o) + json_res = json.dumps(result) + return json_res + +async def wssDecode(ws: WebSocketServerProtocol, model_and_alignementmodel): + """Async Decode function endpoint""" + res = await ws.recv() + try: + config = json.loads(res)["config"] + sample_rate = config["sample_rate"] + logger.info(f"Received config: {config}") + except Exception as e: + logger.error("Failed to read stream configuration") + await ws.close(reason="Failed to load configuration") + model, alignementmodel = model_and_alignementmodel + if USE_CTRANSLATE2: + logger.info("Using ctranslate2 for decoding") + asr = FasterWhisperASR(model=model, lan="fr") + else: + logger.info("Using whisper_timestamped for decoding") + asr = WhisperTimestampedASR(model=model, lan="fr") + online = OnlineASRProcessor(asr, logfile=sys.stderr, buffer_trimming=8) + logger.info("Waiting for chunks") + while True: + try: + message = await ws.recv() + if message is None or message == "": # Timeout + logger.info("Connection closed by client") + ws.close() + except Exception as e: + print("Connection closed by client: {}".format(str(e))) + break + if "eof" in str(message): + # await ws.send(json.dumps("")) + o = online.finish() + await ws.send(whisper_to_json(o)) + logger.info(f"End of stream {message}") + await ws.close(reason="End of stream") + break + online.insert_audio_chunk(bytes_to_array(message)) + o, _ = online.process_iter() + logger.info(o) + await ws.send(whisper_to_json(o)) + + +class HypothesisBuffer: + + def __init__(self, logfile=sys.stderr): + self.commited_in_buffer = [] + self.buffer = [] + self.new = [] + + self.last_commited_time = 0 + self.last_commited_word = None + self.last_buffered_time = -1 + + self.logfile = logfile + + def insert(self, new, offset): + # compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content + # the new tail is added to self.new + + new = [(a+offset,b+offset,t) for a,b,t in new] + self.new = [(a,b,t) for a,b,t in new if a > self.last_commited_time-0.1] + + if len(self.new) >= 1: + a,b,t = self.new[0] + if abs(a - self.last_commited_time) < 1: + if self.commited_in_buffer: + # it's going to search for 1, 2, ..., 5 consecutive words (n-grams) that are identical in commited and new. If they are, they're dropped. + cn = len(self.commited_in_buffer) + nn = len(self.new) + for i in range(1,min(min(cn,nn),5)+1): # 5 is the maximum + c = " ".join([self.commited_in_buffer[-j][2] for j in range(1,i+1)][::-1]) + tail = " ".join(self.new[j-1][2] for j in range(1,i+1)) + if c == tail: + logger.debug(f"removing last {i} words:") + for j in range(i): + logger.debug(f"\t{self.new.pop(0)}") + break + + def flush(self): + # returns commited chunk = the longest common prefix of 2 last inserts. + + commit = [] + while self.new: + na, nb, nt = self.new[0] + + if len(self.buffer) == 0: + break + + if nt.lower().translate(str.maketrans('', '', string.punctuation)) == self.buffer[0][2].lower().translate(str.maketrans('', '', string.punctuation)): + commit.append((na,nb,nt)) + self.last_commited_word = nt + self.last_commited_time = nb + self.buffer.pop(0) + self.new.pop(0) + else: + # print(f"SStop committing at '{nt}' and '{self.buffer[0][2]}'") + break + self.buffer = self.new + new_non_commit = [i for i in self.buffer if i[1] > self.last_buffered_time-0.1] + self.last_buffered_time = self.buffer[-1][1] if self.buffer else -1 + self.new = [] + self.commited_in_buffer.extend(commit) + return commit, new_non_commit + + def pop_commited(self, time): + while self.commited_in_buffer and self.commited_in_buffer[0][1] <= time: + self.commited_in_buffer.pop(0) + + def complete(self): + return self.buffer + +class OnlineASRProcessor: + + SAMPLING_RATE = 16000 + + def __init__(self, asr, buffer_trimming=15, logfile=sys.stderr): + """asr: WhisperASR object + tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer. It can be None, if "segment" buffer trimming option is used, then tokenizer is not used at all. + ("segment", 15) + buffer_trimming: a pair of (option, seconds), where option is either "sentence" or "segment", and seconds is a number. Buffer is trimmed if it is longer than "seconds" threshold. Default is the most recommended option. + logfile: where to store the log. + """ + self.asr = asr + self.logfile = logfile + + self.init() + + self.buffer_trimming_sec = buffer_trimming + + def init(self): + """run this when starting or restarting processing""" + self.audio_buffer = np.array([],dtype=np.float32) + self.buffer_time_offset = 0 + + self.transcript_buffer = HypothesisBuffer(logfile=self.logfile) + self.commited = [] + self.last_chunked_at = 0 + + self.silence_iters = 0 + + def insert_audio_chunk(self, audio): + self.audio_buffer = np.append(self.audio_buffer, audio) + + def prompt(self): + """Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer. + "context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons. + """ + k = max(0,len(self.commited)-1) + while k > 0 and self.commited[k-1][1] > self.last_chunked_at: + k -= 1 + + p = self.commited[:k] + p = [t for _,_,t in p] + prompt = [] + l = 0 + while p and l < 200: # 200 characters prompt size + x = p.pop(-1) + l += len(x)+1 + prompt.append(x) + non_prompt = self.commited[k:] + return self.asr.sep.join(prompt[::-1]), self.asr.sep.join(t for _,_,t in non_prompt) + + def process_iter(self): + """Runs on the current audio buffer. + Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, ""). + The non-emty text is confirmed (committed) partial transcript. + """ + vad = True + prompt, non_prompt = self.prompt() + logger.debug(f"PROMPT:{prompt}") + logger.debug(f"CONTEXT:{non_prompt}") + logger.debug(f"Transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds starting at {self.buffer_time_offset:2.2f}s") + # print(f"Transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds starting at {self.buffer_time_offset:2.2f}s") + # use VAD to filter out the silence + if vad: + from whisper_timestamped.transcribe import remove_non_speech + tensor_buffer = torch.tensor(self.audio_buffer) + audio_speech, segments, convertion_function = remove_non_speech(tensor_buffer, method="silero", sample_rate=self.SAMPLING_RATE, dilatation=0.5) + audio_speech = audio_speech.numpy() + res = self.asr.transcribe(audio_speech, init_prompt=prompt) + else: + res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt) + # transform to [(beg,end,"word1"), ...] + tsw = self.asr.ts_words(res, convertion_function if vad else None) + self.transcript_buffer.insert(tsw, self.buffer_time_offset) + o, buffer = self.transcript_buffer.flush() + self.commited.extend(o) + # print(f"{buffer}") + if buffer and (self.buffer_time_offset+len(self.audio_buffer)/self.SAMPLING_RATE)-buffer[-1][1]<0.05: + buffer.pop(-1) + logger.debug(f">>>>COMPLETE NOW:{self.to_flush(o)}") + logger.debug(f"INCOMPLETE:{self.to_flush(self.transcript_buffer.complete())}") + + if len(self.audio_buffer)/self.SAMPLING_RATE > self.buffer_trimming_sec: + self.chunk_completed_segment(res, chunk_silence=vad, speech_segments=segments if vad else False) + + logger.debug(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}") + return self.to_flush(o), self.to_flush(buffer) + + def chunk_completed_sentence(self): + if self.commited == []: return + logger.info(self.commited) + sents = self.words_to_sentences(self.commited) + for s in sents: + logger.debug("\t\tSENT:",s) + if len(sents) < 2: + return + while len(sents) > 2: + sents.pop(0) + # we will continue with audio processing at this timestamp + chunk_at = sents[-2][1] + + logger.debug(f"--- sentence chunked at {chunk_at:2.2f}") + self.chunk_at(chunk_at) + + def chunk_completed_segment(self, res, chunk_silence=False, speech_segments=None): + if self.commited == [] and not chunk_silence: + return + + ends = self.asr.segments_end_ts(res) + t = self.commited[-1][1] + if len(ends) > 1: + e = ends[-2]+self.buffer_time_offset + while len(ends) > 2 and e > t: + ends.pop(-1) + e = ends[-2]+self.buffer_time_offset + if e <= t: + logger.debug(f"--- segment chunked at {e:2.2f}") + # print(f"--- segment chunked at {e:2.2f}") + self.chunk_at(e) + else: + logger.debug(f"--- last segment not within commited area") + elif chunk_silence: + lenght = len(self.audio_buffer)/self.SAMPLING_RATE + e = self.buffer_time_offset + lenght - 2 + if speech_segments: + end_silence = lenght - speech_segments[-1][1] + if end_silence > 2: + logger.debug(f"--- Silence segment chunked at {e:2.2f}") + self.chunk_at(e) + elif speech_segments is not None: + logger.debug(f"--- Silence segment chunked at {e:2.2f}") + self.chunk_at(e) + else: + logger.debug(f"--- not enough segments to chunk") + + + + + + def chunk_at(self, time): + """trims the hypothesis and audio buffer at "time" + """ + # print(f"chunking at {time:2.2f}") + self.transcript_buffer.pop_commited(time) + cut_seconds = time - self.buffer_time_offset + self.audio_buffer = self.audio_buffer[int(cut_seconds*self.SAMPLING_RATE):] + self.buffer_time_offset = time + self.last_chunked_at = time + + def words_to_sentences(self, words): + """Uses self.tokenizer for sentence segmentation of words. + Returns: [(beg,end,"sentence 1"),...] + """ + + cwords = [w for w in words] + t = " ".join(o[2] for o in cwords) + s = self.tokenizer.split(t) + out = [] + while s: + beg = None + end = None + sent = s.pop(0).strip() + fsent = sent + while cwords: + b,e,w = cwords.pop(0) + w = w.strip() + if beg is None and sent.startswith(w): + beg = b + elif end is None and sent == w: + end = e + out.append((beg,end,fsent)) + break + sent = sent[len(w):].strip() + return out + + def finish(self): + """Flush the incomplete text when the whole processing ends. + Returns: the same format as self.process_iter() + """ + o = self.transcript_buffer.complete() + f = self.to_flush(o) + logger.debug(f"last, noncommited:{f}") + return f + + + def to_flush(self, sents, sep=None, offset=0, ): + # concatenates the timestamped words or sentences into one sequence that is flushed in one line + # sents: [(beg1, end1, "sentence1"), ...] or [] if empty + # return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty + if sep is None: + sep = self.asr.sep + t = sep.join(s[2] for s in sents) + if len(sents) == 0: + b = None + e = None + else: + b = offset + sents[0][0] + e = offset + sents[-1][1] + return (b,e,t) + + +class ASRBase: + + sep = " " # join transcribe words with this character (" " for whisper_timestamped, + # "" for faster-whisper because it emits the spaces when needed) + + def __init__(self, lan, model=None, logfile=sys.stderr, condition_on_previous_text=None): + self.logfile = logfile + + self.transcribe_kargs = {} + self.original_language = lan + self.model = model + + def transcribe(self, audio, init_prompt=""): + raise NotImplemented("must be implemented in the child class") + + def use_vad(self, vad_name=None): + raise NotImplemented("must be implemented in the child class") + + +class FasterWhisperASR(ASRBase): + """Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version. + """ + + sep = "" + + def __init__(self, lan, model=None, logfile=sys.stderr, condition_on_previous_text=None): + super().__init__(lan, model=model, logfile=logfile) + self.transcribe_kargs['beam_size'] = 1 + self.transcribe_kargs['best_of'] = 1 + self.transcribe_kargs['temperature'] = 0 + self.transcribe_kargs['condition_on_previous_text'] = False if condition_on_previous_text is None else condition_on_previous_text + + def transcribe(self, audio, init_prompt=""): + # tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01) + segments, info = self.model.transcribe(audio, language=self.original_language, initial_prompt=init_prompt, word_timestamps=True, **self.transcribe_kargs) + return list(segments) + + def ts_words(self, segments, timestamps_convert_function=None): + o = [] + for segment in segments: + for word in segment.words: + # not stripping the spaces -- should not be merged with them! + w = word.word + if timestamps_convert_function is not None: + start, end = timestamps_convert_function(word.start, word.end) + t = (start, end, w) + else: + t = (word.start, word.end, w) + o.append(t) + return o + + def segments_end_ts(self, res): + return [s.end for s in res] + + def use_vad(self, vad_name=None): + self.transcribe_kargs["vad_filter"] = True + + +class WhisperTimestampedASR(ASRBase): + """Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper. + On the other hand, the installation for GPU could be easier. + """ + + sep = " " + + def __init__(self, lan, model=None, logfile=sys.stderr, condition_on_previous_text=None): + super().__init__(lan, model=model, logfile=logfile) + self.transcribe_kargs["verbose"] = None + self.transcribe_kargs["beam_size"] = None + self.transcribe_kargs["best_of"] = None + self.transcribe_kargs["temperature"] = 0 + self.transcribe_kargs['condition_on_previous_text'] = False if condition_on_previous_text is None else condition_on_previous_text + + def transcribe(self, audio, init_prompt=""): + # result = whisper_timestamped.transcribe(self.model, audio, language=self.original_language, **self.transcribe_kargs) + result = whisper_timestamped.transcribe_timestamped(self.model, + audio, language=self.original_language, + initial_prompt=init_prompt, **self.transcribe_kargs) + return result + + def ts_words(self,r, timestamps_convert_function=None): + # return: transcribe result object to [(beg,end,"word1"), ...] + o = [] + for s in r["segments"]: + for w in s["words"]: + if timestamps_convert_function is not None: + # print(f"start: {word.start}->{timestamps_convert_function(word.start)}, end: {word.end}->{timestamps_convert_function(word.end)}") + start, end = timestamps_convert_function(w["start"], w['end']) + t = (start, end, w["text"]) + else: + t = (w["start"],w["end"],w["text"]) + o.append(t) + return o + + def segments_end_ts(self, res): + return [s["end"] for s in res["segments"]] + + def use_vad(self, vad_name=None): + if vad_name is None: + self.transcribe_kargs["vad"] = True + else: + self.transcribe_kargs["vad"] = vad_name + From 6d1996f86cbb229fbf5b401461908648a5e43120 Mon Sep 17 00:00:00 2001 From: AudranBert Date: Tue, 26 Mar 2024 15:46:29 +0100 Subject: [PATCH 02/50] update readme and comments --- whisper/README.md | 5 ++--- whisper/stt/processing/streaming.py | 7 +------ 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/whisper/README.md b/whisper/README.md index 6160c81..5649542 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -257,9 +257,7 @@ Websocket server's mode deploy a streaming transcription service only. The SERVICE_MODE value in the .env should be set to ```websocket```. -Usage is the same as the [http streaming API](#/streaming) - -The code is from [this repository](https://github.com/linto-ai/whisper_streaming) which is a fork of [ufal/whisper_streaming](https://github.com/ufal/whisper_streaming) with some modifications. They published a paper : ["Turning Whisper into Real-Time Transcription System" by Dominik Macháček, Raj Dabre, Ondřej Bojar](https://arxiv.org/abs/2307.14743). We strongly encourage you to take a look at their work. +Usage is the same as the [http streaming API](#/streaming). ## Usages ### HTTP API @@ -347,3 +345,4 @@ This project is developped under the AGPLv3 License (see LICENSE). * [HuggingFace Transformers](https://github.com/huggingface/transformers) * [SpeechBrain](https://github.com/speechbrain/speechbrain) * [TorchAudio](https://github.com/pytorch/audio) +* [Whisper_Streaming](https://github.com/ufal/whisper_streaming) \ No newline at end of file diff --git a/whisper/stt/processing/streaming.py b/whisper/stt/processing/streaming.py index d90def3..d5e0706 100644 --- a/whisper/stt/processing/streaming.py +++ b/whisper/stt/processing/streaming.py @@ -9,11 +9,7 @@ from websockets.legacy.server import WebSocketServerProtocol import whisper_timestamped -# CITATION (Please look at their github repository): -# Code from https://github.com/linto-ai/whisper_streaming which is a fork of https://github.com/ufal/whisper_streaming with some modifications -# They published a paper : "Turning Whisper into Real-Time Transcription System" by Dominik Macháček, Raj Dabre, Ondřej Bojar -# https://arxiv.org/abs/2307.14743 -# + def bytes_to_array(bytes): return np.frombuffer(bytes, dtype=np.int16).astype(np.float32) / 32768 @@ -373,7 +369,6 @@ def __init__(self, lan, model=None, logfile=sys.stderr, condition_on_previous_te self.transcribe_kargs['condition_on_previous_text'] = False if condition_on_previous_text is None else condition_on_previous_text def transcribe(self, audio, init_prompt=""): - # tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01) segments, info = self.model.transcribe(audio, language=self.original_language, initial_prompt=init_prompt, word_timestamps=True, **self.transcribe_kargs) return list(segments) From 5c22a36ffa8249f88b083eb92a48c677b22d50af Mon Sep 17 00:00:00 2001 From: AudranBert Date: Tue, 26 Mar 2024 17:41:32 +0100 Subject: [PATCH 03/50] remove torch as requirement for vad for streaming --- whisper/stt/processing/streaming.py | 358 ++++++++++++++++++++++++++-- 1 file changed, 332 insertions(+), 26 deletions(-) diff --git a/whisper/stt/processing/streaming.py b/whisper/stt/processing/streaming.py index d5e0706..2faf0b7 100644 --- a/whisper/stt/processing/streaming.py +++ b/whisper/stt/processing/streaming.py @@ -2,14 +2,16 @@ import sys import string import numpy as np -import torch +import os +import shutil from .text_normalize import normalize_text, remove_emoji, remove_punctuation from .decoding import decode_ct2 from stt import logger, USE_CTRANSLATE2 from websockets.legacy.server import WebSocketServerProtocol -import whisper_timestamped - +_silero_vad_model = {} +_has_onnx = None +_vad_import = None def bytes_to_array(bytes): return np.frombuffer(bytes, dtype=np.int16).astype(np.float32) / 32768 @@ -42,7 +44,7 @@ async def wssDecode(ws: WebSocketServerProtocol, model_and_alignementmodel): else: logger.info("Using whisper_timestamped for decoding") asr = WhisperTimestampedASR(model=model, lan="fr") - online = OnlineASRProcessor(asr, logfile=sys.stderr, buffer_trimming=8) + online = OnlineASRProcessor(asr, logfile=sys.stderr, buffer_trimming=8, use_vad=True, vad_method="auditok") logger.info("Waiting for chunks") while True: try: @@ -139,7 +141,7 @@ class OnlineASRProcessor: SAMPLING_RATE = 16000 - def __init__(self, asr, buffer_trimming=15, logfile=sys.stderr): + def __init__(self, asr, buffer_trimming=15, use_vad=True, vad_method="silero", logfile=sys.stderr): """asr: WhisperASR object tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer. It can be None, if "segment" buffer trimming option is used, then tokenizer is not used at all. ("segment", 15) @@ -152,6 +154,10 @@ def __init__(self, asr, buffer_trimming=15, logfile=sys.stderr): self.init() self.buffer_trimming_sec = buffer_trimming + self.use_vad = use_vad + self.vad_method = vad_method + if self.use_vad and self.vad_method is None: + self.vad_method = "silero" def init(self): """run this when starting or restarting processing""" @@ -191,23 +197,20 @@ def process_iter(self): Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, ""). The non-emty text is confirmed (committed) partial transcript. """ - vad = True prompt, non_prompt = self.prompt() logger.debug(f"PROMPT:{prompt}") logger.debug(f"CONTEXT:{non_prompt}") logger.debug(f"Transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds starting at {self.buffer_time_offset:2.2f}s") # print(f"Transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds starting at {self.buffer_time_offset:2.2f}s") # use VAD to filter out the silence - if vad: - from whisper_timestamped.transcribe import remove_non_speech - tensor_buffer = torch.tensor(self.audio_buffer) - audio_speech, segments, convertion_function = remove_non_speech(tensor_buffer, method="silero", sample_rate=self.SAMPLING_RATE, dilatation=0.5) - audio_speech = audio_speech.numpy() + if self.use_vad: + np_buffer = np.array(self.audio_buffer) + audio_speech, segments, convertion_function = remove_non_speech(np_buffer, method=self.vad_method, sample_rate=self.SAMPLING_RATE, dilatation=0.5) res = self.asr.transcribe(audio_speech, init_prompt=prompt) else: res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt) # transform to [(beg,end,"word1"), ...] - tsw = self.asr.ts_words(res, convertion_function if vad else None) + tsw = self.asr.ts_words(res, convertion_function if self.use_vad else None) self.transcript_buffer.insert(tsw, self.buffer_time_offset) o, buffer = self.transcript_buffer.flush() self.commited.extend(o) @@ -218,7 +221,7 @@ def process_iter(self): logger.debug(f"INCOMPLETE:{self.to_flush(self.transcript_buffer.complete())}") if len(self.audio_buffer)/self.SAMPLING_RATE > self.buffer_trimming_sec: - self.chunk_completed_segment(res, chunk_silence=vad, speech_segments=segments if vad else False) + self.chunk_completed_segment(res, chunk_silence=self.use_vad, speech_segments=segments if self.use_vad else False) logger.debug(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}") return self.to_flush(o), self.to_flush(buffer) @@ -271,9 +274,6 @@ def chunk_completed_segment(self, res, chunk_silence=False, speech_segments=None logger.debug(f"--- not enough segments to chunk") - - - def chunk_at(self, time): """trims the hypothesis and audio buffer at "time" """ @@ -389,10 +389,6 @@ def ts_words(self, segments, timestamps_convert_function=None): def segments_end_ts(self, res): return [s.end for s in res] - def use_vad(self, vad_name=None): - self.transcribe_kargs["vad_filter"] = True - - class WhisperTimestampedASR(ASRBase): """Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper. On the other hand, the installation for GPU could be easier. @@ -407,10 +403,12 @@ def __init__(self, lan, model=None, logfile=sys.stderr, condition_on_previous_te self.transcribe_kargs["best_of"] = None self.transcribe_kargs["temperature"] = 0 self.transcribe_kargs['condition_on_previous_text'] = False if condition_on_previous_text is None else condition_on_previous_text + from whisper_timestamped import transcribe_timestamped + self.transcribe_timestamped = transcribe_timestamped + def transcribe(self, audio, init_prompt=""): - # result = whisper_timestamped.transcribe(self.model, audio, language=self.original_language, **self.transcribe_kargs) - result = whisper_timestamped.transcribe_timestamped(self.model, + result = self.transcribe_timestamped(self.model, audio, language=self.original_language, initial_prompt=init_prompt, **self.transcribe_kargs) return result @@ -432,9 +430,317 @@ def ts_words(self,r, timestamps_convert_function=None): def segments_end_ts(self, res): return [s["end"] for s in res["segments"]] - def use_vad(self, vad_name=None): - if vad_name is None: - self.transcribe_kargs["vad"] = True + +def remove_non_speech(audio, + use_sample=False, + min_speech_duration=0.1, + min_silence_duration=1, + dilatation=0.5, + sample_rate=16000, + method="silero", + avoid_empty_speech=False, + ): + """ + Remove non-speech segments from audio (using Silero VAD), + glue the speech segments together and return the result along with + a function to convert timestamps from the new audio to the original audio + + parameters: + audio: torch.Tensor + audio data *in 16kHz* + use_sample: bool + if True, return start and end in samples instead of seconds + min_speech_duration: float + minimum duration (in sec) of a speech segment + min_silence_duration: float + minimum duration (in sec) of a silence segment + dilatation: float + how much (in sec) to enlarge each speech segment detected by the VAD + method: str + method to use to remove non-speech segments + avoid_empty_speech: bool + if True, avoid returning an empty speech segment (re) + """ + + segments = get_vad_segments( + audio, + sample_rate=sample_rate, + output_sample=True, + min_speech_duration=min_speech_duration, + min_silence_duration=min_silence_duration, + dilatation=dilatation, + method=method, + ) + + segments = [(seg["start"], seg["end"]) for seg in segments] + if len(segments) == 0: + if avoid_empty_speech: + segments = [(0, audio.shape[-1])] else: - self.transcribe_kargs["vad"] = vad_name + np.array([]), [], lambda t, t2 = None: t if t2 is None else [t, t2] + + audio_speech = np.concatenate([audio[..., s:e] for s, e in segments], axis=-1) + # audio_speech = torch.cat([audio[..., s:e] for s,e in segments], dim=-1) + + if not use_sample: + segments = [(float(s)/sample_rate, float(e)/sample_rate) for s,e in segments] + + return audio_speech, segments, lambda t, t2 = None: do_convert_timestamps(segments, t, t2) + +def do_convert_timestamps(segments, t, t2 = None): + """ + Convert timestamp from audio without non-speech segments to original audio (with non-speech segments) + parameters: + segments: list of tuple (start, end) corresponding to non-speech segments in original audio + t: timestamp to convert + t2: second timestamp to convert (optional), when the two timestamps should be in the same segment + """ + assert len(segments) + ioffset = 0 # Input offset + ooffset = 0 # Output offset + ipreviousend = 0 + result = [] + for istart, iend in segments: + ostart = ooffset + oend = ostart + (iend - istart) + ooffset = oend + ioffset += istart - ipreviousend + ipreviousend = iend + t_in = t <= oend + t2_in = t_in if t2 is None else t2 <= oend + if t_in or t2_in: + result.append([ + max(istart, min(iend, ioffset + t)), + max(istart, min(iend, ioffset + t2)) if t2 is not None else None + ]) + if t_in and t2_in: + break + if not len(result): + result.append( + [ioffset + t, ioffset + t2 if t2 is not None else None] + ) + + if len(result) > 1: + # Minimize difference between durations + result = sorted(result, key=lambda x: abs(abs(t2-t) - abs(x[1]-x[0]))) + result = result[0] + if t2 is None: + result = round(result[0], 2) + else: + result = [round(x, 2) for x in result] + return result + + + +def get_vad_segments(audio, + sample_rate=16000, + output_sample=False, + min_speech_duration=0.1, + min_silence_duration=0.1, + dilatation=0.5, + method="silero", + ): + """ + Get speech segments from audio using Silero VAD + parameters: + audio: torch.Tensor + audio data *in 16kHz* + output_sample: bool + if True, return start and end in samples instead of seconds + min_speech_duration: float + minimum duration (in sec) of a speech segment + min_silence_duration: float + minimum duration (in sec) of a silence segment + dilatation: float + how much (in sec) to enlarge each speech segment detected by the VAD + method: str or list + VAD method to use (auditok, silero, silero:v3.1) + """ + global _silero_vad_model, _silero_get_speech_ts, _has_onnx, _vad_import + + if isinstance(method, list): + # Explicit timestamps + segments = [{"start": s * sample_rate, "end": e * sample_rate} for (s, e) in method] + dilatation = 0 + + elif isinstance(method, str) and method.startswith("silero"): + version = None + _, version = check_vad_method(method, True) + # See discussion https://github.com/linto-ai/whisper-timestamped/pull/142/files#r1398326287 + need_folder_hack = version and (version < "v4") + + if _silero_vad_model.get(version) is None: + # ONNX support since 3.1 in silero + if (version is None or version >= "v3.1") and (_has_onnx is not False): + onnx=True + try: + import onnxruntime + onnxruntime.set_default_logger_severity(3) # Remove warning "Removing initializer 'XXX'. It is not used by any node and should be removed from the model." + _has_onnx = True + except ImportError as err: + logger.warning(f"Please install onnxruntime to use more efficiently silero VAD") + _has_onnx = False + onnx=False + else: + onnx=False + + # Choose silero version because of problems with version 4, see https://github.com/linto-ai/whisper-timestamped/issues/74 + torch_home = os.environ.get('TORCH_HOME', '~/.cache/torch') + repo_or_dir_master = os.path.expanduser(torch_home + "/hub/snakers4_silero-vad_master") + repo_or_dir_specific = os.path.expanduser(torch_home + f"/hub/snakers4_silero-vad_{version}") if version else repo_or_dir_master + repo_or_dir = repo_or_dir_specific + tmp_folder = None + def apply_folder_hack(): + nonlocal tmp_folder + if os.path.exists(repo_or_dir_master): + tmp_folder = repo_or_dir_master + ".tmp" + shutil.move(repo_or_dir_master, tmp_folder) + # Make a symlink to the v3.1 model, otherwise it fails + input_exists = os.path.exists(repo_or_dir_specific) + if not input_exists: + # Make dummy file for the symlink to work + os.makedirs(repo_or_dir_specific, exist_ok=True) + os.symlink(repo_or_dir_specific, repo_or_dir_master) + if not input_exists: + shutil.rmtree(repo_or_dir_specific) + + source = "local" + if not os.path.exists(repo_or_dir): + # Load specific version of silero + repo_or_dir = f"snakers4/silero-vad:{version}" if version else "snakers4/silero-vad" + source = "github" + if need_folder_hack: + apply_folder_hack() + try: + if _vad_import is None: + from torch.hub import load as torch_load + _vad_import = torch_load + silero_vad_model, utils = _vad_import(repo_or_dir=repo_or_dir, model="silero_vad", onnx=onnx, source=source) + _silero_vad_model[version] = silero_vad_model + except ImportError as err: + raise RuntimeError(f"Please install what is needed to use the silero VAD (or use another VAD method)") from err + except Exception as err: + raise RuntimeError(f"Problem when installing silero with version {version}. Check versions here: https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models") from err + finally: + if need_folder_hack: + if os.path.exists(repo_or_dir_master): + os.remove(repo_or_dir_master) + if tmp_folder: + shutil.move(tmp_folder, repo_or_dir_master) + assert os.path.isdir(repo_or_dir_specific), f"Unexpected situation: missing {repo_or_dir_specific}" + + _silero_get_speech_ts = utils[0] + + # Cheap normalization of the volume + # audio = audio / max(0.1, audio.abs().max()) + audio = audio / max(0.1, np.max(np.abs(audio))) + + segments = _silero_get_speech_ts(audio, _silero_vad_model[version], + sampling_rate = sample_rate, + min_speech_duration_ms = round(min_speech_duration * 1000), + min_silence_duration_ms = round(min_silence_duration * 1000), + return_seconds = False, + ) + + elif method == "auditok": + # import auditok + if _vad_import is None: + from auditok import split + _vad_import = split + + # Cheap normalization of the volume + # audio = audio / max(0.1, audio.abs().max()) + audio = audio / max(0.1, np.max(np.abs(audio))) + data = (audio * 32767).astype(np.int16).tobytes() + + audio_duration = len(audio) / sample_rate + + segments = _vad_import( + data, + sampling_rate=sample_rate, # sampling frequency in Hz + channels=1, # number of channels + sample_width=2, # number of bytes per sample + min_dur=min_speech_duration, # minimum duration of a valid audio event in seconds + max_dur=audio_duration, # maximum duration of an event + max_silence=min(audio_duration*.95, min_silence_duration), # maximum duration of tolerated continuous silence within an event + energy_threshold=50, + drop_trailing_silence=True, + ) + + segments = [{"start": s._meta.start * sample_rate, "end": s._meta.end * sample_rate} for s in segments] + + else: + raise ValueError(f"Got unexpected VAD method {method}") + + if dilatation > 0: + dilatation = round(dilatation * sample_rate) + new_segments = [] + for seg in segments: + new_seg = { + "start": max(0, seg["start"] - dilatation), + "end": min(len(audio), seg["end"] + dilatation) + } + if len(new_segments) > 0 and new_segments[-1]["end"] >= new_seg["start"]: + new_segments[-1]["end"] = new_seg["end"] + else: + new_segments.append(new_seg) + segments = new_segments + + ratio = 1 if output_sample else 1 / sample_rate + + if ratio != 1: + for seg in segments: + seg["start"] *= ratio + seg["end"] *= ratio + if output_sample: + for seg in segments: + seg["start"] = round(seg["start"]) + seg["end"] = round(seg["end"]) + return segments + +def check_vad_method(method, with_version=False): + """ + Check whether the VAD method is valid and return the method in a consistent format + + method: str or list or True or False + """ + if method in [True, "True", "true"]: + return check_vad_method("silero") # default method + elif method in [None, False, "False", "false", "None", "none"]: + return None + elif not isinstance(method, str) and hasattr(method, '__iter__'): + # list of explicit timestamps + checked_pairs = [] + for s_e in method: + assert len(s_e) == 2, f"Got unexpected element {s_e} in the list of VAD segments. Expect (start, end) pairs" + checked_pairs.append(tuple(s_e)) + return checked_pairs + elif isinstance(method, str) and method.startswith("silero"): + version = None + if method != "silero": + assert method.startswith("silero:"), f"Got unexpected VAD method {method}" + version = method.split(":")[1] + if not version.startswith("v"): + version = "v" + version + try: + assert float(version[1:]) >= 1 + except: + raise ValueError(f"Got unexpected silero version {version} (please check https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models)") + if with_version: + return ("silero", version) + else: + return method + elif method == "auditok": + try: + import auditok + except ImportError: + raise ImportError("Please install auditok to use the auditok VAD (or use another VAD method)") + else: + try: + method = eval(method) + assert hasattr(method, '__iter__') + except: + raise ValueError(f"Got unexpected VAD method {method}") + return check_vad_method(method, with_version=with_version) + return method \ No newline at end of file From e54249d2d687239589e7d4ba438162d1d6bc757c Mon Sep 17 00:00:00 2001 From: AudranBert Date: Wed, 27 Mar 2024 11:00:22 +0100 Subject: [PATCH 04/50] use faster-whisper silero when using faster whisper --- whisper/stt/processing/streaming.py | 322 +---------------------- whisper/stt/processing/streaming_vad.py | 334 ++++++++++++++++++++++++ 2 files changed, 335 insertions(+), 321 deletions(-) create mode 100644 whisper/stt/processing/streaming_vad.py diff --git a/whisper/stt/processing/streaming.py b/whisper/stt/processing/streaming.py index 2faf0b7..e03d78a 100644 --- a/whisper/stt/processing/streaming.py +++ b/whisper/stt/processing/streaming.py @@ -2,16 +2,10 @@ import sys import string import numpy as np -import os -import shutil -from .text_normalize import normalize_text, remove_emoji, remove_punctuation -from .decoding import decode_ct2 +from stt.processing.streaming_vad import remove_non_speech from stt import logger, USE_CTRANSLATE2 from websockets.legacy.server import WebSocketServerProtocol -_silero_vad_model = {} -_has_onnx = None -_vad_import = None def bytes_to_array(bytes): return np.frombuffer(bytes, dtype=np.int16).astype(np.float32) / 32768 @@ -430,317 +424,3 @@ def ts_words(self,r, timestamps_convert_function=None): def segments_end_ts(self, res): return [s["end"] for s in res["segments"]] - -def remove_non_speech(audio, - use_sample=False, - min_speech_duration=0.1, - min_silence_duration=1, - dilatation=0.5, - sample_rate=16000, - method="silero", - avoid_empty_speech=False, - ): - """ - Remove non-speech segments from audio (using Silero VAD), - glue the speech segments together and return the result along with - a function to convert timestamps from the new audio to the original audio - - parameters: - audio: torch.Tensor - audio data *in 16kHz* - use_sample: bool - if True, return start and end in samples instead of seconds - min_speech_duration: float - minimum duration (in sec) of a speech segment - min_silence_duration: float - minimum duration (in sec) of a silence segment - dilatation: float - how much (in sec) to enlarge each speech segment detected by the VAD - method: str - method to use to remove non-speech segments - avoid_empty_speech: bool - if True, avoid returning an empty speech segment (re) - """ - - segments = get_vad_segments( - audio, - sample_rate=sample_rate, - output_sample=True, - min_speech_duration=min_speech_duration, - min_silence_duration=min_silence_duration, - dilatation=dilatation, - method=method, - ) - - segments = [(seg["start"], seg["end"]) for seg in segments] - if len(segments) == 0: - if avoid_empty_speech: - segments = [(0, audio.shape[-1])] - else: - np.array([]), [], lambda t, t2 = None: t if t2 is None else [t, t2] - - audio_speech = np.concatenate([audio[..., s:e] for s, e in segments], axis=-1) - # audio_speech = torch.cat([audio[..., s:e] for s,e in segments], dim=-1) - - if not use_sample: - segments = [(float(s)/sample_rate, float(e)/sample_rate) for s,e in segments] - - return audio_speech, segments, lambda t, t2 = None: do_convert_timestamps(segments, t, t2) - -def do_convert_timestamps(segments, t, t2 = None): - """ - Convert timestamp from audio without non-speech segments to original audio (with non-speech segments) - - parameters: - segments: list of tuple (start, end) corresponding to non-speech segments in original audio - t: timestamp to convert - t2: second timestamp to convert (optional), when the two timestamps should be in the same segment - """ - assert len(segments) - ioffset = 0 # Input offset - ooffset = 0 # Output offset - ipreviousend = 0 - result = [] - for istart, iend in segments: - ostart = ooffset - oend = ostart + (iend - istart) - ooffset = oend - ioffset += istart - ipreviousend - ipreviousend = iend - t_in = t <= oend - t2_in = t_in if t2 is None else t2 <= oend - if t_in or t2_in: - result.append([ - max(istart, min(iend, ioffset + t)), - max(istart, min(iend, ioffset + t2)) if t2 is not None else None - ]) - if t_in and t2_in: - break - if not len(result): - result.append( - [ioffset + t, ioffset + t2 if t2 is not None else None] - ) - - if len(result) > 1: - # Minimize difference between durations - result = sorted(result, key=lambda x: abs(abs(t2-t) - abs(x[1]-x[0]))) - result = result[0] - if t2 is None: - result = round(result[0], 2) - else: - result = [round(x, 2) for x in result] - return result - - - -def get_vad_segments(audio, - sample_rate=16000, - output_sample=False, - min_speech_duration=0.1, - min_silence_duration=0.1, - dilatation=0.5, - method="silero", - ): - """ - Get speech segments from audio using Silero VAD - parameters: - audio: torch.Tensor - audio data *in 16kHz* - output_sample: bool - if True, return start and end in samples instead of seconds - min_speech_duration: float - minimum duration (in sec) of a speech segment - min_silence_duration: float - minimum duration (in sec) of a silence segment - dilatation: float - how much (in sec) to enlarge each speech segment detected by the VAD - method: str or list - VAD method to use (auditok, silero, silero:v3.1) - """ - global _silero_vad_model, _silero_get_speech_ts, _has_onnx, _vad_import - - if isinstance(method, list): - # Explicit timestamps - segments = [{"start": s * sample_rate, "end": e * sample_rate} for (s, e) in method] - dilatation = 0 - - elif isinstance(method, str) and method.startswith("silero"): - version = None - _, version = check_vad_method(method, True) - # See discussion https://github.com/linto-ai/whisper-timestamped/pull/142/files#r1398326287 - need_folder_hack = version and (version < "v4") - - if _silero_vad_model.get(version) is None: - # ONNX support since 3.1 in silero - if (version is None or version >= "v3.1") and (_has_onnx is not False): - onnx=True - try: - import onnxruntime - onnxruntime.set_default_logger_severity(3) # Remove warning "Removing initializer 'XXX'. It is not used by any node and should be removed from the model." - _has_onnx = True - except ImportError as err: - logger.warning(f"Please install onnxruntime to use more efficiently silero VAD") - _has_onnx = False - onnx=False - else: - onnx=False - - # Choose silero version because of problems with version 4, see https://github.com/linto-ai/whisper-timestamped/issues/74 - torch_home = os.environ.get('TORCH_HOME', '~/.cache/torch') - repo_or_dir_master = os.path.expanduser(torch_home + "/hub/snakers4_silero-vad_master") - repo_or_dir_specific = os.path.expanduser(torch_home + f"/hub/snakers4_silero-vad_{version}") if version else repo_or_dir_master - repo_or_dir = repo_or_dir_specific - tmp_folder = None - def apply_folder_hack(): - nonlocal tmp_folder - if os.path.exists(repo_or_dir_master): - tmp_folder = repo_or_dir_master + ".tmp" - shutil.move(repo_or_dir_master, tmp_folder) - # Make a symlink to the v3.1 model, otherwise it fails - input_exists = os.path.exists(repo_or_dir_specific) - if not input_exists: - # Make dummy file for the symlink to work - os.makedirs(repo_or_dir_specific, exist_ok=True) - os.symlink(repo_or_dir_specific, repo_or_dir_master) - if not input_exists: - shutil.rmtree(repo_or_dir_specific) - - source = "local" - if not os.path.exists(repo_or_dir): - # Load specific version of silero - repo_or_dir = f"snakers4/silero-vad:{version}" if version else "snakers4/silero-vad" - source = "github" - if need_folder_hack: - apply_folder_hack() - try: - if _vad_import is None: - from torch.hub import load as torch_load - _vad_import = torch_load - silero_vad_model, utils = _vad_import(repo_or_dir=repo_or_dir, model="silero_vad", onnx=onnx, source=source) - _silero_vad_model[version] = silero_vad_model - except ImportError as err: - raise RuntimeError(f"Please install what is needed to use the silero VAD (or use another VAD method)") from err - except Exception as err: - raise RuntimeError(f"Problem when installing silero with version {version}. Check versions here: https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models") from err - finally: - if need_folder_hack: - if os.path.exists(repo_or_dir_master): - os.remove(repo_or_dir_master) - if tmp_folder: - shutil.move(tmp_folder, repo_or_dir_master) - assert os.path.isdir(repo_or_dir_specific), f"Unexpected situation: missing {repo_or_dir_specific}" - - _silero_get_speech_ts = utils[0] - - # Cheap normalization of the volume - # audio = audio / max(0.1, audio.abs().max()) - audio = audio / max(0.1, np.max(np.abs(audio))) - - segments = _silero_get_speech_ts(audio, _silero_vad_model[version], - sampling_rate = sample_rate, - min_speech_duration_ms = round(min_speech_duration * 1000), - min_silence_duration_ms = round(min_silence_duration * 1000), - return_seconds = False, - ) - - elif method == "auditok": - # import auditok - if _vad_import is None: - from auditok import split - _vad_import = split - - # Cheap normalization of the volume - # audio = audio / max(0.1, audio.abs().max()) - audio = audio / max(0.1, np.max(np.abs(audio))) - data = (audio * 32767).astype(np.int16).tobytes() - - audio_duration = len(audio) / sample_rate - - segments = _vad_import( - data, - sampling_rate=sample_rate, # sampling frequency in Hz - channels=1, # number of channels - sample_width=2, # number of bytes per sample - min_dur=min_speech_duration, # minimum duration of a valid audio event in seconds - max_dur=audio_duration, # maximum duration of an event - max_silence=min(audio_duration*.95, min_silence_duration), # maximum duration of tolerated continuous silence within an event - energy_threshold=50, - drop_trailing_silence=True, - ) - - segments = [{"start": s._meta.start * sample_rate, "end": s._meta.end * sample_rate} for s in segments] - - else: - raise ValueError(f"Got unexpected VAD method {method}") - - if dilatation > 0: - dilatation = round(dilatation * sample_rate) - new_segments = [] - for seg in segments: - new_seg = { - "start": max(0, seg["start"] - dilatation), - "end": min(len(audio), seg["end"] + dilatation) - } - if len(new_segments) > 0 and new_segments[-1]["end"] >= new_seg["start"]: - new_segments[-1]["end"] = new_seg["end"] - else: - new_segments.append(new_seg) - segments = new_segments - - ratio = 1 if output_sample else 1 / sample_rate - - if ratio != 1: - for seg in segments: - seg["start"] *= ratio - seg["end"] *= ratio - if output_sample: - for seg in segments: - seg["start"] = round(seg["start"]) - seg["end"] = round(seg["end"]) - return segments - -def check_vad_method(method, with_version=False): - """ - Check whether the VAD method is valid and return the method in a consistent format - - method: str or list or True or False - """ - if method in [True, "True", "true"]: - return check_vad_method("silero") # default method - elif method in [None, False, "False", "false", "None", "none"]: - return None - elif not isinstance(method, str) and hasattr(method, '__iter__'): - # list of explicit timestamps - checked_pairs = [] - for s_e in method: - assert len(s_e) == 2, f"Got unexpected element {s_e} in the list of VAD segments. Expect (start, end) pairs" - checked_pairs.append(tuple(s_e)) - return checked_pairs - elif isinstance(method, str) and method.startswith("silero"): - version = None - if method != "silero": - assert method.startswith("silero:"), f"Got unexpected VAD method {method}" - version = method.split(":")[1] - if not version.startswith("v"): - version = "v" + version - try: - assert float(version[1:]) >= 1 - except: - raise ValueError(f"Got unexpected silero version {version} (please check https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models)") - if with_version: - return ("silero", version) - else: - return method - elif method == "auditok": - try: - import auditok - except ImportError: - raise ImportError("Please install auditok to use the auditok VAD (or use another VAD method)") - else: - try: - method = eval(method) - assert hasattr(method, '__iter__') - except: - raise ValueError(f"Got unexpected VAD method {method}") - return check_vad_method(method, with_version=with_version) - return method \ No newline at end of file diff --git a/whisper/stt/processing/streaming_vad.py b/whisper/stt/processing/streaming_vad.py new file mode 100644 index 0000000..05bd89a --- /dev/null +++ b/whisper/stt/processing/streaming_vad.py @@ -0,0 +1,334 @@ +import numpy as np +import os +import shutil +from stt import logger, USE_CTRANSLATE2 + + +_silero_vad_model = {} +_has_onnx = None +_vad_import = None + + +def remove_non_speech(audio, + use_sample=False, + min_speech_duration=0.1, + min_silence_duration=1, + dilatation=0.5, + sample_rate=16000, + method="silero", + avoid_empty_speech=False, + ): + """ + Remove non-speech segments from audio (using Silero VAD), + glue the speech segments together and return the result along with + a function to convert timestamps from the new audio to the original audio + + parameters: + audio: torch.Tensor + audio data *in 16kHz* + use_sample: bool + if True, return start and end in samples instead of seconds + min_speech_duration: float + minimum duration (in sec) of a speech segment + min_silence_duration: float + minimum duration (in sec) of a silence segment + dilatation: float + how much (in sec) to enlarge each speech segment detected by the VAD + method: str + method to use to remove non-speech segments + avoid_empty_speech: bool + if True, avoid returning an empty speech segment (re) + """ + + if USE_CTRANSLATE2 and method.startswith("silero"): + from faster_whisper.vad import VadOptions + options = VadOptions( + min_speech_duration=min_speech_duration*1000, + min_silence_duration=min_silence_duration*1000, + ) + from faster_whisper.vad import get_speech_timestamps + segments = get_speech_timestamps(audio, vad_options=options) + else: + segments = get_vad_segments( + audio, + sample_rate=sample_rate, + output_sample=True, + min_speech_duration=min_speech_duration, + min_silence_duration=min_silence_duration, + dilatation=dilatation, + method=method, + ) + + print(segments) + segments = [(seg["start"], seg["end"]) for seg in segments] + if len(segments) == 0: + if avoid_empty_speech: + segments = [(0, audio.shape[-1])] + else: + np.array([]), [], lambda t, t2 = None: t if t2 is None else [t, t2] + + audio_speech = np.concatenate([audio[..., s:e] for s, e in segments], axis=-1) + # audio_speech = torch.cat([audio[..., s:e] for s,e in segments], dim=-1) + + if not use_sample: + segments = [(float(s)/sample_rate, float(e)/sample_rate) for s,e in segments] + + return audio_speech, segments, lambda t, t2 = None: do_convert_timestamps(segments, t, t2) + +def do_convert_timestamps(segments, t, t2 = None): + """ + Convert timestamp from audio without non-speech segments to original audio (with non-speech segments) + + parameters: + segments: list of tuple (start, end) corresponding to non-speech segments in original audio + t: timestamp to convert + t2: second timestamp to convert (optional), when the two timestamps should be in the same segment + """ + assert len(segments) + ioffset = 0 # Input offset + ooffset = 0 # Output offset + ipreviousend = 0 + result = [] + for istart, iend in segments: + ostart = ooffset + oend = ostart + (iend - istart) + ooffset = oend + ioffset += istart - ipreviousend + ipreviousend = iend + t_in = t <= oend + t2_in = t_in if t2 is None else t2 <= oend + if t_in or t2_in: + result.append([ + max(istart, min(iend, ioffset + t)), + max(istart, min(iend, ioffset + t2)) if t2 is not None else None + ]) + if t_in and t2_in: + break + if not len(result): + result.append( + [ioffset + t, ioffset + t2 if t2 is not None else None] + ) + + if len(result) > 1: + # Minimize difference between durations + result = sorted(result, key=lambda x: abs(abs(t2-t) - abs(x[1]-x[0]))) + result = result[0] + if t2 is None: + result = round(result[0], 2) + else: + result = [round(x, 2) for x in result] + return result + + + +def get_vad_segments(audio, + sample_rate=16000, + output_sample=False, + min_speech_duration=0.1, + min_silence_duration=0.1, + dilatation=0.5, + method="silero", + ): + """ + Get speech segments from audio using Silero VAD + parameters: + audio: torch.Tensor + audio data *in 16kHz* + output_sample: bool + if True, return start and end in samples instead of seconds + min_speech_duration: float + minimum duration (in sec) of a speech segment + min_silence_duration: float + minimum duration (in sec) of a silence segment + dilatation: float + how much (in sec) to enlarge each speech segment detected by the VAD + method: str or list + VAD method to use (auditok, silero, silero:v3.1) + """ + global _silero_vad_model, _silero_get_speech_ts, _has_onnx, _vad_import + + if isinstance(method, list): + # Explicit timestamps + segments = [{"start": s * sample_rate, "end": e * sample_rate} for (s, e) in method] + dilatation = 0 + + elif isinstance(method, str) and method.startswith("silero"): + version = None + _, version = check_vad_method(method, True) + # See discussion https://github.com/linto-ai/whisper-timestamped/pull/142/files#r1398326287 + need_folder_hack = version and (version < "v4") + + if _silero_vad_model.get(version) is None: + # ONNX support since 3.1 in silero + if (version is None or version >= "v3.1") and (_has_onnx is not False): + onnx=True + try: + import onnxruntime + onnxruntime.set_default_logger_severity(3) # Remove warning "Removing initializer 'XXX'. It is not used by any node and should be removed from the model." + _has_onnx = True + except ImportError as err: + logger.warning(f"Please install onnxruntime to use more efficiently silero VAD") + _has_onnx = False + onnx=False + else: + onnx=False + + # Choose silero version because of problems with version 4, see https://github.com/linto-ai/whisper-timestamped/issues/74 + torch_home = os.environ.get('TORCH_HOME', '~/.cache/torch') + repo_or_dir_master = os.path.expanduser(torch_home + "/hub/snakers4_silero-vad_master") + repo_or_dir_specific = os.path.expanduser(torch_home + f"/hub/snakers4_silero-vad_{version}") if version else repo_or_dir_master + repo_or_dir = repo_or_dir_specific + tmp_folder = None + def apply_folder_hack(): + nonlocal tmp_folder + if os.path.exists(repo_or_dir_master): + tmp_folder = repo_or_dir_master + ".tmp" + shutil.move(repo_or_dir_master, tmp_folder) + # Make a symlink to the v3.1 model, otherwise it fails + input_exists = os.path.exists(repo_or_dir_specific) + if not input_exists: + # Make dummy file for the symlink to work + os.makedirs(repo_or_dir_specific, exist_ok=True) + os.symlink(repo_or_dir_specific, repo_or_dir_master) + if not input_exists: + shutil.rmtree(repo_or_dir_specific) + + source = "local" + if not os.path.exists(repo_or_dir): + # Load specific version of silero + repo_or_dir = f"snakers4/silero-vad:{version}" if version else "snakers4/silero-vad" + source = "github" + if need_folder_hack: + apply_folder_hack() + try: + if _vad_import is None: + from torch.hub import load as torch_load + _vad_import = torch_load + silero_vad_model, utils = _vad_import(repo_or_dir=repo_or_dir, model="silero_vad", onnx=onnx, source=source) + _silero_vad_model[version] = silero_vad_model + except ImportError as err: + raise RuntimeError(f"Please install what is needed to use the silero VAD (or use another VAD method)") from err + except Exception as err: + raise RuntimeError(f"Problem when installing silero with version {version}. Check versions here: https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models") from err + finally: + if need_folder_hack: + if os.path.exists(repo_or_dir_master): + os.remove(repo_or_dir_master) + if tmp_folder: + shutil.move(tmp_folder, repo_or_dir_master) + assert os.path.isdir(repo_or_dir_specific), f"Unexpected situation: missing {repo_or_dir_specific}" + + _silero_get_speech_ts = utils[0] + + # Cheap normalization of the volume + # audio = audio / max(0.1, audio.abs().max()) + audio = audio / max(0.1, np.max(np.abs(audio))) + + segments = _silero_get_speech_ts(audio, _silero_vad_model[version], + sampling_rate = sample_rate, + min_speech_duration_ms = round(min_speech_duration * 1000), + min_silence_duration_ms = round(min_silence_duration * 1000), + return_seconds = False, + ) + + elif method == "auditok": + # import auditok + if _vad_import is None: + from auditok import split + _vad_import = split + + # Cheap normalization of the volume + # audio = audio / max(0.1, audio.abs().max()) + audio = audio / max(0.1, np.max(np.abs(audio))) + data = (audio * 32767).astype(np.int16).tobytes() + + audio_duration = len(audio) / sample_rate + + segments = _vad_import( + data, + sampling_rate=sample_rate, # sampling frequency in Hz + channels=1, # number of channels + sample_width=2, # number of bytes per sample + min_dur=min_speech_duration, # minimum duration of a valid audio event in seconds + max_dur=audio_duration, # maximum duration of an event + max_silence=min(audio_duration*.95, min_silence_duration), # maximum duration of tolerated continuous silence within an event + energy_threshold=50, + drop_trailing_silence=True, + ) + + segments = [{"start": s._meta.start * sample_rate, "end": s._meta.end * sample_rate} for s in segments] + + else: + raise ValueError(f"Got unexpected VAD method {method}") + + if dilatation > 0: + dilatation = round(dilatation * sample_rate) + new_segments = [] + for seg in segments: + new_seg = { + "start": max(0, seg["start"] - dilatation), + "end": min(len(audio), seg["end"] + dilatation) + } + if len(new_segments) > 0 and new_segments[-1]["end"] >= new_seg["start"]: + new_segments[-1]["end"] = new_seg["end"] + else: + new_segments.append(new_seg) + segments = new_segments + + ratio = 1 if output_sample else 1 / sample_rate + + if ratio != 1: + for seg in segments: + seg["start"] *= ratio + seg["end"] *= ratio + if output_sample: + for seg in segments: + seg["start"] = round(seg["start"]) + seg["end"] = round(seg["end"]) + return segments + +def check_vad_method(method, with_version=False): + """ + Check whether the VAD method is valid and return the method in a consistent format + + method: str or list or True or False + """ + if method in [True, "True", "true"]: + return check_vad_method("silero") # default method + elif method in [None, False, "False", "false", "None", "none"]: + return None + elif not isinstance(method, str) and hasattr(method, '__iter__'): + # list of explicit timestamps + checked_pairs = [] + for s_e in method: + assert len(s_e) == 2, f"Got unexpected element {s_e} in the list of VAD segments. Expect (start, end) pairs" + checked_pairs.append(tuple(s_e)) + return checked_pairs + elif isinstance(method, str) and method.startswith("silero"): + version = None + if method != "silero": + assert method.startswith("silero:"), f"Got unexpected VAD method {method}" + version = method.split(":")[1] + if not version.startswith("v"): + version = "v" + version + try: + assert float(version[1:]) >= 1 + except: + raise ValueError(f"Got unexpected silero version {version} (please check https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models)") + if with_version: + return ("silero", version) + else: + return method + elif method == "auditok": + try: + import auditok + except ImportError: + raise ImportError("Please install auditok to use the auditok VAD (or use another VAD method)") + else: + try: + method = eval(method) + assert hasattr(method, '__iter__') + except: + raise ValueError(f"Got unexpected VAD method {method}") + return check_vad_method(method, with_version=with_version) + return method \ No newline at end of file From 800f48e67e8399dfa107fdcf33a8e5951fa43614 Mon Sep 17 00:00:00 2001 From: AudranBert Date: Wed, 27 Mar 2024 14:08:22 +0100 Subject: [PATCH 05/50] refactor USE_VAD --- whisper/stt/__init__.py | 7 ++++++ whisper/stt/processing/__init__.py | 5 ++++- whisper/stt/processing/decoding.py | 7 +++--- whisper/stt/processing/streaming.py | 29 +++++++++++++++---------- whisper/stt/processing/streaming_vad.py | 7 +++--- 5 files changed, 34 insertions(+), 21 deletions(-) diff --git a/whisper/stt/__init__.py b/whisper/stt/__init__.py index f5551af..cae6274 100644 --- a/whisper/stt/__init__.py +++ b/whisper/stt/__init__.py @@ -12,6 +12,13 @@ # see https://github.com/guillaumekln/faster-whisper/issues/150 os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # GPU in the right order +if os.environ.get("USE_VAD","auditok") in [True, "true", 1]: + USE_VAD = "auditok" +elif os.environ.get("USE_VAD","auditok") in [False, "false", 0]: + USE_VAD = False +else: + USE_VAD = os.environ.get("USE_VAD","auditok") + try: import faster_whisper diff --git a/whisper/stt/processing/__init__.py b/whisper/stt/processing/__init__.py index 3af614e..116984b 100644 --- a/whisper/stt/processing/__init__.py +++ b/whisper/stt/processing/__init__.py @@ -2,7 +2,7 @@ import os from lockfile import FileLock -from stt import USE_CTRANSLATE2, logger +from stt import USE_CTRANSLATE2, USE_VAD, logger from .alignment_model import get_alignment_model, load_alignment_model from .decoding import decode @@ -51,6 +51,9 @@ def __call__(self, *args, **kwargs): language = get_language() logger.info(f"Using language {language}") +logger.info(f"USE_VAD={USE_VAD}") +logger.info(f"USE_CTRANSLATE2={USE_CTRANSLATE2}") + # Load ASR model model_type = os.environ.get("MODEL", "medium") logger.info( diff --git a/whisper/stt/processing/decoding.py b/whisper/stt/processing/decoding.py index 5f032e4..e87e89b 100644 --- a/whisper/stt/processing/decoding.py +++ b/whisper/stt/processing/decoding.py @@ -5,7 +5,7 @@ from typing import Tuple, Union import numpy as np -from stt import USE_CTRANSLATE2, logger +from stt import USE_CTRANSLATE2, USE_VAD, logger from .alignment_model import get_alignment_model, load_alignment_model from .text_normalize import normalize_text, remove_emoji, remove_punctuation @@ -17,7 +17,6 @@ import whisper_timestamped USE_ACCURATE = True -USE_VAD = True if USE_ACCURATE: default_beam_size = 5 @@ -80,7 +79,7 @@ def decode_ct2( kwargs["beam_size"] = 1 if kwargs.get("best_of") is None: kwargs["best_of"] = 1 - + logger.info(f"Transcribing...") segments, info = model.transcribe( audio, word_timestamps=with_word_timestamps, @@ -90,7 +89,7 @@ def decode_ct2( vad_filter=USE_VAD, **kwargs, ) - + logger.info(f"Transcription done.") segments = list(segments) return format_faster_whisper_response( diff --git a/whisper/stt/processing/streaming.py b/whisper/stt/processing/streaming.py index e03d78a..8464229 100644 --- a/whisper/stt/processing/streaming.py +++ b/whisper/stt/processing/streaming.py @@ -3,9 +3,9 @@ import string import numpy as np from stt.processing.streaming_vad import remove_non_speech -from stt import logger, USE_CTRANSLATE2 +from stt import logger, USE_CTRANSLATE2, USE_VAD from websockets.legacy.server import WebSocketServerProtocol - +from simple_websocket.ws import Server as WSServer def bytes_to_array(bytes): return np.frombuffer(bytes, dtype=np.int16).astype(np.float32) / 32768 @@ -31,14 +31,14 @@ async def wssDecode(ws: WebSocketServerProtocol, model_and_alignementmodel): except Exception as e: logger.error("Failed to read stream configuration") await ws.close(reason="Failed to load configuration") - model, alignementmodel = model_and_alignementmodel + model, _ = model_and_alignementmodel if USE_CTRANSLATE2: logger.info("Using ctranslate2 for decoding") asr = FasterWhisperASR(model=model, lan="fr") else: logger.info("Using whisper_timestamped for decoding") asr = WhisperTimestampedASR(model=model, lan="fr") - online = OnlineASRProcessor(asr, logfile=sys.stderr, buffer_trimming=8, use_vad=True, vad_method="auditok") + online = OnlineASRProcessor(asr, logfile=sys.stderr, buffer_trimming=8, use_vad=USE_VAD) logger.info("Waiting for chunks") while True: try: @@ -50,7 +50,6 @@ async def wssDecode(ws: WebSocketServerProtocol, model_and_alignementmodel): print("Connection closed by client: {}".format(str(e))) break if "eof" in str(message): - # await ws.send(json.dumps("")) o = online.finish() await ws.send(whisper_to_json(o)) logger.info(f"End of stream {message}") @@ -60,7 +59,16 @@ async def wssDecode(ws: WebSocketServerProtocol, model_and_alignementmodel): o, _ = online.process_iter() logger.info(o) await ws.send(whisper_to_json(o)) + + +def ws_streaming(websocket_server: WSServer, model): + """Sync Decode function endpoint""" + # Wait for config + res = websocket_server.receive(timeout=10) + # Timeout + if res is None: + pass class HypothesisBuffer: @@ -135,7 +143,7 @@ class OnlineASRProcessor: SAMPLING_RATE = 16000 - def __init__(self, asr, buffer_trimming=15, use_vad=True, vad_method="silero", logfile=sys.stderr): + def __init__(self, asr, buffer_trimming=15, use_vad="auditok", logfile=sys.stderr): """asr: WhisperASR object tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer. It can be None, if "segment" buffer trimming option is used, then tokenizer is not used at all. ("segment", 15) @@ -149,9 +157,6 @@ def __init__(self, asr, buffer_trimming=15, use_vad=True, vad_method="silero", l self.buffer_trimming_sec = buffer_trimming self.use_vad = use_vad - self.vad_method = vad_method - if self.use_vad and self.vad_method is None: - self.vad_method = "silero" def init(self): """run this when starting or restarting processing""" @@ -196,10 +201,10 @@ def process_iter(self): logger.debug(f"CONTEXT:{non_prompt}") logger.debug(f"Transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds starting at {self.buffer_time_offset:2.2f}s") # print(f"Transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds starting at {self.buffer_time_offset:2.2f}s") - # use VAD to filter out the silence + # use VAD to filter out the silence if self.use_vad: np_buffer = np.array(self.audio_buffer) - audio_speech, segments, convertion_function = remove_non_speech(np_buffer, method=self.vad_method, sample_rate=self.SAMPLING_RATE, dilatation=0.5) + audio_speech, segments, convertion_function = remove_non_speech(np_buffer, method=self.use_vad, sample_rate=self.SAMPLING_RATE, dilatation=0.5) res = self.asr.transcribe(audio_speech, init_prompt=prompt) else: res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt) @@ -208,8 +213,8 @@ def process_iter(self): self.transcript_buffer.insert(tsw, self.buffer_time_offset) o, buffer = self.transcript_buffer.flush() self.commited.extend(o) - # print(f"{buffer}") if buffer and (self.buffer_time_offset+len(self.audio_buffer)/self.SAMPLING_RATE)-buffer[-1][1]<0.05: + # remove the last word if it is too close to the end of the buffer buffer.pop(-1) logger.debug(f">>>>COMPLETE NOW:{self.to_flush(o)}") logger.debug(f"INCOMPLETE:{self.to_flush(self.transcript_buffer.complete())}") diff --git a/whisper/stt/processing/streaming_vad.py b/whisper/stt/processing/streaming_vad.py index 05bd89a..149de6d 100644 --- a/whisper/stt/processing/streaming_vad.py +++ b/whisper/stt/processing/streaming_vad.py @@ -43,8 +43,8 @@ def remove_non_speech(audio, if USE_CTRANSLATE2 and method.startswith("silero"): from faster_whisper.vad import VadOptions options = VadOptions( - min_speech_duration=min_speech_duration*1000, - min_silence_duration=min_silence_duration*1000, + min_speech_duration_ms =min_speech_duration*1000, + min_silence_duration_ms =min_silence_duration*1000, ) from faster_whisper.vad import get_speech_timestamps segments = get_speech_timestamps(audio, vad_options=options) @@ -59,7 +59,6 @@ def remove_non_speech(audio, method=method, ) - print(segments) segments = [(seg["start"], seg["end"]) for seg in segments] if len(segments) == 0: if avoid_empty_speech: @@ -130,7 +129,7 @@ def get_vad_segments(audio, method="silero", ): """ - Get speech segments from audio using Silero VAD + Get speech segments from audio using the method VAD parameters: audio: torch.Tensor audio data *in 16kHz* From d2a1be671bca6f6b7e97671298de87a5ec50fff9 Mon Sep 17 00:00:00 2001 From: AudranBert Date: Wed, 27 Mar 2024 14:35:44 +0100 Subject: [PATCH 06/50] add auditok in dockerfiles --- whisper/Dockerfile.ctranslate2.cpu | 2 +- whisper/Dockerfile.torch.cpu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/whisper/Dockerfile.ctranslate2.cpu b/whisper/Dockerfile.ctranslate2.cpu index df5eac7..1e5c544 100644 --- a/whisper/Dockerfile.ctranslate2.cpu +++ b/whisper/Dockerfile.ctranslate2.cpu @@ -6,7 +6,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins # Install python dependencies COPY whisper/requirements.ctranslate2.txt ./ RUN pip install --no-cache-dir -r requirements.ctranslate2.txt && rm requirements.ctranslate2.txt - +RUN pip install --no-cache-dir auditok WORKDIR /usr/src/app COPY celery_app /usr/src/app/celery_app diff --git a/whisper/Dockerfile.torch.cpu b/whisper/Dockerfile.torch.cpu index 17a3fb8..4a0f97f 100644 --- a/whisper/Dockerfile.torch.cpu +++ b/whisper/Dockerfile.torch.cpu @@ -12,7 +12,7 @@ RUN pip3 install \ # Install python dependencies COPY whisper/requirements.torch.txt ./ RUN pip install --no-cache-dir -r requirements.torch.txt && rm requirements.torch.txt - +RUN pip install --no-cache-dir auditok WORKDIR /usr/src/app COPY celery_app /usr/src/app/celery_app From 8b4015799c39b1ed4382d8f4f82c15bbd0034997 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Wed, 27 Mar 2024 15:26:51 +0100 Subject: [PATCH 07/50] Control the number of threads --- whisper/stt/__init__.py | 23 ++++++++++++++++++++++- whisper/stt/processing/__init__.py | 14 +++++++++----- whisper/stt/processing/decoding.py | 3 +-- 3 files changed, 32 insertions(+), 8 deletions(-) diff --git a/whisper/stt/__init__.py b/whisper/stt/__init__.py index cae6274..70058dd 100644 --- a/whisper/stt/__init__.py +++ b/whisper/stt/__init__.py @@ -18,7 +18,10 @@ USE_VAD = False else: USE_VAD = os.environ.get("USE_VAD","auditok") - + +NUM_THREADS = os.environ.get("NUM_THREADS", os.environ.get("OMP_NUM_THREADS")) +NUM_THREADS = int(NUM_THREADS) + try: import faster_whisper @@ -43,3 +46,21 @@ USE_TORCHAUDIO = True except ImportError: USE_TORCHAUDIO = False + +if USE_CTRANSLATE2: + def set_num_threads(n): + # os.environ["OMP_NUM_THREADS"] = str(n) + pass +else: + import torch + DEFAULT_NUM_THREADS = torch.get_num_threads() + def set_num_threads(n): + torch.set_num_threads(n) + +# Number of CPU threads +if NUM_THREADS is None: + NUM_THREADS = DEFAULT_NUM_THREADS +if NUM_THREADS is not None: + NUM_THREADS = int(NUM_THREADS) +# For Torch, we will set it afterward, because setting that before loading the model can hang the process (see https://github.com/pytorch/pytorch/issues/58962) +set_num_threads(1) diff --git a/whisper/stt/processing/__init__.py b/whisper/stt/processing/__init__.py index 116984b..edaea94 100644 --- a/whisper/stt/processing/__init__.py +++ b/whisper/stt/processing/__init__.py @@ -2,7 +2,7 @@ import os from lockfile import FileLock -from stt import USE_CTRANSLATE2, USE_VAD, logger +from stt import USE_CTRANSLATE2, USE_VAD, logger, set_num_threads, NUM_THREADS from .alignment_model import get_alignment_model, load_alignment_model from .decoding import decode @@ -20,10 +20,12 @@ class LazyLoadedModel: - def __init__(self, model_type, device): + def __init__(self, model_type, device, num_threads): self.model_type = model_type self.device = device + self.num_threads = num_threads self._model = None + self.has_set_num_threads = False def check_loaded(self): if self._model is None: @@ -36,6 +38,9 @@ def __getattr__(self, name): return getattr(self._model, name) def __call__(self, *args, **kwargs): + if not self.has_set_num_threads and self.num_threads: + set_num_threads(self.num_threads) + self.has_set_num_threads = True self.check_loaded() return self._model(*args, **kwargs) @@ -60,10 +65,9 @@ def __call__(self, *args, **kwargs): f"Loading Whisper model {model_type} ({'local' if os.path.exists(model_type) else 'remote'})..." ) try: - model = LazyLoadedModel(model_type, device=device) - if os.environ.get("ENABLE_STREAMING", False) in [True, "true", 1]: + model = LazyLoadedModel(model_type, device=device, num_threads=NUM_THREADS) + if not USE_CTRANSLATE2 or device.lower() != "cpu": model.check_loaded() - # model = load_whisper_model(model_type, device=device) except Exception as err: raise Exception("Failed to load transcription model: {}".format(str(err))) from err diff --git a/whisper/stt/processing/decoding.py b/whisper/stt/processing/decoding.py index e87e89b..2113d59 100644 --- a/whisper/stt/processing/decoding.py +++ b/whisper/stt/processing/decoding.py @@ -63,7 +63,6 @@ def decode( kwargs.pop("alignment_model") res = decode_ct2(**kwargs) else: - print("OK") res = decode_torch(**kwargs) logger.info("Transcription complete (t={}s)".format(time.time() - start_t)) @@ -89,8 +88,8 @@ def decode_ct2( vad_filter=USE_VAD, **kwargs, ) - logger.info(f"Transcription done.") segments = list(segments) + logger.info(f"Transcription done.") return format_faster_whisper_response( segments, info, remove_punctuation_from_words=remove_punctuation_from_words From a456e2ca339358c87f738c7f6273fa2147b02b8b Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Wed, 27 Mar 2024 15:27:08 +0100 Subject: [PATCH 08/50] Update doc with NUM_THREADS variable --- whisper/.envdefault | 2 +- whisper/README.md | 27 ++++++++++++++------------- whisper/RELEASE.md | 5 +++++ 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/whisper/.envdefault b/whisper/.envdefault index 75919c0..9a40b79 100644 --- a/whisper/.envdefault +++ b/whisper/.envdefault @@ -40,7 +40,7 @@ PROMPT= # CUDA_VISIBLE_DEVICES=0 # Number of threads per worker when running on CPU -OMP_NUM_THREADS=4 +NUM_THREADS=4 # Number of workers CONCURRENCY=2 diff --git a/whisper/README.md b/whisper/README.md index 5649542..1f1bf92 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -114,17 +114,18 @@ cp whisper/.envdefault whisper/.env | PARAMETER | DESCRIPTION | EXEMPLE | |---|---|---| -| SERVICE_MODE | STT serving mode see [Serving mode](#serving-mode) | `http` \| `task` | -| MODEL | Path to a Whisper model, type of Whisper model used, or HuggingFace identifier of a Whisper model. | `large-v3` \| `distil-whisper/distil-large-v2` \| \ \| ... | -| LANGUAGE | (Optional) Language to recognize | `*` \| `fr` \| `fr-FR` \| `French` \| `en` \| `en-US` \| `English` \| ... | -| PROMPT | (Optional) Prompt to use for the Whisper model | `some free text to encourage a certain transcription style (disfluencies, no punctuation, ...)` | -| ALIGNMENT_MODEL | (Optional and deprecated) Path to the wav2vec model for word alignment, or name of HuggingFace repository or torchaudio pipeline | `WAV2VEC2_ASR_BASE_960H` \| `jonatasgrosman/wav2vec2-large-xlsr-53-english` \| \ \| ... | -| DEVICE | (Optional) Device to use for the model | `cpu` \| `cuda` ... | -| CUDA_VISIBLE_DEVICES | (Optional) GPU device index to use, if several. We also recommend to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` on multi-GPU machines | `0` \| `1` \| `2` \| ... | +| SERVICE_MODE | (Required) STT serving mode see [Serving mode](#serving-mode) | `http` \| `task` | +| MODEL | (Required) Path to a Whisper model, type of Whisper model used, or HuggingFace identifier of a Whisper model. | `large-v3` \| `distil-whisper/distil-large-v2` \| \ \| ... | +| LANGUAGE | Language to recognize | `*` \| `fr` \| `fr-FR` \| `French` \| `en` \| `en-US` \| `English` \| ... | +| PROMPT | Prompt to use for the Whisper model | `some free text to encourage a certain transcription style (disfluencies, no punctuation, ...)` | +| DEVICE | Device to use for the model (by default, GPU/CUDA is used if it is available, CPU otherwise) | `cpu` \| `cuda` | +| NUM_THREADS | Number of threads (maximum) to use for things running on CPU | `1` \| `4` \| ... | +| CUDA_VISIBLE_DEVICES | GPU device index to use, when running on GPU/CUDA. We also recommend to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` on multi-GPU machines | `0` \| `1` \| `2` \| ... | | CONCURRENCY | Maximum number of parallel requests | `2` | -| SERVICE_NAME | (For the task mode) queue's name for task processing | `my-stt` | -| SERVICE_BROKER | (For the task mode) URL of the message broker | `redis://my-broker:6379` | +| SERVICE_NAME | (For the task mode only) queue's name for task processing | `my-stt` | +| SERVICE_BROKER | (For the task mode only) URL of the message broker | `redis://my-broker:6379` | | BROKER_PASS | (For the task mode only) broker password | `my-password` \| (empty) | +| ALIGNMENT_MODEL | (Deprecated) Path to the wav2vec model for word alignment, or name of HuggingFace repository or torchaudio pipeline | `WAV2VEC2_ASR_BASE_960H` \| `jonatasgrosman/wav2vec2-large-xlsr-53-english` \| \ \| ... | #### MODEL environment variable @@ -217,9 +218,9 @@ You may also want to add specific options: | Variables | Description | Example | |:-|:-|:-| | `HOST_SERVING_PORT` | Host serving port | 8080 | -| `` | (Optional) Path to a folder to download wav2vec alignment models when relevant | /home/username/.cache | +| `` | Path to a folder to download wav2vec alignment models when relevant | /home/username/.cache | | `` | Path to the Whisper model on the host machine mounted to /opt/model.pt | /my/path/to/models/medium.pt | -| `` | (Optional) Path to a folder to a custom wav2vec alignment model | /my/path/to/models/wav2vec | +| `` | Path to a folder to a custom wav2vec alignment model | /my/path/to/models/wav2vec | ### Micro-service within LinTO-Platform stack The TASK serving mode connect a celery worker to a message broker. @@ -248,9 +249,9 @@ You may also want to add specific options: | Variables | Description | Example | |:-|:-|:-| | `` | Shared audio folder mounted to /opt/audio | /my/path/to/models/vosk-model | -| `` | (Optional) Path to a folder to download wav2vec alignment models when relevant | /home/username/.cache | +| `` | Path to a folder to download wav2vec alignment models when relevant | /home/username/.cache | | `` | Path to the Whisper model on the host machine mounted to /opt/model.pt | /my/path/to/models/medium.pt | -| `` | (Optional) Path to a folder to a custom wav2vec alignment model | /my/path/to/models/wav2vec | +| `` | Path to a folder to a custom wav2vec alignment model | /my/path/to/models/wav2vec | ### Websocket Server Websocket server's mode deploy a streaming transcription service only. diff --git a/whisper/RELEASE.md b/whisper/RELEASE.md index f54537e..1f5688c 100644 --- a/whisper/RELEASE.md +++ b/whisper/RELEASE.md @@ -1,3 +1,8 @@ +# 1.0.3 +- Streaming support +- New NUM_THREADS env variable to control the number of threads +- Load the model when launching the service (not at the first request) + # 1.0.2 - ct2/faster_whisper: Upgrade faster_whisper and support recent distilled models - ct2/faster_whisper: Fix possible gluing of different words together From c60bd8ea2a7203b0fc0963bda569528857cc9552 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Wed, 27 Mar 2024 15:53:41 +0100 Subject: [PATCH 09/50] cosm --- whisper/stt/processing/__init__.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/whisper/stt/processing/__init__.py b/whisper/stt/processing/__init__.py index edaea94..a632a52 100644 --- a/whisper/stt/processing/__init__.py +++ b/whisper/stt/processing/__init__.py @@ -33,15 +33,19 @@ def check_loaded(self): with FileLock(lockfile): self._model = load_whisper_model(self.model_type, device=self.device) + def check_num_threads(self): + if not self.has_set_num_threads and self.num_threads: + set_num_threads(self.num_threads) + self.has_set_num_threads = True + def __getattr__(self, name): self.check_loaded() + self.check_num_threads() return getattr(self._model, name) def __call__(self, *args, **kwargs): - if not self.has_set_num_threads and self.num_threads: - set_num_threads(self.num_threads) - self.has_set_num_threads = True self.check_loaded() + self.check_num_threads() return self._model(*args, **kwargs) From 8f827a103928086ba907d5622569e4564b99fd01 Mon Sep 17 00:00:00 2001 From: AudranBert Date: Wed, 27 Mar 2024 15:44:52 +0100 Subject: [PATCH 10/50] add auditok to requirements --- whisper/Dockerfile.ctranslate2.cpu | 1 - whisper/Dockerfile.torch.cpu | 1 - whisper/requirements.ctranslate2.txt | 1 + whisper/requirements.torch.txt | 3 ++- 4 files changed, 3 insertions(+), 3 deletions(-) diff --git a/whisper/Dockerfile.ctranslate2.cpu b/whisper/Dockerfile.ctranslate2.cpu index 1e5c544..5c4817c 100644 --- a/whisper/Dockerfile.ctranslate2.cpu +++ b/whisper/Dockerfile.ctranslate2.cpu @@ -6,7 +6,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins # Install python dependencies COPY whisper/requirements.ctranslate2.txt ./ RUN pip install --no-cache-dir -r requirements.ctranslate2.txt && rm requirements.ctranslate2.txt -RUN pip install --no-cache-dir auditok WORKDIR /usr/src/app COPY celery_app /usr/src/app/celery_app diff --git a/whisper/Dockerfile.torch.cpu b/whisper/Dockerfile.torch.cpu index 4a0f97f..549a767 100644 --- a/whisper/Dockerfile.torch.cpu +++ b/whisper/Dockerfile.torch.cpu @@ -12,7 +12,6 @@ RUN pip3 install \ # Install python dependencies COPY whisper/requirements.torch.txt ./ RUN pip install --no-cache-dir -r requirements.torch.txt && rm requirements.torch.txt -RUN pip install --no-cache-dir auditok WORKDIR /usr/src/app COPY celery_app /usr/src/app/celery_app diff --git a/whisper/requirements.ctranslate2.txt b/whisper/requirements.ctranslate2.txt index e471fd7..87b5e80 100644 --- a/whisper/requirements.ctranslate2.txt +++ b/whisper/requirements.ctranslate2.txt @@ -11,6 +11,7 @@ regex requests>=2.26.0 wavio>=0.0.4 websockets +auditok #faster_whisper==1.0.1 # This is version faster_whisper==1.0.1 + option for (persistent) prompt + fix for large-v3 git+https://github.com/linto-ai/faster-whisper.git \ No newline at end of file diff --git a/whisper/requirements.torch.txt b/whisper/requirements.torch.txt index 3976414..e5f5f93 100644 --- a/whisper/requirements.torch.txt +++ b/whisper/requirements.torch.txt @@ -15,4 +15,5 @@ wavio>=0.0.4 websockets whisper-timestamped onnxruntime -torchaudio \ No newline at end of file +torchaudio +auditok \ No newline at end of file From 70c1a7c41f11f9815f12a871b7799ef8726369f9 Mon Sep 17 00:00:00 2001 From: AudranBert Date: Wed, 27 Mar 2024 17:43:32 +0100 Subject: [PATCH 11/50] add streaming through http mode --- whisper/stt/processing/decoding.py | 2 -- whisper/stt/processing/streaming.py | 44 ++++++++++++++++++++++++----- 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/whisper/stt/processing/decoding.py b/whisper/stt/processing/decoding.py index 2113d59..8230000 100644 --- a/whisper/stt/processing/decoding.py +++ b/whisper/stt/processing/decoding.py @@ -78,7 +78,6 @@ def decode_ct2( kwargs["beam_size"] = 1 if kwargs.get("best_of") is None: kwargs["best_of"] = 1 - logger.info(f"Transcribing...") segments, info = model.transcribe( audio, word_timestamps=with_word_timestamps, @@ -89,7 +88,6 @@ def decode_ct2( **kwargs, ) segments = list(segments) - logger.info(f"Transcription done.") return format_faster_whisper_response( segments, info, remove_punctuation_from_words=remove_punctuation_from_words diff --git a/whisper/stt/processing/streaming.py b/whisper/stt/processing/streaming.py index 8464229..9a4d7c6 100644 --- a/whisper/stt/processing/streaming.py +++ b/whisper/stt/processing/streaming.py @@ -47,7 +47,7 @@ async def wssDecode(ws: WebSocketServerProtocol, model_and_alignementmodel): logger.info("Connection closed by client") ws.close() except Exception as e: - print("Connection closed by client: {}".format(str(e))) + logger.info(f"Connection closed by client: {e}") break if "eof" in str(message): o = online.finish() @@ -61,14 +61,44 @@ async def wssDecode(ws: WebSocketServerProtocol, model_and_alignementmodel): await ws.send(whisper_to_json(o)) -def ws_streaming(websocket_server: WSServer, model): +def ws_streaming(websocket_server: WSServer, model_and_alignementmodel): """Sync Decode function endpoint""" - # Wait for config res = websocket_server.receive(timeout=10) - - # Timeout - if res is None: - pass + try: + config = json.loads(res)["config"] + sample_rate = config["sample_rate"] + logger.info(f"Received config: {config}") + except Exception as e: + logger.error("Failed to read stream configuration") + websocket_server.close() + model, _ = model_and_alignementmodel + if USE_CTRANSLATE2: + logger.info("Using ctranslate2 for decoding") + asr = FasterWhisperASR(model=model, lan="fr") + else: + logger.info("Using whisper_timestamped for decoding") + asr = WhisperTimestampedASR(model=model, lan="fr") + online = OnlineASRProcessor(asr, logfile=sys.stderr, buffer_trimming=8, use_vad=USE_VAD) + logger.info("Waiting for chunks") + while True: + try: + message = websocket_server.receive(timeout=10) + if message is None or message == "": # Timeout + logger.info("Connection closed by client") + websocket_server.close() + except Exception as e: + logger.info(f"Connection closed by client: {e}") + break + if "eof" in str(message): + o = online.finish() + websocket_server.send(whisper_to_json(o)) + logger.info(f"End of stream {message}") + websocket_server.close() + break + online.insert_audio_chunk(bytes_to_array(message)) + o, _ = online.process_iter() + logger.info(o) + websocket_server.send(whisper_to_json(o)) class HypothesisBuffer: From a6fb111d1628088edacd9facebc9d85eb70d196e Mon Sep 17 00:00:00 2001 From: AudranBert Date: Thu, 28 Mar 2024 15:36:30 +0100 Subject: [PATCH 12/50] upd .env + clean code --- whisper/.envdefault | 10 +++++++++- whisper/RELEASE.md | 1 + whisper/stt/processing/streaming.py | 29 +++++------------------------ 3 files changed, 15 insertions(+), 25 deletions(-) diff --git a/whisper/.envdefault b/whisper/.envdefault index 9a40b79..795a5ff 100644 --- a/whisper/.envdefault +++ b/whisper/.envdefault @@ -1,7 +1,7 @@ ############################################ # SERVING PARAMETERS ############################################ -# "http" or "task" +# "http" or "task" or "websocket" SERVICE_MODE=http # Below: used when SERVICE_MODE=task @@ -44,3 +44,11 @@ NUM_THREADS=4 # Number of workers CONCURRENCY=2 + +# WEBSOCKET PARAMETERS +STREAMING_PORT=80 + +# HTTP PARAMETERS +ENABLE_STREAMING=true + +USE_VAD=auditok \ No newline at end of file diff --git a/whisper/RELEASE.md b/whisper/RELEASE.md index 1f5688c..03922f3 100644 --- a/whisper/RELEASE.md +++ b/whisper/RELEASE.md @@ -1,5 +1,6 @@ # 1.0.3 - Streaming support +- Refactoring VAD system - New NUM_THREADS env variable to control the number of threads - Load the model when launching the service (not at the first request) diff --git a/whisper/stt/processing/streaming.py b/whisper/stt/processing/streaming.py index 9a4d7c6..98b32a6 100644 --- a/whisper/stt/processing/streaming.py +++ b/whisper/stt/processing/streaming.py @@ -97,7 +97,7 @@ def ws_streaming(websocket_server: WSServer, model_and_alignementmodel): break online.insert_audio_chunk(bytes_to_array(message)) o, _ = online.process_iter() - logger.info(o) + # logger.info(o) websocket_server.send(whisper_to_json(o)) class HypothesisBuffer: @@ -246,35 +246,18 @@ def process_iter(self): if buffer and (self.buffer_time_offset+len(self.audio_buffer)/self.SAMPLING_RATE)-buffer[-1][1]<0.05: # remove the last word if it is too close to the end of the buffer buffer.pop(-1) - logger.debug(f">>>>COMPLETE NOW:{self.to_flush(o)}") - logger.debug(f"INCOMPLETE:{self.to_flush(self.transcript_buffer.complete())}") + logger.debug(f"New committed text:{self.to_flush(o)}") + logger.debug(f"Buffered text:{self.to_flush(self.transcript_buffer.complete())}") if len(self.audio_buffer)/self.SAMPLING_RATE > self.buffer_trimming_sec: self.chunk_completed_segment(res, chunk_silence=self.use_vad, speech_segments=segments if self.use_vad else False) - logger.debug(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}") + logger.debug(f"Len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}s") return self.to_flush(o), self.to_flush(buffer) - def chunk_completed_sentence(self): - if self.commited == []: return - logger.info(self.commited) - sents = self.words_to_sentences(self.commited) - for s in sents: - logger.debug("\t\tSENT:",s) - if len(sents) < 2: - return - while len(sents) > 2: - sents.pop(0) - # we will continue with audio processing at this timestamp - chunk_at = sents[-2][1] - - logger.debug(f"--- sentence chunked at {chunk_at:2.2f}") - self.chunk_at(chunk_at) - def chunk_completed_segment(self, res, chunk_silence=False, speech_segments=None): - if self.commited == [] and not chunk_silence: + if self.commited == [] and not chunk_silence: return - ends = self.asr.segments_end_ts(res) t = self.commited[-1][1] if len(ends) > 1: @@ -284,7 +267,6 @@ def chunk_completed_segment(self, res, chunk_silence=False, speech_segments=None e = ends[-2]+self.buffer_time_offset if e <= t: logger.debug(f"--- segment chunked at {e:2.2f}") - # print(f"--- segment chunked at {e:2.2f}") self.chunk_at(e) else: logger.debug(f"--- last segment not within commited area") @@ -306,7 +288,6 @@ def chunk_completed_segment(self, res, chunk_silence=False, speech_segments=None def chunk_at(self, time): """trims the hypothesis and audio buffer at "time" """ - # print(f"chunking at {time:2.2f}") self.transcript_buffer.pop_commited(time) cut_seconds = time - self.buffer_time_offset self.audio_buffer = self.audio_buffer[int(cut_seconds*self.SAMPLING_RATE):] From e6e993489e519c0b6182c58847593613eae5c9f3 Mon Sep 17 00:00:00 2001 From: AudranBert Date: Thu, 28 Mar 2024 16:18:08 +0100 Subject: [PATCH 13/50] add test doc and code for streaming --- test/test_streaming.py | 120 +++++++++++++++++++++++++++++++++++++++++ whisper/README.md | 22 +++++++- 2 files changed, 141 insertions(+), 1 deletion(-) create mode 100644 test/test_streaming.py diff --git a/test/test_streaming.py b/test/test_streaming.py new file mode 100644 index 0000000..d78f6cf --- /dev/null +++ b/test/test_streaming.py @@ -0,0 +1,120 @@ +import asyncio +import websockets +import json +import shutil + +def linstt_streaming(*kargs, **kwargs): + text = asyncio.run(_linstt_streaming(*kargs, **kwargs)) + return text + +async def _linstt_streaming( + audio_file, + ws_api = "ws://localhost:8080/streaming", + verbose = False, +): + + if audio_file is None: + import pyaudio + # Init pyaudio + audio = pyaudio.PyAudio() + stream = audio.open(format=pyaudio.paInt16, channels=1, rate=16000, input=True, frames_per_buffer=2048) + if verbose > 1: + print("Start recording") + else: + stream = open(audio_file, "rb") + + alive = True + text = "" + partial = None + + try: + async with websockets.connect(ws_api) as websocket: + await websocket.send(json.dumps({"config" : {"sample_rate": 16000 }})) + while alive: + try: + data = stream.read(32000) + if audio_file and not data: + if verbose > 1: + print("\nAudio file finished") + alive = False + await websocket.send(data) + res = await websocket.recv() + message = json.loads(res) + if message is None: + if verbose > 1: + print("\n Received None") + continue + if "partial" in message.keys(): + partial = message["partial"] + if verbose: + print_partial(partial) + elif "text" in message.keys(): + line = message["text"] + if verbose: + print_final(line) + if line: + if text: + text += "\n" + text += line + elif verbose: + print("???", message) + except KeyboardInterrupt: + if verbose > 1: + print("\nKeyboard interrupt") + alive = False + await websocket.send(json.dumps({"eof" : 1})) + res = await websocket.recv() + message = json.loads(res) + if isinstance(message, str): + message = json.loads(message) + if text: + text += " " + text += message["text"] + try: + res = await websocket.recv() + except websockets.ConnectionClosedOK: + if verbose > 1: + print("Websocket Closed") + except KeyboardInterrupt: + if verbose > 1: + print("\nKeyboard interrupt") + if verbose: + print_final("= FULL TRANSCRIPTION ", background="=") + print(text) + + return text + +def print_partial(text): + text = text + "…" + terminal_size = shutil.get_terminal_size() + width = terminal_size.columns + start = ((len(text) - 1)// width) * width + if start > 0: + print(" "*width, end="\r") + if start < len(text) - 1: + print("…"+text[start+1:]+" "*(width-len(text)-start-1), end="\r") + else: + print(text[-width:], end="\r") + else: + print(text, end="\r") + +def print_final(text, background=" "): + terminal_size = shutil.get_terminal_size() + width = terminal_size.columns + print(background * width, end="\r") + print(text) + +if __name__ == "__main__": + + import argparse + parser = argparse.ArgumentParser(description='Transcribe input streaming (from mic or a file) with LinSTT', + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument('--server', help='Transcription server', + default="ws://localhost:8080/streaming", + ) + parser.add_argument("-v", "--verbose", action="store_true", help="Verbose mode") + parser.add_argument("--audio_file", default=None, help="A path to an audio file to transcribe (if not provided, use mic)") + args = parser.parse_args() + + res = linstt_streaming(args.audio_file, args.server, verbose=2 if args.verbose else 1) \ No newline at end of file diff --git a/whisper/README.md b/whisper/README.md index 1f1bf92..2683661 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -293,6 +293,18 @@ Return the transcripted text using "text/plain" or a json object when using "app } ``` +#### /streaming +The /streaming route is accessible if the ENABLE_STREAMING environment variable is set to true. + +The route accepts websocket connexions. Exchanges are structured as followed: +1. Client send a json {"config": {"sample_rate":16000}}. +2. Client send audio chunk (go to 3- ) or {"eof" : 1} (go to 5-). +3. Server send either a partial result {"partial" : "this is a "} or a final result {"text": "this is a transcription"}. +4. Back to 2- +5. Server send a final result and close the connexion. + +> Connexion will be closed and the worker will be freed if no chunk are received for 10s. + #### /docs The /docs route offers a OpenAPI/swagger interface. @@ -329,11 +341,19 @@ On a successfull transcription the returned object is a json object structured a ## Test ### Curl -You can test you http API using curl: +You can test your http API using curl: + ```bash curl -X POST "http://YOUR_SERVICE:YOUR_PORT/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@YOUR_FILE;type=audio/x-wav" ``` +### Streaming +You can test your streaming API using a websocket: + +```bash +python test/test_streaming.py --server ws://YOUR_SERVICE:YOUR_PORT/streaming --audio_file test/bonjour.wav +``` + ## License This project is developped under the AGPLv3 License (see LICENSE). From d81952ef4c6654b204ba480caad41c57ecdac4c4 Mon Sep 17 00:00:00 2001 From: AudranBert Date: Thu, 28 Mar 2024 17:09:48 +0100 Subject: [PATCH 14/50] set sample rate with config --- whisper/stt/processing/streaming.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/whisper/stt/processing/streaming.py b/whisper/stt/processing/streaming.py index 98b32a6..1e077eb 100644 --- a/whisper/stt/processing/streaming.py +++ b/whisper/stt/processing/streaming.py @@ -38,7 +38,7 @@ async def wssDecode(ws: WebSocketServerProtocol, model_and_alignementmodel): else: logger.info("Using whisper_timestamped for decoding") asr = WhisperTimestampedASR(model=model, lan="fr") - online = OnlineASRProcessor(asr, logfile=sys.stderr, buffer_trimming=8, use_vad=USE_VAD) + online = OnlineASRProcessor(asr, logfile=sys.stderr, buffer_trimming=8, use_vad=USE_VAD, sample_rate=sample_rate) logger.info("Waiting for chunks") while True: try: @@ -78,7 +78,7 @@ def ws_streaming(websocket_server: WSServer, model_and_alignementmodel): else: logger.info("Using whisper_timestamped for decoding") asr = WhisperTimestampedASR(model=model, lan="fr") - online = OnlineASRProcessor(asr, logfile=sys.stderr, buffer_trimming=8, use_vad=USE_VAD) + online = OnlineASRProcessor(asr, logfile=sys.stderr, buffer_trimming=8, use_vad=USE_VAD, sample_rate=sample_rate) logger.info("Waiting for chunks") while True: try: @@ -171,9 +171,7 @@ def complete(self): class OnlineASRProcessor: - SAMPLING_RATE = 16000 - - def __init__(self, asr, buffer_trimming=15, use_vad="auditok", logfile=sys.stderr): + def __init__(self, asr, buffer_trimming=15, use_vad="auditok", logfile=sys.stderr, sample_rate=16000): """asr: WhisperASR object tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer. It can be None, if "segment" buffer trimming option is used, then tokenizer is not used at all. ("segment", 15) @@ -187,6 +185,8 @@ def __init__(self, asr, buffer_trimming=15, use_vad="auditok", logfile=sys.stder self.buffer_trimming_sec = buffer_trimming self.use_vad = use_vad + self.sampling_rate = sample_rate + def init(self): """run this when starting or restarting processing""" @@ -229,12 +229,12 @@ def process_iter(self): prompt, non_prompt = self.prompt() logger.debug(f"PROMPT:{prompt}") logger.debug(f"CONTEXT:{non_prompt}") - logger.debug(f"Transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds starting at {self.buffer_time_offset:2.2f}s") + logger.debug(f"Transcribing {len(self.audio_buffer)/self.sampling_rate:2.2f} seconds starting at {self.buffer_time_offset:2.2f}s") # print(f"Transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds starting at {self.buffer_time_offset:2.2f}s") # use VAD to filter out the silence if self.use_vad: np_buffer = np.array(self.audio_buffer) - audio_speech, segments, convertion_function = remove_non_speech(np_buffer, method=self.use_vad, sample_rate=self.SAMPLING_RATE, dilatation=0.5) + audio_speech, segments, convertion_function = remove_non_speech(np_buffer, method=self.use_vad, sample_rate=self.sampling_rate, dilatation=0.5) res = self.asr.transcribe(audio_speech, init_prompt=prompt) else: res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt) @@ -243,16 +243,16 @@ def process_iter(self): self.transcript_buffer.insert(tsw, self.buffer_time_offset) o, buffer = self.transcript_buffer.flush() self.commited.extend(o) - if buffer and (self.buffer_time_offset+len(self.audio_buffer)/self.SAMPLING_RATE)-buffer[-1][1]<0.05: + if buffer and (self.buffer_time_offset+len(self.audio_buffer)/self.sampling_rate)-buffer[-1][1]<0.05: # remove the last word if it is too close to the end of the buffer buffer.pop(-1) logger.debug(f"New committed text:{self.to_flush(o)}") logger.debug(f"Buffered text:{self.to_flush(self.transcript_buffer.complete())}") - if len(self.audio_buffer)/self.SAMPLING_RATE > self.buffer_trimming_sec: + if len(self.audio_buffer)/self.sampling_rate > self.buffer_trimming_sec: self.chunk_completed_segment(res, chunk_silence=self.use_vad, speech_segments=segments if self.use_vad else False) - logger.debug(f"Len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}s") + logger.debug(f"Len of buffer now: {len(self.audio_buffer)/self.sampling_rate:2.2f}s") return self.to_flush(o), self.to_flush(buffer) def chunk_completed_segment(self, res, chunk_silence=False, speech_segments=None): @@ -271,7 +271,7 @@ def chunk_completed_segment(self, res, chunk_silence=False, speech_segments=None else: logger.debug(f"--- last segment not within commited area") elif chunk_silence: - lenght = len(self.audio_buffer)/self.SAMPLING_RATE + lenght = len(self.audio_buffer)/self.sampling_rate e = self.buffer_time_offset + lenght - 2 if speech_segments: end_silence = lenght - speech_segments[-1][1] @@ -290,7 +290,7 @@ def chunk_at(self, time): """ self.transcript_buffer.pop_commited(time) cut_seconds = time - self.buffer_time_offset - self.audio_buffer = self.audio_buffer[int(cut_seconds*self.SAMPLING_RATE):] + self.audio_buffer = self.audio_buffer[int(cut_seconds*self.sampling_rate):] self.buffer_time_offset = time self.last_chunked_at = time From 4c9f48c0f314843c3477f1c2a4b4a4bf1a0d3020 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Fri, 29 Mar 2024 10:29:20 +0100 Subject: [PATCH 15/50] Do not mention stack anymore --- kaldi/README.md | 4 ++-- whisper/README.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/kaldi/README.md b/kaldi/README.md index 7ebfa85..70584f2 100644 --- a/kaldi/README.md +++ b/kaldi/README.md @@ -68,7 +68,7 @@ cp kaldi/.envdefault kaldi/.env STT can be used three ways: * Through an [HTTP API](#http-server) using the **http**'s mode. -* Through a [message broker](#micro-service-within-linto-platform-stack) using the **task**'s mode. +* Through a [message broker](#celery-task) using the **task**'s mode. * Through a [websocket server](#websocket-server) **websocket**'s mode. Mode is specified using the .env value or environment variable ```SERVING_MODE```. @@ -99,7 +99,7 @@ This will run a container providing an [HTTP API](#http-api) binded on the host | LM_PATH | Path to the language model on the host machine mounted to /opt/LM | /my/path/to/models/fr-FR_big-v2.2.0 | | MODEL_PATH | Path to the model (using MODEL_TYPE=vosk) mounted to /opt/model | /my/path/to/models/vosk-model | -### Micro-service within LinTO-Platform stack +### Celery task The TASK serving mode connect a celery worker to a message broker. The SERVICE_MODE value in the .env should be set to ```task```. diff --git a/whisper/README.md b/whisper/README.md index 2683661..776d82f 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -185,7 +185,7 @@ and also `yue(cantonese)` since large-v3. STT can be used in two ways: * Through an [HTTP API](#http-server) using the **http**'s mode. -* Through a [message broker](#micro-service-within-linto-platform-stack) using the **task**'s mode. +* Through a [message broker](#celery-task) using the **task**'s mode. Mode is specified using the .env value or environment variable ```SERVING_MODE```. ```bash @@ -222,7 +222,7 @@ You may also want to add specific options: | `` | Path to the Whisper model on the host machine mounted to /opt/model.pt | /my/path/to/models/medium.pt | | `` | Path to a folder to a custom wav2vec alignment model | /my/path/to/models/wav2vec | -### Micro-service within LinTO-Platform stack +### Celery task The TASK serving mode connect a celery worker to a message broker. The SERVICE_MODE value in the .env should be set to ```task```. From 5fa518812ece64250343faa9a22cee4207a8174a Mon Sep 17 00:00:00 2001 From: AudranBert Date: Fri, 29 Mar 2024 17:03:18 +0100 Subject: [PATCH 16/50] fix workers on cpu --- http_server/ingress.py | 17 +++++++++++++++++ whisper/stt/processing/__init__.py | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/http_server/ingress.py b/http_server/ingress.py index ec21d33..73de3d6 100644 --- a/http_server/ingress.py +++ b/http_server/ingress.py @@ -127,6 +127,20 @@ def server_error(error): else: serving_type = GunicornServing logger.debug("Serving with gunicorn") + # serving_type = GunicornServing + # logger.debug("Serving with gunicorn") + + def worker_started(worker): + logger.info(f"Worker started {worker.pid}") + MODEL[0].check_loaded() + logger.info("Worker fully initialized") + + + # def post_fork(server, worker): + # logger.info("Worker post fork") + # MODEL[0].check_loaded() + # logger.info("Worker f") + serving = serving_type( app, @@ -134,6 +148,9 @@ def server_error(error): "bind": f"0.0.0.0:{args.service_port}", "workers": args.workers, "timeout": 3600 * 24, + # "on_starting": lambda server: logger.info("Server started"), + "post_worker_init": worker_started, + # "post_fork": post_fork }, ) logger.info(args) diff --git a/whisper/stt/processing/__init__.py b/whisper/stt/processing/__init__.py index a632a52..0fa99ee 100644 --- a/whisper/stt/processing/__init__.py +++ b/whisper/stt/processing/__init__.py @@ -70,7 +70,7 @@ def __call__(self, *args, **kwargs): ) try: model = LazyLoadedModel(model_type, device=device, num_threads=NUM_THREADS) - if not USE_CTRANSLATE2 or device.lower() != "cpu": + if device.lower() != "cpu": model.check_loaded() except Exception as err: raise Exception("Failed to load transcription model: {}".format(str(err))) from err From d02ec5d609db803e9322048f5fad97eee7e26878 Mon Sep 17 00:00:00 2001 From: AudranBert Date: Tue, 2 Apr 2024 16:05:24 +0200 Subject: [PATCH 17/50] add: external VAD and refactor VAD --- whisper/requirements.ctranslate2.txt | 2 +- whisper/stt/__init__.py | 6 +++--- whisper/stt/processing/__init__.py | 6 +++--- whisper/stt/processing/decoding.py | 12 +++++++++--- whisper/stt/processing/streaming.py | 12 ++++++------ .../stt/processing/{streaming_vad.py => vad.py} | 17 ++++++++++------- 6 files changed, 32 insertions(+), 23 deletions(-) rename whisper/stt/processing/{streaming_vad.py => vad.py} (96%) diff --git a/whisper/requirements.ctranslate2.txt b/whisper/requirements.ctranslate2.txt index 87b5e80..5fc25d2 100644 --- a/whisper/requirements.ctranslate2.txt +++ b/whisper/requirements.ctranslate2.txt @@ -14,4 +14,4 @@ websockets auditok #faster_whisper==1.0.1 # This is version faster_whisper==1.0.1 + option for (persistent) prompt + fix for large-v3 -git+https://github.com/linto-ai/faster-whisper.git \ No newline at end of file +git+https://github.com/linto-ai/faster-whisper.git@external_vad \ No newline at end of file diff --git a/whisper/stt/__init__.py b/whisper/stt/__init__.py index 70058dd..93a04e5 100644 --- a/whisper/stt/__init__.py +++ b/whisper/stt/__init__.py @@ -13,11 +13,11 @@ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # GPU in the right order if os.environ.get("USE_VAD","auditok") in [True, "true", 1]: - USE_VAD = "auditok" + VAD = "auditok" elif os.environ.get("USE_VAD","auditok") in [False, "false", 0]: - USE_VAD = False + VAD = False else: - USE_VAD = os.environ.get("USE_VAD","auditok") + VAD = os.environ.get("USE_VAD","auditok") NUM_THREADS = os.environ.get("NUM_THREADS", os.environ.get("OMP_NUM_THREADS")) NUM_THREADS = int(NUM_THREADS) diff --git a/whisper/stt/processing/__init__.py b/whisper/stt/processing/__init__.py index 0fa99ee..2140768 100644 --- a/whisper/stt/processing/__init__.py +++ b/whisper/stt/processing/__init__.py @@ -2,7 +2,7 @@ import os from lockfile import FileLock -from stt import USE_CTRANSLATE2, USE_VAD, logger, set_num_threads, NUM_THREADS +from stt import USE_CTRANSLATE2, VAD, logger, set_num_threads, NUM_THREADS from .alignment_model import get_alignment_model, load_alignment_model from .decoding import decode @@ -60,7 +60,7 @@ def __call__(self, *args, **kwargs): language = get_language() logger.info(f"Using language {language}") -logger.info(f"USE_VAD={USE_VAD}") +logger.info(f"VAD={VAD}") logger.info(f"USE_CTRANSLATE2={USE_CTRANSLATE2}") # Load ASR model @@ -70,7 +70,7 @@ def __call__(self, *args, **kwargs): ) try: model = LazyLoadedModel(model_type, device=device, num_threads=NUM_THREADS) - if device.lower() != "cpu": + if str(device).lower() != "cpu": model.check_loaded() except Exception as err: raise Exception("Failed to load transcription model: {}".format(str(err))) from err diff --git a/whisper/stt/processing/decoding.py b/whisper/stt/processing/decoding.py index 8230000..3d50579 100644 --- a/whisper/stt/processing/decoding.py +++ b/whisper/stt/processing/decoding.py @@ -5,8 +5,9 @@ from typing import Tuple, Union import numpy as np -from stt import USE_CTRANSLATE2, USE_VAD, logger +from stt import USE_CTRANSLATE2, VAD, logger +from .vad import remove_non_speech from .alignment_model import get_alignment_model, load_alignment_model from .text_normalize import normalize_text, remove_emoji, remove_punctuation from .utils import SAMPLE_RATE, get_language @@ -78,13 +79,15 @@ def decode_ct2( kwargs["beam_size"] = 1 if kwargs.get("best_of") is None: kwargs["best_of"] = 1 + if VAD: + _, speech_segments, _ = remove_non_speech(audio, VAD, return_format="dict") segments, info = model.transcribe( audio, word_timestamps=with_word_timestamps, language=language, # Careful with the following options max_initial_timestamp=10000.0, - vad_filter=USE_VAD, + vad_filter=speech_segments if VAD else False, **kwargs, ) segments = list(segments) @@ -114,6 +117,9 @@ def decode_torch( fp16 = model.device != torch.device("cpu") + if VAD: + _, speech_segments, _ = remove_non_speech(audio, VAD) + kwargs = dict( language=language, fp16=fp16, @@ -123,7 +129,7 @@ def decode_torch( condition_on_previous_text=condition_on_previous_text, no_speech_threshold=no_speech_threshold, compression_ratio_threshold=compression_ratio_threshold, - vad=USE_VAD, + vad=speech_segments if VAD else False, initial_prompt=prompt, ) diff --git a/whisper/stt/processing/streaming.py b/whisper/stt/processing/streaming.py index 1e077eb..f4ad44d 100644 --- a/whisper/stt/processing/streaming.py +++ b/whisper/stt/processing/streaming.py @@ -2,8 +2,8 @@ import sys import string import numpy as np -from stt.processing.streaming_vad import remove_non_speech -from stt import logger, USE_CTRANSLATE2, USE_VAD +from .vad import remove_non_speech +from stt import logger, USE_CTRANSLATE2, VAD from websockets.legacy.server import WebSocketServerProtocol from simple_websocket.ws import Server as WSServer @@ -38,8 +38,8 @@ async def wssDecode(ws: WebSocketServerProtocol, model_and_alignementmodel): else: logger.info("Using whisper_timestamped for decoding") asr = WhisperTimestampedASR(model=model, lan="fr") - online = OnlineASRProcessor(asr, logfile=sys.stderr, buffer_trimming=8, use_vad=USE_VAD, sample_rate=sample_rate) - logger.info("Waiting for chunks") + online = OnlineASRProcessor(asr, logfile=sys.stderr, buffer_trimming=8, use_vad=VAD, sample_rate=sample_rate) + logger.info("Starting transcription ...") while True: try: message = await ws.recv() @@ -78,8 +78,8 @@ def ws_streaming(websocket_server: WSServer, model_and_alignementmodel): else: logger.info("Using whisper_timestamped for decoding") asr = WhisperTimestampedASR(model=model, lan="fr") - online = OnlineASRProcessor(asr, logfile=sys.stderr, buffer_trimming=8, use_vad=USE_VAD, sample_rate=sample_rate) - logger.info("Waiting for chunks") + online = OnlineASRProcessor(asr, logfile=sys.stderr, buffer_trimming=8, use_vad=VAD, sample_rate=sample_rate) + logger.info("Starting transcription ...") while True: try: message = websocket_server.receive(timeout=10) diff --git a/whisper/stt/processing/streaming_vad.py b/whisper/stt/processing/vad.py similarity index 96% rename from whisper/stt/processing/streaming_vad.py rename to whisper/stt/processing/vad.py index 149de6d..eb83bd0 100644 --- a/whisper/stt/processing/streaming_vad.py +++ b/whisper/stt/processing/vad.py @@ -17,6 +17,7 @@ def remove_non_speech(audio, sample_rate=16000, method="silero", avoid_empty_speech=False, + return_format="tuple", ): """ Remove non-speech segments from audio (using Silero VAD), @@ -40,11 +41,11 @@ def remove_non_speech(audio, if True, avoid returning an empty speech segment (re) """ - if USE_CTRANSLATE2 and method.startswith("silero"): + if USE_CTRANSLATE2 and method=="silero": from faster_whisper.vad import VadOptions options = VadOptions( - min_speech_duration_ms =min_speech_duration*1000, - min_silence_duration_ms =min_silence_duration*1000, + min_speech_duration_ms=min_speech_duration*1000, + min_silence_duration_ms=min_silence_duration*1000, ) from faster_whisper.vad import get_speech_timestamps segments = get_speech_timestamps(audio, vad_options=options) @@ -58,7 +59,6 @@ def remove_non_speech(audio, dilatation=dilatation, method=method, ) - segments = [(seg["start"], seg["end"]) for seg in segments] if len(segments) == 0: if avoid_empty_speech: @@ -71,7 +71,8 @@ def remove_non_speech(audio, if not use_sample: segments = [(float(s)/sample_rate, float(e)/sample_rate) for s,e in segments] - + if return_format == "dict": + segments = [{"start": s, "end": e} for s, e in segments] return audio_speech, segments, lambda t, t2 = None: do_convert_timestamps(segments, t, t2) def do_convert_timestamps(segments, t, t2 = None): @@ -220,8 +221,10 @@ def apply_folder_hack(): _silero_get_speech_ts = utils[0] # Cheap normalization of the volume - # audio = audio / max(0.1, audio.abs().max()) - audio = audio / max(0.1, np.max(np.abs(audio))) + if isinstance(audio, np.ndarray): + audio = audio / max(0.1, np.max(np.abs(audio))) + else: + audio = audio / max(0.1, audio.abs().max()) segments = _silero_get_speech_ts(audio, _silero_vad_model[version], sampling_rate = sample_rate, From a24d3f77d3b217ce77590e274f45dfaf417711e6 Mon Sep 17 00:00:00 2001 From: AudranBert Date: Tue, 2 Apr 2024 16:13:55 +0200 Subject: [PATCH 18/50] add streaming options in readme --- whisper/README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/whisper/README.md b/whisper/README.md index 776d82f..f5119a5 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -122,11 +122,14 @@ cp whisper/.envdefault whisper/.env | NUM_THREADS | Number of threads (maximum) to use for things running on CPU | `1` \| `4` \| ... | | CUDA_VISIBLE_DEVICES | GPU device index to use, when running on GPU/CUDA. We also recommend to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` on multi-GPU machines | `0` \| `1` \| `2` \| ... | | CONCURRENCY | Maximum number of parallel requests | `2` | +| ENABLE_STREAMING | (For the http mode) enable the /streaming websocket route | true\|false | +| STREAMING_PORT | (For the websocket mode) the listening port for ingoing WS connexions. | 80 | | SERVICE_NAME | (For the task mode only) queue's name for task processing | `my-stt` | | SERVICE_BROKER | (For the task mode only) URL of the message broker | `redis://my-broker:6379` | | BROKER_PASS | (For the task mode only) broker password | `my-password` \| (empty) | | ALIGNMENT_MODEL | (Deprecated) Path to the wav2vec model for word alignment, or name of HuggingFace repository or torchaudio pipeline | `WAV2VEC2_ASR_BASE_960H` \| `jonatasgrosman/wav2vec2-large-xlsr-53-english` \| \ \| ... | + #### MODEL environment variable **Warning:** From 414c5df070ab9e5f36966a2091504b93e59fcf9a Mon Sep 17 00:00:00 2001 From: AudranBert Date: Tue, 2 Apr 2024 16:58:57 +0200 Subject: [PATCH 19/50] fix vad decoding + reformat files --- http_server/ingress.py | 17 +-- whisper/stt/processing/decoding.py | 45 ++++-- whisper/stt/processing/streaming.py | 228 ++++++++++++++++++---------- whisper/stt/processing/vad.py | 190 ++++++++++++++--------- 4 files changed, 305 insertions(+), 175 deletions(-) diff --git a/http_server/ingress.py b/http_server/ingress.py index 73de3d6..c70fa52 100644 --- a/http_server/ingress.py +++ b/http_server/ingress.py @@ -84,7 +84,9 @@ def transcribe(): logger.error(traceback.format_exc()) logger.error(repr(error)) - return "Server Error: {}".format(str(error)), 400 if isinstance(error, ValueError) else 500 + return "Server Error: {}".format(str(error)), ( + 400 if isinstance(error, ValueError) else 500 + ) @app.errorhandler(405) @@ -127,20 +129,11 @@ def server_error(error): else: serving_type = GunicornServing logger.debug("Serving with gunicorn") - # serving_type = GunicornServing - # logger.debug("Serving with gunicorn") - + def worker_started(worker): logger.info(f"Worker started {worker.pid}") MODEL[0].check_loaded() logger.info("Worker fully initialized") - - - # def post_fork(server, worker): - # logger.info("Worker post fork") - # MODEL[0].check_loaded() - # logger.info("Worker f") - serving = serving_type( app, @@ -148,9 +141,7 @@ def worker_started(worker): "bind": f"0.0.0.0:{args.service_port}", "workers": args.workers, "timeout": 3600 * 24, - # "on_starting": lambda server: logger.info("Server started"), "post_worker_init": worker_started, - # "post_fork": post_fork }, ) logger.info(args) diff --git a/whisper/stt/processing/decoding.py b/whisper/stt/processing/decoding.py index 3d50579..ffa2024 100644 --- a/whisper/stt/processing/decoding.py +++ b/whisper/stt/processing/decoding.py @@ -72,7 +72,12 @@ def decode( def decode_ct2( - audio, model, with_word_timestamps, language, remove_punctuation_from_words, **kwargs + audio, + model, + with_word_timestamps, + language, + remove_punctuation_from_words, + **kwargs, ): kwargs["no_speech_threshold"] = 1 # To avoid empty output if kwargs.get("beam_size") is None: @@ -80,7 +85,7 @@ def decode_ct2( if kwargs.get("best_of") is None: kwargs["best_of"] = 1 if VAD: - _, speech_segments, _ = remove_non_speech(audio, VAD, return_format="dict") + _, speech_segments, _ = remove_non_speech(audio, method=VAD, return_format="dict") segments, info = model.transcribe( audio, word_timestamps=with_word_timestamps, @@ -118,7 +123,7 @@ def decode_torch( fp16 = model.device != torch.device("cpu") if VAD: - _, speech_segments, _ = remove_non_speech(audio, VAD) + _, speech_segments, _ = remove_non_speech(audio, method=VAD) kwargs = dict( language=language, @@ -135,7 +140,9 @@ def decode_torch( if alignment_model is None: # Use Whisper cross-attention weights - whisper_res = whisper_timestamped.transcribe(model, audio, verbose=None, **kwargs) + whisper_res = whisper_timestamped.transcribe( + model, audio, verbose=None, **kwargs + ) if language is None: language = whisper_res["language"] logger.info(f"Detected language: {language}") @@ -177,7 +184,9 @@ def decode_torch( result["text"] = text result["language"] = language result["confidence-score"] = ( - np.exp(np.array([r["avg_logprob"] for r in segments])).mean() if len(segments) else 0.0 + np.exp(np.array([r["avg_logprob"] for r in segments])).mean() + if len(segments) + else 0.0 ) if not with_word_timestamps: @@ -253,7 +262,9 @@ def decode_torch( return result -def format_whisper_timestamped_response(transcription, remove_punctuation_from_words=False): +def format_whisper_timestamped_response( + transcription, remove_punctuation_from_words=False +): """Format Whisper response.""" for i, seg in enumerate(transcription["segments"][:-1]): @@ -283,9 +294,11 @@ def format_whisper_timestamped_response(transcription, remove_punctuation_from_w return { "text": transcription["text"].strip(), "language": transcription["language"], - "confidence-score": round(np.exp(np.array([r["avg_logprob"] for r in segments])).mean(), 2) - if len(segments) - else 0.0, + "confidence-score": ( + round(np.exp(np.array([r["avg_logprob"] for r in segments])).mean(), 2) + if len(segments) + else 0.0 + ), "words": words, } @@ -309,7 +322,10 @@ def checked_timestamps(start, end=None): if end == start: pass # end = start + 0.01 else: - print("WARNING, end timestamp %f is smaller than start timestamp %f" % (end, start)) + print( + "WARNING, end timestamp %f is smaller than start timestamp %f" + % (end, start) + ) if end is None: return start return (start, end) @@ -329,7 +345,11 @@ def checked_timestamps(start, end=None): and len(words) and len(word_strip) > 1 and word_strip[0] in glue_punctuations - and (word_strip == word_string or not contains_alphanum(words[-1]["text"]) or not contains_alphanum(word_strip)) + and ( + word_strip == word_string + or not contains_alphanum(words[-1]["text"]) + or not contains_alphanum(word_strip) + ) ): words[-1]["text"] += word_strip words[-1]["confidence"].append(word.probability) @@ -370,5 +390,6 @@ def checked_timestamps(start, end=None): transcription, remove_punctuation_from_words=remove_punctuation_from_words ) + def contains_alphanum(text: str) -> bool: - return re.search(r"[^\W\'\-_]", text) \ No newline at end of file + return re.search(r"[^\W\'\-_]", text) diff --git a/whisper/stt/processing/streaming.py b/whisper/stt/processing/streaming.py index f4ad44d..2382717 100644 --- a/whisper/stt/processing/streaming.py +++ b/whisper/stt/processing/streaming.py @@ -7,20 +7,24 @@ from websockets.legacy.server import WebSocketServerProtocol from simple_websocket.ws import Server as WSServer + def bytes_to_array(bytes): return np.frombuffer(bytes, dtype=np.int16).astype(np.float32) / 32768 + def processor_output_to_text(o): if o[0] is None: return "" return o[2] + def whisper_to_json(o): result = dict() result["text"] = processor_output_to_text(o) json_res = json.dumps(result) return json_res + async def wssDecode(ws: WebSocketServerProtocol, model_and_alignementmodel): """Async Decode function endpoint""" res = await ws.recv() @@ -38,7 +42,9 @@ async def wssDecode(ws: WebSocketServerProtocol, model_and_alignementmodel): else: logger.info("Using whisper_timestamped for decoding") asr = WhisperTimestampedASR(model=model, lan="fr") - online = OnlineASRProcessor(asr, logfile=sys.stderr, buffer_trimming=8, use_vad=VAD, sample_rate=sample_rate) + online = OnlineASRProcessor( + asr, logfile=sys.stderr, buffer_trimming=8, use_vad=VAD, sample_rate=sample_rate + ) logger.info("Starting transcription ...") while True: try: @@ -59,8 +65,8 @@ async def wssDecode(ws: WebSocketServerProtocol, model_and_alignementmodel): o, _ = online.process_iter() logger.info(o) await ws.send(whisper_to_json(o)) - - + + def ws_streaming(websocket_server: WSServer, model_and_alignementmodel): """Sync Decode function endpoint""" res = websocket_server.receive(timeout=10) @@ -78,7 +84,9 @@ def ws_streaming(websocket_server: WSServer, model_and_alignementmodel): else: logger.info("Using whisper_timestamped for decoding") asr = WhisperTimestampedASR(model=model, lan="fr") - online = OnlineASRProcessor(asr, logfile=sys.stderr, buffer_trimming=8, use_vad=VAD, sample_rate=sample_rate) + online = OnlineASRProcessor( + asr, logfile=sys.stderr, buffer_trimming=8, use_vad=VAD, sample_rate=sample_rate + ) logger.info("Starting transcription ...") while True: try: @@ -97,9 +105,9 @@ def ws_streaming(websocket_server: WSServer, model_and_alignementmodel): break online.insert_audio_chunk(bytes_to_array(message)) o, _ = online.process_iter() - # logger.info(o) websocket_server.send(whisper_to_json(o)) + class HypothesisBuffer: def __init__(self, logfile=sys.stderr): @@ -116,20 +124,24 @@ def __init__(self, logfile=sys.stderr): def insert(self, new, offset): # compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content # the new tail is added to self.new - - new = [(a+offset,b+offset,t) for a,b,t in new] - self.new = [(a,b,t) for a,b,t in new if a > self.last_commited_time-0.1] + + new = [(a + offset, b + offset, t) for a, b, t in new] + self.new = [(a, b, t) for a, b, t in new if a > self.last_commited_time - 0.1] if len(self.new) >= 1: - a,b,t = self.new[0] + a, b, t = self.new[0] if abs(a - self.last_commited_time) < 1: if self.commited_in_buffer: # it's going to search for 1, 2, ..., 5 consecutive words (n-grams) that are identical in commited and new. If they are, they're dropped. cn = len(self.commited_in_buffer) nn = len(self.new) - for i in range(1,min(min(cn,nn),5)+1): # 5 is the maximum - c = " ".join([self.commited_in_buffer[-j][2] for j in range(1,i+1)][::-1]) - tail = " ".join(self.new[j-1][2] for j in range(1,i+1)) + for i in range(1, min(min(cn, nn), 5) + 1): # 5 is the maximum + c = " ".join( + [self.commited_in_buffer[-j][2] for j in range(1, i + 1)][ + ::-1 + ] + ) + tail = " ".join(self.new[j - 1][2] for j in range(1, i + 1)) if c == tail: logger.debug(f"removing last {i} words:") for j in range(i): @@ -137,8 +149,7 @@ def insert(self, new, offset): break def flush(self): - # returns commited chunk = the longest common prefix of 2 last inserts. - + # returns commited chunk = the longest common prefix of 2 last inserts. commit = [] while self.new: na, nb, nt = self.new[0] @@ -146,17 +157,22 @@ def flush(self): if len(self.buffer) == 0: break - if nt.lower().translate(str.maketrans('', '', string.punctuation)) == self.buffer[0][2].lower().translate(str.maketrans('', '', string.punctuation)): - commit.append((na,nb,nt)) + if nt.lower().translate( + str.maketrans("", "", string.punctuation) + ) == self.buffer[0][2].lower().translate( + str.maketrans("", "", string.punctuation) + ): + commit.append((na, nb, nt)) self.last_commited_word = nt self.last_commited_time = nb self.buffer.pop(0) self.new.pop(0) else: - # print(f"SStop committing at '{nt}' and '{self.buffer[0][2]}'") break self.buffer = self.new - new_non_commit = [i for i in self.buffer if i[1] > self.last_buffered_time-0.1] + new_non_commit = [ + i for i in self.buffer if i[1] > self.last_buffered_time - 0.1 + ] self.last_buffered_time = self.buffer[-1][1] if self.buffer else -1 self.new = [] self.commited_in_buffer.extend(commit) @@ -169,14 +185,22 @@ def pop_commited(self, time): def complete(self): return self.buffer + class OnlineASRProcessor: - def __init__(self, asr, buffer_trimming=15, use_vad="auditok", logfile=sys.stderr, sample_rate=16000): + def __init__( + self, + asr, + buffer_trimming=15, + use_vad="auditok", + logfile=sys.stderr, + sample_rate=16000, + ): """asr: WhisperASR object tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer. It can be None, if "segment" buffer trimming option is used, then tokenizer is not used at all. ("segment", 15) buffer_trimming: a pair of (option, seconds), where option is either "sentence" or "segment", and seconds is a number. Buffer is trimmed if it is longer than "seconds" threshold. Default is the most recommended option. - logfile: where to store the log. + logfile: where to store the log. """ self.asr = asr self.logfile = logfile @@ -186,11 +210,10 @@ def __init__(self, asr, buffer_trimming=15, use_vad="auditok", logfile=sys.stder self.buffer_trimming_sec = buffer_trimming self.use_vad = use_vad self.sampling_rate = sample_rate - def init(self): """run this when starting or restarting processing""" - self.audio_buffer = np.array([],dtype=np.float32) + self.audio_buffer = np.array([], dtype=np.float32) self.buffer_time_offset = 0 self.transcript_buffer = HypothesisBuffer(logfile=self.logfile) @@ -203,38 +226,45 @@ def insert_audio_chunk(self, audio): self.audio_buffer = np.append(self.audio_buffer, audio) def prompt(self): - """Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer. + """Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer. "context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons. """ - k = max(0,len(self.commited)-1) - while k > 0 and self.commited[k-1][1] > self.last_chunked_at: + k = max(0, len(self.commited) - 1) + while k > 0 and self.commited[k - 1][1] > self.last_chunked_at: k -= 1 p = self.commited[:k] - p = [t for _,_,t in p] + p = [t for _, _, t in p] prompt = [] l = 0 while p and l < 200: # 200 characters prompt size x = p.pop(-1) - l += len(x)+1 + l += len(x) + 1 prompt.append(x) non_prompt = self.commited[k:] - return self.asr.sep.join(prompt[::-1]), self.asr.sep.join(t for _,_,t in non_prompt) + return self.asr.sep.join(prompt[::-1]), self.asr.sep.join( + t for _, _, t in non_prompt + ) def process_iter(self): """Runs on the current audio buffer. - Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, ""). + Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, ""). The non-emty text is confirmed (committed) partial transcript. """ prompt, non_prompt = self.prompt() logger.debug(f"PROMPT:{prompt}") logger.debug(f"CONTEXT:{non_prompt}") - logger.debug(f"Transcribing {len(self.audio_buffer)/self.sampling_rate:2.2f} seconds starting at {self.buffer_time_offset:2.2f}s") - # print(f"Transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds starting at {self.buffer_time_offset:2.2f}s") - # use VAD to filter out the silence + logger.debug( + f"Transcribing {len(self.audio_buffer)/self.sampling_rate:2.2f} seconds starting at {self.buffer_time_offset:2.2f}s" + ) if self.use_vad: np_buffer = np.array(self.audio_buffer) - audio_speech, segments, convertion_function = remove_non_speech(np_buffer, method=self.use_vad, sample_rate=self.sampling_rate, dilatation=0.5) + audio_speech, segments, convertion_function = remove_non_speech( + np_buffer, + method=self.use_vad, + sample_rate=self.sampling_rate, + dilatation=0.5, + ) res = self.asr.transcribe(audio_speech, init_prompt=prompt) else: res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt) @@ -243,16 +273,29 @@ def process_iter(self): self.transcript_buffer.insert(tsw, self.buffer_time_offset) o, buffer = self.transcript_buffer.flush() self.commited.extend(o) - if buffer and (self.buffer_time_offset+len(self.audio_buffer)/self.sampling_rate)-buffer[-1][1]<0.05: + if ( + buffer + and (self.buffer_time_offset + len(self.audio_buffer) / self.sampling_rate) + - buffer[-1][1] + < 0.05 + ): # remove the last word if it is too close to the end of the buffer buffer.pop(-1) logger.debug(f"New committed text:{self.to_flush(o)}") - logger.debug(f"Buffered text:{self.to_flush(self.transcript_buffer.complete())}") - - if len(self.audio_buffer)/self.sampling_rate > self.buffer_trimming_sec: - self.chunk_completed_segment(res, chunk_silence=self.use_vad, speech_segments=segments if self.use_vad else False) - - logger.debug(f"Len of buffer now: {len(self.audio_buffer)/self.sampling_rate:2.2f}s") + logger.debug( + f"Buffered text:{self.to_flush(self.transcript_buffer.complete())}" + ) + + if len(self.audio_buffer) / self.sampling_rate > self.buffer_trimming_sec: + self.chunk_completed_segment( + res, + chunk_silence=self.use_vad, + speech_segments=segments if self.use_vad else False, + ) + + logger.debug( + f"Len of buffer now: {len(self.audio_buffer)/self.sampling_rate:2.2f}s" + ) return self.to_flush(o), self.to_flush(buffer) def chunk_completed_segment(self, res, chunk_silence=False, speech_segments=None): @@ -261,17 +304,17 @@ def chunk_completed_segment(self, res, chunk_silence=False, speech_segments=None ends = self.asr.segments_end_ts(res) t = self.commited[-1][1] if len(ends) > 1: - e = ends[-2]+self.buffer_time_offset + e = ends[-2] + self.buffer_time_offset while len(ends) > 2 and e > t: ends.pop(-1) - e = ends[-2]+self.buffer_time_offset + e = ends[-2] + self.buffer_time_offset if e <= t: logger.debug(f"--- segment chunked at {e:2.2f}") self.chunk_at(e) else: logger.debug(f"--- last segment not within commited area") elif chunk_silence: - lenght = len(self.audio_buffer)/self.sampling_rate + lenght = len(self.audio_buffer) / self.sampling_rate e = self.buffer_time_offset + lenght - 2 if speech_segments: end_silence = lenght - speech_segments[-1][1] @@ -280,17 +323,15 @@ def chunk_completed_segment(self, res, chunk_silence=False, speech_segments=None self.chunk_at(e) elif speech_segments is not None: logger.debug(f"--- Silence segment chunked at {e:2.2f}") - self.chunk_at(e) + self.chunk_at(e) else: logger.debug(f"--- not enough segments to chunk") - def chunk_at(self, time): - """trims the hypothesis and audio buffer at "time" - """ + """trims the hypothesis and audio buffer at "time" """ self.transcript_buffer.pop_commited(time) cut_seconds = time - self.buffer_time_offset - self.audio_buffer = self.audio_buffer[int(cut_seconds*self.sampling_rate):] + self.audio_buffer = self.audio_buffer[int(cut_seconds * self.sampling_rate) :] self.buffer_time_offset = time self.last_chunked_at = time @@ -298,7 +339,7 @@ def words_to_sentences(self, words): """Uses self.tokenizer for sentence segmentation of words. Returns: [(beg,end,"sentence 1"),...] """ - + cwords = [w for w in words] t = " ".join(o[2] for o in cwords) s = self.tokenizer.split(t) @@ -309,15 +350,15 @@ def words_to_sentences(self, words): sent = s.pop(0).strip() fsent = sent while cwords: - b,e,w = cwords.pop(0) + b, e, w = cwords.pop(0) w = w.strip() if beg is None and sent.startswith(w): beg = b elif end is None and sent == w: end = e - out.append((beg,end,fsent)) + out.append((beg, end, fsent)) break - sent = sent[len(w):].strip() + sent = sent[len(w) :].strip() return out def finish(self): @@ -329,8 +370,12 @@ def finish(self): logger.debug(f"last, noncommited:{f}") return f - - def to_flush(self, sents, sep=None, offset=0, ): + def to_flush( + self, + sents, + sep=None, + offset=0, + ): # concatenates the timestamped words or sentences into one sequence that is flushed in one line # sents: [(beg1, end1, "sentence1"), ...] or [] if empty # return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty @@ -343,19 +388,21 @@ def to_flush(self, sents, sep=None, offset=0, ): else: b = offset + sents[0][0] e = offset + sents[-1][1] - return (b,e,t) - - + return (b, e, t) + + class ASRBase: - sep = " " # join transcribe words with this character (" " for whisper_timestamped, - # "" for faster-whisper because it emits the spaces when needed) + sep = " " # join transcribe words with this character (" " for whisper_timestamped, + # "" for faster-whisper because it emits the spaces when needed) - def __init__(self, lan, model=None, logfile=sys.stderr, condition_on_previous_text=None): + def __init__( + self, lan, model=None, logfile=sys.stderr, condition_on_previous_text=None + ): self.logfile = logfile self.transcribe_kargs = {} - self.original_language = lan + self.original_language = lan self.model = model def transcribe(self, audio, init_prompt=""): @@ -363,23 +410,32 @@ def transcribe(self, audio, init_prompt=""): def use_vad(self, vad_name=None): raise NotImplemented("must be implemented in the child class") - - + + class FasterWhisperASR(ASRBase): - """Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version. - """ + """Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version.""" sep = "" - def __init__(self, lan, model=None, logfile=sys.stderr, condition_on_previous_text=None): + def __init__( + self, lan, model=None, logfile=sys.stderr, condition_on_previous_text=None + ): super().__init__(lan, model=model, logfile=logfile) - self.transcribe_kargs['beam_size'] = 1 - self.transcribe_kargs['best_of'] = 1 - self.transcribe_kargs['temperature'] = 0 - self.transcribe_kargs['condition_on_previous_text'] = False if condition_on_previous_text is None else condition_on_previous_text + self.transcribe_kargs["beam_size"] = 1 + self.transcribe_kargs["best_of"] = 1 + self.transcribe_kargs["temperature"] = 0 + self.transcribe_kargs["condition_on_previous_text"] = ( + False if condition_on_previous_text is None else condition_on_previous_text + ) def transcribe(self, audio, init_prompt=""): - segments, info = self.model.transcribe(audio, language=self.original_language, initial_prompt=init_prompt, word_timestamps=True, **self.transcribe_kargs) + segments, info = self.model.transcribe( + audio, + language=self.original_language, + initial_prompt=init_prompt, + word_timestamps=True, + **self.transcribe_kargs, + ) return list(segments) def ts_words(self, segments, timestamps_convert_function=None): @@ -399,6 +455,7 @@ def ts_words(self, segments, timestamps_convert_function=None): def segments_end_ts(self, res): return [s.end for s in res] + class WhisperTimestampedASR(ASRBase): """Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper. On the other hand, the installation for GPU could be easier. @@ -406,37 +463,44 @@ class WhisperTimestampedASR(ASRBase): sep = " " - def __init__(self, lan, model=None, logfile=sys.stderr, condition_on_previous_text=None): + def __init__( + self, lan, model=None, logfile=sys.stderr, condition_on_previous_text=None + ): super().__init__(lan, model=model, logfile=logfile) self.transcribe_kargs["verbose"] = None self.transcribe_kargs["beam_size"] = None self.transcribe_kargs["best_of"] = None self.transcribe_kargs["temperature"] = 0 - self.transcribe_kargs['condition_on_previous_text'] = False if condition_on_previous_text is None else condition_on_previous_text + self.transcribe_kargs["condition_on_previous_text"] = ( + False if condition_on_previous_text is None else condition_on_previous_text + ) from whisper_timestamped import transcribe_timestamped - self.transcribe_timestamped = transcribe_timestamped + self.transcribe_timestamped = transcribe_timestamped def transcribe(self, audio, init_prompt=""): - result = self.transcribe_timestamped(self.model, - audio, language=self.original_language, - initial_prompt=init_prompt, **self.transcribe_kargs) + result = self.transcribe_timestamped( + self.model, + audio, + language=self.original_language, + initial_prompt=init_prompt, + **self.transcribe_kargs, + ) return result - - def ts_words(self,r, timestamps_convert_function=None): + + def ts_words(self, r, timestamps_convert_function=None): # return: transcribe result object to [(beg,end,"word1"), ...] o = [] for s in r["segments"]: for w in s["words"]: if timestamps_convert_function is not None: # print(f"start: {word.start}->{timestamps_convert_function(word.start)}, end: {word.end}->{timestamps_convert_function(word.end)}") - start, end = timestamps_convert_function(w["start"], w['end']) + start, end = timestamps_convert_function(w["start"], w["end"]) t = (start, end, w["text"]) else: - t = (w["start"],w["end"],w["text"]) + t = (w["start"], w["end"], w["text"]) o.append(t) return o def segments_end_ts(self, res): return [s["end"] for s in res["segments"]] - diff --git a/whisper/stt/processing/vad.py b/whisper/stt/processing/vad.py index eb83bd0..d7d5716 100644 --- a/whisper/stt/processing/vad.py +++ b/whisper/stt/processing/vad.py @@ -9,7 +9,8 @@ _vad_import = None -def remove_non_speech(audio, +def remove_non_speech( + audio, use_sample=False, min_speech_duration=0.1, min_silence_duration=1, @@ -18,7 +19,7 @@ def remove_non_speech(audio, method="silero", avoid_empty_speech=False, return_format="tuple", - ): +): """ Remove non-speech segments from audio (using Silero VAD), glue the speech segments together and return the result along with @@ -41,15 +42,17 @@ def remove_non_speech(audio, if True, avoid returning an empty speech segment (re) """ - if USE_CTRANSLATE2 and method=="silero": + if USE_CTRANSLATE2 and method == "silero": from faster_whisper.vad import VadOptions + options = VadOptions( - min_speech_duration_ms=min_speech_duration*1000, - min_silence_duration_ms=min_silence_duration*1000, + min_speech_duration_ms=min_speech_duration * 1000, + min_silence_duration_ms=min_silence_duration * 1000, ) from faster_whisper.vad import get_speech_timestamps + segments = get_speech_timestamps(audio, vad_options=options) - else: + else: segments = get_vad_segments( audio, sample_rate=sample_rate, @@ -64,18 +67,25 @@ def remove_non_speech(audio, if avoid_empty_speech: segments = [(0, audio.shape[-1])] else: - np.array([]), [], lambda t, t2 = None: t if t2 is None else [t, t2] + np.array([]), [], lambda t, t2=None: t if t2 is None else [t, t2] audio_speech = np.concatenate([audio[..., s:e] for s, e in segments], axis=-1) # audio_speech = torch.cat([audio[..., s:e] for s,e in segments], dim=-1) if not use_sample: - segments = [(float(s)/sample_rate, float(e)/sample_rate) for s,e in segments] + segments = [ + (float(s) / sample_rate, float(e) / sample_rate) for s, e in segments + ] if return_format == "dict": segments = [{"start": s, "end": e} for s, e in segments] - return audio_speech, segments, lambda t, t2 = None: do_convert_timestamps(segments, t, t2) + return ( + audio_speech, + segments, + lambda t, t2=None: do_convert_timestamps(segments, t, t2), + ) -def do_convert_timestamps(segments, t, t2 = None): + +def do_convert_timestamps(segments, t, t2=None): """ Convert timestamp from audio without non-speech segments to original audio (with non-speech segments) @@ -85,8 +95,8 @@ def do_convert_timestamps(segments, t, t2 = None): t2: second timestamp to convert (optional), when the two timestamps should be in the same segment """ assert len(segments) - ioffset = 0 # Input offset - ooffset = 0 # Output offset + ioffset = 0 # Input offset + ooffset = 0 # Output offset ipreviousend = 0 result = [] for istart, iend in segments: @@ -98,20 +108,20 @@ def do_convert_timestamps(segments, t, t2 = None): t_in = t <= oend t2_in = t_in if t2 is None else t2 <= oend if t_in or t2_in: - result.append([ - max(istart, min(iend, ioffset + t)), - max(istart, min(iend, ioffset + t2)) if t2 is not None else None - ]) + result.append( + [ + max(istart, min(iend, ioffset + t)), + max(istart, min(iend, ioffset + t2)) if t2 is not None else None, + ] + ) if t_in and t2_in: break if not len(result): - result.append( - [ioffset + t, ioffset + t2 if t2 is not None else None] - ) - + result.append([ioffset + t, ioffset + t2 if t2 is not None else None]) + if len(result) > 1: # Minimize difference between durations - result = sorted(result, key=lambda x: abs(abs(t2-t) - abs(x[1]-x[0]))) + result = sorted(result, key=lambda x: abs(abs(t2 - t) - abs(x[1] - x[0]))) result = result[0] if t2 is None: result = round(result[0], 2) @@ -120,15 +130,15 @@ def do_convert_timestamps(segments, t, t2 = None): return result - -def get_vad_segments(audio, +def get_vad_segments( + audio, sample_rate=16000, output_sample=False, min_speech_duration=0.1, min_silence_duration=0.1, dilatation=0.5, method="silero", - ): +): """ Get speech segments from audio using the method VAD parameters: @@ -146,10 +156,11 @@ def get_vad_segments(audio, VAD method to use (auditok, silero, silero:v3.1) """ global _silero_vad_model, _silero_get_speech_ts, _has_onnx, _vad_import - if isinstance(method, list): # Explicit timestamps - segments = [{"start": s * sample_rate, "end": e * sample_rate} for (s, e) in method] + segments = [ + {"start": s * sample_rate, "end": e * sample_rate} for (s, e) in method + ] dilatation = 0 elif isinstance(method, str) and method.startswith("silero"): @@ -161,24 +172,36 @@ def get_vad_segments(audio, if _silero_vad_model.get(version) is None: # ONNX support since 3.1 in silero if (version is None or version >= "v3.1") and (_has_onnx is not False): - onnx=True + onnx = True try: import onnxruntime - onnxruntime.set_default_logger_severity(3) # Remove warning "Removing initializer 'XXX'. It is not used by any node and should be removed from the model." + + onnxruntime.set_default_logger_severity( + 3 + ) # Remove warning "Removing initializer 'XXX'. It is not used by any node and should be removed from the model." _has_onnx = True except ImportError as err: - logger.warning(f"Please install onnxruntime to use more efficiently silero VAD") + logger.warning( + f"Please install onnxruntime to use more efficiently silero VAD" + ) _has_onnx = False - onnx=False + onnx = False else: - onnx=False + onnx = False # Choose silero version because of problems with version 4, see https://github.com/linto-ai/whisper-timestamped/issues/74 - torch_home = os.environ.get('TORCH_HOME', '~/.cache/torch') - repo_or_dir_master = os.path.expanduser(torch_home + "/hub/snakers4_silero-vad_master") - repo_or_dir_specific = os.path.expanduser(torch_home + f"/hub/snakers4_silero-vad_{version}") if version else repo_or_dir_master + torch_home = os.environ.get("TORCH_HOME", "~/.cache/torch") + repo_or_dir_master = os.path.expanduser( + torch_home + "/hub/snakers4_silero-vad_master" + ) + repo_or_dir_specific = ( + os.path.expanduser(torch_home + f"/hub/snakers4_silero-vad_{version}") + if version + else repo_or_dir_master + ) repo_or_dir = repo_or_dir_specific tmp_folder = None + def apply_folder_hack(): nonlocal tmp_folder if os.path.exists(repo_or_dir_master): @@ -196,69 +219,93 @@ def apply_folder_hack(): source = "local" if not os.path.exists(repo_or_dir): # Load specific version of silero - repo_or_dir = f"snakers4/silero-vad:{version}" if version else "snakers4/silero-vad" + repo_or_dir = ( + f"snakers4/silero-vad:{version}" + if version + else "snakers4/silero-vad" + ) source = "github" if need_folder_hack: apply_folder_hack() try: if _vad_import is None: from torch.hub import load as torch_load + _vad_import = torch_load - silero_vad_model, utils = _vad_import(repo_or_dir=repo_or_dir, model="silero_vad", onnx=onnx, source=source) + silero_vad_model, utils = _vad_import( + repo_or_dir=repo_or_dir, + model="silero_vad", + onnx=onnx, + source=source, + ) _silero_vad_model[version] = silero_vad_model except ImportError as err: - raise RuntimeError(f"Please install what is needed to use the silero VAD (or use another VAD method)") from err + raise RuntimeError( + f"Please install what is needed to use the silero VAD (or use another VAD method)" + ) from err except Exception as err: - raise RuntimeError(f"Problem when installing silero with version {version}. Check versions here: https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models") from err + raise RuntimeError( + f"Problem when installing silero with version {version}. Check versions here: https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models" + ) from err finally: if need_folder_hack: if os.path.exists(repo_or_dir_master): os.remove(repo_or_dir_master) if tmp_folder: shutil.move(tmp_folder, repo_or_dir_master) - assert os.path.isdir(repo_or_dir_specific), f"Unexpected situation: missing {repo_or_dir_specific}" + assert os.path.isdir( + repo_or_dir_specific + ), f"Unexpected situation: missing {repo_or_dir_specific}" _silero_get_speech_ts = utils[0] # Cheap normalization of the volume - if isinstance(audio, np.ndarray): - audio = audio / max(0.1, np.max(np.abs(audio))) - else: - audio = audio / max(0.1, audio.abs().max()) - - segments = _silero_get_speech_ts(audio, _silero_vad_model[version], - sampling_rate = sample_rate, - min_speech_duration_ms = round(min_speech_duration * 1000), - min_silence_duration_ms = round(min_silence_duration * 1000), - return_seconds = False, + audio = audio / max(0.1, audio.abs().max()) + + segments = _silero_get_speech_ts( + audio, + _silero_vad_model[version], + sampling_rate=sample_rate, + min_speech_duration_ms=round(min_speech_duration * 1000), + min_silence_duration_ms=round(min_silence_duration * 1000), + return_seconds=False, ) elif method == "auditok": - # import auditok if _vad_import is None: from auditok import split + _vad_import = split # Cheap normalization of the volume # audio = audio / max(0.1, audio.abs().max()) - audio = audio / max(0.1, np.max(np.abs(audio))) - data = (audio * 32767).astype(np.int16).tobytes() - + if isinstance(audio, np.ndarray): + audio = audio / max(0.1, np.max(np.abs(audio))) + data = (audio * 32767).astype(np.int16).tobytes() + else: + audio = audio / max(0.1, audio.abs().max()) + data = (audio.numpy() * 32767).astype(np.int16).tobytes() + audio_duration = len(audio) / sample_rate segments = _vad_import( data, - sampling_rate=sample_rate, # sampling frequency in Hz - channels=1, # number of channels - sample_width=2, # number of bytes per sample - min_dur=min_speech_duration, # minimum duration of a valid audio event in seconds - max_dur=audio_duration, # maximum duration of an event - max_silence=min(audio_duration*.95, min_silence_duration), # maximum duration of tolerated continuous silence within an event + sampling_rate=sample_rate, # sampling frequency in Hz + channels=1, # number of channels + sample_width=2, # number of bytes per sample + min_dur=min_speech_duration, # minimum duration of a valid audio event in seconds + max_dur=audio_duration, # maximum duration of an event + max_silence=min( + audio_duration * 0.95, min_silence_duration + ), # maximum duration of tolerated continuous silence within an event energy_threshold=50, drop_trailing_silence=True, ) - segments = [{"start": s._meta.start * sample_rate, "end": s._meta.end * sample_rate} for s in segments] + segments = [ + {"start": s._meta.start * sample_rate, "end": s._meta.end * sample_rate} + for s in segments + ] else: raise ValueError(f"Got unexpected VAD method {method}") @@ -269,7 +316,7 @@ def apply_folder_hack(): for seg in segments: new_seg = { "start": max(0, seg["start"] - dilatation), - "end": min(len(audio), seg["end"] + dilatation) + "end": min(len(audio), seg["end"] + dilatation), } if len(new_segments) > 0 and new_segments[-1]["end"] >= new_seg["start"]: new_segments[-1]["end"] = new_seg["end"] @@ -289,6 +336,7 @@ def apply_folder_hack(): seg["end"] = round(seg["end"]) return segments + def check_vad_method(method, with_version=False): """ Check whether the VAD method is valid and return the method in a consistent format @@ -296,14 +344,16 @@ def check_vad_method(method, with_version=False): method: str or list or True or False """ if method in [True, "True", "true"]: - return check_vad_method("silero") # default method + return check_vad_method("silero") # default method elif method in [None, False, "False", "false", "None", "none"]: return None - elif not isinstance(method, str) and hasattr(method, '__iter__'): + elif not isinstance(method, str) and hasattr(method, "__iter__"): # list of explicit timestamps checked_pairs = [] for s_e in method: - assert len(s_e) == 2, f"Got unexpected element {s_e} in the list of VAD segments. Expect (start, end) pairs" + assert ( + len(s_e) == 2 + ), f"Got unexpected element {s_e} in the list of VAD segments. Expect (start, end) pairs" checked_pairs.append(tuple(s_e)) return checked_pairs elif isinstance(method, str) and method.startswith("silero"): @@ -316,7 +366,9 @@ def check_vad_method(method, with_version=False): try: assert float(version[1:]) >= 1 except: - raise ValueError(f"Got unexpected silero version {version} (please check https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models)") + raise ValueError( + f"Got unexpected silero version {version} (please check https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models)" + ) if with_version: return ("silero", version) else: @@ -325,12 +377,14 @@ def check_vad_method(method, with_version=False): try: import auditok except ImportError: - raise ImportError("Please install auditok to use the auditok VAD (or use another VAD method)") + raise ImportError( + "Please install auditok to use the auditok VAD (or use another VAD method)" + ) else: try: method = eval(method) - assert hasattr(method, '__iter__') + assert hasattr(method, "__iter__") except: raise ValueError(f"Got unexpected VAD method {method}") return check_vad_method(method, with_version=with_version) - return method \ No newline at end of file + return method From 3d6eadb650fd3cdaef2457eff22be50c71690435 Mon Sep 17 00:00:00 2001 From: AudranBert Date: Tue, 2 Apr 2024 17:26:27 +0200 Subject: [PATCH 20/50] rename USE_VAD to VAD + add VAD in Readme --- whisper/.envdefault | 2 +- whisper/README.md | 5 +++-- whisper/stt/__init__.py | 6 +++--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/whisper/.envdefault b/whisper/.envdefault index 795a5ff..5e6d5a6 100644 --- a/whisper/.envdefault +++ b/whisper/.envdefault @@ -51,4 +51,4 @@ STREAMING_PORT=80 # HTTP PARAMETERS ENABLE_STREAMING=true -USE_VAD=auditok \ No newline at end of file +VAD=auditok \ No newline at end of file diff --git a/whisper/README.md b/whisper/README.md index f5119a5..91c260c 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -122,8 +122,9 @@ cp whisper/.envdefault whisper/.env | NUM_THREADS | Number of threads (maximum) to use for things running on CPU | `1` \| `4` \| ... | | CUDA_VISIBLE_DEVICES | GPU device index to use, when running on GPU/CUDA. We also recommend to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` on multi-GPU machines | `0` \| `1` \| `2` \| ... | | CONCURRENCY | Maximum number of parallel requests | `2` | -| ENABLE_STREAMING | (For the http mode) enable the /streaming websocket route | true\|false | -| STREAMING_PORT | (For the websocket mode) the listening port for ingoing WS connexions. | 80 | +| VAD | Activate (and specify which method) Voice Activity Detection for removing non speech segments | `if the argument is not specidifed, it will use auditok` \| `true (will use auditok)` \| `false` \| `auditok` \| `silero` \| +| ENABLE_STREAMING | (For the http mode) enable the /streaming websocket route | `true\|false` | +| STREAMING_PORT | (For the websocket mode) the listening port for ingoing WS connexions. | `80` | | SERVICE_NAME | (For the task mode only) queue's name for task processing | `my-stt` | | SERVICE_BROKER | (For the task mode only) URL of the message broker | `redis://my-broker:6379` | | BROKER_PASS | (For the task mode only) broker password | `my-password` \| (empty) | diff --git a/whisper/stt/__init__.py b/whisper/stt/__init__.py index 93a04e5..feb6276 100644 --- a/whisper/stt/__init__.py +++ b/whisper/stt/__init__.py @@ -12,12 +12,12 @@ # see https://github.com/guillaumekln/faster-whisper/issues/150 os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # GPU in the right order -if os.environ.get("USE_VAD","auditok") in [True, "true", 1]: +if os.environ.get("VAD","auditok") in [True, "true", 1]: VAD = "auditok" -elif os.environ.get("USE_VAD","auditok") in [False, "false", 0]: +elif os.environ.get("VAD","auditok") in [False, "false", 0]: VAD = False else: - VAD = os.environ.get("USE_VAD","auditok") + VAD = os.environ.get("VAD","auditok") NUM_THREADS = os.environ.get("NUM_THREADS", os.environ.get("OMP_NUM_THREADS")) NUM_THREADS = int(NUM_THREADS) From 798c109943014dc358c10a33efda68eedb8a0cea Mon Sep 17 00:00:00 2001 From: AudranBert Date: Wed, 3 Apr 2024 17:43:37 +0200 Subject: [PATCH 21/50] add vad parameters + basic warmup + upd readme/env --- http_server/ingress.py | 12 +++++------ kaldi/stt/processing/__init__.py | 3 +++ whisper/.envdefault | 19 ++++++++++------- whisper/README.md | 2 +- whisper/stt/__init__.py | 8 +++++-- whisper/stt/processing/__init__.py | 4 ++++ whisper/stt/processing/decoding.py | 9 ++++---- whisper/stt/processing/streaming.py | 33 +++++++++++++++++++---------- whisper/stt/processing/vad.py | 20 ++++++++--------- 9 files changed, 68 insertions(+), 42 deletions(-) diff --git a/http_server/ingress.py b/http_server/ingress.py index c70fa52..3d25cfd 100644 --- a/http_server/ingress.py +++ b/http_server/ingress.py @@ -9,7 +9,7 @@ from flask import Flask, json, request from serving import GeventServing, GunicornServing from stt import logger as stt_logger -from stt.processing import MODEL, USE_GPU, decode, load_wave_buffer +from stt.processing import MODEL, USE_GPU, decode, load_wave_buffer, warmup from swagger import setupSwaggerUI app = Flask("__stt-standalone-worker__") @@ -130,10 +130,10 @@ def server_error(error): serving_type = GunicornServing logger.debug("Serving with gunicorn") - def worker_started(worker): - logger.info(f"Worker started {worker.pid}") - MODEL[0].check_loaded() - logger.info("Worker fully initialized") + def post_worker_init(worker): + logger.info(f"Worker {worker.pid} init") + warmup() + logger.info(f"Worker {worker.pid} fully initialized") serving = serving_type( app, @@ -141,7 +141,7 @@ def worker_started(worker): "bind": f"0.0.0.0:{args.service_port}", "workers": args.workers, "timeout": 3600 * 24, - "post_worker_init": worker_started, + "post_worker_init": post_worker_init, }, ) logger.info(args) diff --git a/kaldi/stt/processing/__init__.py b/kaldi/stt/processing/__init__.py index 9f99406..e0476a1 100644 --- a/kaldi/stt/processing/__init__.py +++ b/kaldi/stt/processing/__init__.py @@ -29,5 +29,8 @@ sys.exit(-1) logger.info("Acoustic model and decoding graph loaded. (t={}s)".format(time() - start)) +def warmup(): + pass + # Not implemented yet in Kaldi USE_GPU = False diff --git a/whisper/.envdefault b/whisper/.envdefault index 5e6d5a6..500d4a3 100644 --- a/whisper/.envdefault +++ b/whisper/.envdefault @@ -9,6 +9,12 @@ SERVICE_NAME=stt SERVICES_BROKER=redis://172.17.0.1:6379 BROKER_PASS= +# HTTP PARAMETERS +ENABLE_STREAMING=true + +# WEBSOCKET PARAMETERS +STREAMING_PORT=80 + ############################################ # STT MODELING PARAMETERS ############################################ @@ -30,6 +36,11 @@ PROMPT= # This option is experimental (and not implemented with ctranslate2). # ALIGNMENT_MODEL=wav2vec +VAD=auditok +VAD_dilatation=0.1 +VAD_min_speech_duration=0.1 +VAD_min_silence_duration=0.1 + ############################################ # EFFICIENCY PARAMETERS ############################################ @@ -44,11 +55,3 @@ NUM_THREADS=4 # Number of workers CONCURRENCY=2 - -# WEBSOCKET PARAMETERS -STREAMING_PORT=80 - -# HTTP PARAMETERS -ENABLE_STREAMING=true - -VAD=auditok \ No newline at end of file diff --git a/whisper/README.md b/whisper/README.md index 91c260c..54af00e 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -122,7 +122,7 @@ cp whisper/.envdefault whisper/.env | NUM_THREADS | Number of threads (maximum) to use for things running on CPU | `1` \| `4` \| ... | | CUDA_VISIBLE_DEVICES | GPU device index to use, when running on GPU/CUDA. We also recommend to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` on multi-GPU machines | `0` \| `1` \| `2` \| ... | | CONCURRENCY | Maximum number of parallel requests | `2` | -| VAD | Activate (and specify which method) Voice Activity Detection for removing non speech segments | `if the argument is not specidifed, it will use auditok` \| `true (will use auditok)` \| `false` \| `auditok` \| `silero` \| +| VAD | Voice Activity Detection method. Use "false" to disable. If not specified, the default is auditok VAD. | `true` \| `false` \| `1` \| `0` \| `auditok` \| `silero` | ENABLE_STREAMING | (For the http mode) enable the /streaming websocket route | `true\|false` | | STREAMING_PORT | (For the websocket mode) the listening port for ingoing WS connexions. | `80` | | SERVICE_NAME | (For the task mode only) queue's name for task processing | `my-stt` | diff --git a/whisper/stt/__init__.py b/whisper/stt/__init__.py index feb6276..d679fd3 100644 --- a/whisper/stt/__init__.py +++ b/whisper/stt/__init__.py @@ -12,13 +12,17 @@ # see https://github.com/guillaumekln/faster-whisper/issues/150 os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # GPU in the right order -if os.environ.get("VAD","auditok") in [True, "true", 1]: +if os.environ.get("VAD","auditok") in ["true", 1]: VAD = "auditok" -elif os.environ.get("VAD","auditok") in [False, "false", 0]: +elif os.environ.get("VAD","auditok") in ["false", 0]: VAD = False else: VAD = os.environ.get("VAD","auditok") +VAD_DILATATION = float(os.environ.get("VAD_DILATATION", 0.5)) +VAD_MIN_SPEECH_DURATION = float(os.environ.get("VAD_MIN_SPEECH_DURATION", 0.1)) +VAD_MIN_SILENCE_DURATION = float(os.environ.get("VAD_MAX_SILENCE_DURATION", 0.1)) + NUM_THREADS = os.environ.get("NUM_THREADS", os.environ.get("OMP_NUM_THREADS")) NUM_THREADS = int(NUM_THREADS) diff --git a/whisper/stt/processing/__init__.py b/whisper/stt/processing/__init__.py index 2140768..b574b71 100644 --- a/whisper/stt/processing/__init__.py +++ b/whisper/stt/processing/__init__.py @@ -90,4 +90,8 @@ def __call__(self, *args, **kwargs): ) alignment_model = {} # Alignement model(s) will be loaded on the fly + +def warmup(): + model.check_loaded() + MODEL = (model, alignment_model) diff --git a/whisper/stt/processing/decoding.py b/whisper/stt/processing/decoding.py index ffa2024..9ead692 100644 --- a/whisper/stt/processing/decoding.py +++ b/whisper/stt/processing/decoding.py @@ -5,7 +5,7 @@ from typing import Tuple, Union import numpy as np -from stt import USE_CTRANSLATE2, VAD, logger +from stt import USE_CTRANSLATE2, VAD, VAD_DILATATION, VAD_MIN_SILENCE_DURATION, VAD_MIN_SPEECH_DURATION, logger from .vad import remove_non_speech from .alignment_model import get_alignment_model, load_alignment_model @@ -47,7 +47,6 @@ def decode( ) -> dict: if language is None: language = get_language() - kwargs = copy.copy(locals()) kwargs.pop("model_and_alignementmodel") kwargs["model"], kwargs["alignment_model"] = model_and_alignementmodel @@ -85,7 +84,8 @@ def decode_ct2( if kwargs.get("best_of") is None: kwargs["best_of"] = 1 if VAD: - _, speech_segments, _ = remove_non_speech(audio, method=VAD, return_format="dict") + _, speech_segments, _ = remove_non_speech(audio, use_sample=True, method=VAD, dilatation=VAD_DILATATION, \ + min_silence_duration=VAD_MIN_SILENCE_DURATION, min_speech_duration=VAD_MIN_SPEECH_DURATION, return_format="dict") segments, info = model.transcribe( audio, word_timestamps=with_word_timestamps, @@ -123,7 +123,8 @@ def decode_torch( fp16 = model.device != torch.device("cpu") if VAD: - _, speech_segments, _ = remove_non_speech(audio, method=VAD) + _, speech_segments, _ = remove_non_speech(audio, use_sample=True, method=VAD, dilatation=VAD_DILATATION, \ + min_silence_duration=VAD_MIN_SILENCE_DURATION, min_speech_duration=VAD_MIN_SPEECH_DURATION,) kwargs = dict( language=language, diff --git a/whisper/stt/processing/streaming.py b/whisper/stt/processing/streaming.py index 2382717..7d2efce 100644 --- a/whisper/stt/processing/streaming.py +++ b/whisper/stt/processing/streaming.py @@ -3,7 +3,7 @@ import string import numpy as np from .vad import remove_non_speech -from stt import logger, USE_CTRANSLATE2, VAD +from stt import logger, USE_CTRANSLATE2, VAD, VAD_DILATATION, VAD_MIN_SPEECH_DURATION, VAD_MIN_SILENCE_DURATION from websockets.legacy.server import WebSocketServerProtocol from simple_websocket.ws import Server as WSServer @@ -43,7 +43,8 @@ async def wssDecode(ws: WebSocketServerProtocol, model_and_alignementmodel): logger.info("Using whisper_timestamped for decoding") asr = WhisperTimestampedASR(model=model, lan="fr") online = OnlineASRProcessor( - asr, logfile=sys.stderr, buffer_trimming=8, use_vad=VAD, sample_rate=sample_rate + asr, logfile=sys.stderr, buffer_trimming=8, vad=VAD, sample_rate=sample_rate, \ + dilatation=VAD_DILATATION, min_speech_duration=VAD_MIN_SPEECH_DURATION, min_silence_duration=VAD_MIN_SILENCE_DURATION ) logger.info("Starting transcription ...") while True: @@ -85,7 +86,8 @@ def ws_streaming(websocket_server: WSServer, model_and_alignementmodel): logger.info("Using whisper_timestamped for decoding") asr = WhisperTimestampedASR(model=model, lan="fr") online = OnlineASRProcessor( - asr, logfile=sys.stderr, buffer_trimming=8, use_vad=VAD, sample_rate=sample_rate + asr, logfile=sys.stderr, buffer_trimming=8, vad=VAD, sample_rate=sample_rate, \ + dilatation=VAD_DILATATION, min_speech_duration=VAD_MIN_SPEECH_DURATION, min_silence_duration=VAD_MIN_SILENCE_DURATION ) logger.info("Starting transcription ...") while True: @@ -192,9 +194,12 @@ def __init__( self, asr, buffer_trimming=15, - use_vad="auditok", + vad="auditok", logfile=sys.stderr, sample_rate=16000, + min_speech_duration=0.1, + min_silence_duration=0.1, + dilatation=0.5, ): """asr: WhisperASR object tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer. It can be None, if "segment" buffer trimming option is used, then tokenizer is not used at all. @@ -208,7 +213,10 @@ def __init__( self.init() self.buffer_trimming_sec = buffer_trimming - self.use_vad = use_vad + self.vad = vad + self.vad_dilatation = dilatation + self.vad_min_speech_duration = min_speech_duration + self.vad_min_silence_duration = min_silence_duration self.sampling_rate = sample_rate def init(self): @@ -257,19 +265,22 @@ def process_iter(self): logger.debug( f"Transcribing {len(self.audio_buffer)/self.sampling_rate:2.2f} seconds starting at {self.buffer_time_offset:2.2f}s" ) - if self.use_vad: + if self.vad: np_buffer = np.array(self.audio_buffer) audio_speech, segments, convertion_function = remove_non_speech( np_buffer, - method=self.use_vad, + method=self.vad, + use_sample=True, sample_rate=self.sampling_rate, - dilatation=0.5, + dilatation=self.vad_dilatation, + min_speech_duration=self.vad_min_speech_duration, + min_silence_duration=self.vad_min_silence_duration, ) res = self.asr.transcribe(audio_speech, init_prompt=prompt) else: res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt) # transform to [(beg,end,"word1"), ...] - tsw = self.asr.ts_words(res, convertion_function if self.use_vad else None) + tsw = self.asr.ts_words(res, convertion_function if self.vad else None) self.transcript_buffer.insert(tsw, self.buffer_time_offset) o, buffer = self.transcript_buffer.flush() self.commited.extend(o) @@ -289,8 +300,8 @@ def process_iter(self): if len(self.audio_buffer) / self.sampling_rate > self.buffer_trimming_sec: self.chunk_completed_segment( res, - chunk_silence=self.use_vad, - speech_segments=segments if self.use_vad else False, + chunk_silence=self.vad, + speech_segments=segments if self.vad else False, ) logger.debug( diff --git a/whisper/stt/processing/vad.py b/whisper/stt/processing/vad.py index d7d5716..22af275 100644 --- a/whisper/stt/processing/vad.py +++ b/whisper/stt/processing/vad.py @@ -62,27 +62,25 @@ def remove_non_speech( dilatation=dilatation, method=method, ) + segments = apply_dilatation(segments, dilatation, sample_rate, audio, output_sample=True) segments = [(seg["start"], seg["end"]) for seg in segments] if len(segments) == 0: if avoid_empty_speech: segments = [(0, audio.shape[-1])] else: np.array([]), [], lambda t, t2=None: t if t2 is None else [t, t2] - - audio_speech = np.concatenate([audio[..., s:e] for s, e in segments], axis=-1) - # audio_speech = torch.cat([audio[..., s:e] for s,e in segments], dim=-1) - if not use_sample: segments = [ (float(s) / sample_rate, float(e) / sample_rate) for s, e in segments ] + if return_format == "dict": segments = [{"start": s, "end": e} for s, e in segments] - return ( - audio_speech, - segments, - lambda t, t2=None: do_convert_timestamps(segments, t, t2), - ) + return None, segments, lambda t, t2=None: do_convert_timestamps(segments, t, t2) + + audio_speech = np.concatenate([audio[..., s:e] for s, e in segments], axis=-1) + + return audio_speech, segments, lambda t, t2=None: do_convert_timestamps(segments, t, t2) def do_convert_timestamps(segments, t, t2=None): @@ -309,7 +307,10 @@ def apply_folder_hack(): else: raise ValueError(f"Got unexpected VAD method {method}") + return segments + +def apply_dilatation(segments, dilatation, sample_rate, audio, output_sample=False): if dilatation > 0: dilatation = round(dilatation * sample_rate) new_segments = [] @@ -336,7 +337,6 @@ def apply_folder_hack(): seg["end"] = round(seg["end"]) return segments - def check_vad_method(method, with_version=False): """ Check whether the VAD method is valid and return the method in a consistent format From 56533bb248f402cda46eb8ea622afd6a29cb798f Mon Sep 17 00:00:00 2001 From: AudranBert Date: Thu, 4 Apr 2024 16:43:23 +0200 Subject: [PATCH 22/50] add warmup + update CONCURRENCY doc --- whisper/.envdefault | 2 +- whisper/Dockerfile.ctranslate2 | 1 + whisper/Dockerfile.ctranslate2.cpu | 1 + whisper/Dockerfile.torch | 1 + whisper/Dockerfile.torch.cpu | 1 + whisper/README.md | 2 +- whisper/stt/processing/__init__.py | 38 +++++++++++++++++------------- 7 files changed, 27 insertions(+), 19 deletions(-) diff --git a/whisper/.envdefault b/whisper/.envdefault index 500d4a3..512105b 100644 --- a/whisper/.envdefault +++ b/whisper/.envdefault @@ -53,5 +53,5 @@ VAD_min_silence_duration=0.1 # Number of threads per worker when running on CPU NUM_THREADS=4 -# Number of workers +# Number of workers minus one (all except from the main one) CONCURRENCY=2 diff --git a/whisper/Dockerfile.ctranslate2 b/whisper/Dockerfile.ctranslate2 index c2b3cd5..5fd3c53 100644 --- a/whisper/Dockerfile.ctranslate2 +++ b/whisper/Dockerfile.ctranslate2 @@ -15,6 +15,7 @@ COPY websocket /usr/src/app/websocket COPY document /usr/src/app/document COPY whisper/stt /usr/src/app/stt COPY whisper/docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./ +COPY test/bonjour.wav /usr/src/app/test/bonjour.wav ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/stt" diff --git a/whisper/Dockerfile.ctranslate2.cpu b/whisper/Dockerfile.ctranslate2.cpu index 5c4817c..1f0f40c 100644 --- a/whisper/Dockerfile.ctranslate2.cpu +++ b/whisper/Dockerfile.ctranslate2.cpu @@ -14,6 +14,7 @@ COPY websocket /usr/src/app/websocket COPY document /usr/src/app/document COPY whisper/stt /usr/src/app/stt COPY whisper/docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./ +COPY test/bonjour.wav /usr/src/app/test/bonjour.wav ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/stt" diff --git a/whisper/Dockerfile.torch b/whisper/Dockerfile.torch index 06b22f3..acd0b08 100644 --- a/whisper/Dockerfile.torch +++ b/whisper/Dockerfile.torch @@ -15,6 +15,7 @@ COPY websocket /usr/src/app/websocket COPY document /usr/src/app/document COPY whisper/stt /usr/src/app/stt COPY whisper/docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./ +COPY test/bonjour.wav /usr/src/app/test/bonjour.wav ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/stt" diff --git a/whisper/Dockerfile.torch.cpu b/whisper/Dockerfile.torch.cpu index 549a767..2d45336 100644 --- a/whisper/Dockerfile.torch.cpu +++ b/whisper/Dockerfile.torch.cpu @@ -20,6 +20,7 @@ COPY websocket /usr/src/app/websocket COPY document /usr/src/app/document COPY whisper/stt /usr/src/app/stt COPY whisper/docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./ +COPY test/bonjour.wav /usr/src/app/test/bonjour.wav ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/stt" diff --git a/whisper/README.md b/whisper/README.md index 54af00e..ec1953a 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -121,7 +121,7 @@ cp whisper/.envdefault whisper/.env | DEVICE | Device to use for the model (by default, GPU/CUDA is used if it is available, CPU otherwise) | `cpu` \| `cuda` | | NUM_THREADS | Number of threads (maximum) to use for things running on CPU | `1` \| `4` \| ... | | CUDA_VISIBLE_DEVICES | GPU device index to use, when running on GPU/CUDA. We also recommend to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` on multi-GPU machines | `0` \| `1` \| `2` \| ... | -| CONCURRENCY | Maximum number of parallel requests | `2` | +| CONCURRENCY | Maximum number of parallel requests (number of workers minus one) | `2` | | VAD | Voice Activity Detection method. Use "false" to disable. If not specified, the default is auditok VAD. | `true` \| `false` \| `1` \| `0` \| `auditok` \| `silero` | ENABLE_STREAMING | (For the http mode) enable the /streaming websocket route | `true\|false` | | STREAMING_PORT | (For the websocket mode) the listening port for ingoing WS connexions. | `80` | diff --git a/whisper/stt/processing/__init__.py b/whisper/stt/processing/__init__.py index b574b71..95e103a 100644 --- a/whisper/stt/processing/__init__.py +++ b/whisper/stt/processing/__init__.py @@ -18,7 +18,12 @@ "USE_GPU", ] - +def warmup(): + model.check_loaded() + audio_data = load_audiofile("test/bonjour.wav") + transcription = decode(audio_data, MODEL, False) + logger.info(f"Warmup result: {transcription}") + class LazyLoadedModel: def __init__(self, model_type, device, num_threads): self.model_type = model_type @@ -63,18 +68,6 @@ def __call__(self, *args, **kwargs): logger.info(f"VAD={VAD}") logger.info(f"USE_CTRANSLATE2={USE_CTRANSLATE2}") -# Load ASR model -model_type = os.environ.get("MODEL", "medium") -logger.info( - f"Loading Whisper model {model_type} ({'local' if os.path.exists(model_type) else 'remote'})..." -) -try: - model = LazyLoadedModel(model_type, device=device, num_threads=NUM_THREADS) - if str(device).lower() != "cpu": - model.check_loaded() -except Exception as err: - raise Exception("Failed to load transcription model: {}".format(str(err))) from err - # Load alignment model (if any) alignment_model = get_alignment_model(os.environ.get("alignment_model"), language) if alignment_model: @@ -91,7 +84,18 @@ def __call__(self, *args, **kwargs): alignment_model = {} # Alignement model(s) will be loaded on the fly -def warmup(): - model.check_loaded() - -MODEL = (model, alignment_model) +# Load ASR model +model_type = os.environ.get("MODEL", "medium") +logger.info( + f"Loading Whisper model {model_type} ({'local' if os.path.exists(model_type) else 'remote'})..." +) +try: + model = LazyLoadedModel(model_type, device=device, num_threads=NUM_THREADS) + MODEL = (model, alignment_model) + if str(device).lower() != "cpu": + warmup() +except Exception as err: + raise Exception("Failed to load transcription model: {}".format(str(err))) from err + + + \ No newline at end of file From 0563487d2b46403f3bccee6be3fb80e511486eaa Mon Sep 17 00:00:00 2001 From: AudranBert Date: Thu, 4 Apr 2024 18:03:27 +0200 Subject: [PATCH 23/50] add tests for http and whisper --- test/.envtest | 55 +++++++++++ test/run_server.sh | 6 ++ test/test.sh | 240 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 301 insertions(+) create mode 100644 test/.envtest create mode 100644 test/run_server.sh create mode 100644 test/test.sh diff --git a/test/.envtest b/test/.envtest new file mode 100644 index 0000000..e2f047c --- /dev/null +++ b/test/.envtest @@ -0,0 +1,55 @@ +############################################ +# SERVING PARAMETERS +############################################ +# "http" or "task" or "websocket" +SERVICE_MODE=http + +# Below: used when SERVICE_MODE=task +SERVICE_NAME=stt +SERVICES_BROKER=redis://172.17.0.1:6379 +BROKER_PASS= + +# HTTP PARAMETERS +ENABLE_STREAMING=true + +# WEBSOCKET PARAMETERS +STREAMING_PORT=80 + +############################################ +# STT MODELING PARAMETERS +############################################ + +# The model can be a path to a model (e.g. "/root/.cache/whisper/large-v3.pt", "/root/.cache/huggingface/hub/models--openai--whisper-large-v3"), +# or a model size ("tiny", "base", "small", "medium", "large-v1", "large-v2" or "large-v3") +# or a HuggingFace model name (e.g. "distil-whisper/distil-large-v2") +MODEL=large-v3 + +# The language can be in different formats: "en", "en-US", "English", ... +# If not set or set to "*", the language will be detected automatically. +LANGUAGE=* + +# Prompt to use for the model. This can be used to provide context to the model, to encourage disfluencies or a special behaviour regarding punctuation and capitalization. +PROMPT= + +# An alignment wav2vec model can be used to get word timestamps. +# It can be a path to a model, a language code (fr, en, ...), or "wav2vec" to automatically chose a model for the language +# This option is experimental (and not implemented with ctranslate2). +# ALIGNMENT_MODEL=wav2vec + +VAD_dilatation=0.1 +VAD_min_speech_duration=0.1 +VAD_min_silence_duration=0.1 + +############################################ +# EFFICIENCY PARAMETERS +############################################ + +# Device to use. It can be "cuda" to force/check GPU, "cpu" to force computation on CPU, or a specific GPU ("cuda:0", "cuda:1", ...) +# CUDA_DEVICE_ORDER=PCI_BUS_ID +# CUDA_VISIBLE_DEVICES=0 + +# Number of threads per worker when running on CPU +NUM_THREADS=4 + +# Number of workers minus one (all except from the main one) +CONCURRENCY=1 diff --git a/test/run_server.sh b/test/run_server.sh new file mode 100644 index 0000000..5c9c126 --- /dev/null +++ b/test/run_server.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +docker build . -f $1 -t linto-stt-whisper:latest +cp $2 whisper/.env +touch build_finished +docker run --rm -p 8080:80 --name test_container --env-file whisper/.env --gpus all -v /home/abert/.cache:/root/.cache linto-stt-whisper:latest diff --git a/test/test.sh b/test/test.sh new file mode 100644 index 0000000..74603d6 --- /dev/null +++ b/test/test.sh @@ -0,0 +1,240 @@ +#!/bin/bash + +tests_run=0 +passed=0 +failed=0 + +function test_failed() { + mkdir -p test/tests_failed + mv $2 .envtmp test/tests_failed/$local_test_id.env + echo 'Test failed.' + echo 'See test/test.log for more details.' + echo '.Env file has been moved to tests_failed directory.' >> test/test.log + echo 'Test failed.' >> test/test.log + failed=$((failed + 1)) + docker stop test_container + pkill -P $pid + echo '' >> test/test.log + # exit 1 +} + +function test_finished(){ + echo 'Test passed.' + echo 'Test passed.' >> test/test.log + passed=$((passed + 1)) + docker stop test_container + pkill -P $pid + echo '' >> test/test.log +} + +function ending() { + echo '' + echo 'Ending the tests...' + echo $passed/$tests_run tests passed. + echo $failed/$tests_run tests failed. + if [ $failed -eq 0 ]; then + echo 'TEST PASSED.' + else + echo 'TEST FAILED.' + fi + docker stop test_container + pkill -P $pid + exit 1 +} + +# Fonction pour construire l'image Docker +build_docker_image() { + local docker_image="$1" + local config_file="$2" + test/run_server.sh $docker_image $2 > /dev/null 2>&1 +} + +function ctrl_c() { + echo '' + echo "Ctrl + C happened, attempting to stop the server..." + rm build_finished + rm .envtmp + ending +} + + +# Attend la création du fichier avec un timeout de 600 secondes +wait_for_file_creation_with_timeout() { + local file="$1" + local timeout=600 # 10 minutes en secondes + local start_time=$(date +%s) + + while [ ! -f "$file" ]; do + current_time=$(date +%s) + elapsed_time=$((current_time - start_time)) + if [ $elapsed_time -ge $timeout ]; then + echo "Timeout. The docker image took too long to be built." >> test/test.log + return 1 + fi + sleep 1 + done + sleep 1 + if ps -p $pid > /dev/null; then + process_running=true + else + echo "Docker building process failed." >> test/test.log + rm $file + return 1 + fi + echo "File $file has been created. Docker image has been successfully built in $elapsed_time sec." >> test/test.log + rm $file + return 0 +} + + + +check_http_server_availability() { + local server="$1" + local total_wait_time=600 # 10 minutes en secondes + local retry_interval=5 # Interval entre les tentatives (en secondes) + local elapsed_time=0 + + while [ $elapsed_time -lt $total_wait_time ]; do + # Test de la disponibilité du serveur HTTP + curl -s --head --request GET "$server" | grep "200 OK" + if [ $? -eq 0 ]; then + echo "The server $server is available after $elapsed_time sec." >> test/test.log + sleep 2 + return 0 + fi + + # Attendre avant la prochaine tentative + sleep $retry_interval + elapsed_time=$((elapsed_time + retry_interval)) + done + + echo "The server $server is not available after $total_wait_time seconds, server launching must have failed." >> test/test.log + return 1 +} + +make_env() +{ + local env_file="$1" + cp $env_file .envtmp + if [ -z "$2" ]; then + return 0 + else + echo $2 >> test/test.log + echo $2 >> .envtmp + fi + if [ -z "$3" ]; then + return 0 + else + echo $3 >> test/test.log + echo $3 >> .envtmp + fi + if [ -z "$4" ]; then + return 0 + else + echo $4 >> test/test.log + echo $4 >> .envtmp + fi + +} + +process_test() +{ + echo '' >> test/test.log + + echo Test $test_id >> test/test.log + echo Docker image: $1 >> test/test.log + echo Audio file: $3 >> test/test.log + echo Test type: $4 >> test/test.log + echo '' + echo Starting test $test_id + local config_file="$2" + make_env $config_file $5 $6 $7 + local local_test_id=$test_id + test_id=$((test_id + 1)) + tests_run=$((tests_run + 1)) + local docker_image="$1" + local test_file="$3" + local test_type="$4" + # Exécute la fonction de construction dans un sous-processus + build_docker_image $docker_image .envtmp & + pid=$! + echo "The server is creating and will be running with the PID $pid." >> test/test.log + + # Attend la création du fichier avec un timeout de 600 secondes + wait_for_file_creation_with_timeout build_finished + local r=$? + if [ "$r" -ne 0 ]; then + mv $2 tests_failed/$local_test_id.env + test_failed $2 + return 1 + fi + check_http_server_availability "http://localhost:8080/healthcheck" + local r=$? + if [ "$r" -ne 0 ]; then + mv $2 tests_failed/$local_test_id.env + test_failed $2 + return 1 + fi + if [ "$test_file" == "test/GOLE7.wav" ] ; then + regex=".*Je crois que j'avais des profs.*" + elif [ "$test_file" == "test/bonjour.wav" ]; then + regex=".*Bonjour.*" + fi + if [ "$test_type" == "decoding" ]; then + local start_time=$(date +%s) + local res=$(curl -X POST "http://localhost:8080/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@$test_file;type=audio/wav" 2>/dev/null) + local end_time=$(date +%s) + if [ -z "$res" ]; then + echo "The server didn't transcribed, retrying in 10sec">> test/test.log + sleep 10 + start_time=$(date +%s) + res=$(curl -X POST "http://localhost:8080/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@$test_file;type=audio/wav" 2>/dev/null) + end_time=$(date +%s) + fi + echo "The server has transcribed $3 in $((end_time - start_time)) sec." >> test/test.log + + if [[ $res =~ $regex ]]; then + echo "The string is matching the regex ($regex), the server must has successfully transcribed." >> test/test.log + test_finished $2 + return 0 + else + echo "The string is not matching the regex ($regex), the server didn't transcribed correctly. Output text : $res" >> test/test.log + test_failed $2 + fi + elif [ "$test_type" == "streaming" ]; then + echo "Starting streaming test" >> test/test.log + res=$(python3 test/test_streaming.py --audio_file $test_file) + if [[ $res =~ $regex ]]; then + echo "The string is matching the regex ($regex), the server must has successfully transcribed." >> test/test.log + test_finished $2 + return 0 + else + echo "The string is not matching the regex ($regex), the server didn't transcribed correctly. Output text : $res" >> test/test.log + test_failed $2 + fi + else + echo "Test type $test_type not supported." >> test/test.log + test_failed $2 + fi + return 1 +} + +test_id=0 +trap ctrl_c INT +echo Starting tests at $(date '+%d/%m/%Y %H:%M:%S') > test/test.log +echo '' >> test/test.log + +process_test whisper/Dockerfile.ctranslate2 test/.envtest test/bonjour.wav decoding device=cpu vad=auditok + +process_test whisper/Dockerfile.ctranslate2 test/.envtest test/bonjour.wav decoding device=cuda vad=False +process_test whisper/Dockerfile.ctranslate2 test/.envtest test/bonjour.wav decoding device=cuda vad=auditok +process_test whisper/Dockerfile.ctranslate2 test/.envtest test/bonjour.wav decoding device=cuda vad=silero + +process_test whisper/Dockerfile.torch test/.envtest test/bonjour.wav decoding device=cuda vad=False +# process_test whisper/Dockerfile.torch test/.envtest test/bonjour.wav decoding device=cuda vad=auditok # if auditok works for faster whisper it will work for torch +process_test whisper/Dockerfile.torch test/.envtest test/bonjour.wav decoding device=cuda vad=silero + +process_test whisper/Dockerfile.ctranslate2 test/.envtest test/bonjour.wav streaming device=cuda vad=False +process_test whisper/Dockerfile.ctranslate2 test/.envtest test/bonjour.wav streaming device=cuda vad=auditok + +ending \ No newline at end of file From b0a570d8034fe9ed86a9a779b649db5e0a8f01bf Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Thu, 4 Apr 2024 19:00:20 +0200 Subject: [PATCH 24/50] more robust conditions --- http_server/ingress.py | 2 +- whisper/stt/__init__.py | 4 ++-- whisper/stt/processing/__init__.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/http_server/ingress.py b/http_server/ingress.py index 3d25cfd..424e71c 100644 --- a/http_server/ingress.py +++ b/http_server/ingress.py @@ -24,7 +24,7 @@ logger.setLevel(logging.INFO) # If websocket streaming route is enabled -if os.environ.get("ENABLE_STREAMING", False) in [True, "true", 1]: +if os.environ.get("ENABLE_STREAMING", "false").lower() in ["true", "1"]: from flask_sock import Sock from stt.processing.streaming import ws_streaming diff --git a/whisper/stt/__init__.py b/whisper/stt/__init__.py index d679fd3..8bac458 100644 --- a/whisper/stt/__init__.py +++ b/whisper/stt/__init__.py @@ -12,9 +12,9 @@ # see https://github.com/guillaumekln/faster-whisper/issues/150 os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # GPU in the right order -if os.environ.get("VAD","auditok") in ["true", 1]: +if os.environ.get("VAD","auditok").lower() in ["true", "1"]: VAD = "auditok" -elif os.environ.get("VAD","auditok") in ["false", 0]: +elif os.environ.get("VAD","auditok").lower() in ["false", "0"]: VAD = False else: VAD = os.environ.get("VAD","auditok") diff --git a/whisper/stt/processing/__init__.py b/whisper/stt/processing/__init__.py index 95e103a..b82a482 100644 --- a/whisper/stt/processing/__init__.py +++ b/whisper/stt/processing/__init__.py @@ -92,7 +92,7 @@ def __call__(self, *args, **kwargs): try: model = LazyLoadedModel(model_type, device=device, num_threads=NUM_THREADS) MODEL = (model, alignment_model) - if str(device).lower() != "cpu": + if USE_GPU: warmup() except Exception as err: raise Exception("Failed to load transcription model: {}".format(str(err))) from err From acab0e58066e958086eea4efbb89e509e26ca191 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Thu, 4 Apr 2024 19:01:01 +0200 Subject: [PATCH 25/50] correction and factorization --- test/run_server.sh | 0 test/test.sh | 31 +++++++++++++------------------ 2 files changed, 13 insertions(+), 18 deletions(-) mode change 100644 => 100755 test/run_server.sh mode change 100644 => 100755 test/test.sh diff --git a/test/run_server.sh b/test/run_server.sh old mode 100644 new mode 100755 diff --git a/test/test.sh b/test/test.sh old mode 100644 new mode 100755 index 74603d6..7311c37 --- a/test/test.sh +++ b/test/test.sh @@ -46,7 +46,7 @@ function ending() { build_docker_image() { local docker_image="$1" local config_file="$2" - test/run_server.sh $docker_image $2 > /dev/null 2>&1 + test/run_server.sh $docker_image $2 # > /dev/null 2>&1 } function ctrl_c() { @@ -77,11 +77,11 @@ wait_for_file_creation_with_timeout() { if ps -p $pid > /dev/null; then process_running=true else - echo "Docker building process failed." >> test/test.log + echo "Docker building process failed." | tee -a test/test.log rm $file return 1 fi - echo "File $file has been created. Docker image has been successfully built in $elapsed_time sec." >> test/test.log + echo "File $file has been created. Docker image has been successfully built in $elapsed_time sec." | tee -a test/test.log rm $file return 0 } @@ -98,7 +98,7 @@ check_http_server_availability() { # Test de la disponibilité du serveur HTTP curl -s --head --request GET "$server" | grep "200 OK" if [ $? -eq 0 ]; then - echo "The server $server is available after $elapsed_time sec." >> test/test.log + echo "The server $server is available after $elapsed_time sec." | tee -a test/test.log sleep 2 return 0 fi @@ -108,7 +108,7 @@ check_http_server_availability() { elapsed_time=$((elapsed_time + retry_interval)) done - echo "The server $server is not available after $total_wait_time seconds, server launching must have failed." >> test/test.log + echo "The server $server is not available after $total_wait_time seconds, server launching must have failed." | tee -a test/test.log return 1 } @@ -158,7 +158,7 @@ process_test() # Exécute la fonction de construction dans un sous-processus build_docker_image $docker_image .envtmp & pid=$! - echo "The server is creating and will be running with the PID $pid." >> test/test.log + echo "The server is creating and will be running with the PID $pid." | tee -a test/test.log # Attend la création du fichier avec un timeout de 600 secondes wait_for_file_creation_with_timeout build_finished @@ -224,17 +224,12 @@ trap ctrl_c INT echo Starting tests at $(date '+%d/%m/%Y %H:%M:%S') > test/test.log echo '' >> test/test.log -process_test whisper/Dockerfile.ctranslate2 test/.envtest test/bonjour.wav decoding device=cpu vad=auditok - -process_test whisper/Dockerfile.ctranslate2 test/.envtest test/bonjour.wav decoding device=cuda vad=False -process_test whisper/Dockerfile.ctranslate2 test/.envtest test/bonjour.wav decoding device=cuda vad=auditok -process_test whisper/Dockerfile.ctranslate2 test/.envtest test/bonjour.wav decoding device=cuda vad=silero - -process_test whisper/Dockerfile.torch test/.envtest test/bonjour.wav decoding device=cuda vad=False -# process_test whisper/Dockerfile.torch test/.envtest test/bonjour.wav decoding device=cuda vad=auditok # if auditok works for faster whisper it will work for torch -process_test whisper/Dockerfile.torch test/.envtest test/bonjour.wav decoding device=cuda vad=silero - -process_test whisper/Dockerfile.ctranslate2 test/.envtest test/bonjour.wav streaming device=cuda vad=False -process_test whisper/Dockerfile.ctranslate2 test/.envtest test/bonjour.wav streaming device=cuda vad=auditok +for serving in decoding streaming;do + for vad in False auditok silero; do + for device in cpu cuda; do + process_test whisper/Dockerfile.ctranslate2 test/.envtest test/bonjour.wav $serving DEVICE=$device VAD=$vad + done + done +done ending \ No newline at end of file From d181983e9ab02eba9e63f01e159cb9af2f30db91 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Thu, 4 Apr 2024 19:31:04 +0200 Subject: [PATCH 26/50] In testing: Use .envdefault. Use tiny model. Use local cache or not. --- test/.envtest | 55 ---------------------------------------------- test/run_server.sh | 15 ++++++++++--- test/test.sh | 40 +++++++++++++++++++-------------- 3 files changed, 36 insertions(+), 74 deletions(-) diff --git a/test/.envtest b/test/.envtest index e2f047c..e69de29 100644 --- a/test/.envtest +++ b/test/.envtest @@ -1,55 +0,0 @@ -############################################ -# SERVING PARAMETERS -############################################ -# "http" or "task" or "websocket" -SERVICE_MODE=http - -# Below: used when SERVICE_MODE=task -SERVICE_NAME=stt -SERVICES_BROKER=redis://172.17.0.1:6379 -BROKER_PASS= - -# HTTP PARAMETERS -ENABLE_STREAMING=true - -# WEBSOCKET PARAMETERS -STREAMING_PORT=80 - -############################################ -# STT MODELING PARAMETERS -############################################ - -# The model can be a path to a model (e.g. "/root/.cache/whisper/large-v3.pt", "/root/.cache/huggingface/hub/models--openai--whisper-large-v3"), -# or a model size ("tiny", "base", "small", "medium", "large-v1", "large-v2" or "large-v3") -# or a HuggingFace model name (e.g. "distil-whisper/distil-large-v2") -MODEL=large-v3 - -# The language can be in different formats: "en", "en-US", "English", ... -# If not set or set to "*", the language will be detected automatically. -LANGUAGE=* - -# Prompt to use for the model. This can be used to provide context to the model, to encourage disfluencies or a special behaviour regarding punctuation and capitalization. -PROMPT= - -# An alignment wav2vec model can be used to get word timestamps. -# It can be a path to a model, a language code (fr, en, ...), or "wav2vec" to automatically chose a model for the language -# This option is experimental (and not implemented with ctranslate2). -# ALIGNMENT_MODEL=wav2vec - -VAD_dilatation=0.1 -VAD_min_speech_duration=0.1 -VAD_min_silence_duration=0.1 - -############################################ -# EFFICIENCY PARAMETERS -############################################ - -# Device to use. It can be "cuda" to force/check GPU, "cpu" to force computation on CPU, or a specific GPU ("cuda:0", "cuda:1", ...) -# CUDA_DEVICE_ORDER=PCI_BUS_ID -# CUDA_VISIBLE_DEVICES=0 - -# Number of threads per worker when running on CPU -NUM_THREADS=4 - -# Number of workers minus one (all except from the main one) -CONCURRENCY=1 diff --git a/test/run_server.sh b/test/run_server.sh index 5c9c126..3ec1757 100755 --- a/test/run_server.sh +++ b/test/run_server.sh @@ -1,6 +1,15 @@ #!/bin/bash -docker build . -f $1 -t linto-stt-whisper:latest -cp $2 whisper/.env +dockerfile=$1 +shift +env_file=$1 +shift + +tag=test_`basename $dockerfile` + +docker build . -f $dockerfile -t linto-stt-whisper:$tag touch build_finished -docker run --rm -p 8080:80 --name test_container --env-file whisper/.env --gpus all -v /home/abert/.cache:/root/.cache linto-stt-whisper:latest + +CMD="docker run --rm -p 8080:80 --name test_container --env-file $env_file --gpus all $* linto-stt-whisper:$tag" +echo $CMD +eval $CMD diff --git a/test/test.sh b/test/test.sh index 7311c37..65662b9 100755 --- a/test/test.sh +++ b/test/test.sh @@ -44,9 +44,7 @@ function ending() { # Fonction pour construire l'image Docker build_docker_image() { - local docker_image="$1" - local config_file="$2" - test/run_server.sh $docker_image $2 # > /dev/null 2>&1 + test/run_server.sh $* # > /dev/null 2>&1 } function ctrl_c() { @@ -147,16 +145,22 @@ process_test() echo Test type: $4 >> test/test.log echo '' echo Starting test $test_id + local docker_image="$1" local config_file="$2" - make_env $config_file $5 $6 $7 + local test_file="$3" + local test_type="$4" + local use_local_cache="$5" + make_env $config_file $6 $7 $8 local local_test_id=$test_id test_id=$((test_id + 1)) tests_run=$((tests_run + 1)) - local docker_image="$1" - local test_file="$3" - local test_type="$4" + if [ $use_local_cache -gt 0 ];then + build_args="-v $HOME/.cache:/root/.cache" + else + build_args="" + fi # Exécute la fonction de construction dans un sous-processus - build_docker_image $docker_image .envtmp & + build_docker_image $docker_image .envtmp $build_args & pid=$! echo "The server is creating and will be running with the PID $pid." | tee -a test/test.log @@ -164,15 +168,13 @@ process_test() wait_for_file_creation_with_timeout build_finished local r=$? if [ "$r" -ne 0 ]; then - mv $2 tests_failed/$local_test_id.env - test_failed $2 + test_failed .envtmp return 1 fi check_http_server_availability "http://localhost:8080/healthcheck" local r=$? if [ "$r" -ne 0 ]; then - mv $2 tests_failed/$local_test_id.env - test_failed $2 + test_failed .envtmp return 1 fi if [ "$test_file" == "test/GOLE7.wav" ] ; then @@ -224,10 +226,16 @@ trap ctrl_c INT echo Starting tests at $(date '+%d/%m/%Y %H:%M:%S') > test/test.log echo '' >> test/test.log -for serving in decoding streaming;do - for vad in False auditok silero; do - for device in cpu cuda; do - process_test whisper/Dockerfile.ctranslate2 test/.envtest test/bonjour.wav $serving DEVICE=$device VAD=$vad +# Prepare env file for tests +cat whisper/.envdefault | grep -v "DEVICE=" | grep -v "VAD=" | grep -v "MODEL=" > test/.env +echo "MODEL=tiny" >> test/.env + +for use_local_cache in 0 1;do + for serving in decoding streaming;do + for vad in False auditok silero; do + for device in cuda cpu; do + process_test whisper/Dockerfile.ctranslate2 test/.env test/bonjour.wav $serving $use_local_cache DEVICE=$device VAD=$vad + done done done done From 13f23f5fb96d8eac9e179333eddd5212e7d5ff88 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Thu, 4 Apr 2024 19:37:13 +0200 Subject: [PATCH 27/50] Update files to ignore --- .gitignore | 4 ++-- test/.envtest | 0 2 files changed, 2 insertions(+), 2 deletions(-) delete mode 100644 test/.envtest diff --git a/.gitignore b/.gitignore index 06b349b..1aa14be 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ -start_container.sh .env* -test/* tmp* +test/tests_failed/* +test.log __pycache__ \ No newline at end of file diff --git a/test/.envtest b/test/.envtest deleted file mode 100644 index e69de29..0000000 From 687001b076ec9a7ebaa1bbfd2865a820aaf46b0b Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Thu, 4 Apr 2024 20:03:37 +0200 Subject: [PATCH 28/50] clarify VAD parameters --- whisper/.envdefault | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/whisper/.envdefault b/whisper/.envdefault index 512105b..a8f8794 100644 --- a/whisper/.envdefault +++ b/whisper/.envdefault @@ -36,10 +36,14 @@ PROMPT= # This option is experimental (and not implemented with ctranslate2). # ALIGNMENT_MODEL=wav2vec -VAD=auditok -VAD_dilatation=0.1 -VAD_min_speech_duration=0.1 -VAD_min_silence_duration=0.1 +# Voice Activity Detection (VAD) method +# It can be either "0"/"false" (no VAD), "silero", or "1"/"true"/"auditok" (by default) +# VAD=auditok + +# Voice Activity Detection (VAD) parameters +# VAD_DILATATION=0.1 +# VAD_MIN_SPEECH_DURATION=0.1 +# VAD_MIN_SILENCE_DURATION=0.1 ############################################ # EFFICIENCY PARAMETERS From bc02226452f58bcee74f7a4811fdb5e714cff07e Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Thu, 4 Apr 2024 20:04:46 +0200 Subject: [PATCH 29/50] Add timing information in tests + some tuning --- test/run_server.sh | 5 +++-- test/test.sh | 44 ++++++++++++++++++++++++-------------------- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/test/run_server.sh b/test/run_server.sh index 3ec1757..6f09d2a 100755 --- a/test/run_server.sh +++ b/test/run_server.sh @@ -7,9 +7,10 @@ shift tag=test_`basename $dockerfile` -docker build . -f $dockerfile -t linto-stt-whisper:$tag +docker build . -f $dockerfile -t linto-stt-whisper:$tag > /dev/null 2>&1 touch build_finished CMD="docker run --rm -p 8080:80 --name test_container --env-file $env_file --gpus all $* linto-stt-whisper:$tag" echo $CMD -eval $CMD +grep -v "^#" $env_file | grep "=" | grep -v SERVICE_NAME | grep -v BROKER | grep -v PORT +eval $CMD > /dev/null 2>&1 diff --git a/test/test.sh b/test/test.sh index 65662b9..07f0025 100755 --- a/test/test.sh +++ b/test/test.sh @@ -5,40 +5,41 @@ passed=0 failed=0 function test_failed() { + end=$(date +%s) mkdir -p test/tests_failed mv $2 .envtmp test/tests_failed/$local_test_id.env - echo 'Test failed.' + echo 'Test failed in '$((end-start))' seconds.' | tee -a test/test.log echo 'See test/test.log for more details.' - echo '.Env file has been moved to tests_failed directory.' >> test/test.log - echo 'Test failed.' >> test/test.log + echo '.env file has been moved to test/tests_failed directory.' >> test/test.log failed=$((failed + 1)) - docker stop test_container + docker stop test_container > /dev/null pkill -P $pid echo '' >> test/test.log # exit 1 } function test_finished(){ - echo 'Test passed.' - echo 'Test passed.' >> test/test.log + end=$(date +%s) + echo 'Test passed in '$((end-start))' seconds.' | tee -a test/test.log passed=$((passed + 1)) - docker stop test_container + docker stop test_container > /dev/null pkill -P $pid echo '' >> test/test.log } function ending() { - echo '' - echo 'Ending the tests...' - echo $passed/$tests_run tests passed. - echo $failed/$tests_run tests failed. + docker stop test_container > /dev/null + pkill -P $pid + global_end=$(date +%s) + echo '' | tee -a test/test.log + echo 'Time to run tests: '$((global_end-global_start))' seconds.' | tee -a test/test.log + echo $passed/$tests_run tests passed. | tee -a test/test.log + echo $failed/$tests_run tests failed. | tee -a test/test.log if [ $failed -eq 0 ]; then - echo 'TEST PASSED.' + echo 'TEST PASSED.' | tee -a test/test.log else - echo 'TEST FAILED.' + echo 'TEST FAILED.' | tee -a test/test.log fi - docker stop test_container - pkill -P $pid exit 1 } @@ -49,9 +50,9 @@ build_docker_image() { function ctrl_c() { echo '' - echo "Ctrl + C happened, attempting to stop the server..." - rm build_finished - rm .envtmp + echo "Ctrl + C happened, stopping the server... (do not press Ctrl + C again)" + rm -f build_finished + rm -f .envtmp ending } @@ -79,7 +80,7 @@ wait_for_file_creation_with_timeout() { rm $file return 1 fi - echo "File $file has been created. Docker image has been successfully built in $elapsed_time sec." | tee -a test/test.log + echo "Docker image has been successfully built in $elapsed_time sec." | tee -a test/test.log rm $file return 0 } @@ -159,6 +160,7 @@ process_test() else build_args="" fi + start=$(date +%s) # Exécute la fonction de construction dans un sous-processus build_docker_image $docker_image .envtmp $build_args & pid=$! @@ -230,7 +232,9 @@ echo '' >> test/test.log cat whisper/.envdefault | grep -v "DEVICE=" | grep -v "VAD=" | grep -v "MODEL=" > test/.env echo "MODEL=tiny" >> test/.env -for use_local_cache in 0 1;do +global_start=$(date +%s) + +for use_local_cache in 1 0;do for serving in decoding streaming;do for vad in False auditok silero; do for device in cuda cpu; do From fbc00a1d6045a974a6a2409fe49c3468613915c2 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Thu, 4 Apr 2024 20:10:36 +0200 Subject: [PATCH 30/50] limit the number of tests --- test/test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test.sh b/test/test.sh index 07f0025..b02d393 100755 --- a/test/test.sh +++ b/test/test.sh @@ -234,7 +234,7 @@ echo "MODEL=tiny" >> test/.env global_start=$(date +%s) -for use_local_cache in 1 0;do +for use_local_cache in 0;do for serving in decoding streaming;do for vad in False auditok silero; do for device in cuda cpu; do From 484bdcaec9cc9e19682bdf48e9e31d9998a2e328 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Fri, 5 Apr 2024 09:13:29 +0200 Subject: [PATCH 31/50] Simplify code, and fail early if docker build or run does not complete correctly, add test of celery tasks --- kaldi/docker-entrypoint.sh | 4 +- test/build_and_run_container.sh | 22 ++ test/launch_redis.sh | 3 + test/run_server.sh | 16 -- test/test.sh | 403 ++++++++++++++++++++------------ test/test_celery.py | 11 + whisper/docker-entrypoint.sh | 4 +- 7 files changed, 291 insertions(+), 172 deletions(-) create mode 100755 test/build_and_run_container.sh create mode 100755 test/launch_redis.sh delete mode 100755 test/run_server.sh create mode 100755 test/test_celery.py diff --git a/kaldi/docker-entrypoint.sh b/kaldi/docker-entrypoint.sh index 212b145..74d3b15 100755 --- a/kaldi/docker-entrypoint.sh +++ b/kaldi/docker-entrypoint.sh @@ -25,7 +25,7 @@ fi # Launch parameters, environement variables and dependencies check if [ -z "$SERVICE_MODE" ] then - echo "ERROR: Must specify a serving mode: [ http | task | websocket ]" + echo "ERROR: Must specify an environment variable SERVICE_MODE in [ http | task | websocket ] (None was specified)" exit -1 else if [ "$SERVICE_MODE" = "http" ] @@ -48,7 +48,7 @@ else echo "Running Websocket server on port ${STREAMING_PORT:=80}" python websocket/websocketserver.py else - echo "ERROR: Wrong serving command: $1" + echo "ERROR: Must specify an environment variable SERVICE_MODE in [ http | task | websocket ] (got SERVICE_MODE=$SERVICE_MODE)" exit -1 fi fi diff --git a/test/build_and_run_container.sh b/test/build_and_run_container.sh new file mode 100755 index 0000000..c3685e3 --- /dev/null +++ b/test/build_and_run_container.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +dockerfile=$1 +shift +env_file=$1 +shift + +tag=test_`basename $dockerfile` + +CMD="docker build . -f $dockerfile -t linto-stt-test:$tag" +echo "$ "$CMD +eval $CMD > /dev/null 2>&1 +if [ $? -ne 0 ]; then + echo "Build failed" + exit 1 +fi +touch build_finished + +CMD="docker run --rm -p 8080:80 --name test_container --env-file $env_file --gpus all $* linto-stt-test:$tag" +# grep -v "^#" $env_file | grep "=" | grep -v SERVICE_NAME | grep -v BROKER | grep -v PORT +echo "$ "$CMD +eval $CMD > /dev/null 2>&1 diff --git a/test/launch_redis.sh b/test/launch_redis.sh new file mode 100755 index 0000000..7fe4f68 --- /dev/null +++ b/test/launch_redis.sh @@ -0,0 +1,3 @@ +CMD="docker run --rm -p 6379:6379 --name test_redis redis/redis-stack-server:latest redis-server /etc/redis-stack.conf --protected-mode no --bind 0.0.0.0 --loglevel debug" +echo "$ "$CMD +eval $CMD 2> /dev/null > /dev/null \ No newline at end of file diff --git a/test/run_server.sh b/test/run_server.sh deleted file mode 100755 index 6f09d2a..0000000 --- a/test/run_server.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash - -dockerfile=$1 -shift -env_file=$1 -shift - -tag=test_`basename $dockerfile` - -docker build . -f $dockerfile -t linto-stt-whisper:$tag > /dev/null 2>&1 -touch build_finished - -CMD="docker run --rm -p 8080:80 --name test_container --env-file $env_file --gpus all $* linto-stt-whisper:$tag" -echo $CMD -grep -v "^#" $env_file | grep "=" | grep -v SERVICE_NAME | grep -v BROKER | grep -v PORT -eval $CMD > /dev/null 2>&1 diff --git a/test/test.sh b/test/test.sh index b02d393..7ed3d35 100755 --- a/test/test.sh +++ b/test/test.sh @@ -3,56 +3,79 @@ tests_run=0 passed=0 failed=0 +global_start=$(date +%s) +test_log=test/test.log + +function echo_success() { + # Print a green tick (with colour only on the terminal, not the log file) + printf '\033[0;32m' + printf '\xE2\x9C\x94 ' | tee -a $test_log + printf '\033[0m' # No Color + echo $* | tee -a $test_log +} + +function echo_failure() { + # Print a red cross (with colour only on the terminal, not the log file) + printf '\033[0;31m' + printf '\xE2\x9C\x96 ' | tee -a $test_log + printf '\033[0m' # No Color + echo $* | tee -a $test_log +} + +function echo_note() { + printf '🕓 ' | tee -a $test_log + echo $* | tee -a $test_log +} function test_failed() { - end=$(date +%s) - mkdir -p test/tests_failed - mv $2 .envtmp test/tests_failed/$local_test_id.env - echo 'Test failed in '$((end-start))' seconds.' | tee -a test/test.log - echo 'See test/test.log for more details.' - echo '.env file has been moved to test/tests_failed directory.' >> test/test.log + local end=$(date +%s) failed=$((failed + 1)) - docker stop test_container > /dev/null - pkill -P $pid - echo '' >> test/test.log + echo "-----------------------" | tee -a $test_log + echo_failure "Test failed after "$((end-start))" seconds ($passed/$tests_run tests succeeded in "$((end-global_start))" seconds)" + test_teardown + echo 'See $test_log for more details.' # exit 1 } -function test_finished(){ - end=$(date +%s) - echo 'Test passed in '$((end-start))' seconds.' | tee -a test/test.log +function test_succeeded(){ + local end=$(date +%s) passed=$((passed + 1)) - docker stop test_container > /dev/null - pkill -P $pid - echo '' >> test/test.log + echo "-----------------------" | tee -a $test_log + echo_success "Test passed in "$((end-start))" seconds ($passed/$tests_run tests succeeded in "$((end-global_start))" seconds)" + test_teardown +} + +function test_teardown(){ + rm -f build_finished + local end=$(date +%s) + docker stop test_redis > /dev/null 2> /dev/null + docker stop test_container > /dev/null 2> /dev/null + pkill -P $pids + echo | tee -a $test_log } function ending() { - docker stop test_container > /dev/null - pkill -P $pid - global_end=$(date +%s) - echo '' | tee -a test/test.log - echo 'Time to run tests: '$((global_end-global_start))' seconds.' | tee -a test/test.log - echo $passed/$tests_run tests passed. | tee -a test/test.log - echo $failed/$tests_run tests failed. | tee -a test/test.log - if [ $failed -eq 0 ]; then - echo 'TEST PASSED.' | tee -a test/test.log + local end=$(date +%s) + echo_note 'Time to run tests: '$((end-global_start))' seconds.' + if [ $passed -gt 0 ];then + echo_success $passed/$tests_run tests passed. + fi + if [ $failed -gt 0 ];then + echo_failure $failed/$tests_run tests failed. + fi + if [ $passed -eq $tests_run ]; then + echo_success 'TEST PASSED.' + exit 0 else - echo 'TEST FAILED.' | tee -a test/test.log + echo_failure 'TEST FAILED.' + exit 1 fi - exit 1 -} - -# Fonction pour construire l'image Docker -build_docker_image() { - test/run_server.sh $* # > /dev/null 2>&1 } function ctrl_c() { echo '' - echo "Ctrl + C happened, stopping the server... (do not press Ctrl + C again)" - rm -f build_finished - rm -f .envtmp + echo_failure "Interruption signal received, stopping the server... (do not press Ctrl + C again)" + test_teardown ending } @@ -60,6 +83,7 @@ function ctrl_c() { # Attend la création du fichier avec un timeout de 600 secondes wait_for_file_creation_with_timeout() { local file="$1" + local pid="$2" local timeout=600 # 10 minutes en secondes local start_time=$(date +%s) @@ -67,181 +91,256 @@ wait_for_file_creation_with_timeout() { current_time=$(date +%s) elapsed_time=$((current_time - start_time)) if [ $elapsed_time -ge $timeout ]; then - echo "Timeout. The docker image took too long to be built." >> test/test.log - return 1 + echo "Fatal Error: Timeout. The docker image took too long to be built." | tee -a $test_log + exit 1 + fi + # Vérifie si le processus est toujours en cours d'exécution + if ! ps -p $pid > /dev/null; then + echo "Fatal Error: Docker build failed." | tee -a $test_log + exit 1 fi sleep 1 done - sleep 1 - if ps -p $pid > /dev/null; then - process_running=true - else - echo "Docker building process failed." | tee -a test/test.log - rm $file - return 1 - fi - echo "Docker image has been successfully built in $elapsed_time sec." | tee -a test/test.log + end_time=$(date +%s) + echo_note "Docker image has been successfully built in "$((end_time - start_time))" sec." rm $file + if [[ "$(ps -p $pid > /dev/null)" ]]; then + echo_failure "Fatal Error: Docker container start failed immediately." + exit 1 + fi return 0 } - - check_http_server_availability() { local server="$1" local total_wait_time=600 # 10 minutes en secondes - local retry_interval=5 # Interval entre les tentatives (en secondes) + local retry_interval=1 # Interval entre les tentatives (en secondes) local elapsed_time=0 while [ $elapsed_time -lt $total_wait_time ]; do # Test de la disponibilité du serveur HTTP curl -s --head --request GET "$server" | grep "200 OK" if [ $? -eq 0 ]; then - echo "The server $server is available after $elapsed_time sec." | tee -a test/test.log - sleep 2 + echo_note "$server is available after $elapsed_time sec." return 0 fi + if [[ `docker ps -a -q -f name=test_container | wc -l` -eq 0 ]];then + echo_failure "Fatal error: the server container has stopped for unexpected reason." + exit 1 + fi + # Attendre avant la prochaine tentative sleep $retry_interval elapsed_time=$((elapsed_time + retry_interval)) done - echo "The server $server is not available after $total_wait_time seconds, server launching must have failed." | tee -a test/test.log - return 1 + echo_failure "$server is not available after $total_wait_time seconds, server launching must have failed." + exit 1 } -make_env() +build_and_run_container() { - local env_file="$1" - cp $env_file .envtmp - if [ -z "$2" ]; then - return 0 - else - echo $2 >> test/test.log - echo $2 >> .envtmp - fi - if [ -z "$3" ]; then - return 0 - else - echo $3 >> test/test.log - echo $3 >> .envtmp - fi - if [ -z "$4" ]; then - return 0 - else - echo $4 >> test/test.log - echo $4 >> .envtmp - fi - -} + # Input parameters + local serving="$1" + local docker_image="$2" + local use_local_cache="$3" + env_variables=$(echo $@ | cut -d' ' -f4-) -process_test() -{ - echo '' >> test/test.log - - echo Test $test_id >> test/test.log - echo Docker image: $1 >> test/test.log - echo Audio file: $3 >> test/test.log - echo Test type: $4 >> test/test.log - echo '' - echo Starting test $test_id - local docker_image="$1" - local config_file="$2" - local test_file="$3" - local test_type="$4" - local use_local_cache="$5" - make_env $config_file $6 $7 $8 - local local_test_id=$test_id - test_id=$((test_id + 1)) tests_run=$((tests_run + 1)) + echo "=== Starting test $tests_run ===" | tee -a $test_log + echo "* Docker image: $docker_image" | tee -a $test_log + echo "* Audio file..: $test_file" | tee -a $test_log + build_args="" + for env in $env_variables; do + build_args="$build_args --env $env" + done + build_args="$build_args --env SERVICE_MODE=$serving" if [ $use_local_cache -gt 0 ];then - build_args="-v $HOME/.cache:/root/.cache" - else - build_args="" + build_args="$build_args -v $HOME/.cache:/root/.cache" fi + echo "* Options.....:$build_args" | tee -a $test_log + echo "-----------------------" | tee -a $test_log + + pids="" + if [ "$serving" == "task" ]; then + build_args="$build_args -v `pwd`:/opt/audio" + # Launch Redis server + test/launch_redis.sh & + if [ $? -ne 0 ]; then + echo_failure "Redis server failed to start." + test_failed + exit 1 + fi + pids=$! + fi + start=$(date +%s) # Exécute la fonction de construction dans un sous-processus - build_docker_image $docker_image .envtmp $build_args & - pid=$! - echo "The server is creating and will be running with the PID $pid." | tee -a test/test.log + rm -f build_finished + test/build_and_run_container.sh $docker_image test/.env $build_args & + local pid=$! + pids="$pids $pid" # Attend la création du fichier avec un timeout de 600 secondes - wait_for_file_creation_with_timeout build_finished - local r=$? - if [ "$r" -ne 0 ]; then - test_failed .envtmp - return 1 + wait_for_file_creation_with_timeout build_finished $pid + if [ $? -ne 0 ]; then + test_failed + exit 1 fi - check_http_server_availability "http://localhost:8080/healthcheck" - local r=$? - if [ "$r" -ne 0 ]; then - test_failed .envtmp - return 1 +} + +run_test() +{ + local serving="$1" + shift + if [ "$serving" == "http" ]; then + run_test_http $* + elif [ "$serving" == "task" ]; then + run_test_task $* + else + echo_failure "Error: Unknown serving mode '$serving'." + exit 1 fi +} + +run_test() +{ + # Input parameters + local serving="$1" + shift + local test_file="$1" + shift if [ "$test_file" == "test/GOLE7.wav" ] ; then regex=".*Je crois que j'avais des profs.*" elif [ "$test_file" == "test/bonjour.wav" ]; then regex=".*Bonjour.*" fi - if [ "$test_type" == "decoding" ]; then + + build_and_run_container $serving $* + + if [ "$serving" == "http" ]; then + check_http_server_availability "http://localhost:8080/healthcheck" + if [ $? -ne 0 ]; then + test_failed + return 1 + fi + + # Test HTTP + CMD='curl -X POST "http://localhost:8080/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@$test_file;type=audio/wav"' + echo "$ "$CMD local start_time=$(date +%s) - local res=$(curl -X POST "http://localhost:8080/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@$test_file;type=audio/wav" 2>/dev/null) + local res=$(eval $CMD 2>/dev/null) local end_time=$(date +%s) if [ -z "$res" ]; then - echo "The server didn't transcribed, retrying in 10sec">> test/test.log - sleep 10 - start_time=$(date +%s) - res=$(curl -X POST "http://localhost:8080/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@$test_file;type=audio/wav" 2>/dev/null) + echo_failure "The server didn't transcribed, retrying in 2 sec..." + sleep 2 + res=$(eval $CMD 2>/dev/null) end_time=$(date +%s) fi - echo "The server has transcribed $3 in $((end_time - start_time)) sec." >> test/test.log - - if [[ $res =~ $regex ]]; then - echo "The string is matching the regex ($regex), the server must has successfully transcribed." >> test/test.log - test_finished $2 - return 0 - else - echo "The string is not matching the regex ($regex), the server didn't transcribed correctly. Output text : $res" >> test/test.log - test_failed $2 + echo_note "HTTP route 'transcribe' has transcribed $test_file in $((end_time - start_time)) sec." + if [[ ! $res =~ $regex ]]; then + echo_note "Error: The string '$res' is not matching the regex ($regex), the server didn't transcribed correctly. Output text : $res" + test_failed + return 1 fi - elif [ "$test_type" == "streaming" ]; then - echo "Starting streaming test" >> test/test.log - res=$(python3 test/test_streaming.py --audio_file $test_file) - if [[ $res =~ $regex ]]; then - echo "The string is matching the regex ($regex), the server must has successfully transcribed." >> test/test.log - test_finished $2 - return 0 - else - echo "The string is not matching the regex ($regex), the server didn't transcribed correctly. Output text : $res" >> test/test.log - test_failed $2 + + # Test streaming + CMD="python3 test/test_streaming.py --audio_file $test_file" + echo "$ "$CMD + start_time=$(date +%s) + res=$(eval $CMD 2> >(tee -a $test_log >&2)) + end_time=$(date +%s) + echo_note "HTTP websocket has transcribed $test_file in $((end_time - start_time)) sec." + if [[ ! $res =~ $regex ]]; then + echo_failure "The string '$res' is not matching the regex ($regex), the server didn't transcribed correctly. Output text : $res" + test_failed + return 1 fi - else - echo "Test type $test_type not supported." >> test/test.log - test_failed $2 + + elif [ "$serving" == "task" ]; then + + CMD="python3 test/test_celery.py $test_file" + echo "$ "$CMD + local start_time=$(date +%s) + local res=$(eval $CMD 2> >(tee -a $test_log >&2)) + local end_time=$(date +%s) + if [ $? -ne 0 ]; then + test_failed + return 1 + fi + echo_note "Celery task has transcribed $test_file in $((end_time - start_time)) sec." + if [[ ! $res =~ $regex ]]; then + echo_failure "The string '$res' is not matching the regex ($regex), the server didn't transcribed correctly. Output text : $res" + test_failed + return 1 + fi + fi - return 1 + + test_succeeded + return 0 } -test_id=0 trap ctrl_c INT -echo Starting tests at $(date '+%d/%m/%Y %H:%M:%S') > test/test.log -echo '' >> test/test.log +echo Starting tests at $(date '+%d/%m/%Y %H:%M:%S') | tee $test_log +echo '' | tee -a $test_log # Prepare env file for tests -cat whisper/.envdefault | grep -v "DEVICE=" | grep -v "VAD=" | grep -v "MODEL=" > test/.env -echo "MODEL=tiny" >> test/.env +cat whisper/.envdefault | grep -v "DEVICE=" | grep -v "VAD=" | grep -v "MODEL=" | grep -v "SERVICE_MODE=" > test/.env -global_start=$(date +%s) +####################### +# List of what to test -for use_local_cache in 0;do - for serving in decoding streaming;do - for vad in False auditok silero; do - for device in cuda cpu; do - process_test whisper/Dockerfile.ctranslate2 test/.env test/bonjour.wav $serving $use_local_cache DEVICE=$device VAD=$vad - done - done - done +dockerfiles+=" whisper/Dockerfile.ctranslate2" +dockerfiles+=" whisper/Dockerfile.ctranslate2.cpu" +dockerfiles+=" whisper/Dockerfile.torch" +dockerfiles+=" whisper/Dockerfile.torch.cpu" + +use_local_caches+=" 1" +# use_local_caches+=" 0" + +servings+=" task" +servings+=" http" + +vads+=" NONE" +vads+=" false" +vads+=" auditok" +vads+=" silero" + +devices+=" NONE" +devices+=" cpu" +devices+=" cuda" + +models+=" tiny" + +####################### +# Run tests + +for use_local_cache in $use_local_caches;do +for dockerfile in $dockerfiles; do +for device in $devices; do +for vad in $vads; do +for model in $models; do +for serving in $servings; do + + # Tests to skip + if [[ "$device" != "cpu" ]] && [[ `echo $dockerfile | grep cpu | wc -l` -gt 0 ]]; then continue; fi + + # Set env variables + envs="" + if [ "$vad" != "NONE" ]; then envs="$envs VAD=$vad"; fi + if [ "$device" != "NONE" ]; then envs="$envs DEVICE=$device"; fi + envs="$envs MODEL=$model" + + # Run test + run_test $serving test/bonjour.wav $dockerfile $use_local_cache $envs + +done +done +done +done +done done ending \ No newline at end of file diff --git a/test/test_celery.py b/test/test_celery.py new file mode 100755 index 0000000..64537a9 --- /dev/null +++ b/test/test_celery.py @@ -0,0 +1,11 @@ +import sys +from celery import Celery +celery = Celery(broker='redis://localhost:6379/0', backend='redis://localhost:6379/1') +r = celery.send_task( + 'transcribe_task', + ( + sys.argv[1], + True, + ), + queue='stt') +print(r.get()) diff --git a/whisper/docker-entrypoint.sh b/whisper/docker-entrypoint.sh index 71ca438..09ea120 100755 --- a/whisper/docker-entrypoint.sh +++ b/whisper/docker-entrypoint.sh @@ -14,7 +14,7 @@ fi # Launch parameters, environement variables and dependencies check if [ -z "$SERVICE_MODE" ] then - echo "ERROR: Must specify a serving mode: [ http | task | websocket ]" + echo "ERROR: Must specify an environment variable SERVICE_MODE in [ http | task | websocket ] (None was specified)" exit -1 else if [ "$SERVICE_MODE" = "http" ] @@ -46,7 +46,7 @@ else echo "Running Websocket server on port ${STREAMING_PORT:=80}" python3 websocket/websocketserver.py else - echo "ERROR: Wrong serving command: $SERVICE_MODE" + echo "ERROR: Must specify an environment variable SERVICE_MODE in [ http | task | websocket ] (got SERVICE_MODE=$SERVICE_MODE)" exit -1 fi fi From f83211a7083c8464adf66f063f903cfbc48cff6d Mon Sep 17 00:00:00 2001 From: AudranBert Date: Fri, 5 Apr 2024 11:56:56 +0200 Subject: [PATCH 32/50] fix: streaming+torch+silero --- whisper/stt/processing/vad.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/whisper/stt/processing/vad.py b/whisper/stt/processing/vad.py index 22af275..c94239a 100644 --- a/whisper/stt/processing/vad.py +++ b/whisper/stt/processing/vad.py @@ -6,7 +6,6 @@ _silero_vad_model = {} _has_onnx = None -_vad_import = None def remove_non_speech( @@ -153,7 +152,7 @@ def get_vad_segments( method: str or list VAD method to use (auditok, silero, silero:v3.1) """ - global _silero_vad_model, _silero_get_speech_ts, _has_onnx, _vad_import + global _silero_vad_model, _silero_get_speech_ts, _has_onnx if isinstance(method, list): # Explicit timestamps segments = [ @@ -226,11 +225,8 @@ def apply_folder_hack(): if need_folder_hack: apply_folder_hack() try: - if _vad_import is None: - from torch.hub import load as torch_load - - _vad_import = torch_load - silero_vad_model, utils = _vad_import( + from torch.hub import load as torch_load + silero_vad_model, utils = torch_load( repo_or_dir=repo_or_dir, model="silero_vad", onnx=onnx, @@ -258,8 +254,11 @@ def apply_folder_hack(): _silero_get_speech_ts = utils[0] # Cheap normalization of the volume - audio = audio / max(0.1, audio.abs().max()) - + + if isinstance(audio, np.ndarray): + audio = audio / max(0.1, np.max(np.abs(audio))) + else: + audio = audio / max(0.1, audio.abs().max()) segments = _silero_get_speech_ts( audio, _silero_vad_model[version], @@ -270,13 +269,7 @@ def apply_folder_hack(): ) elif method == "auditok": - if _vad_import is None: - from auditok import split - - _vad_import = split - # Cheap normalization of the volume - # audio = audio / max(0.1, audio.abs().max()) if isinstance(audio, np.ndarray): audio = audio / max(0.1, np.max(np.abs(audio))) data = (audio * 32767).astype(np.int16).tobytes() @@ -285,8 +278,8 @@ def apply_folder_hack(): data = (audio.numpy() * 32767).astype(np.int16).tobytes() audio_duration = len(audio) / sample_rate - - segments = _vad_import( + from auditok import split + segments = split( data, sampling_rate=sample_rate, # sampling frequency in Hz channels=1, # number of channels From 114e0d05fbc29f0f105bfd19652b3ae52c375681 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Fri, 5 Apr 2024 12:21:26 +0200 Subject: [PATCH 33/50] Update gitignore --- .gitignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.gitignore b/.gitignore index 1aa14be..725405d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,4 @@ .env* tmp* -test/tests_failed/* test.log __pycache__ \ No newline at end of file From 78f84d3ed71d6849e6b9f93e0b91e44767534cc3 Mon Sep 17 00:00:00 2001 From: AudranBert Date: Fri, 5 Apr 2024 12:34:39 +0200 Subject: [PATCH 34/50] fix args --- whisper/stt/processing/vad.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/whisper/stt/processing/vad.py b/whisper/stt/processing/vad.py index c94239a..7137410 100644 --- a/whisper/stt/processing/vad.py +++ b/whisper/stt/processing/vad.py @@ -15,7 +15,7 @@ def remove_non_speech( min_silence_duration=1, dilatation=0.5, sample_rate=16000, - method="silero", + method="auditok", avoid_empty_speech=False, return_format="tuple", ): @@ -55,10 +55,8 @@ def remove_non_speech( segments = get_vad_segments( audio, sample_rate=sample_rate, - output_sample=True, min_speech_duration=min_speech_duration, min_silence_duration=min_silence_duration, - dilatation=dilatation, method=method, ) segments = apply_dilatation(segments, dilatation, sample_rate, audio, output_sample=True) @@ -130,11 +128,9 @@ def do_convert_timestamps(segments, t, t2=None): def get_vad_segments( audio, sample_rate=16000, - output_sample=False, min_speech_duration=0.1, min_silence_duration=0.1, - dilatation=0.5, - method="silero", + method="auditok", ): """ Get speech segments from audio using the method VAD @@ -158,8 +154,6 @@ def get_vad_segments( segments = [ {"start": s * sample_rate, "end": e * sample_rate} for (s, e) in method ] - dilatation = 0 - elif isinstance(method, str) and method.startswith("silero"): version = None _, version = check_vad_method(method, True) From 54e22b69930087aa0a109abcd6bdbd5a2d49f2cd Mon Sep 17 00:00:00 2001 From: AudranBert Date: Fri, 5 Apr 2024 13:54:17 +0200 Subject: [PATCH 35/50] rm unused func --- test/test.sh | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/test/test.sh b/test/test.sh index 7ed3d35..83acc71 100755 --- a/test/test.sh +++ b/test/test.sh @@ -190,20 +190,6 @@ build_and_run_container() fi } -run_test() -{ - local serving="$1" - shift - if [ "$serving" == "http" ]; then - run_test_http $* - elif [ "$serving" == "task" ]; then - run_test_task $* - else - echo_failure "Error: Unknown serving mode '$serving'." - exit 1 - fi -} - run_test() { # Input parameters From 9c74a8f760fcbc21c80af6278f3d04d278a65d49 Mon Sep 17 00:00:00 2001 From: AudranBert Date: Tue, 9 Apr 2024 09:54:09 +0200 Subject: [PATCH 36/50] update requirements faster-whisper --- whisper/requirements.ctranslate2.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/whisper/requirements.ctranslate2.txt b/whisper/requirements.ctranslate2.txt index 5fc25d2..87b5e80 100644 --- a/whisper/requirements.ctranslate2.txt +++ b/whisper/requirements.ctranslate2.txt @@ -14,4 +14,4 @@ websockets auditok #faster_whisper==1.0.1 # This is version faster_whisper==1.0.1 + option for (persistent) prompt + fix for large-v3 -git+https://github.com/linto-ai/faster-whisper.git@external_vad \ No newline at end of file +git+https://github.com/linto-ai/faster-whisper.git \ No newline at end of file From 3533c22682481afdfcf3a650f6a37ec28b1e93fb Mon Sep 17 00:00:00 2001 From: AudranBert Date: Thu, 11 Apr 2024 16:10:33 +0200 Subject: [PATCH 37/50] fix kaldi task mode + convert test.sh to test.py + add kaldi test --- kaldi/.envdefault | 4 +- kaldi/RELEASE.md | 3 + test/build_and_run_container.sh | 22 --- test/launch_redis.sh | 3 - test/test.py | 265 +++++++++++++++++++++++++ test/test.sh | 332 -------------------------------- test/test_config.ini | 4 + 7 files changed, 274 insertions(+), 359 deletions(-) delete mode 100755 test/build_and_run_container.sh delete mode 100755 test/launch_redis.sh create mode 100644 test/test.py delete mode 100755 test/test.sh create mode 100644 test/test_config.ini diff --git a/kaldi/.envdefault b/kaldi/.envdefault index 33a394c..f22fa40 100644 --- a/kaldi/.envdefault +++ b/kaldi/.envdefault @@ -7,8 +7,8 @@ ENABLE_STREAMING=true # TASK PARAMETERS SERVICE_NAME=stt -SERVICES_BROKER=redis://192.168.0.1:6379 -BROKER_PASS=password +SERVICES_BROKER=redis://172.17.0.1:6379 +BROKER_PASS= # WEBSOCKET PARAMETERS STREAMING_PORT=80 diff --git a/kaldi/RELEASE.md b/kaldi/RELEASE.md index 6ce152a..be44814 100644 --- a/kaldi/RELEASE.md +++ b/kaldi/RELEASE.md @@ -1,3 +1,6 @@ +# 1.0.3 +- Fix task mode for kaldi by updating SERVICES_BROKER and BROKER_PASS in .envdefault + # 1.0.1 - Fix streaming mode (websocket) in linto-stt-kaldi diff --git a/test/build_and_run_container.sh b/test/build_and_run_container.sh deleted file mode 100755 index c3685e3..0000000 --- a/test/build_and_run_container.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash - -dockerfile=$1 -shift -env_file=$1 -shift - -tag=test_`basename $dockerfile` - -CMD="docker build . -f $dockerfile -t linto-stt-test:$tag" -echo "$ "$CMD -eval $CMD > /dev/null 2>&1 -if [ $? -ne 0 ]; then - echo "Build failed" - exit 1 -fi -touch build_finished - -CMD="docker run --rm -p 8080:80 --name test_container --env-file $env_file --gpus all $* linto-stt-test:$tag" -# grep -v "^#" $env_file | grep "=" | grep -v SERVICE_NAME | grep -v BROKER | grep -v PORT -echo "$ "$CMD -eval $CMD > /dev/null 2>&1 diff --git a/test/launch_redis.sh b/test/launch_redis.sh deleted file mode 100755 index 7fe4f68..0000000 --- a/test/launch_redis.sh +++ /dev/null @@ -1,3 +0,0 @@ -CMD="docker run --rm -p 6379:6379 --name test_redis redis/redis-stack-server:latest redis-server /etc/redis-stack.conf --protected-mode no --bind 0.0.0.0 --loglevel debug" -echo "$ "$CMD -eval $CMD 2> /dev/null > /dev/null \ No newline at end of file diff --git a/test/test.py b/test/test.py new file mode 100644 index 0000000..d754037 --- /dev/null +++ b/test/test.py @@ -0,0 +1,265 @@ +import unittest +import os +import time +import subprocess +import requests +import argparse +import re +from ddt import ddt, data, idata +import signal +import sys + + + +class TestContainer(): + def __init__(self, use_kaldi=False): + self.use_kaldi = use_kaldi + self.cleanup() + + def echo_success(self, message): + print('\033[0;32m' + u'\u2714' + '\033[0m ' + message) + + def echo_failure(self, message): + print('\033[0;31m' + u'\u2716' + '\033[0m ' + message) + + def echo_note(self, message): + print(u'\u231B' + ' ' + message) + + def echo_command(self, message): + print(f"$ {message}") + + def test_failed(self, message): + self.echo_failure(message) + self.cleanup() + + def test_succeeded(self): + self.echo_success(f"Test passed.") + self.cleanup() + + def cleanup(self): + try: + os.remove("build_finished") + except FileNotFoundError: + pass + subprocess.run(["docker", "stop", "test_redis"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + subprocess.run(["docker", "stop", "test_container"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + + def process_output(self, p): + l = p.communicate()[0].decode('utf-8').replace('\n', '\n\t') + e = p.communicate()[1].decode('utf-8').replace('\n', '\n\t') + return f" \u2192 Log Message:\n\t{l}\n \u2192 Error Message:\n\t{e}" + + + def check_http_server_availability(self, server, pid): + total_wait_time = 60 # 10 minutes in seconds + retry_interval = 1 # Interval between attempts (in seconds) + elapsed_time = 0 + + while elapsed_time < total_wait_time: + try: + response = requests.head(server) + if response.status_code == 200: + self.echo_note(f"Server: {server} is available after {elapsed_time} sec.") + return + except requests.ConnectionError: + pass + if pid.poll() is not None: + return f"The server container has stopped for an unexpected reason.\n{self.process_output(pid)}" + + time.sleep(retry_interval) + elapsed_time += retry_interval + + return f"Server: {server} is not available after {total_wait_time} seconds, server launching must have failed." + + def build_and_run_container(self, serving, docker_image, use_local_cache, env_variables): + self.echo_note(f"* Docker image: {docker_image}") + self.echo_note(f"* Options.....: {env_variables}") + build_args = "" + for i, env in enumerate(env_variables.split()): + if i>0 and env_variables.split()[i-1] =="-v": + build_args += f"-v {env} " + elif env=="-v": + continue + else: + build_args += f"--env {env} " + build_args += f"--env SERVICE_MODE={serving} " + if use_local_cache > 0: + from pathlib import Path + home = str(Path.home()) + build_args += f"-v {home}/.cache:/root/.cache " + + if serving == "task": + build_args += "-v {}/:/opt/audio ".format(os.getcwd()) + CMD = "docker run --rm -p 6379:6379 --name test_redis redis/redis-stack-server:latest redis-server /etc/redis-stack.conf --protected-mode no --bind 0.0.0.0 --loglevel debug" + self.echo_command(CMD) + p = subprocess.Popen(CMD.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if p.poll() is not None: + return f"Redis server failed to start.\n{self.process_output(p)}", None + time.sleep(2) + tag = f"test_{os.path.basename(docker_image)}" + CMD = f'docker build . -f {docker_image} -t linto-stt-test:{tag}' + self.echo_command(CMD) + start_time = time.time() + p = subprocess.Popen(CMD.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + p.wait() + end_time = time.time() + if p.poll() != 0: + return f"Docker build failed.\n{self.process_output(p)}", None + self.echo_note(f"Docker image has been successfully built in {end_time - start_time:.0f} sec.") + CMD=f"docker run --rm -p 8080:80 --name test_container --env-file test/.env --gpus all {build_args} linto-stt-test:{tag}" + self.echo_command(CMD) + p = subprocess.Popen(CMD.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if p.poll() is not None: + return f"Docker container failed to start.\n{self.process_output(p)}", None + return None, p + + def transcribe(self, command, regex, test_file, error_message, success_message, timeout=None): + start = time.time() + res = subprocess.run(command, shell=True, timeout=timeout, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + end = time.time() + res = res.stdout.decode('utf-8') + if not re.search(regex, res): + message = f"{error_message}: The string '{res}' is not matching the regex ({regex}), the server didn't transcribe correctly." + self.test_failed(message) + return message + self.echo_note(f"{success_message} has transcribed {test_file} in {end - start:.0f} sec.") + return + + def run_test(self, serving, test_file, docker_image, use_local_cache, env_variables): + import warnings + warnings.simplefilter("ignore", ResourceWarning) + regex = "" + if test_file == "test/bonjour.wav": + regex = re.compile("[b|B]onjour") + r, pid=self.build_and_run_container(serving, docker_image, use_local_cache, env_variables) + if r!=None: + self.test_failed(r) + return r + if serving == "http": + r=self.check_http_server_availability("http://localhost:8080/healthcheck", pid) + if r!=None: + self.test_failed(r) + return r + CMD = f'curl -X POST "http://localhost:8080/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@{test_file};type=audio/wav"' + self.echo_command(CMD) + r = self.transcribe(CMD, regex, test_file, "Error transcription", "HTTP route 'transcribe'") + if r!=None: + return r + CMD = f'python3 test/test_streaming.py --audio_file {test_file}' + self.echo_command(CMD) + r = self.transcribe(CMD, regex, test_file, "Error streaming", "HTTP route 'streaming'") + if r!=None: + return r + elif serving == "task": + # you can be stuck here if the server crashed bc the task will be in the queue forever + CMD = f"python3 test/test_celery.py {test_file}" + self.echo_command(CMD) + r = self.transcribe(CMD, regex, test_file, "Error task", "TASK route", timeout=60) + if r!=None: + return r + self.test_succeeded() + return True + + +def generate_whisper_test_setups(): + dockerfiles = ["whisper/Dockerfile.ctranslate2", "whisper/Dockerfile.ctranslate2.cpu", + "whisper/Dockerfile.torch", "whisper/Dockerfile.torch.cpu"] + + use_local_caches = [1] # Add 0 for additional cache usage + + servings = ["task", "http"] + # servings = ["task"] + # servings = ["http"] + + vads = ["NONE", "false", "auditok", "silero"] + # vads = ["NONE"] + devices = ["NONE", "cpu", "cuda"] + # devices = ["NONE"] + models = ["tiny"] + + for use_local_cache in use_local_caches: + for dockerfile in dockerfiles: + for device in devices: + for vad in vads: + for model in models: + for serving in servings: + # try: + if dockerfile.endswith("cpu") and device != "cpu": + continue + envs = "" + if vad != "NONE": + envs += f"VAD={vad} " + if device != "NONE": + envs += f"DEVICE={device} " + envs += f"MODEL={model}" + + yield serving, "test/bonjour.wav", dockerfile, use_local_cache, envs + +def generate_kaldi_test_setups(): + dockerfiles = ["kaldi/Dockerfile"] + + use_local_caches = [1] # Add 0 for additional cache usage + + servings = ["task", "http"] + # servings = ["http"] + + for use_local_cache in use_local_caches: + for dockerfile in dockerfiles: + for serving in servings: + envs = "" + yield serving, "test/bonjour.wav", dockerfile, use_local_cache, envs + +def copy_env_file(env_file, key_words_to_remove): + with open(env_file, "r") as f: + lines = f.readlines() + with open("test/.env", "w") as f: + for line in lines: + if not any([word in line for word in key_words_to_remove]): + f.write(line) + +@ddt +class TestRunner(unittest.TestCase): + + @idata(generate_kaldi_test_setups()) + def test_kaldi_integration(self, setup): + print() + if AM_PATH is None or LM_PATH is None: + self.fail("AM or LM path not provided. Skipping kaldi test.") + if not os.path.exists(AM_PATH) or not os.path.exists(LM_PATH): + self.fail(f"AM or LM path not found: {AM_PATH} or {LM_PATH}") + copy_env_file("kaldi/.envdefault", ["SERVICE_MODE"]) + serving, test_file, dockerfile, use_local_cache, envs = setup + envs += f"-v {AM_PATH}:/opt/AM -v {LM_PATH}:/opt/LM" + testobject = TestContainer(use_kaldi=True) + test_result = testobject.run_test(serving, test_file, dockerfile, use_local_cache, envs) + if test_result!=True: + self.fail(test_result) + + + @idata(generate_whisper_test_setups()) + def test_whisper_integration(self, setup): + print() + copy_env_file("whisper/.envdefault", ["VAD", "DEVICE", "MODEL", "SERVICE_MODE"]) + serving, test_file, dockerfile, use_local_cache, envs = setup + testobject = TestContainer() + test_result = testobject.run_test(serving, test_file, dockerfile, use_local_cache, envs) + if test_result!=True: + self.fail(test_result) + + + + +AM_PATH = None +LM_PATH = None + +if __name__ == '__main__': + from configparser import ConfigParser + config = ConfigParser() + + config.read('test/test_config.ini') + + AM_PATH = config.get('kaldi', 'AM_PATH') + LM_PATH = config.get('kaldi', 'LM_PATH') + + unittest.main(verbosity=2) + # unittest.main() diff --git a/test/test.sh b/test/test.sh deleted file mode 100755 index 83acc71..0000000 --- a/test/test.sh +++ /dev/null @@ -1,332 +0,0 @@ -#!/bin/bash - -tests_run=0 -passed=0 -failed=0 -global_start=$(date +%s) -test_log=test/test.log - -function echo_success() { - # Print a green tick (with colour only on the terminal, not the log file) - printf '\033[0;32m' - printf '\xE2\x9C\x94 ' | tee -a $test_log - printf '\033[0m' # No Color - echo $* | tee -a $test_log -} - -function echo_failure() { - # Print a red cross (with colour only on the terminal, not the log file) - printf '\033[0;31m' - printf '\xE2\x9C\x96 ' | tee -a $test_log - printf '\033[0m' # No Color - echo $* | tee -a $test_log -} - -function echo_note() { - printf '🕓 ' | tee -a $test_log - echo $* | tee -a $test_log -} - -function test_failed() { - local end=$(date +%s) - failed=$((failed + 1)) - echo "-----------------------" | tee -a $test_log - echo_failure "Test failed after "$((end-start))" seconds ($passed/$tests_run tests succeeded in "$((end-global_start))" seconds)" - test_teardown - echo 'See $test_log for more details.' - # exit 1 -} - -function test_succeeded(){ - local end=$(date +%s) - passed=$((passed + 1)) - echo "-----------------------" | tee -a $test_log - echo_success "Test passed in "$((end-start))" seconds ($passed/$tests_run tests succeeded in "$((end-global_start))" seconds)" - test_teardown -} - -function test_teardown(){ - rm -f build_finished - local end=$(date +%s) - docker stop test_redis > /dev/null 2> /dev/null - docker stop test_container > /dev/null 2> /dev/null - pkill -P $pids - echo | tee -a $test_log -} - -function ending() { - local end=$(date +%s) - echo_note 'Time to run tests: '$((end-global_start))' seconds.' - if [ $passed -gt 0 ];then - echo_success $passed/$tests_run tests passed. - fi - if [ $failed -gt 0 ];then - echo_failure $failed/$tests_run tests failed. - fi - if [ $passed -eq $tests_run ]; then - echo_success 'TEST PASSED.' - exit 0 - else - echo_failure 'TEST FAILED.' - exit 1 - fi -} - -function ctrl_c() { - echo '' - echo_failure "Interruption signal received, stopping the server... (do not press Ctrl + C again)" - test_teardown - ending -} - - -# Attend la création du fichier avec un timeout de 600 secondes -wait_for_file_creation_with_timeout() { - local file="$1" - local pid="$2" - local timeout=600 # 10 minutes en secondes - local start_time=$(date +%s) - - while [ ! -f "$file" ]; do - current_time=$(date +%s) - elapsed_time=$((current_time - start_time)) - if [ $elapsed_time -ge $timeout ]; then - echo "Fatal Error: Timeout. The docker image took too long to be built." | tee -a $test_log - exit 1 - fi - # Vérifie si le processus est toujours en cours d'exécution - if ! ps -p $pid > /dev/null; then - echo "Fatal Error: Docker build failed." | tee -a $test_log - exit 1 - fi - sleep 1 - done - end_time=$(date +%s) - echo_note "Docker image has been successfully built in "$((end_time - start_time))" sec." - rm $file - if [[ "$(ps -p $pid > /dev/null)" ]]; then - echo_failure "Fatal Error: Docker container start failed immediately." - exit 1 - fi - return 0 -} - -check_http_server_availability() { - local server="$1" - local total_wait_time=600 # 10 minutes en secondes - local retry_interval=1 # Interval entre les tentatives (en secondes) - local elapsed_time=0 - - while [ $elapsed_time -lt $total_wait_time ]; do - # Test de la disponibilité du serveur HTTP - curl -s --head --request GET "$server" | grep "200 OK" - if [ $? -eq 0 ]; then - echo_note "$server is available after $elapsed_time sec." - return 0 - fi - - if [[ `docker ps -a -q -f name=test_container | wc -l` -eq 0 ]];then - echo_failure "Fatal error: the server container has stopped for unexpected reason." - exit 1 - fi - - # Attendre avant la prochaine tentative - sleep $retry_interval - elapsed_time=$((elapsed_time + retry_interval)) - done - - echo_failure "$server is not available after $total_wait_time seconds, server launching must have failed." - exit 1 -} - -build_and_run_container() -{ - # Input parameters - local serving="$1" - local docker_image="$2" - local use_local_cache="$3" - env_variables=$(echo $@ | cut -d' ' -f4-) - - tests_run=$((tests_run + 1)) - echo "=== Starting test $tests_run ===" | tee -a $test_log - echo "* Docker image: $docker_image" | tee -a $test_log - echo "* Audio file..: $test_file" | tee -a $test_log - build_args="" - for env in $env_variables; do - build_args="$build_args --env $env" - done - build_args="$build_args --env SERVICE_MODE=$serving" - if [ $use_local_cache -gt 0 ];then - build_args="$build_args -v $HOME/.cache:/root/.cache" - fi - echo "* Options.....:$build_args" | tee -a $test_log - echo "-----------------------" | tee -a $test_log - - pids="" - if [ "$serving" == "task" ]; then - build_args="$build_args -v `pwd`:/opt/audio" - # Launch Redis server - test/launch_redis.sh & - if [ $? -ne 0 ]; then - echo_failure "Redis server failed to start." - test_failed - exit 1 - fi - pids=$! - fi - - start=$(date +%s) - # Exécute la fonction de construction dans un sous-processus - rm -f build_finished - test/build_and_run_container.sh $docker_image test/.env $build_args & - local pid=$! - pids="$pids $pid" - - # Attend la création du fichier avec un timeout de 600 secondes - wait_for_file_creation_with_timeout build_finished $pid - if [ $? -ne 0 ]; then - test_failed - exit 1 - fi -} - -run_test() -{ - # Input parameters - local serving="$1" - shift - local test_file="$1" - shift - if [ "$test_file" == "test/GOLE7.wav" ] ; then - regex=".*Je crois que j'avais des profs.*" - elif [ "$test_file" == "test/bonjour.wav" ]; then - regex=".*Bonjour.*" - fi - - build_and_run_container $serving $* - - if [ "$serving" == "http" ]; then - check_http_server_availability "http://localhost:8080/healthcheck" - if [ $? -ne 0 ]; then - test_failed - return 1 - fi - - # Test HTTP - CMD='curl -X POST "http://localhost:8080/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@$test_file;type=audio/wav"' - echo "$ "$CMD - local start_time=$(date +%s) - local res=$(eval $CMD 2>/dev/null) - local end_time=$(date +%s) - if [ -z "$res" ]; then - echo_failure "The server didn't transcribed, retrying in 2 sec..." - sleep 2 - res=$(eval $CMD 2>/dev/null) - end_time=$(date +%s) - fi - echo_note "HTTP route 'transcribe' has transcribed $test_file in $((end_time - start_time)) sec." - if [[ ! $res =~ $regex ]]; then - echo_note "Error: The string '$res' is not matching the regex ($regex), the server didn't transcribed correctly. Output text : $res" - test_failed - return 1 - fi - - # Test streaming - CMD="python3 test/test_streaming.py --audio_file $test_file" - echo "$ "$CMD - start_time=$(date +%s) - res=$(eval $CMD 2> >(tee -a $test_log >&2)) - end_time=$(date +%s) - echo_note "HTTP websocket has transcribed $test_file in $((end_time - start_time)) sec." - if [[ ! $res =~ $regex ]]; then - echo_failure "The string '$res' is not matching the regex ($regex), the server didn't transcribed correctly. Output text : $res" - test_failed - return 1 - fi - - elif [ "$serving" == "task" ]; then - - CMD="python3 test/test_celery.py $test_file" - echo "$ "$CMD - local start_time=$(date +%s) - local res=$(eval $CMD 2> >(tee -a $test_log >&2)) - local end_time=$(date +%s) - if [ $? -ne 0 ]; then - test_failed - return 1 - fi - echo_note "Celery task has transcribed $test_file in $((end_time - start_time)) sec." - if [[ ! $res =~ $regex ]]; then - echo_failure "The string '$res' is not matching the regex ($regex), the server didn't transcribed correctly. Output text : $res" - test_failed - return 1 - fi - - fi - - test_succeeded - return 0 -} - -trap ctrl_c INT -echo Starting tests at $(date '+%d/%m/%Y %H:%M:%S') | tee $test_log -echo '' | tee -a $test_log - -# Prepare env file for tests -cat whisper/.envdefault | grep -v "DEVICE=" | grep -v "VAD=" | grep -v "MODEL=" | grep -v "SERVICE_MODE=" > test/.env - -####################### -# List of what to test - -dockerfiles+=" whisper/Dockerfile.ctranslate2" -dockerfiles+=" whisper/Dockerfile.ctranslate2.cpu" -dockerfiles+=" whisper/Dockerfile.torch" -dockerfiles+=" whisper/Dockerfile.torch.cpu" - -use_local_caches+=" 1" -# use_local_caches+=" 0" - -servings+=" task" -servings+=" http" - -vads+=" NONE" -vads+=" false" -vads+=" auditok" -vads+=" silero" - -devices+=" NONE" -devices+=" cpu" -devices+=" cuda" - -models+=" tiny" - -####################### -# Run tests - -for use_local_cache in $use_local_caches;do -for dockerfile in $dockerfiles; do -for device in $devices; do -for vad in $vads; do -for model in $models; do -for serving in $servings; do - - # Tests to skip - if [[ "$device" != "cpu" ]] && [[ `echo $dockerfile | grep cpu | wc -l` -gt 0 ]]; then continue; fi - - # Set env variables - envs="" - if [ "$vad" != "NONE" ]; then envs="$envs VAD=$vad"; fi - if [ "$device" != "NONE" ]; then envs="$envs DEVICE=$device"; fi - envs="$envs MODEL=$model" - - # Run test - run_test $serving test/bonjour.wav $dockerfile $use_local_cache $envs - -done -done -done -done -done -done - -ending \ No newline at end of file diff --git a/test/test_config.ini b/test/test_config.ini new file mode 100644 index 0000000..334727a --- /dev/null +++ b/test/test_config.ini @@ -0,0 +1,4 @@ + +[kaldi] +AM_PATH=/home/abert/Linagora/models/linSTT_AM_fr-FR_v2.0.0 +LM_PATH=/home/abert/Linagora/models/decoding_graph_fr-FR_Medium_v2.1.0 \ No newline at end of file From 753334e4800c1f03ca63361e3cd90eb636f54098 Mon Sep 17 00:00:00 2001 From: AudranBert Date: Thu, 11 Apr 2024 19:10:36 +0200 Subject: [PATCH 38/50] add more tests --- test/test.py | 66 ++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 54 insertions(+), 12 deletions(-) diff --git a/test/test.py b/test/test.py index d754037..b6391e2 100644 --- a/test/test.py +++ b/test/test.py @@ -12,8 +12,8 @@ class TestContainer(): - def __init__(self, use_kaldi=False): - self.use_kaldi = use_kaldi + def __init__(self, show_failed_tests=True): + self.show_failed_tests = show_failed_tests self.cleanup() def echo_success(self, message): @@ -29,8 +29,9 @@ def echo_command(self, message): print(f"$ {message}") def test_failed(self, message): - self.echo_failure(message) - self.cleanup() + if self.show_failed_tests: + self.echo_failure(message) + self.cleanup() def test_succeeded(self): self.echo_success(f"Test passed.") @@ -117,6 +118,8 @@ def transcribe(self, command, regex, test_file, error_message, success_message, start = time.time() res = subprocess.run(command, shell=True, timeout=timeout, stdout=subprocess.PIPE, stderr=subprocess.PIPE) end = time.time() + if res.returncode != 0: + raise FileNotFoundError(f"Error: {res.stderr.decode('utf-8')}") res = res.stdout.decode('utf-8') if not re.search(regex, res): message = f"{error_message}: The string '{res}' is not matching the regex ({regex}), the server didn't transcribe correctly." @@ -168,13 +171,9 @@ def generate_whisper_test_setups(): use_local_caches = [1] # Add 0 for additional cache usage servings = ["task", "http"] - # servings = ["task"] - # servings = ["http"] vads = ["NONE", "false", "auditok", "silero"] - # vads = ["NONE"] devices = ["NONE", "cpu", "cuda"] - # devices = ["NONE"] models = ["tiny"] for use_local_cache in use_local_caches: @@ -201,7 +200,6 @@ def generate_kaldi_test_setups(): use_local_caches = [1] # Add 0 for additional cache usage servings = ["task", "http"] - # servings = ["http"] for use_local_cache in use_local_caches: for dockerfile in dockerfiles: @@ -230,7 +228,7 @@ def test_kaldi_integration(self, setup): copy_env_file("kaldi/.envdefault", ["SERVICE_MODE"]) serving, test_file, dockerfile, use_local_cache, envs = setup envs += f"-v {AM_PATH}:/opt/AM -v {LM_PATH}:/opt/LM" - testobject = TestContainer(use_kaldi=True) + testobject = TestContainer() test_result = testobject.run_test(serving, test_file, dockerfile, use_local_cache, envs) if test_result!=True: self.fail(test_result) @@ -246,7 +244,52 @@ def test_whisper_integration(self, setup): if test_result!=True: self.fail(test_result) - + def test_whisper_curl_not_existing_file(self): + print() + copy_env_file("whisper/.envdefault", ["VAD", "DEVICE", "MODEL", "SERVICE_MODE"]) + serving = "http" + test_file = "notexisting" + dockerfile = "whisper/Dockerfile.ctranslate2" + use_local_cache = 1 + envs = "MODEL=tiny " + testobject = TestContainer() + with self.assertRaises(FileNotFoundError): + testobject.run_test(serving, test_file, dockerfile, use_local_cache, envs) + + def test_cuda_on_cpu_dockerfile(self): + print() + copy_env_file("whisper/.envdefault", ["VAD", "DEVICE", "MODEL", "SERVICE_MODE"]) + serving = "http" + test_file = "test/bonjour.wav" + dockerfile = "whisper/Dockerfile.ctranslate2.cpu" + use_local_cache = 1 + envs = "MODEL=tiny DEVICE=cuda" + testobject = TestContainer(show_failed_tests=False) + self.assertIn("The server container has stopped for an unexpected reason.", testobject.run_test(serving, test_file, dockerfile, use_local_cache, envs)) + + def test_model_whisper(self): + print() + copy_env_file("whisper/.envdefault", ["VAD", "DEVICE", "MODEL", "SERVICE_MODE"]) + serving = "http" + test_file = "test/bonjour.wav" + dockerfile = "whisper/Dockerfile.ctranslate2" + use_local_cache = 1 + envs = "MODEL=small" + testobject = TestContainer() + test_result = testobject.run_test(serving, test_file, dockerfile, use_local_cache, envs) + if test_result!=True: + self.fail(test_result) + + def test_vad_whisper(self): + print() + copy_env_file("whisper/.envdefault", ["VAD", "DEVICE", "MODEL", "SERVICE_MODE"]) + serving = "http" + test_file = "test/bonjour.wav" + dockerfile = "whisper/Dockerfile.ctranslate2" + use_local_cache = 1 + envs = "VAD=whatever" + testobject = TestContainer(show_failed_tests=False) + self.assertIn("The server container has stopped for an unexpected reason.", testobject.run_test(serving, test_file, dockerfile, use_local_cache, envs)) AM_PATH = None @@ -262,4 +305,3 @@ def test_whisper_integration(self, setup): LM_PATH = config.get('kaldi', 'LM_PATH') unittest.main(verbosity=2) - # unittest.main() From e24dbbc40efe8a921a56bc02f1cebf7ece92f99c Mon Sep 17 00:00:00 2001 From: AudranBert Date: Fri, 12 Apr 2024 11:45:09 +0200 Subject: [PATCH 39/50] fix .ini --- test/test.py | 2 +- test/test_celery.py | 23 ++++++++++++++--------- test/test_config.ini | 4 ++-- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/test/test.py b/test/test.py index b6391e2..accc6f0 100644 --- a/test/test.py +++ b/test/test.py @@ -221,7 +221,7 @@ class TestRunner(unittest.TestCase): @idata(generate_kaldi_test_setups()) def test_kaldi_integration(self, setup): print() - if AM_PATH is None or LM_PATH is None: + if AM_PATH is None or LM_PATH is None or AM_PATH=="" or LM_PATH=="": self.fail("AM or LM path not provided. Skipping kaldi test.") if not os.path.exists(AM_PATH) or not os.path.exists(LM_PATH): self.fail(f"AM or LM path not found: {AM_PATH} or {LM_PATH}") diff --git a/test/test_celery.py b/test/test_celery.py index 64537a9..59ed62e 100755 --- a/test/test_celery.py +++ b/test/test_celery.py @@ -1,11 +1,16 @@ import sys from celery import Celery -celery = Celery(broker='redis://localhost:6379/0', backend='redis://localhost:6379/1') -r = celery.send_task( - 'transcribe_task', - ( - sys.argv[1], - True, - ), - queue='stt') -print(r.get()) + +def transcribe_task(file_path): + celery = Celery(broker='redis://localhost:6379/0', backend='redis://localhost:6379/1') + r = celery.send_task( + 'transcribe_task', + ( + file_path, + True, + ), + queue='stt') + return r.get() + +if __name__ == '__main__': + print(transcribe_task(sys.argv[1])) \ No newline at end of file diff --git a/test/test_config.ini b/test/test_config.ini index 334727a..73d70ab 100644 --- a/test/test_config.ini +++ b/test/test_config.ini @@ -1,4 +1,4 @@ [kaldi] -AM_PATH=/home/abert/Linagora/models/linSTT_AM_fr-FR_v2.0.0 -LM_PATH=/home/abert/Linagora/models/decoding_graph_fr-FR_Medium_v2.1.0 \ No newline at end of file +AM_PATH= +LM_PATH= \ No newline at end of file From d6d74f65b056820ecb708230233e5d8640698e15 Mon Sep 17 00:00:00 2001 From: AudranBert Date: Mon, 15 Apr 2024 11:05:26 +0200 Subject: [PATCH 40/50] add docs for tests --- test/README.md | 70 +++++++++++++++++++++++++++++++++++++++++ test/test_deployment.sh | 2 +- 2 files changed, 71 insertions(+), 1 deletion(-) create mode 100644 test/README.md diff --git a/test/README.md b/test/README.md new file mode 100644 index 0000000..dc1d3f9 --- /dev/null +++ b/test/README.md @@ -0,0 +1,70 @@ +# LinTO-STT-Tests + +## Use tests + +### HTTP - transcribe + +You can test your http server by using: + +```bash +test_deployment.sh +``` + +> ⚠️ Be sure to check that you use the right port (default port for testing: 8080). + +### HTTP - streaming + +You can test your http streaming route by using: +```bash +test_streaming.py +``` +Be sure to have a working microphone. +> ⚠️ Be sure to check that you use the right port (default port for testing: 8080). + +If you want to test the streaming on a file: +```bash +test_streaming.py --audio_file bonjour.wav +``` + +### Task + +You can test your deployment of the task service mode by using: + +```bash +test_celery.py AUDIO.wav +``` + +with AUDIO.wav the file you want to test on, for example, you can use bonjour.wav. + +> ⚠️ Be sure to check that you use the same port in your .env and in test_celery.py (default port for testing: 6379) + + +## Unit tests + +You will need to install: +```bash +pip3 install ddt +``` + +To test the Kaldi models, you will need to download the models (see [Kaldi models](kaldi/README.md)) and then fill the test_config.ini AM_PATH and LM_PATH fields. +> ⚠️ If you don't specify the models, the tests about Kaldi will fail. + +To launch the test you can do : +```bash +python test/test.py +``` + +> ⚠️ Be sure to launch it from the root folder of the repository. + +If you want the test to stop at the first fail use the -f flag: +```bash +python test/test.py -f +``` +If you want to run a subset of test you can use -k with a part of a test name. for example only kaldi tests: +```bash +python test/test.py -k kaldi +``` +or test with VAD=auditok, DEVICE=cuda: +```bash +python test/test.py -k VAD_auditok_DEVICE_cuda +``` \ No newline at end of file diff --git a/test/test_deployment.sh b/test/test_deployment.sh index b1b8d36..84daac8 100755 --- a/test/test_deployment.sh +++ b/test/test_deployment.sh @@ -1 +1 @@ -curl -X POST "http://localhost:8888/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@bonjour.wav;type=audio/wav" +curl -X POST "http://localhost:8080/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@bonjour.wav;type=audio/wav" From c6d7684834499c762b48b7d896ac9a41d837dda1 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Mon, 15 Apr 2024 11:21:40 +0200 Subject: [PATCH 41/50] remove useless dependencies --- test/test.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/test.py b/test/test.py index accc6f0..f8ce4bb 100644 --- a/test/test.py +++ b/test/test.py @@ -3,11 +3,8 @@ import time import subprocess import requests -import argparse import re -from ddt import ddt, data, idata -import signal -import sys +from ddt import ddt, idata From d2a07b6f757d5776642085a7e09e506eef69f661 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Mon, 15 Apr 2024 11:22:11 +0200 Subject: [PATCH 42/50] use contiguous version numbers --- kaldi/RELEASE.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kaldi/RELEASE.md b/kaldi/RELEASE.md index be44814..f5bc967 100644 --- a/kaldi/RELEASE.md +++ b/kaldi/RELEASE.md @@ -1,4 +1,4 @@ -# 1.0.3 +# 1.0.2 - Fix task mode for kaldi by updating SERVICES_BROKER and BROKER_PASS in .envdefault # 1.0.1 From 5808c0dd0d4e9519ad3668f63497bc3060544b45 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Mon, 15 Apr 2024 11:23:22 +0200 Subject: [PATCH 43/50] more details about VAD --- whisper/RELEASE.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/whisper/RELEASE.md b/whisper/RELEASE.md index 03922f3..84de80f 100644 --- a/whisper/RELEASE.md +++ b/whisper/RELEASE.md @@ -1,6 +1,7 @@ # 1.0.3 +- Make Voice Activity Detection (VAD) configurable +- Change default VAD from silero (neural approach) to auditok (heuristical approach), because silero can have unpredictable behaviour on different corner cases - Streaming support -- Refactoring VAD system - New NUM_THREADS env variable to control the number of threads - Load the model when launching the service (not at the first request) From e6ff146527a27b67fd0647ccd1a1647e4fefed09 Mon Sep 17 00:00:00 2001 From: AudranBert Date: Mon, 15 Apr 2024 11:30:19 +0200 Subject: [PATCH 44/50] fix and add links to test/readme --- README.md | 1 + kaldi/README.md | 5 ++++- test/README.md | 2 +- whisper/README.md | 5 ++++- 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 10f860f..492dd8c 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ LinTO-STT can either be used as a standalone transcription service or deployed w The following families of STT models are currently supported (please refer to respective documentation for more details): * [Kaldi models](kaldi/README.md) * [Whisper models](whisper/README.md) +* [Test scripts](test/README.md) LinTO-STT can either be used as a standalone transcription service or deployed within a micro-services infrastructure using a message broker connector. diff --git a/kaldi/README.md b/kaldi/README.md index 70584f2..e7c2036 100644 --- a/kaldi/README.md +++ b/kaldi/README.md @@ -205,7 +205,10 @@ On a successfull transcription the returned object is a json object structured a * The confidence field contains the overall confidence for the transcription. (0.0 if with_metadata=False) -## Test +## Tests + +See [Test scripts](../test/README.md) for more details about testing. + ### Curl You can test you http API using curl: ```bash diff --git a/test/README.md b/test/README.md index dc1d3f9..f8e1ebd 100644 --- a/test/README.md +++ b/test/README.md @@ -46,7 +46,7 @@ You will need to install: pip3 install ddt ``` -To test the Kaldi models, you will need to download the models (see [Kaldi models](kaldi/README.md)) and then fill the test_config.ini AM_PATH and LM_PATH fields. +To test the Kaldi models, you will need to download the models (see [Kaldi models](../kaldi/README.md)) and then fill the test_config.ini AM_PATH and LM_PATH fields. > ⚠️ If you don't specify the models, the tests about Kaldi will fail. To launch the test you can do : diff --git a/whisper/README.md b/whisper/README.md index ec1953a..52d2122 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -343,7 +343,10 @@ On a successfull transcription the returned object is a json object structured a * The confidence field contains the overall confidence for the transcription. (0.0 if with_metadata=False) -## Test +## Tests + +See [Test scripts](../test/README.md) for more details about testing. + ### Curl You can test your http API using curl: From 7c252657c2db674d82b438d1b12d55334a1e1343 Mon Sep 17 00:00:00 2001 From: AudranBert Date: Mon, 15 Apr 2024 14:38:29 +0200 Subject: [PATCH 45/50] timeout parameter + improve error message when timeout --- test/test.py | 17 ++++++++++------- test/test_config.ini | 2 ++ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/test/test.py b/test/test.py index f8ce4bb..8f07893 100644 --- a/test/test.py +++ b/test/test.py @@ -9,8 +9,8 @@ class TestContainer(): - def __init__(self, show_failed_tests=True): - self.show_failed_tests = show_failed_tests + def __init__(self, ignore_failed_tests=None): + self.ignore_failed_tests = ignore_failed_tests self.cleanup() def echo_success(self, message): @@ -26,7 +26,7 @@ def echo_command(self, message): print(f"$ {message}") def test_failed(self, message): - if self.show_failed_tests: + if self.ignore_failed_tests is None or not message.startswith(self.ignore_failed_tests): self.echo_failure(message) self.cleanup() @@ -49,7 +49,7 @@ def process_output(self, p): def check_http_server_availability(self, server, pid): - total_wait_time = 60 # 10 minutes in seconds + total_wait_time = SERVER_STARTING_TIMEOUT # 10 minutes in seconds retry_interval = 1 # Interval between attempts (in seconds) elapsed_time = 0 @@ -67,7 +67,7 @@ def check_http_server_availability(self, server, pid): time.sleep(retry_interval) elapsed_time += retry_interval - return f"Server: {server} is not available after {total_wait_time} seconds, server launching must have failed." + return f"Server: {server} is not available after {total_wait_time} seconds, server launching must have failed.\n{self.process_output(pid)}" def build_and_run_container(self, serving, docker_image, use_local_cache, env_variables): self.echo_note(f"* Docker image: {docker_image}") @@ -261,7 +261,7 @@ def test_cuda_on_cpu_dockerfile(self): dockerfile = "whisper/Dockerfile.ctranslate2.cpu" use_local_cache = 1 envs = "MODEL=tiny DEVICE=cuda" - testobject = TestContainer(show_failed_tests=False) + testobject = TestContainer(ignore_failed_tests="The server container has stopped for an unexpected reason.") self.assertIn("The server container has stopped for an unexpected reason.", testobject.run_test(serving, test_file, dockerfile, use_local_cache, envs)) def test_model_whisper(self): @@ -285,12 +285,13 @@ def test_vad_whisper(self): dockerfile = "whisper/Dockerfile.ctranslate2" use_local_cache = 1 envs = "VAD=whatever" - testobject = TestContainer(show_failed_tests=False) + testobject = TestContainer(ignore_failed_tests="The server container has stopped for an unexpected reason.") self.assertIn("The server container has stopped for an unexpected reason.", testobject.run_test(serving, test_file, dockerfile, use_local_cache, envs)) AM_PATH = None LM_PATH = None +SERVER_STARTING_TIMEOUT = 60 if __name__ == '__main__': from configparser import ConfigParser @@ -298,6 +299,8 @@ def test_vad_whisper(self): config.read('test/test_config.ini') + SERVER_STARTING_TIMEOUT = int(config.get('server', 'STARTING_TIMEOUT')) if config.get('server', 'STARTING_TIMEOUT')!="" else SERVER_STARTING_TIMEOUT + AM_PATH = config.get('kaldi', 'AM_PATH') LM_PATH = config.get('kaldi', 'LM_PATH') diff --git a/test/test_config.ini b/test/test_config.ini index 73d70ab..76bf72d 100644 --- a/test/test_config.ini +++ b/test/test_config.ini @@ -1,3 +1,5 @@ +[server] +STARTING_TIMEOUT=60 [kaldi] AM_PATH= From 31f608bca9a06be07bc389d4fbe3a3a0110e9c48 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Mon, 15 Apr 2024 16:35:37 +0200 Subject: [PATCH 46/50] Factorize code. Do not build several time the same image. Can call the test script from anywhere --- test/README.md | 2 +- test/test.py | 326 +++++++++++++++++++++++-------------------------- 2 files changed, 155 insertions(+), 173 deletions(-) diff --git a/test/README.md b/test/README.md index f8e1ebd..1739688 100644 --- a/test/README.md +++ b/test/README.md @@ -46,7 +46,7 @@ You will need to install: pip3 install ddt ``` -To test the Kaldi models, you will need to download the models (see [Kaldi models](../kaldi/README.md)) and then fill the test_config.ini AM_PATH and LM_PATH fields. +To test the Kaldi models, you will need to download the models (see [Kaldi models](../kaldi/README.md)) and then fill the AM_PATH and LM_PATH fields in the [test_config.ini file](test_config.ini). > ⚠️ If you don't specify the models, the tests about Kaldi will fail. To launch the test you can do : diff --git a/test/test.py b/test/test.py index 8f07893..56f5877 100644 --- a/test/test.py +++ b/test/test.py @@ -5,12 +5,70 @@ import requests import re from ddt import ddt, idata +from pathlib import Path +import warnings +TESTDIR = os.path.dirname(os.path.realpath(__file__)) +ROOTDIR = os.path.dirname(TESTDIR) +os.chdir(ROOTDIR) +TESTDIR = os.path.basename(TESTDIR) -class TestContainer(): - def __init__(self, ignore_failed_tests=None): - self.ignore_failed_tests = ignore_failed_tests + +def generate_whisper_test_setups(): + dockerfiles = ["whisper/Dockerfile.ctranslate2", "whisper/Dockerfile.ctranslate2.cpu", + "whisper/Dockerfile.torch", "whisper/Dockerfile.torch.cpu"] + + servings = ["http", "task"] + + vads = ["NONE", "false", "auditok", "silero"] + devices = ["NONE", "cpu", "cuda"] + models = ["tiny"] + + for dockerfile in dockerfiles: + for device in devices: + for vad in vads: + for model in models: + for serving in servings: + # try: + if dockerfile.endswith("cpu") and device != "cpu": + continue + env_variables = "" + if vad != "NONE": + env_variables += f"VAD={vad} " + if device != "NONE": + env_variables += f"DEVICE={device} " + env_variables += f"MODEL={model}" + + yield dockerfile, serving, env_variables + +def generate_kaldi_test_setups(): + dockerfiles = ["kaldi/Dockerfile"] + + servings = ["http", "task"] + + for dockerfile in dockerfiles: + for serving in servings: + env_variables = "" + yield dockerfile, serving, env_variables + +def copy_env_file(env_file, env_variables=""): + env_variables = env_variables.split() + env_variables.append("SERVICE_MODE=") + with open(env_file, "r") as f: + lines = f.readlines() + with open(f"{TESTDIR}/.env", "w") as f: + for line in lines: + if not any([line.startswith(b.split("=")[0] + "=") for b in env_variables]): + f.write(line) + +@ddt +class TestRunner(unittest.TestCase): + + built_images = [] + + def __init__(self, *args, **kwargs): + super(TestRunner, self).__init__(*args, **kwargs) self.cleanup() def echo_success(self, message): @@ -25,13 +83,16 @@ def echo_note(self, message): def echo_command(self, message): print(f"$ {message}") - def test_failed(self, message): - if self.ignore_failed_tests is None or not message.startswith(self.ignore_failed_tests): + def report_failure(self, message, expect_failure=True): + if expect_failure: self.echo_failure(message) - self.cleanup() + self.cleanup() + if expect_failure: + self.fail(message) + return message - def test_succeeded(self): - self.echo_success(f"Test passed.") + def report_success(self): + self.echo_success("Test passed.") self.cleanup() def cleanup(self): @@ -69,7 +130,7 @@ def check_http_server_availability(self, server, pid): return f"Server: {server} is not available after {total_wait_time} seconds, server launching must have failed.\n{self.process_output(pid)}" - def build_and_run_container(self, serving, docker_image, use_local_cache, env_variables): + def build_and_run_container(self, serving, docker_image, env_variables, use_local_cache): self.echo_note(f"* Docker image: {docker_image}") self.echo_note(f"* Options.....: {env_variables}") build_args = "" @@ -81,32 +142,39 @@ def build_and_run_container(self, serving, docker_image, use_local_cache, env_va else: build_args += f"--env {env} " build_args += f"--env SERVICE_MODE={serving} " - if use_local_cache > 0: - from pathlib import Path + if use_local_cache: home = str(Path.home()) build_args += f"-v {home}/.cache:/root/.cache " if serving == "task": + # Launch redis build_args += "-v {}/:/opt/audio ".format(os.getcwd()) - CMD = "docker run --rm -p 6379:6379 --name test_redis redis/redis-stack-server:latest redis-server /etc/redis-stack.conf --protected-mode no --bind 0.0.0.0 --loglevel debug" - self.echo_command(CMD) - p = subprocess.Popen(CMD.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + cmd = "docker run --rm -p 6379:6379 --name test_redis redis/redis-stack-server:latest redis-server /etc/redis-stack.conf --protected-mode no --bind 0.0.0.0 --loglevel debug" + self.echo_command(cmd) + p = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) if p.poll() is not None: + self.cleanup() return f"Redis server failed to start.\n{self.process_output(p)}", None time.sleep(2) + tag = f"test_{os.path.basename(docker_image)}" - CMD = f'docker build . -f {docker_image} -t linto-stt-test:{tag}' - self.echo_command(CMD) - start_time = time.time() - p = subprocess.Popen(CMD.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) - p.wait() - end_time = time.time() - if p.poll() != 0: - return f"Docker build failed.\n{self.process_output(p)}", None - self.echo_note(f"Docker image has been successfully built in {end_time - start_time:.0f} sec.") - CMD=f"docker run --rm -p 8080:80 --name test_container --env-file test/.env --gpus all {build_args} linto-stt-test:{tag}" - self.echo_command(CMD) - p = subprocess.Popen(CMD.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if tag not in TestRunner.built_images: + # Only build images that have not been built yet + cmd = f'docker build . -f {docker_image} -t linto-stt-test:{tag}' + self.echo_command(cmd) + start_time = time.time() + p = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + p.wait() + end_time = time.time() + if p.poll() != 0: + self.cleanup() + return f"Docker build failed.\n{self.process_output(p)}", None + self.echo_note(f"Docker image has been successfully built in {end_time - start_time:.0f} sec.") + TestRunner.built_images.append(tag) + + cmd=f"docker run --rm -p 8080:80 --name test_container --env-file {TESTDIR}/.env --gpus all {build_args} linto-stt-test:{tag}" + self.echo_command(cmd) + p = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) if p.poll() is not None: return f"Docker container failed to start.\n{self.process_output(p)}", None return None, p @@ -120,173 +188,87 @@ def transcribe(self, command, regex, test_file, error_message, success_message, res = res.stdout.decode('utf-8') if not re.search(regex, res): message = f"{error_message}: The string '{res}' is not matching the regex ({regex}), the server didn't transcribe correctly." - self.test_failed(message) - return message + return self.report_failure(message) self.echo_note(f"{success_message} has transcribed {test_file} in {end - start:.0f} sec.") return - def run_test(self, serving, test_file, docker_image, use_local_cache, env_variables): - import warnings + def run_test(self, docker_image="whisper/Dockerfile.ctranslate2", serving="http", env_variables="", test_file=f"{TESTDIR}/bonjour.wav", use_local_cache=True, expect_failure=False): warnings.simplefilter("ignore", ResourceWarning) regex = "" - if test_file == "test/bonjour.wav": - regex = re.compile("[b|B]onjour") - r, pid=self.build_and_run_container(serving, docker_image, use_local_cache, env_variables) - if r!=None: - self.test_failed(r) - return r + if os.path.basename(test_file) == "bonjour.wav": + regex = re.compile("[bB]onjour") + r, pid = self.build_and_run_container(serving, docker_image, env_variables, use_local_cache) + if r: + return self.report_failure(r, expect_failure=expect_failure) if serving == "http": r=self.check_http_server_availability("http://localhost:8080/healthcheck", pid) - if r!=None: - self.test_failed(r) - return r - CMD = f'curl -X POST "http://localhost:8080/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@{test_file};type=audio/wav"' - self.echo_command(CMD) - r = self.transcribe(CMD, regex, test_file, "Error transcription", "HTTP route 'transcribe'") - if r!=None: - return r - CMD = f'python3 test/test_streaming.py --audio_file {test_file}' - self.echo_command(CMD) - r = self.transcribe(CMD, regex, test_file, "Error streaming", "HTTP route 'streaming'") - if r!=None: - return r + if r: + return self.report_failure(r, expect_failure=expect_failure) + cmd = f'curl -X POST "http://localhost:8080/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@{test_file};type=audio/wav"' + self.echo_command(cmd) + r = self.transcribe(cmd, regex, test_file, "Error transcription", "HTTP route 'transcribe'") + if r: + return self.report_failure(r, expect_failure=expect_failure) + cmd = f"python3 {TESTDIR}/test_streaming.py --audio_file {test_file}" + self.echo_command(cmd) + r = self.transcribe(cmd, regex, test_file, "Error streaming", "HTTP route 'streaming'") elif serving == "task": # you can be stuck here if the server crashed bc the task will be in the queue forever - CMD = f"python3 test/test_celery.py {test_file}" - self.echo_command(CMD) - r = self.transcribe(CMD, regex, test_file, "Error task", "TASK route", timeout=60) - if r!=None: - return r - self.test_succeeded() - return True - - -def generate_whisper_test_setups(): - dockerfiles = ["whisper/Dockerfile.ctranslate2", "whisper/Dockerfile.ctranslate2.cpu", - "whisper/Dockerfile.torch", "whisper/Dockerfile.torch.cpu"] - - use_local_caches = [1] # Add 0 for additional cache usage - - servings = ["task", "http"] - - vads = ["NONE", "false", "auditok", "silero"] - devices = ["NONE", "cpu", "cuda"] - models = ["tiny"] - - for use_local_cache in use_local_caches: - for dockerfile in dockerfiles: - for device in devices: - for vad in vads: - for model in models: - for serving in servings: - # try: - if dockerfile.endswith("cpu") and device != "cpu": - continue - envs = "" - if vad != "NONE": - envs += f"VAD={vad} " - if device != "NONE": - envs += f"DEVICE={device} " - envs += f"MODEL={model}" - - yield serving, "test/bonjour.wav", dockerfile, use_local_cache, envs - -def generate_kaldi_test_setups(): - dockerfiles = ["kaldi/Dockerfile"] - - use_local_caches = [1] # Add 0 for additional cache usage - - servings = ["task", "http"] - - for use_local_cache in use_local_caches: - for dockerfile in dockerfiles: - for serving in servings: - envs = "" - yield serving, "test/bonjour.wav", dockerfile, use_local_cache, envs - -def copy_env_file(env_file, key_words_to_remove): - with open(env_file, "r") as f: - lines = f.readlines() - with open("test/.env", "w") as f: - for line in lines: - if not any([word in line for word in key_words_to_remove]): - f.write(line) + cmd = f"python3 {TESTDIR}/test_celery.py {test_file}" + self.echo_command(cmd) + r = self.transcribe(cmd, regex, test_file, "Error task", "TASK route", timeout=60) + else: + raise RuntimeError(f"Unknown serving mode: {serving}") + if r: + return self.report_failure(r, expect_failure=expect_failure) + if not expect_failure: + self.report_success() + return "" + + def setUp(self): + # Print an empty line because unittest prints the name of the test first, without a newline + print() -@ddt -class TestRunner(unittest.TestCase): - @idata(generate_kaldi_test_setups()) - def test_kaldi_integration(self, setup): - print() + def test_01_kaldi_integration(self, setup): + dockerfile, serving, env_variables = setup if AM_PATH is None or LM_PATH is None or AM_PATH=="" or LM_PATH=="": self.fail("AM or LM path not provided. Skipping kaldi test.") if not os.path.exists(AM_PATH) or not os.path.exists(LM_PATH): self.fail(f"AM or LM path not found: {AM_PATH} or {LM_PATH}") - copy_env_file("kaldi/.envdefault", ["SERVICE_MODE"]) - serving, test_file, dockerfile, use_local_cache, envs = setup - envs += f"-v {AM_PATH}:/opt/AM -v {LM_PATH}:/opt/LM" - testobject = TestContainer() - test_result = testobject.run_test(serving, test_file, dockerfile, use_local_cache, envs) - if test_result!=True: - self.fail(test_result) + copy_env_file("kaldi/.envdefault") + env_variables += f"-v {AM_PATH}:/opt/AM -v {LM_PATH}:/opt/LM" + self.run_test(dockerfile, serving=serving, env_variables=env_variables) @idata(generate_whisper_test_setups()) - def test_whisper_integration(self, setup): - print() - copy_env_file("whisper/.envdefault", ["VAD", "DEVICE", "MODEL", "SERVICE_MODE"]) - serving, test_file, dockerfile, use_local_cache, envs = setup - testobject = TestContainer() - test_result = testobject.run_test(serving, test_file, dockerfile, use_local_cache, envs) - if test_result!=True: - self.fail(test_result) + def test_03_whisper_integration(self, setup): + dockerfile, serving, env_variables = setup + copy_env_file("whisper/.envdefault", env_variables) + self.run_test(dockerfile, serving=serving, env_variables=env_variables) - def test_whisper_curl_not_existing_file(self): - print() - copy_env_file("whisper/.envdefault", ["VAD", "DEVICE", "MODEL", "SERVICE_MODE"]) - serving = "http" - test_file = "notexisting" - dockerfile = "whisper/Dockerfile.ctranslate2" - use_local_cache = 1 - envs = "MODEL=tiny " - testobject = TestContainer() + def test_02_whisper_failures_not_existing_file(self): + env_variables = "MODEL=tiny" + copy_env_file("whisper/.envdefault", env_variables) with self.assertRaises(FileNotFoundError): - testobject.run_test(serving, test_file, dockerfile, use_local_cache, envs) + self.run_test(test_file="notexisting", env_variables=env_variables, expect_failure=False) + self.cleanup() - def test_cuda_on_cpu_dockerfile(self): - print() - copy_env_file("whisper/.envdefault", ["VAD", "DEVICE", "MODEL", "SERVICE_MODE"]) - serving = "http" - test_file = "test/bonjour.wav" + def test_02_whisper_failures_cuda_on_cpu_dockerfile(self): + env_variables = "MODEL=tiny DEVICE=cuda" dockerfile = "whisper/Dockerfile.ctranslate2.cpu" - use_local_cache = 1 - envs = "MODEL=tiny DEVICE=cuda" - testobject = TestContainer(ignore_failed_tests="The server container has stopped for an unexpected reason.") - self.assertIn("The server container has stopped for an unexpected reason.", testobject.run_test(serving, test_file, dockerfile, use_local_cache, envs)) - - def test_model_whisper(self): - print() - copy_env_file("whisper/.envdefault", ["VAD", "DEVICE", "MODEL", "SERVICE_MODE"]) - serving = "http" - test_file = "test/bonjour.wav" - dockerfile = "whisper/Dockerfile.ctranslate2" - use_local_cache = 1 - envs = "MODEL=small" - testobject = TestContainer() - test_result = testobject.run_test(serving, test_file, dockerfile, use_local_cache, envs) - if test_result!=True: - self.fail(test_result) - - def test_vad_whisper(self): - print() - copy_env_file("whisper/.envdefault", ["VAD", "DEVICE", "MODEL", "SERVICE_MODE"]) - serving = "http" - test_file = "test/bonjour.wav" - dockerfile = "whisper/Dockerfile.ctranslate2" - use_local_cache = 1 - envs = "VAD=whatever" - testobject = TestContainer(ignore_failed_tests="The server container has stopped for an unexpected reason.") - self.assertIn("The server container has stopped for an unexpected reason.", testobject.run_test(serving, test_file, dockerfile, use_local_cache, envs)) + copy_env_file("whisper/.envdefault", env_variables) + self.assertIn("cannot open shared object file", self.run_test(dockerfile, env_variables=env_variables, expect_failure=False)) + + def test_02_whisper_failures_wrong_vad(self): + env_variables = "VAD=whatever MODEL=tiny" + copy_env_file("whisper/.envdefault", env_variables) + self.assertIn("Got unexpected VAD method whatever", self.run_test(env_variables=env_variables, expect_failure=False)) + + def test_04_model_whisper(self): + env_variables = "MODEL=small" + copy_env_file("whisper/.envdefault", env_variables) + self.run_test(env_variables=env_variables) AM_PATH = None @@ -297,7 +279,7 @@ def test_vad_whisper(self): from configparser import ConfigParser config = ConfigParser() - config.read('test/test_config.ini') + config.read(f"{TESTDIR}/test_config.ini") SERVER_STARTING_TIMEOUT = int(config.get('server', 'STARTING_TIMEOUT')) if config.get('server', 'STARTING_TIMEOUT')!="" else SERVER_STARTING_TIMEOUT From b6ca5c0bd73eb043f865644e76a6c3096ae082a0 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Mon, 15 Apr 2024 17:03:51 +0200 Subject: [PATCH 47/50] Add separators between tests --- test/test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test.py b/test/test.py index 56f5877..625c5fd 100644 --- a/test/test.py +++ b/test/test.py @@ -228,6 +228,10 @@ def run_test(self, docker_image="whisper/Dockerfile.ctranslate2", serving="http" def setUp(self): # Print an empty line because unittest prints the name of the test first, without a newline print() + print("-"*70) + + def tearDown(self): + print("-"*70) @idata(generate_kaldi_test_setups()) def test_01_kaldi_integration(self, setup): From 74fd2115681d23cf653c3359a84429097d16b13e Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Mon, 15 Apr 2024 17:12:25 +0200 Subject: [PATCH 48/50] Do not run all the tests --- test/test.py | 38 +++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/test/test.py b/test/test.py index 625c5fd..68af4d4 100644 --- a/test/test.py +++ b/test/test.py @@ -16,8 +16,12 @@ def generate_whisper_test_setups(): - dockerfiles = ["whisper/Dockerfile.ctranslate2", "whisper/Dockerfile.ctranslate2.cpu", - "whisper/Dockerfile.torch", "whisper/Dockerfile.torch.cpu"] + dockerfiles = [ + "whisper/Dockerfile.ctranslate2", + "whisper/Dockerfile.ctranslate2.cpu", + "whisper/Dockerfile.torch", + "whisper/Dockerfile.torch.cpu", + ] servings = ["http", "task"] @@ -30,9 +34,16 @@ def generate_whisper_test_setups(): for vad in vads: for model in models: for serving in servings: - # try: + + # Test CPU dockerfile only on CPU if dockerfile.endswith("cpu") and device != "cpu": continue + + # Do not test all VAD settings if not on CPU + if vad not in ["NONE", "silero"]: + if device != "cpu": + continue + env_variables = "" if vad != "NONE": env_variables += f"VAD={vad} " @@ -67,9 +78,9 @@ class TestRunner(unittest.TestCase): built_images = [] - def __init__(self, *args, **kwargs): - super(TestRunner, self).__init__(*args, **kwargs) - self.cleanup() + # def __init__(self, *args, **kwargs): + # super(TestRunner, self).__init__(*args, **kwargs) + # self.cleanup() def echo_success(self, message): print('\033[0;32m' + u'\u2714' + '\033[0m ' + message) @@ -83,11 +94,11 @@ def echo_note(self, message): def echo_command(self, message): print(f"$ {message}") - def report_failure(self, message, expect_failure=True): - if expect_failure: + def report_failure(self, message, expect_failure=False): + if not expect_failure: self.echo_failure(message) self.cleanup() - if expect_failure: + if not expect_failure: self.fail(message) return message @@ -100,8 +111,9 @@ def cleanup(self): os.remove("build_finished") except FileNotFoundError: pass - subprocess.run(["docker", "stop", "test_redis"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + self.echo_command("docker stop test_container") subprocess.run(["docker", "stop", "test_container"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + subprocess.run(["docker", "stop", "test_redis"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) def process_output(self, p): l = p.communicate()[0].decode('utf-8').replace('\n', '\n\t') @@ -255,19 +267,19 @@ def test_02_whisper_failures_not_existing_file(self): env_variables = "MODEL=tiny" copy_env_file("whisper/.envdefault", env_variables) with self.assertRaises(FileNotFoundError): - self.run_test(test_file="notexisting", env_variables=env_variables, expect_failure=False) + self.run_test(test_file="notexisting", env_variables=env_variables, expect_failure=True) self.cleanup() def test_02_whisper_failures_cuda_on_cpu_dockerfile(self): env_variables = "MODEL=tiny DEVICE=cuda" dockerfile = "whisper/Dockerfile.ctranslate2.cpu" copy_env_file("whisper/.envdefault", env_variables) - self.assertIn("cannot open shared object file", self.run_test(dockerfile, env_variables=env_variables, expect_failure=False)) + self.assertIn("cannot open shared object file", self.run_test(dockerfile, env_variables=env_variables, expect_failure=True)) def test_02_whisper_failures_wrong_vad(self): env_variables = "VAD=whatever MODEL=tiny" copy_env_file("whisper/.envdefault", env_variables) - self.assertIn("Got unexpected VAD method whatever", self.run_test(env_variables=env_variables, expect_failure=False)) + self.assertIn("Got unexpected VAD method whatever", self.run_test(env_variables=env_variables, expect_failure=True)) def test_04_model_whisper(self): env_variables = "MODEL=small" From 38f93bd5fb72bdd91d62973bcf49fa708a623c0d Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Mon, 15 Apr 2024 17:17:34 +0200 Subject: [PATCH 49/50] Use None instead of NONE string --- test/test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test.py b/test/test.py index 68af4d4..ff321f2 100644 --- a/test/test.py +++ b/test/test.py @@ -25,8 +25,8 @@ def generate_whisper_test_setups(): servings = ["http", "task"] - vads = ["NONE", "false", "auditok", "silero"] - devices = ["NONE", "cpu", "cuda"] + vads = [None, "false", "auditok", "silero"] + devices = [None, "cpu", "cuda"] models = ["tiny"] for dockerfile in dockerfiles: @@ -40,14 +40,14 @@ def generate_whisper_test_setups(): continue # Do not test all VAD settings if not on CPU - if vad not in ["NONE", "silero"]: + if vad not in [None, "silero"]: if device != "cpu": continue env_variables = "" - if vad != "NONE": + if vad: env_variables += f"VAD={vad} " - if device != "NONE": + if device: env_variables += f"DEVICE={device} " env_variables += f"MODEL={model}" From 180099059724aa54825d1d222ba4f64c7fcf6468 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Mon, 15 Apr 2024 17:36:03 +0200 Subject: [PATCH 50/50] Do launch redis once only --- test/test.py | 59 ++++++++++++++++++++++++++++++++-------------------- 1 file changed, 36 insertions(+), 23 deletions(-) diff --git a/test/test.py b/test/test.py index ff321f2..f683583 100644 --- a/test/test.py +++ b/test/test.py @@ -77,6 +77,7 @@ def copy_env_file(env_file, env_variables=""): class TestRunner(unittest.TestCase): built_images = [] + redis_launched = False # def __init__(self, *args, **kwargs): # super(TestRunner, self).__init__(*args, **kwargs) @@ -107,13 +108,13 @@ def report_success(self): self.cleanup() def cleanup(self): - try: - os.remove("build_finished") - except FileNotFoundError: - pass - self.echo_command("docker stop test_container") - subprocess.run(["docker", "stop", "test_container"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - subprocess.run(["docker", "stop", "test_redis"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + # Check if the container is running + p = subprocess.Popen(["docker", "ps", "-a"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, err = p.communicate() + if b"test_container" in out: + self.echo_command("docker stop test_container") + subprocess.run(["docker", "stop", "test_container"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + time.sleep(0.2) # Without this, the following tests can fail (The container name "/test_container" is already in use) def process_output(self, p): l = p.communicate()[0].decode('utf-8').replace('\n', '\n\t') @@ -141,7 +142,19 @@ def check_http_server_availability(self, server, pid): elapsed_time += retry_interval return f"Server: {server} is not available after {total_wait_time} seconds, server launching must have failed.\n{self.process_output(pid)}" - + + def launch_redis(self): + if TestRunner.redis_launched: + return + cmd = "docker run --rm -p 6379:6379 --name test_redis redis/redis-stack-server:latest redis-server /etc/redis-stack.conf --protected-mode no --bind 0.0.0.0 --loglevel debug" + self.echo_command(cmd) + p = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + time.sleep(2) + if p.poll() is not None: + self.cleanup() + return f"Redis server failed to start.\n{self.process_output(p)}", None + TestRunner.redis_launched = True + def build_and_run_container(self, serving, docker_image, env_variables, use_local_cache): self.echo_note(f"* Docker image: {docker_image}") self.echo_note(f"* Options.....: {env_variables}") @@ -159,15 +172,8 @@ def build_and_run_container(self, serving, docker_image, env_variables, use_loca build_args += f"-v {home}/.cache:/root/.cache " if serving == "task": - # Launch redis + self.launch_redis() build_args += "-v {}/:/opt/audio ".format(os.getcwd()) - cmd = "docker run --rm -p 6379:6379 --name test_redis redis/redis-stack-server:latest redis-server /etc/redis-stack.conf --protected-mode no --bind 0.0.0.0 --loglevel debug" - self.echo_command(cmd) - p = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) - if p.poll() is not None: - self.cleanup() - return f"Redis server failed to start.\n{self.process_output(p)}", None - time.sleep(2) tag = f"test_{os.path.basename(docker_image)}" if tag not in TestRunner.built_images: @@ -263,18 +269,18 @@ def test_03_whisper_integration(self, setup): copy_env_file("whisper/.envdefault", env_variables) self.run_test(dockerfile, serving=serving, env_variables=env_variables) + def test_02_whisper_failures_cuda_on_cpu_dockerfile(self): + env_variables = "MODEL=tiny DEVICE=cuda" + dockerfile = "whisper/Dockerfile.ctranslate2.cpu" + copy_env_file("whisper/.envdefault", env_variables) + self.assertIn("cannot open shared object file", self.run_test(dockerfile, env_variables=env_variables, expect_failure=True)) + def test_02_whisper_failures_not_existing_file(self): env_variables = "MODEL=tiny" copy_env_file("whisper/.envdefault", env_variables) with self.assertRaises(FileNotFoundError): self.run_test(test_file="notexisting", env_variables=env_variables, expect_failure=True) self.cleanup() - - def test_02_whisper_failures_cuda_on_cpu_dockerfile(self): - env_variables = "MODEL=tiny DEVICE=cuda" - dockerfile = "whisper/Dockerfile.ctranslate2.cpu" - copy_env_file("whisper/.envdefault", env_variables) - self.assertIn("cannot open shared object file", self.run_test(dockerfile, env_variables=env_variables, expect_failure=True)) def test_02_whisper_failures_wrong_vad(self): env_variables = "VAD=whatever MODEL=tiny" @@ -286,6 +292,10 @@ def test_04_model_whisper(self): copy_env_file("whisper/.envdefault", env_variables) self.run_test(env_variables=env_variables) +def finalize_tests(): + subprocess.run(["docker", "stop", "test_container"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + subprocess.run(["docker", "stop", "test_redis"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + AM_PATH = None LM_PATH = None @@ -302,4 +312,7 @@ def test_04_model_whisper(self): AM_PATH = config.get('kaldi', 'AM_PATH') LM_PATH = config.get('kaldi', 'LM_PATH') - unittest.main(verbosity=2) + try: + unittest.main(verbosity=2) + finally: + finalize_tests()