File size: 2,508 Bytes
8c3e214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from fastapi import FastAPI

# from transformers import pipeline
from txtai.embeddings import Embeddings
from txtai.pipeline import Extractor
from langchain.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

# NOTE - we configure docs_url to serve the interactive Docs at the root path
# of the app. This way, we can use the docs as a landing page for the app on Spaces.
app = FastAPI(docs_url="/")

# Create embeddings model with content support
embeddings = Embeddings(
    {"path": "sentence-transformers/all-MiniLM-L6-v2", "content": True}
)


# Create extractor instance
# extractor = Extractor(embeddings, "google/flan-t5-base")


def _stream(dataset, limit, index: int = 0):
    for row in dataset:
        yield (index, row.page_content, None)
        index += 1

        if index >= limit:
            break


def _max_index_id(path):
    db = sqlite3.connect(path)

    table = "sections"
    df = pd.read_sql_query(f"select * from {table}", db)
    return {"max_index": df["indexid"].max()}


def _prompt(question):
    return f"""Answer the following question using only the context below. Say 'no answer' when the question can't be answered.
            Question: {question}
            Context: """


async def _search(query, extractor, question=None):
    # Default question to query if empty
    if not question:
        question = query

    return extractor([("answer", query, _prompt(question), False)])[0][1]


def _text_splitter(doc):
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=500,
        chunk_overlap=50,
        length_function=len,
    )
    return text_splitter.transform_documents(doc)


def _load_docs(path: str):
    load_doc = WebBaseLoader(path).load()
    doc = _text_splitter(load_doc)
    return doc


async def _upsert_docs(doc):
    max_index = _max_index_id("index/documents")
    embeddings.upsert(_stream(doc, 500, max_index["max_index"]))
    embeddings.save("index")

    return embeddings


@app.put("/rag/{path}")
async def get_doc_path(path: str):
    return path


@app.get("/rag")
async def rag(question: str):
    # question = "what is the document about?"
    embeddings.load("index")
    path = await get_doc_path(path)
    doc = _load_docs(path)
    embeddings = _upsert_docs(doc)

    # Create extractor instance
    extractor = Extractor(embeddings, "google/flan-t5-base")
    answer = await _search(question, extractor)
    # print(question, answer)
    return {answer}