diff --git a/app.py b/app.py index 568851d..2bece9e 100644 --- a/app.py +++ b/app.py @@ -23,10 +23,10 @@ import io import base64 +import torch import numpy as np import scipy.io.wavfile from typing import Text -from huggingface_hub import HfApi import streamlit as st from pyannote.audio import Pipeline from pyannote.audio import Audio @@ -49,32 +49,47 @@ def to_base64(waveform: np.ndarray, sample_rate: int = 16000) -> Text: PYANNOTE_LOGO = "https://avatars.githubusercontent.com/u/7559051?s=400&v=4" EXCERPT = 30.0 -st.set_page_config( - page_title="pyannote.audio pretrained pipelines", page_icon=PYANNOTE_LOGO -) +st.set_page_config(page_title="pyannote pretrained pipelines", page_icon=PYANNOTE_LOGO) +col1, col2 = st.columns([0.2, 0.8], gap="small") -st.sidebar.image(PYANNOTE_LOGO) +with col1: + st.image(PYANNOTE_LOGO) + +with col2: + st.markdown( + """ +# pretrained pipelines +Make the most of [pyannote](https://github.com/pyannote) thanks to our [consulting services](https://herve.niderb.fr/consulting.html) +""" + ) -st.markdown("""# 🎹 Pretrained pipelines -""") PIPELINES = [ - p.modelId - for p in HfApi().list_models(filter="pyannote-audio-pipeline") - if p.modelId.startswith("pyannote/") + "pyannote/speaker-diarization-3.0", ] audio = Audio(sample_rate=16000, mono=True) -selected_pipeline = st.selectbox("Select a pipeline", PIPELINES, index=0) +selected_pipeline = st.selectbox("Select a pretrained pipeline", PIPELINES, index=0) + with st.spinner("Loading pipeline..."): - pipeline = Pipeline.from_pretrained(selected_pipeline, use_auth_token=st.secrets["PYANNOTE_TOKEN"]) + try: + use_auth_token = st.secrets["PYANNOTE_TOKEN"] + except FileNotFoundError: + use_auth_token = None + except KeyError: + use_auth_token = None + + pipeline = Pipeline.from_pretrained( + selected_pipeline, use_auth_token=use_auth_token + ) + if torch.cuda.is_available(): + pipeline.to(torch.device("cuda")) -uploaded_file = st.file_uploader("Choose an audio file") +uploaded_file = st.file_uploader("Upload an audio file") if uploaded_file is not None: - try: duration = audio.get_duration(uploaded_file) except RuntimeError as e: @@ -86,12 +101,12 @@ def to_base64(waveform: np.ndarray, sample_rate: int = 16000) -> Text: uri = "".join(uploaded_file.name.split()) file = {"waveform": waveform, "sample_rate": sample_rate, "uri": uri} - with st.spinner(f"Processing first {EXCERPT:g} seconds..."): + with st.spinner(f"Processing {EXCERPT:g} seconds..."): output = pipeline(file) - with open('assets/template.html') as html, open('assets/style.css') as css: + with open("assets/template.html") as html, open("assets/style.css") as css: html_template = html.read() - st.markdown(''.format(css.read()), unsafe_allow_html=True) + st.markdown("".format(css.read()), unsafe_allow_html=True) colors = [ "#ffd70033", @@ -105,50 +120,36 @@ def to_base64(waveform: np.ndarray, sample_rate: int = 16000) -> Text: ] num_colors = len(colors) - label2color = {label: colors[k % num_colors] for k, label in enumerate(sorted(output.labels()))} + label2color = { + label: colors[k % num_colors] for k, label in enumerate(sorted(output.labels())) + } BASE64 = to_base64(waveform.numpy().T) REGIONS = "" - LEGENDS = "" - labels=[] for segment, _, label in output.itertracks(yield_label=True): - REGIONS += f"var re = wavesurfer.addRegion({{start: {segment.start:g}, end: {segment.end:g}, color: '{label2color[label]}', resize : false, drag : false}});" - if not label in labels: - LEGENDS += f"