Skip to content

Commit

Permalink
Support explicit VAD timestamps
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeronymous committed Mar 1, 2024
1 parent 79cc85e commit 8352601
Showing 1 changed file with 48 additions and 15 deletions.
63 changes: 48 additions & 15 deletions whisper_timestamped/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def transcribe_timestamped(
Whether to compute word confidence.
If True, a finer confidence for each segment will be computed as well.
vad: bool or str in ["silero", "silero:3.1", "auditok"]
vad: bool or str in ["silero", "silero:3.1", "auditok"] or list of start/end timestamps pairs corresponding to speech (ex: [(0.0, 3.50), (32.43, 36.43)])
Whether to perform voice activity detection (VAD) on the audio file, to remove silent parts before transcribing with Whisper model.
This should decrease hallucinations from the Whisper model.
When set to True, the default VAD algorithm is used (silero).
Expand Down Expand Up @@ -279,7 +279,7 @@ def transcribe_timestamped(

if vad:
audio = get_audio_tensor(audio)
audio, convert_timestamps = remove_non_speech(audio, method=vad, plot=plot_word_alignment)
audio, convert_timestamps = remove_non_speech(audio, method=vad, sample_rate=SAMPLE_RATE, plot=plot_word_alignment)

global num_alignment_for_plot
num_alignment_for_plot = 0
Expand Down Expand Up @@ -1844,11 +1844,23 @@ def split_tokens_on_spaces(tokens: torch.Tensor, tokenizer, remove_punctuation_f
return words, word_tokens, word_tokens_indices

def check_vad_method(method, with_version=False):
"""
Check whether the VAD method is valid and return the method in a consistent format
method: str or list or True or False
"""
if method in [True, "True", "true"]:
return check_vad_method("silero") # default method
elif method in [False, "False", "false"]:
return False
elif method.startswith("silero"):
elif not isinstance(method, str) and hasattr(method, '__iter__'):
# list of explicit timestamps
checked_pairs = []
for s_e in method:
assert len(s_e) == 2, f"Got unexpected element {s_e} in the list of VAD segments. Expect (start, end) pairs"
checked_pairs.append(tuple(s_e))
return checked_pairs
elif isinstance(method, str) and method.startswith("silero"):
version = None
if method != "silero":
assert method.startswith("silero:"), f"Got unexpected VAD method {method}"
Expand All @@ -1869,12 +1881,18 @@ def check_vad_method(method, with_version=False):
except ImportError:
raise ImportError("Please install auditok to use the auditok VAD (or use another VAD method)")
else:
raise ValueError(f"Got unexpected VAD method {method}")
try:
method = eval(method)
assert hasattr(method, '__iter__')
except:
raise ValueError(f"Got unexpected VAD method {method}")
return check_vad_method(method, with_version=with_version)
return method

_silero_vad_model = {}
_has_onnx = None
def get_vad_segments(audio,
sample_rate=SAMPLE_RATE,
output_sample=False,
min_speech_duration=0.1,
min_silence_duration=0.1,
Expand All @@ -1894,12 +1912,17 @@ def get_vad_segments(audio,
minimum duration (in sec) of a silence segment
dilatation: float
how much (in sec) to enlarge each speech segment detected by the VAD
method: str
method: str or list
VAD method to use (auditok, silero, silero:v3.1)
"""
global _silero_vad_model, _silero_get_speech_ts, _has_onnx

if method.startswith("silero"):
if isinstance(method, list):
# Explicit timestamps
segments = [{"start": s * sample_rate, "end": e * sample_rate} for (s, e) in method]
dilatation = 0

elif isinstance(method, str) and method.startswith("silero"):

version = None
_, version = check_vad_method(method, True)
Expand Down Expand Up @@ -1969,6 +1992,7 @@ def apply_folder_hack():
audio = audio / max(0.1, audio.abs().max())

segments = _silero_get_speech_ts(audio, _silero_vad_model[version],
sampling_rate = sample_rate,
min_speech_duration_ms = round(min_speech_duration * 1000),
min_silence_duration_ms = round(min_silence_duration * 1000),
return_seconds = False,
Expand All @@ -1982,11 +2006,11 @@ def apply_folder_hack():

data = (audio.numpy() * 32767).astype(np.int16).tobytes()

audio_duration = len(audio) / SAMPLE_RATE
audio_duration = len(audio) / sample_rate

segments = auditok.split(
data,
sampling_rate=SAMPLE_RATE, # sampling frequency in Hz
sampling_rate=sample_rate, # sampling frequency in Hz
channels=1, # number of channels
sample_width=2, # number of bytes per sample
min_dur=min_speech_duration, # minimum duration of a valid audio event in seconds
Expand All @@ -1996,13 +2020,13 @@ def apply_folder_hack():
drop_trailing_silence=True,
)

segments = [{"start": s._meta.start * SAMPLE_RATE, "end": s._meta.end * SAMPLE_RATE} for s in segments]
segments = [{"start": s._meta.start * sample_rate, "end": s._meta.end * sample_rate} for s in segments]

else:
raise ValueError(f"Got unexpected VAD method {method}")

if dilatation > 0:
dilatation = round(dilatation * SAMPLE_RATE)
dilatation = round(dilatation * sample_rate)
new_segments = []
for seg in segments:
new_seg = {
Expand All @@ -2015,7 +2039,7 @@ def apply_folder_hack():
new_segments.append(new_seg)
segments = new_segments

ratio = 1 if output_sample else 1 / SAMPLE_RATE
ratio = 1 if output_sample else 1 / sample_rate

if ratio != 1:
for seg in segments:
Expand All @@ -2031,6 +2055,8 @@ def remove_non_speech(audio,
use_sample=False,
min_speech_duration=0.1,
min_silence_duration=1,
dilatation=0.5,
sample_rate=SAMPLE_RATE,
method="silero",
plot=False,
):
Expand All @@ -2048,6 +2074,8 @@ def remove_non_speech(audio,
minimum duration (in sec) of a speech segment
min_silence_duration: float
minimum duration (in sec) of a silence segment
dilatation: float
how much (in sec) to enlarge each speech segment detected by the VAD
method: str
method to use to remove non-speech segments
plot: bool or str
Expand All @@ -2057,9 +2085,11 @@ def remove_non_speech(audio,

segments = get_vad_segments(
audio,
sample_rate=sample_rate,
output_sample=True,
min_speech_duration=min_speech_duration,
min_silence_duration=min_silence_duration,
dilatation=dilatation,
method=method,
)

Expand All @@ -2074,17 +2104,17 @@ def remove_non_speech(audio,
plt.figure()
max_num_samples = 10000
step = (audio.shape[-1] // max_num_samples) + 1
times = [i*step/SAMPLE_RATE for i in range((audio.shape[-1]-1) // step + 1)]
times = [i*step/sample_rate for i in range((audio.shape[-1]-1) // step + 1)]
plt.plot(times, audio[::step])
for s, e in segments:
plt.axvspan(s/SAMPLE_RATE, e/SAMPLE_RATE, color='red', alpha=0.1)
plt.axvspan(s/sample_rate, e/sample_rate, color='red', alpha=0.1)
if isinstance(plot, str):
plt.savefig(f"{plot}.VAD.jpg", bbox_inches='tight', pad_inches=0)
else:
plt.show()

if not use_sample:
segments = [(float(s)/SAMPLE_RATE, float(e)/SAMPLE_RATE) for s,e in segments]
segments = [(float(s)/sample_rate, float(e)/sample_rate) for s,e in segments]

return audio_speech, lambda t, t2 = None: do_convert_timestamps(segments, t, t2)

Expand Down Expand Up @@ -2939,7 +2969,10 @@ def str2output_formats(string):
parser.add_argument('--language', help=f"language spoken in the audio, specify None to perform language detection.", choices=sorted(whisper.tokenizer.LANGUAGES.keys()) + sorted([k.title() for k in whisper.tokenizer.TO_LANGUAGE_CODE.keys()]), default=None)
# f"{', '.join(sorted(k+'('+v+')' for k,v in whisper.tokenizer.LANGUAGES.items()))}

parser.add_argument('--vad', default=False, help="whether to run Voice Activity Detection (VAD) to remove non-speech segment before applying Whisper model (removes hallucinations). Can be: True, False, silero, silero:3.1 (or another version), or autitok. Some additional libraries might be needed")
parser.add_argument('--vad', default=False, help="whether to run Voice Activity Detection (VAD) to remove non-speech segment before applying Whisper model (removes hallucinations). "
"Can be: True, False, auditok, silero (default when vad=True), silero:3.1 (or another version), or a list of timestamps in seconds (e.g. \"[(0.0, 3.50), (32.43, 36.43)]\"). "
"Note: Some additional libraries might be needed (torchaudio and onnxruntime for silero, auditok for auditok)."
)
parser.add_argument('--detect_disfluencies', default=False, help="whether to try to detect disfluencies, marking them as special words [*]", type=str2bool)
parser.add_argument('--recompute_all_timestamps', default=not TRUST_WHISPER_TIMESTAMP_BY_DEFAULT, help="Do not rely at all on Whisper timestamps (Experimental option: did not bring any improvement, but could be useful in cases where Whipser segment timestamp are wrong by more than 0.5 seconds)", type=str2bool)
parser.add_argument("--punctuations_with_words", default=True, help="whether to include punctuations in the words", type=str2bool)
Expand Down

0 comments on commit 8352601

Please sign in to comment.