pwn-7b / app.py
QLWD's picture
Update app.py
232a3c7 verified
raw
history blame
3.87 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from peft import PeftModel
import gradio as gr
from threading import Thread
import spaces
import os
# 从环境变量中获取 Hugging Face 模型信息
HF_TOKEN = os.environ.get("HF_TOKEN", None)
BASE_MODEL_ID = "Qwen/Qwen2.5-3B-Instruct" # 替换为基础模型
LORA_MODEL_PATH = "QLWD/test-3b" # 替换为 LoRA 模型仓库路径
# 定义界面标题和描述
TITLE = "<h1><center>LoRA 微调模型测试</center></h1>"
DESCRIPTION = f"""
<h3>模型: <a href="https://huggingface.co/{LORA_MODEL_PATH}">LoRA 微调模型</a></h3>
<center>
<p>测试基础模型 + LoRA 补丁的生成效果。</p>
</center>
"""
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
}
"""
# 加载基础模型和 LoRA 微调权重
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
# 加载 LoRA 微调权重
model = PeftModel.from_pretrained(base_model, LORA_MODEL_PATH)
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
# 定义推理函数
@spaces.GPU(duration=2)
def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
conversation = []
for prompt, answer in history:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
conversation.append({"role": "user", "content": message})
# 使用自定义对话模板生成 input_ids
input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(input_ids, return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
# 设置生成参数
generate_kwargs = dict(
inputs,
streamer=streamer,
top_k=top_k,
top_p=top_p,
repetition_penalty=penalty,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=[151645, 151643],
)
# 启动生成线程
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
# 定义 Gradio 界面
chatbot = gr.Chatbot(height=450)
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ 参数设置", open=False, render=False),
additional_inputs=[
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature", render=False),
gr.Slider(minimum=128, maximum=4096, step=1, value=1024, label="Max new tokens", render=False),
gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.8, label="top_p", render=False),
gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_k", render=False),
gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1.0, label="Repetition penalty", render=False),
],
examples=[
["请帮我生成一段关于学习的句子"],
["解释一下量子计算的概念"],
["给我提供一些Python编程技巧"],
["用CSS和JavaScript创建一个固定的页眉"],
],
cache_examples=False,
)
# 启动 Gradio 应用
if __name__ == "__main__":
demo.launch()