|
from fastapi import FastAPI |
|
|
|
|
|
from txtai.embeddings import Embeddings |
|
from txtai.pipeline import Extractor |
|
from langchain.document_loaders import WebBaseLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
|
|
|
|
|
app = FastAPI(docs_url="/") |
|
|
|
|
|
embeddings = Embeddings( |
|
{"path": "sentence-transformers/all-MiniLM-L6-v2", "content": True} |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def _stream(dataset, limit, index: int = 0): |
|
for row in dataset: |
|
yield (index, row.page_content, None) |
|
index += 1 |
|
|
|
if index >= limit: |
|
break |
|
|
|
|
|
def _max_index_id(path): |
|
db = sqlite3.connect(path) |
|
|
|
table = "sections" |
|
df = pd.read_sql_query(f"select * from {table}", db) |
|
return {"max_index": df["indexid"].max()} |
|
|
|
|
|
def _prompt(question): |
|
return f"""Answer the following question using only the context below. Say 'no answer' when the question can't be answered. |
|
Question: {question} |
|
Context: """ |
|
|
|
|
|
async def _search(query, extractor, question=None): |
|
|
|
if not question: |
|
question = query |
|
|
|
return extractor([("answer", query, _prompt(question), False)])[0][1] |
|
|
|
|
|
def _text_splitter(doc): |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=500, |
|
chunk_overlap=50, |
|
length_function=len, |
|
) |
|
return text_splitter.transform_documents(doc) |
|
|
|
|
|
def _load_docs(path: str): |
|
load_doc = WebBaseLoader(path).load() |
|
doc = _text_splitter(load_doc) |
|
return doc |
|
|
|
|
|
async def _upsert_docs(doc): |
|
max_index = _max_index_id("index/documents") |
|
embeddings.upsert(_stream(doc, 500, max_index["max_index"])) |
|
embeddings.save("index") |
|
|
|
return embeddings |
|
|
|
|
|
@app.put("/rag/{path}") |
|
async def get_doc_path(path: str): |
|
return path |
|
|
|
|
|
@app.get("/rag") |
|
async def rag(question: str): |
|
|
|
embeddings.load("index") |
|
path = await get_doc_path(path) |
|
doc = _load_docs(path) |
|
embeddings = _upsert_docs(doc) |
|
|
|
|
|
extractor = Extractor(embeddings, "google/flan-t5-base") |
|
answer = await _search(question, extractor) |
|
|
|
return {answer} |
|
|