Al-Alcoba-Inciarte's picture
Update app.py
826babd verified
#from haystack.components.generators import HuggingFaceTGIGenerator
from llama_index.llms import HuggingFaceInferenceAPI
from llama_index.llms import ChatMessage, MessageRole
from llama_index.prompts import ChatPromptTemplate
from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext #, LLMPredictor, StorageContext, load_index_from_storage
import gradio as gr
#import sys
#import logging
#import torch
#from huggingface_hub import InferenceClient
#import tqdm as notebook_tqdm
import requests
import os
import json
#generator = HuggingFaceTGIGenerator("mistralai/Mixtral-8x7B-Instruct-v0.1")
#generator.warm_up()
def download_file(url, filename):
"""
Download a file from the specified URL and save it locally under the given filename.
"""
response = requests.get(url, stream=True)
# Check if the request was successful
if filename in os.listdir('content/'): return
if filename == '': return
if response.status_code == 200:
with open('content/' + filename, 'wb') as file:
for chunk in response.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
file.write(chunk)
print(f"Download complete: {filename}")
else:
print(f"Error: Unable to download file. HTTP status code: {response.status_code}")
#def save_answer(prompt, rag_answer, norag_answer):
# json_dict = dict()
# json_dict['prompt'] = prompt
# json_dict['rag_answer'] = rag_answer
# json_dict['norag_answer'] = norag_answer
#
# file_path = 'saved_answers.json'
#
# # Check if the file exists
# if not os.path.isfile(file_path):
# with open(file_path, 'w') as f:
# # Create an empty list in the file to store dictionaries
# json.dump([], f)
# f.write('\n') # Add a newline to separate the list and future entries
#
# # Open the file in append mode
# with open(file_path, 'a+') as f:
# # Read the existing data
# f.seek(0)
# data = json.load(f)
#
# # Append the new dictionary to the list
# data.append(json_dict)
#
# # Move the cursor to the beginning of the file
# f.seek(0)
#
# # Write the updated list of dictionaries
# json.dump(data, f)
# f.write('\n') # Add a newline to separate the list and future entries
#
#
#def check_answer(prompt):
# file_path = 'saved_answers.json'
#
# if not os.path.isfile(file_path):
# with open(file_path, 'w') as f:
# # Create an empty list in the file to store dictionaries
# json.dump([], f)
# f.write('\n') # Add a newline to separate the list and future entries
# with open('saved_answers.json', 'r') as f:
# data = json.load(f)
# for entry in data:
# if entry['prompt'] == prompt:
# return entry['rag_answer'], entry['norag_answer']
# return None, None # Return None if the prompt is not found
def save_answer(prompt, rag_answer, norag_answer):
file_path = 'saved_answers.jsonl'
# Create a dictionary for the current answer
json_dict = {
'prompt': prompt,
'rag_answer': rag_answer,
'norag_answer': norag_answer
}
# Check if the file exists, and create it if not
#if not os.path.isfile(file_path):
# with open(file_path, 'w') as f:
# # Create an empty list in the file to store dictionaries
# json.dump([], f)
# f.write('\n') # Add a newline to separate the list and future entries
# Load existing data from the file
existing_data = load_jsonl(file_path)
# Append the new answer to the existing data
existing_data.append(json_dict)
# Save the updated data back to the file
write_to_jsonl(file_path, existing_data)
def check_answer(prompt):
file_path = 'saved_answers.jsonl'
## Check if the file exists, and create it if not
#if not os.path.isfile(file_path):
# with open(file_path, 'w') as f:
# # Create an empty list in the file to store dictionaries
# json.dump([], f)
# f.write('\n') # Add a newline to separate the list and future entries
# Load existing data from the file
try:
existing_data = load_jsonl(file_path)
except:
return None, None
if len(existing_data) == 0:
return None, None
# Find the answer for the given prompt, if it exists
for entry in existing_data:
if entry['prompt'] == prompt:
return entry['rag_answer'], entry['norag_answer']
# Return None if the prompt is not found
return None, None
# Helper functions
def load_jsonl(file_path):
data = []
with open(file_path, 'r') as file:
for line in file:
# Each line is a JSON object
item = json.loads(line)
data.append(item)
return data
def write_to_jsonl(file_path, data):
with open(file_path, 'a+') as file:
for item in data:
# Convert Python object to JSON string and write it to the file
json_line = json.dumps(item)
file.write(json_line + '\n')
def generate(prompt, history, rag_only, file_link, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,):
rag_answer, norag_answer = check_answer(prompt)
if rag_answer != None:
if rag_only:
return f'* Mixtral + RAG Output:\n{rag_answer}'
else:
return f'* Mixtral Output:\n{norag_answer}\n\n* Mixtral + RAG Output:\n{rag_answer}'
mixtral = HuggingFaceInferenceAPI(
model_name="mistralai/Mixtral-8x7B-Instruct-v0.1"
#Mistral-7B-Instruct-v0.2
)
service_context = ServiceContext.from_defaults(
llm=mixtral, embed_model="local:BAAI/bge-small-en-v1.5"
)
download = download_file(file_link,file_link.split("/")[-1])
documents = SimpleDirectoryReader("content/").load_data()
index = VectorStoreIndex.from_documents(documents,service_context=service_context)
# Text QA Prompt
chat_text_qa_msgs = [
ChatMessage(
role=MessageRole.SYSTEM,
content=(
"Always answer the question, even if the context isn't helpful."
),
),
ChatMessage(
role=MessageRole.USER,
content=(
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given the context information and not prior knowledge, "
"answer the question: {query_str}\n"
),
),
]
text_qa_template = ChatPromptTemplate(chat_text_qa_msgs)
# Refine Prompt
chat_refine_msgs = [
ChatMessage(
role=MessageRole.SYSTEM,
content=(
"Always answer the question, even if the context isn't helpful."
),
),
ChatMessage(
role=MessageRole.USER,
content=(
"We have the opportunity to refine the original answer "
"(only if needed) with some more context below.\n"
"------------\n"
"{context_msg}\n"
"------------\n"
"Given the new context, refine the original answer to better "
"answer the question: {query_str}. "
"If the context isn't useful, output the original answer again.\n"
"Original Answer: {existing_answer}"
),
),
]
refine_template = ChatPromptTemplate(chat_refine_msgs)
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
stream= index.as_query_engine(
text_qa_template=text_qa_template, refine_template=refine_template, similarity_top_k=6, temperature = temperature,
max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty = repetition_penalty
).query(prompt)
print(str(stream))
output_rag= str(stream) #""
#output_norag = mixtral.complete(prompt, details=True, similarity_top_k=6, temperature = temperature,
# max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty = repetition_penalty)
#for response in str(stream):
# output += response
# yield output
#print(output_norag)
#result = generator.run(prompt, generation_kwargs={"max_new_tokens": 350})
#output_norag = result["replies"][0]
### NORAG
if rag_only == False:
chat_text_qa_msgs_nr = [
ChatMessage(
role=MessageRole.SYSTEM,
content=(
"Always answer the question"
),
),
ChatMessage(
role=MessageRole.USER,
content=(
"answer the question: {query_str}\n"
),
),
]
text_qa_template_nr = ChatPromptTemplate(chat_text_qa_msgs_nr)
# Refine Prompt
chat_refine_msgs_nr = [
ChatMessage(
role=MessageRole.SYSTEM,
content=(
"Always answer the question"
),
),
ChatMessage(
role=MessageRole.USER,
content=(
"answer the question: {query_str}. "
"If the context isn't useful, output the original answer again.\n"
"Original Answer: {existing_answer}"
),
),
]
refine_template_nr = ChatPromptTemplate(chat_refine_msgs_nr)
stream_nr= index.as_query_engine(
text_qa_template=text_qa_template_nr, refine_template=refine_template_nr, similarity_top_k=6
).query(prompt)
###
output_norag = str(stream_nr)
save_answer(prompt, output_rag, output_norag)
return f'* Mixtral Output:\n{output_norag}\n\n* Mixtral + RAG Output:\n{output_rag}'
return f'* Mixtral + RAG Output:\n{output_rag}'
#for response in formatted_output:
# output += response
# yield output
#return formatted_output
def upload_file(files):
file_paths = [file.name for file in files]
return file_paths
additional_inputs=[
gr.Checkbox(
label="RAG Only",
interactive=True,
value= False
),
gr.Textbox(
label="File Link",
max_lines=1,
interactive=True,
value= "https://arxiv.org/pdf/2401.10020.pdf"
),
gr.Slider(
label="Temperature",
value=0.9,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=1024,
minimum=0,
maximum=2048,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.90,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
examples=[["What is a trustworthy digital repository, where can you find this information?", None, None, None, None, None, None, ],
["What are things a repository must have?", None, None, None, None, None, None,],
["What principles should record creators follow?", None, None, None, None, None, None,],
["Write a very short summary of Data Sanitation Techniques by Edgar Dale, and write a citation in APA style.", None, None, None, None, None, None,],
["Can you explain how the QuickSort algorithm works and provide a Python implementation?", None, None, None, None, None, None,],
["What are some unique features of Rust that make it stand out compared to other systems programming languages like C++?", None, None, None, None, None, None,],
]
gr.ChatInterface(
fn=generate,
chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
additional_inputs=additional_inputs,
title="RAG Demo",
examples=examples,
#concurrency_limit=20,
).queue().launch(show_api=False,debug=True,share=True)
#iface = gr.Interface(fn=generate, inputs=["text"], outputs=["text", "text"],
# additional_inputs=additional_inputs, title="RAG Demo", examples=examples)
#iface.launch()