Spaces:
Runtime error
Runtime error
import os | |
from typing import List | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from llama_index.vector_stores.milvus import MilvusVectorStore | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
# from llama_index.core.postprocessor import SentenceTransformerRerank | |
from llama_index.core import VectorStoreIndex | |
from llama_index.core import Settings | |
app = FastAPI() | |
# rerank = SentenceTransformerRerank( | |
# model="cross-encoder/ms-marco-MiniLM-L-2-v2", top_n=3 | |
# ) | |
def greet_json(): | |
return {"Hello": "World!"} | |
class SearchRequest(BaseModel): | |
query: str | |
limit: int = 10 | |
class Metadata(BaseModel): | |
window: str | |
original_text: str | |
class MyNodeWithScore(BaseModel): | |
node: Metadata | |
relationships: List[Metadata] | |
score: float | |
class MyResult(BaseModel): | |
results: List[MyNodeWithScore] | |
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5", | |
cache_folder=".cache") | |
vector_store = MilvusVectorStore( | |
overwrite=False, | |
uri=os.getenv('MILVUS_CLOUD_URI'), | |
token=os.getenv('MILVUS_CLOUD_TOKEN'), | |
collection_name=os.getenv('COLLECTION_NAME'), | |
dim=384, | |
) | |
def search(search_request: SearchRequest): | |
sentence_index = VectorStoreIndex.from_vector_store(vector_store=vector_store) | |
retriever = sentence_index.as_retriever( | |
include_text=True, # include source chunk with matching paths | |
similarity_top_k=search_request.limit, | |
# node_postprocessors=[rerank] | |
) | |
result_retriever_engine = retriever.retrieve(search_request.query) | |
node_with_score_list = MyResult(results=[MyNodeWithScore( | |
node=Metadata(window=result.metadata['window'], | |
original_text=result.metadata['original_text']), | |
relationships=[ | |
Metadata(window=relationship.metadata.get('window', " "), | |
original_text=relationship.metadata.get('original_text', " ") | |
) for key, relationship in result.node.relationships.items() | |
], | |
score=result.get_score()) for result in result_retriever_engine]) | |
# node_with_score_list = [json.loads(result.json()) for result in query_engine] | |
return node_with_score_list |