-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: update to latest pyannote and wavesurfer (#3)
- Loading branch information
Showing
3 changed files
with
87 additions
and
97 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,10 +23,10 @@ | |
|
||
import io | ||
import base64 | ||
import torch | ||
import numpy as np | ||
import scipy.io.wavfile | ||
from typing import Text | ||
from huggingface_hub import HfApi | ||
import streamlit as st | ||
from pyannote.audio import Pipeline | ||
from pyannote.audio import Audio | ||
|
@@ -49,32 +49,47 @@ def to_base64(waveform: np.ndarray, sample_rate: int = 16000) -> Text: | |
PYANNOTE_LOGO = "https://avatars.githubusercontent.com/u/7559051?s=400&v=4" | ||
EXCERPT = 30.0 | ||
|
||
st.set_page_config( | ||
page_title="pyannote.audio pretrained pipelines", page_icon=PYANNOTE_LOGO | ||
) | ||
st.set_page_config(page_title="pyannote pretrained pipelines", page_icon=PYANNOTE_LOGO) | ||
|
||
col1, col2 = st.columns([0.2, 0.8], gap="small") | ||
|
||
st.sidebar.image(PYANNOTE_LOGO) | ||
with col1: | ||
st.image(PYANNOTE_LOGO) | ||
|
||
with col2: | ||
st.markdown( | ||
""" | ||
# pretrained pipelines | ||
Make the most of [pyannote](https://github.com/pyannote) thanks to our [consulting services](https://herve.niderb.fr/consulting.html) | ||
""" | ||
) | ||
|
||
st.markdown("""# 🎹 Pretrained pipelines | ||
""") | ||
|
||
PIPELINES = [ | ||
p.modelId | ||
for p in HfApi().list_models(filter="pyannote-audio-pipeline") | ||
if p.modelId.startswith("pyannote/") | ||
"pyannote/speaker-diarization-3.0", | ||
] | ||
|
||
audio = Audio(sample_rate=16000, mono=True) | ||
|
||
selected_pipeline = st.selectbox("Select a pipeline", PIPELINES, index=0) | ||
selected_pipeline = st.selectbox("Select a pretrained pipeline", PIPELINES, index=0) | ||
|
||
|
||
with st.spinner("Loading pipeline..."): | ||
pipeline = Pipeline.from_pretrained(selected_pipeline, use_auth_token=st.secrets["PYANNOTE_TOKEN"]) | ||
try: | ||
use_auth_token = st.secrets["PYANNOTE_TOKEN"] | ||
except FileNotFoundError: | ||
use_auth_token = None | ||
except KeyError: | ||
use_auth_token = None | ||
|
||
pipeline = Pipeline.from_pretrained( | ||
selected_pipeline, use_auth_token=use_auth_token | ||
) | ||
if torch.cuda.is_available(): | ||
pipeline.to(torch.device("cuda")) | ||
|
||
uploaded_file = st.file_uploader("Choose an audio file") | ||
uploaded_file = st.file_uploader("Upload an audio file") | ||
if uploaded_file is not None: | ||
|
||
try: | ||
duration = audio.get_duration(uploaded_file) | ||
except RuntimeError as e: | ||
|
@@ -86,12 +101,12 @@ def to_base64(waveform: np.ndarray, sample_rate: int = 16000) -> Text: | |
uri = "".join(uploaded_file.name.split()) | ||
file = {"waveform": waveform, "sample_rate": sample_rate, "uri": uri} | ||
|
||
with st.spinner(f"Processing first {EXCERPT:g} seconds..."): | ||
with st.spinner(f"Processing {EXCERPT:g} seconds..."): | ||
output = pipeline(file) | ||
|
||
with open('assets/template.html') as html, open('assets/style.css') as css: | ||
with open("assets/template.html") as html, open("assets/style.css") as css: | ||
html_template = html.read() | ||
st.markdown('<style>{}</style>'.format(css.read()), unsafe_allow_html=True) | ||
st.markdown("<style>{}</style>".format(css.read()), unsafe_allow_html=True) | ||
|
||
colors = [ | ||
"#ffd70033", | ||
|
@@ -105,50 +120,36 @@ def to_base64(waveform: np.ndarray, sample_rate: int = 16000) -> Text: | |
] | ||
num_colors = len(colors) | ||
|
||
label2color = {label: colors[k % num_colors] for k, label in enumerate(sorted(output.labels()))} | ||
label2color = { | ||
label: colors[k % num_colors] for k, label in enumerate(sorted(output.labels())) | ||
} | ||
|
||
BASE64 = to_base64(waveform.numpy().T) | ||
|
||
REGIONS = "" | ||
LEGENDS = "" | ||
labels=[] | ||
for segment, _, label in output.itertracks(yield_label=True): | ||
REGIONS += f"var re = wavesurfer.addRegion({{start: {segment.start:g}, end: {segment.end:g}, color: '{label2color[label]}', resize : false, drag : false}});" | ||
if not label in labels: | ||
LEGENDS += f"<li><span style='background-color:{label2color[label]}'></span>{label}</li>" | ||
labels.append(label) | ||
REGIONS += f"regions.addRegion({{start: {segment.start:g}, end: {segment.end:g}, color: '{label2color[label]}', resize : false, drag : false}});" | ||
|
||
html = html_template.replace("BASE64", BASE64).replace("REGIONS", REGIONS) | ||
components.html(html, height=250, scrolling=True) | ||
st.markdown("<div style='overflow : auto'><ul class='legend'>"+LEGENDS+"</ul></div>", unsafe_allow_html=True) | ||
|
||
st.markdown("---") | ||
|
||
with io.StringIO() as fp: | ||
output.write_rttm(fp) | ||
content = fp.getvalue() | ||
|
||
b64 = base64.b64encode(content.encode()).decode() | ||
href = f'Download as <a download="{output.uri}.rttm" href="data:file/text;base64,{b64}">RTTM</a> or run it on the whole {int(duration):d}s file:' | ||
href = f'<a download="{output.uri}.rttm" href="data:file/text;base64,{b64}">Download</a> result in RTTM file format or run it locally:' | ||
st.markdown(href, unsafe_allow_html=True) | ||
|
||
code = f""" | ||
from pyannote.audio import Pipeline | ||
pipeline = Pipeline.from_pretrained("{selected_pipeline}") | ||
output = pipeline("{uploaded_file.name}") | ||
""" | ||
st.code(code, language='python') | ||
|
||
|
||
|
||
st.sidebar.markdown( | ||
""" | ||
------------------- | ||
To use these pipelines on more and longer files on your own (GPU, hence much faster) servers, check the [documentation](https://github.com/pyannote/pyannote-audio). | ||
# load pretrained pipeline | ||
from pyannote.audio import Pipeline | ||
pipeline = Pipeline.from_pretrained("{selected_pipeline}", | ||
use_auth_token=HUGGINGFACE_TOKEN) | ||
For [technical questions](https://github.com/pyannote/pyannote-audio/discussions) and [bug reports](https://github.com/pyannote/pyannote-audio/issues), please check [pyannote.audio](https://github.com/pyannote/pyannote-audio) Github repository. | ||
# (optional) send pipeline to GPU | ||
import torch | ||
pipeline.to(torch.device("cuda")) | ||
For commercial enquiries and scientific consulting, please contact [me](mailto:[email protected]). | ||
""" | ||
) | ||
# process audio file | ||
output = pipeline("audio.wav")""" | ||
st.code(code, language="python") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,41 @@ | ||
<script src="https://unpkg.com/wavesurfer.js"></script> | ||
<script src="https://unpkg.com/wavesurfer.js/dist/plugin/wavesurfer.regions.min.js"></script> | ||
<script src="https://unpkg.com/wavesurfer.js/dist/plugin/wavesurfer.timeline.min.js"></script> | ||
<br> | ||
<div id="waveform"></div> | ||
<div id="timeline"></div> | ||
<br> | ||
<div><button onclick="play()" id="ppb">Play</button><div> | ||
<script type="text/javascript"> | ||
var labels=[]; | ||
var wavesurfer = WaveSurfer.create({ | ||
container: '#waveform', | ||
barGap: 2, | ||
barHeight: 3, | ||
barWidth: 3, | ||
barRadius: 2, | ||
plugins: [ | ||
WaveSurfer.regions.create({}), | ||
WaveSurfer.timeline.create({ | ||
container: "#timeline", | ||
notchPercentHeight: 40, | ||
primaryColor: "#444", | ||
primaryFontColor: "#444" | ||
}) | ||
] | ||
}); | ||
wavesurfer.load('BASE64'); | ||
wavesurfer.on('ready', function () { | ||
wavesurfer.play(); | ||
}); | ||
wavesurfer.on('play',function() { | ||
document.getElementById('ppb').innerHTML = "Pause"; | ||
}); | ||
wavesurfer.on('pause',function() { | ||
document.getElementById('ppb').innerHTML = "Play"; | ||
}); | ||
<script type="module"> | ||
import WaveSurfer from 'https://unpkg.com/wavesurfer.js@7/dist/wavesurfer.esm.js' | ||
import RegionsPlugin from 'https://unpkg.com/wavesurfer.js@7/dist/plugins/regions.esm.js' | ||
|
||
|
||
var labels=[]; | ||
const wavesurfer = WaveSurfer.create({ | ||
container: '#waveform', | ||
barGap: 2, | ||
barHeight: 3, | ||
barWidth: 3, | ||
barRadius: 2, | ||
}); | ||
|
||
const regions = wavesurfer.registerPlugin(RegionsPlugin.create()) | ||
|
||
wavesurfer.load('BASE64'); | ||
wavesurfer.on('ready', function () { | ||
wavesurfer.play(); | ||
}); | ||
|
||
wavesurfer.on('decode', function () { | ||
|
||
REGIONS | ||
document.addEventListener('keyup', event => { | ||
if (event.code === 'Space') { | ||
play(); | ||
} | ||
}) | ||
function play(){ | ||
wavesurfer.isPlaying() ? wavesurfer.pause() : wavesurfer.play(); | ||
} | ||
|
||
wavesurfer.play(); | ||
|
||
}); | ||
|
||
wavesurfer.on('click', () => { | ||
play(); | ||
}); | ||
|
||
|
||
function play(){ | ||
wavesurfer.isPlaying() ? wavesurfer.pause() : wavesurfer.play(); | ||
} | ||
|
||
</script> | ||
<div id="waveform"></div> | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1 @@ | ||
torch==1.11.0 | ||
torchvision==0.12.0 | ||
torchaudio==0.11.0 | ||
torchtext==0.12.0 | ||
speechbrain==0.5.12 | ||
pyannote-audio>=2.1 | ||
|
||
pyannote-audio==3.0.1 |