chat-llm-streaming / retrieval.py
star_nox
added retrrieval to UI
95445f2
raw
history blame
No virus
2.53 kB
import json
import os
import pathlib
import sys
import time
from typing import Any, Dict, List
import pinecone # cloud-hosted vector database for context retrieval
# for vector search
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Pinecone
from dotenv import load_dotenv
from PIL import Image
from transformers import (AutoModelForSequenceClassification, AutoTokenizer, GPT2Tokenizer, OPTForCausalLM, T5ForConditionalGeneration)
class Retrieval:
def __init__(self,
device='cuda',
use_clip=True):
self.user_question = ''
self.max_text_length = None
self.pinecone_index_name = 'uiuc-chatbot' # uiuc-chatbot-v2
self.use_clip = use_clip
# init parameters
self.device = device
self.num_answers_generated = 3
self.vectorstore = None
def _load_pinecone_vectorstore(self,):
model_name = "intfloat/e5-large" # best text embedding model. 1024 dims.
embeddings = HuggingFaceEmbeddings(model_name=model_name)
#pinecone.init(api_key=os.environ['PINECONE_API_KEY'], environment="us-west1-gcp")
pinecone.init(api_key=PINECONE_API_KEY, environment="us-west1-gcp")
pincecone_index = pinecone.Index("uiuc-chatbot")
self.vectorstore = Pinecone(index=pincecone_index, embedding_function=embeddings.embed_query, text_key="text")
def retrieve_contexts_from_pinecone(self, user_question: str, topk: int = None) -> List[Any]:
'''
Invoke Pinecone for vector search. These vector databases are created in the notebook `data_formatting_patel.ipynb` and `data_formatting_student_notes.ipynb`.
Returns a list of LangChain Documents. They have properties: `doc.page_content`: str, doc.metadata['page_number']: int, doc.metadata['textbook_name']: str.
'''
print("USER QUESTION: ", user_question)
print("TOPK: ", topk)
if topk is None:
topk = self.num_answers_generated
# similarity search
top_context_list = self.vectorstore.similarity_search(user_question, k=topk)
# add the source info to the bottom of the context.
top_context_metadata = [f"Source: page {doc.metadata['page_number']} in {doc.metadata['textbook_name']}" for doc in top_context_list]
relevant_context_list = [f"{text.page_content}. {meta}" for text, meta in zip(top_context_list, top_context_metadata)]
return relevant_context_list