Spaces:
Paused
Paused
import transformers | |
import re | |
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM | |
from vllm import LLM, SamplingParams | |
import torch | |
import gradio as gr | |
import json | |
import os | |
import shutil | |
import requests | |
import lancedb | |
import pandas as pd | |
# Define the device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Define variables | |
temperature = 0.7 | |
max_new_tokens = 3000 | |
top_p = 0.95 | |
repetition_penalty = 1.2 | |
model_name = "PleIAs/Cassandre-RAG" | |
# Initialize vLLM | |
llm = LLM(model_name, max_model_len=8128) | |
# Connect to the LanceDB database | |
db = lancedb.connect("content/lancedb_data") | |
table = db.open_table("eduv1") | |
def hybrid_search(text): | |
results = table.search(text, query_type="hybrid").limit(6).to_pandas() | |
document = [] | |
document_html = [] | |
for _, row in results.iterrows(): | |
hash_id = str(row['hash']) | |
title = row['main_title'] | |
#content = row['text'][:100] + "..." # Truncate the text for preview | |
content = row['text'] | |
document.append(f"**{hash_id}**\n{title}\n{content}") | |
document_html.append(f'<div class="source" id="{hash_id}"><p><b>{hash_id}</b> : {title}<br>{content}</div>') | |
document = "\n\n".join(document) | |
document_html = '<div id="source_listing">' + "".join(document_html) + "</div>" | |
return document, document_html | |
class CassandreChatBot: | |
def __init__(self, system_prompt="Tu es Cassandre, le chatbot de l'Éducation nationale qui donne des réponses sourcées."): | |
self.system_prompt = system_prompt | |
def predict(self, user_message): | |
fiches, fiches_html = hybrid_search(user_message) | |
sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_new_tokens, presence_penalty=repetition_penalty, stop=["#END#"]) | |
detailed_prompt = f"""### Query ###\n{user_message}\n\n### Source ###\n{fiches}\n\n### Answer ###\n""" | |
prompts = [detailed_prompt] | |
outputs = llm.generate(prompts, sampling_params, use_tqdm=False) | |
generated_text = outputs[0].outputs[0].text | |
generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + format_references(generated_text) + "</div>" | |
fiches_html = '<h2 style="text-align:center">Sources</h3>\n' + fiches_html | |
return generated_text, fiches_html | |
def format_references(text): | |
ref_start_marker = '<ref text="' | |
ref_end_marker = '</ref>' | |
parts = [] | |
current_pos = 0 | |
ref_number = 1 | |
while True: | |
start_pos = text.find(ref_start_marker, current_pos) | |
if start_pos == -1: | |
parts.append(text[current_pos:]) | |
break | |
parts.append(text[current_pos:start_pos]) | |
end_pos = text.find('">', start_pos) | |
if end_pos == -1: | |
break | |
ref_text = text[start_pos + len(ref_start_marker):end_pos].replace('\n', ' ').strip() | |
ref_text_encoded = ref_text.replace("&", "&").replace("<", "<").replace(">", ">") | |
ref_end_pos = text.find(ref_end_marker, end_pos) | |
if ref_end_pos == -1: | |
break | |
ref_id = text[end_pos + 2:ref_end_pos].strip() | |
tooltip_html = f'<span class="tooltip" data-refid="{ref_id}" data-text="{ref_id}: {ref_text_encoded}"><a href="#{ref_id}">[{ref_number}]</a></span>' | |
parts.append(tooltip_html) | |
current_pos = ref_end_pos + len(ref_end_marker) | |
ref_number = ref_number + 1 | |
return ''.join(parts) | |
# Initialize the CassandreChatBot | |
cassandre_bot = CassandreChatBot() | |
# CSS for styling | |
css = """ | |
.generation { | |
margin-left:2em; | |
margin-right:2em; | |
} | |
:target { | |
background-color: #CCF3DF; | |
} | |
.source { | |
float:left; | |
max-width:17%; | |
margin-left:2%; | |
} | |
.tooltip { | |
position: relative; | |
cursor: pointer; | |
font-variant-position: super; | |
color: #97999b; | |
} | |
.tooltip:hover::after { | |
content: attr(data-text); | |
position: absolute; | |
left: 0; | |
top: 120%; | |
white-space: pre-wrap; | |
width: 500px; | |
max-width: 500px; | |
z-index: 1; | |
background-color: #f9f9f9; | |
color: #000; | |
border: 1px solid #ddd; | |
border-radius: 5px; | |
padding: 5px; | |
display: block; | |
box-shadow: 0 4px 8px rgba(0,0,0,0.1); | |
} | |
""" | |
# Gradio interface | |
def gradio_interface(user_message): | |
response, sources = cassandre_bot.predict(user_message) | |
return response, sources | |
# Create Gradio app | |
demo = gr.Blocks(css=css) | |
with demo: | |
gr.HTML("""<h1 style="text-align:center">Cassandre</h1>""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
text_input = gr.Textbox(label="Votre question ou votre instruction", lines=3) | |
text_button = gr.Button("Interroger Cassandre") | |
with gr.Column(scale=3): | |
text_output = gr.HTML(label="La réponse de Cassandre") | |
with gr.Row(): | |
embedding_output = gr.HTML(label="Les sources utilisées") | |
text_button.click(gradio_interface, inputs=text_input, outputs=[text_output, embedding_output]) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() |