Esrath commited on
Commit
96d5d14
·
verified ·
1 Parent(s): 02c3007

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +144 -0
  2. chainlit.md +1 -0
  3. inject.py +45 -0
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import torch
4
+ import transformers
5
+ import chainlit as cl
6
+ from getpass import getpass
7
+ from dotenv import load_dotenv
8
+ from huggingface_hub import login
9
+ from transformers import AutoModel
10
+ from langchain.llms import BaseLLM
11
+ from langchain import HuggingFaceHub
12
+ from langchain_community.llms import Ollama
13
+ from langchain_community.llms import Cohere
14
+ from langchain_community.llms import LlamaCpp
15
+ from langchain.llms import HuggingFacePipeline
16
+ from langchain_community.vectorstores import FAISS
17
+ from langchain_community.llms import CTransformers
18
+ from langchain.chains import ConversationalRetrievalChain
19
+ from langchain.retrievers import ContextualCompressionRetriever
20
+ from langchain_community.embeddings import HuggingFaceEmbeddings
21
+ from langchain.retrievers.document_compressors import FlashrankRerank
22
+ from langchain.memory import ChatMessageHistory, ConversationBufferMemory
23
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
24
+ from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler
25
+
26
+ load_dotenv()
27
+ COHERE_API_KEY = os.getenv('COHERE_API_KEY')
28
+
29
+
30
+ # HUGGINGFACEHUB_API_TOKEN = getpass()
31
+ # os.environ["HUGGINGFACEHUB_API_TOKEN"] = HUGGINGFACEHUB_API_TOKEN
32
+ # load_dotenv()
33
+
34
+ # HUGGINGFACE_TOKEN = os.getenv('HUGGINGFACE_TOKEN')
35
+ # print(HUGGINGFACE_TOKEN)
36
+ # login(token = HUGGINGFACE_TOKEN)
37
+
38
+
39
+ # embeddings_model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
40
+
41
+ # from transformers import AutoModel
42
+
43
+ embeddings_model = HuggingFaceEmbeddings(
44
+ model_name="mixedbread-ai/mxbai-embed-large-v1",
45
+ model_kwargs={'device': 'cpu'},
46
+ )
47
+
48
+ # Load FIASS db index as retriever
49
+ db = FAISS.load_local("mxbai_faiss_index_v2", embeddings_model, allow_dangerous_deserialization=True)
50
+ retriever = db.as_retriever()
51
+
52
+ # Use Flashrank as rerank engine
53
+ compressor = FlashrankRerank()
54
+
55
+ # Pass reranker as base compressor and retriever as base retriever
56
+ # to ContextualCompressonRetriever.
57
+ compression_retriever = ContextualCompressionRetriever(
58
+ base_compressor=compressor, base_retriever=retriever
59
+ )
60
+
61
+ # I/0 stream
62
+ callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
63
+
64
+
65
+ #* Round 2
66
+ # llm = HuggingFaceHub(
67
+ # huggingfacehub_api_token=HUGGINGFACE_TOKEN,
68
+ # repo_id=model_id,
69
+ # model_kwargs={
70
+ # "temperature": 0.5
71
+ # }
72
+ # )
73
+
74
+ #* Round 3
75
+ # llm = CTransformers(model=model_id)
76
+ # llm = CTransformers(model='IlyaGusev/saiga_llama3_8b_gguf', model_file='model-q4_K.gguf', model_type="llama")
77
+
78
+ # llm = CTransformers(model='../../data_test/Meta-Llama-3-8B.Q4_K_M.gguf', model_type='llama')
79
+
80
+ #* Round 4
81
+ # n_gpu_layers = 15
82
+ # n_batch = 128
83
+ # llm = LlamaCpp(
84
+ # model_path="../../data_test/Meta-Llama-3-8B.Q4_K_M.gguf",
85
+ # # n_ctx = 1024,
86
+ # n_gpu_layers=n_gpu_layers,
87
+ # n_batch=n_batch,
88
+ # f16_kv=True,
89
+ # callback_manager=callback_manager,
90
+ # verbose=True,
91
+ # )
92
+
93
+ # llm = Ollama(model="llama3", temperature=0.2)
94
+ llm = Cohere(temperature=0.2)
95
+
96
+ @cl.on_chat_start
97
+ async def on_chat_start():
98
+
99
+ message_history = ChatMessageHistory()
100
+
101
+ memory = ConversationBufferMemory(
102
+ memory_key="chat_history",
103
+ output_key="answer",
104
+ chat_memory=message_history,
105
+ return_messages=True,
106
+ )
107
+
108
+ chain = ConversationalRetrievalChain.from_llm(
109
+ llm,
110
+ chain_type="stuff",
111
+ retriever=compression_retriever,
112
+ memory=memory,
113
+ return_source_documents=True,
114
+ )
115
+
116
+ cl.user_session.set("chain", chain)
117
+
118
+ #TODO: Stream response
119
+ @cl.on_message
120
+ async def main(message: cl.Message):
121
+ chain = cl.user_session.get("chain")
122
+ cb = cl.AsyncLangchainCallbackHandler()
123
+
124
+ res = await chain.acall(message.content, callbacks=[cb])
125
+ answer = res["answer"]
126
+ source_documents = res["source_documents"]
127
+
128
+ text_elements = []
129
+
130
+ #* Returning Sources
131
+ if source_documents:
132
+ for source_idx, source_doc in enumerate(source_documents):
133
+ source_name = f"source_{source_idx+1}"
134
+ text_elements.append(
135
+ cl.Text(content=source_doc.page_content, name=source_name)
136
+ )
137
+ source_names = [text_el.name for text_el in text_elements]
138
+
139
+ if source_names:
140
+ answer += f"\nSources: {', '.join(source_names)}"
141
+ else:
142
+ answer += "\nNo sources found"
143
+
144
+ await cl.Message(content=answer, elements=text_elements, author="Brocxi").send()
chainlit.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # Hi! I am Brocxi, your virtual assistant for the God of War Ragnarok. I will be your game guide during your adventure through the landscapes of Norse mythology.
inject.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from transformers import AutoModel
4
+ from langchain.storage import LocalFileStore
5
+ from langchain.document_loaders import TextLoader
6
+ from langchain_community.vectorstores import FAISS
7
+ from langchain.embeddings import CacheBackedEmbeddings
8
+ from langchain.embeddings import HuggingFaceEmbeddings
9
+ from langchain_community.document_loaders import DirectoryLoader
10
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
11
+
12
+
13
+
14
+ # cache_store = LocalFileStore("./mxbai_cache_v2/")
15
+
16
+ # Load txt files from dir
17
+ loader = DirectoryLoader('../extracted_files', glob="*.txt", loader_cls=TextLoader, show_progress=True)
18
+ docs = loader.load()
19
+
20
+ # Chunking
21
+ text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
22
+ chunk_size=256,
23
+ chunk_overlap=64,
24
+ )
25
+ chunked = text_splitter.split_documents(docs)
26
+
27
+ # model = AutoModel.from_pretrained('mixedbread-ai/mxbai-embed-large-v1', trust_remote_code=True)
28
+
29
+ model_name = "mixedbread-ai/mxbai-embed-large-v1"
30
+ model_kwargs = {'device': 'cpu'}
31
+ embeddings_model = HuggingFaceEmbeddings(
32
+ model_name=model_name,
33
+ model_kwargs=model_kwargs,
34
+ )
35
+
36
+ # embeddings_model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
37
+
38
+ cached_embedder = CacheBackedEmbeddings.from_bytes_store(
39
+ embeddings_model, cache_store, namespace="mixedbread-ai/mxbai-embed-large-v1")
40
+
41
+ db = FAISS.from_documents(chunked, cached_embedder)
42
+
43
+ db.save_local("mxbai_faiss_index_v2")
44
+
45
+ print("Embeddings saved ...")