Spaces:
Runtime error
Runtime error
Update backupapp.py
Browse files- backupapp.py +456 -15
backupapp.py
CHANGED
@@ -1,14 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import requests
|
2 |
import streamlit as st
|
3 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
|
|
|
6 |
API_KEY = os.getenv('API_KEY')
|
7 |
-
|
8 |
headers = {
|
9 |
"Authorization": f"Bearer {API_KEY}",
|
10 |
"Content-Type": "application/json"
|
11 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
def query(payload):
|
14 |
response = requests.post(API_URL, headers=headers, json=payload)
|
@@ -18,19 +88,390 @@ def query(payload):
|
|
18 |
def get_output(prompt):
|
19 |
return query({"inputs": prompt})
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
def main():
|
22 |
-
st.title("
|
23 |
-
|
24 |
-
|
25 |
-
if st.button("
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
if __name__ == "__main__":
|
36 |
main()
|
|
|
1 |
+
Modify the program below to add a few picture buttons which will use emojis and a witty title to describe a prompt. the first prompt I want is "Write ten random adult limerick based on quotes that are tweet length and make you laugh. Show as numbered bold faced and large font markdown outline with emojis for each." Modify this code to add the prompt emoji labeled buttons above the text box. when you click them pass the varible they contain to a function which runs the chat through the Llama web service call in the code below. refactor it so it is function based. Put variables that set description for button and label for button right before the st.button() function calls and use st.expander() function to create a expanded description container with a witty label so user could collapse st.expander to hide buttons of a particular type. This first type will be Wit and Humor. Make sure each label contains appropriate emojis. Code: # Imports
|
2 |
+
import base64
|
3 |
+
import glob
|
4 |
+
import json
|
5 |
+
import math
|
6 |
+
import mistune
|
7 |
+
import openai
|
8 |
+
import os
|
9 |
+
import pytz
|
10 |
+
import re
|
11 |
import requests
|
12 |
import streamlit as st
|
13 |
+
import textract
|
14 |
+
import time
|
15 |
+
import zipfile
|
16 |
+
from audio_recorder_streamlit import audio_recorder
|
17 |
+
from bs4 import BeautifulSoup
|
18 |
+
from collections import deque
|
19 |
+
from datetime import datetime
|
20 |
+
from dotenv import load_dotenv
|
21 |
+
from huggingface_hub import InferenceClient
|
22 |
+
from io import BytesIO
|
23 |
+
from langchain.chat_models import ChatOpenAI
|
24 |
+
from langchain.chains import ConversationalRetrievalChain
|
25 |
+
from langchain.embeddings import OpenAIEmbeddings
|
26 |
+
from langchain.memory import ConversationBufferMemory
|
27 |
+
from langchain.text_splitter import CharacterTextSplitter
|
28 |
+
from langchain.vectorstores import FAISS
|
29 |
+
from openai import ChatCompletion
|
30 |
+
from PyPDF2 import PdfReader
|
31 |
+
from templates import bot_template, css, user_template
|
32 |
+
from xml.etree import ElementTree as ET
|
33 |
|
34 |
+
# Constants
|
35 |
+
API_URL = 'https://qe55p8afio98s0u3.us-east-1.aws.endpoints.huggingface.cloud' # Dr Llama
|
36 |
API_KEY = os.getenv('API_KEY')
|
|
|
37 |
headers = {
|
38 |
"Authorization": f"Bearer {API_KEY}",
|
39 |
"Content-Type": "application/json"
|
40 |
}
|
41 |
+
key = os.getenv('OPENAI_API_KEY')
|
42 |
+
prompt = f"Write instructions to teach anyone to write a discharge plan. List the entities, features and relationships to CCDA and FHIR objects in boldface."
|
43 |
+
# page config and sidebar declares up front allow all other functions to see global class variables
|
44 |
+
st.set_page_config(page_title="GPT Streamlit Document Reasoner", layout="wide")
|
45 |
+
|
46 |
+
# UI Controls
|
47 |
+
should_save = st.sidebar.checkbox("๐พ Save", value=True)
|
48 |
+
|
49 |
+
# Functions
|
50 |
+
def StreamLLMChatResponse(prompt):
|
51 |
+
endpoint_url = API_URL
|
52 |
+
hf_token = API_KEY
|
53 |
+
client = InferenceClient(endpoint_url, token=hf_token)
|
54 |
+
gen_kwargs = dict(
|
55 |
+
max_new_tokens=512,
|
56 |
+
top_k=30,
|
57 |
+
top_p=0.9,
|
58 |
+
temperature=0.2,
|
59 |
+
repetition_penalty=1.02,
|
60 |
+
stop_sequences=["\nUser:", "<|endoftext|>", "</s>"],
|
61 |
+
)
|
62 |
+
stream = client.text_generation(prompt, stream=True, details=True, **gen_kwargs)
|
63 |
+
report=[]
|
64 |
+
res_box = st.empty()
|
65 |
+
collected_chunks=[]
|
66 |
+
collected_messages=[]
|
67 |
+
for r in stream:
|
68 |
+
if r.token.special:
|
69 |
+
continue
|
70 |
+
if r.token.text in gen_kwargs["stop_sequences"]:
|
71 |
+
break
|
72 |
+
collected_chunks.append(r.token.text)
|
73 |
+
chunk_message = r.token.text
|
74 |
+
collected_messages.append(chunk_message)
|
75 |
+
try:
|
76 |
+
report.append(r.token.text)
|
77 |
+
if len(r.token.text) > 0:
|
78 |
+
result="".join(report).strip()
|
79 |
+
res_box.markdown(f'*{result}*')
|
80 |
+
except:
|
81 |
+
st.write(' ')
|
82 |
|
83 |
def query(payload):
|
84 |
response = requests.post(API_URL, headers=headers, json=payload)
|
|
|
88 |
def get_output(prompt):
|
89 |
return query({"inputs": prompt})
|
90 |
|
91 |
+
def generate_filename(prompt, file_type):
|
92 |
+
central = pytz.timezone('US/Central')
|
93 |
+
safe_date_time = datetime.now(central).strftime("%m%d_%H%M")
|
94 |
+
replaced_prompt = prompt.replace(" ", "_").replace("\n", "_")
|
95 |
+
safe_prompt = "".join(x for x in replaced_prompt if x.isalnum() or x == "_")[:90]
|
96 |
+
return f"{safe_date_time}_{safe_prompt}.{file_type}"
|
97 |
+
|
98 |
+
def transcribe_audio(openai_key, file_path, model):
|
99 |
+
openai.api_key = openai_key
|
100 |
+
OPENAI_API_URL = "https://api.openai.com/v1/audio/transcriptions"
|
101 |
+
headers = {
|
102 |
+
"Authorization": f"Bearer {openai_key}",
|
103 |
+
}
|
104 |
+
with open(file_path, 'rb') as f:
|
105 |
+
data = {'file': f}
|
106 |
+
response = requests.post(OPENAI_API_URL, headers=headers, files=data, data={'model': model})
|
107 |
+
if response.status_code == 200:
|
108 |
+
st.write(response.json())
|
109 |
+
chatResponse = chat_with_model(response.json().get('text'), '') # *************************************
|
110 |
+
transcript = response.json().get('text')
|
111 |
+
filename = generate_filename(transcript, 'txt')
|
112 |
+
response = chatResponse
|
113 |
+
user_prompt = transcript
|
114 |
+
create_file(filename, user_prompt, response, should_save)
|
115 |
+
return transcript
|
116 |
+
else:
|
117 |
+
st.write(response.json())
|
118 |
+
st.error("Error in API call.")
|
119 |
+
return None
|
120 |
+
|
121 |
+
def save_and_play_audio(audio_recorder):
|
122 |
+
audio_bytes = audio_recorder()
|
123 |
+
if audio_bytes:
|
124 |
+
filename = generate_filename("Recording", "wav")
|
125 |
+
with open(filename, 'wb') as f:
|
126 |
+
f.write(audio_bytes)
|
127 |
+
st.audio(audio_bytes, format="audio/wav")
|
128 |
+
return filename
|
129 |
+
return None
|
130 |
+
|
131 |
+
def create_file(filename, prompt, response, should_save=True):
|
132 |
+
if not should_save:
|
133 |
+
return
|
134 |
+
base_filename, ext = os.path.splitext(filename)
|
135 |
+
has_python_code = bool(re.search(r"```python([\s\S]*?)```", response))
|
136 |
+
if ext in ['.txt', '.htm', '.md']:
|
137 |
+
with open(f"{base_filename}-Prompt.txt", 'w') as file:
|
138 |
+
file.write(prompt)
|
139 |
+
with open(f"{base_filename}-Response.md", 'w') as file:
|
140 |
+
file.write(response)
|
141 |
+
if has_python_code:
|
142 |
+
python_code = re.findall(r"```python([\s\S]*?)```", response)[0].strip()
|
143 |
+
with open(f"{base_filename}-Code.py", 'w') as file:
|
144 |
+
file.write(python_code)
|
145 |
+
|
146 |
+
def truncate_document(document, length):
|
147 |
+
return document[:length]
|
148 |
+
|
149 |
+
def divide_document(document, max_length):
|
150 |
+
return [document[i:i+max_length] for i in range(0, len(document), max_length)]
|
151 |
+
|
152 |
+
def get_table_download_link(file_path):
|
153 |
+
with open(file_path, 'r') as file:
|
154 |
+
try:
|
155 |
+
data = file.read()
|
156 |
+
except:
|
157 |
+
st.write('')
|
158 |
+
return file_path
|
159 |
+
b64 = base64.b64encode(data.encode()).decode()
|
160 |
+
file_name = os.path.basename(file_path)
|
161 |
+
ext = os.path.splitext(file_name)[1] # get the file extension
|
162 |
+
if ext == '.txt':
|
163 |
+
mime_type = 'text/plain'
|
164 |
+
elif ext == '.py':
|
165 |
+
mime_type = 'text/plain'
|
166 |
+
elif ext == '.xlsx':
|
167 |
+
mime_type = 'text/plain'
|
168 |
+
elif ext == '.csv':
|
169 |
+
mime_type = 'text/plain'
|
170 |
+
elif ext == '.htm':
|
171 |
+
mime_type = 'text/html'
|
172 |
+
elif ext == '.md':
|
173 |
+
mime_type = 'text/markdown'
|
174 |
+
else:
|
175 |
+
mime_type = 'application/octet-stream' # general binary data type
|
176 |
+
href = f'<a href="data:{mime_type};base64,{b64}" target="_blank" download="{file_name}">{file_name}</a>'
|
177 |
+
return href
|
178 |
+
|
179 |
+
def CompressXML(xml_text):
|
180 |
+
root = ET.fromstring(xml_text)
|
181 |
+
for elem in list(root.iter()):
|
182 |
+
if isinstance(elem.tag, str) and 'Comment' in elem.tag:
|
183 |
+
elem.parent.remove(elem)
|
184 |
+
return ET.tostring(root, encoding='unicode', method="xml")
|
185 |
+
|
186 |
+
def read_file_content(file,max_length):
|
187 |
+
if file.type == "application/json":
|
188 |
+
content = json.load(file)
|
189 |
+
return str(content)
|
190 |
+
elif file.type == "text/html" or file.type == "text/htm":
|
191 |
+
content = BeautifulSoup(file, "html.parser")
|
192 |
+
return content.text
|
193 |
+
elif file.type == "application/xml" or file.type == "text/xml":
|
194 |
+
tree = ET.parse(file)
|
195 |
+
root = tree.getroot()
|
196 |
+
xml = CompressXML(ET.tostring(root, encoding='unicode'))
|
197 |
+
return xml
|
198 |
+
elif file.type == "text/markdown" or file.type == "text/md":
|
199 |
+
md = mistune.create_markdown()
|
200 |
+
content = md(file.read().decode())
|
201 |
+
return content
|
202 |
+
elif file.type == "text/plain":
|
203 |
+
return file.getvalue().decode()
|
204 |
+
else:
|
205 |
+
return ""
|
206 |
+
|
207 |
+
def chat_with_model(prompt, document_section, model_choice='gpt-3.5-turbo'):
|
208 |
+
model = model_choice
|
209 |
+
conversation = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
|
210 |
+
conversation.append({'role': 'user', 'content': prompt})
|
211 |
+
if len(document_section)>0:
|
212 |
+
conversation.append({'role': 'assistant', 'content': document_section})
|
213 |
+
start_time = time.time()
|
214 |
+
report = []
|
215 |
+
res_box = st.empty()
|
216 |
+
collected_chunks = []
|
217 |
+
collected_messages = []
|
218 |
+
for chunk in openai.ChatCompletion.create(model='gpt-3.5-turbo', messages=conversation, temperature=0.5, stream=True):
|
219 |
+
collected_chunks.append(chunk)
|
220 |
+
chunk_message = chunk['choices'][0]['delta']
|
221 |
+
collected_messages.append(chunk_message)
|
222 |
+
content=chunk["choices"][0].get("delta",{}).get("content")
|
223 |
+
try:
|
224 |
+
report.append(content)
|
225 |
+
if len(content) > 0:
|
226 |
+
result = "".join(report).strip()
|
227 |
+
res_box.markdown(f'*{result}*')
|
228 |
+
except:
|
229 |
+
st.write(' ')
|
230 |
+
full_reply_content = ''.join([m.get('content', '') for m in collected_messages])
|
231 |
+
st.write("Elapsed time:")
|
232 |
+
st.write(time.time() - start_time)
|
233 |
+
return full_reply_content
|
234 |
+
|
235 |
+
def chat_with_file_contents(prompt, file_content, model_choice='gpt-3.5-turbo'):
|
236 |
+
conversation = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
|
237 |
+
conversation.append({'role': 'user', 'content': prompt})
|
238 |
+
if len(file_content)>0:
|
239 |
+
conversation.append({'role': 'assistant', 'content': file_content})
|
240 |
+
response = openai.ChatCompletion.create(model=model_choice, messages=conversation)
|
241 |
+
return response['choices'][0]['message']['content']
|
242 |
+
|
243 |
+
def extract_mime_type(file):
|
244 |
+
if isinstance(file, str):
|
245 |
+
pattern = r"type='(.*?)'"
|
246 |
+
match = re.search(pattern, file)
|
247 |
+
if match:
|
248 |
+
return match.group(1)
|
249 |
+
else:
|
250 |
+
raise ValueError(f"Unable to extract MIME type from {file}")
|
251 |
+
elif isinstance(file, streamlit.UploadedFile):
|
252 |
+
return file.type
|
253 |
+
else:
|
254 |
+
raise TypeError("Input should be a string or a streamlit.UploadedFile object")
|
255 |
+
|
256 |
+
def extract_file_extension(file):
|
257 |
+
# get the file name directly from the UploadedFile object
|
258 |
+
file_name = file.name
|
259 |
+
pattern = r".*?\.(.*?)$"
|
260 |
+
match = re.search(pattern, file_name)
|
261 |
+
if match:
|
262 |
+
return match.group(1)
|
263 |
+
else:
|
264 |
+
raise ValueError(f"Unable to extract file extension from {file_name}")
|
265 |
+
|
266 |
+
def pdf2txt(docs):
|
267 |
+
text = ""
|
268 |
+
for file in docs:
|
269 |
+
file_extension = extract_file_extension(file)
|
270 |
+
st.write(f"File type extension: {file_extension}")
|
271 |
+
try:
|
272 |
+
if file_extension.lower() in ['py', 'txt', 'html', 'htm', 'xml', 'json']:
|
273 |
+
text += file.getvalue().decode('utf-8')
|
274 |
+
elif file_extension.lower() == 'pdf':
|
275 |
+
from PyPDF2 import PdfReader
|
276 |
+
pdf = PdfReader(BytesIO(file.getvalue()))
|
277 |
+
for page in range(len(pdf.pages)):
|
278 |
+
text += pdf.pages[page].extract_text() # new PyPDF2 syntax
|
279 |
+
except Exception as e:
|
280 |
+
st.write(f"Error processing file {file.name}: {e}")
|
281 |
+
return text
|
282 |
+
|
283 |
+
def txt2chunks(text):
|
284 |
+
text_splitter = CharacterTextSplitter(separator="\n", chunk_size=1000, chunk_overlap=200, length_function=len)
|
285 |
+
return text_splitter.split_text(text)
|
286 |
+
|
287 |
+
def vector_store(text_chunks):
|
288 |
+
embeddings = OpenAIEmbeddings(openai_api_key=key)
|
289 |
+
return FAISS.from_texts(texts=text_chunks, embedding=embeddings)
|
290 |
+
|
291 |
+
def get_chain(vectorstore):
|
292 |
+
llm = ChatOpenAI()
|
293 |
+
memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)
|
294 |
+
return ConversationalRetrievalChain.from_llm(llm=llm, retriever=vectorstore.as_retriever(), memory=memory)
|
295 |
+
|
296 |
+
def process_user_input(user_question):
|
297 |
+
response = st.session_state.conversation({'question': user_question})
|
298 |
+
st.session_state.chat_history = response['chat_history']
|
299 |
+
for i, message in enumerate(st.session_state.chat_history):
|
300 |
+
template = user_template if i % 2 == 0 else bot_template
|
301 |
+
st.write(template.replace("{{MSG}}", message.content), unsafe_allow_html=True)
|
302 |
+
filename = generate_filename(user_question, 'txt')
|
303 |
+
response = message.content
|
304 |
+
user_prompt = user_question
|
305 |
+
create_file(filename, user_prompt, response, should_save)
|
306 |
+
|
307 |
+
def divide_prompt(prompt, max_length):
|
308 |
+
words = prompt.split()
|
309 |
+
chunks = []
|
310 |
+
current_chunk = []
|
311 |
+
current_length = 0
|
312 |
+
for word in words:
|
313 |
+
if len(word) + current_length <= max_length:
|
314 |
+
current_length += len(word) + 1
|
315 |
+
current_chunk.append(word)
|
316 |
+
else:
|
317 |
+
chunks.append(' '.join(current_chunk))
|
318 |
+
current_chunk = [word]
|
319 |
+
current_length = len(word)
|
320 |
+
chunks.append(' '.join(current_chunk))
|
321 |
+
return chunks
|
322 |
+
|
323 |
+
def create_zip_of_files(files):
|
324 |
+
zip_name = "all_files.zip"
|
325 |
+
with zipfile.ZipFile(zip_name, 'w') as zipf:
|
326 |
+
for file in files:
|
327 |
+
zipf.write(file)
|
328 |
+
return zip_name
|
329 |
+
|
330 |
+
def get_zip_download_link(zip_file):
|
331 |
+
with open(zip_file, 'rb') as f:
|
332 |
+
data = f.read()
|
333 |
+
b64 = base64.b64encode(data).decode()
|
334 |
+
href = f'<a href="data:application/zip;base64,{b64}" download="{zip_file}">Download All</a>'
|
335 |
+
return href
|
336 |
+
|
337 |
def main():
|
338 |
+
st.title(" DrLlama7B")
|
339 |
+
prompt = f"Write ten funny jokes that are tweet length stories that make you laugh. Show as markdown outline with emojis for each."
|
340 |
+
example_input = st.text_input("Enter your example text:", value=prompt)
|
341 |
+
if st.button("Run Prompt With Dr Llama"):
|
342 |
+
try:
|
343 |
+
StreamLLMChatResponse(example_input)
|
344 |
+
except:
|
345 |
+
st.write('Dr. Llama is asleep. Starting up now on A10 - please give 5 minutes then retry as KEDA scales up from zero to activate running container(s).')
|
346 |
+
openai.api_key = os.getenv('OPENAI_KEY')
|
347 |
+
menu = ["txt", "htm", "xlsx", "csv", "md", "py"]
|
348 |
+
choice = st.sidebar.selectbox("Output File Type:", menu)
|
349 |
+
model_choice = st.sidebar.radio("Select Model:", ('gpt-3.5-turbo', 'gpt-3.5-turbo-0301'))
|
350 |
+
filename = save_and_play_audio(audio_recorder)
|
351 |
+
if filename is not None:
|
352 |
+
transcription = transcribe_audio(key, filename, "whisper-1")
|
353 |
+
st.sidebar.markdown(get_table_download_link(filename), unsafe_allow_html=True)
|
354 |
+
filename = None
|
355 |
+
user_prompt = st.text_area("Enter prompts, instructions & questions:", '', height=100)
|
356 |
+
collength, colupload = st.columns([2,3]) # adjust the ratio as needed
|
357 |
+
with collength:
|
358 |
+
max_length = st.slider("File section length for large files", min_value=1000, max_value=128000, value=12000, step=1000)
|
359 |
+
with colupload:
|
360 |
+
uploaded_file = st.file_uploader("Add a file for context:", type=["pdf", "xml", "json", "xlsx", "csv", "html", "htm", "md", "txt"])
|
361 |
+
document_sections = deque()
|
362 |
+
document_responses = {}
|
363 |
+
if uploaded_file is not None:
|
364 |
+
file_content = read_file_content(uploaded_file, max_length)
|
365 |
+
document_sections.extend(divide_document(file_content, max_length))
|
366 |
+
if len(document_sections) > 0:
|
367 |
+
if st.button("๐๏ธ View Upload"):
|
368 |
+
st.markdown("**Sections of the uploaded file:**")
|
369 |
+
for i, section in enumerate(list(document_sections)):
|
370 |
+
st.markdown(f"**Section {i+1}**\n{section}")
|
371 |
+
st.markdown("**Chat with the model:**")
|
372 |
+
for i, section in enumerate(list(document_sections)):
|
373 |
+
if i in document_responses:
|
374 |
+
st.markdown(f"**Section {i+1}**\n{document_responses[i]}")
|
375 |
+
else:
|
376 |
+
if st.button(f"Chat about Section {i+1}"):
|
377 |
+
st.write('Reasoning with your inputs...')
|
378 |
+
response = chat_with_model(user_prompt, section, model_choice)
|
379 |
+
st.write('Response:')
|
380 |
+
st.write(response)
|
381 |
+
document_responses[i] = response
|
382 |
+
filename = generate_filename(f"{user_prompt}_section_{i+1}", choice)
|
383 |
+
create_file(filename, user_prompt, response, should_save)
|
384 |
+
st.sidebar.markdown(get_table_download_link(filename), unsafe_allow_html=True)
|
385 |
+
if st.button('๐ฌ Chat'):
|
386 |
+
st.write('Reasoning with your inputs...')
|
387 |
+
user_prompt_sections = divide_prompt(user_prompt, max_length)
|
388 |
+
full_response = ''
|
389 |
+
for prompt_section in user_prompt_sections:
|
390 |
+
response = chat_with_model(prompt_section, ''.join(list(document_sections)), model_choice)
|
391 |
+
full_response += response + '\n' # Combine the responses
|
392 |
+
response = full_response
|
393 |
+
st.write('Response:')
|
394 |
+
st.write(response)
|
395 |
+
filename = generate_filename(user_prompt, choice)
|
396 |
+
create_file(filename, user_prompt, response, should_save)
|
397 |
+
st.sidebar.markdown(get_table_download_link(filename), unsafe_allow_html=True)
|
398 |
+
all_files = glob.glob("*.*")
|
399 |
+
all_files = [file for file in all_files if len(os.path.splitext(file)[0]) >= 20] # exclude files with short names
|
400 |
+
all_files.sort(key=lambda x: (os.path.splitext(x)[1], x), reverse=True) # sort by file type and file name in descending order
|
401 |
+
if st.sidebar.button("๐ Delete All"):
|
402 |
+
for file in all_files:
|
403 |
+
os.remove(file)
|
404 |
+
st.experimental_rerun()
|
405 |
+
if st.sidebar.button("โฌ๏ธ Download All"):
|
406 |
+
zip_file = create_zip_of_files(all_files)
|
407 |
+
st.sidebar.markdown(get_zip_download_link(zip_file), unsafe_allow_html=True)
|
408 |
+
file_contents=''
|
409 |
+
next_action=''
|
410 |
+
for file in all_files:
|
411 |
+
col1, col2, col3, col4, col5 = st.sidebar.columns([1,6,1,1,1]) # adjust the ratio as needed
|
412 |
+
with col1:
|
413 |
+
if st.button("๐", key="md_"+file): # md emoji button
|
414 |
+
with open(file, 'r') as f:
|
415 |
+
file_contents = f.read()
|
416 |
+
next_action='md'
|
417 |
+
with col2:
|
418 |
+
st.markdown(get_table_download_link(file), unsafe_allow_html=True)
|
419 |
+
with col3:
|
420 |
+
if st.button("๐", key="open_"+file): # open emoji button
|
421 |
+
with open(file, 'r') as f:
|
422 |
+
file_contents = f.read()
|
423 |
+
next_action='open'
|
424 |
+
with col4:
|
425 |
+
if st.button("๐", key="read_"+file): # search emoji button
|
426 |
+
with open(file, 'r') as f:
|
427 |
+
file_contents = f.read()
|
428 |
+
next_action='search'
|
429 |
+
with col5:
|
430 |
+
if st.button("๐", key="delete_"+file):
|
431 |
+
os.remove(file)
|
432 |
+
st.experimental_rerun()
|
433 |
+
if len(file_contents) > 0:
|
434 |
+
if next_action=='open':
|
435 |
+
file_content_area = st.text_area("File Contents:", file_contents, height=500)
|
436 |
+
if next_action=='md':
|
437 |
+
st.markdown(file_contents)
|
438 |
+
if next_action=='search':
|
439 |
+
file_content_area = st.text_area("File Contents:", file_contents, height=500)
|
440 |
+
st.write('Reasoning with your inputs...')
|
441 |
+
response = chat_with_model(user_prompt, file_contents, model_choice)
|
442 |
+
filename = generate_filename(file_contents, choice)
|
443 |
+
create_file(filename, user_prompt, response, should_save)
|
444 |
+
st.experimental_rerun()
|
445 |
+
|
446 |
+
|
447 |
+
# Feedback
|
448 |
+
# Step: Give User a Way to Upvote or Downvote
|
449 |
+
feedback = st.radio("Step 8: Give your feedback", ("๐ Upvote", "๐ Downvote"))
|
450 |
+
|
451 |
+
if feedback == "๐ Upvote":
|
452 |
+
st.write("You upvoted ๐. Thank you for your feedback!")
|
453 |
+
else:
|
454 |
+
st.write("You downvoted ๐. Thank you for your feedback!")
|
455 |
+
|
456 |
+
load_dotenv()
|
457 |
+
st.write(css, unsafe_allow_html=True)
|
458 |
+
st.header("Chat with documents :books:")
|
459 |
+
user_question = st.text_input("Ask a question about your documents:")
|
460 |
+
if user_question:
|
461 |
+
process_user_input(user_question)
|
462 |
+
with st.sidebar:
|
463 |
+
st.subheader("Your documents")
|
464 |
+
docs = st.file_uploader("import documents", accept_multiple_files=True)
|
465 |
+
with st.spinner("Processing"):
|
466 |
+
raw = pdf2txt(docs)
|
467 |
+
if len(raw) > 0:
|
468 |
+
length = str(len(raw))
|
469 |
+
text_chunks = txt2chunks(raw)
|
470 |
+
vectorstore = vector_store(text_chunks)
|
471 |
+
st.session_state.conversation = get_chain(vectorstore)
|
472 |
+
st.markdown('# AI Search Index of Length:' + length + ' Created.') # add timing
|
473 |
+
filename = generate_filename(raw, 'txt')
|
474 |
+
create_file(filename, raw, '', should_save)
|
475 |
|
476 |
if __name__ == "__main__":
|
477 |
main()
|