kastan's picture
initial commit, fixing chat history
493fae2
raw
history blame
3.93 kB
import os
from typing import Any, Dict, List
# for vector search
import pinecone # cloud-hosted vector database for context retrieval
# for auto-gpu selection
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Pinecone
# load custom code
from clip_for_ppts import ClipImage
# from gpu_memory_utils import (get_device_with_most_free_memory,
# get_free_memory_dict,
# get_gpu_ids_with_sufficient_memory)
# from PIL import Image
LECTURE_SLIDES_DIR = os.path.join(os.getcwd(), "lecture_slides")
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
class Retrieval:
def __init__(self, device='cuda', use_clip=True):
self.user_question = ''
self.max_text_length = None
self.pinecone_index_name = 'uiuc-chatbot' # uiuc-chatbot-v2
self.use_clip = use_clip
self.clip_search_class = None
# init parameters
self.device = device
self.num_answers_generated = 3
self.vectorstore = None
# Load everything into cuda memory
self.load_modules()
def load_modules(self):
self._load_pinecone_vectorstore()
if self.use_clip:
self._load_clip()
else:
print("CLIP IS MANUALLY DISABLED for speed.. REENABLE LATER. ")
def _load_pinecone_vectorstore(self,):
model_name = "intfloat/e5-large" # best text embedding model. 1024 dims.
embeddings = HuggingFaceEmbeddings(model_name=model_name)
#pinecone.init(api_key=os.environ['PINECONE_API_KEY'], environment="us-west1-gcp")
pinecone.init(api_key=PINECONE_API_KEY, environment="us-west1-gcp")
pincecone_index = pinecone.Index("uiuc-chatbot")
self.vectorstore = Pinecone(index=pincecone_index, embedding_function=embeddings.embed_query, text_key="text")
def retrieve_contexts_from_pinecone(self, user_question: str, topk: int = None) -> List[Any]:
'''
Invoke Pinecone for vector search. These vector databases are created in the notebook `data_formatting_patel.ipynb` and `data_formatting_student_notes.ipynb`.
Returns a list of LangChain Documents. They have properties: `doc.page_content`: str, doc.metadata['page_number']: int, doc.metadata['textbook_name']: str.
'''
try:
# catch other models that have different prompting
user_question = user_question.split("<|prompter|>")[-1]
except Exception as e:
print("Failed to split user question: ", e)
if topk is None:
topk = self.num_answers_generated
# similarity search
top_context_list = self.vectorstore.similarity_search(user_question, k=topk)
# add the source info to the bottom of the context.
top_context_metadata = [f"Source: page {doc.metadata['page_number']} in {doc.metadata['textbook_name']}" for doc in top_context_list]
relevant_context_list = [f"{text.page_content}. {meta}" for text, meta in zip(top_context_list, top_context_metadata)]
return relevant_context_list
def _load_clip(self):
self.clip_search_class = ClipImage(path_of_ppt_folders=LECTURE_SLIDES_DIR,
path_to_save_image_features=os.getcwd(),
mode='text',
device='cuda')
def reverse_img_search(self, img):
imgs = self.clip_search_class.image_to_images_search(img)
img_path_list = []
for img in imgs:
img_path_list.append(os.path.join(LECTURE_SLIDES_DIR, img[0], img[1]))
return img_path_list
def clip_text_to_image(self, search_question: str, num_images_returned: int = 4):
"""
Run CLIP
Returns a list of images in all cases.
"""
imgs = self.clip_search_class.text_to_image_search(search_text=search_question, top_k_to_return=num_images_returned)
img_path_list = []
for img in imgs:
img_path_list.append(os.path.join(LECTURE_SLIDES_DIR, img[0], img[1]))
return img_path_list