File size: 2,767 Bytes
48e00e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e359b32
 
48e00e6
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
from typing import Union
from PIL import Image
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from sentence_transformers import SentenceTransformer
import uvicorn
import vecs

DB_CONNECTION = os.environ.get(
    'DB_URL', "postgresql://postgres:postgres@localhost:54322/postgres")

app = FastAPI()


@app.get("/seed")
def seed():
    # create vector store client
    vx = vecs.create_client(DB_CONNECTION)

    iv = vx.get_collection(name="image_vectors")

    if iv:
        return {"message": "Collection already exists."}

    # create a collection of vectors with 512 dimensions
    images = vx.create_collection(name="image_vectors", dimension=512)

    # Load CLIP model
    model = SentenceTransformer('clip-ViT-B-32')

    # Encode an image:
    img_emb1 = model.encode(Image.open('./images/one.jpg'))
    img_emb2 = model.encode(Image.open('./images/two.jpg'))
    img_emb3 = model.encode(Image.open('./images/three.jpg'))
    img_emb4 = model.encode(Image.open('./images/four.jpg'))

    images.upsert(
        vectors=[
            (
                "one.jpg",
                img_emb1,
                {"type": "jpg"}
            ), (
                "two.jpg",
                img_emb2,
                {"type": "jpg"}
            ), (
                "three.jpg",
                img_emb3,
                {"type": "jpg"}
            ), (
                "four.jpg",
                img_emb4,
                {"type": "jpg"}
            )
        ]
    )
    print("Inserted images")

    # index the collection fro fast search performance
    images.create_index()
    return {"message": "Collection created and indexed."}


@app.get("/search")
def search(query: Union[str, None] = None):
    # create vector store client
    vx = vecs.create_client(DB_CONNECTION)
    images = vx.get_collection(name="image_vectors")

    # Load CLIP model
    model = SentenceTransformer('clip-ViT-B-32')
    # Encode text query
    query_string = query
    text_emb = model.encode(query_string)

    # query the collection filtering metadata for "type" = "jpg"
    results = images.query(
        query_vector=text_emb,
        limit=1,
        filters={"type": {"$eq": "jpg"}},
    )
    result = results[0]
    return {"result": result, "query": query}


app.mount("/images", StaticFiles(directory="images"), name="images")
app.mount("/", StaticFiles(directory="static", html=True), name="static")


@app.get("/")
def index() -> FileResponse:
    return FileResponse(path="static/index.html", media_type="text/html")


def start():
    """Launched with `poetry run start` at root level"""
    uvicorn.run("image_search.main:app",
                host="0.0.0.0", port=7860, reload=True)