WooWoof_AI / app.py
larry1129's picture
Update app.py
0af3958 verified
raw
history blame
3.12 kB
import gradio as gr
from transformers import AutoTokenizer
from peft import PeftModel
import torch
import os
import spaces
# 获取 Hugging Face 访问令牌
hf_token = os.getenv("HF_API_TOKEN")
# 定义基础模型名称
base_model_name = "larry1129/meta-llama-3.1-8b-bnb-4bit" # 替换为你的基础模型名称
# 定义 adapter 模型名称
adapter_model_name = "larry1129/WooWoof_AI" # 替换为你的 adapter 模型名称
# 加载分词器(无需 GPU,可在全局加载)
tokenizer = AutoTokenizer.from_pretrained(base_model_name, use_auth_token=hf_token)
# 定义一个全局变量用于缓存模型
model = None
# 定义提示生成函数
def generate_prompt(instruction, input_text=""):
if input_text:
prompt = f"""### Instruction:
{instruction}
### Input:
{input_text}
### Response:
"""
else:
prompt = f"""### Instruction:
{instruction}
### Response:
"""
return prompt
# 定义生成响应的函数,并使用 @spaces.GPU 装饰
@spaces.GPU
def generate_response(instruction, input_text):
global model
if model is None:
# 在函数内部导入需要 GPU 的库
import bitsandbytes
from transformers import AutoModelForCausalLM
# 加载基础模型
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
device_map="auto",
torch_dtype=torch.float16,
use_auth_token=hf_token,
trust_remote_code=True # 如果你的模型使用自定义代码,请保留此参数
)
# 加载 adapter 并将其应用到基础模型上
model = PeftModel.from_pretrained(
base_model,
adapter_model_name,
device_map="auto",
torch_dtype=torch.float16,
use_auth_token=hf_token,
trust_remote_code=True
)
# 设置 pad_token
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id
# 切换到评估模式
model.eval()
# 生成提示
prompt = generate_prompt(instruction, input_text)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=128,
temperature=0.7,
top_p=0.95,
do_sample=True,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("### Response:")[-1].strip()
return response
# 创建 Gradio 接口
iface = gr.Interface(
fn=generate_response,
inputs=[
gr.Textbox(lines=2, placeholder="请输入指令...", label="Instruction"),
gr.Textbox(lines=2, placeholder="如果有额外输入,请在此填写...", label="Input (可选)")
],
outputs="text",
title="WooWoof AI 交互式聊天",
description="基于 LLAMA 3.1 的大语言模型,支持指令和可选输入。",
allow_flagging="never"
)
# 启动 Gradio 接口
iface.launch()