truongghieu's picture
Update app.py
1db8934
raw
history blame
3.93 kB
import speech_recognition as sr
import gradio as gr
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig,BitsAndBytesConfig
import torch
import os
from openai import OpenAI
key = os.environ.get('OPENAI_API_KEY')
client = OpenAI(api_key=key)
Medical_finetunned_model = "truongghieu/deci-finetuned_Prj2"
answer_text = "This is an answer"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="float16", bnb_4bit_use_double_quant=True
)
tokenizer = AutoTokenizer.from_pretrained(Medical_finetunned_model, trust_remote_code=True)
if torch.cuda.is_available():
model = AutoModelForCausalLM.from_pretrained(Medical_finetunned_model, trust_remote_code=True, quantization_config=bnb_config)
else:
model = AutoModelForCausalLM.from_pretrained("truongghieu/deci-finetuned", trust_remote_code=True)
def generate_text(*args):
if args[0] == "":
return "Please input text"
generation_config = GenerationConfig(
penalty_alpha=args[1],
do_sample=args[2],
top_k=args[3],
temperature=args[4],
repetition_penalty=args[5],
max_new_tokens=args[6],
pad_token_id=tokenizer.eos_token_id
)
input_text = f'{args[0]}'
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
output_ids = model.generate(input_ids, generation_config=generation_config)
output_text = tokenizer.decode(output_ids[0])
return output_text
def gpt_generate(*args):
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": args[0]}],
temperature = args[4],
max_tokens = args[6],
)
return response.choices[0].message.content
def recognize_speech(audio_data):
# return text
audio_data = sr.AudioData(np.array(audio_data[1]), sample_rate=audio_data[0] , sample_width=2)
recognizer = sr.Recognizer()
try:
text = recognizer.recognize_google(audio_data)
return text
except sr.UnknownValueError:
return "Speech Recognition could not understand audio."
except sr.RequestError as e:
return f"Could not request results from Google Speech Recognition service; {e}"
with gr.Blocks() as demo:
with gr.Row():
inp = gr.Audio(type="numpy")
out_text_predict = gr.Textbox(label="Recognized Speech")
button = gr.Button("Recognize Speech" , size="lg")
button.click(recognize_speech, inp, out_text_predict)
with gr.Row():
with gr.Row():
penalty_alpha_slider = gr.Slider(minimum=0, maximum=1, step=0.1, label="penalty alpha",value=0.6)
do_sample_checkbox = gr.Checkbox(label="do sample",value=True)
top_k_slider = gr.Slider(minimum=0, maximum=10, step=1, label="top k", value=5)
with gr.Row():
temperature_slider = gr.Slider(minimum=0, maximum=1, step=0.1, label="temperature",value=0.5)
repetition_penalty_slider = gr.Slider(minimum=0, maximum=2, step=0.1, label="repetition penalty",value=1.0)
max_new_tokens_slider = gr.Slider(minimum=0, maximum=200, step=1, label="max new tokens",value=30)
with gr.Row():
out_answer = gr.Textbox(label="Answer")
button_answer = gr.Button("Answer")
button_answer.click(generate_text, [out_text_predict, penalty_alpha_slider, do_sample_checkbox, top_k_slider, temperature_slider, repetition_penalty_slider, max_new_tokens_slider], out_answer)
with gr.Row():
gpt_output = gr.Textbox(label="GPT-3.5 Turbo Output")
button_gpt = gr.Button("GPT-3.5 Answer")
button_gpt.click(gpt_generate,[out_text_predict, penalty_alpha_slider, do_sample_checkbox, top_k_slider, temperature_slider, repetition_penalty_slider, max_new_tokens_slider],gpt_output)
demo.launch()