KleinPenny commited on
Commit
68f0d8d
·
verified ·
1 Parent(s): 867343a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -71
app.py CHANGED
@@ -5,65 +5,30 @@ import os
5
  import requests
6
  import scipy.io.wavfile
7
  import io
 
8
 
9
  client = InferenceClient(
10
  "meta-llama/Meta-Llama-3-8B-Instruct",
11
  token=os.getenv('hf_token')
12
  )
13
 
14
- def respond(
15
- message,
16
- history: list[tuple[str, str]],
17
- system_message,
18
- max_tokens,
19
- temperature,
20
- top_p,
21
- ):
22
- messages = [{"role": "system", "content": system_message}]
23
-
24
- for val in history:
25
- if val[0]:
26
- messages.append({"role": "user", "content": val[0]})
27
- if val[1]:
28
- messages.append({"role": "assistant", "content": val[1]})
29
-
30
- messages.append({"role": "user", "content": message})
31
-
32
- response = ""
33
-
34
- for message in client.chat_completion(
35
- messages,
36
- max_tokens=max_tokens,
37
- stream=True,
38
- temperature=temperature,
39
- top_p=top_p,
40
- ):
41
- token = message.choices[0].delta.content
42
-
43
- response += token
44
- yield response
45
 
46
  def process_audio(audio_data):
47
  if audio_data is None:
48
- return "No audio provided"
49
-
50
- print("audio_data:", audio_data) # 添加这行代码
51
 
52
  # 检查 audio_data 是否是元组,并提取数据
53
  if isinstance(audio_data, tuple):
54
  sample_rate, data = audio_data
55
- print("Sample rate:", sample_rate)
56
- print("Data type:", type(data))
57
  else:
58
- return "Invalid audio data format"
59
-
60
 
61
  # Convert the audio data to WAV format in memory
62
  buf = io.BytesIO()
63
  scipy.io.wavfile.write(buf, sample_rate, data)
64
  wav_bytes = buf.getvalue()
65
  buf.close()
66
-
67
  API_URL = "https://api-inference.huggingface.co/models/openai/whisper-large-v2"
68
  headers = {"Authorization": f"Bearer {os.getenv('hf_token')}"}
69
 
@@ -74,13 +39,15 @@ def process_audio(audio_data):
74
  # Call the API to process the audio
75
  output = query(wav_bytes)
76
 
77
- print(output)
78
 
79
  # Check the API response
80
  if 'text' in output:
81
- return output['text']
 
82
  else:
83
- return "Error in processing audio"
 
84
 
85
  # 定义函数以禁用按钮并显示加载指示器
86
  def disable_components():
@@ -94,42 +61,101 @@ def disable_components():
94
 
95
  # 定义函数以启用按钮并隐藏加载指示器
96
  def enable_components(recognized_text):
97
- # 处理完成后,recognized_text 已经由 process_audio 更新
98
- # 重新启用 process_button
99
  process_button_update = gr.update(interactive=True)
100
  # 隐藏加载动画
101
  loading_animation_update = gr.update(visible=False)
102
  return recognized_text, process_button_update, loading_animation_update
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- # 创建界面
106
  def create_interface():
107
  with gr.Blocks() as demo:
108
- # 标题
109
- gr.Markdown("# 语音识别与聊天系统")
110
 
111
- # 音频输入部分
112
  with gr.Row():
113
  audio_input = gr.Audio(
114
  sources="microphone",
115
- type="numpy", # 获取音频数据和采样率
116
- label="上传音频"
117
  )
 
118
 
119
- # 文本识别输出部分
120
- with gr.Row():
121
- recognized_text = gr.Textbox(label="识别文本")
122
-
123
- # 处理音频的按钮
124
- process_button = gr.Button("处理音频")
125
 
126
- # 加载动画
127
  loading_animation = gr.HTML(
128
  value='<div style="text-align: center;"><span style="font-size: 18px;">ASR Model is running...</span></div>',
129
  visible=False
130
  )
131
-
132
- # 关联音频处理函数,并在点击时更新组件状态
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  process_button.click(
134
  fn=disable_components,
135
  inputs=[],
@@ -137,22 +163,14 @@ def create_interface():
137
  ).then(
138
  fn=process_audio,
139
  inputs=[audio_input],
140
- outputs=recognized_text
141
  ).then(
142
  fn=enable_components,
143
  inputs=[recognized_text],
144
  outputs=[recognized_text, process_button, loading_animation]
145
  )
146
-
147
- # Chatbot 界面
148
- chatbot = gr.ChatInterface(
149
- fn=respond,
150
- additional_inputs=[
151
- gr.Textbox(value="You are a helpful chatbot that answers questions.", label="系统消息")
152
- ]
153
- )
154
-
155
- # 布局包含 Chatbot
156
  with gr.Row():
157
  chatbot_output = chatbot
158
 
@@ -162,4 +180,4 @@ def create_interface():
162
 
163
  if __name__ == "__main__":
164
  demo = create_interface()
165
- demo.launch()
 
5
  import requests
6
  import scipy.io.wavfile
7
  import io
8
+ import time
9
 
10
  client = InferenceClient(
11
  "meta-llama/Meta-Llama-3-8B-Instruct",
12
  token=os.getenv('hf_token')
13
  )
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def process_audio(audio_data):
17
  if audio_data is None:
18
+ return "No audio provided.", ""
 
 
19
 
20
  # 检查 audio_data 是否是元组,并提取数据
21
  if isinstance(audio_data, tuple):
22
  sample_rate, data = audio_data
 
 
23
  else:
24
+ return "Invalid audio data format.", ""
 
25
 
26
  # Convert the audio data to WAV format in memory
27
  buf = io.BytesIO()
28
  scipy.io.wavfile.write(buf, sample_rate, data)
29
  wav_bytes = buf.getvalue()
30
  buf.close()
31
+
32
  API_URL = "https://api-inference.huggingface.co/models/openai/whisper-large-v2"
33
  headers = {"Authorization": f"Bearer {os.getenv('hf_token')}"}
34
 
 
39
  # Call the API to process the audio
40
  output = query(wav_bytes)
41
 
42
+ print(output) # Check output in console (logs in HF space)
43
 
44
  # Check the API response
45
  if 'text' in output:
46
+ recognized_text = output['text']
47
+ return recognized_text, recognized_text
48
  else:
49
+ recognized_text = "The ASR module is still loading, please press the button again!"
50
+ return recognized_text, ""
51
 
52
  # 定义函数以禁用按钮并显示加载指示器
53
  def disable_components():
 
61
 
62
  # 定义函数以启用按钮并隐藏加载指示器
63
  def enable_components(recognized_text):
 
 
64
  process_button_update = gr.update(interactive=True)
65
  # 隐藏加载动画
66
  loading_animation_update = gr.update(visible=False)
67
  return recognized_text, process_button_update, loading_animation_update
68
 
69
+ llama_responded = 0
70
+
71
+ def respond(
72
+ message,
73
+ history: list[tuple[str, str]]
74
+ ):
75
+ global llama_responded
76
+ system_message = "You are a helpful chatbot that answers questions. Give any answer within 50 words."
77
+ messages = [{"role": "system", "content": system_message}]
78
+
79
+ for val in history:
80
+ print(val[0])
81
+ if val[0] != None:
82
+ if val[0]:
83
+ messages.append({"role": "user", "content": val[0]})
84
+ if val[1]:
85
+ messages.append({"role": "assistant", "content": val[1]})
86
+ messages.append({"role": "user", "content": message})
87
+
88
+ response = ""
89
+
90
+ for message in client.chat_completion(
91
+ messages,
92
+ stream=True,
93
+ ):
94
+ token = message.choices[0].delta.content
95
+ response += token
96
+
97
+ llama_responded = 1
98
+ return response #gr.Audio("/home/yxpeng/Projects/RAGHack/Exodia/voice_sample/trump1.wav")
99
+
100
+ def update_response_display():
101
+ while not llama_responded:
102
+ time.sleep(1)
103
+
104
+ def bot(history):
105
+ global llama_responded
106
+ #print(history)
107
+ history.append([None,gr.Audio("/home/yxpeng/Projects/RAGHack/Exodia/voice_sample/trump1.wav")])
108
+ llama_responded = 0
109
+
110
+ return history
111
 
 
112
  def create_interface():
113
  with gr.Blocks() as demo:
114
+ # Title
115
+ gr.Markdown("# Exodia AI Assistant")
116
 
117
+ # Audio input section
118
  with gr.Row():
119
  audio_input = gr.Audio(
120
  sources="microphone",
121
+ type="numpy", # Get audio data and sample rate
122
+ label="Say Something..."
123
  )
124
+ recognized_text = gr.Textbox(label="Recognized Text",interactive=False)
125
 
126
+ # Process audio button
127
+ process_button = gr.Button("Process Audio")
 
 
 
 
128
 
129
+ # Loading animation
130
  loading_animation = gr.HTML(
131
  value='<div style="text-align: center;"><span style="font-size: 18px;">ASR Model is running...</span></div>',
132
  visible=False
133
  )
134
+
135
+ chatbot_custom = gr.Chatbot(height=500) # Set height to 500 pixels
136
+
137
+ # Chat interface using the custom chatbot instance
138
+ chatbot = gr.ChatInterface(
139
+ fn=respond,
140
+ chatbot=chatbot_custom,
141
+ submit_btn="Start Chatting"
142
+ )
143
+ user_start =chatbot.textbox.submit(
144
+ fn=update_response_display,
145
+ inputs=[],
146
+ outputs=[],
147
+ )
148
+
149
+ # 在用户提交请求的时候
150
+ #user_start = chatbot.textbox.submit()
151
+
152
+ user_start.then(
153
+ fn=bot,
154
+ inputs=[chatbot_custom],
155
+ outputs=chatbot_custom, # 更新 response_display 的内容
156
+ )
157
+
158
+ # Associate audio processing function and update component states on click
159
  process_button.click(
160
  fn=disable_components,
161
  inputs=[],
 
163
  ).then(
164
  fn=process_audio,
165
  inputs=[audio_input],
166
+ outputs=[recognized_text, chatbot.textbox]
167
  ).then(
168
  fn=enable_components,
169
  inputs=[recognized_text],
170
  outputs=[recognized_text, process_button, loading_animation]
171
  )
172
+
173
+ # Layout includes Chatbot
 
 
 
 
 
 
 
 
174
  with gr.Row():
175
  chatbot_output = chatbot
176
 
 
180
 
181
  if __name__ == "__main__":
182
  demo = create_interface()
183
+ demo.launch()