lukecq commited on
Commit
2360578
·
verified ·
1 Parent(s): 015f3f9

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -63
app.py CHANGED
@@ -1,64 +1,181 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
-
62
-
63
- if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import time
3
+ from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor
4
+ from io import BytesIO
5
+ from urllib.request import urlopen
6
+ import librosa
7
+ import os, json
8
+ from sys import argv
9
+ from vllm import LLM, SamplingParams
10
+
11
+ def load_model_processor(model_path):
12
+ processor = AutoProcessor.from_pretrained(model_path)
13
+ llm = LLM(
14
+ model=model_path, trust_remote_code=True, gpu_memory_utilization=0.4,
15
+ enforce_eager=True,
16
+ limit_mm_per_prompt={"audio": 5},
17
+ )
18
+ return llm, processor
19
+
20
+ model_path1 = "Qwen/Qwen2-Audio-7B-Instruct" #argv[1]
21
+ model1, processor1 = load_model_processor(model_path1)
22
+
23
+ def response_to_audio_conv(conversation, model=None, processor=None, temperature = 0.1,repetition_penalty=1.1, top_p = 0.9,
24
+ max_new_tokens = 2048):
25
+ text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
26
+ audios = []
27
+ for message in conversation:
28
+ if isinstance(message["content"], list):
29
+ for ele in message["content"]:
30
+ if ele["type"] == "audio":
31
+ if ele['audio_url'] != None:
32
+ audios.append(librosa.load(
33
+ ele['audio_url'],
34
+ sr=processor.feature_extractor.sampling_rate)[0]
35
+ )
36
+
37
+ sampling_params = SamplingParams(
38
+ temperature=temperature, max_tokens=max_new_tokens, repetition_penalty=repetition_penalty, top_p=top_p, top_k=20,
39
+ stop_token_ids=[],
40
+ )
41
+
42
+ input = {
43
+ 'prompt': text,
44
+ 'multi_modal_data': {
45
+ 'audio': [(audio, 16000) for audio in audios]
46
+ }
47
+ }
48
+
49
+ output = model.generate([input], sampling_params=sampling_params)[0]
50
+ response = output.outputs[0].text
51
+ return response
52
+
53
+ def print_like_dislike(x: gr.LikeData):
54
+ print(x.index, x.value, x.liked)
55
+
56
+ def add_message(history, message):
57
+ paths = []
58
+ for turn in history:
59
+ if turn['role'] == "user" and type(turn['content']) != str:
60
+ paths.append(turn['content'][0])
61
+ for x in message["files"]:
62
+ if x not in paths:
63
+ history.append({"role": "user", "content": {"path": x}})
64
+ if message["text"] is not None:
65
+ history.append({"role": "user", "content": message["text"]})
66
+ return history, gr.MultimodalTextbox(value=None, interactive=False)
67
+
68
+ def format_user_messgae(message):
69
+ if type(message['content']) == str:
70
+ return {"role": "user", "content": [{"type": "text", "text": message['content']}]}
71
+ else:
72
+ return {"role": "user", "content": [{"type": "audio", "audio_url": message['content'][0]}]}
73
+
74
+ def history_to_conversation(history):
75
+ conversation = []
76
+ audio_paths = []
77
+ for turn in history:
78
+ if turn['role'] == "user":
79
+ if not turn['content']:
80
+ continue
81
+ turn = format_user_messgae(turn)
82
+ if turn['content'][0]['type'] == 'audio':
83
+ if turn['content'][0]['audio_url'] in audio_paths:
84
+ continue
85
+ else:
86
+ audio_paths.append(turn['content'][0]['audio_url'])
87
+
88
+ if len(conversation) > 0 and conversation[-1]["role"] == "user":
89
+ conversation[-1]['content'].append(turn['content'][0])
90
+ else:
91
+ conversation.append(turn)
92
+ else:
93
+ conversation.append(turn)
94
+
95
+ print(json.dumps(conversation, indent=4, ensure_ascii=False))
96
+ return conversation
97
+
98
+ def bot(history: list, temperature = 0.1,repetition_penalty=1.1, top_p = 0.9,
99
+ max_new_tokens = 2048):
100
+ conversation = history_to_conversation(history)
101
+ response = response_to_audio_conv(conversation, model=model1, processor=processor1, temperature = temperature,repetition_penalty=repetition_penalty, top_p = top_p, max_new_tokens = max_new_tokens)
102
+ # response = "Nice to meet you!"
103
+ print("Bot:",response)
104
+
105
+ history.append({"role": "assistant", "content": ""})
106
+ for character in response:
107
+ history[-1]["content"] += character
108
+ time.sleep(0.01)
109
+ yield history
110
+
111
+ insturctions = """**Instruction**: there are three input format:
112
+ 1. text: input text message only
113
+ 2. audio: upload audio file or record a voice message
114
+ 3. audio + text: record a voice message and input text message"""
115
+
116
+ with gr.Blocks() as demo:
117
+ # gr.Markdown("""<p align="center"><img src="images/seal_logo.png" style="height: 80px"/><p>""")
118
+ # gr.Image("images/seal_logo.png", elem_id="seal_logo", show_label=False,height=80,show_fullscreen_button=False)
119
+ gr.Markdown(
120
+ """<div style="text-align: center; font-size: 32px; font-weight: bold;">SeaLLMs-Audio ChatBot</div>""",
121
+ )
122
+
123
+ # Description text
124
+ gr.Markdown(
125
+ """<div style="text-align: center; font-size: 16px;">
126
+ This WebUI is based on SeaLLMs-Audio-7B-Chat, developed by Alibaba DAMO Academy.<br>
127
+ You can interact with the chatbot in <b>English, Chinese, Indonesian, Thai, or Vietnamese</b>.<br>
128
+ For each round, you can input <b>audio and/or text</b>.
129
+ </div>""",
130
+ )
131
+
132
+ # Links with proper formatting
133
+ gr.Markdown(
134
+ """<div style="text-align: center; font-size: 16px;">
135
+ <a href="https://huggingface.co/SeaLLMs/SeaLLMs-v3-7B-Chat">[Website]</a> &nbsp;
136
+ <a href="https://huggingface.co/SeaLLMs/SeaLLMs-v3-7B-Chat">[Model🤗]</a> &nbsp;
137
+ <a href="https://github.com/liuchaoqun/SeaLLMs-Audio">[Github]</a>
138
+ </div>""",
139
+ )
140
+
141
+ # gr.Markdown(insturctions)
142
+ # with gr.Row():
143
+ # with gr.Column():
144
+ # temperature = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.1, label="Temperature")
145
+ # with gr.Column():
146
+ # top_p = gr.Slider(minimum=0.1, maximum=1, value=0.5, step=0.1, label="Top P")
147
+ # with gr.Column():
148
+ # repetition_penalty = gr.Slider(minimum=0, maximum=2, value=1.1, step=0.1, label="Repetition Penalty")
149
+ chatbot = gr.Chatbot(elem_id="chatbot", bubble_full_width=False, type="messages")
150
+
151
+ chat_input = gr.MultimodalTextbox(
152
+ interactive=True,
153
+ file_count="single",
154
+ file_types=['.wav'],
155
+ placeholder="Enter message (optional) ...",
156
+ show_label=False,
157
+ sources=["microphone", "upload"],
158
+ )
159
+
160
+ chat_msg = chat_input.submit(
161
+ add_message, [chatbot, chat_input], [chatbot, chat_input]
162
+ )
163
+ bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name="bot_response")
164
+ # bot_msg = chat_msg.then(bot, [chatbot, temperature, repetition_penalty, top_p], chatbot, api_name="bot_response")
165
+ bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
166
+
167
+ # chatbot.like(print_like_dislike, None, None, like_user_message=True)
168
+
169
+ clear_button = gr.ClearButton([chatbot, chat_input])
170
+
171
+ PORT = 7950
172
+ demo.launch(server_port=PORT, show_api = True, allowed_paths = [],
173
+ root_path = f"https://dsw-gateway.alibaba-inc.com/dsw81322/proxy/{PORT}/")
174
+
175
+ # demo.launch(
176
+ # share=False,
177
+ # inbrowser=True,
178
+ # server_port=7950,
179
+ # server_name="0.0.0.0",
180
+ # max_threads=40
181
+ # )