Spaces:
Runtime error
Runtime error
import transformers | |
import re | |
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM | |
from vllm import LLM, SamplingParams | |
import torch | |
import gradio as gr | |
import json | |
import os | |
import shutil | |
import requests | |
import chromadb | |
import difflib | |
import pandas as pd | |
from chromadb.config import Settings | |
from chromadb.utils import embedding_functions | |
# Define the device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_checkpoint = "PleIAs/Estienne" | |
token_classifier = pipeline( | |
"token-classification", model=editorial_model, aggregation_strategy="simple", device=device | |
) | |
tokenizer = AutoTokenizer.from_pretrained(editorial_model, model_max_length=512) | |
def split_text(text, max_tokens=500): | |
# Split the text by newline characters | |
parts = text.split("\n") | |
chunks = [] | |
current_chunk = "" | |
for part in parts: | |
# Add part to current chunk | |
if current_chunk: | |
temp_chunk = current_chunk + "\n" + part | |
else: | |
temp_chunk = part | |
# Tokenize the temporary chunk | |
num_tokens = len(tokenizer.tokenize(temp_chunk)) | |
if num_tokens <= max_tokens: | |
current_chunk = temp_chunk | |
else: | |
if current_chunk: | |
chunks.append(current_chunk) | |
current_chunk = part | |
if current_chunk: | |
chunks.append(current_chunk) | |
# If no newlines were found and still exceeding max_tokens, split further | |
if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens: | |
long_text = chunks[0] | |
chunks = [] | |
while len(tokenizer.tokenize(long_text)) > max_tokens: | |
split_point = len(long_text) // 2 | |
while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]): | |
split_point += 1 | |
# Ensure split_point does not go out of range | |
if split_point >= len(long_text): | |
split_point = len(long_text) - 1 | |
chunks.append(long_text[:split_point].strip()) | |
long_text = long_text[split_point:].strip() | |
if long_text: | |
chunks.append(long_text) | |
return chunks | |
#Curtesy of claude | |
def generate_html_diff(old_text, new_text): | |
d = difflib.Differ() | |
diff = list(d.compare(old_text.split(), new_text.split())) | |
html_diff = [] | |
for word in diff: | |
if word.startswith(' '): | |
html_diff.append(word[2:]) | |
elif word.startswith('+ '): | |
html_diff.append(f'<span style="background-color: #90EE90;">{word[2:]}</span>') | |
# We're not adding anything for words that start with '- ' | |
return ' '.join(html_diff) | |
# Class to encapsulate the Falcon chatbot | |
class MistralChatBot: | |
def __init__(self, system_prompt="Le dialogue suivant est une conversation"): | |
self.system_prompt = system_prompt | |
def predict(self, user_message): | |
#We drop the newlines. | |
editorial_text = re.sub("\n", " ¶ ", user_message) | |
# Tokenize the prompt and check if it exceeds 500 tokens | |
num_tokens = len(tokenizer.tokenize(prompt)) | |
if num_tokens > 500: | |
# Split the prompt into chunks | |
batch_prompts = split_text(prompt, max_tokens=500) | |
else: | |
batch_prompts = [prompt] | |
out = token_classifier(batch_prompts) | |
out = "".join(out) | |
generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + html_diff + "</div>" | |
return generated_text | |
# Create the Falcon chatbot instance | |
mistral_bot = MistralChatBot() | |
# Define the Gradio interface | |
title = "Éditorialisation" | |
description = "Un outil expérimental d'identification de la structure du texte à partir d'un encoder (Deberta)" | |
examples = [ | |
[ | |
"Qui peut bénéficier de l'AIP?", # user_message | |
0.7 # temperature | |
] | |
] | |
additional_inputs=[ | |
gr.Slider( | |
label="Température", | |
value=0.2, # Default value | |
minimum=0.05, | |
maximum=1.0, | |
step=0.05, | |
interactive=True, | |
info="Des valeurs plus élevées donne plus de créativité, mais aussi d'étrangeté", | |
), | |
] | |
demo = gr.Blocks() | |
with gr.Blocks(theme='JohnSmith9982/small_and_pretty', css=css) as demo: | |
gr.HTML("""<h1 style="text-align:center">Correction d'OCR</h1>""") | |
text_input = gr.Textbox(label="Votre texte.", type="text", lines=1) | |
text_button = gr.Button("Identifier les structures éditoriales") | |
text_output = gr.HTML(label="Le texte corrigé") | |
text_button.click(mistral_bot.predict, inputs=text_input, outputs=[text_output]) | |
if __name__ == "__main__": | |
demo.queue().launch() |