xbgp / app.py
methodw's picture
fix key name
bd9a56b
raw
history blame
2.58 kB
import gradio as gr
from transformers import AutoImageProcessor, AutoModel
import torch
from PIL import Image
import json
import numpy as np
import faiss
# Init similarity search AI model and processor
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dino_v2_model = AutoModel.from_pretrained("./dinov2-base").to(torch_device)
dino_v2_image_processor = AutoImageProcessor.from_pretrained("./dinov2-base")
def process_image(image):
"""
Process the image and extract features using the DINOv2 model.
"""
# Add your image processing code here.
# This will include preprocessing the image, passing it through the model,
# and then formatting the output (extracted features).
# Load the index
with open("xbgp-faiss-map.json", "r") as f:
images = json.load(f)
# Convert to RGB if it isn't already
if image.mode != "RGB":
image = image.convert("RGB")
# Resize to 64px while maintaining aspect ratio
width, height = image.size
if width < height:
w_percent = 64 / float(width)
new_width = 64
new_height = int(float(height) * float(w_percent))
else:
h_percent = 64 / float(height)
new_height = 64
new_width = int(float(width) * float(h_percent))
image = image.resize((new_width, new_height), Image.LANCZOS)
# Extract the features from the uploaded image
with torch.no_grad():
inputs = dino_v2_image_processor(images=image, return_tensors="pt").to(
torch_device
)
outputs = dino_v2_model(**inputs)
# Normalize the features before search, whatever that means
embeddings = outputs.last_hidden_state
embeddings = embeddings.mean(dim=1)
vector = embeddings.detach().cpu().numpy()
vector = np.float32(vector)
faiss.normalize_L2(vector)
# Read the index file and perform search of top 50 images
index = faiss.read_index("xbgp-faiss.index")
distances, indices = index.search(vector, 50)
matches = []
for idx, matching_gamerpic in enumerate(indices[0]):
gamerpic = {}
gamerpic["id"] = images[matching_gamerpic]
gamerpic["score"] = str(round((1 / (distances[0][idx] + 1) * 100), 2)) + "%"
print(gamerpic)
matches.append(gamerpic)
return matches
# Create a Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"), # Adjust the shape as needed
outputs="json", # Or any other output format that suits your needs
).queue()
# Launch the Gradio app
iface.launch()