File size: 4,526 Bytes
11f96c1 63df3f2 11f96c1 63df3f2 11f96c1 63df3f2 11f96c1 63df3f2 11f96c1 63df3f2 11f96c1 63df3f2 de5a712 63df3f2 11f96c1 63df3f2 de5a712 63df3f2 11f96c1 63df3f2 11f96c1 63df3f2 11f96c1 63df3f2 11f96c1 63df3f2 de5a712 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
from fastapi import FastAPI, HTTPException, Query, Path
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List
import json
import os
import logging
from txtai.embeddings import Embeddings
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="Embeddings API",
description="An API for creating and querying text embeddings indexes.",
version="1.0.0"
)
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
embeddings = Embeddings({"path": "avsolatorio/GIST-all-MiniLM-L6-v2"})
class DocumentRequest(BaseModel):
index_id: str = Field(..., description="Unique identifier for the index")
documents: List[str] = Field(..., description="List of documents to be indexed")
class QueryRequest(BaseModel):
index_id: str = Field(..., description="Unique identifier for the index to query")
query: str = Field(..., description="The search query")
num_results: int = Field(..., description="Number of results to return", ge=1)
def save_embeddings(index_id: str, document_list: List[str]):
try:
folder_path = f"/app/indexes/{index_id}"
os.makedirs(folder_path, exist_ok=True)
# Save embeddings
embeddings.save(f"{folder_path}/embeddings")
# Save document_list
with open(f"{folder_path}/document_list.json", "w") as f:
json.dump(document_list, f)
logger.info(f"Embeddings and document list saved for index_id: {index_id}")
except Exception as e:
logger.error(f"Error saving embeddings for index_id {index_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error saving embeddings: {str(e)}")
def load_embeddings(index_id: str) -> List[str]:
try:
folder_path = f"/app/indexes/{index_id}"
if not os.path.exists(folder_path):
logger.error(f"Index not found for index_id: {index_id}")
raise HTTPException(status_code=404, detail="Index not found")
# Load embeddings
embeddings.load(f"{folder_path}/embeddings")
# Load document_list
with open(f"{folder_path}/document_list.json", "r") as f:
document_list = json.load(f)
logger.info(f"Embeddings and document list loaded for index_id: {index_id}")
return document_list
except Exception as e:
logger.error(f"Error loading embeddings for index_id {index_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error loading embeddings: {str(e)}")
@app.post("/create_index/", response_model=dict, tags=["Index Operations"])
async def create_index(request: DocumentRequest):
"""
Create a new index with the given documents.
- **index_id**: Unique identifier for the index
- **documents**: List of documents to be indexed
"""
try:
document_list = [(i, text, None) for i, text in enumerate(request.documents)]
embeddings.index(document_list)
save_embeddings(request.index_id, request.documents) # Save the original documents
logger.info(f"Index created successfully for index_id: {request.index_id}")
return {"message": "Index created successfully"}
except Exception as e:
logger.error(f"Error creating index: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error creating index: {str(e)}")
@app.post("/query_index/", response_model=dict, tags=["Index Operations"])
async def query_index(request: QueryRequest):
"""
Query an existing index with the given search query.
- **index_id**: Unique identifier for the index to query
- **query**: The search query
- **num_results**: Number of results to return
"""
try:
document_list = load_embeddings(request.index_id)
results = embeddings.search(request.query, request.num_results)
queried_texts = [document_list[idx[0]] for idx in results]
logger.info(f"Query executed successfully for index_id: {request.index_id}")
return {"queried_texts": queried_texts}
except Exception as e:
logger.error(f"Error querying index: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error querying index: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) |