Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerated train using deepspeed and use an enlarged CommonVoice dataset #1

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions bash_runners/run_small_ds.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
mkdir -p logs

deepspeed run_speech_recognition_seq2seq_streaming.py \
--deepspeed="ds_config.json" \
--model_name_or_path="ales/whisper-small-belarusian" \
--dataset_name="../data/hf_dataset_cv14/" \
--dataset_config_name="be" \
--language="be" \
--train_split_name="train" \
--eval_split_name="dev" \
--model_index_name="Whisper Small Belarusian" \
\
--max_steps="150000" \
--output_dir="./whisper-small-be-deepspeed" \
--per_device_train_batch_size="64" \
--per_device_eval_batch_size="32" \
--logging_steps="50" \
--logging_first_step \
--learning_rate="3e-5" \
--learning_rate_end="1e-5" \
--from_disk \
--warmup_steps="0" \
--evaluation_strategy="steps" \
--eval_steps="7000" \
--save_strategy="steps" \
--save_steps="7000" \
--gradient_checkpointing \
--fp16 \
\
--shuffle_buffer_size="500" \
--generation_max_length="225" \
--max_duration_in_seconds="30" \
--text_column_name="sentence" \
--freeze_feature_encoder="False" \
--report_to="wandb" \
--metric_for_best_model="wer" \
--greater_is_better="False" \
--load_best_model_at_end \
\
--do_train \
--do_eval \
--ignore_data_skip \
--predict_with_generate \
--do_normalize_eval \
--streaming_train="True" \
--streaming_eval="True" \
--seed="43" \
--use_auth_token \
--push_to_hub="False" 2>&1 | tee "logs/train_$(date +"%Y%m%d-%H%M%S").log"
29 changes: 29 additions & 0 deletions ds_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},

"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
},

"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto"
}
Empty file added prepare_dataset/__init__.py
Empty file.
189 changes: 189 additions & 0 deletions prepare_dataset/make_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import os
import argparse
from argparse import Namespace
from typing import List, Dict, Union
from pathlib import Path

import pandas as pd
from datasets import DatasetDict, Dataset, Audio


def make_train_val_test_split(
path_to_cv: Union[str, Path], split_to_include: List[str]
) -> Dict[str, pd.DataFrame]:
"""Creating a custom train/val/text split based
on the original splits from CommonVoice.

The base split from CommonVoice are:
'train' -> 'train', 'dev' -> 'validation', 'test' -> 'test'.
Also CommonVoice have 'validated', 'unvalidated', 'other'
split, and this split can have the same speaker(client_id)
from base split('train', 'dev', 'test'). Also this split
can have the same sentences.
Therefore, need to filter out additional splits to 'test', 'dev'.

For example 'validated' can contain speaker 1 from 'test',
and speaker 2 from 'dev', so need to add data
with this speaker to 'test' and 'dev' split accordingly,
other data with different speakers go to train.
After this preproccesing need to remove audio with
validation and test sentence from train split, and remove
audio with dev sentence from test split.

Args:
path_to_cv (str): Path to folder which contain CommanVoice dataset
split_to_include (List[str]): Split from CommonVoice dataset
for spliting into train/val/test and concatenate with base split.

Returns:
Dict[str, pd.DataFrame]: Dict of train/val/test split DataFrames
"""
# Load base Dataframe from CommonVoice
base_train_df = pd.read_csv(os.path.join(Path(path_to_cv), "train.tsv"), sep="\t")
base_val_df = pd.read_csv(os.path.join(Path(path_to_cv), "dev.tsv"), sep="\t")
base_test_df = pd.read_csv(os.path.join(Path(path_to_cv), "test.tsv"), sep="\t")

# Get list of unique speaker
val_speaker_list = base_val_df["client_id"].unique()
test_speaker_list = base_test_df["client_id"].unique()

# Load other Dataframe of split to include in final dataset
splits = {
split: pd.read_csv(os.path.join(Path(path_to_cv), f"{split}.tsv"), sep="\t")
for split in split_to_include
}

# Make list of dataframes with speaker from base Dataframes
split_for_train = [
df[
~df["client_id"].isin(val_speaker_list)
& ~df["client_id"].isin(test_speaker_list)
]
for df in splits.values()
]
split_for_val = [
df[df["client_id"].isin(val_speaker_list)] for df in splits.values()
]
split_for_test = [
df[df["client_id"].isin(test_speaker_list)] for df in splits.values()
]

# Concatenate all DataFrame for each split in one
train_final_df = pd.concat(split_for_train + [base_train_df])
val_final_df = pd.concat(split_for_val + [base_val_df])
test_final_df = pd.concat(split_for_test + [base_test_df])

# Reset the index of the final dataframes
train_final_df = train_final_df.reset_index(drop=True)
val_final_df = val_final_df.reset_index(drop=True)
test_final_df = test_final_df.reset_index(drop=True)

# Get list of unique sentence from dev and test
val_sentence_list = val_final_df["sentence"].unique()
test_sentence_list = test_final_df["sentence"].unique()

# Delete dev and test sentence from train
train_final_df = train_final_df[
~train_final_df["sentence"].isin(val_sentence_list)
& ~train_final_df["sentence"].isin(test_sentence_list)
]

# Delete dev sentence from test
val_final_df = val_final_df[~val_final_df["sentence"].isin(test_sentence_list)]

# Make dict of splits
final_dict_df = {
"train": train_final_df,
"validation": val_final_df,
"test": test_final_df,
}
return final_dict_df


def process_raw_data(
path_to_cv: Union[str, Path],
dict_df: Dict[str, pd.DataFrame],
sampling_rate: int = 16000,
) -> DatasetDict:
"""Create a HuggingFace Dataset from raw train/val/test split dict

Args:
path_to_cv (str): Path to folder which contain CommanVoice dataset
dict_df (Dict[str, pd.DataFrame]): Dict of train/val/test split DataFrames
sampling_rate (int, optional): Sampling rate for read audio with hfd.Audio.
Defaults to 16000.

Returns:
DatasetDict: Processed HuggingFace DatasetDict
"""

# Add full path to new column 'audio'
def add_column(df: pd.DataFrame) -> pd.DataFrame:
df["audio"] = os.path.join(Path(path_to_cv), "clips", str(df["path"]))
return df

dict_df = {split: add_column(df) for split, df in dict_df.items()}

# Make a hf.DatasetDict from dict of DataFrame
hf_dataset = DatasetDict(
{
split: Dataset.from_pandas(df[["audio", "sentence"]].reset_index(drop=True))
for split, df in dict_df
}
)

# Process audio column to hf.Audio feature
hf_dataset = hf_dataset.cast_column(
column="audio", feature=Audio(sampling_rate=sampling_rate)
)
return hf_dataset


def pipline(run_opts: Namespace):
"""Pipline for process raw data and save dataset

Args:
run_opts (Namespace): Options from command line
"""
# Make train/val/test split for dataset
data_df = make_train_val_test_split(
path_to_cv=run_opts["path_to_cv"], split_to_include=run_opts["split_to_include"]
)
# Process dataset
hf_dataset = process_raw_data(
path_to_cv=run_opts["path_to_cv"], dict_df=data_df, sampling_rate=run_opts["sr"]
)
# Save to disk
hf_dataset.save_to_disk(run_opts["output_path"])


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Create a CommonVoice dataset for ASR")
parser.add_argument(
"--path_to_cv", type=str, help="Directory which contain CommonVoice Dataset"
)

parser.add_argument(
"--split_to_include",
default="validated",
nargs="+",
help="Splits of CommonVoice dataset to include in new dataset"
"Supported value {validated, invalidated, other}",
)

parser.add_argument(
"--sr",
type=int,
default=16000,
help="Sampling rate to create featutre audio, to create a dataset",
)

parser.add_argument(
"--output_path",
type=str,
default="hf_dataset",
help="Directory for saving HuggingFace DatasetDict",
)

args = parser.parse_args()
pipline(args)
42 changes: 42 additions & 0 deletions prepare_dataset/test_make_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Dict

import pandas as pd
import numpy as np

from .make_dataset import make_train_val_test_split


def compute_stats(dataset: Dict[str, pd.DataFrame], column: str) -> np.array:
"""Calculate intersection beetwen splits based on column from DataFrame.

Args:
dataset (Dict[str, pd.DataFrame]): Dict of train/val/test split DataFrames
column (str): Colunm in DataFrame to check intersection beetwen splits

Returns:
np.array: Square matrix 3x3 with intersection count beetwen splits.
"""
stats = [
[
len(set(f1_val[f"{column}"]).intersection(f2_val[f"{column}"]))
for f2_val in dataset.values()
]
for f1_val in dataset.values()
]
return np.array(stats)


def test_intersection_speaker():
path_to_cv = "/mnt/980pro/datasets/commonvoice14/cv-corpus-14.0-2023-06-23/be/"
split_to_include = ["validated", "invalidated", "other"]
dataset = make_train_val_test_split(path_to_cv, split_to_include)
stats = compute_stats(dataset, "client_id")
assert np.all(np.array(stats) == np.diag(np.diagonal(np.array(stats))))


def test_intersection_sentence():
path_to_cv = "/mnt/980pro/datasets/commonvoice14/cv-corpus-14.0-2023-06-23/be/"
split_to_include = ["validated", "invalidated", "other"]
dataset = make_train_val_test_split(path_to_cv, split_to_include)
stats = compute_stats(dataset, "sentence")
assert np.all(stats == np.diag(np.diagonal(stats)))
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ evaluate>=0.3.0
more-itertools
tensorboard
openpyxl
deepspeed
git+https://github.com/huggingface/accelerate.git@main
Loading