File size: 5,265 Bytes
d1c980d
 
 
 
 
 
265b4a4
 
d1c980d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07c2aaa
 
 
 
 
d1c980d
 
 
07c2aaa
 
 
 
 
 
 
 
 
 
d1c980d
 
 
07c2aaa
 
d1c980d
 
 
 
 
 
 
 
 
 
07c2aaa
80a6c67
 
 
 
 
 
 
63404b3
 
 
07c2aaa
80a6c67
 
 
 
 
 
 
 
 
 
02372e0
b6e4d03
 
 
 
 
d1c980d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import spaces
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread

# model_path = 'dreamerdeo/Sailor2-0.8B-Chat'
model_path = 'sail/Sailor-0.5B-Chat'

# Loading the tokenizer and model from Hugging Face's model hub.
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)

# using CUDA for an optimal experience
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Defining a custom stopping criteria class for the model's text generation.
class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = [151645]  # IDs of tokens where the generation should stop.
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:  # Checking if the last generated token is a stop token.
                return True
        return False


system_role= 'system'
user_role = 'user'
assistant_role = 'assistant'

sft_start_token =  "<|im_start|>"
sft_end_token = "<|im_end|>"
ct_end_token = "<|endoftext|>"

system_prompt= \
'You are an AI assistant named Sailor2, created by Sea AI Lab. \
As an AI assistant, you can answer questions in English, Chinese, and Southeast Asian languages \
such as Burmese, Cebuano, Ilocano, Indonesian, Javanese, Khmer, Lao, Malay, Sundanese, Tagalog, Thai, Vietnamese, and Waray. \
Your responses should be friendly, unbiased, informative, detailed, and faithful.'

system_prompt = f"<|im_start|>{system_role}\n{system_prompt}<|im_end|>"

# Function to generate model predictions.
@spaces.GPU()
def predict(message, history):
    # 初始化对话历史格式
    if history is None:
        history = []

    # 在历史中添加当前用户输入,临时设置机器人的回复为空
    history_transformer_format = history + [[message, ""]]
    stop = StopOnTokens()

    # 格式化输入为模型需要的格式
    messages = (
        system_prompt
        + sft_end_token.join([
            sft_end_token.join([
                f"\n{sft_start_token}{user_role}\n" + item[0],
                f"\n{sft_start_token}{assistant_role}\n" + item[1]
            ]) for item in history_transformer_format
        ])
    )
    model_inputs = tokenizer([messages], return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        input_ids=model_inputs["input_ids"],
        attention_mask=model_inputs["attention_mask"],
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=True,
        top_p=0.8,
        top_k=20,
        temperature=0.7,
        num_beams=1,
        stopping_criteria=StoppingCriteriaList([stop]),
        repetition_penalty=1.1,
    )

    outputs = model.generate(**generate_kwargs)
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    partial_message = generated_text
    final_message = partial_message.replace(sft_end_token, "").strip()

    return final_message

    # # 使用线程来运行生成过程
    # t = Thread(target=model.generate, kwargs=generate_kwargs)
    # t.start()

    # # 实时生成部分消息
    # partial_message = ""
    # for new_token in streamer:
    #     partial_message += new_token
    #     if sft_end_token in partial_message:  # 检测到停止标志
    #         break
    #     # 将历史记录和当前消息转换为 tuple 格式并实时返回
    #     # yield [(msg, bot) for msg, bot in history] + [(message, partial_message)]
    #     # yield (message, partial_message)
    #     yield partial_message
    
    # # 处理生成的最终回复
    # final_message = partial_message.replace(sft_end_token, "").strip()
    # history.append([message, final_message])  # 更新历史记录
    # # 返回最终的对话历史,确保格式为元组的列表
    # yield [(msg, bot) for msg, bot in history]

css = """
full-height {
    height: 100%;
}
"""

prompt_examples = [
    'How to cook a fish?',
    'Cara memanggang ikan',
    'วิธีย่างปลา',
    'Cách nướng cá'
]

placeholder = """
<div style="opacity: 0.5;">
    <img src="https://raw.githubusercontent.com/sail-sg/sailor-llm/main/misc/banner.jpg" style="width:30%;">
    <br>Sailor models are designed to understand and generate text across diverse linguistic landscapes of these SEA regions:
    <br>🇮🇩Indonesian, 🇹🇭Thai, 🇻🇳Vietnamese, 🇲🇾Malay, and 🇱🇦Lao.
</div>
"""

chatbot = gr.Chatbot(label='Sailor', placeholder=placeholder) 
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
    # gr.Markdown("""<center><font size=8>Sailor-Chat Bot⚓</center>""")
    gr.Markdown("""<p align="center"><img src="https://github.com/sail-sg/sailor-llm/raw/main/misc/wide_sailor_banner.jpg" style="height: 110px"/><p>""")
    gr.ChatInterface(predict, chatbot=chatbot, fill_height=True, examples=prompt_examples, css=css)

    demo.launch()  # Launching the web interface.