Spaces:
Runtime error
Runtime error
pchavaux01
commited on
Commit
•
353df34
1
Parent(s):
c65315c
Upload 4 files
Browse files- app.py +1439 -0
- main.py +43 -0
- packages.txt +1 -0
- requirements.txt +15 -0
app.py
ADDED
@@ -0,0 +1,1439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import dependencies
|
2 |
+
|
3 |
+
# Audio Manipulation
|
4 |
+
import audioread
|
5 |
+
import librosa
|
6 |
+
from pydub import AudioSegment, silence
|
7 |
+
import youtube_dl
|
8 |
+
from youtube_dl import DownloadError
|
9 |
+
|
10 |
+
# Models
|
11 |
+
import torch
|
12 |
+
from transformers import pipeline, HubertForCTC, T5Tokenizer, T5ForConditionalGeneration, Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2Tokenizer
|
13 |
+
from pyannote.audio import Pipeline
|
14 |
+
|
15 |
+
# Others
|
16 |
+
from datetime import timedelta
|
17 |
+
import os
|
18 |
+
import pandas as pd
|
19 |
+
import pickle
|
20 |
+
import re
|
21 |
+
import streamlit as st
|
22 |
+
import time
|
23 |
+
import whisper
|
24 |
+
from whisper import load_model
|
25 |
+
import whisperx
|
26 |
+
import os
|
27 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] ="128mb"
|
28 |
+
|
29 |
+
import gc
|
30 |
+
torch.cuda.empty_cache()
|
31 |
+
gc.collect()
|
32 |
+
|
33 |
+
|
34 |
+
def config():
|
35 |
+
"""
|
36 |
+
App Configuration
|
37 |
+
This functions sets the page title, its favicon, initialize some global variables (session_state values), displays
|
38 |
+
a title, a smaller one, and apply CSS Code to the app.
|
39 |
+
"""
|
40 |
+
# Set config
|
41 |
+
st.set_page_config(page_title="Speech to Text", page_icon="📝")
|
42 |
+
|
43 |
+
# Create a Data Directory
|
44 |
+
# Will not be executed with AI Deploy because it is indicated in the DockerFile of the app
|
45 |
+
|
46 |
+
if not os.path.exists("../data"):
|
47 |
+
os.makedirs("../data")
|
48 |
+
|
49 |
+
# Initialize session state variables
|
50 |
+
if 'page_index' not in st.session_state:
|
51 |
+
st.session_state['page_index'] = -1 # Handle which page should be displayed (token page, home page, results page, rename page)
|
52 |
+
st.session_state['txt_transcript'] = "" # Save the transcript as .txt so we can display it again on the results page
|
53 |
+
st.session_state["process"] = [] # Save the results obtained so we can display them again on the results page
|
54 |
+
st.session_state['srt_txt'] = "" # Save the transcript in a subtitles case to display it on the results page
|
55 |
+
st.session_state['srt_token'] = 0 # Is subtitles parameter enabled or not
|
56 |
+
st.session_state['audio_file'] = None # Save the audio file provided by the user so we can display it again on the results page
|
57 |
+
st.session_state["start_time"] = 0 # Default audio player starting point (0s)
|
58 |
+
st.session_state["summary"] = "" # Save the summary of the transcript so we can display it on the results page
|
59 |
+
st.session_state["number_of_speakers"] = 0 # Save the number of speakers detected in the conversation (diarization)
|
60 |
+
st.session_state["chosen_mode"] = 0 # Save the mode chosen by the user (Diarization or not, timestamps or not)
|
61 |
+
st.session_state["btn_token_list"] = [] # List of tokens that indicates what options are activated to adapt the display on results page
|
62 |
+
st.session_state["my_HF_token"] = "ACCESS_TOKEN_GOES_HERE" # User's Token that allows the use of the diarization model
|
63 |
+
st.session_state["disable"] = True # Default appearance of the button to change your token
|
64 |
+
|
65 |
+
# Display Text and CSS
|
66 |
+
st.title("Speech to Text App 📝")
|
67 |
+
|
68 |
+
st.markdown("""
|
69 |
+
<style>
|
70 |
+
.block-container.css-12oz5g7.egzxvld2{
|
71 |
+
padding: 1%;}
|
72 |
+
# speech-to-text-app > div:nth-child(1) > span:nth-child(2){
|
73 |
+
text-align:center;}
|
74 |
+
.stRadio > label:nth-child(1){
|
75 |
+
font-weight: bold;
|
76 |
+
}
|
77 |
+
.stRadio > div{flex-direction:row;}
|
78 |
+
p, span{
|
79 |
+
text-align: justify;
|
80 |
+
}
|
81 |
+
span{
|
82 |
+
text-align: center;
|
83 |
+
}
|
84 |
+
""", unsafe_allow_html=True)
|
85 |
+
|
86 |
+
st.subheader("You want to extract text from an audio/video? You are in the right place!")
|
87 |
+
|
88 |
+
|
89 |
+
def load_options(audio_length, dia_pipeline):
|
90 |
+
"""
|
91 |
+
Display options so the user can customize the result (punctuate, summarize the transcript ? trim the audio? ...)
|
92 |
+
User can choose his parameters thanks to sliders & checkboxes, both displayed in a st.form so the page doesn't
|
93 |
+
reload when interacting with an element (frustrating if it does because user loses fluidity).
|
94 |
+
:return: the chosen parameters
|
95 |
+
"""
|
96 |
+
# Create a st.form()
|
97 |
+
with st.form("form"):
|
98 |
+
st.markdown("""<h6>
|
99 |
+
You can transcript a specific part of your audio by setting start and end values below (in seconds). Then,
|
100 |
+
choose your parameters.</h6>""", unsafe_allow_html=True)
|
101 |
+
|
102 |
+
# Possibility to trim / cut the audio on a specific part (=> transcribe less seconds will result in saving time)
|
103 |
+
# To perform that, user selects his time intervals thanks to sliders, displayed in 2 different columns
|
104 |
+
col1, col2 = st.columns(2)
|
105 |
+
with col1:
|
106 |
+
start = st.slider("Start value (s)", 0, audio_length, value=0)
|
107 |
+
with col2:
|
108 |
+
end = st.slider("End value (s)", 0, audio_length, value=audio_length)
|
109 |
+
|
110 |
+
# Create 3 new columns to displayed other options
|
111 |
+
col1, col2, col3 = st.columns(3)
|
112 |
+
|
113 |
+
# User selects his preferences with checkboxes
|
114 |
+
with col1:
|
115 |
+
# Get an automatic punctuation
|
116 |
+
punctuation_token = st.checkbox("Punctuate my final text", value=True)
|
117 |
+
|
118 |
+
# Differentiate Speakers
|
119 |
+
if dia_pipeline == None:
|
120 |
+
st.write("Diarization model unvailable")
|
121 |
+
diarization_token = False
|
122 |
+
else:
|
123 |
+
diarization_token = st.checkbox("Differentiate speakers")
|
124 |
+
|
125 |
+
with col2:
|
126 |
+
# Summarize the transcript
|
127 |
+
summarize_token = st.checkbox("Generate a summary", value=False)
|
128 |
+
|
129 |
+
# Generate a SRT file instead of a TXT file (shorter timestamps)
|
130 |
+
srt_token = st.checkbox("Generate subtitles file", value=False)
|
131 |
+
|
132 |
+
with col3:
|
133 |
+
# Display the timestamp of each transcribed part
|
134 |
+
timestamps_token = st.checkbox("Show timestamps", value=True)
|
135 |
+
|
136 |
+
# Improve transcript with an other model (better transcript but longer to obtain)
|
137 |
+
choose_better_model = st.checkbox("Change STT Model")
|
138 |
+
|
139 |
+
# Srt option requires timestamps so it can matches text with time => Need to correct the following case
|
140 |
+
if not timestamps_token and srt_token:
|
141 |
+
timestamps_token = True
|
142 |
+
st.warning("Srt option requires timestamps. We activated it for you.")
|
143 |
+
|
144 |
+
# Validate choices with a button
|
145 |
+
transcript_btn = st.form_submit_button("Transcribe audio!")
|
146 |
+
|
147 |
+
return transcript_btn, start, end, diarization_token, punctuation_token, timestamps_token, srt_token, summarize_token, choose_better_model
|
148 |
+
access_token="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"
|
149 |
+
sst_model = load_model("base.en")
|
150 |
+
@st.cache(allow_output_mutation=True)
|
151 |
+
def load_models():
|
152 |
+
"""
|
153 |
+
Instead of systematically downloading each time the models we use (transcript model, summarizer, speaker differentiation, ...)
|
154 |
+
thanks to transformers' pipeline, we first try to directly import them locally to save time when the app is launched.
|
155 |
+
This function has a st.cache(), because as the models never change, we want the function to execute only one time
|
156 |
+
(also to save time). Otherwise, it would run every time we transcribe a new audio file.
|
157 |
+
:return: Loaded models
|
158 |
+
"""
|
159 |
+
|
160 |
+
# Load facebook-hubert-large-ls960-ft model (English speech to text model)
|
161 |
+
with st.spinner("Loading Speech to Text Model"):
|
162 |
+
# If models are stored in a folder, we import them. Otherwise, we import the models with their respective library
|
163 |
+
|
164 |
+
try:
|
165 |
+
stt_tokenizer = pickle.load(open("models/STT_processor_hubert-large-ls960-ft.sav", 'rb'))
|
166 |
+
except FileNotFoundError:
|
167 |
+
stt_tokenizer = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
|
168 |
+
|
169 |
+
try:
|
170 |
+
#stt_model = pickle.load(open("models/STT_model_hubert-large-ls960-ft.sav", 'rb'))
|
171 |
+
stt_model = load_model("base.en")
|
172 |
+
options = whisper.DecodingOptions(language='english', task='transcribe', without_timestamps=False)
|
173 |
+
except FileNotFoundError:
|
174 |
+
#stt_model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft")
|
175 |
+
stt_model = load_model("base.en")
|
176 |
+
options = whisper.DecodingOptions(language='english', task='transcribe', without_timestamps=False)
|
177 |
+
|
178 |
+
# Load T5 model (Auto punctuation model)
|
179 |
+
with st.spinner("Loading Punctuation Model"):
|
180 |
+
try:
|
181 |
+
t5_tokenizer = torch.load("models/T5_tokenizer.sav")
|
182 |
+
except OSError:
|
183 |
+
t5_tokenizer = T5Tokenizer.from_pretrained("flexudy/t5-small-wav2vec2-grammar-fixer")
|
184 |
+
|
185 |
+
try:
|
186 |
+
t5_model = torch.load("models/T5_model.sav")
|
187 |
+
except FileNotFoundError:
|
188 |
+
t5_model = T5ForConditionalGeneration.from_pretrained("flexudy/t5-small-wav2vec2-grammar-fixer")
|
189 |
+
|
190 |
+
# Load summarizer model
|
191 |
+
with st.spinner("Loading Summarization Model"):
|
192 |
+
try:
|
193 |
+
summarizer = pickle.load(open("models/summarizer.sav", 'rb'))
|
194 |
+
except FileNotFoundError:
|
195 |
+
summarizer = pipeline("summarization")
|
196 |
+
|
197 |
+
# Load Diarization model (Differentiate speakers)
|
198 |
+
with st.spinner("Loading Diarization Model"):
|
199 |
+
try:
|
200 |
+
dia_pipeline = pickle.load(open("models/dia_pipeline.sav", 'rb'))
|
201 |
+
except FileNotFoundError:
|
202 |
+
#access_token="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"
|
203 |
+
dia_pipeline = Pipeline.from_pretrained('pyannote/speaker-diarization@2.1',use_auth_token=access_token)
|
204 |
+
#dia_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",use_auth_token=st.session_state["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"])
|
205 |
+
# If the token hasn't been modified, dia_pipeline will automatically be set to None. The functionality will then be disabled.
|
206 |
+
|
207 |
+
return stt_tokenizer, stt_model, t5_tokenizer, t5_model, summarizer, dia_pipeline
|
208 |
+
|
209 |
+
|
210 |
+
def transcript_from_url(stt_tokenizer, stt_model, t5_tokenizer, t5_model, summarizer, dia_pipeline):
|
211 |
+
"""
|
212 |
+
Display a text input area, where the user can enter a YouTube URL link. If the link seems correct, we try to
|
213 |
+
extract the audio from the video, and then transcribe it.
|
214 |
+
:param stt_tokenizer: Speech to text model's tokenizer
|
215 |
+
:param stt_model: Speech to text model
|
216 |
+
:param t5_tokenizer: Auto punctuation model's tokenizer
|
217 |
+
:param t5_model: Auto punctuation model
|
218 |
+
:param summarizer: Summarizer model
|
219 |
+
:param dia_pipeline: Diarization Model (Differentiate speakers)
|
220 |
+
"""
|
221 |
+
|
222 |
+
url = st.text_input("Enter the YouTube video URL then press Enter to confirm!")
|
223 |
+
# If link seems correct, we try to transcribe
|
224 |
+
if "youtu" in url:
|
225 |
+
filename = extract_audio_from_yt_video(url)
|
226 |
+
if filename is not None:
|
227 |
+
transcription(stt_tokenizer, stt_model, t5_tokenizer, t5_model, summarizer, dia_pipeline, filename)
|
228 |
+
else:
|
229 |
+
st.error("We were unable to extract the audio. Please verify your link, retry or choose another video")
|
230 |
+
|
231 |
+
|
232 |
+
def transcript_from_file(stt_tokenizer, stt_model, t5_tokenizer, t5_model, summarizer, dia_pipeline):
|
233 |
+
"""
|
234 |
+
Display a file uploader area, where the user can import his own file (mp3, mp4 or wav). If the file format seems
|
235 |
+
correct, we transcribe the audio.
|
236 |
+
:param stt_tokenizer: Speech to text model's tokenizer
|
237 |
+
:param stt_model: Speech to text model
|
238 |
+
:param t5_tokenizer: Auto punctuation model's tokenizer
|
239 |
+
:param t5_model: Auto punctuation model
|
240 |
+
:param summarizer: Summarizer model
|
241 |
+
:param dia_pipeline: Diarization Model (Differentiate speakers)
|
242 |
+
"""
|
243 |
+
|
244 |
+
# File uploader widget with a callback function, so the page reloads if the users uploads a new audio file
|
245 |
+
uploaded_file = st.file_uploader("Upload your file! It can be a .mp3, .mp4 or .wav", type=["mp3", "mp4", "wav"],
|
246 |
+
on_change=update_session_state, args=("page_index", 0,))
|
247 |
+
|
248 |
+
if uploaded_file is not None:
|
249 |
+
# get name and launch transcription function
|
250 |
+
filename = uploaded_file.name
|
251 |
+
transcription(stt_tokenizer, stt_model, t5_tokenizer, t5_model, summarizer, dia_pipeline, filename,
|
252 |
+
uploaded_file)
|
253 |
+
|
254 |
+
|
255 |
+
def transcription(stt_tokenizer, stt_model, t5_tokenizer, t5_model, summarizer, dia_pipeline, filename,
|
256 |
+
uploaded_file=None):
|
257 |
+
"""
|
258 |
+
Mini-main function
|
259 |
+
Display options, transcribe an audio file and save results.
|
260 |
+
:param stt_tokenizer: Speech to text model's tokenizer
|
261 |
+
:param stt_model: Speech to text model
|
262 |
+
:param t5_tokenizer: Auto punctuation model's tokenizer
|
263 |
+
:param t5_model: Auto punctuation model
|
264 |
+
:param summarizer: Summarizer model
|
265 |
+
:param dia_pipeline: Diarization Model (Differentiate speakers)
|
266 |
+
:param filename: name of the audio file
|
267 |
+
:param uploaded_file: file / name of the audio file which allows the code to reach the file
|
268 |
+
"""
|
269 |
+
|
270 |
+
# If the audio comes from the Youtube extraction mode, the audio is downloaded so the uploaded_file is
|
271 |
+
# the same as the filename. We need to change the uploaded_file which is currently set to None
|
272 |
+
if uploaded_file is None:
|
273 |
+
uploaded_file = filename
|
274 |
+
|
275 |
+
# Get audio length of the file(s)
|
276 |
+
myaudio = AudioSegment.from_file(uploaded_file)
|
277 |
+
audio_length = myaudio.duration_seconds
|
278 |
+
|
279 |
+
# Save Audio (so we can display it on another page ("DISPLAY RESULTS"), otherwise it is lost)
|
280 |
+
update_session_state("audio_file", uploaded_file)
|
281 |
+
|
282 |
+
# Display audio file
|
283 |
+
st.audio(uploaded_file)
|
284 |
+
|
285 |
+
# Is transcription possible
|
286 |
+
if audio_length > 0:
|
287 |
+
|
288 |
+
# We display options and user shares his wishes
|
289 |
+
transcript_btn, start, end, diarization_token, punctuation_token, timestamps_token, srt_token, summarize_token, choose_better_model = load_options(
|
290 |
+
int(audio_length), dia_pipeline)
|
291 |
+
|
292 |
+
# If end value hasn't been changed, we fix it to the max value so we don't cut some ms of the audio because
|
293 |
+
# end value is returned by a st.slider which return end value as a int (ex: return 12 sec instead of end=12.9s)
|
294 |
+
if end == int(audio_length):
|
295 |
+
end = audio_length
|
296 |
+
|
297 |
+
# Switching model for the better one
|
298 |
+
if choose_better_model:
|
299 |
+
with st.spinner("We are loading the better model. Please wait..."):
|
300 |
+
|
301 |
+
try:
|
302 |
+
stt_tokenizer = pickle.load(open("models/STT_tokenizer2_wav2vec2-large-960h-lv60-self.sav", 'rb'))
|
303 |
+
except FileNotFoundError:
|
304 |
+
stt_tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
|
305 |
+
|
306 |
+
try:
|
307 |
+
stt_model = pickle.load(open("models/STT_model2_wav2vec2-large-960h-lv60-self.sav", 'rb'))
|
308 |
+
except FileNotFoundError:
|
309 |
+
stt_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
|
310 |
+
|
311 |
+
# Validate options and launch the transcription process thanks to the form's button
|
312 |
+
if transcript_btn:
|
313 |
+
|
314 |
+
# Check if start & end values are correct
|
315 |
+
start, end = correct_values(start, end, audio_length)
|
316 |
+
|
317 |
+
# If start a/o end value(s) has/have changed, we trim/cut the audio according to the new start/end values.
|
318 |
+
if start != 0 or end != audio_length:
|
319 |
+
myaudio = myaudio[start * 1000:end * 1000] # Works in milliseconds (*1000)
|
320 |
+
|
321 |
+
# Transcribe process is running
|
322 |
+
with st.spinner("We are transcribing your audio. Please wait"):
|
323 |
+
|
324 |
+
# Initialize variables
|
325 |
+
txt_text, srt_text, save_result = init_transcription(start, int(end))
|
326 |
+
min_space, max_space = silence_mode_init(srt_token)
|
327 |
+
|
328 |
+
# Differentiate speakers mode
|
329 |
+
if diarization_token:
|
330 |
+
|
331 |
+
# Save mode chosen by user, to display expected results
|
332 |
+
if not timestamps_token:
|
333 |
+
update_session_state("chosen_mode", "DIA")
|
334 |
+
elif timestamps_token:
|
335 |
+
update_session_state("chosen_mode", "DIA_TS")
|
336 |
+
|
337 |
+
# Convert mp3/mp4 to wav (Differentiate speakers mode only accepts wav files)
|
338 |
+
if filename.endswith((".mp3", ".mp4")):
|
339 |
+
myaudio, filename = convert_file_to_wav(myaudio, filename)
|
340 |
+
else:
|
341 |
+
filename = "../data/" + filename
|
342 |
+
myaudio.export(filename, format="wav")
|
343 |
+
|
344 |
+
# Differentiate speakers process
|
345 |
+
diarization_timestamps, number_of_speakers = diarization_treatment(filename, dia_pipeline,
|
346 |
+
max_space, srt_token)
|
347 |
+
# Saving the number of detected speakers
|
348 |
+
update_session_state("number_of_speakers", number_of_speakers)
|
349 |
+
|
350 |
+
# Transcribe process with Diarization Mode
|
351 |
+
save_result, txt_text, srt_text = transcription_diarization(filename, diarization_timestamps,
|
352 |
+
stt_model,
|
353 |
+
stt_tokenizer,
|
354 |
+
diarization_token,
|
355 |
+
srt_token, summarize_token,
|
356 |
+
timestamps_token, myaudio, start,
|
357 |
+
save_result,
|
358 |
+
txt_text, srt_text)
|
359 |
+
|
360 |
+
# Non Diarization Mode
|
361 |
+
else:
|
362 |
+
# Save mode chosen by user, to display expected results
|
363 |
+
if not timestamps_token:
|
364 |
+
update_session_state("chosen_mode", "NODIA")
|
365 |
+
if timestamps_token:
|
366 |
+
update_session_state("chosen_mode", "NODIA_TS")
|
367 |
+
|
368 |
+
filename = "../data/" + filename
|
369 |
+
# Transcribe process with non Diarization Mode
|
370 |
+
save_result, txt_text, srt_text = transcription_non_diarization(filename, myaudio, start, end,
|
371 |
+
diarization_token, timestamps_token,
|
372 |
+
srt_token, summarize_token,
|
373 |
+
stt_model, stt_tokenizer,
|
374 |
+
min_space, max_space,
|
375 |
+
save_result, txt_text, srt_text)
|
376 |
+
|
377 |
+
# Save results so it is not lost when we interact with a button
|
378 |
+
update_session_state("process", save_result)
|
379 |
+
update_session_state("srt_txt", srt_text)
|
380 |
+
|
381 |
+
# Get final text (with or without punctuation token)
|
382 |
+
# Diariation Mode
|
383 |
+
if diarization_token:
|
384 |
+
# Create txt text from the process
|
385 |
+
txt_text = create_txt_text_from_process(punctuation_token, t5_model, t5_tokenizer)
|
386 |
+
|
387 |
+
# Non diarization Mode
|
388 |
+
else:
|
389 |
+
|
390 |
+
if punctuation_token:
|
391 |
+
# Need to split the text by 512 text blocks size since the model has a limited input
|
392 |
+
with st.spinner("Transcription is finished! Let us punctuate your audio"):
|
393 |
+
my_split_text_list = split_text(txt_text, 512)
|
394 |
+
txt_text = ""
|
395 |
+
# punctuate each text block
|
396 |
+
for my_split_text in my_split_text_list:
|
397 |
+
txt_text += add_punctuation(t5_model, t5_tokenizer, my_split_text)
|
398 |
+
|
399 |
+
# Clean folder's files
|
400 |
+
clean_directory("../data")
|
401 |
+
|
402 |
+
# Display the final transcript
|
403 |
+
if txt_text != "":
|
404 |
+
st.subheader("Final text is")
|
405 |
+
|
406 |
+
# Save txt_text and display it
|
407 |
+
update_session_state("txt_transcript", txt_text)
|
408 |
+
st.markdown(txt_text, unsafe_allow_html=True)
|
409 |
+
|
410 |
+
# Summarize the transcript
|
411 |
+
if summarize_token:
|
412 |
+
with st.spinner("We are summarizing your audio"):
|
413 |
+
# Display summary in a st.expander widget to don't write too much text on the page
|
414 |
+
with st.expander("Summary"):
|
415 |
+
# Need to split the text by 1024 text blocks size since the model has a limited input
|
416 |
+
if diarization_token:
|
417 |
+
# in diarization mode, the text to summarize is contained in the "summary" the session state variable
|
418 |
+
my_split_text_list = split_text(st.session_state["summary"], 1024)
|
419 |
+
else:
|
420 |
+
# in non-diarization mode, it is contained in the txt_text variable
|
421 |
+
my_split_text_list = split_text(txt_text, 1024)
|
422 |
+
|
423 |
+
summary = ""
|
424 |
+
# Summarize each text block
|
425 |
+
for my_split_text in my_split_text_list:
|
426 |
+
summary += summarizer(my_split_text)[0]['summary_text']
|
427 |
+
|
428 |
+
# Removing multiple spaces and double spaces around punctuation mark " . "
|
429 |
+
summary = re.sub(' +', ' ', summary)
|
430 |
+
summary = re.sub(r'\s+([?.!"])', r'\1', summary)
|
431 |
+
|
432 |
+
# Display summary and save it
|
433 |
+
st.write(summary)
|
434 |
+
update_session_state("summary", summary)
|
435 |
+
|
436 |
+
# Display buttons to interact with results
|
437 |
+
|
438 |
+
# We have 4 possible buttons depending on the user's choices. But we can't set 4 columns for 4
|
439 |
+
# buttons. Indeed, if the user displays only 3 buttons, it is possible that one of the column
|
440 |
+
# 1, 2 or 3 is empty which would be ugly. We want the activated options to be in the first columns
|
441 |
+
# so that the empty columns are not noticed. To do that, let's create a btn_token_list
|
442 |
+
|
443 |
+
btn_token_list = [[diarization_token, "dia_token"], [True, "useless_txt_token"],
|
444 |
+
[srt_token, "srt_token"], [summarize_token, "summarize_token"]]
|
445 |
+
|
446 |
+
# Save this list to be able to reach it on the other pages of the app
|
447 |
+
update_session_state("btn_token_list", btn_token_list)
|
448 |
+
|
449 |
+
# Create 4 columns
|
450 |
+
col1, col2, col3, col4 = st.columns(4)
|
451 |
+
|
452 |
+
# Create a column list
|
453 |
+
col_list = [col1, col2, col3, col4]
|
454 |
+
|
455 |
+
# Check value of each token, if True, we put the respective button of the token in a column
|
456 |
+
col_index = 0
|
457 |
+
for elt in btn_token_list:
|
458 |
+
if elt[0]:
|
459 |
+
mycol = col_list[col_index]
|
460 |
+
if elt[1] == "useless_txt_token":
|
461 |
+
# Download your transcript.txt
|
462 |
+
with mycol:
|
463 |
+
st.download_button("Download as TXT", txt_text, file_name="my_transcription.txt",
|
464 |
+
on_click=update_session_state, args=("page_index", 1,))
|
465 |
+
elif elt[1] == "srt_token":
|
466 |
+
# Download your transcript.srt
|
467 |
+
with mycol:
|
468 |
+
update_session_state("srt_token", srt_token)
|
469 |
+
st.download_button("Download as SRT", srt_text, file_name="my_transcription.srt",
|
470 |
+
on_click=update_session_state, args=("page_index", 1,))
|
471 |
+
elif elt[1] == "dia_token":
|
472 |
+
with mycol:
|
473 |
+
# Rename the speakers detected in your audio
|
474 |
+
st.button("Rename Speakers", on_click=update_session_state, args=("page_index", 2,))
|
475 |
+
|
476 |
+
elif elt[1] == "summarize_token":
|
477 |
+
with mycol:
|
478 |
+
# Download the summary of your transcript.txt
|
479 |
+
st.download_button("Download Summary", st.session_state["summary"],
|
480 |
+
file_name="my_summary.txt",
|
481 |
+
on_click=update_session_state, args=("page_index", 1,))
|
482 |
+
col_index += 1
|
483 |
+
|
484 |
+
else:
|
485 |
+
st.write("Transcription impossible, a problem occurred with your audio or your parameters, "
|
486 |
+
"we apologize :(")
|
487 |
+
|
488 |
+
else:
|
489 |
+
st.error("Seems your audio is 0 s long, please change your file")
|
490 |
+
time.sleep(3)
|
491 |
+
st.stop()
|
492 |
+
|
493 |
+
|
494 |
+
def create_txt_text_from_process(punctuation_token=False, t5_model=None, t5_tokenizer=None):
|
495 |
+
"""
|
496 |
+
If we are in a diarization case (differentiate speakers), we create txt_text from st.session.state['process']
|
497 |
+
There is a lot of information in the process variable, but we only extract the identity of the speaker and
|
498 |
+
the sentence spoken, as in a non-diarization case.
|
499 |
+
:param punctuation_token: Punctuate or not the transcript (choice fixed by user)
|
500 |
+
:param t5_model: T5 Model (Auto punctuation model)
|
501 |
+
:param t5_tokenizer: T5’s Tokenizer (Auto punctuation model's tokenizer)
|
502 |
+
:return: Final transcript (without timestamps)
|
503 |
+
"""
|
504 |
+
txt_text = ""
|
505 |
+
# The information to be extracted is different according to the chosen mode
|
506 |
+
if punctuation_token:
|
507 |
+
with st.spinner("Transcription is finished! Let us punctuate your audio"):
|
508 |
+
if st.session_state["chosen_mode"] == "DIA":
|
509 |
+
for elt in st.session_state["process"]:
|
510 |
+
# [2:] don't want ": text" but only the "text"
|
511 |
+
text_to_punctuate = elt[2][2:]
|
512 |
+
if len(text_to_punctuate) >= 512:
|
513 |
+
text_to_punctutate_list = split_text(text_to_punctuate, 512)
|
514 |
+
punctuated_text = ""
|
515 |
+
for split_text_to_punctuate in text_to_punctutate_list:
|
516 |
+
punctuated_text += add_punctuation(t5_model, t5_tokenizer, split_text_to_punctuate)
|
517 |
+
else:
|
518 |
+
punctuated_text = add_punctuation(t5_model, t5_tokenizer, text_to_punctuate)
|
519 |
+
|
520 |
+
txt_text += elt[1] + " : " + punctuated_text + '\n\n'
|
521 |
+
|
522 |
+
elif st.session_state["chosen_mode"] == "DIA_TS":
|
523 |
+
for elt in st.session_state["process"]:
|
524 |
+
text_to_punctuate = elt[3][2:]
|
525 |
+
if len(text_to_punctuate) >= 512:
|
526 |
+
text_to_punctutate_list = split_text(text_to_punctuate, 512)
|
527 |
+
punctuated_text = ""
|
528 |
+
for split_text_to_punctuate in text_to_punctutate_list:
|
529 |
+
punctuated_text += add_punctuation(t5_model, t5_tokenizer, split_text_to_punctuate)
|
530 |
+
else:
|
531 |
+
punctuated_text = add_punctuation(t5_model, t5_tokenizer, text_to_punctuate)
|
532 |
+
|
533 |
+
txt_text += elt[2] + " : " + punctuated_text + '\n\n'
|
534 |
+
else:
|
535 |
+
if st.session_state["chosen_mode"] == "DIA":
|
536 |
+
for elt in st.session_state["process"]:
|
537 |
+
txt_text += elt[1] + elt[2] + '\n\n'
|
538 |
+
|
539 |
+
elif st.session_state["chosen_mode"] == "DIA_TS":
|
540 |
+
for elt in st.session_state["process"]:
|
541 |
+
txt_text += elt[2] + elt[3] + '\n\n'
|
542 |
+
|
543 |
+
return txt_text
|
544 |
+
|
545 |
+
|
546 |
+
def rename_speakers_window():
|
547 |
+
"""
|
548 |
+
Load a new page which allows the user to rename the different speakers from the diarization process
|
549 |
+
For example he can switch from "Speaker1 : "I wouldn't say that"" to "Mat : "I wouldn't say that""
|
550 |
+
"""
|
551 |
+
|
552 |
+
st.subheader("Here you can rename the speakers as you want")
|
553 |
+
number_of_speakers = st.session_state["number_of_speakers"]
|
554 |
+
|
555 |
+
if number_of_speakers > 0:
|
556 |
+
# Handle displayed text according to the number_of_speakers
|
557 |
+
if number_of_speakers == 1:
|
558 |
+
st.write(str(number_of_speakers) + " speaker has been detected in your audio")
|
559 |
+
else:
|
560 |
+
st.write(str(number_of_speakers) + " speakers have been detected in your audio")
|
561 |
+
|
562 |
+
# Saving the Speaker Name and its ID in a list, example : [1, 'Speaker1']
|
563 |
+
list_of_speakers = []
|
564 |
+
for elt in st.session_state["process"]:
|
565 |
+
if st.session_state["chosen_mode"] == "DIA_TS":
|
566 |
+
if [elt[1], elt[2]] not in list_of_speakers:
|
567 |
+
list_of_speakers.append([elt[1], elt[2]])
|
568 |
+
elif st.session_state["chosen_mode"] == "DIA":
|
569 |
+
if [elt[0], elt[1]] not in list_of_speakers:
|
570 |
+
list_of_speakers.append([elt[0], elt[1]])
|
571 |
+
|
572 |
+
# Sorting (by ID)
|
573 |
+
list_of_speakers.sort() # [[1, 'Speaker1'], [0, 'Speaker0']] => [[0, 'Speaker0'], [1, 'Speaker1']]
|
574 |
+
|
575 |
+
# Display saved names so the user can modify them
|
576 |
+
initial_names = ""
|
577 |
+
for elt in list_of_speakers:
|
578 |
+
initial_names += elt[1] + "\n"
|
579 |
+
|
580 |
+
names_input = st.text_area("Just replace the names without changing the format (one per line)",
|
581 |
+
value=initial_names)
|
582 |
+
|
583 |
+
# Display Options (Cancel / Save)
|
584 |
+
col1, col2 = st.columns(2)
|
585 |
+
with col1:
|
586 |
+
# Cancel changes by clicking a button - callback function to return to the results page
|
587 |
+
st.button("Cancel", on_click=update_session_state, args=("page_index", 1,))
|
588 |
+
with col2:
|
589 |
+
# Confirm changes by clicking a button - callback function to apply changes and return to the results page
|
590 |
+
st.button("Save changes", on_click=click_confirm_rename_btn, args=(names_input, number_of_speakers,))
|
591 |
+
|
592 |
+
# Don't have anyone to rename
|
593 |
+
else:
|
594 |
+
st.error("0 speakers have been detected. Seem there is an issue with diarization")
|
595 |
+
with st.spinner("Redirecting to transcription page"):
|
596 |
+
time.sleep(4)
|
597 |
+
# return to the results page
|
598 |
+
update_session_state("page_index", 1)
|
599 |
+
|
600 |
+
|
601 |
+
def click_confirm_rename_btn(names_input, number_of_speakers):
|
602 |
+
"""
|
603 |
+
If the users decides to rename speakers and confirms his choices, we apply the modifications to our transcript
|
604 |
+
Then we return to the results page of the app
|
605 |
+
:param names_input: string
|
606 |
+
:param number_of_speakers: Number of detected speakers in the audio file
|
607 |
+
"""
|
608 |
+
|
609 |
+
try:
|
610 |
+
names_input = names_input.split("\n")[:number_of_speakers]
|
611 |
+
|
612 |
+
for elt in st.session_state["process"]:
|
613 |
+
elt[2] = names_input[elt[1]]
|
614 |
+
|
615 |
+
txt_text = create_txt_text_from_process()
|
616 |
+
update_session_state("txt_transcript", txt_text)
|
617 |
+
update_session_state("page_index", 1)
|
618 |
+
|
619 |
+
except TypeError: # list indices must be integers or slices, not str (happened to me one time when writing non sense names)
|
620 |
+
st.error("Please respect the 1 name per line format")
|
621 |
+
with st.spinner("We are relaunching the page"):
|
622 |
+
time.sleep(3)
|
623 |
+
update_session_state("page_index", 1)
|
624 |
+
|
625 |
+
|
626 |
+
def transcription_diarization(filename, diarization_timestamps, stt_model, stt_tokenizer, diarization_token, srt_token,
|
627 |
+
summarize_token, timestamps_token, myaudio, start, save_result, txt_text, srt_text):
|
628 |
+
"""
|
629 |
+
Performs transcription with the diarization mode
|
630 |
+
:param filename: name of the audio file
|
631 |
+
:param diarization_timestamps: timestamps of each audio part (ex 10 to 50 secs)
|
632 |
+
:param stt_model: Speech to text model
|
633 |
+
:param stt_tokenizer: Speech to text model's tokenizer
|
634 |
+
:param diarization_token: Differentiate or not the speakers (choice fixed by user)
|
635 |
+
:param srt_token: Enable/Disable generate srt file (choice fixed by user)
|
636 |
+
:param summarize_token: Summarize or not the transcript (choice fixed by user)
|
637 |
+
:param timestamps_token: Display and save or not the timestamps (choice fixed by user)
|
638 |
+
:param myaudio: AudioSegment file
|
639 |
+
:param start: int value (s) given by st.slider() (fixed by user)
|
640 |
+
:param save_result: whole process
|
641 |
+
:param txt_text: generated .txt transcript
|
642 |
+
:param srt_text: generated .srt transcript
|
643 |
+
:return: results of transcribing action
|
644 |
+
"""
|
645 |
+
# Numeric counter that identifies each sequential subtitle
|
646 |
+
srt_index = 1
|
647 |
+
|
648 |
+
# Handle a rare case : Only the case if only one "list" in the list (it makes a classic list) not a list of list
|
649 |
+
if not isinstance(diarization_timestamps[0], list):
|
650 |
+
diarization_timestamps = [diarization_timestamps]
|
651 |
+
|
652 |
+
# Transcribe each audio chunk (from timestamp to timestamp) and display transcript
|
653 |
+
for index, elt in enumerate(diarization_timestamps):
|
654 |
+
sub_start = elt[0]
|
655 |
+
sub_end = elt[1]
|
656 |
+
|
657 |
+
transcription = transcribe_audio_part(filename, stt_model, stt_tokenizer, myaudio, sub_start, sub_end,
|
658 |
+
index)
|
659 |
+
|
660 |
+
# Initial audio has been split with start & end values
|
661 |
+
# It begins to 0s, but the timestamps need to be adjust with +start*1000 values to adapt the gap
|
662 |
+
if transcription != "":
|
663 |
+
save_result, txt_text, srt_text, srt_index = display_transcription(diarization_token, summarize_token,
|
664 |
+
srt_token, timestamps_token,
|
665 |
+
transcription, save_result, txt_text,
|
666 |
+
srt_text,
|
667 |
+
srt_index, sub_start + start * 1000,
|
668 |
+
sub_end + start * 1000, elt)
|
669 |
+
return save_result, txt_text, srt_text
|
670 |
+
|
671 |
+
|
672 |
+
def transcription_non_diarization(filename, myaudio, start, end, diarization_token, timestamps_token, srt_token,
|
673 |
+
summarize_token, stt_model, stt_tokenizer, min_space, max_space, save_result,
|
674 |
+
txt_text, srt_text):
|
675 |
+
"""
|
676 |
+
Performs transcribing action with the non-diarization mode
|
677 |
+
:param filename: name of the audio file
|
678 |
+
:param myaudio: AudioSegment file
|
679 |
+
:param start: int value (s) given by st.slider() (fixed by user)
|
680 |
+
:param end: int value (s) given by st.slider() (fixed by user)
|
681 |
+
:param diarization_token: Differentiate or not the speakers (choice fixed by user)
|
682 |
+
:param timestamps_token: Display and save or not the timestamps (choice fixed by user)
|
683 |
+
:param srt_token: Enable/Disable generate srt file (choice fixed by user)
|
684 |
+
:param summarize_token: Summarize or not the transcript (choice fixed by user)
|
685 |
+
:param stt_model: Speech to text model
|
686 |
+
:param stt_tokenizer: Speech to text model's tokenizer
|
687 |
+
:param min_space: Minimum temporal distance between two silences
|
688 |
+
:param max_space: Maximum temporal distance between two silences
|
689 |
+
:param save_result: whole process
|
690 |
+
:param txt_text: generated .txt transcript
|
691 |
+
:param srt_text: generated .srt transcript
|
692 |
+
:return: results of transcribing action
|
693 |
+
"""
|
694 |
+
|
695 |
+
# Numeric counter identifying each sequential subtitle
|
696 |
+
srt_index = 1
|
697 |
+
|
698 |
+
# get silences
|
699 |
+
silence_list = detect_silences(myaudio)
|
700 |
+
if silence_list != []:
|
701 |
+
silence_list = get_middle_silence_time(silence_list)
|
702 |
+
silence_list = silences_distribution(silence_list, min_space, max_space, start, end, srt_token)
|
703 |
+
else:
|
704 |
+
silence_list = generate_regular_split_till_end(silence_list, int(end), min_space, max_space)
|
705 |
+
|
706 |
+
# Transcribe each audio chunk (from timestamp to timestamp) and display transcript
|
707 |
+
for i in range(0, len(silence_list) - 1):
|
708 |
+
sub_start = silence_list[i]
|
709 |
+
sub_end = silence_list[i + 1]
|
710 |
+
|
711 |
+
transcription = transcribe_audio_part(filename, stt_model, stt_tokenizer, myaudio, sub_start, sub_end, i)
|
712 |
+
|
713 |
+
# Initial audio has been split with start & end values
|
714 |
+
# It begins to 0s, but the timestamps need to be adjust with +start*1000 values to adapt the gap
|
715 |
+
if transcription != "":
|
716 |
+
save_result, txt_text, srt_text, srt_index = display_transcription(diarization_token, summarize_token,
|
717 |
+
srt_token, timestamps_token,
|
718 |
+
transcription, save_result,
|
719 |
+
txt_text,
|
720 |
+
srt_text,
|
721 |
+
srt_index, sub_start + start * 1000,
|
722 |
+
sub_end + start * 1000)
|
723 |
+
|
724 |
+
return save_result, txt_text, srt_text
|
725 |
+
|
726 |
+
|
727 |
+
def silence_mode_init(srt_token):
|
728 |
+
"""
|
729 |
+
Fix min_space and max_space values
|
730 |
+
If the user wants a srt file, we need to have tiny timestamps
|
731 |
+
:param srt_token: Enable/Disable generate srt file option (choice fixed by user)
|
732 |
+
:return: min_space and max_space values
|
733 |
+
"""
|
734 |
+
if srt_token:
|
735 |
+
# We need short intervals if we want a short text
|
736 |
+
min_space = 1000 # 1 sec
|
737 |
+
max_space = 8000 # 8 secs
|
738 |
+
|
739 |
+
else:
|
740 |
+
min_space = 25000 # 25 secs
|
741 |
+
max_space = 45000 # 45secs
|
742 |
+
|
743 |
+
return min_space, max_space
|
744 |
+
|
745 |
+
|
746 |
+
def detect_silences(audio):
|
747 |
+
"""
|
748 |
+
Silence moments detection in an audio file
|
749 |
+
:param audio: pydub.AudioSegment file
|
750 |
+
:return: list with silences time intervals
|
751 |
+
"""
|
752 |
+
# Get Decibels (dB) so silences detection depends on the audio instead of a fixed value
|
753 |
+
dbfs = audio.dBFS
|
754 |
+
|
755 |
+
# Get silences timestamps > 750ms
|
756 |
+
silence_list = silence.detect_silence(audio, min_silence_len=750, silence_thresh=dbfs - 14)
|
757 |
+
|
758 |
+
return silence_list
|
759 |
+
|
760 |
+
|
761 |
+
def generate_regular_split_till_end(time_list, end, min_space, max_space):
|
762 |
+
"""
|
763 |
+
Add automatic "time cuts" to time_list till end value depending on min_space and max_space values
|
764 |
+
:param time_list: silence time list
|
765 |
+
:param end: int value (s)
|
766 |
+
:param min_space: Minimum temporal distance between two silences
|
767 |
+
:param max_space: Maximum temporal distance between two silences
|
768 |
+
:return: list with automatic time cuts
|
769 |
+
"""
|
770 |
+
# In range loop can't handle float values so we convert to int
|
771 |
+
int_last_value = int(time_list[-1])
|
772 |
+
int_end = int(end)
|
773 |
+
|
774 |
+
# Add maxspace to the last list value and add this value to the list
|
775 |
+
for i in range(int_last_value, int_end, max_space):
|
776 |
+
value = i + max_space
|
777 |
+
if value < end:
|
778 |
+
time_list.append(value)
|
779 |
+
|
780 |
+
# Fix last automatic cut
|
781 |
+
# If small gap (ex: 395 000, with end = 400 000)
|
782 |
+
if end - time_list[-1] < min_space:
|
783 |
+
time_list[-1] = end
|
784 |
+
else:
|
785 |
+
# If important gap (ex: 311 000 then 356 000, with end = 400 000, can't replace and then have 311k to 400k)
|
786 |
+
time_list.append(end)
|
787 |
+
return time_list
|
788 |
+
|
789 |
+
|
790 |
+
def get_middle_silence_time(silence_list):
|
791 |
+
"""
|
792 |
+
Replace in a list each timestamp by a unique value, which is approximately the middle of each silence timestamp, to
|
793 |
+
avoid word cutting
|
794 |
+
:param silence_list: List of lists where each element has a start and end value which describes a silence timestamp
|
795 |
+
:return: Simple float list
|
796 |
+
"""
|
797 |
+
length = len(silence_list)
|
798 |
+
index = 0
|
799 |
+
while index < length:
|
800 |
+
diff = (silence_list[index][1] - silence_list[index][0])
|
801 |
+
if diff < 3500:
|
802 |
+
silence_list[index] = silence_list[index][0] + diff / 2
|
803 |
+
index += 1
|
804 |
+
else:
|
805 |
+
adapted_diff = 1500
|
806 |
+
silence_list.insert(index + 1, silence_list[index][1] - adapted_diff)
|
807 |
+
silence_list[index] = silence_list[index][0] + adapted_diff
|
808 |
+
length += 1
|
809 |
+
index += 2
|
810 |
+
|
811 |
+
return silence_list
|
812 |
+
|
813 |
+
|
814 |
+
def silences_distribution(silence_list, min_space, max_space, start, end, srt_token=False):
|
815 |
+
"""
|
816 |
+
We keep each silence value if it is sufficiently distant from its neighboring values, without being too much
|
817 |
+
:param silence_list: List with silences intervals
|
818 |
+
:param min_space: Minimum temporal distance between two silences
|
819 |
+
:param max_space: Maximum temporal distance between two silences
|
820 |
+
:param start: int value (seconds)
|
821 |
+
:param end: int value (seconds)
|
822 |
+
:param srt_token: Enable/Disable generate srt file (choice fixed by user)
|
823 |
+
:return: list with equally distributed silences
|
824 |
+
"""
|
825 |
+
# If starts != 0, we need to adjust end value since silences detection is performed on the trimmed/cut audio
|
826 |
+
# (and not on the original audio) (ex: trim audio from 20s to 2m will be 0s to 1m40 = 2m-20s)
|
827 |
+
|
828 |
+
# Shift the end according to the start value
|
829 |
+
end -= start
|
830 |
+
start = 0
|
831 |
+
end *= 1000
|
832 |
+
|
833 |
+
# Step 1 - Add start value
|
834 |
+
newsilence = [start]
|
835 |
+
|
836 |
+
# Step 2 - Create a regular distribution between start and the first element of silence_list to don't have a gap > max_space and run out of memory
|
837 |
+
# example newsilence = [0] and silence_list starts with 100000 => It will create a massive gap [0, 100000]
|
838 |
+
|
839 |
+
if silence_list[0] - max_space > newsilence[0]:
|
840 |
+
for i in range(int(newsilence[0]), int(silence_list[0]), max_space): # int bc float can't be in a range loop
|
841 |
+
value = i + max_space
|
842 |
+
if value < silence_list[0]:
|
843 |
+
newsilence.append(value)
|
844 |
+
|
845 |
+
# Step 3 - Create a regular distribution until the last value of the silence_list
|
846 |
+
min_desired_value = newsilence[-1]
|
847 |
+
max_desired_value = newsilence[-1]
|
848 |
+
nb_values = len(silence_list)
|
849 |
+
|
850 |
+
while nb_values != 0:
|
851 |
+
max_desired_value += max_space
|
852 |
+
|
853 |
+
# Get a window of the values greater than min_desired_value and lower than max_desired_value
|
854 |
+
silence_window = list(filter(lambda x: min_desired_value < x <= max_desired_value, silence_list))
|
855 |
+
|
856 |
+
if silence_window != []:
|
857 |
+
# Get the nearest value we can to min_desired_value or max_desired_value depending on srt_token
|
858 |
+
if srt_token:
|
859 |
+
nearest_value = min(silence_window, key=lambda x: abs(x - min_desired_value))
|
860 |
+
nb_values -= silence_window.index(nearest_value) + 1 # (index begins at 0, so we add 1)
|
861 |
+
else:
|
862 |
+
nearest_value = min(silence_window, key=lambda x: abs(x - max_desired_value))
|
863 |
+
# Max value index = len of the list
|
864 |
+
nb_values -= len(silence_window)
|
865 |
+
|
866 |
+
# Append the nearest value to our list
|
867 |
+
newsilence.append(nearest_value)
|
868 |
+
|
869 |
+
# If silence_window is empty we add the max_space value to the last one to create an automatic cut and avoid multiple audio cutting
|
870 |
+
else:
|
871 |
+
newsilence.append(newsilence[-1] + max_space)
|
872 |
+
|
873 |
+
min_desired_value = newsilence[-1]
|
874 |
+
max_desired_value = newsilence[-1]
|
875 |
+
|
876 |
+
# Step 4 - Add the final value (end)
|
877 |
+
|
878 |
+
if end - newsilence[-1] > min_space:
|
879 |
+
# Gap > Min Space
|
880 |
+
if end - newsilence[-1] < max_space:
|
881 |
+
newsilence.append(end)
|
882 |
+
else:
|
883 |
+
# Gap too important between the last list value and the end value
|
884 |
+
# We need to create automatic max_space cut till the end
|
885 |
+
newsilence = generate_regular_split_till_end(newsilence, end, min_space, max_space)
|
886 |
+
else:
|
887 |
+
# Gap < Min Space <=> Final value and last value of new silence are too close, need to merge
|
888 |
+
if len(newsilence) >= 2:
|
889 |
+
if end - newsilence[-2] <= max_space:
|
890 |
+
# Replace if gap is not too important
|
891 |
+
newsilence[-1] = end
|
892 |
+
else:
|
893 |
+
newsilence.append(end)
|
894 |
+
|
895 |
+
else:
|
896 |
+
if end - newsilence[-1] <= max_space:
|
897 |
+
# Replace if gap is not too important
|
898 |
+
newsilence[-1] = end
|
899 |
+
else:
|
900 |
+
newsilence.append(end)
|
901 |
+
|
902 |
+
return newsilence
|
903 |
+
|
904 |
+
|
905 |
+
def init_transcription(start, end):
|
906 |
+
"""
|
907 |
+
Initialize values and inform user that transcription is in progress
|
908 |
+
:param start: int value (s) given by st.slider() (fixed by user)
|
909 |
+
:param end: int value (s) given by st.slider() (fixed by user)
|
910 |
+
:return: final_transcription, final_srt_text, and the process
|
911 |
+
"""
|
912 |
+
update_session_state("summary", "")
|
913 |
+
st.write("Transcription between", start, "and", end, "seconds in process.\n\n")
|
914 |
+
txt_text = ""
|
915 |
+
srt_text = ""
|
916 |
+
save_result = []
|
917 |
+
return txt_text, srt_text, save_result
|
918 |
+
|
919 |
+
|
920 |
+
def transcribe_audio_part(filename, stt_model, stt_tokenizer, myaudio, sub_start, sub_end, index):
|
921 |
+
"""
|
922 |
+
Transcribe an audio between a sub_start and a sub_end value (s)
|
923 |
+
:param filename: name of the audio file
|
924 |
+
:param stt_model: Speech to text model
|
925 |
+
:param stt_tokenizer: Speech to text model's tokenizer
|
926 |
+
:param myaudio: AudioSegment file
|
927 |
+
:param sub_start: start value (s) of the considered audio part to transcribe
|
928 |
+
:param sub_end: end value (s) of the considered audio part to transcribe
|
929 |
+
:param index: audio file counter
|
930 |
+
:return: transcription of the considered audio (only in uppercase, so we add lower() to make the reading easier)
|
931 |
+
"""
|
932 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
933 |
+
try:
|
934 |
+
with torch.no_grad():
|
935 |
+
new_audio = myaudio[sub_start:sub_end] # Works in milliseconds
|
936 |
+
path = filename[:-3] + "audio_" + str(index) + ".mp3"
|
937 |
+
new_audio.export(path) # Exports to a mp3 file in the current path
|
938 |
+
|
939 |
+
# Load audio file with librosa, set sound rate to 16000 Hz because the model we use was trained on 16000 Hz data
|
940 |
+
input_audio, _ = librosa.load(path, sr=16000,mono=True)
|
941 |
+
#audio = librosa.load(path,sr=16000,mono=True)
|
942 |
+
audio = whisper.load_audio(path)
|
943 |
+
audio = whisper.pad_or_trim(audio)
|
944 |
+
mel = whisper.log_mel_spectrogram(audio).to(stt_model.device)
|
945 |
+
# return PyTorch torch.Tensor instead of a list of python integers thanks to return_tensors = ‘pt’
|
946 |
+
input_values = stt_tokenizer(input_audio, return_tensors="pt").to(device).input_values
|
947 |
+
|
948 |
+
# Get logits from the data structure containing all the information returned by the model and get our prediction
|
949 |
+
#______________________________
|
950 |
+
#logits = stt_model.to(device)(input_values).logits
|
951 |
+
#prediction = torch.argmax(logits, dim=-1)
|
952 |
+
#_______________________________
|
953 |
+
# Decode & lower our string (model's output is only uppercase)
|
954 |
+
options = whisper.DecodingOptions(language='english', task='transcribe', without_timestamps=False)
|
955 |
+
if isinstance(stt_tokenizer, Wav2Vec2Tokenizer):
|
956 |
+
#transcription = stt_tokenizer.batch_decode(prediction)[0]
|
957 |
+
transcription = sst_model.decode(mel,options)
|
958 |
+
elif isinstance(stt_tokenizer, Wav2Vec2Processor):
|
959 |
+
#transcription = stt_tokenizer.decode(prediction[0])
|
960 |
+
result =stt_model.decode(mel,options)
|
961 |
+
transcription = result.text # sst_model.decode(mel,options)
|
962 |
+
# return transcription
|
963 |
+
return transcription
|
964 |
+
|
965 |
+
except audioread.NoBackendError:
|
966 |
+
# Means we have a chunk with a [value1 : value2] case with value1>value2
|
967 |
+
st.error("Sorry, seems we have a problem on our side. Please change start & end values.")
|
968 |
+
time.sleep(3)
|
969 |
+
st.stop()
|
970 |
+
|
971 |
+
|
972 |
+
def optimize_subtitles(transcription, srt_index, sub_start, sub_end, srt_text):
|
973 |
+
"""
|
974 |
+
Create & Optimize the subtitles (avoid a too long reading when many words are said in a short time)
|
975 |
+
The optimization (if statement) can sometimes create a gap between the subtitles and the video, if there is music
|
976 |
+
for example. In this case, it may be wise to disable the optimization, never going through the if statement.
|
977 |
+
:param transcription: transcript generated for an audio chunk
|
978 |
+
:param srt_index: Numeric counter that identifies each sequential subtitle
|
979 |
+
:param sub_start: beginning of the transcript
|
980 |
+
:param sub_end: end of the transcript
|
981 |
+
:param srt_text: generated .srt transcript
|
982 |
+
"""
|
983 |
+
|
984 |
+
transcription_length = len(transcription)
|
985 |
+
|
986 |
+
# Length of the transcript should be limited to about 42 characters per line to avoid this problem
|
987 |
+
if transcription_length > 42:
|
988 |
+
# Split the timestamp and its transcript in two parts
|
989 |
+
# Get the middle timestamp
|
990 |
+
diff = (timedelta(milliseconds=sub_end) - timedelta(milliseconds=sub_start)) / 2
|
991 |
+
middle_timestamp = str(timedelta(milliseconds=sub_start) + diff).split(".")[0]
|
992 |
+
|
993 |
+
# Get the closest middle index to a space (we don't divide transcription_length/2 to avoid cutting a word)
|
994 |
+
space_indexes = [pos for pos, char in enumerate(transcription) if char == " "]
|
995 |
+
nearest_index = min(space_indexes, key=lambda x: abs(x - transcription_length / 2))
|
996 |
+
|
997 |
+
# First transcript part
|
998 |
+
first_transcript = transcription[:nearest_index]
|
999 |
+
|
1000 |
+
# Second transcript part
|
1001 |
+
second_transcript = transcription[nearest_index + 1:]
|
1002 |
+
|
1003 |
+
# Add both transcript parts to the srt_text
|
1004 |
+
srt_text += str(srt_index) + "\n" + str(timedelta(milliseconds=sub_start)).split(".")[0] + " --> " + middle_timestamp + "\n" + first_transcript + "\n\n"
|
1005 |
+
srt_index += 1
|
1006 |
+
srt_text += str(srt_index) + "\n" + middle_timestamp + " --> " + str(timedelta(milliseconds=sub_end)).split(".")[0] + "\n" + second_transcript + "\n\n"
|
1007 |
+
srt_index += 1
|
1008 |
+
else:
|
1009 |
+
# Add transcript without operations
|
1010 |
+
srt_text += str(srt_index) + "\n" + str(timedelta(milliseconds=sub_start)).split(".")[0] + " --> " + str(timedelta(milliseconds=sub_end)).split(".")[0] + "\n" + transcription + "\n\n"
|
1011 |
+
|
1012 |
+
return srt_text, srt_index
|
1013 |
+
|
1014 |
+
|
1015 |
+
def display_transcription(diarization_token, summarize_token, srt_token, timestamps_token, transcription, save_result,
|
1016 |
+
txt_text, srt_text, srt_index, sub_start, sub_end, elt=None):
|
1017 |
+
"""
|
1018 |
+
Display results
|
1019 |
+
:param diarization_token: Differentiate or not the speakers (choice fixed by user)
|
1020 |
+
:param summarize_token: Summarize or not the transcript (choice fixed by user)
|
1021 |
+
:param srt_token: Enable/Disable generate srt file (choice fixed by user)
|
1022 |
+
:param timestamps_token: Display and save or not the timestamps (choice fixed by user)
|
1023 |
+
:param transcription: transcript of the considered audio
|
1024 |
+
:param save_result: whole process
|
1025 |
+
:param txt_text: generated .txt transcript
|
1026 |
+
:param srt_text: generated .srt transcript
|
1027 |
+
:param srt_index : numeric counter that identifies each sequential subtitle
|
1028 |
+
:param sub_start: start value (s) of the considered audio part to transcribe
|
1029 |
+
:param sub_end: end value (s) of the considered audio part to transcribe
|
1030 |
+
:param elt: timestamp (diarization case only, otherwise elt = None)
|
1031 |
+
"""
|
1032 |
+
# Display will be different depending on the mode (dia, no dia, dia_ts, nodia_ts)
|
1033 |
+
# diarization mode
|
1034 |
+
if diarization_token:
|
1035 |
+
|
1036 |
+
if summarize_token:
|
1037 |
+
update_session_state("summary", transcription + " ", concatenate_token=True)
|
1038 |
+
|
1039 |
+
if not timestamps_token:
|
1040 |
+
temp_transcription = elt[2] + " : " + transcription
|
1041 |
+
st.write(temp_transcription + "\n\n")
|
1042 |
+
|
1043 |
+
save_result.append([int(elt[2][-1]), elt[2], " : " + transcription])
|
1044 |
+
|
1045 |
+
elif timestamps_token:
|
1046 |
+
temp_timestamps = str(timedelta(milliseconds=sub_start)).split(".")[0] + " --> " + \
|
1047 |
+
str(timedelta(milliseconds=sub_end)).split(".")[0] + "\n"
|
1048 |
+
temp_transcription = elt[2] + " : " + transcription
|
1049 |
+
temp_list = [temp_timestamps, int(elt[2][-1]), elt[2], " : " + transcription, int(sub_start / 1000)]
|
1050 |
+
save_result.append(temp_list)
|
1051 |
+
st.button(temp_timestamps, on_click=click_timestamp_btn, args=(sub_start,))
|
1052 |
+
st.write(temp_transcription + "\n\n")
|
1053 |
+
|
1054 |
+
if srt_token:
|
1055 |
+
srt_text, srt_index = optimize_subtitles(transcription, srt_index, sub_start, sub_end, srt_text)
|
1056 |
+
|
1057 |
+
# Non diarization case
|
1058 |
+
else:
|
1059 |
+
if not timestamps_token:
|
1060 |
+
save_result.append([transcription])
|
1061 |
+
st.write(transcription + "\n\n")
|
1062 |
+
|
1063 |
+
else:
|
1064 |
+
temp_timestamps = str(timedelta(milliseconds=sub_start)).split(".")[0] + " --> " + \
|
1065 |
+
str(timedelta(milliseconds=sub_end)).split(".")[0] + "\n"
|
1066 |
+
temp_list = [temp_timestamps, transcription, int(sub_start / 1000)]
|
1067 |
+
save_result.append(temp_list)
|
1068 |
+
st.button(temp_timestamps, on_click=click_timestamp_btn, args=(sub_start,))
|
1069 |
+
st.write(transcription + "\n\n")
|
1070 |
+
|
1071 |
+
if srt_token:
|
1072 |
+
srt_text, srt_index = optimize_subtitles(transcription, srt_index, sub_start, sub_end, srt_text)
|
1073 |
+
|
1074 |
+
txt_text += transcription + " " # So x seconds sentences are separated
|
1075 |
+
|
1076 |
+
return save_result, txt_text, srt_text, srt_index
|
1077 |
+
|
1078 |
+
|
1079 |
+
def add_punctuation(t5_model, t5_tokenizer, transcript):
|
1080 |
+
"""
|
1081 |
+
Punctuate a transcript
|
1082 |
+
:return: Punctuated and improved (corrected) transcript
|
1083 |
+
"""
|
1084 |
+
input_text = "fix: { " + transcript + " } </s>"
|
1085 |
+
|
1086 |
+
input_ids = t5_tokenizer.encode(input_text, return_tensors="pt", max_length=10000, truncation=True,
|
1087 |
+
add_special_tokens=True)
|
1088 |
+
|
1089 |
+
outputs = t5_model.generate(
|
1090 |
+
input_ids=input_ids,
|
1091 |
+
max_length=256,
|
1092 |
+
num_beams=4,
|
1093 |
+
repetition_penalty=1.0,
|
1094 |
+
length_penalty=1.0,
|
1095 |
+
early_stopping=True
|
1096 |
+
)
|
1097 |
+
|
1098 |
+
transcript = t5_tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
1099 |
+
|
1100 |
+
return transcript
|
1101 |
+
|
1102 |
+
|
1103 |
+
def convert_file_to_wav(aud_seg, filename):
|
1104 |
+
"""
|
1105 |
+
Convert a mp3/mp4 in a wav format
|
1106 |
+
Needs to be modified if you want to convert a format which contains less or more than 3 letters
|
1107 |
+
:param aud_seg: pydub.AudioSegment
|
1108 |
+
:param filename: name of the file
|
1109 |
+
:return: name of the converted file
|
1110 |
+
"""
|
1111 |
+
filename = "../data/my_wav_file_" + filename[:-3] + "wav"
|
1112 |
+
aud_seg.export(filename, format="wav")
|
1113 |
+
|
1114 |
+
newaudio = AudioSegment.from_file(filename)
|
1115 |
+
|
1116 |
+
return newaudio, filename
|
1117 |
+
|
1118 |
+
|
1119 |
+
def get_diarization(dia_pipeline, filename):
|
1120 |
+
"""
|
1121 |
+
Diarize an audio (find numbers of speakers, when they speak, ...)
|
1122 |
+
:param dia_pipeline: Pyannote's library (diarization pipeline)
|
1123 |
+
:param filename: name of a wav audio file
|
1124 |
+
:return: str list containing audio's diarization time intervals
|
1125 |
+
"""
|
1126 |
+
# Get diarization of the audio
|
1127 |
+
diarization = dia_pipeline({'audio': filename})
|
1128 |
+
listmapping = diarization.labels()
|
1129 |
+
listnewmapping = []
|
1130 |
+
|
1131 |
+
# Rename default speakers' names (Default is A, B, ...), we want Speaker0, Speaker1, ...
|
1132 |
+
number_of_speakers = len(listmapping)
|
1133 |
+
for i in range(number_of_speakers):
|
1134 |
+
listnewmapping.append("Speaker" + str(i))
|
1135 |
+
|
1136 |
+
mapping_dict = dict(zip(listmapping, listnewmapping))
|
1137 |
+
|
1138 |
+
diarization.rename_labels(mapping_dict,
|
1139 |
+
copy=False) # copy set to False so we don't create a new annotation, we replace the actual on
|
1140 |
+
|
1141 |
+
return diarization, number_of_speakers
|
1142 |
+
|
1143 |
+
|
1144 |
+
def confirm_token_change(hf_token, page_index):
|
1145 |
+
"""
|
1146 |
+
A function that saves the hugging face token entered by the user.
|
1147 |
+
It also updates the page index variable so we can indicate we now want to display the home page instead of the token page
|
1148 |
+
:param hf_token: user's token
|
1149 |
+
:param page_index: number that represents the home page index (mentioned in the main.py file)
|
1150 |
+
"""
|
1151 |
+
update_session_state("my_HF_token", hf_token)
|
1152 |
+
update_session_state("page_index", page_index)
|
1153 |
+
|
1154 |
+
|
1155 |
+
def convert_str_diarlist_to_timedelta(diarization_result):
|
1156 |
+
"""
|
1157 |
+
Extract from Diarization result the given speakers with their respective speaking times and transform them in pandas timedelta objects
|
1158 |
+
:param diarization_result: result of diarization
|
1159 |
+
:return: list with timedelta intervals and their respective speaker
|
1160 |
+
"""
|
1161 |
+
|
1162 |
+
# get speaking intervals from diarization
|
1163 |
+
segments = diarization_result.for_json()["content"]
|
1164 |
+
diarization_timestamps = []
|
1165 |
+
for sample in segments:
|
1166 |
+
# Convert segment in a pd.Timedelta object
|
1167 |
+
new_seg = [pd.Timedelta(seconds=round(sample["segment"]["start"], 2)),
|
1168 |
+
pd.Timedelta(seconds=round(sample["segment"]["end"], 2)), sample["label"]]
|
1169 |
+
# Start and end = speaking duration
|
1170 |
+
# label = who is speaking
|
1171 |
+
diarization_timestamps.append(new_seg)
|
1172 |
+
|
1173 |
+
return diarization_timestamps
|
1174 |
+
|
1175 |
+
|
1176 |
+
def merge_speaker_times(diarization_timestamps, max_space, srt_token):
|
1177 |
+
"""
|
1178 |
+
Merge near times for each detected speaker (Same speaker during 1-2s and 3-4s -> Same speaker during 1-4s)
|
1179 |
+
:param diarization_timestamps: diarization list
|
1180 |
+
:param max_space: Maximum temporal distance between two silences
|
1181 |
+
:param srt_token: Enable/Disable generate srt file (choice fixed by user)
|
1182 |
+
:return: list with timedelta intervals and their respective speaker
|
1183 |
+
"""
|
1184 |
+
if not srt_token:
|
1185 |
+
threshold = pd.Timedelta(seconds=max_space / 1000)
|
1186 |
+
|
1187 |
+
index = 0
|
1188 |
+
length = len(diarization_timestamps) - 1
|
1189 |
+
|
1190 |
+
while index < length:
|
1191 |
+
if diarization_timestamps[index + 1][2] == diarization_timestamps[index][2] and \
|
1192 |
+
diarization_timestamps[index + 1][1] - threshold <= diarization_timestamps[index][0]:
|
1193 |
+
diarization_timestamps[index][1] = diarization_timestamps[index + 1][1]
|
1194 |
+
del diarization_timestamps[index + 1]
|
1195 |
+
length -= 1
|
1196 |
+
else:
|
1197 |
+
index += 1
|
1198 |
+
return diarization_timestamps
|
1199 |
+
|
1200 |
+
|
1201 |
+
def extending_timestamps(new_diarization_timestamps):
|
1202 |
+
"""
|
1203 |
+
Extend timestamps between each diarization timestamp if possible, so we avoid word cutting
|
1204 |
+
:param new_diarization_timestamps: list
|
1205 |
+
:return: list with merged times
|
1206 |
+
"""
|
1207 |
+
for i in range(1, len(new_diarization_timestamps)):
|
1208 |
+
if new_diarization_timestamps[i][0] - new_diarization_timestamps[i - 1][1] <= timedelta(milliseconds=3000) and \
|
1209 |
+
new_diarization_timestamps[i][0] - new_diarization_timestamps[i - 1][1] >= timedelta(milliseconds=100):
|
1210 |
+
middle = (new_diarization_timestamps[i][0] - new_diarization_timestamps[i - 1][1]) / 2
|
1211 |
+
new_diarization_timestamps[i][0] -= middle
|
1212 |
+
new_diarization_timestamps[i - 1][1] += middle
|
1213 |
+
|
1214 |
+
# Converting list so we have a milliseconds format
|
1215 |
+
for elt in new_diarization_timestamps:
|
1216 |
+
elt[0] = elt[0].total_seconds() * 1000
|
1217 |
+
elt[1] = elt[1].total_seconds() * 1000
|
1218 |
+
|
1219 |
+
return new_diarization_timestamps
|
1220 |
+
|
1221 |
+
|
1222 |
+
def clean_directory(path):
|
1223 |
+
"""
|
1224 |
+
Clean files of directory
|
1225 |
+
:param path: directory's path
|
1226 |
+
"""
|
1227 |
+
for file in os.listdir(path):
|
1228 |
+
os.remove(os.path.join(path, file))
|
1229 |
+
|
1230 |
+
|
1231 |
+
def correct_values(start, end, audio_length):
|
1232 |
+
"""
|
1233 |
+
Start or/and end value(s) can be in conflict, so we check these values
|
1234 |
+
:param start: int value (s) given by st.slider() (fixed by user)
|
1235 |
+
:param end: int value (s) given by st.slider() (fixed by user)
|
1236 |
+
:param audio_length: audio duration (s)
|
1237 |
+
:return: approved values
|
1238 |
+
"""
|
1239 |
+
# Start & end Values need to be checked
|
1240 |
+
|
1241 |
+
if start >= audio_length or start >= end:
|
1242 |
+
start = 0
|
1243 |
+
st.write("Start value has been set to 0s because of conflicts with other values")
|
1244 |
+
|
1245 |
+
if end > audio_length or end == 0:
|
1246 |
+
end = audio_length
|
1247 |
+
st.write("End value has been set to maximum value because of conflicts with other values")
|
1248 |
+
|
1249 |
+
return start, end
|
1250 |
+
|
1251 |
+
|
1252 |
+
def split_text(my_text, max_size):
|
1253 |
+
"""
|
1254 |
+
Split a text
|
1255 |
+
Maximum sequence length for this model is max_size.
|
1256 |
+
If the transcript is longer, it needs to be split by the nearest possible value to max_size.
|
1257 |
+
To avoid cutting words, we will cut on "." characters, and " " if there is not "."
|
1258 |
+
:return: split text
|
1259 |
+
"""
|
1260 |
+
|
1261 |
+
cut2 = max_size
|
1262 |
+
|
1263 |
+
# First, we get indexes of "."
|
1264 |
+
my_split_text_list = []
|
1265 |
+
nearest_index = 0
|
1266 |
+
length = len(my_text)
|
1267 |
+
# We split the transcript in text blocks of size <= max_size.
|
1268 |
+
if cut2 == length:
|
1269 |
+
my_split_text_list.append(my_text)
|
1270 |
+
else:
|
1271 |
+
while cut2 <= length:
|
1272 |
+
cut1 = nearest_index
|
1273 |
+
cut2 = nearest_index + max_size
|
1274 |
+
# Find the best index to split
|
1275 |
+
|
1276 |
+
dots_indexes = [index for index, char in enumerate(my_text[cut1:cut2]) if
|
1277 |
+
char == "."]
|
1278 |
+
if dots_indexes != []:
|
1279 |
+
nearest_index = max(dots_indexes) + 1 + cut1
|
1280 |
+
else:
|
1281 |
+
spaces_indexes = [index for index, char in enumerate(my_text[cut1:cut2]) if
|
1282 |
+
char == " "]
|
1283 |
+
if spaces_indexes != []:
|
1284 |
+
nearest_index = max(spaces_indexes) + 1 + cut1
|
1285 |
+
else:
|
1286 |
+
nearest_index = cut2 + cut1
|
1287 |
+
my_split_text_list.append(my_text[cut1: nearest_index])
|
1288 |
+
|
1289 |
+
return my_split_text_list
|
1290 |
+
|
1291 |
+
|
1292 |
+
def update_session_state(var, data, concatenate_token=False):
|
1293 |
+
"""
|
1294 |
+
A simple function to update a session state variable
|
1295 |
+
:param var: variable's name
|
1296 |
+
:param data: new value of the variable
|
1297 |
+
:param concatenate_token: do we replace or concatenate
|
1298 |
+
"""
|
1299 |
+
|
1300 |
+
if concatenate_token:
|
1301 |
+
st.session_state[var] += data
|
1302 |
+
else:
|
1303 |
+
st.session_state[var] = data
|
1304 |
+
|
1305 |
+
|
1306 |
+
def display_results():
|
1307 |
+
"""
|
1308 |
+
Display Results page
|
1309 |
+
This function allows you to display saved results after clicking a button. Without it, Streamlit automatically
|
1310 |
+
reload the whole page when clicking a button, so you would lose all the generated transcript which would be very
|
1311 |
+
frustrating for the user.
|
1312 |
+
"""
|
1313 |
+
|
1314 |
+
# Add a button to return to the main page
|
1315 |
+
st.button("Load an other file", on_click=update_session_state, args=("page_index", 0,))
|
1316 |
+
|
1317 |
+
# Display results
|
1318 |
+
st.audio(st.session_state['audio_file'], start_time=st.session_state["start_time"])
|
1319 |
+
|
1320 |
+
# Display results of transcript by steps
|
1321 |
+
if st.session_state["process"] != []:
|
1322 |
+
|
1323 |
+
if st.session_state["chosen_mode"] == "NODIA": # Non diarization, non timestamps case
|
1324 |
+
for elt in (st.session_state['process']):
|
1325 |
+
st.write(elt[0])
|
1326 |
+
|
1327 |
+
elif st.session_state["chosen_mode"] == "DIA": # Diarization without timestamps case
|
1328 |
+
for elt in (st.session_state['process']):
|
1329 |
+
st.write(elt[1] + elt[2])
|
1330 |
+
|
1331 |
+
elif st.session_state["chosen_mode"] == "NODIA_TS": # Non diarization with timestamps case
|
1332 |
+
for elt in (st.session_state['process']):
|
1333 |
+
st.button(elt[0], on_click=update_session_state, args=("start_time", elt[2],))
|
1334 |
+
st.write(elt[1])
|
1335 |
+
|
1336 |
+
elif st.session_state["chosen_mode"] == "DIA_TS": # Diarization with timestamps case
|
1337 |
+
for elt in (st.session_state['process']):
|
1338 |
+
st.button(elt[0], on_click=update_session_state, args=("start_time", elt[4],))
|
1339 |
+
st.write(elt[2] + elt[3])
|
1340 |
+
|
1341 |
+
# Display final text
|
1342 |
+
st.subheader("Final text is")
|
1343 |
+
st.write(st.session_state["txt_transcript"])
|
1344 |
+
|
1345 |
+
# Display Summary
|
1346 |
+
if st.session_state["summary"] != "":
|
1347 |
+
with st.expander("Summary"):
|
1348 |
+
st.write(st.session_state["summary"])
|
1349 |
+
|
1350 |
+
# Display the buttons in a list to avoid having empty columns (explained in the transcription() function)
|
1351 |
+
col1, col2, col3, col4 = st.columns(4)
|
1352 |
+
col_list = [col1, col2, col3, col4]
|
1353 |
+
col_index = 0
|
1354 |
+
|
1355 |
+
for elt in st.session_state["btn_token_list"]:
|
1356 |
+
if elt[0]:
|
1357 |
+
mycol = col_list[col_index]
|
1358 |
+
if elt[1] == "useless_txt_token":
|
1359 |
+
# Download your transcription.txt
|
1360 |
+
with mycol:
|
1361 |
+
st.download_button("Download as TXT", st.session_state["txt_transcript"],
|
1362 |
+
file_name="my_transcription.txt")
|
1363 |
+
|
1364 |
+
elif elt[1] == "srt_token":
|
1365 |
+
# Download your transcription.srt
|
1366 |
+
with mycol:
|
1367 |
+
st.download_button("Download as SRT", st.session_state["srt_txt"], file_name="my_transcription.srt")
|
1368 |
+
elif elt[1] == "dia_token":
|
1369 |
+
with mycol:
|
1370 |
+
# Rename the speakers detected in your audio
|
1371 |
+
st.button("Rename Speakers", on_click=update_session_state, args=("page_index", 2,))
|
1372 |
+
|
1373 |
+
elif elt[1] == "summarize_token":
|
1374 |
+
with mycol:
|
1375 |
+
st.download_button("Download Summary", st.session_state["summary"], file_name="my_summary.txt")
|
1376 |
+
col_index += 1
|
1377 |
+
|
1378 |
+
|
1379 |
+
def click_timestamp_btn(sub_start):
|
1380 |
+
"""
|
1381 |
+
When user clicks a Timestamp button, we go to the display results page and st.audio is set to the sub_start value)
|
1382 |
+
It allows the user to listen to the considered part of the audio
|
1383 |
+
:param sub_start: Beginning of the considered transcript (ms)
|
1384 |
+
"""
|
1385 |
+
update_session_state("page_index", 1)
|
1386 |
+
update_session_state("start_time", int(sub_start / 1000)) # division to convert ms to s
|
1387 |
+
|
1388 |
+
|
1389 |
+
def diarization_treatment(filename, dia_pipeline, max_space, srt_token):
|
1390 |
+
"""
|
1391 |
+
Launch the whole diarization process to get speakers time intervals as pandas timedelta objects
|
1392 |
+
:param filename: name of the audio file
|
1393 |
+
:param dia_pipeline: Diarization Model (Differentiate speakers)
|
1394 |
+
:param max_space: Maximum temporal distance between two silences
|
1395 |
+
:param srt_token: Enable/Disable generate srt file (choice fixed by user)
|
1396 |
+
:return: speakers time intervals list and number of different detected speakers
|
1397 |
+
"""
|
1398 |
+
# initialization
|
1399 |
+
diarization_timestamps = []
|
1400 |
+
|
1401 |
+
# whole diarization process
|
1402 |
+
diarization, number_of_speakers = get_diarization(dia_pipeline, filename)
|
1403 |
+
|
1404 |
+
if len(diarization) > 0:
|
1405 |
+
diarization_timestamps = convert_str_diarlist_to_timedelta(diarization)
|
1406 |
+
diarization_timestamps = merge_speaker_times(diarization_timestamps, max_space, srt_token)
|
1407 |
+
diarization_timestamps = extending_timestamps(diarization_timestamps)
|
1408 |
+
|
1409 |
+
return diarization_timestamps, number_of_speakers
|
1410 |
+
|
1411 |
+
|
1412 |
+
def extract_audio_from_yt_video(url):
|
1413 |
+
"""
|
1414 |
+
Extracts audio from a YouTube url
|
1415 |
+
:param url: link of a YT video
|
1416 |
+
:return: name of the saved audio file
|
1417 |
+
"""
|
1418 |
+
filename = "yt_download_" + url[-11:] + ".mp3"
|
1419 |
+
try:
|
1420 |
+
|
1421 |
+
ydl_opts = {
|
1422 |
+
'format': 'bestaudio/best',
|
1423 |
+
'outtmpl': filename,
|
1424 |
+
'postprocessors': [{
|
1425 |
+
'key': 'FFmpegExtractAudio',
|
1426 |
+
'preferredcodec': 'mp3',
|
1427 |
+
}],
|
1428 |
+
}
|
1429 |
+
with st.spinner("We are extracting the audio from the video"):
|
1430 |
+
with youtube_dl.YoutubeDL(ydl_opts) as ydl:
|
1431 |
+
ydl.download([url])
|
1432 |
+
|
1433 |
+
# Handle DownloadError: ERROR: unable to download video data: HTTP Error 403: Forbidden / happens sometimes
|
1434 |
+
except DownloadError:
|
1435 |
+
filename = None
|
1436 |
+
|
1437 |
+
return filename
|
1438 |
+
|
1439 |
+
|
main.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from app import *
|
2 |
+
|
3 |
+
if __name__ == '__main__':
|
4 |
+
config()
|
5 |
+
|
6 |
+
if st.session_state['page_index'] == -1:
|
7 |
+
# Specify token page (mandatory to use the diarization option)
|
8 |
+
st.warning('You must specify a token to use the diarization model. Otherwise, the app will be launched without this model. You can learn how to create your token here: https://huggingface.co/pyannote/speaker-diarization')
|
9 |
+
text_input = st.text_input("Enter your Hugging Face token:", placeholder="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", type="password")
|
10 |
+
|
11 |
+
# Confirm or continue without the option
|
12 |
+
col1, col2 = st.columns(2)
|
13 |
+
|
14 |
+
# save changes button
|
15 |
+
with col1:
|
16 |
+
confirm_btn = st.button("I have changed my token", on_click=confirm_token_change, args=(text_input, 0), disabled=st.session_state["disable"])
|
17 |
+
# if text is changed, button is clickable
|
18 |
+
if text_input != "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":
|
19 |
+
st.session_state["disable"] = False
|
20 |
+
|
21 |
+
# Continue without a token (there will be no diarization option)
|
22 |
+
with col2:
|
23 |
+
dont_mind_btn = st.button("Continue without this option", on_click=update_session_state, args=("page_index", 0))
|
24 |
+
|
25 |
+
if st.session_state['page_index'] == 0:
|
26 |
+
# Home page
|
27 |
+
choice = st.radio("Features", ["By a video URL", "By uploading a file"])
|
28 |
+
|
29 |
+
stt_tokenizer, stt_model, t5_tokenizer, t5_model, summarizer, dia_pipeline = load_models()
|
30 |
+
|
31 |
+
if choice == "By a video URL":
|
32 |
+
transcript_from_url(stt_tokenizer, stt_model, t5_tokenizer, t5_model, summarizer, dia_pipeline)
|
33 |
+
|
34 |
+
elif choice == "By uploading a file":
|
35 |
+
transcript_from_file(stt_tokenizer, stt_model, t5_tokenizer, t5_model, summarizer, dia_pipeline)
|
36 |
+
|
37 |
+
elif st.session_state['page_index'] == 1:
|
38 |
+
# Results page
|
39 |
+
display_results()
|
40 |
+
|
41 |
+
elif st.session_state['page_index'] == 2:
|
42 |
+
# Rename speakers page
|
43 |
+
rename_speakers_window()
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
libsndfile1-dev
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
librosa==0.9.1
|
2 |
+
youtube_dl==2021.12.17
|
3 |
+
numba==0.56.4
|
4 |
+
streamlit==1.9.0
|
5 |
+
transformers==4.18.0
|
6 |
+
httplib2==0.20.2
|
7 |
+
torch==1.11.0
|
8 |
+
torchtext==0.12.0
|
9 |
+
torchaudio==0.11.0
|
10 |
+
sentencepiece==0.1.96
|
11 |
+
tokenizers==0.12.1
|
12 |
+
pyannote.audio==2.1.1
|
13 |
+
pyannote.core==4.4
|
14 |
+
pydub==0.25.1
|
15 |
+
|