arjunanand13 commited on
Commit
fc583d6
1 Parent(s): 7b1bacb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -29
app.py CHANGED
@@ -3,11 +3,11 @@ import json
3
  from torch import cuda, bfloat16
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList
5
  from langchain.llms import HuggingFacePipeline
 
 
6
  import gradio as gr
7
- import os
8
- import faiss
9
- import numpy as np
10
  from langchain.embeddings import HuggingFaceEmbeddings
 
11
 
12
  class Chatbot:
13
  def __init__(self):
@@ -37,19 +37,14 @@ class Chatbot:
37
  )
38
  self.llm = HuggingFacePipeline(pipeline=self.generate_text)
39
 
40
- # Initialize the embedding model
41
- self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2", model_kwargs={"device": "cuda"})
42
-
43
  try:
44
- # Initialize FAISS with GPU support
45
- cpu_index = faiss.read_index('faiss_index_new_model3.index')
46
- res = faiss.StandardGpuResources() # Use this to allocate the GPU resources
47
- co = faiss.GpuClonerOptions()
48
- co.useFloat16 = True # Enable float16 for better performance
49
- self.vectorstore = faiss.index_cpu_to_gpu(res, 0, cpu_index, co)
50
  print("Loaded embedding successfully")
51
- except Exception as e:
52
- print("FAISS could not be imported or index could not be loaded.")
53
  raise e
54
 
55
  self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.vectorstore.as_retriever(), return_source_documents=True)
@@ -63,10 +58,10 @@ class Chatbot:
63
  return False
64
 
65
  def format_prompt(self, query):
66
- prompt = f"""
67
  You are a knowledgeable assistant with access to a comprehensive database.
68
  I need you to answer my question and provide related information in a specific format.
69
- I have provided four relatable json files, choose the most suitable chunks for answering the query.
70
  Here's what I need:
71
  Include a final answer without additional comments, sign-offs, or extra phrases. Be direct and to the point.
72
 
@@ -86,20 +81,10 @@ class Chatbot:
86
  def qa_infer(self, query):
87
  content = ""
88
  formatted_prompt = self.format_prompt(query)
89
-
90
- # Embed the query
91
- query_embedding = self.embeddings.embed_query(formatted_prompt)
92
-
93
- # Perform the search
94
- distances, indices = self.vectorstore.search(np.array([query_embedding]), k=5)
95
-
96
- # Retrieve the top documents
97
- for idx in indices[0]:
98
- doc = self.vectorstore.get_document(idx)
99
  content += "-" * 50 + "\n"
100
  content += doc.page_content + "\n"
101
-
102
- result = self.chain({"question": formatted_prompt, "chat_history": self.chat_history})
103
  print(content)
104
  print("#" * 100)
105
  print(result['answer'])
@@ -158,4 +143,4 @@ class Chatbot:
158
 
159
  # Instantiate and launch the chatbot
160
  chatbot = Chatbot()
161
- chatbot.launch_interface()
 
3
  from torch import cuda, bfloat16
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList
5
  from langchain.llms import HuggingFacePipeline
6
+ from langchain.vectorstores import FAISS
7
+ from langchain.chains import ConversationalRetrievalChain
8
  import gradio as gr
 
 
 
9
  from langchain.embeddings import HuggingFaceEmbeddings
10
+ import os
11
 
12
  class Chatbot:
13
  def __init__(self):
 
37
  )
38
  self.llm = HuggingFacePipeline(pipeline=self.generate_text)
39
 
 
 
 
40
  try:
41
+ # self.vectorstore = FAISS.load_local('faiss_index', HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2", model_kwargs={"device": "cuda"}))
42
+ self.vectorstore = FAISS.load_local('faiss_index_new_model3.index', HuggingFaceEmbeddings(model_name="flax-sentence-embeddings/all_datasets_v3_MiniLM-L12", model_kwargs={"device": "cuda"}))
43
+ # cpu_index = faiss.read_index('faiss_index_new_model3.index')
44
+ # gpu_index = faiss.index_cpu_to_gpu(gpu_resource, 0, cpu_index)
 
 
45
  print("Loaded embedding successfully")
46
+ except ImportError as e:
47
+ print("FAISS could not be imported. Make sure FAISS is installed correctly.")
48
  raise e
49
 
50
  self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.vectorstore.as_retriever(), return_source_documents=True)
 
58
  return False
59
 
60
  def format_prompt(self, query):
61
+ prompt=f"""
62
  You are a knowledgeable assistant with access to a comprehensive database.
63
  I need you to answer my question and provide related information in a specific format.
64
+ I have provided four relatable json files , choose the most suitable chunks for answering the query
65
  Here's what I need:
66
  Include a final answer without additional comments, sign-offs, or extra phrases. Be direct and to the point.
67
 
 
81
  def qa_infer(self, query):
82
  content = ""
83
  formatted_prompt = self.format_prompt(query)
84
+ result = self.chain({"question": formatted_prompt, "chat_history": self.chat_history})
85
+ for doc in result['source_documents']:
 
 
 
 
 
 
 
 
86
  content += "-" * 50 + "\n"
87
  content += doc.page_content + "\n"
 
 
88
  print(content)
89
  print("#" * 100)
90
  print(result['answer'])
 
143
 
144
  # Instantiate and launch the chatbot
145
  chatbot = Chatbot()
146
+ chatbot.launch_interface()