|
import torch |
|
from transformers import AutoImageProcessor, Dinov2ForImageClassification, Dinov2Config, Dinov2Model |
|
from PIL import Image |
|
import gradio as gr |
|
from huggingface_hub import hf_hub_download |
|
import json |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
|
|
model_name = "DinoVdeau-large-2024_04_03-with_data_aug_batch-size32_epochs150_freeze" |
|
|
|
checkpoint_name = "lombardata/" + model_name |
|
|
|
|
|
config_path = hf_hub_download(repo_id=checkpoint_name, filename="config.json") |
|
with open(config_path, 'r') as config_file: |
|
config = json.load(config_file) |
|
id2label = config["id2label"] |
|
label2id = config["label2id"] |
|
image_size = config["image_size"] |
|
num_labels = len(id2label) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def create_head(num_features , number_classes ,dropout_prob=0.5 ,activation_func =nn.ReLU): |
|
features_lst = [num_features , num_features//2 , num_features//4] |
|
layers = [] |
|
for in_f ,out_f in zip(features_lst[:-1] , features_lst[1:]): |
|
layers.append(nn.Linear(in_f , out_f)) |
|
layers.append(activation_func()) |
|
layers.append(nn.BatchNorm1d(out_f)) |
|
if dropout_prob !=0 : layers.append(nn.Dropout(dropout_prob)) |
|
layers.append(nn.Linear(features_lst[-1] , number_classes)) |
|
return nn.Sequential(*layers) |
|
|
|
class NewheadDinov2ForImageClassification(Dinov2ForImageClassification): |
|
def __init__(self, config: Dinov2Config) -> None: |
|
super().__init__(config) |
|
|
|
|
|
self.classifier = create_head(config.hidden_size * 2, config.num_labels) |
|
|
|
model = NewheadDinov2ForImageClassification.from_pretrained(checkpoint_name) |
|
model.to(device) |
|
def sigmoid(_outputs): |
|
return 1.0 / (1.0 + np.exp(-_outputs)) |
|
|
|
def download_thresholds(repo_id, filename): |
|
threshold_path = hf_hub_download(repo_id=repo_id, filename=filename) |
|
with open(threshold_path, 'r') as threshold_file: |
|
thresholds = json.load(threshold_file) |
|
return thresholds |
|
|
|
def predict(image, slider_threshold=0.5, fixed_thresholds=None): |
|
|
|
processor = AutoImageProcessor.from_pretrained(checkpoint_name) |
|
inputs = processor(images=image, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
model_outputs = model(**inputs) |
|
logits = model_outputs.logits[0] |
|
probabilities = torch.sigmoid(logits).cpu().numpy() |
|
|
|
|
|
slider_results = {id2label[str(i)]: float(prob) for i, prob in enumerate(probabilities) if prob > slider_threshold} |
|
|
|
|
|
fixed_threshold_labels_str = None |
|
if fixed_thresholds is not None: |
|
fixed_threshold_labels = [id2label[str(i)] for i, prob in enumerate(probabilities) if prob > fixed_thresholds[id2label[str(i)]]] |
|
fixed_threshold_labels_str = ", ".join(fixed_threshold_labels) |
|
|
|
return slider_results, fixed_threshold_labels_str |
|
|
|
def predict_wrapper(image, slider_threshold=0.5): |
|
|
|
thresholds = download_thresholds(checkpoint_name, "threshold.json") |
|
|
|
|
|
slider_results, fixed_threshold_results = predict(image, slider_threshold, thresholds) |
|
|
|
|
|
return slider_results, fixed_threshold_results |
|
|
|
|
|
|
|
title = "Victor - DinoVd'eau image classification" |
|
model_link = "https://huggingface.co/" + checkpoint_name |
|
description = f"This application showcases the capability of artificial intelligence-based systems to identify objects within underwater images. To utilize it, you can either upload your own image or select one of the provided examples for analysis.\nFor predictions, we use this [open-source model]({model_link})" |
|
description = ("This application showcases the capability of artificial intelligence-based systems " |
|
"to identify objects within underwater images. To utilize it, you can either upload " |
|
"your own image or select one of the provided examples for analysis. " |
|
"\nFor predictions, we use this [open-source model](model_link)") |
|
|
|
subtitle = "Note: the model runs on CPU, so it may take a while to run the prediction." |
|
|
|
full_description = f"{description}\n\n{subtitle}" |
|
|
|
iface = gr.Interface( |
|
fn=predict_wrapper, |
|
inputs=[gr.components.Image(type="pil"), gr.components.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Threshold")], |
|
outputs=[ |
|
gr.components.Textbox(label="Fixed Threshold Predictions"), |
|
gr.components.Label(label="Slider Threshold Predictions") |
|
], |
|
title=title, |
|
description=full_description, |
|
examples=[["session_GOPR0106.JPG"], |
|
["session_2021_08_30_Mayotte_10_image_00066.jpg"], |
|
["session_2018_11_17_kite_Le_Morne_Manawa_G0065777.JPG"], |
|
["session_2023_06_28_caplahoussaye_plancha_body_v1B_00_GP1_3_1327.jpeg"]]).launch() |