Spaces:
QLWD
/
Running on Zero

File size: 4,098 Bytes
51a7d9e
232a3c7
 
51a7d9e
232a3c7
51a7d9e
 
 
232a3c7
51a7d9e
de898ce
cad536c
51a7d9e
232a3c7
 
51a7d9e
bd34f0b
232a3c7
bd34f0b
232a3c7
bd34f0b
 
51a7d9e
5169fea
 
 
 
 
 
 
 
 
 
 
51a7d9e
232a3c7
77e3cd7
 
232a3c7
 
77e3cd7
232a3c7
51a7d9e
232a3c7
248fc73
bd34f0b
51a7d9e
248fc73
 
cd70784
248fc73
 
51a7d9e
cd70784
248fc73
 
51a7d9e
 
232a3c7
bd34f0b
232a3c7
639e063
232a3c7
 
 
edb9e8a
232a3c7
edb9e8a
bd34f0b
 
 
232a3c7
 
51a7d9e
232a3c7
51a7d9e
232a3c7
 
edb9e8a
 
51a7d9e
edb9e8a
 
 
 
51a7d9e
232a3c7
51a7d9e
 
 
 
 
 
232a3c7
51a7d9e
 
 
 
232a3c7
51a7d9e
232a3c7
 
 
 
 
51a7d9e
 
 
 
232a3c7
51a7d9e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from peft import PeftModel
import gradio as gr
from threading import Thread
import spaces
import os

# 从环境变量中获取 Hugging Face 模型信息
HF_TOKEN = os.environ.get("HF_TOKEN", None)
BASE_MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct"  # 替换为基础模型
LORA_MODEL_PATH = "QLWD/test-7b"  # 替换为 LoRA 模型仓库路径

# 定义界面标题和描述
TITLE = "<h1><center>LoRA 微调模型测试</center></h1>"

DESCRIPTION = f"""
<h3>模型: <a href="https://huggingface.co/{LORA_MODEL_PATH}">LoRA 微调模型</a></h3>
<center>
<p>测试基础模型 + LoRA 补丁的生成效果。</p>
</center>
"""

CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
}
"""

# 加载基础模型和 LoRA 微调权重
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, torch_dtype=torch.float16, device_map="auto", use_auth_token=HF_TOKEN)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_auth_token=HF_TOKEN)

# 加载 LoRA 微调权重
model = PeftModel.from_pretrained(base_model, LORA_MODEL_PATH, use_auth_token=HF_TOKEN)
model = model.to("cuda" if torch.cuda.is_available() else "cpu")

# 定义推理函数
@spaces.GPU(duration=50)
def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
    conversation = []
    
    # 添加系统提示,定义模型的角色
    conversation.append({"role": "system", "content": "你是一个名为'漏洞助手'的检测代码漏洞的AI助手,帮助用户找到并修复代码中的安全问题,给出代码漏洞的具体片段,指出类型,给出修复建议。"})

    # 将历史对话内容添加到会话中
    for prompt, answer in history:
        conversation.extend([{"role": "user", "content": prompt}, {"role": "漏洞助手", "content": answer}])
    
    # 添加当前用户的输入到对话中
    conversation.append({"role": "user", "content": message})

    # 使用自定义对话模板生成 input_ids
    input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(input_ids, return_tensors="pt").to("cuda")

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    
    # 设置生成参数
    generate_kwargs = dict(
        inputs,
        streamer=streamer,
        top_k=top_k,
        top_p=top_p,
        repetition_penalty=penalty,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        eos_token_id=[151645, 151643],
    )

    # 启动生成线程
    thread = Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    buffer = ""
    for new_text in streamer:
        buffer += new_text
        yield buffer

# 定义 Gradio 界面
chatbot = gr.Chatbot(height=450)

with gr.Blocks(css=CSS) as demo:
    gr.HTML(TITLE)
    gr.HTML(DESCRIPTION)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
    
    gr.ChatInterface(
        fn=stream_chat,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ 参数设置", open=False, render=False),
        additional_inputs=[
            gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature", render=False),
            gr.Slider(minimum=128, maximum=4096, step=1, value=1024, label="Max new tokens", render=False),
            gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.8, label="top_p", render=False),
            gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_k", render=False),
            gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1.0, label="Repetition penalty", render=False),
        ],
        cache_examples=False,
    )

# 启动 Gradio 应用
if __name__ == "__main__":
    demo.launch()