Carlosito16 commited on
Commit
1d0f253
1 Parent(s): f3517b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -32
app.py CHANGED
@@ -1,49 +1,133 @@
1
  import streamlit as st
 
2
  import pandas as pd
3
- import copy
4
- from googletrans import Translator
5
- from langchain.vectorstores import FAISS
6
- from langchain.text_splitter import RecursiveCharacterTextSplitter
7
- from streamlit_extras.row import row
8
- import requests
9
- from bs4 import BeautifulSoup
10
- from urllib.parse import urlparse
11
- from collections import Counter
12
  import torch
13
- from langchain.embeddings import HuggingFaceInstructEmbeddings
 
 
 
14
  from langchain.vectorstores import FAISS
 
 
 
15
  from langchain import HuggingFacePipeline
16
  from langchain.chains import RetrievalQA
17
- from collections import Counter
 
 
 
 
 
 
 
18
 
19
- st.set_page_config(
20
- page_title="Hello",
21
- page_icon="👋",
 
 
 
 
 
 
 
 
 
 
22
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- st.markdown("main page")
 
 
25
 
 
26
 
27
- # @st.cache_resource
28
- # def load_llm_model(max_length=256, temperature=0, repetition_penalty=1.3):
29
- # # llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
30
- # # task= 'text2text-generation',
31
- # # model_kwargs={ "device_map": "auto",
32
- # # "load_in_8bit": True,"max_length": 256, "temperature": 0,
33
- # # "repetition_penalty": 1.5})
34
 
 
 
 
 
 
 
35
 
36
- # llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
37
- # task= 'text2text-generation',
38
-
39
- # model_kwargs={ "max_length": max_length, "temperature": temperature,
40
- # "torch_dtype":torch.float32,
41
- # "repetition_penalty": repetition_penalty})
42
- # return llm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
 
44
 
45
 
46
- # max_length, temperature, repetition_penalty = 128, 0 , 1.3
47
 
48
- # llm_model = load_llm_model(max_length, temperature, repetition_penalty)
49
- # st.markdown("model successfully downloaded")
 
1
  import streamlit as st
2
+ from streamlit_chat import message as st_message
3
  import pandas as pd
4
+ import numpy as np
5
+ import datetime
6
+ import gspread
 
 
 
 
 
 
7
  import torch
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+
10
+
11
+ # from langchain.vectorstores import Chroma
12
  from langchain.vectorstores import FAISS
13
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
14
+
15
+
16
  from langchain import HuggingFacePipeline
17
  from langchain.chains import RetrievalQA
18
+ from langchain.prompts import PromptTemplate
19
+ from langchain.memory import ConversationBufferWindowMemory
20
+
21
+
22
+ from langchain.chains import LLMChain
23
+ from langchain.chains import ConversationalRetrievalChain
24
+ from langchain.chains.question_answering import load_qa_chain
25
+ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
26
 
27
+
28
+ prompt_template = """
29
+ You are the chatbot and your job is to give answers.
30
+ MUST only use the following pieces of context to answer the question at the end. If the answers are not in the context or you are not sure of the answer, just say that you don't know, don't try to make up an answer.
31
+ {context}
32
+ Question: {question}
33
+ When encountering abusive, offensive, or harmful language, such as fuck, bitch,etc, just politely ask the users to maintain appropriate behaviours.
34
+ Always make sure to elaborate your response.
35
+ Never answer with any unfinished response
36
+ Answer:
37
+ """
38
+ PROMPT = PromptTemplate(
39
+ template=prompt_template, input_variables=["context", "question"]
40
  )
41
+ # chain_type_kwargs = {"prompt": PROMPT}
42
+
43
+ @st.cache_resource
44
+ def load_conversational_qa_memory_retriever():
45
+
46
+ question_generator = LLMChain(llm=llm_model, prompt=CONDENSE_QUESTION_PROMPT)
47
+ doc_chain = load_qa_chain(llm_model, chain_type="stuff", prompt = PROMPT)
48
+ memory = ConversationBufferWindowMemory(k = 3, memory_key="chat_history", return_messages=True, output_key='answer')
49
+
50
+
51
+
52
+ conversational_qa_memory_retriever = ConversationalRetrievalChain(
53
+ retriever=vector_database.as_retriever(),
54
+ question_generator=question_generator,
55
+ combine_docs_chain=doc_chain,
56
+ return_source_documents=True,
57
+ memory = memory,
58
+ get_chat_history=lambda h :h)
59
+ return conversational_qa_memory_retriever, question_generator
60
 
61
+ def new_retrieve_answer():
62
+ prompt_answer= st.session_state.my_text_input + ". Try to be elaborate and informative in your answer."
63
+ answer = conversational_qa_memory_retriever({"question": prompt_answer })
64
 
65
+ print(f"condensed quesion : {question_generator.run({'chat_history': answer['chat_history'], 'question' : prompt_answer})}")
66
 
67
+ print(answer["chat_history"])
 
 
 
 
 
 
68
 
69
+ st.session_state.chat_history.append({"message": st.session_state.my_text_input, "is_user": True})
70
+ st.session_state.chat_history.append({"message": answer['answer'][6:] , "is_user": False})
71
+
72
+ st.session_state.my_text_input = ""
73
+
74
+ return answer['answer'][6:] #this positional slicing helps remove "<pad> " at the beginning
75
 
76
+ def clean_chat_history():
77
+ st.session_state.chat_history = []
78
+ conversational_qa_memory_retriever.memory.chat_memory.clear() #add this to remove
79
+
80
+
81
+ if "history" not in st.session_state: #this one is for the google sheet logging
82
+ st.session_state.history = []
83
+
84
+
85
+ if "chat_history" not in st.session_state: #this one is to pass previous messages into chat flow
86
+ st.session_state.chat_history = []
87
+
88
+
89
+ llm_model = st.session_state['model'],
90
+ vector_database = st.session_state['faiss_db']
91
+ conversational_qa_memory_retriever, question_generator = load_conversational_qa_memory_retriever()
92
+
93
+
94
+
95
+ print("all load done")
96
+
97
+
98
+ # Try adding this to set to clear the memory in each session
99
+ if st.session_state.chat_history == []:
100
+ conversational_qa_memory_retriever.memory.chat_memory.clear()
101
+
102
+
103
+
104
+ st.write("# extraGPT 🤖 ")
105
+
106
+ with st.expander("key information"):
107
+ st.write( st.session_state['chunked_df'], unsafe_allow_html=True)
108
+ st.markdown(st.session_state['max_length'])
109
+ st.markdown(st.session_state['temperature'])
110
+ st.markdown(st.session_state['repetition_penalty'])
111
+
112
+
113
+
114
+ st.write(' ⚠️ Please expect to wait **~ 10 - 20 seconds per question** as thi app is running on CPU against 3-billion-parameter LLM')
115
+
116
+ st.markdown("---")
117
+ st.write(" ")
118
+ st.write("""
119
+ ### ❔ Ask a question
120
+ """)
121
+
122
+
123
+
124
+
125
+ for chat in st.session_state.chat_history:
126
+ st_message(**chat)
127
 
128
+ query_input = st.text_input(label= 'Type a question' , key = 'my_text_input', on_change= new_retrieve_answer )
129
 
130
 
 
131
 
132
+ clear_button = st.button("Start new convo",
133
+ on_click=clean_chat_history)