xbgp / app.py
methodw's picture
jit fix 5
1932c4f
raw
history blame
2.8 kB
import gradio as gr
from transformers import AutoImageProcessor, AutoModel
import torch
from PIL import Image
import json
import numpy as np
import faiss
# Init similarity search AI model and processor
device = torch.device("cpu")
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large")
model = AutoModel.from_pretrained("facebook/dinov2-large")
model.config.return_dict = False # Set return_dict to False for JIT tracing
model.to(device)
# Prepare an example input for tracing
example_input = torch.rand(1, 3, 224, 224).to(device) # Adjust size if needed
traced_model = torch.jit.trace(model, example_input)
traced_model = traced_model.to(device)
def process_image(image):
"""
Process the image and extract features using the DINOv2 model.
"""
# Add your image processing code here.
# This will include preprocessing the image, passing it through the model,
# and then formatting the output (extracted features).
# Load the index
with open("xbgp-faiss-map.json", "r") as f:
images = json.load(f)
# Convert to RGB if it isn't already
if image.mode != "RGB":
image = image.convert("RGB")
# Resize to 224px while maintaining aspect ratio
width, height = image.size
if width < height:
w_percent = 224 / float(width)
new_width = 224
new_height = int(float(height) * float(w_percent))
else:
h_percent = 224 / float(height)
new_height = 224
new_width = int(float(width) * float(h_percent))
image = image.resize((new_width, new_height), Image.LANCZOS)
# Extract the features from the uploaded image
with torch.no_grad():
inputs = processor(images=image, return_tensors="pt")["pixel_values"].to(device)
# Use the traced model for inference
outputs = traced_model(inputs)
# Normalize the features before search, whatever that means
embeddings = outputs[0].mean(dim=1)
vector = embeddings.detach().cpu().numpy()
vector = np.float32(vector)
faiss.normalize_L2(vector)
# Read the index file and perform search of top 50 images
index = faiss.read_index("xbgp-faiss.index")
distances, indices = index.search(vector, 50)
matches = []
for idx, matching_gamerpic in enumerate(indices[0]):
gamerpic = {}
gamerpic["id"] = images[matching_gamerpic]
gamerpic["score"] = str(round((1 / (distances[0][idx] + 1) * 100), 2)) + "%"
print(gamerpic)
matches.append(gamerpic)
return matches
# Create a Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"), # Adjust the shape as needed
outputs="json", # Or any other output format that suits your needs
).queue()
# Launch the Gradio app
iface.launch()