Skip to content

Commit

Permalink
transcribe speech with timestamps
Browse files Browse the repository at this point in the history
Signed-off-by: Monica Sekoyan <[email protected]>
  • Loading branch information
monica-sekoyan committed Dec 11, 2024
1 parent ade8cab commit 81635e1
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 73 deletions.
2 changes: 1 addition & 1 deletion examples/asr/asr_vad/speech_to_text_with_vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
from nemo.collections.asr.data import feature_to_text_dataset
from nemo.collections.asr.metrics.wer import word_error_rate
from nemo.collections.asr.models import ASRModel, EncDecClassificationModel
from nemo.collections.asr.parts.submodules import CTCDecodingConfig
from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig
from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig
from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest
from nemo.collections.asr.parts.utils.vad_utils import (
Expand Down
2 changes: 2 additions & 0 deletions examples/asr/transcribe_speech_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def main(cfg: ParallelTranscriptionConfig):
if isinstance(model, EncDecHybridRNNTCTCModel) and cfg.decoder_type is not None:
model.change_decoding_strategy(decoder_type=cfg.decoder_type)

model.change_decoding_strategy(cfg.rnnt_decoding, decoder_type=cfg.decoder_type)

cfg.predict_ds.return_sample_id = True
cfg.predict_ds = match_train_config(predict_ds=cfg.predict_ds, train_ds=model.cfg.train_ds)

Expand Down
24 changes: 15 additions & 9 deletions nemo/collections/asr/data/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,16 +380,16 @@ def shard_manifests_if_needed(
world_size: int,
):
if shard_manifests:
if not torch.distributed.is_available():
logging.warning("Not running in torch.distributed mode. Manifest sharding not available")
return manifest_filepaths
# if not torch.distributed.is_available():
# logging.warning("Not running in torch.distributed mode. Manifest sharding not available")
# return manifest_filepaths

if not torch.distributed.is_initialized():
logging.warning(
'Manifest sharding was requested but torch.distributed is not initialized '
'Did you intend to set the defer_setup flag?'
)
return manifest_filepaths
# if not torch.distributed.is_initialized():
# logging.warning(
# 'Manifest sharding was requested but torch.distributed is not initialized '
# 'Did you intend to set the defer_setup flag?'
# )
# return manifest_filepaths

manifest_filepaths = expand_sharded_filepaths(
sharded_filepaths=manifest_filepaths,
Expand Down Expand Up @@ -848,6 +848,9 @@ def __init__(
self.shard_manifests = shard_manifests

# Shard manifests if necessary and possible and then expand the paths

print(self.shard_manifests)
print('='*20)
manifest_filepath = shard_manifests_if_needed(
shard_manifests=shard_manifests,
shard_strategy=shard_strategy,
Expand All @@ -856,6 +859,9 @@ def __init__(
global_rank=global_rank,
)

print(manifest_filepath)
print('='*20)

# If necessary, cache manifests from object store
cache_datastore_manifests(manifest_filepaths=manifest_filepath)

Expand Down
13 changes: 8 additions & 5 deletions nemo/collections/asr/data/audio_to_text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,9 @@ def get_tarred_dataset(
if 'labels' not in config:
logging.warning(f"dataset does not have explicitly defined labels")


if 'max_utts' in config:
raise ValueError('"max_utts" parameter is not supported for tarred datasets')
logging.warning('"max_utts" parameter is not supported for tarred datasets')

for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate(
zip(tarred_audio_filepaths, manifest_filepaths)
Expand Down Expand Up @@ -389,7 +390,7 @@ def get_tarred_dataset(
trim=config.get('trim_silence', False),
use_start_end_token=config.get('use_start_end_token', True),
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
shard_manifests=config.get('shard_manifests', False),
shard_manifests=config.get('shard_manifests', True),
global_rank=global_rank,
world_size=world_size,
return_sample_id=config.get('return_sample_id', False),
Expand Down Expand Up @@ -861,7 +862,7 @@ def write_on_batch_end(
):
import lhotse

for sample_id, transcribed_text in prediction:
for sample_id, hyp in prediction:
item = {}
if isinstance(sample_id, lhotse.cut.Cut):
sample = sample_id
Expand All @@ -871,7 +872,8 @@ def write_on_batch_end(
item["offset"] = sample.start
item["duration"] = sample.duration
item["text"] = sample.supervisions[0].text
item["pred_text"] = transcribed_text
item["pred_text"] = hyp.text
item['timestamps'] = hyp.timestep['segment']
self.outf.write(json.dumps(item) + "\n")
self.samples_num += 1
else:
Expand All @@ -880,7 +882,8 @@ def write_on_batch_end(
item["offset"] = sample.offset
item["duration"] = sample.duration
item["text"] = sample.text_raw
item["pred_text"] = transcribed_text
item["pred_text"] = hyp.text
item['timestamps'] = hyp.timestep['segment']
self.outf.write(json.dumps(item) + "\n")
self.samples_num += 1
return
Expand Down
70 changes: 35 additions & 35 deletions nemo/collections/asr/models/clustering_diarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,41 +425,41 @@ def diarize(self, paths2audio_files: List[str] = None, batch_size: int = 0):
self._perform_speech_activity_detection()

# Segmentation
scales = self.multiscale_args_dict['scale_dict'].items()
for scale_idx, (window, shift) in scales:

# Segmentation for the current scale (scale_idx)
self._run_segmentation(window, shift, scale_tag=f'_scale{scale_idx}')

# Embedding Extraction for the current scale (scale_idx)
self._extract_embeddings(self.subsegments_manifest_path, scale_idx, len(scales))

self.multiscale_embeddings_and_timestamps[scale_idx] = [self.embeddings, self.time_stamps]

embs_and_timestamps = get_embs_and_timestamps(
self.multiscale_embeddings_and_timestamps, self.multiscale_args_dict
)

# Clustering
all_reference, all_hypothesis = perform_clustering(
embs_and_timestamps=embs_and_timestamps,
AUDIO_RTTM_MAP=self.AUDIO_RTTM_MAP,
out_rttm_dir=out_rttm_dir,
clustering_params=self._cluster_params,
device=self._speaker_model.device,
verbose=self.verbose,
)
logging.info("Outputs are saved in {} directory".format(os.path.abspath(self._diarizer_params.out_dir)))

# Scoring
return score_labels(
self.AUDIO_RTTM_MAP,
all_reference,
all_hypothesis,
collar=self._diarizer_params.collar,
ignore_overlap=self._diarizer_params.ignore_overlap,
verbose=self.verbose,
)
# scales = self.multiscale_args_dict['scale_dict'].items()
# for scale_idx, (window, shift) in scales:

# # Segmentation for the current scale (scale_idx)
# self._run_segmentation(window, shift, scale_tag=f'_scale{scale_idx}')

# # Embedding Extraction for the current scale (scale_idx)
# self._extract_embeddings(self.subsegments_manifest_path, scale_idx, len(scales))

# self.multiscale_embeddings_and_timestamps[scale_idx] = [self.embeddings, self.time_stamps]

# embs_and_timestamps = get_embs_and_timestamps(
# self.multiscale_embeddings_and_timestamps, self.multiscale_args_dict
# )

# # Clustering
# all_reference, all_hypothesis = perform_clustering(
# embs_and_timestamps=embs_and_timestamps,
# AUDIO_RTTM_MAP=self.AUDIO_RTTM_MAP,
# out_rttm_dir=out_rttm_dir,
# clustering_params=self._cluster_params,
# device=self._speaker_model.device,
# verbose=self.verbose,
# )
# logging.info("Outputs are saved in {} directory".format(os.path.abspath(self._diarizer_params.out_dir)))

# # Scoring
# return score_labels(
# self.AUDIO_RTTM_MAP,
# all_reference,
# all_hypothesis,
# collar=self._diarizer_params.collar,
# ignore_overlap=self._diarizer_params.ignore_overlap,
# verbose=self.verbose,
# )

@staticmethod
def __make_nemo_file_from_folder(filename, source_dir):
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
del signal

best_hyp_text, all_hyp_text = self.decoding.rnnt_decoder_predictions_tensor(
encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False
encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=True
)

sample_id = sample_id.cpu().detach().numpy()
Expand Down
4 changes: 3 additions & 1 deletion nemo/collections/asr/parts/submodules/rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,9 @@ def _compute_offsets(
# If the exact timestep information is available, utilize the 1st non-rnnt blank token timestep
# as the start index.
if hypothesis.timestep is not None and len(hypothesis.timestep) > 0:
start_index = max(0, hypothesis.timestep[0] - 1)
first_timestep = hypothesis.timestep[0]
first_timestep = first_timestep if isinstance(first_timestep, int) else first_timestep.item()
start_index = max(0, first_timestep - 1)

# Construct the start and end indices brackets
end_indices = np.asarray(token_repetitions).cumsum()
Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/asr/parts/utils/speaker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,8 @@ def get_offset_and_duration(AUDIO_RTTM_MAP, uniq_id, decimals=5):
audio_path = AUDIO_RTTM_MAP[uniq_id]['audio_filepath']
if AUDIO_RTTM_MAP[uniq_id].get('duration', None):
duration = round(AUDIO_RTTM_MAP[uniq_id]['duration'], decimals)
offset = round(AUDIO_RTTM_MAP[uniq_id]['offset'], decimals)
# offset = round(AUDIO_RTTM_MAP[uniq_id].get('offset', 0), decimals)
offset = 0.0
else:
sound = sf.SoundFile(audio_path)
duration = sound.frames / sound.samplerate
Expand Down
67 changes: 47 additions & 20 deletions tools/ctc_segmentation/scripts/get_metrics_and_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import jiwer
import re
import argparse
import json
import os
Expand All @@ -23,6 +25,9 @@

from nemo.collections.asr.parts.preprocessing.segment import AudioSegment
from nemo.utils import logging
from nemo.collections.asr.parts.utils.transcribe_utils import (
PunctuationCapitalization,
)

parser = argparse.ArgumentParser("Calculate metrics and filters out samples based on thresholds")
parser.add_argument(
Expand All @@ -45,7 +50,7 @@
)
parser.add_argument("--max_edge_cer", type=int, help="Threshold edge CER value, %", default=60)
parser.add_argument("--max_duration", type=int, help="Max duration of a segment, seconds", default=-1)
parser.add_argument("--min_duration", type=int, help="Min duration of a segment, seconds", default=1)
parser.add_argument("--min_duration", type=int, help="Min duration of a segment, seconds", default=0)
parser.add_argument(
"--num_jobs",
default=-2,
Expand All @@ -59,7 +64,7 @@
)


def _calculate(line: dict, edge_len: int):
def _calculate(line: dict, edge_len: int, pc):
"""
Calculates metrics for every entry on manifest.json.
Expand All @@ -78,33 +83,55 @@ def _calculate(line: dict, edge_len: int):
"""
eps = 1e-9

text = line["text"].split()
pred_text = line["pred_text"].split()
text = line["text"]
pred_text = line["pred_text"]

num_words = max(len(text), eps)
word_dist = editdistance.eval(text, pred_text)

text = text.lower()
pred_text = pred_text.lower()

text = re.sub(r"[.,?:;]", "", text)
pred_text = re.sub(r"[.,?:;]", "", pred_text)

text_splitted = text.split()
pred_text_splitted = pred_text.split()

num_words = max(len(text_splitted), eps)

if num_words > eps:
measures = jiwer.compute_measures(text, pred_text)
insertions = measures['insertions']

insertion_rate = insertions / num_words * 100

line["insertion_rate"] = insertion_rate
line["insertions"] = insertions

word_dist = editdistance.eval(text_splitted, pred_text_splitted)
line["WER"] = word_dist / num_words * 100.0
num_chars = max(len(line["text"]), eps)
char_dist = editdistance.eval(line["text"], line["pred_text"])
num_chars = max(len(text), eps)
char_dist = editdistance.eval(text, pred_text)
line["CER"] = char_dist / num_chars * 100.0

line["start_CER"] = editdistance.eval(line["text"][:edge_len], line["pred_text"][:edge_len]) / edge_len * 100
line["end_CER"] = editdistance.eval(line["text"][-edge_len:], line["pred_text"][-edge_len:]) / edge_len * 100
line["len_diff_ratio"] = 1.0 * abs(len(text) - len(pred_text)) / max(len(text), eps)
line["start_CER"] = editdistance.eval(text[:edge_len], pred_text[:edge_len]) / edge_len * 100
line["end_CER"] = editdistance.eval(text[-edge_len:], pred_text[-edge_len:]) / edge_len * 100
line["len_diff_ratio"] = 1.0 * abs(len(text_splitted) - len(pred_text_splitted)) / max(len(text_splitted), eps)
return line


def get_metrics(manifest, manifest_out):
"""Calculate metrics for sample in manifest and saves the results to manifest_out"""
pc = PunctuationCapitalization(".,?:;...")

with open(manifest, "r") as f:
lines = f.readlines()

lines = Parallel(n_jobs=args.num_jobs)(
delayed(_calculate)(json.loads(line), edge_len=args.edge_len) for line in tqdm(lines)
delayed(_calculate)(json.loads(line), edge_len=args.edge_len, pc=pc) for line in tqdm(lines)
)
with open(manifest_out, "w") as f_out:
for line in lines:
f_out.write(json.dumps(line) + "\n")
f_out.write(json.dumps(line, ensure_ascii=False) + "\n")
logging.info(f"Metrics save at {manifest_out}")


Expand All @@ -131,16 +158,16 @@ def _apply_filters(
duration = item["duration"]
segmented_duration += duration
if (
cer <= max_cer
and wer <= max_wer
and len_diff_ratio <= max_len_diff_ratio
and item["end_CER"] <= max_edge_cer
and item["start_CER"] <= max_edge_cer
cer > max_cer
and wer > max_wer
and len_diff_ratio > max_len_diff_ratio
and item["end_CER"] > max_edge_cer
and item["start_CER"] > max_edge_cer
and (max_dur == -1 or (max_dur > -1 and duration < max_dur))
and duration > min_dur
and duration < min_dur
):
remaining_duration += duration
f_out.write(json.dumps(item) + "\n")
f_out.write(json.dumps(item, ensure_ascii=False) + "\n")

logging.info("-" * 50)
logging.info("Threshold values:")
Expand Down

0 comments on commit 81635e1

Please sign in to comment.