Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
hejonathan committed Aug 13, 2024
1 parent 3473bb2 commit d76313f
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 70 deletions.
26 changes: 18 additions & 8 deletions pyha_analyzer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,19 +258,28 @@ def __getitem__(self, index): #-> Any:
index = index,
class_to_idx = self.class_to_idx,
conf=self.cfg)



# if self.train:
# audio, target = self.mixup(audio, target)
# audio = self.audio_augmentations(audio)
# audio, target = self.mixup(audio, target)
# audio = self.audio_augmentations(audio)
image = self.to_image(audio)
# if self.train:
# image = self.image_augmentations(image)

if image.isnan().any():
logger.error("ERROR IN ANNOTATION #%s", index)
self.bad_files.append(index)
image = torch.zeros(image.shape)
target = torch.zeros(target.shape)

#If dataframe has saved onehot encodings, return those
#Assume columns names are species names
if self.onehot:
target = self.samples.loc[index, self.classes].values.astype(np.int32)
target = torch.Tensor(target)

file_name = self.samples.iloc[index][self.cfg.file_name_col]
return audio, int(target.argmax()), file_name
return image, target

def get_num_classes(self) -> int:
""" Returns number of classes
Expand Down Expand Up @@ -362,9 +371,11 @@ def get_datasets(cfg) -> Tuple[PyhaDFDataset, PyhaDFDataset, Optional[PyhaDFData
lambda x: pd.Series(x[cfg.file_name_col].unique()).sample(frac=train_p)
)
train = data[data[cfg.file_name_col].isin(train_files)]

valid = data[~data.index.isin(train.index)]

train = train[:len(train)//100]
valid = valid[:len(valid)//100]

train_ds = PyhaDFDataset(train, train=True, species=classes, cfg=cfg)

valid_ds = PyhaDFDataset(valid, train=False, species=classes, cfg=cfg)
Expand All @@ -378,7 +389,6 @@ def get_datasets(cfg) -> Tuple[PyhaDFDataset, PyhaDFDataset, Optional[PyhaDFData

return train_ds, valid_ds, infer_ds, classes


def set_torch_file_sharing(_) -> None:
"""
Sets torch.multiprocessing to use file sharing
Expand Down Expand Up @@ -455,4 +465,4 @@ def main() -> None:
# for _, (_, _) in enumerate(infer_dataloader):
# break
if __name__ == '__main__':
main()
main()
33 changes: 6 additions & 27 deletions pyha_analyzer/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def main(in_sweep=True) -> None:

# Load in dataset
logger.info("Loading Dataset...")
train_dataset, val_dataset, infer_dataset, classes = get_datasets(cfg)
#train_dataset, val_dataset, infer_dataset, classes = get_datasets(cfg)


print(train_dataset)
Expand Down Expand Up @@ -468,15 +468,10 @@ def preprocess_function(examples):
from scipy.special import softmax
# Load the metric for evaluation
metric_precision = load("precision", trust_remote_code=True)
metric_precision_multi = load("precision", "multilabel", trust_remote_code=True)
metric_recall = load("recall", trust_remote_code=True)
metric_recall_multi = load("recall", "multilabel", trust_remote_code=True)
metric_f1 = load("f1", trust_remote_code=True)
metric_f1_multi = load("f1", "multilabel", trust_remote_code=True)
metric_roc_auc = load("roc_auc", "multiclass", trust_remote_code=True)
metric_roc_auc_multi = load("roc_auc", "multilabel", trust_remote_code=True)
metric_accuracy = load("accuracy", trust_remote_code=True)
metric_accuracy_multi = load("accuracy", "multilabel", trust_remote_code=True)

all_logits = []
all_labels = []
Expand All @@ -492,31 +487,15 @@ def compute_metrics(eval_pred):
accuracy = metric_accuracy.compute(predictions=predictions, references=labels)
roc_auc = metric_roc_auc.compute(prediction_scores=prob, references=labels, average="macro", multi_class="ovr")

# onehot = np.zeros((labels.size, labels.max()+1), dtype=int)
# onehot[:,labels] = 1
# predictions_onehot = np.zeros((predictions.size, predictions.max()+1), dtype=int)
# predictions_onehot[:,predictions] = 1

# precision_multi = metric_precision_multi.compute(predictions=predictions_onehot, references=onehot, average='macro')
# recall_multi = metric_recall_multi.compute(predictions=predictions_onehot, references=labels, average='macro')
# f1_multi = metric_f1_multi.compute(predictions=predictions_onehot, references=onehot, average='macro')
# accuracy_multi = metric_accuracy_multi.compute(predictions=predictions_onehot, references=onehot)
# roc_auc_multi = metric_roc_auc_multi.compute(prediction_scores=prob, references=onehot, average="macro", multi_class="ovr")

all_logits.append(logits)
all_labels.append(labels)

return {
'precision': precision,
'recall': recall,
'f1': f1,
'accuracy': accuracy,
'roc_auc': roc_auc,
# 'precision_multi': precision_multi,
# 'recall_multi': recall_multi,
# 'f1_multi': f1_multi,
# 'accuracy_multi': accuracy_multi,
# 'roc_auc_multi': roc_auc_multi,
**precision,
**recall,
**f1,
**accuracy,
**roc_auc,
}


Expand Down
85 changes: 50 additions & 35 deletions pyha_analyzer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pyha_analyzer.models.early_stopper import EarlyStopper
from pyha_analyzer.models.timm_model import TimmModel

from datasets import Dataset, DatasetDict, ClassLabel, Features, Value, Audio, Sequence
from datasets import Dataset, DatasetDict, ClassLabel, Features, Value, Image

from huggingface_hub import notebook_login

Expand Down Expand Up @@ -351,25 +351,23 @@ def main(in_sweep=True) -> None:
# Load in dataset
logger.info("Loading Dataset...")
train_dataset, val_dataset, infer_dataset, classes = get_datasets(cfg)


print(train_dataset)
def pytorch_dataset_to_hf_dataset(pytorch_dataset):
def generator():
for i in range(len(pytorch_dataset)):
audio, target, file_name = pytorch_dataset[i]
audio = audio.numpy().astype(np.float32)
#print(f"Shape of audio: {audio.shape}")
#print(f"Type of image_list: {type(image)}")
#print(f"Length of image_list: {len(image)}")
image, target = pytorch_dataset[i]
image = image.numpy().astype(np.float32)
target = torch.argmax(target)

yield {
'audio': {'array': audio, 'path': file_name, 'sampling_rate': 16000},
'file': file_name,
'image': {'array': image, 'path': ""},
'label': int(target) # Ensure target is an integer
}

features = Features({
'audio': Audio(sampling_rate=16000),
'file': Value('string'),
'image': Image(),
'label': ClassLabel(names=classes) # Use the ClassLabel feature here
})

Expand Down Expand Up @@ -401,30 +399,23 @@ def generator():


## Training
model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
from transformers import AutoFeatureExtractor
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
model_checkpoint = "microsoft/resnet-50"
from transformers import AutoImageProcessor, ResNetForImageClassification, Trainer, TrainingArguments
feature_extractor = AutoImageProcessor.from_pretrained(model_checkpoint)
print(feature_extractor)
max_duration = 5.0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device:', device)
def preprocess_function(examples):
audio_arrays = [np.array(x["array"]) for x in examples["audio"]]
inputs = feature_extractor(
audio_arrays,
sampling_rate=feature_extractor.sampling_rate,
max_length=int(feature_extractor.sampling_rate * max_duration),
truncation=True,
)
images = examples["image"]
inputs = feature_extractor(images, return_tensors="pt")
return inputs

encoded_dataset = dataset.map(preprocess_function, remove_columns=["audio", "file"], batched=True)
encoded_dataset = dataset.map(preprocess_function, remove_columns=["image"], batched=True, load_from_cache_file=False)

from transformers import ASTForAudioClassification, TrainingArguments, Trainer

num_labels = len(id2label)
model = ASTForAudioClassification.from_pretrained(
model = ResNetForImageClassification.from_pretrained(
model_checkpoint,
num_labels=num_labels,
label2id=label2id,
Expand All @@ -433,13 +424,17 @@ def preprocess_function(examples):
)
model.to(device)

model_name = model_checkpoint.split("/")[-1]

bs = 8
lr = 5e-6
lr = 0.1

model_name = model_checkpoint.split("/")[-1]
model_name = f"{model_name}-bs{bs}-lr{lr}"
train_dataset.samples.to_csv(f"{model_name}/train_samples.csv")
val_dataset.samples.to_csv(f"{model_name}/valid_samples.csv")

args = TrainingArguments(
f"{model_name}-bs{bs}-lr{lr}",
model_name,
eval_strategy = "steps",
eval_steps = 8000,
save_strategy = "steps",
Expand All @@ -448,8 +443,8 @@ def preprocess_function(examples):
per_device_train_batch_size=bs,
gradient_accumulation_steps=2,
per_device_eval_batch_size=bs,
num_train_epochs=3,
warmup_ratio=0.125,
num_train_epochs=10,
warmup_ratio=0.1,
logging_steps=10,
load_best_model_at_end=True,
metric_for_best_model="precision",
Expand All @@ -465,6 +460,7 @@ def preprocess_function(examples):
from scipy.special import softmax
# Load the metric for evaluation
metric_precision = load("precision", trust_remote_code=True)
metric_recall = load("recall", trust_remote_code=True)
metric_f1 = load("f1", trust_remote_code=True)
metric_roc_auc = load("roc_auc", "multiclass", trust_remote_code=True)
metric_accuracy = load("accuracy", trust_remote_code=True)
Expand All @@ -476,13 +472,32 @@ def compute_metrics(eval_pred):
prob = softmax(logits, axis=-1)

precision = metric_precision.compute(predictions=predictions, references=labels, average='macro')
recall = metric_recall.compute(predictions=predictions, references=labels, average='macro')
f1 = metric_f1.compute(predictions=predictions, references=labels, average='macro')
accuracy = metric_accuracy.compute(predictions=predictions, references=labels)
try:
roc_auc_mac = metric_roc_auc.compute(prediction_scores=prob, references=labels, average="macro", multi_class="ovr")
return {**precision, **f1, **roc_auc_mac, **accuracy}
except Exception as e:
return {**precision, **f1, **accuracy}
roc_auc = metric_roc_auc.compute(prediction_scores=prob, references=labels, average="macro", multi_class="ovr")

import pickle
pickle_file = f'{model_name}/logits_labels.pkl'

data_to_pickle = {
'logits': logits,
'labels': labels,
'id2label': id2label
}

# Save the lists to a pickle file
with open(pickle_file, 'wb') as file:
pickle.dump(data_to_pickle, file)

return {
**precision,
**recall,
**f1,
**accuracy,
**roc_auc,
}



trainer = Trainer(
Expand All @@ -494,11 +509,11 @@ def compute_metrics(eval_pred):
compute_metrics=compute_metrics,
)
save_path = '.'

trainer.train()
trainer.save_model()
model.save_pretrained(save_path)
feature_extractor.save_pretrained(save_path)
trainer.push_to_hub()


if __name__ == '__main__':
Expand Down

0 comments on commit d76313f

Please sign in to comment.