Spaces:
Paused
Paused
import transformers | |
import numpy as np | |
import re | |
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM | |
from vllm import LLM, SamplingParams | |
import torch | |
import gradio as gr | |
import json | |
import os | |
import shutil | |
import requests | |
from pprint import pprint | |
import chromadb | |
import pandas as pd | |
from sklearn.metrics.pairwise import cosine_similarity | |
pd.set_option('display.max_columns', None) | |
#sampling_params = SamplingParams(temperature=.7, top_p=.95, max_tokens=2000, presence_penalty = 1.5, stop = ["``"]) | |
# Define the device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
#Define variables | |
temperature=0.2 | |
max_new_tokens=1000 | |
top_p=0.92 | |
repetition_penalty=1.7 | |
model_name = "Inagua/code-model" | |
llm = LLM(model_name, max_model_len=4096) | |
#CSS for references formatting | |
css = """ | |
.generation { | |
margin-left:2em; | |
margin-right:2em; | |
} | |
:target { | |
background-color: #CCF3DF; /* Change the text color to red */ | |
} | |
.source { | |
float:left; | |
max-width:17%; | |
margin-left:2%; | |
} | |
.tooltip { | |
position: relative; | |
cursor: pointer; | |
font-variant-position: super; | |
color: #97999b; | |
} | |
.tooltip:hover::after { | |
content: attr(data-text); | |
position: absolute; | |
left: 0; | |
top: 120%; /* Adjust this value as needed to control the vertical spacing between the text and the tooltip */ | |
white-space: pre-wrap; /* Allows the text to wrap */ | |
width: 500px; /* Sets a fixed maximum width for the tooltip */ | |
max-width: 500px; /* Ensures the tooltip does not exceed the maximum width */ | |
z-index: 1; | |
background-color: #f9f9f9; | |
color: #000; | |
border: 1px solid #ddd; | |
border-radius: 5px; | |
padding: 5px; | |
display: block; | |
box-shadow: 0 4px 8px rgba(0,0,0,0.1); /* Optional: Adds a subtle shadow for better visibility */ | |
}""" | |
#Curtesy of chatgpt | |
def format_references(text): | |
# Define start and end markers for the reference | |
ref_start_marker = '<ref text="' | |
ref_end_marker = '</ref>' | |
# Initialize an empty list to hold parts of the text | |
parts = [] | |
current_pos = 0 | |
ref_number = 1 | |
# Loop until no more reference start markers are found | |
while True: | |
start_pos = text.find(ref_start_marker, current_pos) | |
if start_pos == -1: | |
# No more references found, add the rest of the text | |
parts.append(text[current_pos:]) | |
break | |
# Add text up to the start of the reference | |
parts.append(text[current_pos:start_pos]) | |
# Find the end of the reference text attribute | |
end_pos = text.find('">', start_pos) | |
if end_pos == -1: | |
# Malformed reference, break to avoid infinite loop | |
break | |
# Extract the reference text | |
ref_text = text[start_pos + len(ref_start_marker):end_pos].replace('\n', ' ').strip() | |
ref_text_encoded = ref_text.replace("&", "&").replace("<", "<").replace(">", ">") | |
# Find the end of the reference tag | |
ref_end_pos = text.find(ref_end_marker, end_pos) | |
if ref_end_pos == -1: | |
# Malformed reference, break to avoid infinite loop | |
break | |
# Extract the reference ID | |
ref_id = text[end_pos + 2:ref_end_pos].strip() | |
# Create the HTML for the tooltip | |
tooltip_html = f'<span class="tooltip" data-refid="{ref_id}" data-text="{ref_id}: {ref_text_encoded}"><a href="#{ref_id}">[' + str(ref_number) +']</a></span>' | |
parts.append(tooltip_html) | |
# Update current_pos to the end of the current reference | |
current_pos = ref_end_pos + len(ref_end_marker) | |
ref_number = ref_number + 1 | |
# Join and return the parts | |
parts = ''.join(parts) | |
return parts | |
# Class to encapsulate the Falcon chatbot | |
class MistralChatBot: | |
def __init__(self, system_prompt="Le dialogue suivant est une conversation"): | |
self.system_prompt = system_prompt | |
def predict(self, user_message, context): | |
detailed_prompt = """### Question ###\n""" + user_message + "\n\n### Contexte ###\n" + context + "\n\n### Formule ###\n" | |
prompts = [detailed_prompt] | |
outputs = llm.generate(prompts, sampling_params, use_tqdm = False) | |
generated_text = outputs[0].outputs[0].text | |
generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + generated_text + "</div>" | |
fiches_html = "" | |
return generated_text, fiches_html | |
# Create the Falcon chatbot instance | |
mistral_bot = MistralChatBot() | |
# Define the Gradio interface | |
title = "Inagua" | |
description = "An experimental LLM to interact with DAMAaaS documentation" | |
examples = [ | |
[ | |
"How to calculate a linear regression?", # user_message | |
0.7 # temperature | |
] | |
] | |
additional_inputs=[ | |
gr.Slider( | |
label="Température", | |
value=0.2, # Default value | |
minimum=0.05, | |
maximum=1.0, | |
step=0.05, | |
interactive=True, | |
info="Des valeurs plus élevées donne plus de créativité, mais aussi d'étrangeté", | |
), | |
] | |
demo = gr.Blocks() | |
with gr.Blocks(theme='gradio/monochrome', css=css) as demo: | |
gr.HTML("""<h1 style="text-align:center">SkikitLLM</h1>""") | |
text_input = gr.Textbox(label="Your question", type="text", lines=1) | |
context_input = gr.Textbox(label="Your context", type="text", lines=1) | |
text_button = gr.Button("Query SkikitLLM") | |
text_output = gr.HTML(label="Answer") | |
text_button.click(mistral_bot.predict, inputs=[text_input, context_input], outputs=[text_output]) | |
if __name__ == "__main__": | |
demo.queue().launch() |