xbgp / app.py
methodw's picture
use model locally
d9b1d5d
raw
history blame
3 kB
import gradio as gr
import os
from transformers import AutoImageProcessor, AutoModel
import torch
from pymongo import MongoClient
from PIL import Image
import json
import numpy as np
import faiss
from dotenv import load_dotenv
load_dotenv()
# Init similarity search AI model and processor
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dino_v2_model = AutoModel.from_pretrained("./dinov2-base").to(torch_device)
dino_v2_image_processor = AutoImageProcessor.from_pretrained("./dinov2-base")
# MongoDB
MONGO_URI = os.environ.get("MONGO_URI")
mongo = MongoClient(MONGO_URI)
db = mongo["xbgp"]
def process_image(image):
"""
Process the image and extract features using the DINOv2 model.
"""
# Add your image processing code here.
# This will include preprocessing the image, passing it through the model,
# and then formatting the output (extracted features).
# Load the index
with open("images.json", "r") as f:
images = json.load(f)
# Convert to RGB if it isn't already
if image.mode != "RGB":
image = image.convert("RGB")
# Resize to 64px while maintaining aspect ratio
width, height = image.size
if width < height:
w_percent = 64 / float(width)
new_width = 64
new_height = int(float(height) * float(w_percent))
else:
h_percent = 64 / float(height)
new_height = 64
new_width = int(float(width) * float(h_percent))
image = image.resize((new_width, new_height), Image.LANCZOS)
# Extract the features from the uploaded image
with torch.no_grad():
inputs = dino_v2_image_processor(images=image, return_tensors="pt").to(
torch_device
)
outputs = dino_v2_model(**inputs)
# Normalize the features before search, whatever that means
embeddings = outputs.last_hidden_state
embeddings = embeddings.mean(dim=1)
vector = embeddings.detach().cpu().numpy()
vector = np.float32(vector)
faiss.normalize_L2(vector)
# Read the index file and perform search of top 50 images
index = faiss.read_index("vector.index")
distances, indices = index.search(vector, 50)
matches = []
for idx, matching_gamerpic in enumerate(indices[0]):
gamerpic = images[matching_gamerpic]
print(gamerpic)
# Return the corresponding title with only the matched gamerpic
title = db.titles.find_one(
{"gamerpics.cdn": gamerpic},
{"name": 1, "type": 1, "url": 1, "gamerpics.$": 1},
)
print(title)
title["rank"] = idx
title["score"] = str(round((1 / (distances[0][idx] + 1) * 100), 2)) + "%"
matches.append(title)
return matches
# Create a Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"), # Adjust the shape as needed
outputs="json", # Or any other output format that suits your needs
)
# Launch the Gradio app
iface.launch(share="true")