Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
from tempfile import NamedTemporaryFile | |
from typing import Any | |
import streamlit as st | |
from conette import CoNeTTEModel, conette | |
def load_conette(*args, **kwargs) -> CoNeTTEModel: | |
return conette(*args, **kwargs) | |
def main() -> None: | |
st.header("Describe audio content with CoNeTTE") | |
model = load_conette(model_kwds=dict(device="cpu")) | |
task = st.selectbox("Task embedding input", model.tasks, 0) | |
beam_size: int = st.select_slider( # type: ignore | |
"Beam size", | |
list(range(1, 20)), | |
model.config.beam_size, | |
) | |
min_pred_size: int = st.select_slider( # type: ignore | |
"Minimal number of words", | |
list(range(1, 31)), | |
model.config.min_pred_size, | |
) | |
max_pred_size: int = st.select_slider( # type: ignore | |
"Maximal number of words", | |
list(range(1, 31)), | |
model.config.max_pred_size, | |
) | |
st.write("Recommanded audio: lasting from 1s to 30s, sampled at 32 kHz.") | |
audios = st.file_uploader( | |
"Upload an audio file", | |
type=["wav", "flac", "mp3", "ogg", "avi"], | |
accept_multiple_files=True, | |
) | |
if audios is not None and len(audios) > 0: | |
for audio in audios: | |
with NamedTemporaryFile() as temp: | |
temp.write(audio.getvalue()) | |
fpath = temp.name | |
kwargs: dict[str, Any] = dict( | |
task=task, | |
beam_size=beam_size, | |
min_pred_size=min_pred_size, | |
max_pred_size=max_pred_size, | |
) | |
cand_key = f"{audio.name}-{kwargs}" | |
if cand_key in st.session_state: | |
cand = st.session_state[cand_key] | |
else: | |
outputs = model( | |
fpath, | |
**kwargs, | |
) | |
cand = outputs["cands"][0] | |
st.session_state[cand_key] = cand | |
st.write(f"Output for {audio.name}:") | |
st.write(" - ", cand) | |
if __name__ == "__main__": | |
main() | |