Spaces:
Running
Running
import gradio as gr | |
import os | |
from transformers import AutoImageProcessor, AutoModel | |
import torch | |
from pymongo import MongoClient | |
from PIL import Image | |
import json | |
import numpy as np | |
import faiss | |
from dotenv import load_dotenv | |
load_dotenv() | |
# Init similarity search AI model and processor | |
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
dino_v2_model = AutoModel.from_pretrained("facebook/dinov2-base").to(torch_device) | |
dino_v2_image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base") | |
# MongoDB | |
MONGO_URI = os.environ.get("MONGO_URI") | |
mongo = MongoClient(MONGO_URI) | |
db = mongo["xbgp"] | |
def process_image(image): | |
""" | |
Process the image and extract features using the DINOv2 model. | |
""" | |
# Add your image processing code here. | |
# This will include preprocessing the image, passing it through the model, | |
# and then formatting the output (extracted features). | |
# Load the index | |
with open("images.json", "r") as f: | |
images = json.load(f) | |
# Convert to RGB if it isn't already | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
# Resize to 64px while maintaining aspect ratio | |
width, height = image.size | |
if width < height: | |
w_percent = 64 / float(width) | |
new_width = 64 | |
new_height = int(float(height) * float(w_percent)) | |
else: | |
h_percent = 64 / float(height) | |
new_height = 64 | |
new_width = int(float(width) * float(h_percent)) | |
image = image.resize((new_width, new_height), Image.LANCZOS) | |
# Extract the features from the uploaded image | |
with torch.no_grad(): | |
inputs = dino_v2_image_processor(images=image, return_tensors="pt").to( | |
torch_device | |
) | |
outputs = dino_v2_model(**inputs) | |
# Normalize the features before search, whatever that means | |
embeddings = outputs.last_hidden_state | |
embeddings = embeddings.mean(dim=1) | |
vector = embeddings.detach().cpu().numpy() | |
vector = np.float32(vector) | |
faiss.normalize_L2(vector) | |
# Read the index file and perform search of top 50 images | |
index = faiss.read_index("vector.index") | |
distances, indices = index.search(vector, 50) | |
matches = [] | |
for idx, matching_gamerpic in enumerate(indices[0]): | |
gamerpic = images[matching_gamerpic] | |
print(gamerpic) | |
# Return the corresponding title with only the matched gamerpic | |
title = db.titles.find_one( | |
{"gamerpics.cdn": gamerpic}, | |
{"name": 1, "type": 1, "url": 1, "gamerpics.$": 1}, | |
) | |
print(title) | |
title["rank"] = idx | |
title["score"] = str(round((1 / (distances[0][idx] + 1) * 100), 2)) + "%" | |
matches.append(title) | |
return matches | |
# Create a Gradio interface | |
iface = gr.Interface( | |
fn=process_image, | |
inputs=gr.Image(type="pil"), # Adjust the shape as needed | |
outputs="json", # Or any other output format that suits your needs | |
) | |
# Launch the Gradio app | |
iface.launch(share="true") | |