|
import gradio as gr |
|
|
|
|
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer |
|
) |
|
from peft import PeftModel |
|
import torch |
|
|
|
model_path = "Qwen1.5-1.8B-Chat" |
|
lora_path = "." |
|
|
|
if torch.cuda.is_available(): |
|
device = "cuda:0" |
|
else: |
|
device = "cpu" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_path, |
|
) |
|
config_kwargs = {"device_map": device} |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.float16, |
|
**config_kwargs |
|
) |
|
|
|
model = PeftModel.from_pretrained(model, lora_path) |
|
model = model.merge_and_unload() |
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MAX_MATERIALS = 4 |
|
|
|
|
|
def call(related_materials, materials, question): |
|
query_texts = [f"材料{i + 1}\n{material}" for i, material in enumerate(materials) if i in related_materials] |
|
query_texts.append(f"问题:{question}") |
|
query = "\n".join(query_texts) |
|
messages = [ |
|
{"role": "system", "content": "请你根据以下提供的材料来回答问题"}, |
|
{"role": "user", "content": query} |
|
] |
|
text = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
model_inputs = tokenizer([text], return_tensors="pt").to(device) |
|
print(len(model_inputs.input_ids[0])) |
|
generated_ids = model.generate( |
|
model_inputs.input_ids, |
|
max_length=8096 |
|
) |
|
generated_ids = [ |
|
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) |
|
] |
|
|
|
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
return response |
|
|
|
|
|
def create_ui(): |
|
with gr.Blocks() as app: |
|
gr.Markdown("""<center><font size=8>EssayGPT-申论大模型</center>""") |
|
gr.Markdown( |
|
"""<center><font size=4>1.把材料填入对应位置 2.输入问题和要求 3.选择解答问题需要的相关材料 4.点击"提问!"</center>""") |
|
with gr.Row(): |
|
with gr.Column(): |
|
materials = [] |
|
|
|
for i in range(MAX_MATERIALS): |
|
with gr.Tab(f"材料{i + 1}"): |
|
materials.append(gr.Textbox(label="材料内容")) |
|
with gr.Column(): |
|
related_materials = gr.Dropdown( |
|
choices=list(range(1, MAX_MATERIALS + 1)), multiselect=True, |
|
label="问题所需相关材料") |
|
question = gr.Textbox(label="问题") |
|
submit = gr.Button("提问!") |
|
answer = gr.Textbox(label="回答") |
|
build_ui({"materials": materials, "related_materials": related_materials, "question": question, |
|
"submit": submit, "answer": answer}) |
|
return app |
|
|
|
|
|
def build_ui(components): |
|
def func(related_materials, question, *materials): |
|
if not related_materials: |
|
return "请选择问题所需相关材料" |
|
related_materials = [i - 1 for i in related_materials] |
|
return call(related_materials, materials, question) |
|
|
|
components["submit"].click(func, |
|
[components["related_materials"], components["question"], *components["materials"]], |
|
components["answer"]) |
|
|
|
|
|
def run(): |
|
app = create_ui() |
|
app.queue() |
|
app.launch(share=True) |
|
|
|
|
|
if __name__ == '__main__': |
|
run() |
|
|