penlosa / app.py
Pablo276's picture
Update app.py
31577a2
raw
history blame
No virus
4.79 kB
import streamlit as st
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceInstructEmbeddings
from langchain.document_loaders import TextLoader
from langchain.embeddings import SentenceTransformerEmbeddings
from tempfile import NamedTemporaryFile
import os
import shutil
from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from gradio_client import Client
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain.vectorstores import FAISS
import time
try:
shutil.rmtree("tempDir")
except :
pass
try:
os.mkdir("tempDir")
except:
pass
css = '''
<style>
.chat-message {
padding: 1.5rem; border-radius: 0.5rem; margin-bottom: 1rem; display: flex
}
.chat-message.user {
background-color: #2b313e
}
.chat-message.bot {
background-color: #475063
}
.chat-message .avatar {
width: 20%;
}
.chat-message .avatar img {
max-width: 78px;
max-height: 78px;
border-radius: 50%;
object-fit: cover;
}
.chat-message .message {
width: 80%;
padding: 0 1.5rem;
color: #fff;
}
'''
bot_template = '''
<div class="chat-message bot">
<div class="avatar">
<img src="https://i.ibb.co/cN0nmSj/Screenshot-2023-05-28-at-02-37-21.png" style="max-height: 78px; max-width: 78px; border-radius: 50%; object-fit: cover;">
</div>
<div class="message">{{MSG}}</div>
</div>
'''
user_template = '''
<div class="chat-message user">
<div class="avatar">
<img src="https://cdn-icons-png.flaticon.com/512/149/149071.png">
</div>
<div class="message">{{MSG}}</div>
</div>
'''
def save_uploadedfile(uploadedfile):
with open(os.path.join("tempDir",uploadedfile.name),"wb") as f:
f.write(uploadedfile.getbuffer())
return st.success("Saved File:{} to tempDir".format(uploadedfile.name))
def ricerca_llama(domanda):
client = Client("https://ysharma-explore-llamav2-with-tgi.hf.space/")
risultato = client.predict( str(domanda),"you are a university professor, use appropriate language to answer students' questions .",0.1,2000,0.1,1.1,api_name="/chat")
print(domanda)
risultato=str(risultato).split("<")[0]
return risultato
class CustomLLM(LLM):
@property
def _llm_type(self) -> str:
return "custom"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
# Chiamata alla tua funzione API
risultato = ricerca_llama(prompt) # Assumendo che `prompt` sia la descrizione wiki
return risultato
def hande_user_input(user_question):
response=st.session_state.conversation({"question":user_question})
st.session_state.chat_history= response["chat_history"]
for i, message in enumerate(st.session_state.chat_history):
if i % 2== 0:
st.write(user_template.replace("{{MSG}}",message.content),unsafe_allow_html=True)
else:
st.write(bot_template.replace("{{MSG}}",message.content),unsafe_allow_html=True)
def get_conversation_chain(vectorstore):
llm=CustomLLM()
memory=ConversationBufferMemory(memory_key="chat_history",return_messages=True)
conversation_chain= ConversationalRetrievalChain.from_llm(llm=llm,
retriever=vectorstore.as_retriever(),
memory=memory)
return conversation_chain
def main():
st.set_page_config(page_title="chat with unipv")
if "conversation" not in st.session_state:
st.session_state.conversation= None
if "chat_history" not in st.session_state:
st.session_state.chat_history= None
st.write(css,unsafe_allow_html=True)
user_input=st.text_input("fai una domanda al tuo professore ")
if user_input:
hande_user_input(user_input)
with st.sidebar:
st.subheader("Your faiss index")
documents=st.file_uploader("upload your faiss index here ",accept_multiple_files=True)
if st.button("Procedi"):
with st.spinner("sto processando i tuoi dati"):
for document in documents:
save_uploadedfile(document)
time.sleep(1)
embeddings= HuggingFaceInstructEmbeddings(model_name="thenlper/gte-base")
new_db = FAISS.load_local("tempDir", embeddings)
st.session_state.conversation=get_conversation_chain(new_db)
#conversation=get_conversation_chain(new_db)
if __name__=="__main__":
main()