vmoras commited on
Commit
d1701ad
Β·
1 Parent(s): 37890b1

Improve front and add some util functions

Browse files
Files changed (6) hide show
  1. .gitignore +4 -0
  2. app.py +128 -0
  3. audio.py +201 -0
  4. audio_model.py +39 -0
  5. requirements.txt +12 -0
  6. utils.py +73 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .idea/
2
+ .venv/
3
+ __pycache__/
4
+ .env
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ load_dotenv()
3
+
4
+ from utils import *
5
+ import gradio as gr
6
+
7
+
8
+ with gr.Blocks() as app:
9
+ with gr.Tab('General info'):
10
+ client = gr.Textbox(label='Nombre del cliente', placeholder='Inserte el nombre del cliente')
11
+ language = gr.Checkboxgroup(
12
+ choices=['espaΓ±ol', 'ingles', 'portuguΓ©s'], value='espaΓ±ol', label='Idiomas', interactive=True,
13
+ info='Seleccione todos los idiomas que el chatbot va a hablar (al menos debe tener 1 idioma)'
14
+ )
15
+ name = gr.Dropdown(
16
+ choices=['Bella'], value='Bella', label='Nombre del chatbot',
17
+ info='Seleccione el nombre del chatbot, si no se encuentra en la lista, contacte al administrador'
18
+ )
19
+ num_questions = gr.Number(
20
+ value=5, minimum=2, maximum=10, label='NΓΊmero preguntas', interactive=True,
21
+ info='MΓ‘ximo numero de preguntas que puede hacer el usuario.'
22
+ )
23
+
24
+ with gr.Tab('Images'):
25
+ base_image = gr.Image(label='Imagen base para los videos', sources=['upload'])
26
+
27
+ with gr.Tab('Greeting and goodbye'):
28
+ _ = gr.Markdown(
29
+ 'Ingrese los saludos, despedidas y mensajes de error que deba usar el chatbot.'
30
+ )
31
+ with gr.Row():
32
+ greet = gr.Textbox(label='Mensaje', info='Ingrese el mensaje a decir por el chatbot.')
33
+ type_greet = gr.Dropdown(
34
+ choices=['Saludo', 'Despedida', 'Error'], value='Saludo', interactive=True,
35
+ info='Seleccione si es saludo, despedida o mensaje de error.'
36
+ )
37
+ send_greet_button = gr.Button(value='AΓ±adir')
38
+ messages_table = gr.DataFrame(
39
+ headers=['Eliminar', 'Tipo mensaje', 'Mensaje'], type='array', interactive=False
40
+ )
41
+
42
+ with gr.Tab('Random data'):
43
+ _ = gr.Markdown(
44
+ 'Si quiere que Bella diga algunos datos random mientras busca la informaciΓ³n, ingrese dichos pΓ‘rrafos aca.'
45
+ )
46
+ with gr.Row():
47
+ random_data = gr.Text(placeholder='Ingrese el dato random', label='Dato random')
48
+ send_random_button = gr.Button(value='AΓ±adir')
49
+ random_table = gr.DataFrame(headers=['Eliminar', 'Dato random'], type='array', interactive=False)
50
+
51
+ with gr.Tab('Questions - Context'):
52
+ with gr.Row():
53
+ question = gr.Text(placeholder='Ingrese su pregunta', label='Pregunta')
54
+ context = gr.Text(placeholder='Ingrese el pΓ‘rrafo u oraciΓ³n que contesta dicha pregunta', label='Contexto')
55
+ send_question_button = gr.Button(value='AΓ±adir')
56
+ questions_table = gr.DataFrame(
57
+ headers=['Eliminar', 'Pregunta', 'Contexto'], type='array', interactive=False
58
+ )
59
+
60
+ with gr.Tab('General prompt'):
61
+ general_prompt = gr.Text(placeholder='Ingrese el prompt general del bot', label='Prompt')
62
+
63
+ with gr.Tab('Context prompt'):
64
+ context_prompt = gr.Text(placeholder='Ingrese el prompt usado para encontrar el contexto', label='Prompt')
65
+
66
+ with gr.Tab('Create chatbot'):
67
+ _ = gr.Markdown(
68
+ "AsegΓΊrese que toda la informaciΓ³n este correcta antes de enviarla."
69
+ )
70
+ create_chatbot_button = gr.Button(value='Crear chatbot')
71
+
72
+ with gr.Tab('Test'):
73
+ with gr.Row():
74
+ with gr.Column():
75
+ with gr.Row():
76
+ video = gr.Video(interactive=False, label='Video', autoplay=True)
77
+ with gr.Row():
78
+ output_audio = gr.Audio(interactive=False, label='Audio', autoplay=True)
79
+
80
+ with gr.Column():
81
+ with gr.Row():
82
+ chat = gr.Chatbot(label='Chat')
83
+ with gr.Row():
84
+ text = gr.Text(label='Write your question')
85
+
86
+ with gr.Tab('Submit'):
87
+ _ = gr.Markdown(
88
+ "AsegΓΊrese que hizo las suficientes pruebas para aprobar el chatbot."
89
+ )
90
+ submit_button = gr.Button(value='ENVIAR!')
91
+ output_file = gr.File(interactive=False, label='Output file')
92
+
93
+ # ----------------------------------------------- ACTIONS -----------------------------------------------------
94
+
95
+ # Add info to the tables
96
+ send_greet_button.click(
97
+ add_data_table, [messages_table, type_greet, greet], [messages_table, greet]
98
+ )
99
+ send_random_button.click(
100
+ add_data_table, [random_table, random_data], [random_table, random_data]
101
+ )
102
+ send_question_button.click(
103
+ add_data_table, [questions_table, question, context], [questions_table, question, context]
104
+ )
105
+
106
+ # Remove info from the tables
107
+ messages_table.select(
108
+ remove_data_table, messages_table, messages_table
109
+ )
110
+ random_table.select(
111
+ remove_data_table, random_table, random_table
112
+ )
113
+ questions_table.select(
114
+ remove_data_table, questions_table, questions_table
115
+ )
116
+
117
+ # Create the chatbot: create media and vectorstore
118
+ create_chatbot_button.click(
119
+ lambda: gr.Button(value='Creating chatbot...', interactive=False),
120
+ None,
121
+ create_chatbot_button
122
+ ).then(
123
+ create_chatbot,
124
+ [client, language, name, base_image, messages_table, random_table, questions_table],
125
+ create_chatbot_button
126
+ )
127
+
128
+ app.launch(debug=True)
audio.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import nltk
4
+ import torch
5
+ import pickle
6
+ import torchaudio
7
+ import numpy as np
8
+ import gradio as gr
9
+ from google.cloud import storage
10
+ from TTS.tts.models.xtts import Xtts
11
+ from nltk.tokenize import sent_tokenize
12
+ from huggingface_hub import hf_hub_download
13
+ from TTS.tts.configs.xtts_config import XttsConfig
14
+
15
+
16
+ def _download_starting_files() -> None:
17
+ """
18
+ Downloads the embeddings from a bucket
19
+ """
20
+ os.makedirs('assets', exist_ok=True)
21
+
22
+ # Download credentials file
23
+ hf_hub_download(
24
+ repo_id=os.environ.get('DATA'), repo_type='dataset', filename="credentials.json",
25
+ token=os.environ.get('HUB_TOKEN'), local_dir="assets"
26
+ )
27
+
28
+ # Initialise a client
29
+ credentials = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
30
+ storage_client = storage.Client.from_service_account_json(credentials)
31
+ bucket = storage_client.get_bucket('embeddings-bella')
32
+
33
+ # Get both embeddings
34
+ blob = bucket.blob("gpt_cond_latent.npy")
35
+ blob.download_to_filename('assets/gpt_cond_latent.npy')
36
+ blob = bucket.blob("speaker_embedding.npy")
37
+ blob.download_to_filename('assets/speaker_embedding.npy')
38
+
39
+
40
+ def _load_array(filename):
41
+ """
42
+ Opens a file a returns it, used with numpy files
43
+ """
44
+ with open(filename, 'rb') as f:
45
+ return pickle.load(f)
46
+
47
+
48
+ # Get embeddings
49
+ _download_starting_files()
50
+ os.environ['COQUI_TOS_AGREED'] = '1'
51
+
52
+ # Used to generate audio based on a sample
53
+ nltk.download('punkt')
54
+ model_path = os.path.join("tts_model")
55
+
56
+ config = XttsConfig()
57
+ config.load_json(os.path.join(model_path, "config.json"))
58
+
59
+ model = Xtts.init_from_config(config)
60
+ model.load_checkpoint(
61
+ config,
62
+ checkpoint_path=os.path.join(model_path, "model.pth"),
63
+ vocab_path=os.path.join(model_path, "vocab.json"),
64
+ eval=True,
65
+ use_deepspeed=True,
66
+ )
67
+
68
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
69
+ model.to(device)
70
+
71
+ # Speaker latent
72
+ path_latents = 'assets/gpt_cond_latent.npy'
73
+ gpt_cond_latent = _load_array(path_latents)
74
+
75
+ # Speaker embedding
76
+ path_embedding = 'assets/speaker_embedding.npy'
77
+ speaker_embedding = _load_array(path_embedding)
78
+
79
+
80
+ def get_audio(text: str, language: str = 'es') -> gr.Audio:
81
+ """
82
+ Returns a link from a bucket in GCP that contains the generated audio given a text and language and the
83
+ name of such audio
84
+ :param text: used to generate the audio
85
+ :param language: 'es', 'en' or 'pt'
86
+ :return link_audio and name_audio
87
+ """
88
+ # Creates an audio with the answer and saves it as output.wav
89
+ _save_audio(text, language)
90
+
91
+ return gr.Audio(value='output.wav', interactive=False, visible=True)
92
+
93
+
94
+ def _save_audio(answer: str, language: str) -> None:
95
+ """
96
+ Splits the answer into sentences, clean and creates an audio for each one, then concatenates
97
+ all the audios and saves them into a file (output.wav)
98
+ """
99
+ # Split the answer into sentences and clean it
100
+ sentences = _get_clean_answer(answer, language)
101
+
102
+ # Get the voice of each sentence
103
+ audio_segments = []
104
+ for sentence in sentences:
105
+ audio_stream = _get_voice(sentence, language)
106
+ audio_stream = torch.tensor(audio_stream)
107
+ audio_segments.append(audio_stream)
108
+
109
+ # Concatenate and save all audio segments
110
+ concatenated_audio = torch.cat(audio_segments, dim=0)
111
+ torchaudio.save('output.wav', concatenated_audio.unsqueeze(0), 24000)
112
+
113
+
114
+ def _get_voice(sentence: str, language: str) -> np.ndarray:
115
+ """
116
+ Returns a numpy array with a wav of an audio with the given sentence and language
117
+ """
118
+ out = model.inference(
119
+ sentence,
120
+ language=language,
121
+ gpt_cond_latent=gpt_cond_latent,
122
+ speaker_embedding=speaker_embedding,
123
+ temperature=0.1
124
+ )
125
+ return out['wav']
126
+
127
+
128
+ def _get_clean_answer(answer: str, language: str) -> list[str]:
129
+ """
130
+ Returns a list of sentences of the answer. It also removes links
131
+ """
132
+ # Remove the links in the audio and add another sentence
133
+ if language == 'en':
134
+ clean_answer = re.sub(r'http[s]?://\S+', 'the following link', answer)
135
+ max_characters = 250
136
+ elif language == 'es':
137
+ clean_answer = re.sub(r'http[s]?://\S+', 'el siguiente link', answer)
138
+ max_characters = 239
139
+ else:
140
+ clean_answer = re.sub(r'http[s]?://\S+', 'o seguinte link', answer)
141
+ max_characters = 203
142
+
143
+ # Change the name from Bella to Bela
144
+ clean_answer = clean_answer.replace('Bella', 'Bela')
145
+
146
+ # Remove Florida and zipcode
147
+ clean_answer = re.sub(r', FL \d+', "", clean_answer)
148
+
149
+ # Split the answer into sentences with nltk and make sure they are shorter than the maximum possible
150
+ # characters
151
+ split_sentences = sent_tokenize(clean_answer)
152
+ sentences = []
153
+ for sentence in split_sentences:
154
+ if len(sentence) > max_characters:
155
+ sentences.extend(_split_sentence(sentence, max_characters))
156
+ else:
157
+ sentences.append(sentence)
158
+
159
+ return sentences
160
+
161
+
162
+ def _split_sentence(sentence: str, max_characters: int) -> list[str]:
163
+ """
164
+ Returns a split sentences. The split point is the nearest comma to the middle
165
+ of the sentence, if there is no comma then a space is used or just the middle. If the
166
+ remaining sentences are still too long, another iteration is run
167
+ """
168
+ # Get index of each comma
169
+ sentences = []
170
+ commas = [i for i, c in enumerate(sentence) if c == ',']
171
+
172
+ # No commas, search for spaces
173
+ if len(commas) == 0:
174
+ commas = [i for i, c in enumerate(sentence) if c == ' ']
175
+
176
+ # No commas or spaces, split it in the middle
177
+ if len(commas) == 0:
178
+ sentences.append(sentence[:len(sentence) // 2])
179
+ sentences.append(sentence[len(sentence) // 2:])
180
+ return sentences
181
+
182
+ # Nearest index to the middle
183
+ split_point = min(commas, key=lambda x: abs(x - (len(sentence) // 2)))
184
+
185
+ if sentence[split_point] == ',':
186
+ left = sentence[:split_point]
187
+ right = sentence[split_point + 2:]
188
+ else:
189
+ left = sentence[:split_point]
190
+ right = sentence[split_point + 1:]
191
+
192
+ if len(left) > max_characters:
193
+ sentences.extend(_split_sentence(left, max_characters))
194
+ else:
195
+ sentences.append(left)
196
+ if len(right) > max_characters:
197
+ sentences.extend(_split_sentence(right, max_characters))
198
+ else:
199
+ sentences.append(right)
200
+
201
+ return sentences
audio_model.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from tqdm import tqdm
4
+
5
+
6
+ def _download_file(url, destination):
7
+ response = requests.get(url, stream=True)
8
+ total_size_in_bytes = int(response.headers.get('content-length', 0))
9
+ block_size = 1024
10
+
11
+ progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
12
+
13
+ with open(destination, 'wb') as file:
14
+ for data in response.iter_content(block_size):
15
+ progress_bar.update(len(data))
16
+ file.write(data)
17
+
18
+ progress_bar.close()
19
+
20
+
21
+ def download_model():
22
+ # Define files and their corresponding URLs
23
+ files_to_download = {
24
+ 'LICENSE.txt': 'https://huggingface.co/coqui/XTTS-v2/resolve/v2.0.2/LICENSE.txt?download=true',
25
+ 'README.md': 'https://huggingface.co/coqui/XTTS-v2/resolve/v2.0.2/README.md?download=true',
26
+ 'config.json': 'https://huggingface.co/coqui/XTTS-v2/resolve/v2.0.2/config.json?download=true',
27
+ 'model.pth': 'https://huggingface.co/coqui/XTTS-v2/resolve/v2.0.2/model.pth?download=true',
28
+ 'vocab.json': 'https://huggingface.co/coqui/XTTS-v2/resolve/v2.0.2/vocab.json?download=true',
29
+ }
30
+
31
+ if not os.path.exists("tts_model"):
32
+ os.makedirs("tts_model")
33
+
34
+ # Download files if they don't exist
35
+ print("[COQUI TTS] STARTUP: Checking Model is Downloaded.")
36
+ for filename, url in files_to_download.items():
37
+ destination = f'tts_model/{filename}'
38
+ print(f"[COQUI TTS] STARTUP: Downloading {filename}...")
39
+ _download_file(url, destination)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pinecone-client==2.2.4
2
+ python-dotenv==1.0.0
3
+ google-cloud-storage==2.13.0
4
+ requests==2.31.0
5
+ tqdm==4.66.1
6
+ nltk==3.8.1
7
+ # deepspeed==0.12.3
8
+ torch==2.1.1
9
+ torchaudio==2.1.1
10
+ TTS==0.21.2
11
+ google-cloud-storage==2.13.0
12
+ numpy==1.22.0
utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import audio_model
4
+ if not os.path.exists('tts_model'): # Get TTS model
5
+ audio_model.download_model()
6
+ import audio
7
+
8
+
9
+ def add_data_table(table: list[list[str]], first: str, last: str = None):
10
+ """
11
+ Adds the data to the table. Some data consist of two columns others only one.
12
+ So depending on that, the new row and returned value will be different-
13
+ """
14
+ if last is None:
15
+ new_row = ['❌', first]
16
+ new_value = ''
17
+ elif first == 'Saludo' or first == 'Despedida' or first == 'Error':
18
+ new_row = ['❌', first, last]
19
+ new_value = '', first
20
+ else:
21
+ new_row = ['❌', first, last]
22
+ new_value = '', ''
23
+
24
+ # The table is empty, do not append it but replace the first row
25
+ if all(column == '' for column in table[0]):
26
+ table[0] = new_row
27
+
28
+ # Add the new data
29
+ else:
30
+ table.append(new_row)
31
+
32
+ if last is None:
33
+ return table, new_value
34
+ return table, *new_value
35
+
36
+
37
+ def remove_data_table(table: list[list[str]], evt: gr.SelectData):
38
+ """
39
+ Deletes a row on the table if the selected column is the first one
40
+ """
41
+ # The clicked column is not the first one (the one with the X), do not do anything
42
+ if evt.index[1] != 0:
43
+ return table
44
+
45
+ # The list only has one row, do not delete it, just put the default one
46
+ if len(table) == 1:
47
+ table[0] = ['' for _ in range(len(table[0]))]
48
+
49
+ # Delete the row
50
+ else:
51
+ del table[evt.index[0]]
52
+ return table
53
+
54
+
55
+ def create_chatbot(
56
+ client: str, language: list[str], chatbot: str, messages_table, random_table, questions_table,
57
+ ):
58
+ # Set up general info
59
+ client_name = client.lower().replace(' ', '-')
60
+ chatbot_name = chatbot.lower()
61
+
62
+ # Create prerecorded media (greeting, goodbye, error, random and waiting)
63
+ for message in messages_table:
64
+ pass
65
+
66
+ # get_audio()
67
+
68
+ # Set up vectorstore
69
+
70
+ # Upload data to bucket in CP (videos, audio, prompts and csv files)
71
+
72
+ # Change text in the button
73
+ return gr.Button(value='Chatbot created!!!', interactive=True)