Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replaced classification model with EncDecSpeakerLabelModel #11887

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 19 additions & 215 deletions nemo/collections/asr/models/classification_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from nemo.collections.asr.data import audio_to_label_dataset, feature_to_label_dataset
from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel
from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel
from nemo.collections.asr.parts.mixins import TranscriptionMixin, TranscriptionReturnType
from nemo.collections.asr.parts.mixins.transcription import InternalTranscribeConfig
from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer
Expand Down Expand Up @@ -484,210 +485,30 @@ def get_transcribe_config(cls) -> ClassificationInferConfig:
return ClassificationInferConfig()


@deprecated(explanation='EncDecClassificationModel will be merged with EncDecSpeakerLabelModel class.')
class EncDecClassificationModel(_EncDecBaseModel):
"""Encoder decoder Classification models."""

def __init__(self, cfg: DictConfig, trainer: Trainer = None):

if cfg.get("is_regression_task", False):
raise ValueError(f"EndDecClassificationModel requires the flag is_regression_task to be set as false")

super().__init__(cfg=cfg, trainer=trainer)

def _setup_preprocessor(self):
return EncDecClassificationModel.from_config_dict(self._cfg.preprocessor)

def _setup_encoder(self):
return EncDecClassificationModel.from_config_dict(self._cfg.encoder)

def _setup_decoder(self):
return EncDecClassificationModel.from_config_dict(self._cfg.decoder)

def _setup_loss(self):
return CrossEntropyLoss()

def _setup_metrics(self):
self._accuracy = TopKClassificationAccuracy(dist_sync_on_step=True)
class EncDecClassificationModel(EncDecSpeakerLabelModel):
def forward_for_export(self, audio_signal, length):
encoded, length = self.encoder(audio_signal=audio_signal, length=length)
logits = self.decoder(encoder_output=encoded, length=length)
return logits

@classmethod
def list_available_models(cls) -> Optional[List[PretrainedModelInfo]]:
def _update_decoder_config(self, labels, cfg):
"""
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
Update the number of classes in the decoder based on labels provided.

Returns:
List of available pre-trained models.
Args:
labels: The current labels of the model
cfg: The config of the decoder which will be updated.
"""
results = []

model = PretrainedModelInfo(
pretrained_model_name="vad_multilingual_marblenet",
description="For details about this model, please visit https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/vad_multilingual_marblenet",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/vad_multilingual_marblenet/versions/1.10.0/files/vad_multilingual_marblenet.nemo",
)
results.append(model)

model = PretrainedModelInfo(
pretrained_model_name="vad_telephony_marblenet",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:vad_telephony_marblenet",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/vad_telephony_marblenet/versions/1.0.0rc1/files/vad_telephony_marblenet.nemo",
)
results.append(model)

model = PretrainedModelInfo(
pretrained_model_name="vad_marblenet",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:vad_marblenet",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/vad_marblenet/versions/1.0.0rc1/files/vad_marblenet.nemo",
)
results.append(model)

model = PretrainedModelInfo(
pretrained_model_name="commandrecognition_en_matchboxnet3x1x64_v1",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:commandrecognition_en_matchboxnet3x1x64_v1",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/commandrecognition_en_matchboxnet3x1x64_v1/versions/1.0.0rc1/files/commandrecognition_en_matchboxnet3x1x64_v1.nemo",
)
results.append(model)

model = PretrainedModelInfo(
pretrained_model_name="commandrecognition_en_matchboxnet3x2x64_v1",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:commandrecognition_en_matchboxnet3x2x64_v1",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/commandrecognition_en_matchboxnet3x2x64_v1/versions/1.0.0rc1/files/commandrecognition_en_matchboxnet3x2x64_v1.nemo",
)
results.append(model)

model = PretrainedModelInfo(
pretrained_model_name="commandrecognition_en_matchboxnet3x1x64_v2",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:commandrecognition_en_matchboxnet3x1x64_v2",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/commandrecognition_en_matchboxnet3x1x64_v2/versions/1.0.0rc1/files/commandrecognition_en_matchboxnet3x1x64_v2.nemo",
)
results.append(model)

model = PretrainedModelInfo(
pretrained_model_name="commandrecognition_en_matchboxnet3x2x64_v2",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:commandrecognition_en_matchboxnet3x2x64_v2",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/commandrecognition_en_matchboxnet3x2x64_v2/versions/1.0.0rc1/files/commandrecognition_en_matchboxnet3x2x64_v2.nemo",
)
results.append(model)

model = PretrainedModelInfo(
pretrained_model_name="commandrecognition_en_matchboxnet3x1x64_v2_subset_task",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:commandrecognition_en_matchboxnet3x1x64_v2_subset_task",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/commandrecognition_en_matchboxnet3x1x64_v2_subset_task/versions/1.0.0rc1/files/commandrecognition_en_matchboxnet3x1x64_v2_subset_task.nemo",
)
results.append(model)

model = PretrainedModelInfo(
pretrained_model_name="commandrecognition_en_matchboxnet3x2x64_v2_subset_task",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:commandrecognition_en_matchboxnet3x2x64_v2_subset_task",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/commandrecognition_en_matchboxnet3x2x64_v2_subset_task/versions/1.0.0rc1/files/commandrecognition_en_matchboxnet3x2x64_v2_subset_task.nemo",
)
results.append(model)
return results

@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
return {"outputs": NeuralType(('B', 'D'), LogitsType())}

# PTL-specific methods
def training_step(self, batch, batch_nb):
audio_signal, audio_signal_len, labels, labels_len = batch
logits = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len)
loss_value = self.loss(logits=logits, labels=labels)

self.log('train_loss', loss_value)
self.log('learning_rate', self._optimizer.param_groups[0]['lr'])
self.log('global_step', self.trainer.global_step)

self._accuracy(logits=logits, labels=labels)
topk_scores = self._accuracy.compute()
self._accuracy.reset()

for top_k, score in zip(self._accuracy.top_k, topk_scores):
self.log('training_batch_accuracy_top_{}'.format(top_k), score)

return {
'loss': loss_value,
}

def validation_step(self, batch, batch_idx, dataloader_idx=0):
audio_signal, audio_signal_len, labels, labels_len = batch
logits = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len)
loss_value = self.loss(logits=logits, labels=labels)
acc = self._accuracy(logits=logits, labels=labels)
correct_counts, total_counts = self._accuracy.correct_counts_k, self._accuracy.total_counts_k
loss = {
'val_loss': loss_value,
'val_correct_counts': correct_counts,
'val_total_counts': total_counts,
'val_acc': acc,
}
if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(loss)
else:
self.validation_step_outputs.append(loss)
return loss

def test_step(self, batch, batch_idx, dataloader_idx=0):
audio_signal, audio_signal_len, labels, labels_len = batch
logits = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len)
loss_value = self.loss(logits=logits, labels=labels)
acc = self._accuracy(logits=logits, labels=labels)
correct_counts, total_counts = self._accuracy.correct_counts_k, self._accuracy.total_counts_k
loss = {
'test_loss': loss_value,
'test_correct_counts': correct_counts,
'test_total_counts': total_counts,
'test_acc': acc,
}
if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(loss)
else:
self.test_step_outputs.append(loss)
return loss

def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()
correct_counts = torch.stack([x['val_correct_counts'] for x in outputs]).sum(axis=0)
total_counts = torch.stack([x['val_total_counts'] for x in outputs]).sum(axis=0)

self._accuracy.correct_counts_k = correct_counts
self._accuracy.total_counts_k = total_counts
topk_scores = self._accuracy.compute()
self._accuracy.reset()

tensorboard_log = {'val_loss': val_loss_mean}
for top_k, score in zip(self._accuracy.top_k, topk_scores):
tensorboard_log['val_epoch_top@{}'.format(top_k)] = score

return {'log': tensorboard_log}

def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean()
correct_counts = torch.stack([x['test_correct_counts'].unsqueeze(0) for x in outputs]).sum(axis=0)
total_counts = torch.stack([x['test_total_counts'].unsqueeze(0) for x in outputs]).sum(axis=0)

self._accuracy.correct_counts_k = correct_counts
self._accuracy.total_counts_k = total_counts
topk_scores = self._accuracy.compute()
self._accuracy.reset()

tensorboard_log = {'test_loss': test_loss_mean}
for top_k, score in zip(self._accuracy.top_k, topk_scores):
tensorboard_log['test_epoch_top@{}'.format(top_k)] = score
OmegaConf.set_struct(cfg, False)
if 'params' in cfg:
cfg.params.num_classes = len(labels)
cfg.num_classes = len(labels)

return {'log': tensorboard_log}
OmegaConf.set_struct(cfg, True)

@typecheck()
def forward(
self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None
):
logits = super().forward(
input_signal=input_signal,
input_signal_length=input_signal_length,
processed_signal=processed_signal,
processed_signal_length=processed_signal_length,
)
return logits
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self._update_decoder_config(cfg.labels, cfg.decoder)
super().__init__(cfg, trainer)

def change_labels(self, new_labels: List[str]):
"""
Expand Down Expand Up @@ -740,23 +561,6 @@ def change_labels(self, new_labels: List[str]):

logging.info(f"Changed decoder output to {self.decoder.num_classes} labels.")

def _update_decoder_config(self, labels, cfg):
"""
Update the number of classes in the decoder based on labels provided.

Args:
labels: The current labels of the model
cfg: The config of the decoder which will be updated.
"""
OmegaConf.set_struct(cfg, False)

if 'params' in cfg:
cfg.params.num_classes = len(labels)
else:
cfg.num_classes = len(labels)

OmegaConf.set_struct(cfg, True)


class EncDecRegressionModel(_EncDecBaseModel):
"""Encoder decoder class for speech regression models.
Expand Down
88 changes: 47 additions & 41 deletions tests/collections/asr/test_asr_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ def speech_classification_model():

decoder = {
'cls': 'nemo.collections.asr.modules.ConvASRDecoderClassification',
'params': {'feat_in': 32, 'num_classes': 30,},
'params': {
'feat_in': 32,
'num_classes': 30,
},
}

modelConfig = DictConfig(
Expand Down Expand Up @@ -95,7 +98,10 @@ def frame_classification_model():

decoder = {
'cls': 'nemo.collections.common.parts.MultiLayerPerceptron',
'params': {'hidden_size': 32, 'num_classes': 5,},
'params': {
'hidden_size': 32,
'num_classes': 5,
},
}

modelConfig = DictConfig(
Expand Down Expand Up @@ -149,7 +155,7 @@ def test_forward(self, speech_classification_model):
logprobs_instance = torch.cat(logprobs_instance, 0)

# batch size 4
logprobs_batch = asr_model.forward(input_signal=input_signal, input_signal_length=length)
logprobs_batch = asr_model.forward(input_signal=input_signal, input_signal_length=length)[0]

assert logprobs_instance.shape == logprobs_batch.shape
diff = torch.mean(torch.abs(logprobs_instance - logprobs_batch))
Expand All @@ -174,44 +180,44 @@ def test_vocab_change(self, speech_classification_model):
# fully connected + bias
assert asr_model.num_weights == nw1 + 3 * (asr_model.decoder._feat_in + 1)

@pytest.mark.unit
def test_transcription(self, speech_classification_model, test_data_dir):
# Ground truth labels = ["yes", "no"]
audio_filenames = ['an22-flrp-b.wav', 'an90-fbbh-b.wav']
audio_paths = [os.path.join(test_data_dir, "asr", "train", "an4", "wav", fp) for fp in audio_filenames]

model = speech_classification_model.eval()

# Test Top 1 classification transcription
results = model.transcribe(audio_paths, batch_size=2)
assert len(results) == 2
assert results[0].shape == torch.Size([1])

# Test Top 5 classification transcription
model._accuracy.top_k = [5] # set top k to 5 for accuracy calculation
results = model.transcribe(audio_paths, batch_size=2)
assert len(results) == 2
assert results[0].shape == torch.Size([5])

# Test Top 1 and Top 5 classification transcription
model._accuracy.top_k = [1, 5]
results = model.transcribe(audio_paths, batch_size=2)
assert len(results) == 2
assert results[0].shape == torch.Size([2, 1])
assert results[1].shape == torch.Size([2, 5])
assert model._accuracy.top_k == [1, 5]

# Test log probs extraction
model._accuracy.top_k = [1]
results = model.transcribe(audio_paths, batch_size=2, logprobs=True)
assert len(results) == 2
assert results[0].shape == torch.Size([len(model.cfg.labels)])

# Test log probs extraction remains same for any top_k
model._accuracy.top_k = [5]
results = model.transcribe(audio_paths, batch_size=2, logprobs=True)
assert len(results) == 2
assert results[0].shape == torch.Size([len(model.cfg.labels)])
# @pytest.mark.unit
# def test_transcription(self, speech_classification_model, test_data_dir):
# # Ground truth labels = ["yes", "no"]
# audio_filenames = ['an22-flrp-b.wav', 'an90-fbbh-b.wav']
# audio_paths = [os.path.join(test_data_dir, "asr", "train", "an4", "wav", fp) for fp in audio_filenames]

# model = speech_classification_model.eval()

# # Test Top 1 classification transcription
# results = model.transcribe(audio_paths, batch_size=2)
# assert len(results) == 2
# assert results[0].shape == torch.Size([1])

# # Test Top 5 classification transcription
# model._accuracy.top_k = [5] # set top k to 5 for accuracy calculation
# results = model.transcribe(audio_paths, batch_size=2)
# assert len(results) == 2
# assert results[0].shape == torch.Size([5])

# # Test Top 1 and Top 5 classification transcription
# model._accuracy.top_k = [1, 5]
# results = model.transcribe(audio_paths, batch_size=2)
# assert len(results) == 2
# assert results[0].shape == torch.Size([2, 1])
# assert results[1].shape == torch.Size([2, 5])
# assert model._accuracy.top_k == [1, 5]

# # Test log probs extraction
# model._accuracy.top_k = [1]
# results = model.transcribe(audio_paths, batch_size=2, logprobs=True)
# assert len(results) == 2
# assert results[0].shape == torch.Size([len(model.cfg.labels)])

# # Test log probs extraction remains same for any top_k
# model._accuracy.top_k = [5]
# results = model.transcribe(audio_paths, batch_size=2, logprobs=True)
# assert len(results) == 2
# assert results[0].shape == torch.Size([len(model.cfg.labels)])

@pytest.mark.unit
def test_EncDecClassificationDatasetConfig_for_AudioToSpeechLabelDataset(self):
Expand Down
Loading