from llama_index.llms import HuggingFaceInferenceAPI
from llama_index.llms import ChatMessage, MessageRole
from llama_index.prompts import ChatPromptTemplate
from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext
import gradio as gr
import requests
import os
import json

def download_file(url, filename):
    """
    Download a file from the specified URL and save it locally under the given filename. """ response = requests.get(url, stream=True) # Check if the request was successful if filename in os.listdir('content/'): return if filename == '': return if response.status_code == 200: with open('content/' + filename, 'wb') as file: for chunk in response.iter_content(chunk_size=1024): if chunk: # filter out keep-alive new chunks file.write(chunk) print(f"Download complete: {filename}") else: print(f"Error: Unable to download file. HTTP status code: {response.status_code}") #def save_answer(prompt, rag_answer, norag_answer): # json_dict = dict() # json_dict['prompt'] = prompt # json_dict['rag_answer'] = rag_answer # json_dict['norag_answer'] = norag_answer # # file_path = 'saved_answers.json' # # # Check if the file exists # if not os.path.isfile(file_path): # with open(file_path, 'w') as f: # # Create an empty list in the file to store dictionaries # json.dump([], f) # f.write('\n') # Add a newline to separate the list and future entries # # # Open the file in append mode # with open(file_path, 'a+') as f: # # Read the existing data # f.seek(0) # data = json.load(f) # # # Append the new dictionary to the list # data.append(json_dict) # # # Move the cursor to the beginning of the file # f.seek(0) # # # Write the updated list of dictionaries # json.dump(data, f) # f.write('\n') # Add a newline to separate the list and future entries # # #def check_answer(prompt): # file_path = 'saved_answers.json' # # if not os.path.isfile(file_path): # with open(file_path, 'w') as f: # # Create an empty list in the file to store dictionaries # json.dump([], f) # f.write('\n') # Add a newline to separate the list and future entries # with open('saved_answers.json', 'r') as f: # data = json.load(f) # for entry in data: # if entry['prompt'] == prompt: # return entry['rag_answer'], entry['norag_answer'] # return None, None # Return None if the prompt is not found def save_answer(prompt, rag_answer, norag_answer): file_path = 'saved_answers.jsonl' # Create a dictionary for the current answer json_dict = { 'prompt': prompt, 'rag_answer': rag_answer, 'norag_answer': norag_answer } # Check if the file exists, and create it if not #if not os.path.isfile(file_path): # with open(file_path, 'w') as f: # # Create an empty list in the file to store dictionaries # json.dump([], f) # f.write('\n') # Add a newline to separate the list and future entries # Load existing data from the file existing_data = load_jsonl(file_path) # Append the new answer to the existing data existing_data.append(json_dict) # Save the updated data back to the file write_to_jsonl(file_path, existing_data) def check_answer(prompt): file_path = 'saved_answers.jsonl' ## Check if the file exists, and create it if not #if not os.path.isfile(file_path): # with open(file_path, 'w') as f: # # Create an empty list in the file to store dictionaries # json.dump([], f) # f.write('\n') # Add a newline to separate the list and future entries # Load existing data from the file try: existing_data = load_jsonl(file_path) except: return None, None if len(existing_data) == 0: return None, None # Find the answer for the given prompt, if it exists for entry in existing_data: if entry['prompt'] == prompt: return entry['rag_answer'], entry['norag_answer'] # Return None if the prompt is not found return None, None # Helper functions def load_jsonl(file_path): data = [] with open(file_path, 'r') as file: for line in file: # Each line is a JSON object item = json.loads(line) data.append(item) return data def write_to_jsonl(file_path, data): with open(file_path, 'a+') as file: for item in data: # Convert Python object to JSON string and write it to the file json_line = json.dumps(item) file.write(json_line + '\n') def generate(prompt, history, rag_only, file_link, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,): rag_answer, norag_answer = check_answer(prompt) if rag_answer != None: if rag_only: return f'* Mixtral + RAG Output:\n{rag_answer}' else: return f'* Mixtral Output:\n{norag_answer}\n\n* Mixtral + RAG Output:\n{rag_answer}' mixtral = HuggingFaceInferenceAPI( model_name="mistralai/Mixtral-8x7B-Instruct-v0.1" #Mistral-7B-Instruct-v0.2 ) service_context = ServiceContext.from_defaults( llm=mixtral, embed_model="local:BAAI/bge-small-en-v1.5" ) download = download_file(file_link,file_link.split("/")[-1]) documents = SimpleDirectoryReader("content/").load_data() index = VectorStoreIndex.from_documents(documents,service_context=service_context) # Text QA Prompt chat_text_qa_msgs = [ ChatMessage( role=MessageRole.SYSTEM, content=( "Always answer the question, even if the context isn't helpful." ), ), ChatMessage( role=MessageRole.USER, content=( "Context information is below.\n" "---------------------\n" "{context_str}\n" "---------------------\n" "Given the context information and not prior knowledge, " "answer the question: {query_str}\n" ), ), ] text_qa_template = ChatPromptTemplate(chat_text_qa_msgs) # Refine Prompt chat_refine_msgs = [ ChatMessage( role=MessageRole.SYSTEM, content=( "Always answer the question, even if the context isn't helpful." ), ), ChatMessage( role=MessageRole.USER, content=( "We have the opportunity to refine the original answer " "(only if needed) with some more context below.\n" "------------\n" "{context_msg}\n" "------------\n" "Given the new context, refine the original answer to better " "answer the question: {query_str}. " "If the context isn't useful, output the original answer again.\n" "Original Answer: {existing_answer}" ), ), ] refine_template = ChatPromptTemplate(chat_refine_msgs) temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) stream= index.as_query_engine( text_qa_template=text_qa_template, refine_template=refine_template, similarity_top_k=6, temperature = temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty = repetition_penalty ).query(prompt) print(str(stream)) output_rag= str(stream) #"" #output_norag = mixtral.complete(prompt, details=True, similarity_top_k=6, temperature = temperature, # max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty = repetition_penalty) #for response in str(stream): # output += response # yield output #print(output_norag) #result = generator.run(prompt, generation_kwargs={"max_new_tokens": 350}) #output_norag = result["replies"][0] ### NORAG if rag_only == False: chat_text_qa_msgs_nr = [ ChatMessage( role=MessageRole.SYSTEM, content=( "Always answer the question" ), ), ChatMessage( role=MessageRole.USER, content=( "answer the question: {query_str}\n" ), ), ] text_qa_template_nr = ChatPromptTemplate(chat_text_qa_msgs_nr) # Refine Prompt chat_refine_msgs_nr = [ ChatMessage( role=MessageRole.SYSTEM, content=( "Always answer the question" ), ), ChatMessage( role=MessageRole.USER, content=( "answer the question: {query_str}. " "If the context isn't useful, output the original answer again.\n" "Original Answer: {existing_answer}" ), ), ] refine_template_nr = ChatPromptTemplate(chat_refine_msgs_nr) stream_nr= index.as_query_engine( text_qa_template=text_qa_template_nr, refine_template=refine_template_nr, similarity_top_k=6 ).query(prompt) ### output_norag = str(stream_nr) save_answer(prompt, output_rag, output_norag) return f'* Mixtral Output:\n{output_norag}\n\n* Mixtral + RAG Output:\n{output_rag}' return f'* Mixtral + RAG Output:\n{output_rag}' #for response in formatted_output: # output += response # yield output #return formatted_output def upload_file(files): file_paths = [file.name for file in files] return file_paths additional_inputs=[ gr.Checkbox( label="RAG Only", interactive=True, value= False ), gr.Textbox( label="File Link", max_lines=1, interactive=True, value= "https://arxiv.org/pdf/2401.10020.pdf" ), gr.Slider( label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ), gr.Slider( label="Max new tokens", value=1024, minimum=0, maximum=2048, step=64, interactive=True, info="The maximum numbers of new tokens", ), gr.Slider( label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ), gr.Slider( label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens", ) ] examples=[["What is a trustworthy digital repository, where can you find this information?", None, None, None, None, None, None, ], ["What are things a repository must have?", None, None, None, None, None, None,], ["What principles should record creators follow?", None, None, None, None, None, None,], ["Write a very short summary of Data Sanitation Techniques by Edgar Dale, and write a citation in APA style.", None, None, None, None, None, None,], ["Can you explain how the QuickSort algorithm works and provide a Python implementation?", None, None, None, None, None, None,], ["What are some unique features of Rust that make it stand out compared to other systems programming languages like C++?", None, None, None, None, None, None,], ] gr.ChatInterface( fn=generate, chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"), additional_inputs=additional_inputs, title="RAG Demo", examples=examples, #concurrency_limit=20, ).queue().launch(show_api=False,debug=True,share=True) #iface = gr.Interface(fn=generate, inputs=["text"], outputs=["text", "text"], # additional_inputs=additional_inputs, title="RAG Demo", examples=examples) #iface.launch()