Spaces:
Paused
Paused
File size: 2,767 Bytes
48e00e6 e359b32 48e00e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import os
from typing import Union
from PIL import Image
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from sentence_transformers import SentenceTransformer
import uvicorn
import vecs
DB_CONNECTION = os.environ.get(
'DB_URL', "postgresql://postgres:postgres@localhost:54322/postgres")
app = FastAPI()
@app.get("/seed")
def seed():
# create vector store client
vx = vecs.create_client(DB_CONNECTION)
iv = vx.get_collection(name="image_vectors")
if iv:
return {"message": "Collection already exists."}
# create a collection of vectors with 512 dimensions
images = vx.create_collection(name="image_vectors", dimension=512)
# Load CLIP model
model = SentenceTransformer('clip-ViT-B-32')
# Encode an image:
img_emb1 = model.encode(Image.open('./images/one.jpg'))
img_emb2 = model.encode(Image.open('./images/two.jpg'))
img_emb3 = model.encode(Image.open('./images/three.jpg'))
img_emb4 = model.encode(Image.open('./images/four.jpg'))
images.upsert(
vectors=[
(
"one.jpg",
img_emb1,
{"type": "jpg"}
), (
"two.jpg",
img_emb2,
{"type": "jpg"}
), (
"three.jpg",
img_emb3,
{"type": "jpg"}
), (
"four.jpg",
img_emb4,
{"type": "jpg"}
)
]
)
print("Inserted images")
# index the collection fro fast search performance
images.create_index()
return {"message": "Collection created and indexed."}
@app.get("/search")
def search(query: Union[str, None] = None):
# create vector store client
vx = vecs.create_client(DB_CONNECTION)
images = vx.get_collection(name="image_vectors")
# Load CLIP model
model = SentenceTransformer('clip-ViT-B-32')
# Encode text query
query_string = query
text_emb = model.encode(query_string)
# query the collection filtering metadata for "type" = "jpg"
results = images.query(
query_vector=text_emb,
limit=1,
filters={"type": {"$eq": "jpg"}},
)
result = results[0]
return {"result": result, "query": query}
app.mount("/images", StaticFiles(directory="images"), name="images")
app.mount("/", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="static/index.html", media_type="text/html")
def start():
"""Launched with `poetry run start` at root level"""
uvicorn.run("image_search.main:app",
host="0.0.0.0", port=7860, reload=True)
|