|
import argparse |
|
import os |
|
|
|
import gradio as gr |
|
from loguru import logger |
|
from similarities import BertSimilarity, BM25Similarity |
|
|
|
from chatpdf import Rag |
|
|
|
pwd_path = os.path.abspath(os.path.dirname(__file__)) |
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--sim_model_name", type=str, default="shibing624/text2vec-base-multilingual") |
|
parser.add_argument("--gen_model_type", type=str, default="auto") |
|
parser.add_argument("--gen_model_name", type=str, default="Qwen/Qwen2-0.5B-Instruct") |
|
parser.add_argument("--lora_model", type=str, default=None) |
|
parser.add_argument("--rerank_model_name", type=str, default="") |
|
parser.add_argument("--corpus_files", type=str, default="Acuerdo009.pdf") |
|
parser.add_argument("--device", type=str, default=None) |
|
|
|
|
|
parser.add_argument("--chunk_size", type=int, default=220) |
|
parser.add_argument("--chunk_overlap", type=int, default=0) |
|
parser.add_argument("--num_expand_context_chunk", type=int, default=1) |
|
parser.add_argument("--server_name", type=str, default="0.0.0.0") |
|
parser.add_argument("--server_port", type=int, default=8082) |
|
parser.add_argument("--share", action='store_true', default=True, help="share model") |
|
args = parser.parse_args() |
|
logger.info(args) |
|
|
|
|
|
sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device) |
|
model = Rag( |
|
similarity_model=sim_model, |
|
generate_model_type=args.gen_model_type, |
|
generate_model_name_or_path=args.gen_model_name, |
|
lora_model_name_or_path=args.lora_model, |
|
corpus_files=args.corpus_files.split(','), |
|
device=args.device, |
|
chunk_size=args.chunk_size, |
|
chunk_overlap=args.chunk_overlap, |
|
num_expand_context_chunk=args.num_expand_context_chunk, |
|
rerank_model_name_or_path=args.rerank_model_name, |
|
) |
|
logger.info(f"chatpdf model: {model}") |
|
|
|
def predict_stream(message, history): |
|
history_format = [] |
|
for human, assistant in history: |
|
history_format.append([human, assistant]) |
|
model.history = history_format |
|
for chunk in model.predict_stream(message): |
|
yield chunk |
|
|
|
|
|
def predict(message, history): |
|
logger.debug(message) |
|
response, reference_results = model.predict(message) |
|
r = response + "\n\n" + '\n'.join(reference_results) |
|
logger.debug(r) |
|
return r |
|
|
|
|
|
chatbot_stream = gr.Chatbot( |
|
height=600, |
|
avatar_images=( |
|
os.path.join(pwd_path, "assets/user.png"), |
|
os.path.join(pwd_path, "assets/Logo1.png"), |
|
), bubble_full_width=False) |
|
|
|
title = " 馃ChatPDF Zonia馃 " |
|
|
|
css = """.toast-wrap { display: none !importante } """ |
|
examples = ['Puede hablarme del PNL?', 'Introducci贸n a la PNL'] |
|
chat_interface_stream = gr.ChatInterface( |
|
predict, |
|
textbox=gr.Textbox(lines=4, placeholder="Ask me question", scale=7), |
|
title=title, |
|
|
|
chatbot=chatbot_stream, |
|
css=css, |
|
examples=examples, |
|
theme='soft', |
|
) |
|
|
|
|
|
chat_interface_stream.launch() |