Spaces:
Paused
Paused
import transformers | |
import re | |
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM | |
import torch | |
import gradio as gr | |
import json | |
import os | |
import shutil | |
import requests | |
import lancedb | |
import pandas as pd | |
# Define the device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_name = "PleIAs/Pleias-Rag" | |
# Get Hugging Face token from environment variable | |
hf_token = os.environ.get('HF_TOKEN') | |
if not hf_token: | |
raise ValueError("Please set the HF_TOKEN environment variable") | |
# Initialize model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) | |
model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_token) | |
model.to(device) | |
# Set tokenizer configuration | |
tokenizer.eos_token = "<|answer_end|>" | |
eos_token_id=tokenizer.eos_token_id | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.pad_token_id = 1 | |
# Define variables | |
temperature = 0.0 | |
max_new_tokens = 1200 | |
top_p = 0.95 | |
repetition_penalty = 1.0 | |
min_new_tokens = 600 | |
early_stopping = False | |
# Connect to the LanceDB database | |
db = lancedb.connect("content19/lancedb_data") | |
table = db.open_table("edunat19") | |
def hybrid_search(text): | |
results = table.search(text, query_type="hybrid").limit(5).to_pandas() | |
# Add a check for duplicate hashes | |
seen_hashes = set() | |
document = [] | |
document_html = [] | |
for _, row in results.iterrows(): | |
hash_id = str(row['hash']) | |
# Skip if we've already seen this hash | |
if hash_id in seen_hashes: | |
continue | |
seen_hashes.add(hash_id) | |
title = row['section'] | |
content = row['text'] | |
document.append(f"<|source_start|><|source_id_start|>{hash_id}<|source_id_end|>{title}\n{content}<|source_end|>") | |
document_html.append(f'<div class="source" id="{hash_id}"><p><b>{hash_id}</b> : {title}<br>{content}</div>') | |
document = "\n".join(document) | |
document_html = '<div id="source_listing">' + "".join(document_html) + "</div>" | |
return document, document_html | |
class pleiasBot: | |
def __init__(self, system_prompt="Tu es Appli, un asistant de recherche qui donne des responses sourcées"): | |
self.system_prompt = system_prompt | |
def predict(self, user_message): | |
fiches, fiches_html = hybrid_search(user_message) | |
detailed_prompt = f"""<|query_start|>{user_message}<|query_end|>\n{fiches}\n<|source_analysis_start|>""" | |
# Convert inputs to tensor | |
input_ids = tokenizer.encode(detailed_prompt, return_tensors="pt").to(device) | |
attention_mask = torch.ones_like(input_ids) | |
try: | |
output = model.generate( | |
input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=max_new_tokens, | |
do_sample=False, | |
early_stopping=early_stopping, | |
min_new_tokens=min_new_tokens, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
# Decode the generated text | |
generated_text = tokenizer.decode(output[0][len(input_ids[0]):]) | |
# Split the text into analysis and answer sections | |
parts = generated_text.split("<|source_analysis_end|>") | |
if len(parts) == 2: | |
analysis = parts[0].strip() | |
answer = parts[1].replace("<|answer_start|>", "").replace("<|answer_end|>", "").strip() | |
# Format each section with matching h2 titles | |
analysis_text = '<h2 style="text-align:center">Analyse des sources</h2>\n<div class="generation">' + format_references(analysis) + "</div>" | |
answer_text = '<h2 style="text-align:center">Réponse</h2>\n<div class="generation">' + format_references(answer) + "</div>" | |
else: | |
analysis_text = "" | |
answer_text = format_references(generated_text) | |
fiches_html = '<h2 style="text-align:center">Sources</h2>\n' + fiches_html | |
return analysis_text, answer_text, fiches_html | |
except Exception as e: | |
print(f"Error during generation: {str(e)}") | |
import traceback | |
traceback.print_exc() | |
return None, None, None | |
def format_references(text): | |
ref_pattern = r'<ref name="([^"]+)">"([^"]+)"</ref>\.\s*' # Modified pattern to include the period and whitespace after ref | |
parts = [] | |
current_pos = 0 | |
ref_number = 1 | |
for match in re.finditer(ref_pattern, text): | |
# Add text before the reference | |
text_before = text[current_pos:match.start()].rstrip() | |
parts.append(text_before) | |
# Extract reference components | |
ref_id = match.group(1) | |
ref_text = match.group(2).strip() | |
# Add the reference, keeping the existing structure but adding <br> where whitespace was | |
tooltip_html = f'<span class="tooltip"><strong>[{ref_number}]</strong><span class="tooltiptext"><strong>{ref_id}</strong>: {ref_text}</span></span>.<br>' | |
parts.append(tooltip_html) | |
current_pos = match.end() | |
ref_number += 1 | |
# Add any remaining text | |
parts.append(text[current_pos:]) | |
return ''.join(parts) | |
# Initialize the pleiasBot | |
pleias_bot = pleiasBot() | |
# CSS for styling | |
css = """ | |
.generation { | |
margin-left: 2em; | |
margin-right: 2em; | |
} | |
:target { | |
background-color: #CCF3DF; | |
} | |
.source { | |
float: left; | |
max-width: 17%; | |
margin-left: 2%; | |
} | |
.tooltip { | |
position: relative; | |
display: inline-block; | |
color: #183EFA; | |
font-weight: bold; | |
cursor: pointer; | |
} | |
.tooltip .tooltiptext { | |
visibility: hidden; | |
background-color: #fff; | |
color: #000; | |
text-align: left; | |
padding: 12px; | |
border-radius: 6px; | |
border: 1px solid #e5e7eb; | |
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06); | |
position: absolute; | |
z-index: 1; | |
bottom: 125%; | |
left: 50%; | |
transform: translateX(-50%); | |
min-width: 300px; | |
max-width: 400px; | |
white-space: normal; | |
font-size: 0.9em; | |
line-height: 1.4; | |
} | |
.tooltip:hover .tooltiptext { | |
visibility: visible; | |
} | |
.tooltip .tooltiptext::after { | |
content: ""; | |
position: absolute; | |
top: 100%; | |
left: 50%; | |
margin-left: -5px; | |
border-width: 5px; | |
border-style: solid; | |
border-color: #fff transparent transparent transparent; | |
} | |
.section-title { | |
font-weight: bold; | |
font-size: 15px; | |
margin-bottom: 1em; | |
margin-top: 1em; | |
} | |
""" | |
# Gradio interface | |
def gradio_interface(user_message): | |
analysis, response, sources = pleias_bot.predict(user_message) | |
return analysis, response, sources | |
# Create Gradio app | |
demo = gr.Blocks(css=css) | |
with demo: | |
# Header with black bar | |
gr.HTML(""" | |
<div style="display: flex; justify-content: center; width: 100%; background-color: black; padding: 5px 0;"> | |
<pre style="font-family: monospace; line-height: 1.2; font-size: 12px; color: #00ffea; margin: 0;"> | |
_ _ ______ ___ _____ | |
| | (_) | ___ \\/ _ \\| __ \\ | |
_ __ | | ___ _ __ _ ___ ______ | |_/ / /_\\ \\ | \\/ | |
| '_ \\| |/ _ \\ |/ _` / __| |______| | /| _ | | __ | |
| |_) | | __/ | (_| \\__ \\ | |\\ \\| | | | |_\\ \\ | |
| .__/|_|\\___|_|\\__,_|___/ \\_| \\_\\_| |_/\\____/ | |
| | | |
|_| </pre> | |
</div> | |
""") | |
# Centered input section | |
with gr.Column(scale=1): | |
text_input = gr.Textbox(label="Votre question ou votre instruction", lines=3) | |
text_button = gr.Button("Interroger pleias-RAG") | |
# Analysis and Response in side-by-side columns | |
with gr.Row(): | |
# Left column for analysis | |
with gr.Column(scale=2): | |
text_output = gr.HTML(label="Analyse des sources") | |
# Right column for response | |
with gr.Column(scale=3): | |
response_output = gr.HTML(label="Réponse") | |
# Sources at the bottom | |
with gr.Row(): | |
embedding_output = gr.HTML(label="Les sources utilisées") | |
text_button.click(gradio_interface, | |
inputs=text_input, | |
outputs=[text_output, response_output, embedding_output]) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() |