RAG-Demo / app.py
gagan3012's picture
Create app.py
ce19127 verified
from llama_index.llms import HuggingFaceInferenceAPI
from llama_index.llms import ChatMessage, MessageRole
from llama_index.prompts import ChatPromptTemplate
from llama_index import VectorStoreIndex, SimpleDirectoryReader, LLMPredictor, ServiceContext, StorageContext, load_index_from_storage
import gradio as gr
import sys
import logging
import torch
from huggingface_hub import InferenceClient
import tqdm as notebook_tqdm
import requests
def download_file(url, filename):
"""
Download a file from the specified URL and save it locally under the given filename.
"""
response = requests.get(url, stream=True)
# Check if the request was successful
if response.status_code == 200:
with open(filename, 'wb') as file:
for chunk in response.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
file.write(chunk)
print(f"Download complete: {filename}")
else:
print(f"Error: Unable to download file. HTTP status code: {response.status_code}")
def generate(prompt, history, file_link, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,):
mixtral = HuggingFaceInferenceAPI(
model_name="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
service_context = ServiceContext.from_defaults(
llm=mixtral, embed_model="local:BAAI/bge-small-en-v1.5"
)
download = download_file(file_link,file_link.split("/")[-1])
documents = SimpleDirectoryReader("/content").load_data()
index = VectorStoreIndex.from_documents(documents,service_context=service_context)
# Text QA Prompt
chat_text_qa_msgs = [
ChatMessage(
role=MessageRole.SYSTEM,
content=(
"Always answer the question, even if the context isn't helpful."
),
),
ChatMessage(
role=MessageRole.USER,
content=(
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given the context information and not prior knowledge, "
"answer the question: {query_str}\n"
),
),
]
text_qa_template = ChatPromptTemplate(chat_text_qa_msgs)
# Refine Prompt
chat_refine_msgs = [
ChatMessage(
role=MessageRole.SYSTEM,
content=(
"Always answer the question, even if the context isn't helpful."
),
),
ChatMessage(
role=MessageRole.USER,
content=(
"We have the opportunity to refine the original answer "
"(only if needed) with some more context below.\n"
"------------\n"
"{context_msg}\n"
"------------\n"
"Given the new context, refine the original answer to better "
"answer the question: {query_str}. "
"If the context isn't useful, output the original answer again.\n"
"Original Answer: {existing_answer}"
),
),
]
refine_template = ChatPromptTemplate(chat_refine_msgs)
stream= index.as_query_engine(
text_qa_template=text_qa_template, refine_template=refine_template, similarity_top_k=6
).query(prompt)
print(str(stream))
output=""
for response in str(stream):
output += response
yield output
return output
def upload_file(files):
file_paths = [file.name for file in files]
return file_paths
additional_inputs=[
gr.Textbox(
label="File Link",
max_lines=1,
interactive=True,
value="https://arxiv.org/pdf/2401.10020.pdf"
),
gr.Slider(
label="Temperature",
value=0.9,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=1024,
minimum=0,
maximum=2048,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.90,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
examples=[["Explain the paper and describe its novelty", None, None, None, None, None, ],
["Can you write a short story about a time-traveling detective who solves historical mysteries?", None, None, None, None, None,],
["I'm trying to learn French. Can you provide some common phrases that would be useful for a beginner, along with their pronunciations?", None, None, None, None, None,],
["I have chicken, rice, and bell peppers in my kitchen. Can you suggest an easy recipe I can make with these ingredients?", None, None, None, None, None,],
["Can you explain how the QuickSort algorithm works and provide a Python implementation?", None, None, None, None, None,],
["What are some unique features of Rust that make it stand out compared to other systems programming languages like C++?", None, None, None, None, None,],
]
gr.ChatInterface(
fn=generate,
chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
additional_inputs=additional_inputs,
title="RAG Demo",
examples=examples,
concurrency_limit=20,
).launch(show_api=False,debug=True)