gutgut / app.py
Carlos Rosas
Rename app(1).py to app.py
06286b2 verified
raw
history blame
5.07 kB
import transformers
import re
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
from vllm import LLM, SamplingParams
import torch
import gradio as gr
import json
import os
import shutil
import requests
import lancedb
import pandas as pd
# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Define variables
temperature = 0.7
max_new_tokens = 3000
top_p = 0.95
repetition_penalty = 1.2
model_name = "PleIAs/Cassandre-RAG"
# Initialize vLLM
llm = LLM(model_name, max_model_len=8128)
# Connect to the LanceDB database
db = lancedb.connect("content/lancedb_data")
table = db.open_table("eduv1")
def hybrid_search(text):
results = table.search(text, query_type="hybrid").limit(6).to_pandas()
document = []
document_html = []
for _, row in results.iterrows():
hash_id = str(row['hash'])
title = row['main_title']
#content = row['text'][:100] + "..." # Truncate the text for preview
content = row['text']
document.append(f"**{hash_id}**\n{title}\n{content}")
document_html.append(f'<div class="source" id="{hash_id}"><p><b>{hash_id}</b> : {title}<br>{content}</div>')
document = "\n\n".join(document)
document_html = '<div id="source_listing">' + "".join(document_html) + "</div>"
return document, document_html
class CassandreChatBot:
def __init__(self, system_prompt="Tu es Cassandre, le chatbot de l'Éducation nationale qui donne des réponses sourcées."):
self.system_prompt = system_prompt
def predict(self, user_message):
fiches, fiches_html = hybrid_search(user_message)
sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_new_tokens, presence_penalty=repetition_penalty, stop=["#END#"])
detailed_prompt = f"""### Query ###\n{user_message}\n\n### Source ###\n{fiches}\n\n### Answer ###\n"""
prompts = [detailed_prompt]
outputs = llm.generate(prompts, sampling_params, use_tqdm=False)
generated_text = outputs[0].outputs[0].text
generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + format_references(generated_text) + "</div>"
fiches_html = '<h2 style="text-align:center">Sources</h3>\n' + fiches_html
return generated_text, fiches_html
def format_references(text):
ref_start_marker = '<ref text="'
ref_end_marker = '</ref>'
parts = []
current_pos = 0
ref_number = 1
while True:
start_pos = text.find(ref_start_marker, current_pos)
if start_pos == -1:
parts.append(text[current_pos:])
break
parts.append(text[current_pos:start_pos])
end_pos = text.find('">', start_pos)
if end_pos == -1:
break
ref_text = text[start_pos + len(ref_start_marker):end_pos].replace('\n', ' ').strip()
ref_text_encoded = ref_text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
ref_end_pos = text.find(ref_end_marker, end_pos)
if ref_end_pos == -1:
break
ref_id = text[end_pos + 2:ref_end_pos].strip()
tooltip_html = f'<span class="tooltip" data-refid="{ref_id}" data-text="{ref_id}: {ref_text_encoded}"><a href="#{ref_id}">[{ref_number}]</a></span>'
parts.append(tooltip_html)
current_pos = ref_end_pos + len(ref_end_marker)
ref_number = ref_number + 1
return ''.join(parts)
# Initialize the CassandreChatBot
cassandre_bot = CassandreChatBot()
# CSS for styling
css = """
.generation {
margin-left:2em;
margin-right:2em;
}
:target {
background-color: #CCF3DF;
}
.source {
float:left;
max-width:17%;
margin-left:2%;
}
.tooltip {
position: relative;
cursor: pointer;
font-variant-position: super;
color: #97999b;
}
.tooltip:hover::after {
content: attr(data-text);
position: absolute;
left: 0;
top: 120%;
white-space: pre-wrap;
width: 500px;
max-width: 500px;
z-index: 1;
background-color: #f9f9f9;
color: #000;
border: 1px solid #ddd;
border-radius: 5px;
padding: 5px;
display: block;
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}
"""
# Gradio interface
def gradio_interface(user_message):
response, sources = cassandre_bot.predict(user_message)
return response, sources
# Create Gradio app
demo = gr.Blocks(css=css)
with demo:
gr.HTML("""<h1 style="text-align:center">Cassandre</h1>""")
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(label="Votre question ou votre instruction", lines=3)
text_button = gr.Button("Interroger Cassandre")
with gr.Column(scale=3):
text_output = gr.HTML(label="La réponse de Cassandre")
with gr.Row():
embedding_output = gr.HTML(label="Les sources utilisées")
text_button.click(gradio_interface, inputs=text_input, outputs=[text_output, embedding_output])
# Launch the app
if __name__ == "__main__":
demo.launch()