tanyuzhou commited on
Commit
e46cfd7
1 Parent(s): 6386785

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -31
app.py CHANGED
@@ -1,12 +1,18 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
9
 
 
 
 
 
 
 
 
10
  def respond(
11
  message,
12
  history: list[tuple[str, str]],
@@ -15,37 +21,54 @@ def respond(
15
  temperature,
16
  top_p,
17
  ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
  temperature=temperature,
35
  top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
 
 
 
 
 
 
 
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
  demo = gr.ChatInterface(
46
  respond,
47
  additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
  gr.Slider(
@@ -58,6 +81,5 @@ demo = gr.ChatInterface(
58
  ],
59
  )
60
 
61
-
62
  if __name__ == "__main__":
63
- demo.launch()
 
1
  import gradio as gr
 
2
 
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from transformers import TextStreamer
 
6
 
7
+ import spaces
8
 
9
+ # Load model and tokenizer
10
+ model = AutoModelForCausalLM.from_pretrained("Rorical/0-roleplay", return_dict=True, trust_remote_code=True)
11
+ tokenizer = AutoTokenizer.from_pretrained("Rorical/0-roleplay", trust_remote_code=True)
12
+ tokenizer.chat_template = "{% for message in messages %}{{'' + ((message['role'] + '\n') if message['role'] != '' else '') + message['content'] + '' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '星野\n' }}{% endif %}"
13
+
14
+ # Define the response function
15
+ @spaces.GPU
16
  def respond(
17
  message,
18
  history: list[tuple[str, str]],
 
21
  temperature,
22
  top_p,
23
  ):
24
+ # Construct the messages for the chat
25
+ messages = [{"role": "", "content": system_message}]
26
+ for user_message, bot_response in history:
27
+ messages.append({"role": "老师", "content": user_message}) # Assuming the user is "老师"
28
+ messages.append({"role": "星野", "content": bot_response}) # Assuming the bot is "星野"
29
+ messages.append({"role": "老师", "content": message}) # Append the latest user message
30
+
31
+ # Tokenize and prepare inputs
32
+ inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
33
+ inputs = inputs.to("cuda")
34
+
35
+ # Generate response
36
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
37
+ generate_ids = model.generate(
38
+ inputs,
39
+ max_length=max_tokens,
40
  temperature=temperature,
41
  top_p=top_p,
42
+ streamer=streamer
43
+ )
44
+
45
+ # Decode the generated response
46
+ response = tokenizer.decode(generate_ids[0], skip_special_tokens=True)
47
+ response = response.replace("星野\n", "", 1) # Remove the "星野\n" prefix if it exists
48
+ response = response.split("老师\n")[0] # Split on the next possible user input
49
+
50
+ return response, history + [(message, response)]
51
 
52
+ # Default prompt for the chatbot
53
+ prompt = """以下是小鸟游星野的介绍
54
+ 星野是阿拜多斯高中对策委员会的委员长,同时也是学生会副主席。语气懒散,经常自称为大叔,实际上是自己默默承担一切的女生。
55
+ 比起工作,她更喜欢玩。 正因为如此,她经常被委员会的其他人骂。 但是,一旦任务开始,她就会在前线勇敢地战斗以保护她的战友。
56
+ 她在阿拜多斯上高中。与星野一起在对策委员会的成员有白子,茜香,野乃美,和绫音。
57
+ 星野的年龄是17岁,生日为1月2日。
58
+ 星野有一头粉红色的头发,头巾一直长到她的腿上。
59
+ 星野有蓝色和橙色眼睛的异色症。
60
+ 星野其实更符合认真而默默努力的类型。她实际上不相信其它的学校和大人,是对策委员会中最谨慎保守的人。当然,这并不妨碍老师和星野增进关系,成为她唯一信任的大人。
61
+ 是萝莉、有呆毛、天然萌、早熟、学生会副会长、异色瞳、慵懒。
62
+ 星野对海洋动物很感兴趣,对鱼类的知识了解得不少。她在拿到附录中包含2000多种热带鱼图鉴的书后,迫不及待地找了家店坐下来阅读。
63
+ 在众多海洋动物中,星野最喜欢的当属鲸鱼,情人节时星野还在海洋馆买了鲸鱼的巧克力作为纪念。
64
+ 星野还对寻宝有着十分浓厚的兴趣,曾和老师探索了阿拜多斯多个角落。
65
+ 星野给人一种白天睡不醒的瞌睡虫形象。"""
66
 
67
+ # Create the Gradio interface
 
 
68
  demo = gr.ChatInterface(
69
  respond,
70
  additional_inputs=[
71
+ gr.Textbox(value=prompt, label="System message", lines=5),
72
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
73
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
74
  gr.Slider(
 
81
  ],
82
  )
83
 
 
84
  if __name__ == "__main__":
85
+ demo.launch()