File size: 4,088 Bytes
32faf2b
 
 
 
 
 
 
8e0a94c
cfa61ba
32faf2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2984e25
cfa61ba
32faf2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4dc515
32faf2b
 
 
 
 
 
 
 
 
 
 
 
 
 
a3e6571
 
98a649d
a3e6571
 
 
 
 
 
98a649d
 
32faf2b
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import sys

os.chdir('GroundingDINO/')
os.system('pip install -e .')
os.chdir('../SAM')
os.system('pip install -e .')
os.system('pip install opencv-python pycocotools matplotlib onnxruntime onnx ipykernel gradio loguru transformers timm addict yapf loguru tqdm scikit-image scikit-learn pandas tensorboard seaborn open_clip_torch  einops')
os.system('pip install torch==1.10.0 torchvision==0.11.1 -f https://download.pytorch.org/whl/cu113/torch_stable.html')

os.chdir('..')
os.mkdir('weights')
os.chdir('./weights')
os.system('wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth')
os.system('wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth')
os.chdir('..')

import sys
sys.path.append('./GroundingDINO')
sys.path.append('./SAM')
sys.path.append('.')
import matplotlib.pyplot as plt
import SAA as SegmentAnyAnomaly
from utils.training_utils import *
import os



dino_config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
dino_checkpoint = 'weights/groundingdino_swint_ogc.pth'
sam_checkpoint = 'weights/sam_vit_h_4b8939.pth'
box_threshold = 0.1
text_threshold = 0.1
eval_resolution = 256
device = f"cpu"
root_dir = 'result'

# get the model
model = SegmentAnyAnomaly.Model(
    dino_config_file=dino_config_file,
    dino_checkpoint=dino_checkpoint,
    sam_checkpoint=sam_checkpoint,
    box_threshold=box_threshold,
    text_threshold=text_threshold,
    out_size=eval_resolution,
    device=device,
)

model = model.to(device)

import cv2
import numpy as np
import gradio as gr


def process_image(heatmap, image):
    heatmap = heatmap.astype(float)
    heatmap = (heatmap - heatmap.min()) / heatmap.max() * 255
    heatmap = heatmap.astype(np.uint8)
    heat_map = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    visz_map = cv2.addWeighted(heat_map, 0.5, image, 0.5, 0)
    visz_map = cv2.cvtColor(visz_map, cv2.COLOR_BGR2RGB)

    visz_map = visz_map.astype(float)
    visz_map = visz_map / visz_map.max()
    return visz_map


def func(image, anomaly_description, object_name, object_number, mask_number, area_threashold):
    textual_prompts = [
        [anomaly_description, object_name]
    ]  # detect prompts, filtered phrase
    property_text_prompts = f'the image of {object_name} have {object_number} dissimilar {object_name}, with a maximum of {mask_number} anomaly. The anomaly would not exceed {area_threashold} object area. '

    model.set_ensemble_text_prompts(textual_prompts, verbose=True)
    model.set_property_text_prompts(property_text_prompts, verbose=True)

    image = cv2.resize(image, (eval_resolution, eval_resolution))
    score, appendix = model(image)
    similarity_map = appendix['similarity_map']

    image_show = cv2.resize(image, (eval_resolution, eval_resolution))
    similarity_map = cv2.resize(similarity_map, (eval_resolution, eval_resolution))
    score = cv2.resize(score, (eval_resolution, eval_resolution))

    viz_score = process_image(score, image_show)
    viz_sim = process_image(similarity_map, image_show)

    return viz_score, viz_sim


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            image = gr.Image(label="Image")
            anomaly_description = gr.Textbox(label="Anomaly Description (e.g. color defect. hole. black defect. wick hole. spot. )")
            object_name = gr.Textbox(label="Object Name (e.g. candle)")
            object_number = gr.Textbox(label="Object Number (e.g. 4)")
            mask_number = gr.Textbox(label="Mask Number (e.g. 1)")
            area_threashold = gr.Textbox(label="Area Threshold (e.g. 0.3)")
        with gr.Column():
            anomaly_score = gr.Image(label="Anomaly Score")
            saliency_map = gr.Image(label="Saliency Map")

    greet_btn = gr.Button("Inference")
    greet_btn.click(fn=func,
                    inputs=[image, anomaly_description, object_name, object_number, mask_number, area_threashold],
                    outputs=[anomaly_score, saliency_map], api_name="Segment-Any-Anomaly")

demo.launch()