AdaCLIP / app.py
Caoyunkang's picture
first commit
a25563f verified
raw
history blame
4.09 kB
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
import warnings
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
import json
import os
import torch
from scipy.ndimage import gaussian_filter
import cv2
from method import AdaCLIP_Trainer
import numpy as np
############ Init Model
ckt_path1 = 'weights/pretrained_mvtec_colondb.pth'
ckt_path2 = "weights/pretrained_visa_clinicdb.pth"
ckt_path3 = 'weights/pretrained_all.pth'
# Configurations
image_size = 518
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
model = "ViT-L-14-336"
prompting_depth = 4
prompting_length = 5
prompting_type = 'SD'
prompting_branch = 'VL'
use_hsf = True
k_clusters = 20
config_path = os.path.join('./model_configs', f'{model}.json')
# Prepare model
with open(config_path, 'r') as f:
model_configs = json.load(f)
# Set up the feature hierarchy
n_layers = model_configs['vision_cfg']['layers']
substage = n_layers // 4
features_list = [substage, substage * 2, substage * 3, substage * 4]
model = AdaCLIP_Trainer(
backbone=model,
feat_list=features_list,
input_dim=model_configs['vision_cfg']['width'],
output_dim=model_configs['embed_dim'],
learning_rate=0.,
device=device,
image_size=image_size,
prompting_depth=prompting_depth,
prompting_length=prompting_length,
prompting_branch=prompting_branch,
prompting_type=prompting_type,
use_hsf=use_hsf,
k_clusters=k_clusters
).to(device)
def process_image(image, text, options):
# Load the model based on selected options
if 'MVTec AD+Colondb' in options:
model.load(ckt_path1)
elif 'VisA+Clinicdb' in options:
model.load(ckt_path2)
elif 'All' in options:
model.load(ckt_path3)
else:
# Default to 'All' if no valid option is provided
model.load(ckt_path3)
print('Invalid option. Defaulting to All.')
# Ensure image is in RGB mode
image = image.convert('RGB')
# Convert PIL image to NumPy array
np_image = np.array(image)
# Convert RGB to BGR for OpenCV
np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
np_image = cv2.resize(np_image, (image_size, image_size))
# Preprocess the image and run the model
img_input = model.preprocess(image).unsqueeze(0)
img_input = img_input.to(model.device)
with torch.no_grad():
anomaly_map, anomaly_score = model.clip_model(img_input, [text], aggregation=True)
# Process anomaly map
anomaly_map = anomaly_map[0, :, :].cpu().numpy()
anomaly_score = anomaly_score[0].cpu().numpy()
anomaly_map = gaussian_filter(anomaly_map, sigma=4)
anomaly_map = (anomaly_map * 255).astype(np.uint8)
# Apply color map and blend with original image
heat_map = cv2.applyColorMap(anomaly_map, cv2.COLORMAP_JET)
vis_map = cv2.addWeighted(heat_map, 0.5, np_image, 0.5, 0)
# Convert OpenCV image back to PIL image for Gradio
vis_map_pil = Image.fromarray(cv2.cvtColor(vis_map, cv2.COLOR_BGR2RGB))
return vis_map_pil, f'{anomaly_score:.3f}'
# Define examples
examples = [
["asset/img.png", "candle", "MVTec AD+Colondb"],
["asset/img2.png", "bottle", "VisA+Clinicdb"],
["asset/img3.png", "button", "All"],
]
# Gradio interface layout
demo = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Textbox(label="Class Name"),
gr.Radio(["MVTec AD+Colondb",
"VisA+Clinicdb",
"All"],
label="Pre-trained Datasets")
],
outputs=[
gr.Image(type="pil", label="Output Image"),
gr.Textbox(label="Anomaly Score"),
],
examples=examples,
title="AdaCLIP -- Zero-shot Anomaly Detection",
description="Upload an image, enter class name, and select pre-trained datasets to do zero-shot anomaly detection"
)
# Launch the demo
demo.launch()
# demo.launch(server_name="0.0.0.0", server_port=10002)