rag_search / main.py
BhanuPrakashSamoju's picture
Adding the Text Generator (#1)
00a076c
from fastapi import FastAPI
from txtai.embeddings import Embeddings
from txtai.pipeline import Extractor
import os
from langchain import HuggingFaceHub
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
# from transformers import pipeline
# NOTE - we configure docs_url to serve the interactive Docs at the root path
# of the app. This way, we can use the docs as a landing page for the app on Spaces.
app = FastAPI(docs_url="/")
# @app.get("/generate")
# def generate(text: str):
# """
# Using the text2text-generation pipeline from `transformers`, generate text
# from the given input text. The model used is `google/flan-t5-small`, which
# can be found [here](https://huggingface.co/google/flan-t5-small).
# """
# output = pipe(text)
# return {"output": output[0]["generated_text"]}
def _check_if_db_exists(db_path: str) -> bool:
return os.path.exists(db_path)
def _load_embeddings_from_db(
db_present: bool,
domain: str,
path: str = "sentence-transformers/all-MiniLM-L6-v2",
):
# Create embeddings model with content support
embeddings = Embeddings({"path": path, "content": True})
# if Vector DB is not present
if not db_present:
return embeddings
else:
if domain == "":
embeddings.load("index") # change this later
else:
print(3)
embeddings.load(f"index/{domain}")
return embeddings
def _prompt(question):
return f"""Answer the following question using only the context below. Say 'no answer' when the question can't be answered.
Question: {question}
Context: """
def _search(query, extractor, question=None):
# Default question to query if empty
if not question:
question = query
# template = f"""Answer the following question using only the context below. Say 'no answer' when the question can't be answered.
# Question: {question}
# Context: """
# prompt = PromptTemplate(template=template, input_variables=["question"])
# llm_chain = LLMChain(prompt=prompt, llm=extractor)
# return {"question": question, "answer": llm_chain.run(question)}
return extractor([("answer", query, _prompt(question), False)])[0][1]
@app.get("/rag")
def rag(domain: str, question: str):
db_exists = _check_if_db_exists(db_path=f"{os.getcwd()}\index\{domain}\documents")
print(db_exists)
# if db_exists:
embeddings = _load_embeddings_from_db(db_exists, domain)
# Create extractor instance
extractor = Extractor(embeddings, "google/flan-t5-base")
# llm = HuggingFaceHub(
# repo_id="google/flan-t5-xxl",
# model_kwargs={"temperature": 1, "max_length": 1000000},
# )
# else:
answer = _search(question, extractor)
return {"question": question, "answer": answer}