import transformers import re from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM from vllm import LLM, SamplingParams import torch import gradio as gr import json import os import shutil import requests import lancedb import pandas as pd # Define the device device = "cuda" if torch.cuda.is_available() else "cpu" # Define variables temperature = 0.7 max_new_tokens = 3000 top_p = 0.95 repetition_penalty = 1.2 model_name = "PleIAs/Cassandre-RAG" # Initialize vLLM llm = LLM(model_name, max_model_len=8128) # Connect to the LanceDB database db = lancedb.connect("content/lancedb_data") table = db.open_table("eduv1") def hybrid_search(text): results = table.search(text, query_type="hybrid").limit(6).to_pandas() document = [] document_html = [] for _, row in results.iterrows(): hash_id = str(row['hash']) title = row['main_title'] #content = row['text'][:100] + "..." # Truncate the text for preview content = row['text'] document.append(f"**{hash_id}**\n{title}\n{content}") document_html.append(f'

{hash_id} : {title}
{content}

') document = "\n\n".join(document) document_html = '
' + "".join(document_html) + "
" return document, document_html class CassandreChatBot: def __init__(self, system_prompt="Tu es Cassandre, le chatbot de l'Éducation nationale qui donne des réponses sourcées."): self.system_prompt = system_prompt def predict(self, user_message): fiches, fiches_html = hybrid_search(user_message) sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_new_tokens, presence_penalty=repetition_penalty, stop=["#END#"]) detailed_prompt = f"""### Query ###\n{user_message}\n\n### Source ###\n{fiches}\n\n### Answer ###\n""" prompts = [detailed_prompt] outputs = llm.generate(prompts, sampling_params, use_tqdm=False) generated_text = outputs[0].outputs[0].text generated_text = '

Réponse

\n
' + format_references(generated_text) + "
" fiches_html = '

Sources

\n' + fiches_html return generated_text, fiches_html def format_references(text): ref_start_marker = '', start_pos) if end_pos == -1: break ref_text = text[start_pos + len(ref_start_marker):end_pos].replace('\n', ' ').strip() ref_text_encoded = ref_text.replace("&", "&").replace("<", "<").replace(">", ">") ref_end_pos = text.find(ref_end_marker, end_pos) if ref_end_pos == -1: break ref_id = text[end_pos + 2:ref_end_pos].strip() tooltip_html = f'[{ref_number}]' parts.append(tooltip_html) current_pos = ref_end_pos + len(ref_end_marker) ref_number = ref_number + 1 return ''.join(parts) # Initialize the CassandreChatBot cassandre_bot = CassandreChatBot() # CSS for styling css = """ .generation { margin-left:2em; margin-right:2em; } :target { background-color: #CCF3DF; } .source { float:left; max-width:17%; margin-left:2%; } .tooltip { position: relative; cursor: pointer; font-variant-position: super; color: #97999b; } .tooltip:hover::after { content: attr(data-text); position: absolute; left: 0; top: 120%; white-space: pre-wrap; width: 500px; max-width: 500px; z-index: 1; background-color: #f9f9f9; color: #000; border: 1px solid #ddd; border-radius: 5px; padding: 5px; display: block; box-shadow: 0 4px 8px rgba(0,0,0,0.1); } """ # Gradio interface def gradio_interface(user_message): response, sources = cassandre_bot.predict(user_message) return response, sources # Create Gradio app demo = gr.Blocks(css=css) with demo: gr.HTML("""

Cassandre

""") with gr.Row(): with gr.Column(scale=2): text_input = gr.Textbox(label="Votre question ou votre instruction", lines=3) text_button = gr.Button("Interroger Cassandre") with gr.Column(scale=3): text_output = gr.HTML(label="La réponse de Cassandre") with gr.Row(): embedding_output = gr.HTML(label="Les sources utilisées") text_button.click(gradio_interface, inputs=text_input, outputs=[text_output, embedding_output]) # Launch the app if __name__ == "__main__": demo.launch()