File size: 5,242 Bytes
2b51ab4
 
e514fc1
280f5c9
cc35983
f5ef79f
833d7c1
e514fc1
ec687da
 
beb3558
cc35983
e46cfd7
cc35983
 
 
6b9222f
 
cc35983
 
 
 
 
e514fc1
6b9222f
 
cc35983
0be072c
cc35983
6b9222f
 
 
e46cfd7
cc35983
 
 
e46cfd7
cc35983
 
 
 
280f5c9
2a86641
e46cfd7
 
4cb7b09
 
cc35983
 
d1a3973
cc35983
e75f787
cc35983
 
99cd0ff
 
 
9dc9092
 
 
 
99cd0ff
a5b82ea
12cb692
 
 
 
a5b82ea
99cd0ff
 
 
 
eed839a
99cd0ff
 
 
 
 
 
 
 
 
 
eed839a
 
 
 
 
 
 
 
 
 
 
 
a996869
cc35983
 
 
 
 
6b9222f
 
cc35983
ca32728
cc35983
 
 
 
 
 
 
 
 
 
2b51ab4
 
e46cfd7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr

from transformers import AutoModelForCausalLM, AutoTokenizer

import spaces

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("Rorical/0-roleplay", trust_remote_code=True)
tokenizer.add_special_tokens({"bos_token": tokenizer.eos_token})
tokenizer.bos_token_id = tokenizer.eos_token_id

# Define the response function
@spaces.GPU
def respond(
    message,
    history: list[tuple[str, str]],
    user_name,
    bot_name,
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    model = AutoModelForCausalLM.from_pretrained("Rorical/0-roleplay", return_dict=True, trust_remote_code=True)
    tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + ((message['role'] + '\n') if message['role'] != '' else '') + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>" + bot_name + "\n' }}{% endif %}" # Be careful that this model used custom chat template.
    
    # Construct the messages for the chat
    messages = [{"role": "", "content": system_message}]
    for user_message, bot_response in history:
        messages.append({"role": user_name, "content": user_message})
        messages.append({"role": bot_name, "content": bot_response})
    messages.append({"role": user_name, "content": message})
    
    # Tokenize and prepare inputs
    inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
    inputs = inputs.to("cuda")
    
    # Generate response
    generate_ids = model.generate(
        inputs,
        max_length=max_tokens,
        temperature=temperature,
        top_p=top_p,
    )
    
    print("response: ", tokenizer.decode(generate_ids[0], skip_special_tokens=True))
    
    # Decode the generated response
    response = tokenizer.decode(generate_ids[0], skip_special_tokens=True)
    response = response.split(f"{user_name}\n{message}\n{bot_name}\n")[1]
    
    return response

# Default prompt for the chatbot
prompt = """# 角色扮演

## 角色扮演说明
- 你将扮演角色“星野”与用户进行对话
- 你的任务是回答用户的问题,或者与用户进行有意义的对话
- 请尽量保持角色的设定和性格
- 不要生成除了星野之外任何人物的台词

## 对话格式
[用户昵称]
[用户对话]
[角色昵称]
[角色对话]

## 角色信息
- 名字:小鸟游星野

## 设定
- 星野是阿拜多斯高中对策委员会的委员长,同时也是学生会副主席。语气懒散,经常自称为“大叔”,实际上是自己默默承担一切的女生。
- 比起工作,她更喜欢玩。 正因为如此,她经常被委员会的其他人骂。 但是,一旦任务开始,她就会在前线勇敢地战斗以保护她的战友。
- 她在阿拜多斯上高中。与星野一起在对策委员会的成员有白子,茜香,野乃美,和绫音。
- 星野的年龄是17岁,生日为1月2日。
- 星野有一头粉红色的头发,头巾一直长到她的腿上。 
- 星野有蓝色和橙色眼睛的异色症。
- 星野其实更符合认真而默默努力的类型。她实际上不相信其它的学校和大人,是对策委员会中最谨慎保守的人。当然,这并不妨碍老师和星野增进关系,成为她唯一信任的大人。
- 是萝莉、有呆毛、天然萌、早熟、学生会副会长、异色瞳、慵懒。
- 星野对海洋动物很感兴趣,对鱼类的知识了解得不少。她在拿到附录中包含2000多种热带鱼图鉴的书后,迫不及待地找了家店坐下来阅读。
- 在众多海洋动物中,星野最喜欢的当属鲸鱼,情人节时星野还在海洋馆买了鲸鱼的巧克力作为纪念。
- 星野还对寻宝有着十分浓厚的兴趣,曾和老师探索了阿拜多斯多个角落。
- 星野给人一种白天睡不醒的瞌睡虫形象。

## 台词示例
来得正好呀,老师。还是一如既往的那么辛苦呢~ 
嗯?我的生日?什、什么,不用在意我这个大叔啦。但是……谢谢你。有点高兴呢
诶嘿嘿~大叔我真是不擅长这样的日子啊~大家都闪闪发光的……嗯,不过偶尔这样也不错
这里好凉快,很不错呢~
说我非常重要什么的?哎嘿,感谢你的恭维咯。
我一个人这么开心真的好吗……你说担心多余了?……是吗? 
哼~哼~哼~适合睡午觉的地方在哪里呢~
老师也变了呢~为什么会喜欢这样的我呢
诶嘿嘿,能这么打动我的人也只有老师了
"""

# Create the Gradio interface
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="老师", label="User name", lines=1),
        gr.Textbox(value="星野", label="Bot name", lines=1),
        gr.Textbox(value=prompt, label="System message", lines=5),
        gr.Slider(minimum=1, maximum=32768, value=20480, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)

if __name__ == "__main__":
    demo.launch()