Exodia / app.py
KleinPenny's picture
Update app.py
9d03774 verified
raw
history blame
4.93 kB
import gradio as gr
import numpy as np
from huggingface_hub import InferenceClient
import os
import requests
import scipy.io.wavfile
client = InferenceClient(
"meta-llama/Meta-Llama-3-8B-Instruct",
token=os.getenv('hf_token')
)
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += token
yield response
def process_audio(audio_data):
if audio_data is None:
return "No audio provided"
print("audio_data:", audio_data) # 添加这行代码
# 检查 audio_data 是否是元组,并提取数据
if isinstance(audio_data, tuple):
sample_rate, data = audio_data
print("Sample rate:", sample_rate)
print("Data type:", type(data))
else:
return "Invalid audio data format"
# Define the local file path to save the WAV file
local_wav_file = "converted_audio.wav"
# Save the audio data as a WAV file
scipy.io.wavfile.write(local_wav_file, sample_rate, data)
API_URL = "https://api-inference.huggingface.co/models/openai/whisper-large-v2"
headers = {"Authorization": f"Bearer {os.getenv('hf_token')}"}
def query(filename):
with open(filename, "rb") as f:
file_data = f.read()
response = requests.post(API_URL, headers=headers, data=file_data)
return response.json()
# Call the API to process the audio
output = query(local_wav_file)
print(output)
# Check the API response
if 'text' in output:
return output['text']
else:
return "Error in processing audio"
# 定义函数以禁用按钮并显示加载指示器
def disable_components():
# 更新 recognized_text 的内容,提示用户正在处理
recognized_text_update = gr.update(value='正在处理,请稍候...')
# 禁用 process_button
process_button_update = gr.update(interactive=False)
# 显示加载动画
loading_animation_update = gr.update(visible=True)
return recognized_text_update, process_button_update, loading_animation_update
# 定义函数以启用按钮并隐藏加载指示器
def enable_components(recognized_text):
# 处理完成后,recognized_text 已经由 process_audio 更新
# 重新启用 process_button
process_button_update = gr.update(interactive=True)
# 隐藏加载动画
loading_animation_update = gr.update(visible=False)
return recognized_text, process_button_update, loading_animation_update
# 创建界面
def create_interface():
with gr.Blocks() as demo:
# 标题
gr.Markdown("# 语音识别与聊天系统")
# 音频输入部分
with gr.Row():
audio_input = gr.Audio(
sources="microphone",
type="numpy", # 获取音频数据和采样率
label="上传音频"
)
# 文本识别输出部分
with gr.Row():
recognized_text = gr.Textbox(label="识别文本")
# 处理音频的按钮
process_button = gr.Button("处理音频")
# 加载动画
loading_animation = gr.HTML(
value='<div style="text-align: center;"><span style="font-size: 18px;">ASR Model is running...</span></div>',
visible=False
)
# 关联音频处理函数,并在点击时更新组件状态
process_button.click(
fn=disable_components,
inputs=[],
outputs=[recognized_text, process_button, loading_animation]
).then(
fn=process_audio,
inputs=[audio_input],
outputs=recognized_text
).then(
fn=enable_components,
inputs=[recognized_text],
outputs=[recognized_text, process_button, loading_animation]
)
# Chatbot 界面
chatbot = gr.ChatInterface(
fn=respond,
additional_inputs=[
gr.Textbox(value="You are a helpful chatbot that answers questions.", label="系统消息")
]
)
# 布局包含 Chatbot
with gr.Row():
chatbot_output = chatbot
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch()