rag_llm / index.py
Jalajk's picture
Upload 8 files
edc070f
from fastapi import FastAPI
# from transformers import pipeline
from txtai.embeddings import Embeddings
from txtai.pipeline import Extractor
from langchain.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
import pandas as pd
import sqlite3
import os
# NOTE - we configure docs_url to serve the interactive Docs at the root path
# of the app. This way, we can use the docs as a landing page for the app on Spaces.
app = FastAPI(docs_url="/")
# app = FastAPI()
# pipe = pipeline("text2text-generation", model="google/flan-t5-small")
# @app.get("/generate")
# def generate(text: str):
# """
# Using the text2text-generation pipeline from `transformers`, generate text
# from the given input text. The model used is `google/flan-t5-small`, which
# can be found [here](https://huggingface.co/google/flan-t5-small).
# """
# output = pipe(text)
# return {"output": output[0]["generated_text"]}
def load_embeddings(
domain: str = "",
db_present: bool = True,
path: str = "sentence-transformers/all-MiniLM-L6-v2",
index_name: str = "index",
):
# Create embeddings model with content support
embeddings = Embeddings({"path": path, "content": True})
# if Vector DB is not present
if not db_present:
return embeddings
else:
if domain == "":
embeddings.load(index_name) # change this later
else:
print(3)
embeddings.load(f"{index_name}/{domain}")
return embeddings
def _check_if_db_exists(db_path: str) -> bool:
return os.path.exists(db_path)
def _text_splitter(doc):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
length_function=len,
)
return text_splitter.transform_documents(doc)
def _load_docs(path: str):
load_doc = WebBaseLoader(path).load()
doc = _text_splitter(load_doc)
return doc
def _stream(dataset, limit, index: int = 0):
for row in dataset:
yield (index, row.page_content, None)
index += 1
if index >= limit:
break
def _max_index_id(path):
db = sqlite3.connect(path)
table = "sections"
df = pd.read_sql_query(f"select * from {table}", db)
return {"max_index": df["indexid"].max()}
def _upsert_docs(doc, embeddings, vector_doc_path: str, db_present: bool):
print(vector_doc_path)
if db_present:
print(1)
max_index = _max_index_id(f"{vector_doc_path}/documents")
print(max_index)
embeddings.upsert(_stream(doc, 500, max_index["max_index"]))
print("Embeddings done!!")
embeddings.save(vector_doc_path)
print("Embeddings done - 1!!")
else:
print(2)
embeddings.index(_stream(doc, 500, 0))
embeddings.save(vector_doc_path)
max_index = _max_index_id(f"{vector_doc_path}/documents")
print(max_index)
# check
# max_index = _max_index_id(f"{vector_doc_path}/documents")
# print(max_index)
return max_index
# def prompt(question):
# return f"""Answer the following question using only the context below. Say 'no answer' when the question can't be answered.
# Question: {question}
# Context: """
# def search(query, question=None):
# # Default question to query if empty
# if not question:
# question = query
# return extractor([("answer", query, prompt(question), False)])[0][1]
# @app.get("/rag")
# def rag(question: str):
# # question = "what is the document about?"
# answer = search(question)
# # print(question, answer)
# return {answer}
# @app.get("/index")
# def get_url_file_path(url_path: str):
# embeddings = load_embeddings()
# doc = _load_docs(url_path)
# embeddings, max_index = _upsert_docs(doc, embeddings)
# return max_index
@app.get("/index/{domain}/")
def get_domain_file_path(domain: str, file_path: str):
print(domain, file_path)
print(os.getcwd())
bool_value = _check_if_db_exists(db_path=f"{os.getcwd()}\index\{domain}\documents")
print(bool_value)
if bool_value:
embeddings = load_embeddings(domain=domain, db_present=bool_value)
print(embeddings)
doc = _load_docs(file_path)
max_index = _upsert_docs(
doc=doc,
embeddings=embeddings,
vector_doc_path=f"index/{domain}",
db_present=bool_value,
)
# print("-------")
else:
embeddings = load_embeddings(domain=domain, db_present=bool_value)
doc = _load_docs(file_path)
max_index = _upsert_docs(
doc=doc,
embeddings=embeddings,
vector_doc_path=f"index/{domain}",
db_present=bool_value,
)
# print("Final - output : ", max_index)
return "Executed Successfully!!"