ZoniaQwen / app.py
ZoniaChatbot's picture
Update app.py
793c82d verified
import argparse
import os
import gradio as gr
from loguru import logger
from similarities import BertSimilarity, BM25Similarity
from chatpdf import Rag
pwd_path = os.path.abspath(os.path.dirname(__file__))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--sim_model_name", type=str, default="shibing624/text2vec-base-multilingual")
parser.add_argument("--gen_model_type", type=str, default="auto")
parser.add_argument("--gen_model_name", type=str, default="Qwen/Qwen2-0.5B-Instruct")
parser.add_argument("--lora_model", type=str, default=None)
parser.add_argument("--rerank_model_name", type=str, default="")
parser.add_argument("--corpus_files", type=str, default="Acuerdo009.pdf")
parser.add_argument("--device", type=str, default=None)
#parser.add_argument("--int4", action='store_true', help="use int4 quantization")
#parser.add_argument("--int8", action='store_true', help="use int8 quantization")
parser.add_argument("--chunk_size", type=int, default=220)
parser.add_argument("--chunk_overlap", type=int, default=0)
parser.add_argument("--num_expand_context_chunk", type=int, default=1)
parser.add_argument("--server_name", type=str, default="0.0.0.0")
parser.add_argument("--server_port", type=int, default=8082)
parser.add_argument("--share", action='store_true', default=True, help="share model")
args = parser.parse_args()
logger.info(args)
# Inicializar el modelo
sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device)
model = Rag(
similarity_model=sim_model,
generate_model_type=args.gen_model_type,
generate_model_name_or_path=args.gen_model_name,
lora_model_name_or_path=args.lora_model,
corpus_files=args.corpus_files.split(','),
device=args.device,
chunk_size=args.chunk_size,
chunk_overlap=args.chunk_overlap,
num_expand_context_chunk=args.num_expand_context_chunk,
rerank_model_name_or_path=args.rerank_model_name,
)
logger.info(f"chatpdf model: {model}")
def predict_stream(message, history):
history_format = []
for human, assistant in history:
history_format.append([human, assistant])
model.history = history_format
for chunk in model.predict_stream(message):
yield chunk
def predict(message, history):
logger.debug(message)
response, reference_results = model.predict(message)
r = response + "\n\n" + '\n'.join(reference_results)
logger.debug(r)
return r
chatbot_stream = gr.Chatbot(
height=600,
avatar_images=(
os.path.join(pwd_path, "assets/user.png"),
os.path.join(pwd_path, "assets/Logo1.png"),
), bubble_full_width=False)
# Actualizar el t铆tulo y la descripci贸n
title = " 馃ChatPDF Zonia馃 "
# description = "Enlace en Github: [shibing624/ChatPDF](https://github.com/shibing624/ChatPDF)"
css = """.toast-wrap { display: none !importante } """
examples = ['Puede hablarme del PNL?', 'Introducci贸n a la PNL']
chat_interface_stream = gr.ChatInterface(
predict,
textbox=gr.Textbox(lines=4, placeholder="Ask me question", scale=7), # A帽adir submit=True
title=title,
# description=description,
chatbot=chatbot_stream,
css=css,
examples=examples,
theme='soft',
)
# Lanzar la aplicaci贸n sin `server_name` ni `server_port`
chat_interface_stream.launch()