Skip to content

Commit

Permalink
timm model huggingface trial
Browse files Browse the repository at this point in the history
  • Loading branch information
hejonathan committed Aug 12, 2024
1 parent 3473bb2 commit 2028879
Show file tree
Hide file tree
Showing 5 changed files with 593 additions and 41 deletions.
1 change: 1 addition & 0 deletions pyha_analyzer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from . import sweeps
from . import train
from . import eval
from . import timm_huggingface
27 changes: 19 additions & 8 deletions pyha_analyzer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,19 +258,28 @@ def __getitem__(self, index): #-> Any:
index = index,
class_to_idx = self.class_to_idx,
conf=self.cfg)

# if self.train:
# audio, target = self.mixup(audio, target)
# audio = self.audio_augmentations(audio)


if self.train:
audio, target = self.mixup(audio, target)
audio = self.audio_augmentations(audio)
image = self.to_image(audio)
if self.train:
image = self.image_augmentations(image)

if image.isnan().any():
logger.error("ERROR IN ANNOTATION #%s", index)
self.bad_files.append(index)
image = torch.zeros(image.shape)
target = torch.zeros(target.shape)

#If dataframe has saved onehot encodings, return those
#Assume columns names are species names
if self.onehot:
target = self.samples.loc[index, self.classes].values.astype(np.int32)
target = torch.Tensor(target)

file_name = self.samples.iloc[index][self.cfg.file_name_col]
return audio, int(target.argmax()), file_name
return image, target

def get_num_classes(self) -> int:
""" Returns number of classes
Expand Down Expand Up @@ -365,6 +374,9 @@ def get_datasets(cfg) -> Tuple[PyhaDFDataset, PyhaDFDataset, Optional[PyhaDFData

valid = data[~data.index.isin(train.index)]

train = train[:len(train)//100]
valid = valid[:len(valid)//100]

train_ds = PyhaDFDataset(train, train=True, species=classes, cfg=cfg)

valid_ds = PyhaDFDataset(valid, train=False, species=classes, cfg=cfg)
Expand All @@ -378,7 +390,6 @@ def get_datasets(cfg) -> Tuple[PyhaDFDataset, PyhaDFDataset, Optional[PyhaDFData

return train_ds, valid_ds, infer_ds, classes


def set_torch_file_sharing(_) -> None:
"""
Sets torch.multiprocessing to use file sharing
Expand Down Expand Up @@ -455,4 +466,4 @@ def main() -> None:
# for _, (_, _) in enumerate(infer_dataloader):
# break
if __name__ == '__main__':
main()
main()
32 changes: 32 additions & 0 deletions pyha_analyzer/timm_huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from transformers import PreTrainedModel, PretrainedConfig
from transformers import AutoModel, AutoConfig
import timm

class ECAConfig(PretrainedConfig):
model_type = 'mymodel'
def __init__(self, model_name, **kwargs):
super().__init__(**kwargs)
self.model_name = model_name
self.config = None

class ECAModel(PreTrainedModel):
config_class = ECAConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.model = timm.create_model(self.config.model_name, pretrained=True)
self.config = self.model.default_cfg
print(self.config)
self.img_size = config["test_input_size"][-1] if "test_input_size" in config else config["input_size"][-1]

def get_transform(self):
return timm.data.transforms_factory.transforms_imagenet_eval(
img_size=self.img_size,
interpolation=self.config["interpolation"],
mean=self.config["mean"],
std=self.config["std"],
crop_pct=self.config["crop_pct"],
)

def forward(self, input):
return self.model(input)
67 changes: 34 additions & 33 deletions pyha_analyzer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
from pyha_analyzer.models.early_stopper import EarlyStopper
from pyha_analyzer.models.timm_model import TimmModel

from datasets import Dataset, DatasetDict, ClassLabel, Features, Value, Audio, Sequence
from datasets import Dataset, DatasetDict, ClassLabel, Features, Image

from huggingface_hub import notebook_login
from pyha_analyzer.timm_huggingface import ECAConfig, ECAModel

notebook_login()

Expand Down Expand Up @@ -356,20 +357,19 @@ def main(in_sweep=True) -> None:
def pytorch_dataset_to_hf_dataset(pytorch_dataset):
def generator():
for i in range(len(pytorch_dataset)):
audio, target, file_name = pytorch_dataset[i]
audio = audio.numpy().astype(np.float32)
image, target = pytorch_dataset[i]
image = image.numpy().astype(np.uint8)
#print(f"Shape of audio: {audio.shape}")
#print(f"Type of image_list: {type(image)}")
#print(f"Length of image_list: {len(image)}")
target = torch.argmax(target)
yield {
'audio': {'array': audio, 'path': file_name, 'sampling_rate': 16000},
'file': file_name,
'image': {'array': image, 'path':''},
'label': int(target) # Ensure target is an integer
}

features = Features({
'audio': Audio(sampling_rate=16000),
'file': Value('string'),
'image': Image(), # TODO: should it be np image or PIL image
'label': ClassLabel(names=classes) # Use the ClassLabel feature here
})

Expand Down Expand Up @@ -401,39 +401,41 @@ def generator():


## Training
model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
from transformers import AutoFeatureExtractor
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
print(feature_extractor)
max_duration = 5.0
model_name = "timm/eca_nfnet_l0"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device:', device)

config = ECAConfig(model_name)
model = ECAModel(config)
model.to(device)

model_name = model_name.split("/")[-1]

print('model loaded success')

transform = model.get_transform()
def preprocess_function(examples):
audio_arrays = [np.array(x["array"]) for x in examples["audio"]]
inputs = feature_extractor(
audio_arrays,
sampling_rate=feature_extractor.sampling_rate,
max_length=int(feature_extractor.sampling_rate * max_duration),
truncation=True,
)
return inputs

encoded_dataset = dataset.map(preprocess_function, remove_columns=["audio", "file"], batched=True)
image = [np.array(x["array"]) for x in examples["image"]]
image = transform(image) # TODO: read up on documentation in timm.data.transforms_factory.transforms_imagenet_eval, which image format does it accept (np vs PIL), does it accept grayscale
image = image.unsqueeze(0)
return image

from transformers import ASTForAudioClassification, TrainingArguments, Trainer

encoded_dataset = dataset.map(preprocess_function, remove_columns=["image", "file"], batched=True)

from transformers import TrainingArguments, Trainer

num_labels = len(id2label)
model = ASTForAudioClassification.from_pretrained(
model_checkpoint,
num_labels=num_labels,
label2id=label2id,
id2label=id2label,
ignore_mismatched_sizes=True
)
model.to(device)
# model = ECAModel.from_pretrained(
# model_checkpoint,
# num_labels=num_labels,
# label2id=label2id,
# id2label=id2label,
# ignore_mismatched_sizes=True
# )

model_name = model_checkpoint.split("/")[-1]


bs = 8
lr = 5e-6
Expand Down Expand Up @@ -497,7 +499,6 @@ def compute_metrics(eval_pred):
trainer.train()
trainer.save_model()
model.save_pretrained(save_path)
feature_extractor.save_pretrained(save_path)
trainer.push_to_hub()


Expand Down
Loading

0 comments on commit 2028879

Please sign in to comment.