Spaces:
Runtime error
Runtime error
bofenghuang
commited on
Commit
β’
a4b1443
0
Parent(s):
Initial commit
Browse files- .gitattributes +34 -0
- README.md +10 -0
- app.py +356 -0
- requirements.txt +8 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Vigogne-Chat
|
3 |
+
emoji: π¦
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.27.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
app.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2023 Bofeng Huang
|
4 |
+
|
5 |
+
"""
|
6 |
+
Modified from: https://huggingface.co/spaces/mosaicml/mpt-7b-chat/raw/main/app.py
|
7 |
+
|
8 |
+
Usage:
|
9 |
+
CUDA_VISIBLE_DEVICES=0
|
10 |
+
|
11 |
+
python vigogne/demo/demo_chat.py \
|
12 |
+
--base_model_name_or_path huggyllama/llama-7b \
|
13 |
+
--lora_model_name_or_path bofenghuang/vigogne-chat-7b
|
14 |
+
"""
|
15 |
+
|
16 |
+
# import datetime
|
17 |
+
import logging
|
18 |
+
import os
|
19 |
+
import re
|
20 |
+
from threading import Event, Thread
|
21 |
+
from typing import List, Optional
|
22 |
+
|
23 |
+
|
24 |
+
# from uuid import uuid4
|
25 |
+
|
26 |
+
import fire
|
27 |
+
import json
|
28 |
+
import gradio as gr
|
29 |
+
|
30 |
+
# import requests
|
31 |
+
import torch
|
32 |
+
from peft import PeftModel
|
33 |
+
from transformers import (
|
34 |
+
AutoModelForCausalLM,
|
35 |
+
AutoTokenizer,
|
36 |
+
GenerationConfig,
|
37 |
+
StoppingCriteriaList,
|
38 |
+
TextIteratorStreamer,
|
39 |
+
)
|
40 |
+
|
41 |
+
from vigogne.constants import ASSISTANT, USER
|
42 |
+
from vigogne.preprocess import generate_inference_chat_prompt
|
43 |
+
from vigogne.inference.inference_utils import StopWordsCriteria
|
44 |
+
|
45 |
+
logging.basicConfig(
|
46 |
+
format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
|
47 |
+
datefmt="%Y-%m-%dT%H:%M:%SZ",
|
48 |
+
)
|
49 |
+
logger = logging.getLogger(__name__)
|
50 |
+
logger.setLevel(logging.DEBUG)
|
51 |
+
|
52 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
53 |
+
|
54 |
+
try:
|
55 |
+
if torch.backends.mps.is_available():
|
56 |
+
device = "mps"
|
57 |
+
except:
|
58 |
+
pass
|
59 |
+
|
60 |
+
logger.info(f"Model will be loaded on device `{device}`")
|
61 |
+
|
62 |
+
|
63 |
+
# def log_conversation(conversation_id, history, messages, generate_kwargs):
|
64 |
+
# logging_url = os.getenv("LOGGING_URL", None)
|
65 |
+
# if logging_url is None:
|
66 |
+
# return
|
67 |
+
|
68 |
+
# timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
|
69 |
+
|
70 |
+
# data = {
|
71 |
+
# "conversation_id": conversation_id,
|
72 |
+
# "timestamp": timestamp,
|
73 |
+
# "history": history,
|
74 |
+
# "messages": messages,
|
75 |
+
# "generate_kwargs": generate_kwargs,
|
76 |
+
# }
|
77 |
+
|
78 |
+
# try:
|
79 |
+
# requests.post(logging_url, json=data)
|
80 |
+
# except requests.exceptions.RequestException as e:
|
81 |
+
# print(f"Error logging conversation: {e}")
|
82 |
+
|
83 |
+
|
84 |
+
def user(message, history):
|
85 |
+
# Append the user's message to the conversation history
|
86 |
+
return "", history + [[message, ""]]
|
87 |
+
|
88 |
+
|
89 |
+
# def get_uuid():
|
90 |
+
# return str(uuid4())
|
91 |
+
|
92 |
+
|
93 |
+
def main(
|
94 |
+
base_model_name_or_path: str = "huggyllama/llama-7b",
|
95 |
+
lora_model_name_or_path: str = "bofenghuang/vigogne-chat-7b",
|
96 |
+
load_8bit: bool = False,
|
97 |
+
server_name: Optional[str] = "0.0.0.0",
|
98 |
+
server_port: Optional[str] = None,
|
99 |
+
share: bool = False,
|
100 |
+
):
|
101 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path, padding_side="right", use_fast=False)
|
102 |
+
|
103 |
+
if device == "cuda":
|
104 |
+
model = AutoModelForCausalLM.from_pretrained(
|
105 |
+
base_model_name_or_path,
|
106 |
+
load_in_8bit=load_8bit,
|
107 |
+
torch_dtype=torch.float16,
|
108 |
+
device_map="auto",
|
109 |
+
)
|
110 |
+
model = PeftModel.from_pretrained(
|
111 |
+
model,
|
112 |
+
lora_model_name_or_path,
|
113 |
+
torch_dtype=torch.float16,
|
114 |
+
)
|
115 |
+
elif device == "mps":
|
116 |
+
model = AutoModelForCausalLM.from_pretrained(
|
117 |
+
base_model_name_or_path,
|
118 |
+
device_map={"": device},
|
119 |
+
torch_dtype=torch.float16,
|
120 |
+
)
|
121 |
+
model = PeftModel.from_pretrained(
|
122 |
+
model,
|
123 |
+
lora_model_name_or_path,
|
124 |
+
device_map={"": device},
|
125 |
+
torch_dtype=torch.float16,
|
126 |
+
)
|
127 |
+
else:
|
128 |
+
model = AutoModelForCausalLM.from_pretrained(base_model_name_or_path, device_map={"": device}, low_cpu_mem_usage=True)
|
129 |
+
model = PeftModel.from_pretrained(
|
130 |
+
model,
|
131 |
+
lora_model_name_or_path,
|
132 |
+
device_map={"": device},
|
133 |
+
)
|
134 |
+
|
135 |
+
if not load_8bit and device != "cpu":
|
136 |
+
model.half() # seems to fix bugs for some users.
|
137 |
+
|
138 |
+
model.eval()
|
139 |
+
|
140 |
+
# NB
|
141 |
+
stop_words = [f"<|{ASSISTANT}|>", f"<|{USER}|>"]
|
142 |
+
stop_words_criteria = StopWordsCriteria(stop_words=stop_words, tokenizer=tokenizer)
|
143 |
+
pattern_trailing_stop_words = re.compile(rf'(?:{"|".join([re.escape(stop_word) for stop_word in stop_words])})\W*$')
|
144 |
+
|
145 |
+
def bot(history, max_new_tokens, temperature, top_p, top_k, repetition_penalty, conversation_id=None):
|
146 |
+
# logger.info(f"History: {json.dumps(history, indent=4, ensure_ascii=False)}")
|
147 |
+
|
148 |
+
# Construct the input message string for the model by concatenating the current system message and conversation history
|
149 |
+
messages = generate_inference_chat_prompt(history, tokenizer)
|
150 |
+
logger.info(messages)
|
151 |
+
assert messages is not None, "User input is too long!"
|
152 |
+
|
153 |
+
# Tokenize the messages string
|
154 |
+
input_ids = tokenizer(messages, return_tensors="pt")["input_ids"].to(device)
|
155 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
156 |
+
generate_kwargs = dict(
|
157 |
+
input_ids=input_ids,
|
158 |
+
generation_config=GenerationConfig(
|
159 |
+
temperature=temperature,
|
160 |
+
do_sample=temperature > 0.0,
|
161 |
+
top_p=top_p,
|
162 |
+
top_k=top_k,
|
163 |
+
repetition_penalty=repetition_penalty,
|
164 |
+
max_new_tokens=max_new_tokens,
|
165 |
+
),
|
166 |
+
streamer=streamer,
|
167 |
+
stopping_criteria=StoppingCriteriaList([stop_words_criteria]),
|
168 |
+
)
|
169 |
+
|
170 |
+
# stream_complete = Event()
|
171 |
+
|
172 |
+
def generate_and_signal_complete():
|
173 |
+
model.generate(**generate_kwargs)
|
174 |
+
# stream_complete.set()
|
175 |
+
|
176 |
+
# def log_after_stream_complete():
|
177 |
+
# stream_complete.wait()
|
178 |
+
# log_conversation(
|
179 |
+
# conversation_id,
|
180 |
+
# history,
|
181 |
+
# messages,
|
182 |
+
# {
|
183 |
+
# "top_k": top_k,
|
184 |
+
# "top_p": top_p,
|
185 |
+
# "temperature": temperature,
|
186 |
+
# "repetition_penalty": repetition_penalty,
|
187 |
+
# },
|
188 |
+
# )
|
189 |
+
|
190 |
+
t1 = Thread(target=generate_and_signal_complete)
|
191 |
+
t1.start()
|
192 |
+
|
193 |
+
# t2 = Thread(target=log_after_stream_complete)
|
194 |
+
# t2.start()
|
195 |
+
|
196 |
+
# Initialize an empty string to store the generated text
|
197 |
+
partial_text = ""
|
198 |
+
for new_text in streamer:
|
199 |
+
# NB
|
200 |
+
new_text = pattern_trailing_stop_words.sub("", new_text)
|
201 |
+
|
202 |
+
partial_text += new_text
|
203 |
+
history[-1][1] = partial_text
|
204 |
+
yield history
|
205 |
+
|
206 |
+
logger.info(f"Response: {history[-1][1]}")
|
207 |
+
|
208 |
+
with gr.Blocks(
|
209 |
+
theme=gr.themes.Soft(),
|
210 |
+
css=".disclaimer {font-variant-caps: all-small-caps;}",
|
211 |
+
) as demo:
|
212 |
+
# conversation_id = gr.State(get_uuid)
|
213 |
+
gr.Markdown(
|
214 |
+
"""<h1><center>π¦ Vigogne Chat</center></h1>
|
215 |
+
|
216 |
+
This demo is of [Vigogne-Chat-7B](https://huggingface.co/bofenghuang/vigogne-chat-7b). It's based on [LLaMA-7B](https://github.com/facebookresearch/llama) finetuned to conduct French π«π· dialogues between a user and an AI assistant.
|
217 |
+
|
218 |
+
For more information, please visit the [Github repo](https://github.com/bofenghuang/vigogne) of the Vigogne project.
|
219 |
+
"""
|
220 |
+
)
|
221 |
+
chatbot = gr.Chatbot().style(height=500)
|
222 |
+
with gr.Row():
|
223 |
+
with gr.Column():
|
224 |
+
msg = gr.Textbox(
|
225 |
+
label="Chat Message Box",
|
226 |
+
placeholder="Chat Message Box",
|
227 |
+
show_label=False,
|
228 |
+
).style(container=False)
|
229 |
+
with gr.Column():
|
230 |
+
with gr.Row():
|
231 |
+
submit = gr.Button("Submit")
|
232 |
+
stop = gr.Button("Stop")
|
233 |
+
clear = gr.Button("Clear")
|
234 |
+
with gr.Row():
|
235 |
+
with gr.Accordion("Advanced Options:", open=False):
|
236 |
+
with gr.Row():
|
237 |
+
with gr.Column():
|
238 |
+
with gr.Row():
|
239 |
+
max_new_tokens = gr.Slider(
|
240 |
+
label="Max New Tokens",
|
241 |
+
value=512,
|
242 |
+
minimum=0,
|
243 |
+
maximum=1024,
|
244 |
+
step=1,
|
245 |
+
interactive=True,
|
246 |
+
info="The Max number of new tokens to generate.",
|
247 |
+
)
|
248 |
+
with gr.Column():
|
249 |
+
with gr.Row():
|
250 |
+
temperature = gr.Slider(
|
251 |
+
label="Temperature",
|
252 |
+
value=0.1,
|
253 |
+
minimum=0.0,
|
254 |
+
maximum=1.0,
|
255 |
+
step=0.1,
|
256 |
+
interactive=True,
|
257 |
+
info="Higher values produce more diverse outputs.",
|
258 |
+
)
|
259 |
+
with gr.Column():
|
260 |
+
with gr.Row():
|
261 |
+
top_p = gr.Slider(
|
262 |
+
label="Top-p (nucleus sampling)",
|
263 |
+
value=1.0,
|
264 |
+
minimum=0.0,
|
265 |
+
maximum=1,
|
266 |
+
step=0.01,
|
267 |
+
interactive=True,
|
268 |
+
info=(
|
269 |
+
"Sample from the smallest possible set of tokens whose cumulative probability "
|
270 |
+
"exceeds top_p. Set to 1 to disable and sample from all tokens."
|
271 |
+
),
|
272 |
+
)
|
273 |
+
with gr.Column():
|
274 |
+
with gr.Row():
|
275 |
+
top_k = gr.Slider(
|
276 |
+
label="Top-k",
|
277 |
+
value=0,
|
278 |
+
minimum=0.0,
|
279 |
+
maximum=200,
|
280 |
+
step=1,
|
281 |
+
interactive=True,
|
282 |
+
info="Sample from a shortlist of top-k tokens β 0 to disable and sample from all tokens.",
|
283 |
+
)
|
284 |
+
with gr.Column():
|
285 |
+
with gr.Row():
|
286 |
+
repetition_penalty = gr.Slider(
|
287 |
+
label="Repetition Penalty",
|
288 |
+
value=1.0,
|
289 |
+
minimum=1.0,
|
290 |
+
maximum=2.0,
|
291 |
+
step=0.1,
|
292 |
+
interactive=True,
|
293 |
+
info="Penalize repetition β 1.0 to disable.",
|
294 |
+
)
|
295 |
+
with gr.Row():
|
296 |
+
gr.Markdown(
|
297 |
+
"Disclaimer: Vigogne is still under development, and there are many limitations that have to be addressed. Please note that it is possible that the model generates harmful or biased content, incorrect information or generally unhelpful answers.",
|
298 |
+
elem_classes=["disclaimer"],
|
299 |
+
)
|
300 |
+
with gr.Row():
|
301 |
+
gr.Markdown(
|
302 |
+
"Acknowledgements: This demo is built on top of [MPT-7B-Chat](https://huggingface.co/spaces/mosaicml/mpt-7b-chat). Thanks for their contribution!",
|
303 |
+
elem_classes=["disclaimer"],
|
304 |
+
)
|
305 |
+
|
306 |
+
submit_event = msg.submit(
|
307 |
+
fn=user,
|
308 |
+
inputs=[msg, chatbot],
|
309 |
+
outputs=[msg, chatbot],
|
310 |
+
queue=False,
|
311 |
+
).then(
|
312 |
+
fn=bot,
|
313 |
+
inputs=[
|
314 |
+
chatbot,
|
315 |
+
max_new_tokens,
|
316 |
+
temperature,
|
317 |
+
top_p,
|
318 |
+
top_k,
|
319 |
+
repetition_penalty,
|
320 |
+
# conversation_id,
|
321 |
+
],
|
322 |
+
outputs=chatbot,
|
323 |
+
queue=True,
|
324 |
+
)
|
325 |
+
submit_click_event = submit.click(
|
326 |
+
fn=user,
|
327 |
+
inputs=[msg, chatbot],
|
328 |
+
outputs=[msg, chatbot],
|
329 |
+
queue=False,
|
330 |
+
).then(
|
331 |
+
fn=bot,
|
332 |
+
inputs=[
|
333 |
+
chatbot,
|
334 |
+
max_new_tokens,
|
335 |
+
temperature,
|
336 |
+
top_p,
|
337 |
+
top_k,
|
338 |
+
repetition_penalty,
|
339 |
+
# conversation_id,
|
340 |
+
],
|
341 |
+
outputs=chatbot,
|
342 |
+
queue=True,
|
343 |
+
)
|
344 |
+
stop.click(
|
345 |
+
fn=None,
|
346 |
+
inputs=None,
|
347 |
+
outputs=None,
|
348 |
+
cancels=[submit_event, submit_click_event],
|
349 |
+
queue=False,
|
350 |
+
)
|
351 |
+
clear.click(lambda: None, None, chatbot, queue=False)
|
352 |
+
|
353 |
+
demo.queue(max_size=128, concurrency_count=2)
|
354 |
+
demo.launch(enable_queue=True, share=share, server_name=server_name, server_port=server_port)
|
355 |
+
|
356 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets
|
2 |
+
loralib
|
3 |
+
sentencepiece
|
4 |
+
git+https://github.com/huggingface/transformers.git
|
5 |
+
accelerate
|
6 |
+
bitsandbytes
|
7 |
+
git+https://github.com/huggingface/peft.git
|
8 |
+
gradio
|