xbgp / app.py
methodw's picture
change torch device to cpu only
d24bb26
raw
history blame
2.78 kB
import gradio as gr
from transformers import AutoImageProcessor, AutoModel
import torch
from PIL import Image
import json
import numpy as np
import faiss
# Init similarity search AI model and processor
device = torch.device("cpu")
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large")
model = AutoModel.from_pretrained("facebook/dinov2-large")
model.config.return_dict = False # Set return_dict to False for JIT tracing
model.to(device)
# Prepare an example input for tracing
example_input = torch.rand(1, 3, 224, 224).to(device) # Adjust size if needed
traced_model = torch.jit.trace(model, example_input)
traced_model = traced_model.to(device)
def process_image(image):
"""
Process the image and extract features using the DINOv2 model.
"""
# Add your image processing code here.
# This will include preprocessing the image, passing it through the model,
# and then formatting the output (extracted features).
# Load the index
with open("xbgp-faiss-map.json", "r") as f:
images = json.load(f)
# Convert to RGB if it isn't already
if image.mode != "RGB":
image = image.convert("RGB")
# Resize to 224px while maintaining aspect ratio
width, height = image.size
if width < height:
w_percent = 224 / float(width)
new_width = 224
new_height = int(float(height) * float(w_percent))
else:
h_percent = 224 / float(height)
new_height = 224
new_width = int(float(width) * float(h_percent))
image = image.resize((new_width, new_height), Image.LANCZOS)
# Extract the features from the uploaded image
with torch.no_grad():
inputs = processor(images=image, return_tensors="pt").to(device)
# Use the traced model for inference
outputs = traced_model(**inputs)
# Normalize the features before search, whatever that means
embeddings = outputs[0].mean(dim=1)
vector = embeddings.detach().cpu().numpy()
vector = np.float32(vector)
faiss.normalize_L2(vector)
# Read the index file and perform search of top 50 images
index = faiss.read_index("xbgp-faiss.index")
distances, indices = index.search(vector, 50)
matches = []
for idx, matching_gamerpic in enumerate(indices[0]):
gamerpic = {}
gamerpic["id"] = images[matching_gamerpic]
gamerpic["score"] = str(round((1 / (distances[0][idx] + 1) * 100), 2)) + "%"
print(gamerpic)
matches.append(gamerpic)
return matches
# Create a Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"), # Adjust the shape as needed
outputs="json", # Or any other output format that suits your needs
).queue()
# Launch the Gradio app
iface.launch()