|
import streamlit as st |
|
from langchain.text_splitter import CharacterTextSplitter |
|
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceInstructEmbeddings |
|
from langchain.document_loaders import TextLoader |
|
from langchain.embeddings import SentenceTransformerEmbeddings |
|
from tempfile import NamedTemporaryFile |
|
import os |
|
import shutil |
|
from typing import Any, List, Mapping, Optional |
|
from langchain.callbacks.manager import CallbackManagerForLLMRun |
|
from langchain.llms.base import LLM |
|
from gradio_client import Client |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.vectorstores import FAISS |
|
import time |
|
try: |
|
shutil.rmtree("tempDir") |
|
except : |
|
pass |
|
try: |
|
os.mkdir("tempDir") |
|
except: |
|
pass |
|
css = ''' |
|
<style> |
|
.chat-message { |
|
padding: 1.5rem; border-radius: 0.5rem; margin-bottom: 1rem; display: flex |
|
} |
|
.chat-message.user { |
|
background-color: #2b313e |
|
} |
|
.chat-message.bot { |
|
background-color: #475063 |
|
} |
|
.chat-message .avatar { |
|
width: 20%; |
|
} |
|
.chat-message .avatar img { |
|
max-width: 78px; |
|
max-height: 78px; |
|
border-radius: 50%; |
|
object-fit: cover; |
|
} |
|
.chat-message .message { |
|
width: 80%; |
|
padding: 0 1.5rem; |
|
color: #fff; |
|
} |
|
''' |
|
|
|
bot_template = ''' |
|
<div class="chat-message bot"> |
|
<div class="avatar"> |
|
<img src="https://i.ibb.co/cN0nmSj/Screenshot-2023-05-28-at-02-37-21.png" style="max-height: 78px; max-width: 78px; border-radius: 50%; object-fit: cover;"> |
|
</div> |
|
<div class="message">{{MSG}}</div> |
|
</div> |
|
''' |
|
|
|
user_template = ''' |
|
<div class="chat-message user"> |
|
<div class="avatar"> |
|
<img src="https://cdn-icons-png.flaticon.com/512/149/149071.png"> |
|
</div> |
|
<div class="message">{{MSG}}</div> |
|
</div> |
|
''' |
|
|
|
|
|
def save_uploadedfile(uploadedfile): |
|
|
|
|
|
with open(os.path.join("tempDir",uploadedfile.name),"wb") as f: |
|
f.write(uploadedfile.getbuffer()) |
|
return st.success("Saved File:{} to tempDir".format(uploadedfile.name)) |
|
|
|
def ricerca_llama(domanda): |
|
client = Client("https://ysharma-explore-llamav2-with-tgi.hf.space/") |
|
risultato = client.predict( str(domanda),"you are a university professor, use appropriate language to answer students' questions .",0.1,2000,0.1,1.1,api_name="/chat") |
|
print(domanda) |
|
risultato=str(risultato).split("<")[0] |
|
return risultato |
|
|
|
|
|
|
|
class CustomLLM(LLM): |
|
@property |
|
def _llm_type(self) -> str: |
|
return "custom" |
|
|
|
def _call( |
|
self, |
|
prompt: str, |
|
stop: Optional[List[str]] = None, |
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
**kwargs: Any, |
|
) -> str: |
|
if stop is not None: |
|
raise ValueError("stop kwargs are not permitted.") |
|
|
|
|
|
risultato = ricerca_llama(prompt) |
|
|
|
return risultato |
|
|
|
def hande_user_input(user_question): |
|
response=st.session_state.conversation({"question":user_question}) |
|
st.session_state.chat_history= response["chat_history"] |
|
for i, message in enumerate(st.session_state.chat_history): |
|
if i % 2== 0: |
|
st.write(user_template.replace("{{MSG}}",message.content),unsafe_allow_html=True) |
|
else: |
|
st.write(bot_template.replace("{{MSG}}",message.content),unsafe_allow_html=True) |
|
|
|
|
|
def get_conversation_chain(vectorstore): |
|
llm=CustomLLM() |
|
memory=ConversationBufferMemory(memory_key="chat_history",return_messages=True) |
|
conversation_chain= ConversationalRetrievalChain.from_llm(llm=llm, |
|
retriever=vectorstore.as_retriever(), |
|
memory=memory) |
|
return conversation_chain |
|
def main(): |
|
|
|
st.set_page_config(page_title="chat with unipv") |
|
|
|
if "conversation" not in st.session_state: |
|
st.session_state.conversation= None |
|
if "chat_history" not in st.session_state: |
|
st.session_state.chat_history= None |
|
|
|
st.write(css,unsafe_allow_html=True) |
|
user_input=st.text_input("fai una domanda al tuo professore ") |
|
if user_input: |
|
hande_user_input(user_input) |
|
|
|
with st.sidebar: |
|
st.subheader("Your faiss index") |
|
documents=st.file_uploader("upload your faiss index here ",accept_multiple_files=True) |
|
if st.button("Procedi"): |
|
with st.spinner("sto processando i tuoi dati"): |
|
for document in documents: |
|
save_uploadedfile(document) |
|
time.sleep(1) |
|
embeddings= HuggingFaceInstructEmbeddings(model_name="thenlper/gte-base") |
|
new_db = FAISS.load_local("tempDir", embeddings) |
|
st.session_state.conversation=get_conversation_chain(new_db) |
|
|
|
if __name__=="__main__": |
|
main() |