File size: 2,790 Bytes
ac47d9a 728e599 ac47d9a d79d76a b5e1894 ac47d9a d79d76a b5e1894 ac47d9a d79d76a ac47d9a d79d76a ac47d9a 0ab7ce5 ac47d9a e9a9ef1 ac47d9a b5e1894 e9a9ef1 d79d76a b5e1894 d79d76a ac47d9a b5e1894 ac47d9a d79d76a ac47d9a b5e1894 ac47d9a d79d76a b5e1894 d79d76a b5e1894 d79d76a ac47d9a d79d76a 0ab7ce5 d79d76a ac47d9a 7b42316 d79d76a 7b42316 728e599 7b42316 d79d76a b5e1894 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
from transformers import T5Tokenizer, T5ForConditionalGeneration
from sentence_transformers import SentenceTransformer
from pinecone import Pinecone
device = 'cpu'
# Calling the pinecone api
pc = Pinecone(api_key='89eeb534-da10-4068-92f7-12eddeabe1e5')
# Connect to the Pinecone index for querying and storing vectors
index_name = 'abstractive-question-answering'
index = pc.Index(index_name)
# Load the retriever model for sentence embeddings and the T5 model for text generation
def load_models():
print("Loading models...")
retriever = SentenceTransformer("flax-sentence-embeddings/all_datasets_v3_mpnet-base")
tokenizer = T5Tokenizer.from_pretrained('t5-small')
generator = T5ForConditionalGeneration.from_pretrained('t5-base').to(device)
return retriever, generator, tokenizer
print("Done loading models")
retriever, generator, tokenizer = load_models()
def process_query(query):
print("Processing...")
# Encode the query into a vector for semantic search using SentenceTransformer
xq = retriever.encode([query]).tolist()
# Query the Pinecone index for the most similar vector to the query
xc = index.query(vector=xq, top_k=1, include_metadata=True)
print("Pinecone response:", xc)
# Concatenates the original question with the context extracted from the matched metadata
if 'matches' in xc and isinstance(xc['matches'], list):
context = [m['metadata']['Output'] for m in xc['matches']]
context_str = " ".join(context)
formatted_query = f"answer the question: {query} context: {context_str}"
# If the context is longer than 5 lines, return the context extracted from Pinecone directly
output_text = context_str
if len(output_text.splitlines()) > 5:
return output_text
# If none, then it will return that it was not covered in the student manual
if output_text.lower() == "none":
return "The topic is not covered in the student manual."
# Tokenizes the formatted query
inputs = tokenizer.encode(formatted_query, return_tensors="pt", max_length=512, truncation=True).to(device)
# Generates an answer using the t5 model
ids = generator.generate(inputs, num_beams=2, min_length=10, max_length=60, repetition_penalty=1.2)
# Decodes the answer to make it readable for the user
answer = tokenizer.decode(ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
# If it has this words, it will just paste the output from the extracted meta-data output from pinecone
nli_keywords = ['not_equivalent', 'not_entailment', 'entailment', 'neutral', 'not_enquiry']
if any(keyword in answer.lower() for keyword in nli_keywords):
return context_str
# returns the answer
return answer
|