JasperV13 commited on
Commit
fc4c10f
·
1 Parent(s): 6877dc0

added class

Browse files
Files changed (1) hide show
  1. app.py +78 -109
app.py CHANGED
@@ -13,113 +13,88 @@ from langchain.chains import ConversationalRetrievalChain
13
  from huggingface_hub import hf_hub_download
14
  from langchain.llms import LlamaCpp
15
  from langchain.chains import LLMChain
16
-
17
  import time
18
  import streamlit as st
19
 
20
 
21
 
22
- loader = TextLoader("Data_blog.txt")
23
- pages = loader.load()
24
-
25
- def split_text(documents: list[Document]):
26
- text_splitter = RecursiveCharacterTextSplitter(
27
- chunk_size=1000,
28
- chunk_overlap=150,
29
- length_function=len,
30
- add_start_index=True,
31
- )
32
- chunks = text_splitter.split_documents(documents)
33
- print(f"Split {len(documents)} documents into {len(chunks)} chunks.")
34
-
35
- document = chunks[10]
36
- print(document.page_content)
37
- print(document.metadata)
38
-
39
- return chunks
40
-
41
- chunks_text = split_text(pages)
42
-
43
- embedding = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2') # machi top
44
-
45
- docs_text = [doc.page_content for doc in chunks_text]
46
-
47
-
48
- VectorStore = FAISS.from_texts(docs_text, embedding=embedding)
49
-
50
- MODEL_ID = "TheBloke/Mistral-7B-OpenOrca-GGUF"
51
- MODEL_BASENAME = "mistral-7b-openorca.Q4_K_M.gguf"
 
52
 
53
- model_path = hf_hub_download(
54
- repo_id=MODEL_ID,
55
- filename=MODEL_BASENAME,
 
56
  resume_download=True,
57
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- print("model_path : ", model_path)
60
-
61
- callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
62
-
63
- CONTEXT_WINDOW_SIZE = 1500
64
- MAX_NEW_TOKENS = 2000
65
- N_BATCH = 512
66
- n_gpu_layers = 40
67
- kwargs = {
68
- "model_path": model_path,
69
- "n_ctx": CONTEXT_WINDOW_SIZE,
70
- "max_tokens": MAX_NEW_TOKENS,
71
- "n_batch": N_BATCH,
72
- "n_gpu_layers": n_gpu_layers,
73
- "callback_manager": callback_manager,
74
- "verbose":True,
75
- }
76
-
77
-
78
- # Callbacks support token-wise streaming
79
- callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
80
-
81
- n_gpu_layers = 40 # Change this value based on your model and your GPU VRAM pool.
82
- n_batch = 512 # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
83
- max_tokens = 2000
84
- # Make sure the model path is correct for your system!
85
- llm = LlamaCpp(
86
- model_path=model_path,
87
- n_gpu_layers=n_gpu_layers,
88
-
89
- n_batch=n_batch,
90
- max_tokens= max_tokens,
91
- callback_manager=callback_manager,
92
- verbose=True, # Verbose is required to pass to the callback manager
93
- )
94
-
95
- llm = LlamaCpp(**kwargs)
96
-
97
- memory = ConversationBufferMemory(
98
- memory_key="chat_history",
99
- return_messages=True,
100
- input_key='question',
101
- output_key='answer'
102
- )
103
-
104
- # memory.clear()
105
-
106
- qa = ConversationalRetrievalChain.from_llm(
107
- llm,
108
- chain_type="stuff",
109
- retriever=VectorStore.as_retriever(search_kwargs={"k": 5}),
110
- memory=memory,
111
- return_source_documents=True,
112
- verbose=False,
113
- )
114
- def translate(text, source="English", target="Moroccan Arabic"):
115
- client = Client("https://facebook-seamless-m4t-v2-large.hf.space/--replicas/2bmbx/")
116
- result = client.predict(
117
  text,
118
  source,
119
  target,
120
  api_name="/t2tt"
121
- )
122
- return result
123
 
124
  #---------------------------------------------------------
125
 
@@ -149,31 +124,25 @@ for message in st.session_state.messages:
149
  with st.chat_message(message["role"], avatar="logo.png"):
150
  st.write(message["content"])
151
 
 
 
 
 
152
  def clear_chat_history():
153
- memory.clear()
154
- qa = ConversationalRetrievalChain.from_llm(
155
- llm,
156
- chain_type="stuff",
157
- retriever=VectorStore.as_retriever(search_kwargs={"k": 5}),
158
- memory=memory,
159
- return_source_documents=True,
160
- verbose=False,
161
- )
162
  st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
163
 
164
- st.sidebar.button('Clear Chat History', on_click=clear_chat_history)
165
- selected_language = st.sidebar.selectbox("Select Language", ["English", "Darija"], index=0) # English is the default
166
-
167
- # Function for generating LLaMA2 response
168
  def generate_llm_response(prompt_input):
169
- res = qa(f'''{prompt_input}''')
170
 
171
  if selected_language == "Darija":
172
- translated_response = translate(res['answer'])
173
  return translated_response
174
  else:
175
  return res['answer']
176
 
 
177
  # User-provided prompt
178
  if prompt := st.chat_input("What is up?"):
179
  if selected_language == "Darija":
 
13
  from huggingface_hub import hf_hub_download
14
  from langchain.llms import LlamaCpp
15
  from langchain.chains import LLMChain
 
16
  import time
17
  import streamlit as st
18
 
19
 
20
 
21
+ class MyBot:
22
+ def __init__(self, text_file, model_id, model_basename):
23
+ self.text_file = text_file
24
+ self.model_id = model_id
25
+ self.model_basename = model_basename
26
+ self.loader = TextLoader(self.text_file)
27
+ self.pages = self.loader.load()
28
+ self.chunks_text = self.split_text(self.pages)
29
+ self.docs_text = [doc.page_content for doc in self.chunks_text]
30
+ self.embedding = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
31
+ self.VectorStore = FAISS.from_texts(self.docs_text, embedding=self.embedding)
32
+ self.model_path = self.download_model(self.model_id, self.model_basename)
33
+ self.callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
34
+ self.llm = self.init_llm(self.model_path, self.callback_manager)
35
+ self.memory = ConversationBufferMemory(
36
+ memory_key="chat_history",
37
+ return_messages=True,
38
+ input_key='question',
39
+ output_key='answer'
40
+ )
41
+ self.qa = self.init_qa(self.llm, self.VectorStore, self.memory)
42
+
43
+ def split_text(self, documents):
44
+ text_splitter = RecursiveCharacterTextSplitter(
45
+ chunk_size=1000,
46
+ chunk_overlap=150,
47
+ length_function=len,
48
+ add_start_index=True,
49
+ )
50
+ chunks = text_splitter.split_documents(documents)
51
+ return chunks
52
 
53
+ def download_model(self, model_id, model_basename):
54
+ model_path = hf_hub_download(
55
+ repo_id=model_id,
56
+ filename=model_basename,
57
  resume_download=True,
58
  )
59
+ print("model_path : ", model_path)
60
+ return model_path
61
+
62
+ def init_llm(self, model_path, callback_manager):
63
+ CONTEXT_WINDOW_SIZE = 1500
64
+ MAX_NEW_TOKENS = 2000
65
+ N_BATCH = 512
66
+ n_gpu_layers = 40
67
+ kwargs = {
68
+ "model_path": model_path,
69
+ "n_ctx": CONTEXT_WINDOW_SIZE,
70
+ "max_tokens": MAX_NEW_TOKENS,
71
+ "n_batch": N_BATCH,
72
+ "n_gpu_layers": n_gpu_layers,
73
+ "callback_manager": callback_manager,
74
+ "verbose":True,
75
+ }
76
+ llm = LlamaCpp(**kwargs)
77
+ return llm
78
+
79
+ def init_qa(self, llm, VectorStore, memory):
80
+ qa = ConversationalRetrievalChain.from_llm(
81
+ llm,
82
+ chain_type="stuff",
83
+ retriever=VectorStore.as_retriever(search_kwargs={"k": 5}),
84
+ memory=memory,
85
+ return_source_documents=True,
86
+ verbose=False,
87
+ )
88
+ return qa
89
 
90
+ def translate(self, text, source="English", target="Moroccan Arabic"):
91
+ client = Client("https://facebook-seamless-m4t-v2-large.hf.space/--replicas/2bmbx/")
92
+ result = client.predict(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  text,
94
  source,
95
  target,
96
  api_name="/t2tt"
97
+ )
 
98
 
99
  #---------------------------------------------------------
100
 
 
124
  with st.chat_message(message["role"], avatar="logo.png"):
125
  st.write(message["content"])
126
 
127
+ # Create an instance of LangChain
128
+ lc = LangChain("Data_blog.txt", "TheBloke/Mistral-7B-OpenOrca-GGUF", "mistral-7b-openorca.Q4_K_M.gguf")
129
+
130
+ # Use the instance methods in your Streamlit application
131
  def clear_chat_history():
132
+ lc.memory.clear()
133
+ lc.qa = lc.init_qa(lc.llm, lc.VectorStore, lc.memory)
 
 
 
 
 
 
 
134
  st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
135
 
 
 
 
 
136
  def generate_llm_response(prompt_input):
137
+ res = lc.qa(f'''{prompt_input}''')
138
 
139
  if selected_language == "Darija":
140
+ translated_response = lc.translate(res['answer'])
141
  return translated_response
142
  else:
143
  return res['answer']
144
 
145
+
146
  # User-provided prompt
147
  if prompt := st.chat_input("What is up?"):
148
  if selected_language == "Darija":