added class
Browse files
app.py
CHANGED
@@ -13,113 +13,88 @@ from langchain.chains import ConversationalRetrievalChain
|
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
from langchain.llms import LlamaCpp
|
15 |
from langchain.chains import LLMChain
|
16 |
-
|
17 |
import time
|
18 |
import streamlit as st
|
19 |
|
20 |
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
56 |
resume_download=True,
|
57 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
CONTEXT_WINDOW_SIZE = 1500
|
64 |
-
MAX_NEW_TOKENS = 2000
|
65 |
-
N_BATCH = 512
|
66 |
-
n_gpu_layers = 40
|
67 |
-
kwargs = {
|
68 |
-
"model_path": model_path,
|
69 |
-
"n_ctx": CONTEXT_WINDOW_SIZE,
|
70 |
-
"max_tokens": MAX_NEW_TOKENS,
|
71 |
-
"n_batch": N_BATCH,
|
72 |
-
"n_gpu_layers": n_gpu_layers,
|
73 |
-
"callback_manager": callback_manager,
|
74 |
-
"verbose":True,
|
75 |
-
}
|
76 |
-
|
77 |
-
|
78 |
-
# Callbacks support token-wise streaming
|
79 |
-
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
80 |
-
|
81 |
-
n_gpu_layers = 40 # Change this value based on your model and your GPU VRAM pool.
|
82 |
-
n_batch = 512 # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
|
83 |
-
max_tokens = 2000
|
84 |
-
# Make sure the model path is correct for your system!
|
85 |
-
llm = LlamaCpp(
|
86 |
-
model_path=model_path,
|
87 |
-
n_gpu_layers=n_gpu_layers,
|
88 |
-
|
89 |
-
n_batch=n_batch,
|
90 |
-
max_tokens= max_tokens,
|
91 |
-
callback_manager=callback_manager,
|
92 |
-
verbose=True, # Verbose is required to pass to the callback manager
|
93 |
-
)
|
94 |
-
|
95 |
-
llm = LlamaCpp(**kwargs)
|
96 |
-
|
97 |
-
memory = ConversationBufferMemory(
|
98 |
-
memory_key="chat_history",
|
99 |
-
return_messages=True,
|
100 |
-
input_key='question',
|
101 |
-
output_key='answer'
|
102 |
-
)
|
103 |
-
|
104 |
-
# memory.clear()
|
105 |
-
|
106 |
-
qa = ConversationalRetrievalChain.from_llm(
|
107 |
-
llm,
|
108 |
-
chain_type="stuff",
|
109 |
-
retriever=VectorStore.as_retriever(search_kwargs={"k": 5}),
|
110 |
-
memory=memory,
|
111 |
-
return_source_documents=True,
|
112 |
-
verbose=False,
|
113 |
-
)
|
114 |
-
def translate(text, source="English", target="Moroccan Arabic"):
|
115 |
-
client = Client("https://facebook-seamless-m4t-v2-large.hf.space/--replicas/2bmbx/")
|
116 |
-
result = client.predict(
|
117 |
text,
|
118 |
source,
|
119 |
target,
|
120 |
api_name="/t2tt"
|
121 |
-
|
122 |
-
return result
|
123 |
|
124 |
#---------------------------------------------------------
|
125 |
|
@@ -149,31 +124,25 @@ for message in st.session_state.messages:
|
|
149 |
with st.chat_message(message["role"], avatar="logo.png"):
|
150 |
st.write(message["content"])
|
151 |
|
|
|
|
|
|
|
|
|
152 |
def clear_chat_history():
|
153 |
-
memory.clear()
|
154 |
-
qa =
|
155 |
-
llm,
|
156 |
-
chain_type="stuff",
|
157 |
-
retriever=VectorStore.as_retriever(search_kwargs={"k": 5}),
|
158 |
-
memory=memory,
|
159 |
-
return_source_documents=True,
|
160 |
-
verbose=False,
|
161 |
-
)
|
162 |
st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
|
163 |
|
164 |
-
st.sidebar.button('Clear Chat History', on_click=clear_chat_history)
|
165 |
-
selected_language = st.sidebar.selectbox("Select Language", ["English", "Darija"], index=0) # English is the default
|
166 |
-
|
167 |
-
# Function for generating LLaMA2 response
|
168 |
def generate_llm_response(prompt_input):
|
169 |
-
res = qa(f'''{prompt_input}''')
|
170 |
|
171 |
if selected_language == "Darija":
|
172 |
-
translated_response = translate(res['answer'])
|
173 |
return translated_response
|
174 |
else:
|
175 |
return res['answer']
|
176 |
|
|
|
177 |
# User-provided prompt
|
178 |
if prompt := st.chat_input("What is up?"):
|
179 |
if selected_language == "Darija":
|
|
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
from langchain.llms import LlamaCpp
|
15 |
from langchain.chains import LLMChain
|
|
|
16 |
import time
|
17 |
import streamlit as st
|
18 |
|
19 |
|
20 |
|
21 |
+
class MyBot:
|
22 |
+
def __init__(self, text_file, model_id, model_basename):
|
23 |
+
self.text_file = text_file
|
24 |
+
self.model_id = model_id
|
25 |
+
self.model_basename = model_basename
|
26 |
+
self.loader = TextLoader(self.text_file)
|
27 |
+
self.pages = self.loader.load()
|
28 |
+
self.chunks_text = self.split_text(self.pages)
|
29 |
+
self.docs_text = [doc.page_content for doc in self.chunks_text]
|
30 |
+
self.embedding = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
|
31 |
+
self.VectorStore = FAISS.from_texts(self.docs_text, embedding=self.embedding)
|
32 |
+
self.model_path = self.download_model(self.model_id, self.model_basename)
|
33 |
+
self.callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
34 |
+
self.llm = self.init_llm(self.model_path, self.callback_manager)
|
35 |
+
self.memory = ConversationBufferMemory(
|
36 |
+
memory_key="chat_history",
|
37 |
+
return_messages=True,
|
38 |
+
input_key='question',
|
39 |
+
output_key='answer'
|
40 |
+
)
|
41 |
+
self.qa = self.init_qa(self.llm, self.VectorStore, self.memory)
|
42 |
+
|
43 |
+
def split_text(self, documents):
|
44 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
45 |
+
chunk_size=1000,
|
46 |
+
chunk_overlap=150,
|
47 |
+
length_function=len,
|
48 |
+
add_start_index=True,
|
49 |
+
)
|
50 |
+
chunks = text_splitter.split_documents(documents)
|
51 |
+
return chunks
|
52 |
|
53 |
+
def download_model(self, model_id, model_basename):
|
54 |
+
model_path = hf_hub_download(
|
55 |
+
repo_id=model_id,
|
56 |
+
filename=model_basename,
|
57 |
resume_download=True,
|
58 |
)
|
59 |
+
print("model_path : ", model_path)
|
60 |
+
return model_path
|
61 |
+
|
62 |
+
def init_llm(self, model_path, callback_manager):
|
63 |
+
CONTEXT_WINDOW_SIZE = 1500
|
64 |
+
MAX_NEW_TOKENS = 2000
|
65 |
+
N_BATCH = 512
|
66 |
+
n_gpu_layers = 40
|
67 |
+
kwargs = {
|
68 |
+
"model_path": model_path,
|
69 |
+
"n_ctx": CONTEXT_WINDOW_SIZE,
|
70 |
+
"max_tokens": MAX_NEW_TOKENS,
|
71 |
+
"n_batch": N_BATCH,
|
72 |
+
"n_gpu_layers": n_gpu_layers,
|
73 |
+
"callback_manager": callback_manager,
|
74 |
+
"verbose":True,
|
75 |
+
}
|
76 |
+
llm = LlamaCpp(**kwargs)
|
77 |
+
return llm
|
78 |
+
|
79 |
+
def init_qa(self, llm, VectorStore, memory):
|
80 |
+
qa = ConversationalRetrievalChain.from_llm(
|
81 |
+
llm,
|
82 |
+
chain_type="stuff",
|
83 |
+
retriever=VectorStore.as_retriever(search_kwargs={"k": 5}),
|
84 |
+
memory=memory,
|
85 |
+
return_source_documents=True,
|
86 |
+
verbose=False,
|
87 |
+
)
|
88 |
+
return qa
|
89 |
|
90 |
+
def translate(self, text, source="English", target="Moroccan Arabic"):
|
91 |
+
client = Client("https://facebook-seamless-m4t-v2-large.hf.space/--replicas/2bmbx/")
|
92 |
+
result = client.predict(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
text,
|
94 |
source,
|
95 |
target,
|
96 |
api_name="/t2tt"
|
97 |
+
)
|
|
|
98 |
|
99 |
#---------------------------------------------------------
|
100 |
|
|
|
124 |
with st.chat_message(message["role"], avatar="logo.png"):
|
125 |
st.write(message["content"])
|
126 |
|
127 |
+
# Create an instance of LangChain
|
128 |
+
lc = LangChain("Data_blog.txt", "TheBloke/Mistral-7B-OpenOrca-GGUF", "mistral-7b-openorca.Q4_K_M.gguf")
|
129 |
+
|
130 |
+
# Use the instance methods in your Streamlit application
|
131 |
def clear_chat_history():
|
132 |
+
lc.memory.clear()
|
133 |
+
lc.qa = lc.init_qa(lc.llm, lc.VectorStore, lc.memory)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
|
135 |
|
|
|
|
|
|
|
|
|
136 |
def generate_llm_response(prompt_input):
|
137 |
+
res = lc.qa(f'''{prompt_input}''')
|
138 |
|
139 |
if selected_language == "Darija":
|
140 |
+
translated_response = lc.translate(res['answer'])
|
141 |
return translated_response
|
142 |
else:
|
143 |
return res['answer']
|
144 |
|
145 |
+
|
146 |
# User-provided prompt
|
147 |
if prompt := st.chat_input("What is up?"):
|
148 |
if selected_language == "Darija":
|