Skip to content

Commit

Permalink
before timm
Browse files Browse the repository at this point in the history
  • Loading branch information
hejonathan committed Aug 11, 2024
1 parent 799fcb5 commit 3473bb2
Show file tree
Hide file tree
Showing 7 changed files with 1,019 additions and 15 deletions.
51 changes: 51 additions & 0 deletions compute_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pickle
from scipy.special import softmax
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score
from sklearn.preprocessing import label_binarize
from collections import Counter
from tqdm import tqdm

pickle_location = 'ast-finetuned-audioset-10-10-0.4593-bs8-lr1e-05/checkpoint-24000/logits_labels.pkl'

# Open the file in binary read mode
with open(pickle_location, 'rb') as file:
# Load the data using pickle
data = pickle.load(file)

logits = data['logits'][0]
labels = data['labels'][0]

prob = softmax(logits, axis=-1)
pred = np.argmax(logits, axis=-1)

print('MultiClass:')
precision = precision_score(labels, pred, average='macro', zero_division=1)
print(f"Precision: {precision}")
recall = recall_score(labels, pred, average='macro')
print(f"Recall: {recall}")
f1 = f1_score(labels, pred, average='macro')
print(f"F1 Score: {f1}")

roc_auc = roc_auc_score(labels, prob, average='macro', multi_class='ovr')

print(f"ROC AUC: {roc_auc}")

print()
print('MultiLabel:')


labels_onehot = label_binarize(labels, classes=np.arange(logits.shape[1]))
thres = 0.5
_prob = prob
prob = (prob > thres).astype(int)

# Multi-label metrics
precision = precision_score(labels_onehot, prob, average='macro', zero_division=1)
print(f"Precision: {precision}")
recall = recall_score(labels_onehot, prob, average='macro')
print(f"Recall: {recall}")
f1 = f1_score(labels_onehot, prob, average='macro')
print(f"F1 Score: {f1}")
roc_auc = roc_auc_score(labels_onehot, _prob, average='macro', multi_class='ovr')
print(f"ROC AUC: {roc_auc}")
375 changes: 375 additions & 0 deletions compute_metrics.ipynb

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions entry_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
""" Entry point into PyHa Analyzer train function """
import torch
import pyha_analyzer as pa

if __name__ == '__main__':
torch.multiprocessing.set_sharing_strategy('file_system')
torch.multiprocessing.set_start_method('spawn')
DO_TRAIN = True
if DO_TRAIN:
pa.eval.main(in_sweep=False)
else:
pa.sweeps.main()
1 change: 1 addition & 0 deletions pyha_analyzer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from os import path
from . import sweeps
from . import train
from . import eval
1 change: 0 additions & 1 deletion pyha_analyzer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,6 @@ def get_datasets(cfg) -> Tuple[PyhaDFDataset, PyhaDFDataset, Optional[PyhaDFData
infer_ds = None
else:
infer = pd.read_csv(cfg.infer_csv)
infer = infer[:len(infer)//100]
infer_ds = PyhaDFDataset(infer, train=False, species=classes, onehot=True, cfg=cfg)

return train_ds, valid_ds, infer_ds, classes
Expand Down
Loading

0 comments on commit 3473bb2

Please sign in to comment.