gagan3012 commited on
Commit
ce19127
1 Parent(s): 00499de

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -0
app.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llama_index.llms import HuggingFaceInferenceAPI
2
+ from llama_index.llms import ChatMessage, MessageRole
3
+ from llama_index.prompts import ChatPromptTemplate
4
+ from llama_index import VectorStoreIndex, SimpleDirectoryReader, LLMPredictor, ServiceContext, StorageContext, load_index_from_storage
5
+ import gradio as gr
6
+ import sys
7
+ import logging
8
+ import torch
9
+ from huggingface_hub import InferenceClient
10
+ import tqdm as notebook_tqdm
11
+ import requests
12
+
13
+ def download_file(url, filename):
14
+ """
15
+ Download a file from the specified URL and save it locally under the given filename.
16
+ """
17
+ response = requests.get(url, stream=True)
18
+
19
+ # Check if the request was successful
20
+ if response.status_code == 200:
21
+ with open(filename, 'wb') as file:
22
+ for chunk in response.iter_content(chunk_size=1024):
23
+ if chunk: # filter out keep-alive new chunks
24
+ file.write(chunk)
25
+ print(f"Download complete: {filename}")
26
+ else:
27
+ print(f"Error: Unable to download file. HTTP status code: {response.status_code}")
28
+
29
+ def generate(prompt, history, file_link, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,):
30
+ mixtral = HuggingFaceInferenceAPI(
31
+ model_name="mistralai/Mixtral-8x7B-Instruct-v0.1"
32
+ )
33
+
34
+ service_context = ServiceContext.from_defaults(
35
+ llm=mixtral, embed_model="local:BAAI/bge-small-en-v1.5"
36
+ )
37
+
38
+
39
+ download = download_file(file_link,file_link.split("/")[-1])
40
+
41
+ documents = SimpleDirectoryReader("/content").load_data()
42
+ index = VectorStoreIndex.from_documents(documents,service_context=service_context)
43
+
44
+ # Text QA Prompt
45
+ chat_text_qa_msgs = [
46
+ ChatMessage(
47
+ role=MessageRole.SYSTEM,
48
+ content=(
49
+ "Always answer the question, even if the context isn't helpful."
50
+ ),
51
+ ),
52
+ ChatMessage(
53
+ role=MessageRole.USER,
54
+ content=(
55
+ "Context information is below.\n"
56
+ "---------------------\n"
57
+ "{context_str}\n"
58
+ "---------------------\n"
59
+ "Given the context information and not prior knowledge, "
60
+ "answer the question: {query_str}\n"
61
+ ),
62
+ ),
63
+ ]
64
+ text_qa_template = ChatPromptTemplate(chat_text_qa_msgs)
65
+
66
+ # Refine Prompt
67
+ chat_refine_msgs = [
68
+ ChatMessage(
69
+ role=MessageRole.SYSTEM,
70
+ content=(
71
+ "Always answer the question, even if the context isn't helpful."
72
+ ),
73
+ ),
74
+ ChatMessage(
75
+ role=MessageRole.USER,
76
+ content=(
77
+ "We have the opportunity to refine the original answer "
78
+ "(only if needed) with some more context below.\n"
79
+ "------------\n"
80
+ "{context_msg}\n"
81
+ "------------\n"
82
+ "Given the new context, refine the original answer to better "
83
+ "answer the question: {query_str}. "
84
+ "If the context isn't useful, output the original answer again.\n"
85
+ "Original Answer: {existing_answer}"
86
+ ),
87
+ ),
88
+ ]
89
+ refine_template = ChatPromptTemplate(chat_refine_msgs)
90
+
91
+ stream= index.as_query_engine(
92
+ text_qa_template=text_qa_template, refine_template=refine_template, similarity_top_k=6
93
+ ).query(prompt)
94
+ print(str(stream))
95
+
96
+ output=""
97
+
98
+ for response in str(stream):
99
+ output += response
100
+ yield output
101
+ return output
102
+
103
+ def upload_file(files):
104
+ file_paths = [file.name for file in files]
105
+ return file_paths
106
+
107
+ additional_inputs=[
108
+ gr.Textbox(
109
+ label="File Link",
110
+ max_lines=1,
111
+ interactive=True,
112
+ value="https://arxiv.org/pdf/2401.10020.pdf"
113
+ ),
114
+ gr.Slider(
115
+ label="Temperature",
116
+ value=0.9,
117
+ minimum=0.0,
118
+ maximum=1.0,
119
+ step=0.05,
120
+ interactive=True,
121
+ info="Higher values produce more diverse outputs",
122
+ ),
123
+ gr.Slider(
124
+ label="Max new tokens",
125
+ value=1024,
126
+ minimum=0,
127
+ maximum=2048,
128
+ step=64,
129
+ interactive=True,
130
+ info="The maximum numbers of new tokens",
131
+ ),
132
+ gr.Slider(
133
+ label="Top-p (nucleus sampling)",
134
+ value=0.90,
135
+ minimum=0.0,
136
+ maximum=1,
137
+ step=0.05,
138
+ interactive=True,
139
+ info="Higher values sample more low-probability tokens",
140
+ ),
141
+ gr.Slider(
142
+ label="Repetition penalty",
143
+ value=1.2,
144
+ minimum=1.0,
145
+ maximum=2.0,
146
+ step=0.05,
147
+ interactive=True,
148
+ info="Penalize repeated tokens",
149
+ )
150
+ ]
151
+
152
+ examples=[["Explain the paper and describe its novelty", None, None, None, None, None, ],
153
+ ["Can you write a short story about a time-traveling detective who solves historical mysteries?", None, None, None, None, None,],
154
+ ["I'm trying to learn French. Can you provide some common phrases that would be useful for a beginner, along with their pronunciations?", None, None, None, None, None,],
155
+ ["I have chicken, rice, and bell peppers in my kitchen. Can you suggest an easy recipe I can make with these ingredients?", None, None, None, None, None,],
156
+ ["Can you explain how the QuickSort algorithm works and provide a Python implementation?", None, None, None, None, None,],
157
+ ["What are some unique features of Rust that make it stand out compared to other systems programming languages like C++?", None, None, None, None, None,],
158
+ ]
159
+
160
+ gr.ChatInterface(
161
+ fn=generate,
162
+ chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
163
+ additional_inputs=additional_inputs,
164
+ title="RAG Demo",
165
+ examples=examples,
166
+ concurrency_limit=20,
167
+ ).launch(show_api=False,debug=True)