|
|
|
from llama_index.llms import HuggingFaceInferenceAPI |
|
from llama_index.llms import ChatMessage, MessageRole |
|
from llama_index.prompts import ChatPromptTemplate |
|
from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext |
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
import requests |
|
import os |
|
import json |
|
|
|
|
|
|
|
|
|
def download_file(url, filename): |
|
""" |
|
Download a file from the specified URL and save it locally under the given filename. |
|
""" |
|
|
|
response = requests.get(url, stream=True) |
|
|
|
|
|
|
|
if filename in os.listdir('content/'): return |
|
if filename == '': return |
|
|
|
if response.status_code == 200: |
|
with open('content/' + filename, 'wb') as file: |
|
for chunk in response.iter_content(chunk_size=1024): |
|
if chunk: |
|
file.write(chunk) |
|
print(f"Download complete: {filename}") |
|
else: |
|
print(f"Error: Unable to download file. HTTP status code: {response.status_code}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_answer(prompt, rag_answer, norag_answer): |
|
file_path = 'saved_answers.jsonl' |
|
|
|
|
|
json_dict = { |
|
'prompt': prompt, |
|
'rag_answer': rag_answer, |
|
'norag_answer': norag_answer |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
existing_data = load_jsonl(file_path) |
|
|
|
|
|
existing_data.append(json_dict) |
|
|
|
|
|
write_to_jsonl(file_path, existing_data) |
|
|
|
def check_answer(prompt): |
|
file_path = 'saved_answers.jsonl' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
existing_data = load_jsonl(file_path) |
|
|
|
except: |
|
return None, None |
|
|
|
if len(existing_data) == 0: |
|
return None, None |
|
|
|
|
|
for entry in existing_data: |
|
if entry['prompt'] == prompt: |
|
return entry['rag_answer'], entry['norag_answer'] |
|
|
|
|
|
return None, None |
|
|
|
|
|
def load_jsonl(file_path): |
|
data = [] |
|
with open(file_path, 'r') as file: |
|
for line in file: |
|
|
|
item = json.loads(line) |
|
data.append(item) |
|
return data |
|
|
|
def write_to_jsonl(file_path, data): |
|
with open(file_path, 'a+') as file: |
|
for item in data: |
|
|
|
json_line = json.dumps(item) |
|
file.write(json_line + '\n') |
|
|
|
|
|
|
|
def generate(prompt, history, rag_only, file_link, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,): |
|
|
|
rag_answer, norag_answer = check_answer(prompt) |
|
|
|
if rag_answer != None: |
|
if rag_only: |
|
return f'* Mixtral + RAG Output:\n{rag_answer}' |
|
else: |
|
return f'* Mixtral Output:\n{norag_answer}\n\n* Mixtral + RAG Output:\n{rag_answer}' |
|
|
|
mixtral = HuggingFaceInferenceAPI( |
|
model_name="mistralai/Mixtral-8x7B-Instruct-v0.1" |
|
|
|
) |
|
|
|
service_context = ServiceContext.from_defaults( |
|
llm=mixtral, embed_model="local:BAAI/bge-small-en-v1.5" |
|
) |
|
|
|
download = download_file(file_link,file_link.split("/")[-1]) |
|
|
|
documents = SimpleDirectoryReader("content/").load_data() |
|
|
|
index = VectorStoreIndex.from_documents(documents,service_context=service_context) |
|
|
|
|
|
chat_text_qa_msgs = [ |
|
ChatMessage( |
|
role=MessageRole.SYSTEM, |
|
content=( |
|
"Always answer the question, even if the context isn't helpful." |
|
), |
|
), |
|
ChatMessage( |
|
role=MessageRole.USER, |
|
content=( |
|
"Context information is below.\n" |
|
"---------------------\n" |
|
"{context_str}\n" |
|
"---------------------\n" |
|
"Given the context information and not prior knowledge, " |
|
"answer the question: {query_str}\n" |
|
), |
|
), |
|
] |
|
text_qa_template = ChatPromptTemplate(chat_text_qa_msgs) |
|
|
|
|
|
chat_refine_msgs = [ |
|
ChatMessage( |
|
role=MessageRole.SYSTEM, |
|
content=( |
|
"Always answer the question, even if the context isn't helpful." |
|
), |
|
), |
|
ChatMessage( |
|
role=MessageRole.USER, |
|
content=( |
|
"We have the opportunity to refine the original answer " |
|
"(only if needed) with some more context below.\n" |
|
"------------\n" |
|
"{context_msg}\n" |
|
"------------\n" |
|
"Given the new context, refine the original answer to better " |
|
"answer the question: {query_str}. " |
|
"If the context isn't useful, output the original answer again.\n" |
|
"Original Answer: {existing_answer}" |
|
), |
|
), |
|
] |
|
refine_template = ChatPromptTemplate(chat_refine_msgs) |
|
|
|
temperature = float(temperature) |
|
if temperature < 1e-2: |
|
temperature = 1e-2 |
|
top_p = float(top_p) |
|
|
|
stream= index.as_query_engine( |
|
text_qa_template=text_qa_template, refine_template=refine_template, similarity_top_k=6, temperature = temperature, |
|
max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty = repetition_penalty |
|
).query(prompt) |
|
print(str(stream)) |
|
|
|
output_rag= str(stream) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if rag_only == False: |
|
chat_text_qa_msgs_nr = [ |
|
ChatMessage( |
|
role=MessageRole.SYSTEM, |
|
content=( |
|
"Always answer the question" |
|
), |
|
), |
|
ChatMessage( |
|
role=MessageRole.USER, |
|
content=( |
|
"answer the question: {query_str}\n" |
|
), |
|
), |
|
] |
|
text_qa_template_nr = ChatPromptTemplate(chat_text_qa_msgs_nr) |
|
|
|
|
|
chat_refine_msgs_nr = [ |
|
ChatMessage( |
|
role=MessageRole.SYSTEM, |
|
content=( |
|
"Always answer the question" |
|
), |
|
), |
|
ChatMessage( |
|
role=MessageRole.USER, |
|
content=( |
|
"answer the question: {query_str}. " |
|
"If the context isn't useful, output the original answer again.\n" |
|
"Original Answer: {existing_answer}" |
|
), |
|
), |
|
] |
|
refine_template_nr = ChatPromptTemplate(chat_refine_msgs_nr) |
|
|
|
stream_nr= index.as_query_engine( |
|
text_qa_template=text_qa_template_nr, refine_template=refine_template_nr, similarity_top_k=6 |
|
).query(prompt) |
|
|
|
|
|
|
|
output_norag = str(stream_nr) |
|
save_answer(prompt, output_rag, output_norag) |
|
|
|
return f'* Mixtral Output:\n{output_norag}\n\n* Mixtral + RAG Output:\n{output_rag}' |
|
|
|
return f'* Mixtral + RAG Output:\n{output_rag}' |
|
|
|
|
|
|
|
|
|
|
|
|
|
def upload_file(files): |
|
file_paths = [file.name for file in files] |
|
return file_paths |
|
|
|
additional_inputs=[ |
|
gr.Checkbox( |
|
label="RAG Only", |
|
interactive=True, |
|
value= False |
|
), |
|
gr.Textbox( |
|
label="File Link", |
|
max_lines=1, |
|
interactive=True, |
|
value= "https://arxiv.org/pdf/2401.10020.pdf" |
|
), |
|
gr.Slider( |
|
label="Temperature", |
|
value=0.9, |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.05, |
|
interactive=True, |
|
info="Higher values produce more diverse outputs", |
|
), |
|
gr.Slider( |
|
label="Max new tokens", |
|
value=1024, |
|
minimum=0, |
|
maximum=2048, |
|
step=64, |
|
interactive=True, |
|
info="The maximum numbers of new tokens", |
|
), |
|
gr.Slider( |
|
label="Top-p (nucleus sampling)", |
|
value=0.90, |
|
minimum=0.0, |
|
maximum=1, |
|
step=0.05, |
|
interactive=True, |
|
info="Higher values sample more low-probability tokens", |
|
), |
|
gr.Slider( |
|
label="Repetition penalty", |
|
value=1.2, |
|
minimum=1.0, |
|
maximum=2.0, |
|
step=0.05, |
|
interactive=True, |
|
info="Penalize repeated tokens", |
|
) |
|
] |
|
|
|
examples=[["What is a trustworthy digital repository, where can you find this information?", None, None, None, None, None, None, ], |
|
["What are things a repository must have?", None, None, None, None, None, None,], |
|
["What principles should record creators follow?", None, None, None, None, None, None,], |
|
["Write a very short summary of Data Sanitation Techniques by Edgar Dale, and write a citation in APA style.", None, None, None, None, None, None,], |
|
["Can you explain how the QuickSort algorithm works and provide a Python implementation?", None, None, None, None, None, None,], |
|
["What are some unique features of Rust that make it stand out compared to other systems programming languages like C++?", None, None, None, None, None, None,], |
|
] |
|
|
|
gr.ChatInterface( |
|
fn=generate, |
|
chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"), |
|
additional_inputs=additional_inputs, |
|
title="RAG Demo", |
|
examples=examples, |
|
|
|
).queue().launch(show_api=False,debug=True,share=True) |
|
|
|
|
|
|
|
|
|
|