rohan13's picture
added press enter to label
b0d0d67
raw
history blame
3.92 kB
import time
import uuid
import gradio as gr
from gtts import gTTS
from transformers import pipeline
from main import index, run
p = pipeline("automatic-speech-recognition", model="openai/whisper-base")
"""Use text to call chat method from main.py"""
models = ["GPT-3.5", "Flan UL2", "Flan T5"]
with gr.Blocks(theme='snehilsanyal/scikit-learn') as demo:
state = gr.State([])
def create_session_id():
return str(uuid.uuid4())
def add_text(history, text, model):
print("Question asked: " + text)
response = run_model(text, model)
history = history + [(text, response)]
print(history)
return history, ""
def run_model(text, model):
start_time = time.time()
print("start time:" + str(start_time))
response = run(text, model, state.session_id)
end_time = time.time()
# If response contains string `SOURCES:`, then add a \n before `SOURCES`
if "SOURCES:" in response:
response = response.replace("SOURCES:", "\nSOURCES:")
# response = response + "\n\n" + "Time taken: " + str(end_time - start_time)
print(response)
print("Time taken: " + str(end_time - start_time))
return response
def get_output(history, audio, model):
txt = p(audio)["text"]
# history.append(( (audio, ) , txt))
audio_path = 'response.wav'
response = run_model(txt, model)
# Remove all text from SOURCES: to the end of the string
trimmed_response = response.split("SOURCES:")[0]
myobj = gTTS(text=trimmed_response, lang='en', slow=False)
myobj.save(audio_path)
# split audio by / and keep the last element
# audio = audio.split("/")[-1]
# audio = audio + ".wav"
history.append(((audio,), (audio_path,)))
print(history)
return history
def set_model(history, model):
print("Model selected: " + model)
history = get_first_message(history)
index(model, state.session_id)
return history
def get_first_message(history):
history = [(None,
'Learn about the course and get answers with sources.\n This is an experiment using AI, so it might make errors')]
return history
def bot(history):
return history
state.session_id = create_session_id()
print("Session ID: " + state.session_id)
# Title on top in middle of the page
# gr.HTML("<h1 style='text-align: center;'>Course Assistant - 3D Printing Revolution</h1>")
chatbot = gr.Chatbot(get_first_message([]), elem_id="chatbot", label='3D Printing Applications Question Answer Bot').style(height=300,
container=False)
# with gr.Row():
# Create radio button to select model
radio = gr.Radio(models, label="Choose a model", value="GPT-3.5", type="value", visible=False)
with gr.Row():
# with gr.Column(scale=0.75):
txt = gr.Textbox(
label="Ask your question here and press enter",
placeholder="Enter text and press enter", lines=1
).style(container=False)
# with gr.Column(scale=0.25):
audio = gr.Audio(source="microphone", type="filepath", visible=False)
txt.submit(add_text, [chatbot, txt, radio], [chatbot, txt], postprocess=False).then(
bot, chatbot, chatbot
)
audio.change(fn=get_output, inputs=[chatbot, audio, radio], outputs=[chatbot], show_progress=True).then(
bot, chatbot, chatbot
)
radio.change(fn=set_model, inputs=[chatbot, radio], outputs=[chatbot]).then(bot, chatbot, chatbot)
audio.change(lambda: None, None, audio)
set_model(chatbot, radio.value)
if __name__ == "__main__":
demo.queue()
demo.queue(concurrency_count=5)
demo.launch(debug=True)