Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
#7 , Added Finetuned Dasheng on Audioset.
  • Loading branch information
RicherMans authored Aug 13, 2024
1 parent 70024aa commit 750e7b5
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,46 @@ MultiGPU support is realized using [Accelerate](https://huggingface.co/docs/acce
accelerate launch --mixed_precision='bf16' dasheng/train/train.py dasheng/train/config/dasheng_base.yaml
```

## FAQ

### Is there an Audioset-finetuned Dasheng?

Yes, the performance for the base model is 49.7 mAP. One can use it as follows:

```python
from typing import Any, Mapping
import dasheng
import torch

class DashengAudiosetClassifier(torch.nn.Module):

def __init__(self) -> None:
super().__init__()
self.dashengmodel = dasheng.dasheng_base()
self.classifier = torch.nn.Sequential(torch.nn.LayerNorm(self.dashengmodel.embed_dim), torch.nn.Linear(self.dashengmodel.embed_dim, 527))

def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False):
self.dashengmodel.load_state_dict(state_dict, strict=False)
for_classifier_dict = {}
for k,v in state_dict.items():
if 'outputlayer' in k:
for_classifier_dict[k.replace('outputlayer.','')] = v
self.classifier.load_state_dict(for_classifier_dict)
return self

def forward(self, x):
x = self.dashengmodel(x).mean(1)
return self.classifier(x).sigmoid()


mdl = DashengAudiosetClassifier()
check = torch.hub.load_state_dict_from_url('https://zenodo.org/records/13315686/files/dasheng_audioset_mAP497.pt?download=1',map_location='cpu')
mdl.load_state_dict(check)

prediction = mdl(torch.randn(1,16000))
```


## Citation

```bibtex
Expand Down

0 comments on commit 750e7b5

Please sign in to comment.