Skip to content

Commit

Permalink
feat: update to latest pyannote and wavesurfer (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin authored Oct 6, 2023
1 parent 98df54b commit 57604a5
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 97 deletions.
93 changes: 47 additions & 46 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@

import io
import base64
import torch
import numpy as np
import scipy.io.wavfile
from typing import Text
from huggingface_hub import HfApi
import streamlit as st
from pyannote.audio import Pipeline
from pyannote.audio import Audio
Expand All @@ -49,32 +49,47 @@ def to_base64(waveform: np.ndarray, sample_rate: int = 16000) -> Text:
PYANNOTE_LOGO = "https://avatars.githubusercontent.com/u/7559051?s=400&v=4"
EXCERPT = 30.0

st.set_page_config(
page_title="pyannote.audio pretrained pipelines", page_icon=PYANNOTE_LOGO
)
st.set_page_config(page_title="pyannote pretrained pipelines", page_icon=PYANNOTE_LOGO)

col1, col2 = st.columns([0.2, 0.8], gap="small")

st.sidebar.image(PYANNOTE_LOGO)
with col1:
st.image(PYANNOTE_LOGO)

with col2:
st.markdown(
"""
# pretrained pipelines
Make the most of [pyannote](https://github.com/pyannote) thanks to our [consulting services](https://herve.niderb.fr/consulting.html)
"""
)

st.markdown("""# 🎹 Pretrained pipelines
""")

PIPELINES = [
p.modelId
for p in HfApi().list_models(filter="pyannote-audio-pipeline")
if p.modelId.startswith("pyannote/")
"pyannote/speaker-diarization-3.0",
]

audio = Audio(sample_rate=16000, mono=True)

selected_pipeline = st.selectbox("Select a pipeline", PIPELINES, index=0)
selected_pipeline = st.selectbox("Select a pretrained pipeline", PIPELINES, index=0)


with st.spinner("Loading pipeline..."):
pipeline = Pipeline.from_pretrained(selected_pipeline, use_auth_token=st.secrets["PYANNOTE_TOKEN"])
try:
use_auth_token = st.secrets["PYANNOTE_TOKEN"]
except FileNotFoundError:
use_auth_token = None
except KeyError:
use_auth_token = None

pipeline = Pipeline.from_pretrained(
selected_pipeline, use_auth_token=use_auth_token
)
if torch.cuda.is_available():
pipeline.to(torch.device("cuda"))

uploaded_file = st.file_uploader("Choose an audio file")
uploaded_file = st.file_uploader("Upload an audio file")
if uploaded_file is not None:

try:
duration = audio.get_duration(uploaded_file)
except RuntimeError as e:
Expand All @@ -86,12 +101,12 @@ def to_base64(waveform: np.ndarray, sample_rate: int = 16000) -> Text:
uri = "".join(uploaded_file.name.split())
file = {"waveform": waveform, "sample_rate": sample_rate, "uri": uri}

with st.spinner(f"Processing first {EXCERPT:g} seconds..."):
with st.spinner(f"Processing {EXCERPT:g} seconds..."):
output = pipeline(file)

with open('assets/template.html') as html, open('assets/style.css') as css:
with open("assets/template.html") as html, open("assets/style.css") as css:
html_template = html.read()
st.markdown('<style>{}</style>'.format(css.read()), unsafe_allow_html=True)
st.markdown("<style>{}</style>".format(css.read()), unsafe_allow_html=True)

colors = [
"#ffd70033",
Expand All @@ -105,50 +120,36 @@ def to_base64(waveform: np.ndarray, sample_rate: int = 16000) -> Text:
]
num_colors = len(colors)

label2color = {label: colors[k % num_colors] for k, label in enumerate(sorted(output.labels()))}
label2color = {
label: colors[k % num_colors] for k, label in enumerate(sorted(output.labels()))
}

BASE64 = to_base64(waveform.numpy().T)

REGIONS = ""
LEGENDS = ""
labels=[]
for segment, _, label in output.itertracks(yield_label=True):
REGIONS += f"var re = wavesurfer.addRegion({{start: {segment.start:g}, end: {segment.end:g}, color: '{label2color[label]}', resize : false, drag : false}});"
if not label in labels:
LEGENDS += f"<li><span style='background-color:{label2color[label]}'></span>{label}</li>"
labels.append(label)
REGIONS += f"regions.addRegion({{start: {segment.start:g}, end: {segment.end:g}, color: '{label2color[label]}', resize : false, drag : false}});"

html = html_template.replace("BASE64", BASE64).replace("REGIONS", REGIONS)
components.html(html, height=250, scrolling=True)
st.markdown("<div style='overflow : auto'><ul class='legend'>"+LEGENDS+"</ul></div>", unsafe_allow_html=True)

st.markdown("---")

with io.StringIO() as fp:
output.write_rttm(fp)
content = fp.getvalue()

b64 = base64.b64encode(content.encode()).decode()
href = f'Download as <a download="{output.uri}.rttm" href="data:file/text;base64,{b64}">RTTM</a> or run it on the whole {int(duration):d}s file:'
href = f'<a download="{output.uri}.rttm" href="data:file/text;base64,{b64}">Download</a> result in RTTM file format or run it locally:'
st.markdown(href, unsafe_allow_html=True)

code = f"""
from pyannote.audio import Pipeline
pipeline = Pipeline.from_pretrained("{selected_pipeline}")
output = pipeline("{uploaded_file.name}")
"""
st.code(code, language='python')



st.sidebar.markdown(
"""
-------------------
To use these pipelines on more and longer files on your own (GPU, hence much faster) servers, check the [documentation](https://github.com/pyannote/pyannote-audio).
# load pretrained pipeline
from pyannote.audio import Pipeline
pipeline = Pipeline.from_pretrained("{selected_pipeline}",
use_auth_token=HUGGINGFACE_TOKEN)
For [technical questions](https://github.com/pyannote/pyannote-audio/discussions) and [bug reports](https://github.com/pyannote/pyannote-audio/issues), please check [pyannote.audio](https://github.com/pyannote/pyannote-audio) Github repository.
# (optional) send pipeline to GPU
import torch
pipeline.to(torch.device("cuda"))
For commercial enquiries and scientific consulting, please contact [me](mailto:[email protected]).
"""
)
# process audio file
output = pipeline("audio.wav")"""
st.code(code, language="python")
83 changes: 39 additions & 44 deletions assets/template.html
Original file line number Diff line number Diff line change
@@ -1,46 +1,41 @@
<script src="https://unpkg.com/wavesurfer.js"></script>
<script src="https://unpkg.com/wavesurfer.js/dist/plugin/wavesurfer.regions.min.js"></script>
<script src="https://unpkg.com/wavesurfer.js/dist/plugin/wavesurfer.timeline.min.js"></script>
<br>
<div id="waveform"></div>
<div id="timeline"></div>
<br>
<div><button onclick="play()" id="ppb">Play</button><div>
<script type="text/javascript">
var labels=[];
var wavesurfer = WaveSurfer.create({
container: '#waveform',
barGap: 2,
barHeight: 3,
barWidth: 3,
barRadius: 2,
plugins: [
WaveSurfer.regions.create({}),
WaveSurfer.timeline.create({
container: "#timeline",
notchPercentHeight: 40,
primaryColor: "#444",
primaryFontColor: "#444"
})
]
});
wavesurfer.load('BASE64');
wavesurfer.on('ready', function () {
wavesurfer.play();
});
wavesurfer.on('play',function() {
document.getElementById('ppb').innerHTML = "Pause";
});
wavesurfer.on('pause',function() {
document.getElementById('ppb').innerHTML = "Play";
});
<script type="module">
import WaveSurfer from 'https://unpkg.com/wavesurfer.js@7/dist/wavesurfer.esm.js'
import RegionsPlugin from 'https://unpkg.com/wavesurfer.js@7/dist/plugins/regions.esm.js'


var labels=[];
const wavesurfer = WaveSurfer.create({
container: '#waveform',
barGap: 2,
barHeight: 3,
barWidth: 3,
barRadius: 2,
});

const regions = wavesurfer.registerPlugin(RegionsPlugin.create())

wavesurfer.load('BASE64');
wavesurfer.on('ready', function () {
wavesurfer.play();
});

wavesurfer.on('decode', function () {

REGIONS
document.addEventListener('keyup', event => {
if (event.code === 'Space') {
play();
}
})
function play(){
wavesurfer.isPlaying() ? wavesurfer.pause() : wavesurfer.play();
}

wavesurfer.play();

});

wavesurfer.on('click', () => {
play();
});


function play(){
wavesurfer.isPlaying() ? wavesurfer.pause() : wavesurfer.play();
}

</script>
<div id="waveform"></div>

8 changes: 1 addition & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1 @@
torch==1.11.0
torchvision==0.12.0
torchaudio==0.11.0
torchtext==0.12.0
speechbrain==0.5.12
pyannote-audio>=2.1

pyannote-audio==3.0.1

0 comments on commit 57604a5

Please sign in to comment.