File size: 2,623 Bytes
2da3321
 
606b7eb
2da3321
 
 
 
 
 
606b7eb
 
 
2da3321
 
 
 
606b7eb
 
2da3321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fddd9db
 
 
 
 
 
 
 
 
 
 
 
2da3321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
from typing import List
from fastapi import FastAPI
from pydantic import BaseModel
from llama_index.vector_stores.milvus import MilvusVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
# from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.core import VectorStoreIndex
from llama_index.core import Settings

app = FastAPI()

# rerank = SentenceTransformerRerank(
#     model="cross-encoder/ms-marco-MiniLM-L-2-v2", top_n=3
# )

@app.get("/")
def greet_json():
    return {"Hello": "World!"}

class SearchRequest(BaseModel):
    query: str
    limit: int = 10

class Metadata(BaseModel):
    window: str
    original_text: str

class MyNodeWithScore(BaseModel):
    node: Metadata
    relationships: List[Metadata]
    score: float

class MyResult(BaseModel):
    results: List[MyNodeWithScore]


Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5",
                                            cache_folder=".cache")

vector_store = MilvusVectorStore(
    overwrite=False,
    uri=os.getenv('MILVUS_CLOUD_URI'),
    token=os.getenv('MILVUS_CLOUD_TOKEN'),
    collection_name=os.getenv('COLLECTION_NAME'),
    dim=384,
)


@app.post("/search/")
def search(search_request: SearchRequest):
    sentence_index = VectorStoreIndex.from_vector_store(vector_store=vector_store)

    retriever = sentence_index.as_retriever(
        include_text=True,  # include source chunk with matching paths
        similarity_top_k=search_request.limit,
        # node_postprocessors=[rerank]
    )

    result_retriever_engine = retriever.retrieve(search_request.query)



    node_with_score_list = MyResult(results=[MyNodeWithScore(
                                                node=Metadata(window=result.metadata['window'],
                                                              original_text=result.metadata['original_text']),
                                                relationships=[
                                                    Metadata(window=relationship.metadata.get('window', " "),
                                                             original_text=relationship.metadata.get('original_text', " ")
                                                             ) for key, relationship in result.node.relationships.items()
                                                ],
                                                score=result.get_score()) for result in result_retriever_engine])

    # node_with_score_list = [json.loads(result.json()) for result in query_engine]

    return node_with_score_list