Caoyunkang commited on
Commit
a25563f
·
verified ·
1 Parent(s): d3da3ef

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -35
  2. .gitignore +4 -0
  3. LICENSE +21 -0
  4. README.md +176 -13
  5. app.py +133 -0
  6. asset/Fig_app.png +3 -0
  7. asset/Fig_detection_results.png +3 -0
  8. asset/Table_industrial.png +3 -0
  9. asset/Table_medical.png +3 -0
  10. asset/framework.png +3 -0
  11. asset/img.png +3 -0
  12. asset/img2.png +3 -0
  13. asset/img3.png +3 -0
  14. config.py +1 -0
  15. data_preprocess/br35h.py +50 -0
  16. data_preprocess/brain_mri.py +51 -0
  17. data_preprocess/btad.py +52 -0
  18. data_preprocess/clinicdb.py +52 -0
  19. data_preprocess/colondb.py +52 -0
  20. data_preprocess/dagm-pre.py +82 -0
  21. data_preprocess/dagm.py +52 -0
  22. data_preprocess/dtd.py +52 -0
  23. data_preprocess/endo.py +52 -0
  24. data_preprocess/headct-pre.py +41 -0
  25. data_preprocess/headct.py +52 -0
  26. data_preprocess/isic.py +52 -0
  27. data_preprocess/mpdd.py +52 -0
  28. data_preprocess/mvtec.py +52 -0
  29. data_preprocess/sdd-pre.py +75 -0
  30. data_preprocess/sdd.py +52 -0
  31. data_preprocess/tn3k.py +52 -0
  32. data_preprocess/visa.py +52 -0
  33. dataset/__init__.py +68 -0
  34. dataset/__pycache__/__init__.cpython-39.pyc +0 -0
  35. dataset/__pycache__/br35h.cpython-39.pyc +0 -0
  36. dataset/__pycache__/brain_mri.cpython-39.pyc +0 -0
  37. dataset/__pycache__/btad.cpython-39.pyc +0 -0
  38. dataset/__pycache__/clinicdb.cpython-39.pyc +0 -0
  39. dataset/__pycache__/colondb.cpython-39.pyc +0 -0
  40. dataset/__pycache__/dagm.cpython-39.pyc +0 -0
  41. dataset/__pycache__/dtd.cpython-39.pyc +0 -0
  42. dataset/__pycache__/headct.cpython-39.pyc +0 -0
  43. dataset/__pycache__/isic.cpython-39.pyc +0 -0
  44. dataset/__pycache__/mpdd.cpython-39.pyc +0 -0
  45. dataset/__pycache__/mvtec.cpython-39.pyc +0 -0
  46. dataset/__pycache__/sdd.cpython-39.pyc +0 -0
  47. dataset/__pycache__/tn3k.cpython-39.pyc +0 -0
  48. dataset/__pycache__/visa.cpython-39.pyc +0 -0
  49. dataset/base_dataset.py +138 -0
  50. dataset/br35h.py +18 -0
.gitattributes CHANGED
@@ -1,35 +1,5 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.png filter=lfs diff=lfs merge=lfs -text
2
+ *.txt.gz filter=lfs diff=lfs merge=lfs -text
3
+ weights/pretrained_all.pth filter=lfs diff=lfs merge=lfs -text
4
+ weights/pretrained_mvtec_colondb.pth filter=lfs diff=lfs merge=lfs -text
5
+ weights/pretrained_visa_clinicdb.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ /result/
2
+ /.idea/
3
+ /__pycache__/
4
+ /weights/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Yunkang Cao
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,176 @@
1
- ---
2
- title: AdaCLIP
3
- emoji: 🌖
4
- colorFrom: pink
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 4.38.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AdaCLIP (Detecting Anomalies for Novel Categories)
2
+ [![HuggingFace Space](https://img.shields.io/badge/🤗-HuggingFace%20Space-cyan.svg)]()
3
+
4
+ > [**ECCV 24**] [**AdaCLIP: Adapting CLIP with Hybrid Learnable Prompts for Zero-Shot Anomaly Detection**]().
5
+ >
6
+ > by [Yunkang Cao](https://caoyunkang.github.io/), [Jiangning Zhang](https://zhangzjn.github.io/), [Luca Frittoli](https://scholar.google.com/citations?user=cdML_XUAAAAJ),
7
+ > [Yuqi Cheng](https://scholar.google.com/citations?user=02BC-WgAAAAJ&hl=en), [Weiming Shen](https://scholar.google.com/citations?user=FuSHsx4AAAAJ&hl=en), [Giacomo Boracchi](https://boracchi.faculty.polimi.it/)
8
+ >
9
+
10
+ ## Introduction
11
+ Zero-shot anomaly detection (ZSAD) targets the identification of anomalies within images from arbitrary novel categories.
12
+ This study introduces AdaCLIP for the ZSAD task, leveraging a pre-trained vision-language model (VLM), CLIP.
13
+ AdaCLIP incorporates learnable prompts into CLIP and optimizes them through training on auxiliary annotated anomaly detection data.
14
+ Two types of learnable prompts are proposed: \textit{static} and \textit{dynamic}. Static prompts are shared across all images, serving to preliminarily adapt CLIP for ZSAD.
15
+ In contrast, dynamic prompts are generated for each test image, providing CLIP with dynamic adaptation capabilities.
16
+ The combination of static and dynamic prompts is referred to as hybrid prompts, and yields enhanced ZSAD performance.
17
+ Extensive experiments conducted across 14 real-world anomaly detection datasets from industrial and medical domains indicate that AdaCLIP outperforms other ZSAD methods and can generalize better to different categories and even domains.
18
+ Finally, our analysis highlights the importance of diverse auxiliary data and optimized prompts for enhanced generalization capacity.
19
+
20
+ ## Overview of AdaCLIP
21
+ ![overview](asset/framework.png)
22
+
23
+ ## 🛠️ Getting Started
24
+
25
+ ### Installation
26
+ To set up the AdaCLIP environment, follow one of the methods below:
27
+
28
+ - Clone this repo:
29
+ ```shell
30
+ git clone https://github.com/caoyunkang/AdaCLIP.git && cd AdaCLIP
31
+ ```
32
+ - You can use our provided installation script for an automated setup::
33
+ ```shell
34
+ sh install.sh
35
+ ```
36
+ - If you prefer to construct the experimental environment manually, follow these steps:
37
+ ```shell
38
+ conda create -n AdaCLIP python=3.9.5 -y
39
+ conda activate AdaCLIP
40
+ pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html
41
+ pip install tqdm tensorboard setuptools==58.0.4 opencv-python scikit-image scikit-learn matplotlib seaborn ftfy regex numpy==1.26.4
42
+ pip install gradio # Optional, for app
43
+ ```
44
+ - Remember to update the dataset root in config.py according to your preference:
45
+ ```python
46
+ DATA_ROOT = '../datasets' # Original setting
47
+ ```
48
+
49
+ ### Dataset Preparation
50
+ Please download our processed visual anomaly detection datasets to your `DATA_ROOT` as needed.
51
+
52
+ #### Industrial Visual Anomaly Detection Datasets
53
+ Note: some links are still in processing...
54
+
55
+ | Dataset | Google Drive | Baidu Drive | Task
56
+ |------------|------------------|------------------| ------------------|
57
+ | MVTec AD | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization |
58
+ | VisA | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization |
59
+ | MPDD | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization |
60
+ | BTAD | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization |
61
+ | KSDD | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization |
62
+ | DAGM | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization |
63
+ | DTD-Synthetic | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization |
64
+
65
+
66
+
67
+
68
+ #### Medical Visual Anomaly Detection Datasets
69
+ | Dataset | Google Drive | Baidu Drive | Task
70
+ |------------|------------------|------------------| ------------------|
71
+ | HeadCT | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection |
72
+ | BrainMRI | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection |
73
+ | Br35H | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection |
74
+ | ISIC | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Localization |
75
+ | ColonDB | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Localization |
76
+ | ClinicDB | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Localization |
77
+ | TN3K | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Localization |
78
+
79
+ #### Custom Datasets
80
+ To use your custom dataset, follow these steps:
81
+
82
+ 1. Refer to the instructions in `./data_preprocess` to generate the JSON file for your dataset.
83
+ 2. Use `./dataset/base_dataset.py` to construct your own dataset.
84
+
85
+
86
+ ### Weight Preparation
87
+
88
+ We offer various pre-trained weights on different auxiliary datasets.
89
+ Please download the pre-trained weights in `./weights`.
90
+
91
+ | Pre-trained Datasets | Google Drive | Baidu Drive
92
+ |------------|------------------|------------------|
93
+ | MVTec AD & ClinicDB | [Google Drive](https://drive.google.com/file/d/1xVXANHGuJBRx59rqPRir7iqbkYzq45W0/view?usp=drive_link) | [Baidu Drive](链接) |
94
+ | VisA & ColonDB | [Google Drive](https://drive.google.com/file/d/1QGmPB0ByPZQ7FucvGODMSz7r5Ke5wx9W/view?usp=drive_link) | [Baidu Drive](链接) |
95
+ | All Datasets Mentioned Above | [Google Drive](https://drive.google.com/file/d/1Cgkfx3GAaSYnXPLolx-P7pFqYV0IVzZF/view?usp=drive_link) | [Baidu Drive](链接) |
96
+
97
+
98
+ ### Train
99
+
100
+ By default, we use MVTec AD & ClinicDB for training and VisA for validation:
101
+ ```shell
102
+ CUDA_VISIBLE_DEVICES=0 python train.py --save_fig True --training_data mvtec colondb --testing_data visa
103
+ ```
104
+
105
+
106
+ Alternatively, for evaluation on MVTec AD & ClinicDB, we use VisA & ColonDB for training and MVTec AD for validation.
107
+ ```shell
108
+ CUDA_VISIBLE_DEVICES=0 python train.py --save_fig True --training_data visa clinicdb --testing_data mvtec
109
+ ```
110
+ Since we have utilized half-precision (FP16) for training, the training process can occasionally be unstable.
111
+ It is recommended to run the training process multiple times and choose the best model based on performance
112
+ on the validation set as the final model.
113
+
114
+
115
+ To construct a robust ZSAD model for demonstration, we also train our AdaCLIP on all AD datasets mentioned above:
116
+ ```shell
117
+ CUDA_VISIBLE_DEVICES=0 python train.py --save_fig True \
118
+ --training_data \
119
+ br35h brain_mri btad clinicdb colondb \
120
+ dagm dtd headct isic mpdd mvtec sdd tn3k visa \
121
+ --testing_data mvtec
122
+ ```
123
+
124
+ ### Test
125
+
126
+ Manually select the best models from the validation set and place them in the `weights/` directory. Then, run the following testing script:
127
+ ```shell
128
+ sh test.sh
129
+ ```
130
+
131
+ If you want to test on a single image, you can refer to `test_single_image.sh`:
132
+ ```shell
133
+ CUDA_VISIBLE_DEVICES=0 python test.py --testing_model image --ckt_path weights/pretrained_all.pth --save_fig True \
134
+ --image_path asset/img.png --class_name candle --save_name test.png
135
+ ```
136
+
137
+ ## Main Results
138
+
139
+ Due to differences in versions utilized, the reported performance may vary slightly compared to the detection performance
140
+ with the provided pre-trained weights. Some categories may show higher performance while others may show lower.
141
+
142
+ ![Table_industrial](./asset/Table_industrial.png)
143
+ ![Table_medical](./asset/Table_medical.png)
144
+ ![Fig_detection_results](./asset/Fig_detection_results.png)
145
+
146
+ ### :page_facing_up: Demo App
147
+
148
+ To run the demo application, use the following command:
149
+
150
+ ```bash
151
+ python app.py
152
+ ```
153
+
154
+ ![Demo](./asset/Fig_app.png)
155
+
156
+ ## 💘 Acknowledgements
157
+ Our work is largely inspired by the following projects. Thanks for their admiring contribution.
158
+
159
+ - [VAND-APRIL-GAN](https://github.com/ByChelsea/VAND-APRIL-GAN)
160
+ - [AnomalyCLIP](https://github.com/zqhang/AnomalyCLIP)
161
+ - [SAA](https://github.com/caoyunkang/Segment-Any-Anomaly)
162
+
163
+
164
+ ## Stargazers over time
165
+ [![Stargazers over time](https://starchart.cc/caoyunkang/AdaCLIP.svg?variant=adaptive)](https://starchart.cc/caoyunkang/AdaCLIP)
166
+
167
+
168
+ ## Citation
169
+
170
+ If you find this project helpful for your research, please consider citing the following BibTeX entry.
171
+
172
+ ```BibTex
173
+
174
+
175
+
176
+ ```
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image, ImageDraw, ImageFont
3
+ import warnings
4
+ import os
5
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
6
+ import json
7
+ import os
8
+ import torch
9
+ from scipy.ndimage import gaussian_filter
10
+ import cv2
11
+ from method import AdaCLIP_Trainer
12
+ import numpy as np
13
+
14
+ ############ Init Model
15
+ ckt_path1 = 'weights/pretrained_mvtec_colondb.pth'
16
+ ckt_path2 = "weights/pretrained_visa_clinicdb.pth"
17
+ ckt_path3 = 'weights/pretrained_all.pth'
18
+
19
+ # Configurations
20
+ image_size = 518
21
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
+ # device = 'cpu'
23
+ model = "ViT-L-14-336"
24
+ prompting_depth = 4
25
+ prompting_length = 5
26
+ prompting_type = 'SD'
27
+ prompting_branch = 'VL'
28
+ use_hsf = True
29
+ k_clusters = 20
30
+
31
+ config_path = os.path.join('./model_configs', f'{model}.json')
32
+
33
+ # Prepare model
34
+ with open(config_path, 'r') as f:
35
+ model_configs = json.load(f)
36
+
37
+ # Set up the feature hierarchy
38
+ n_layers = model_configs['vision_cfg']['layers']
39
+ substage = n_layers // 4
40
+ features_list = [substage, substage * 2, substage * 3, substage * 4]
41
+
42
+ model = AdaCLIP_Trainer(
43
+ backbone=model,
44
+ feat_list=features_list,
45
+ input_dim=model_configs['vision_cfg']['width'],
46
+ output_dim=model_configs['embed_dim'],
47
+ learning_rate=0.,
48
+ device=device,
49
+ image_size=image_size,
50
+ prompting_depth=prompting_depth,
51
+ prompting_length=prompting_length,
52
+ prompting_branch=prompting_branch,
53
+ prompting_type=prompting_type,
54
+ use_hsf=use_hsf,
55
+ k_clusters=k_clusters
56
+ ).to(device)
57
+
58
+
59
+ def process_image(image, text, options):
60
+ # Load the model based on selected options
61
+ if 'MVTec AD+Colondb' in options:
62
+ model.load(ckt_path1)
63
+ elif 'VisA+Clinicdb' in options:
64
+ model.load(ckt_path2)
65
+ elif 'All' in options:
66
+ model.load(ckt_path3)
67
+ else:
68
+ # Default to 'All' if no valid option is provided
69
+ model.load(ckt_path3)
70
+ print('Invalid option. Defaulting to All.')
71
+
72
+ # Ensure image is in RGB mode
73
+ image = image.convert('RGB')
74
+
75
+ # Convert PIL image to NumPy array
76
+ np_image = np.array(image)
77
+
78
+ # Convert RGB to BGR for OpenCV
79
+ np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
80
+ np_image = cv2.resize(np_image, (image_size, image_size))
81
+ # Preprocess the image and run the model
82
+ img_input = model.preprocess(image).unsqueeze(0)
83
+ img_input = img_input.to(model.device)
84
+
85
+ with torch.no_grad():
86
+ anomaly_map, anomaly_score = model.clip_model(img_input, [text], aggregation=True)
87
+
88
+ # Process anomaly map
89
+ anomaly_map = anomaly_map[0, :, :].cpu().numpy()
90
+ anomaly_score = anomaly_score[0].cpu().numpy()
91
+ anomaly_map = gaussian_filter(anomaly_map, sigma=4)
92
+ anomaly_map = (anomaly_map * 255).astype(np.uint8)
93
+
94
+ # Apply color map and blend with original image
95
+ heat_map = cv2.applyColorMap(anomaly_map, cv2.COLORMAP_JET)
96
+ vis_map = cv2.addWeighted(heat_map, 0.5, np_image, 0.5, 0)
97
+
98
+ # Convert OpenCV image back to PIL image for Gradio
99
+ vis_map_pil = Image.fromarray(cv2.cvtColor(vis_map, cv2.COLOR_BGR2RGB))
100
+
101
+ return vis_map_pil, f'{anomaly_score:.3f}'
102
+
103
+ # Define examples
104
+ examples = [
105
+ ["asset/img.png", "candle", "MVTec AD+Colondb"],
106
+ ["asset/img2.png", "bottle", "VisA+Clinicdb"],
107
+ ["asset/img3.png", "button", "All"],
108
+ ]
109
+
110
+ # Gradio interface layout
111
+ demo = gr.Interface(
112
+ fn=process_image,
113
+ inputs=[
114
+ gr.Image(type="pil", label="Upload Image"),
115
+ gr.Textbox(label="Class Name"),
116
+ gr.Radio(["MVTec AD+Colondb",
117
+ "VisA+Clinicdb",
118
+ "All"],
119
+ label="Pre-trained Datasets")
120
+ ],
121
+ outputs=[
122
+ gr.Image(type="pil", label="Output Image"),
123
+ gr.Textbox(label="Anomaly Score"),
124
+ ],
125
+ examples=examples,
126
+ title="AdaCLIP -- Zero-shot Anomaly Detection",
127
+ description="Upload an image, enter class name, and select pre-trained datasets to do zero-shot anomaly detection"
128
+ )
129
+
130
+ # Launch the demo
131
+ demo.launch()
132
+ # demo.launch(server_name="0.0.0.0", server_port=10002)
133
+
asset/Fig_app.png ADDED

Git LFS Details

  • SHA256: f71ab8be0e45353c1660526ff450754e82ddf4a2b7f18bb5a33ac3b704b0d76b
  • Pointer size: 131 Bytes
  • Size of remote file: 269 kB
asset/Fig_detection_results.png ADDED

Git LFS Details

  • SHA256: c00bd303a99d981d964b12e981bd1f2954d469766839523e76f7d7162fbb24cb
  • Pointer size: 131 Bytes
  • Size of remote file: 363 kB
asset/Table_industrial.png ADDED

Git LFS Details

  • SHA256: 5fa4d9ab1ff1b3ca90b45f4b92ee7b12a89e5327cb22621d4081fb5f160d3d68
  • Pointer size: 131 Bytes
  • Size of remote file: 402 kB
asset/Table_medical.png ADDED

Git LFS Details

  • SHA256: d2424190619dbbd134b943ef9e38a6523635ab0d279f2445da6bdd266d3dafac
  • Pointer size: 131 Bytes
  • Size of remote file: 291 kB
asset/framework.png ADDED

Git LFS Details

  • SHA256: 3804c7f5ae141257dbe5dd43cb20f4216a1061051fd8754d6f0c730dd085ad7d
  • Pointer size: 131 Bytes
  • Size of remote file: 440 kB
asset/img.png ADDED

Git LFS Details

  • SHA256: 3eaff97d07132f9b06998737b976d4a0e0a3a2168b40aee43aad6e62d040f87e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.42 MB
asset/img2.png ADDED

Git LFS Details

  • SHA256: a3918b94553a8922b3c16d064ef73e9062710b35639a949c56d926037e4c0d0a
  • Pointer size: 131 Bytes
  • Size of remote file: 548 kB
asset/img3.png ADDED

Git LFS Details

  • SHA256: 9394757293585aa9de542f3e70025788e5a3e1ad5a1277a8648f8050f8d7e868
  • Pointer size: 131 Bytes
  • Size of remote file: 624 kB
config.py ADDED
@@ -0,0 +1 @@
 
 
1
+ DATA_ROOT = '../datasets'
data_preprocess/br35h.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ from config import DATA_ROOT
5
+
6
+ Br35h_ROOT = os.path.join(DATA_ROOT, 'Br35h_anomaly_detection')
7
+ class Br35hSolver(object):
8
+ CLSNAMES = [
9
+ 'br35h',
10
+ ]
11
+
12
+ def __init__(self, root=Br35h_ROOT, train_ratio=0.5):
13
+ self.root = root
14
+ self.meta_path = f'{root}/meta.json'
15
+ self.train_ratio = train_ratio
16
+
17
+ def run(self):
18
+ self.generate_meta_info()
19
+
20
+ def generate_meta_info(self):
21
+ info = dict(train={}, test={})
22
+ for cls_name in self.CLSNAMES:
23
+ cls_dir = f'{self.root}/{cls_name}'
24
+ for phase in ['train', 'test']:
25
+ cls_info = []
26
+ species = os.listdir(f'{cls_dir}/{phase}')
27
+ for specie in species:
28
+ is_abnormal = True if specie not in ['good'] else False
29
+ img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
30
+ img_names.sort()
31
+
32
+ for idx, img_name in enumerate(img_names):
33
+ info_img = dict(
34
+ img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
35
+ mask_path=f'',
36
+ cls_name=cls_name,
37
+ specie_name=specie,
38
+ anomaly=1 if is_abnormal else 0,
39
+ )
40
+ cls_info.append(info_img)
41
+
42
+ info[phase][cls_name] = cls_info
43
+
44
+ with open(self.meta_path, 'w') as f:
45
+ f.write(json.dumps(info, indent=4) + "\n")
46
+
47
+
48
+ if __name__ == '__main__':
49
+ runner = Br35hSolver(root=Br35h_ROOT)
50
+ runner.run()
data_preprocess/brain_mri.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ from config import DATA_ROOT
5
+
6
+ BrainMRI_ROOT = os.path.join(DATA_ROOT, 'BrainMRI')
7
+
8
+ class BrainMRISolver(object):
9
+ CLSNAMES = [
10
+ 'brain_mri',
11
+ ]
12
+
13
+ def __init__(self, root=BrainMRI_ROOT, train_ratio=0.5):
14
+ self.root = root
15
+ self.meta_path = f'{root}/meta.json'
16
+ self.train_ratio = train_ratio
17
+
18
+ def run(self):
19
+ self.generate_meta_info()
20
+
21
+ def generate_meta_info(self):
22
+ info = dict(train={}, test={})
23
+ for cls_name in self.CLSNAMES:
24
+ cls_dir = f'{self.root}/{cls_name}'
25
+ for phase in ['train', 'test']:
26
+ cls_info = []
27
+ species = os.listdir(f'{cls_dir}/{phase}')
28
+ for specie in species:
29
+ is_abnormal = True if specie not in ['good'] else False
30
+ img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
31
+ img_names.sort()
32
+
33
+ for idx, img_name in enumerate(img_names):
34
+ info_img = dict(
35
+ img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
36
+ mask_path=f'',
37
+ cls_name=cls_name,
38
+ specie_name=specie,
39
+ anomaly=1 if is_abnormal else 0,
40
+ )
41
+ cls_info.append(info_img)
42
+
43
+ info[phase][cls_name] = cls_info
44
+
45
+ with open(self.meta_path, 'w') as f:
46
+ f.write(json.dumps(info, indent=4) + "\n")
47
+
48
+
49
+ if __name__ == '__main__':
50
+ runner = BrainMRISolver(root=BrainMRI_ROOT)
51
+ runner.run()
data_preprocess/btad.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ from config import DATA_ROOT
5
+
6
+ BTAD_ROOT = os.path.join(DATA_ROOT, 'BTech_Dataset_transformed')
7
+
8
+ class BTADSolver(object):
9
+ CLSNAMES = [
10
+ '01', '02', '03',
11
+ ]
12
+
13
+ def __init__(self, root=BTAD_ROOT, train_ratio=0.5):
14
+ self.root = root
15
+ self.meta_path = f'{root}/meta.json'
16
+ self.train_ratio = train_ratio
17
+
18
+ def run(self):
19
+ self.generate_meta_info()
20
+
21
+ def generate_meta_info(self):
22
+ info = dict(train={}, test={})
23
+ for cls_name in self.CLSNAMES:
24
+ cls_dir = f'{self.root}/{cls_name}'
25
+ for phase in ['train', 'test']:
26
+ cls_info = []
27
+ species = os.listdir(f'{cls_dir}/{phase}')
28
+ for specie in species:
29
+ is_abnormal = True if specie not in ['ok'] else False
30
+ img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
31
+ mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
32
+ img_names.sort()
33
+ mask_names.sort() if mask_names is not None else None
34
+ for idx, img_name in enumerate(img_names):
35
+ info_img = dict(
36
+ img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
37
+ mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
38
+ cls_name=cls_name,
39
+ specie_name=specie,
40
+ anomaly=1 if is_abnormal else 0,
41
+ )
42
+ cls_info.append(info_img)
43
+
44
+ info[phase][cls_name] = cls_info
45
+
46
+ with open(self.meta_path, 'w') as f:
47
+ f.write(json.dumps(info, indent=4) + "\n")
48
+
49
+
50
+ if __name__ == '__main__':
51
+ runner = BTADSolver(root=BTAD_ROOT)
52
+ runner.run()
data_preprocess/clinicdb.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ from config import DATA_ROOT
5
+
6
+ ClinicDB_ROOT = os.path.join(DATA_ROOT, 'CVC-ClinicDB')
7
+
8
+ class ClinicDBSolver(object):
9
+ CLSNAMES = [
10
+ 'ClinicDB',
11
+ ]
12
+
13
+ def __init__(self, root=ClinicDB_ROOT, train_ratio=0.5):
14
+ self.root = root
15
+ self.meta_path = f'{root}/meta.json'
16
+ self.train_ratio = train_ratio
17
+
18
+ def run(self):
19
+ self.generate_meta_info()
20
+
21
+ def generate_meta_info(self):
22
+ info = dict(train={}, test={})
23
+ for cls_name in self.CLSNAMES:
24
+ cls_dir = f'{self.root}/{cls_name}'
25
+ for phase in ['train', 'test']:
26
+ cls_info = []
27
+ species = os.listdir(f'{cls_dir}/{phase}')
28
+ for specie in species:
29
+ is_abnormal = True if specie not in ['good'] else False
30
+ img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
31
+ mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
32
+ img_names.sort()
33
+ mask_names.sort() if mask_names is not None else None
34
+ for idx, img_name in enumerate(img_names):
35
+ info_img = dict(
36
+ img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
37
+ mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
38
+ cls_name=cls_name,
39
+ specie_name=specie,
40
+ anomaly=1 if is_abnormal else 0,
41
+ )
42
+ cls_info.append(info_img)
43
+
44
+ info[phase][cls_name] = cls_info
45
+
46
+ with open(self.meta_path, 'w') as f:
47
+ f.write(json.dumps(info, indent=4) + "\n")
48
+
49
+
50
+ if __name__ == '__main__':
51
+ runner = ClinicDBSolver(root=ClinicDB_ROOT)
52
+ runner.run()
data_preprocess/colondb.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ from config import DATA_ROOT
5
+
6
+ ColonDB_ROOT = os.path.join(DATA_ROOT, 'CVC-ColonDB')
7
+
8
+ class ColonDBSolver(object):
9
+ CLSNAMES = [
10
+ 'ColonDB',
11
+ ]
12
+
13
+ def __init__(self, root=ColonDB_ROOT, train_ratio=0.5):
14
+ self.root = root
15
+ self.meta_path = f'{root}/meta.json'
16
+ self.train_ratio = train_ratio
17
+
18
+ def run(self):
19
+ self.generate_meta_info()
20
+
21
+ def generate_meta_info(self):
22
+ info = dict(train={}, test={})
23
+ for cls_name in self.CLSNAMES:
24
+ cls_dir = f'{self.root}/{cls_name}'
25
+ for phase in ['train', 'test']:
26
+ cls_info = []
27
+ species = os.listdir(f'{cls_dir}/{phase}')
28
+ for specie in species:
29
+ is_abnormal = True if specie not in ['good'] else False
30
+ img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
31
+ mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
32
+ img_names.sort()
33
+ mask_names.sort() if mask_names is not None else None
34
+ for idx, img_name in enumerate(img_names):
35
+ info_img = dict(
36
+ img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
37
+ mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
38
+ cls_name=cls_name,
39
+ specie_name=specie,
40
+ anomaly=1 if is_abnormal else 0,
41
+ )
42
+ cls_info.append(info_img)
43
+
44
+ info[phase][cls_name] = cls_info
45
+
46
+ with open(self.meta_path, 'w') as f:
47
+ f.write(json.dumps(info, indent=4) + "\n")
48
+
49
+
50
+ if __name__ == '__main__':
51
+ runner = ColonDBSolver(root=ColonDB_ROOT)
52
+ runner.run()
data_preprocess/dagm-pre.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from sklearn.model_selection import train_test_split
4
+ import cv2
5
+ import argparse
6
+ from config import DATA_ROOT
7
+
8
+ dataset_root = os.path.join(DATA_ROOT, 'DAGM2007')
9
+
10
+ class_names = os.listdir(dataset_root)
11
+
12
+
13
+ for class_name in class_names:
14
+ states = os.listdir(os.path.join(dataset_root, class_name))
15
+ for state in states:
16
+ images = list()
17
+ mask = list()
18
+ files = os.listdir(os.path.join(dataset_root, class_name,state))
19
+ for f in files:
20
+ if 'PNG' in f[-3:]:
21
+ images.append(f)
22
+ files = os.listdir(os.path.join(dataset_root, class_name, state,'Label'))
23
+ for f in files:
24
+ if 'PNG' in f[-3:]:
25
+ mask.append(f)
26
+ normal_image_path_train = list()
27
+ normal_image_path_test = list()
28
+ normal_image_path = list()
29
+ abnormal_image_path = list()
30
+ abnormal_image_label = list()
31
+ for f in images:
32
+ id = f[-8:-4]
33
+ flag = 0
34
+ for y in mask:
35
+ if id in y:
36
+ abnormal_image_path.append(f)
37
+ abnormal_image_label.append(y)
38
+ flag = 1
39
+ break
40
+ if flag == 0:
41
+ normal_image_path.append(f)
42
+
43
+ if len(abnormal_image_path) != len(abnormal_image_label):
44
+ raise ValueError
45
+ length = len(abnormal_image_path)
46
+
47
+ normal_image_path_test = normal_image_path[:length]
48
+ normal_image_path_train = normal_image_path[length:]
49
+
50
+ target_root = '../datasets/DAGM_anomaly_detection'
51
+
52
+ train_root = os.path.join(target_root, class_name, 'train','good')
53
+ if not os.path.exists(train_root):
54
+ os.makedirs(train_root)
55
+ for f in normal_image_path_train:
56
+ image_data = cv2.imread(os.path.join(dataset_root, class_name, state,f))
57
+ cv2.imwrite(os.path.join(train_root,f), image_data)
58
+
59
+ test_root = os.path.join(target_root, class_name, 'test','good')
60
+ if not os.path.exists(test_root):
61
+ os.makedirs(test_root)
62
+ for f in normal_image_path_test:
63
+ image_data = cv2.imread(os.path.join(dataset_root, class_name, state,f))
64
+ cv2.imwrite(os.path.join(test_root,f), image_data)
65
+
66
+ test_root = os.path.join(target_root, class_name, 'test','defect')
67
+ if not os.path.exists(test_root):
68
+ os.makedirs(test_root)
69
+ for f in abnormal_image_path:
70
+ image_data = cv2.imread(os.path.join(dataset_root, class_name, state,f))
71
+ cv2.imwrite(os.path.join(test_root,f), image_data)
72
+
73
+ test_root = os.path.join(target_root, class_name, 'ground_truth','defect')
74
+ if not os.path.exists(test_root):
75
+ os.makedirs(test_root)
76
+ for f in mask:
77
+ image_data = cv2.imread(os.path.join(dataset_root, class_name, state,'Label',f))
78
+ cv2.imwrite(os.path.join(test_root,f), image_data)
79
+
80
+
81
+
82
+ print("Done")
data_preprocess/dagm.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ from config import DATA_ROOT
5
+
6
+ DAGM_ROOT = os.path.join(DATA_ROOT, 'DAGM_anomaly_detection')
7
+
8
+ class DAGMSolver(object):
9
+ CLSNAMES = [
10
+ 'Class1', 'Class2', 'Class3', 'Class4', 'Class5','Class6','Class7','Class8','Class9','Class10',
11
+ ]
12
+
13
+ def __init__(self, root=DAGM_ROOT, train_ratio=0.5):
14
+ self.root = root
15
+ self.meta_path = f'{root}/meta.json'
16
+ self.train_ratio = train_ratio
17
+
18
+ def run(self):
19
+ self.generate_meta_info()
20
+
21
+ def generate_meta_info(self):
22
+ info = dict(train={}, test={})
23
+ for cls_name in self.CLSNAMES:
24
+ cls_dir = f'{self.root}/{cls_name}'
25
+ for phase in ['train', 'test']:
26
+ cls_info = []
27
+ species = os.listdir(f'{cls_dir}/{phase}')
28
+ for specie in species:
29
+ is_abnormal = True if specie not in ['good'] else False
30
+ img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
31
+ mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
32
+ img_names.sort()
33
+ mask_names.sort() if mask_names is not None else None
34
+ for idx, img_name in enumerate(img_names):
35
+ info_img = dict(
36
+ img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
37
+ mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
38
+ cls_name=cls_name,
39
+ specie_name=specie,
40
+ anomaly=1 if is_abnormal else 0,
41
+ )
42
+ cls_info.append(info_img)
43
+
44
+ info[phase][cls_name] = cls_info
45
+
46
+ with open(self.meta_path, 'w') as f:
47
+ f.write(json.dumps(info, indent=4) + "\n")
48
+
49
+
50
+ if __name__ == '__main__':
51
+ runner = DAGMSolver(root=DAGM_ROOT)
52
+ runner.run()
data_preprocess/dtd.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ from config import DATA_ROOT
5
+
6
+ DTD_ROOT = os.path.join(DATA_ROOT, 'DTD-Synthetic')
7
+
8
+ class DTDSolver(object):
9
+ CLSNAMES = [
10
+ 'Blotchy_099', 'Fibrous_183', 'Marbled_078', 'Matted_069', 'Mesh_114','Perforated_037','Stratified_154','Woven_001','Woven_068','Woven_104','Woven_125','Woven_127',
11
+ ]
12
+
13
+ def __init__(self, root=DTD_ROOT, train_ratio=0.5):
14
+ self.root = root
15
+ self.meta_path = f'{root}/meta.json'
16
+ self.train_ratio = train_ratio
17
+
18
+ def run(self):
19
+ self.generate_meta_info()
20
+
21
+ def generate_meta_info(self):
22
+ info = dict(train={}, test={})
23
+ for cls_name in self.CLSNAMES:
24
+ cls_dir = f'{self.root}/{cls_name}'
25
+ for phase in ['train', 'test']:
26
+ cls_info = []
27
+ species = os.listdir(f'{cls_dir}/{phase}')
28
+ for specie in species:
29
+ is_abnormal = True if specie not in ['good'] else False
30
+ img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
31
+ mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
32
+ img_names.sort()
33
+ mask_names.sort() if mask_names is not None else None
34
+ for idx, img_name in enumerate(img_names):
35
+ info_img = dict(
36
+ img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
37
+ mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
38
+ cls_name=cls_name,
39
+ specie_name=specie,
40
+ anomaly=1 if is_abnormal else 0,
41
+ )
42
+ cls_info.append(info_img)
43
+
44
+ info[phase][cls_name] = cls_info
45
+
46
+ with open(self.meta_path, 'w') as f:
47
+ f.write(json.dumps(info, indent=4) + "\n")
48
+
49
+
50
+ if __name__ == '__main__':
51
+ runner = DTDSolver(root=DTD_ROOT)
52
+ runner.run()
data_preprocess/endo.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ from config import DATA_ROOT
5
+
6
+ ENDO_ROOT = os.path.join(DATA_ROOT, 'EndoTect')
7
+
8
+ class ENDOSolver(object):
9
+ CLSNAMES = [
10
+ 'endo',
11
+ ]
12
+
13
+ def __init__(self, root=ENDO_ROOT, train_ratio=0.5):
14
+ self.root = root
15
+ self.meta_path = f'{root}/meta.json'
16
+ self.train_ratio = train_ratio
17
+
18
+ def run(self):
19
+ self.generate_meta_info()
20
+
21
+ def generate_meta_info(self):
22
+ info = dict(train={}, test={})
23
+ for cls_name in self.CLSNAMES:
24
+ cls_dir = f'{self.root}/{cls_name}'
25
+ for phase in ['train', 'test']:
26
+ cls_info = []
27
+ species = os.listdir(f'{cls_dir}/{phase}')
28
+ for specie in species:
29
+ is_abnormal = True if specie not in ['good'] else False
30
+ img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
31
+ mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
32
+ img_names.sort()
33
+ mask_names.sort() if mask_names is not None else None
34
+ for idx, img_name in enumerate(img_names):
35
+ info_img = dict(
36
+ img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
37
+ mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
38
+ cls_name=cls_name,
39
+ specie_name=specie,
40
+ anomaly=1 if is_abnormal else 0,
41
+ )
42
+ cls_info.append(info_img)
43
+
44
+ info[phase][cls_name] = cls_info
45
+
46
+ with open(self.meta_path, 'w') as f:
47
+ f.write(json.dumps(info, indent=4) + "\n")
48
+
49
+
50
+ if __name__ == '__main__':
51
+ runner = ENDOSolver(root=ENDO_ROOT)
52
+ runner.run()
data_preprocess/headct-pre.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from sklearn.model_selection import train_test_split
4
+ import shutil
5
+ import argparse
6
+
7
+ from config import DATA_ROOT
8
+
9
+ dataset_root = os.path.join(DATA_ROOT, 'head_ct')
10
+
11
+ label_file = os.path.join(dataset_root, 'labels.csv')
12
+
13
+ data = np.loadtxt(label_file, dtype=int, delimiter=',', skiprows=1)
14
+
15
+ fnames = data[:, 0]
16
+ label = data[:, 1]
17
+
18
+ normal_fnames = fnames[label==0]
19
+ outlier_fnames = fnames[label==1]
20
+
21
+
22
+ target_root = '../datasets/HeadCT_anomaly_detection/headct'
23
+ train_root = os.path.join(target_root, 'train/good')
24
+ if not os.path.exists(train_root):
25
+ os.makedirs(train_root)
26
+
27
+ test_normal_root = os.path.join(target_root, 'test/good')
28
+ if not os.path.exists(test_normal_root):
29
+ os.makedirs(test_normal_root)
30
+ for f in normal_fnames:
31
+ source = os.path.join(dataset_root, 'head_ct/', '{:0>3d}.png'.format(f))
32
+ shutil.copy(source, test_normal_root)
33
+
34
+ test_outlier_root = os.path.join(target_root, 'test/defect')
35
+ if not os.path.exists(test_outlier_root):
36
+ os.makedirs(test_outlier_root)
37
+ for f in outlier_fnames:
38
+ source = os.path.join(dataset_root, 'head_ct/', '{:0>3d}.png'.format(f))
39
+ shutil.copy(source, test_outlier_root)
40
+
41
+ print('Done')
data_preprocess/headct.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ # from dataset import MPDD_ROOT
5
+ # from dataset.mpdd import MPDD_ROOT
6
+
7
+
8
+ HEADCT_ROOT = '../datasets/HeadCT_anomaly_detection'
9
+ class HEADCTSolver(object):
10
+ CLSNAMES = [
11
+ 'headct',
12
+ ]
13
+
14
+ def __init__(self, root=HEADCT_ROOT, train_ratio=0.5):
15
+ self.root = root
16
+ self.meta_path = f'{root}/meta.json'
17
+ self.train_ratio = train_ratio
18
+
19
+ def run(self):
20
+ self.generate_meta_info()
21
+
22
+ def generate_meta_info(self):
23
+ info = dict(train={}, test={})
24
+ for cls_name in self.CLSNAMES:
25
+ cls_dir = f'{self.root}/{cls_name}'
26
+ for phase in ['train', 'test']:
27
+ cls_info = []
28
+ species = os.listdir(f'{cls_dir}/{phase}')
29
+ for specie in species:
30
+ is_abnormal = True if specie not in ['good'] else False
31
+ img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
32
+ img_names.sort()
33
+
34
+ for idx, img_name in enumerate(img_names):
35
+ info_img = dict(
36
+ img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
37
+ mask_path=f'',
38
+ cls_name=cls_name,
39
+ specie_name=specie,
40
+ anomaly=1 if is_abnormal else 0,
41
+ )
42
+ cls_info.append(info_img)
43
+
44
+ info[phase][cls_name] = cls_info
45
+
46
+ with open(self.meta_path, 'w') as f:
47
+ f.write(json.dumps(info, indent=4) + "\n")
48
+
49
+
50
+ if __name__ == '__main__':
51
+ runner = HEADCTSolver(root=HEADCT_ROOT)
52
+ runner.run()
data_preprocess/isic.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ from config import DATA_ROOT
5
+
6
+ ISIC_ROOT = os.path.join(DATA_ROOT, 'ISIC')
7
+
8
+ class ISICSolver(object):
9
+ CLSNAMES = [
10
+ 'isic',
11
+ ]
12
+
13
+ def __init__(self, root=ISIC_ROOT, train_ratio=0.5):
14
+ self.root = root
15
+ self.meta_path = f'{root}/meta.json'
16
+ self.train_ratio = train_ratio
17
+
18
+ def run(self):
19
+ self.generate_meta_info()
20
+
21
+ def generate_meta_info(self):
22
+ info = dict(train={}, test={})
23
+ for cls_name in self.CLSNAMES:
24
+ cls_dir = f'{self.root}/{cls_name}'
25
+ for phase in ['train', 'test']:
26
+ cls_info = []
27
+ species = os.listdir(f'{cls_dir}/{phase}')
28
+ for specie in species:
29
+ is_abnormal = True if specie not in ['good'] else False
30
+ img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
31
+ mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
32
+ img_names.sort()
33
+ mask_names.sort() if mask_names is not None else None
34
+ for idx, img_name in enumerate(img_names):
35
+ info_img = dict(
36
+ img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
37
+ mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
38
+ cls_name=cls_name,
39
+ specie_name=specie,
40
+ anomaly=1 if is_abnormal else 0,
41
+ )
42
+ cls_info.append(info_img)
43
+
44
+ info[phase][cls_name] = cls_info
45
+
46
+ with open(self.meta_path, 'w') as f:
47
+ f.write(json.dumps(info, indent=4) + "\n")
48
+
49
+
50
+ if __name__ == '__main__':
51
+ runner = ISICSolver(root=ISIC_ROOT)
52
+ runner.run()
data_preprocess/mpdd.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ from config import DATA_ROOT
5
+
6
+ MPDD_ROOT = os.path.join(DATA_ROOT, 'MPDD')
7
+
8
+ class MPDDSolver(object):
9
+ CLSNAMES = [
10
+ 'bracket_black', 'bracket_brown', 'bracket_white', 'connector', 'metal_plate','tubes',
11
+ ]
12
+
13
+ def __init__(self, root=MPDD_ROOT, train_ratio=0.5):
14
+ self.root = root
15
+ self.meta_path = f'{root}/meta.json'
16
+ self.train_ratio = train_ratio
17
+
18
+ def run(self):
19
+ self.generate_meta_info()
20
+
21
+ def generate_meta_info(self):
22
+ info = dict(train={}, test={})
23
+ for cls_name in self.CLSNAMES:
24
+ cls_dir = f'{self.root}/{cls_name}'
25
+ for phase in ['train', 'test']:
26
+ cls_info = []
27
+ species = os.listdir(f'{cls_dir}/{phase}')
28
+ for specie in species:
29
+ is_abnormal = True if specie not in ['good'] else False
30
+ img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
31
+ mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
32
+ img_names.sort()
33
+ mask_names.sort() if mask_names is not None else None
34
+ for idx, img_name in enumerate(img_names):
35
+ info_img = dict(
36
+ img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
37
+ mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
38
+ cls_name=cls_name,
39
+ specie_name=specie,
40
+ anomaly=1 if is_abnormal else 0,
41
+ )
42
+ cls_info.append(info_img)
43
+
44
+ info[phase][cls_name] = cls_info
45
+
46
+ with open(self.meta_path, 'w') as f:
47
+ f.write(json.dumps(info, indent=4) + "\n")
48
+
49
+
50
+ if __name__ == '__main__':
51
+ runner = MPDDSolver(root=MPDD_ROOT)
52
+ runner.run()
data_preprocess/mvtec.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ from dataset import MVTEC_ROOT
5
+
6
+ class MVTecSolver(object):
7
+ CLSNAMES = [
8
+ 'bottle', 'cable', 'capsule', 'carpet', 'grid',
9
+ 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw',
10
+ 'tile', 'toothbrush', 'transistor', 'wood', 'zipper',
11
+ ]
12
+
13
+ def __init__(self, root=MVTEC_ROOT, train_ratio=0.5):
14
+ self.root = root
15
+ self.meta_path = f'{root}/meta.json'
16
+ self.train_ratio = train_ratio
17
+
18
+ def run(self):
19
+ self.generate_meta_info()
20
+
21
+ def generate_meta_info(self):
22
+ info = dict(train={}, test={})
23
+ for cls_name in self.CLSNAMES:
24
+ cls_dir = f'{self.root}/{cls_name}'
25
+ for phase in ['train', 'test']:
26
+ cls_info = []
27
+ species = os.listdir(f'{cls_dir}/{phase}')
28
+ for specie in species:
29
+ is_abnormal = True if specie not in ['good'] else False
30
+ img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
31
+ mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
32
+ img_names.sort()
33
+ mask_names.sort() if mask_names is not None else None
34
+ for idx, img_name in enumerate(img_names):
35
+ info_img = dict(
36
+ img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
37
+ mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
38
+ cls_name=cls_name,
39
+ specie_name=specie,
40
+ anomaly=1 if is_abnormal else 0,
41
+ )
42
+ cls_info.append(info_img)
43
+
44
+ info[phase][cls_name] = cls_info
45
+
46
+ with open(self.meta_path, 'w') as f:
47
+ f.write(json.dumps(info, indent=4) + "\n")
48
+
49
+
50
+ if __name__ == '__main__':
51
+ runner = MVTecSolver(root=MVTEC_ROOT)
52
+ runner.run()
data_preprocess/sdd-pre.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from sklearn.model_selection import train_test_split
4
+ import cv2
5
+ import argparse
6
+
7
+ from config import DATA_ROOT
8
+
9
+ dataset_root = os.path.join(DATA_ROOT, 'KolektorSDD')
10
+
11
+ dirs = os.listdir(dataset_root)
12
+ normal_images = list()
13
+ normal_labels = list()
14
+ normal_fname = list()
15
+ outlier_images = list()
16
+ outlier_labels = list()
17
+ outlier_fname = list()
18
+ for d in dirs:
19
+ files = os.listdir(os.path.join(dataset_root, d))
20
+ images = list()
21
+ for f in files:
22
+ if 'jpg' in f[-3:]:
23
+ images.append(f)
24
+
25
+ for image in images:
26
+ split_images = list()
27
+ split_labels = list()
28
+ image_name = image.split('.')[0]
29
+ image_data = cv2.imread(os.path.join(dataset_root, d, image))
30
+ label_data = cv2.imread(os.path.join(dataset_root, d, image_name + '_label.bmp'))
31
+ if image_data.shape != label_data.shape:
32
+ raise ValueError
33
+ image_length = image_data.shape[0]
34
+ split_images.append(image_data[:image_length // 3, :, :])
35
+ split_images.append(image_data[image_length // 3:image_length * 2 // 3, :, :])
36
+ split_images.append(image_data[image_length * 2 // 3:, :, :])
37
+ split_labels.append(label_data[:image_length // 3, :, :])
38
+ split_labels.append(label_data[image_length // 3:image_length * 2 // 3, :, :])
39
+ split_labels.append(label_data[image_length * 2 // 3:, :, :])
40
+ for i, (im, la) in enumerate(zip(split_images, split_labels)):
41
+ if np.max(la) != 0:
42
+ outlier_images.append(im)
43
+ outlier_labels.append(la)
44
+ outlier_fname.append(d + '_' + image_name + '_' + str(i))
45
+ else:
46
+ normal_images.append(im)
47
+ normal_labels.append(la)
48
+ normal_fname.append(d + '_' + image_name + '_' + str(i))
49
+
50
+ normal_train, normal_test, normal_name_train, normal_name_test = train_test_split(normal_images, normal_fname, test_size=0.25, random_state=42)
51
+
52
+ target_root = '../datasets/SDD_anomaly_detection/SDD'
53
+ train_root = os.path.join(target_root, 'train/good')
54
+ if not os.path.exists(train_root):
55
+ os.makedirs(train_root)
56
+ for image, name in zip(normal_train, normal_name_train):
57
+ cv2.imwrite(os.path.join(train_root, name + '.png'), image)
58
+
59
+ test_root = os.path.join(target_root, 'test/good')
60
+ if not os.path.exists(test_root):
61
+ os.makedirs(test_root)
62
+ for image, name in zip(normal_test, normal_name_test):
63
+ cv2.imwrite(os.path.join(test_root, name + '.png'), image)
64
+
65
+ defect_root = os.path.join(target_root, 'test/defect')
66
+ label_root = os.path.join(target_root, 'ground_truth/defect')
67
+ if not os.path.exists(defect_root):
68
+ os.makedirs(defect_root)
69
+ if not os.path.exists(label_root):
70
+ os.makedirs(label_root)
71
+ for image, label, name in zip(outlier_images, outlier_labels, outlier_fname):
72
+ cv2.imwrite(os.path.join(defect_root, name + '.png'), image)
73
+ cv2.imwrite(os.path.join(label_root, name + '_mask.png'), label)
74
+
75
+ print("Done")
data_preprocess/sdd.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ from config import DATA_ROOT
5
+
6
+ SDD_ROOT = os.path.join(DATA_ROOT, 'SDD_anomaly_detection')
7
+
8
+ class SDDSolver(object):
9
+ CLSNAMES = [
10
+ 'SDD',
11
+ ]
12
+
13
+ def __init__(self, root=SDD_ROOT, train_ratio=0.5):
14
+ self.root = root
15
+ self.meta_path = f'{root}/meta.json'
16
+ self.train_ratio = train_ratio
17
+
18
+ def run(self):
19
+ self.generate_meta_info()
20
+
21
+ def generate_meta_info(self):
22
+ info = dict(train={}, test={})
23
+ for cls_name in self.CLSNAMES:
24
+ cls_dir = f'{self.root}/{cls_name}'
25
+ for phase in ['train', 'test']:
26
+ cls_info = []
27
+ species = os.listdir(f'{cls_dir}/{phase}')
28
+ for specie in species:
29
+ is_abnormal = True if specie not in ['good'] else False
30
+ img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
31
+ mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
32
+ img_names.sort()
33
+ mask_names.sort() if mask_names is not None else None
34
+ for idx, img_name in enumerate(img_names):
35
+ info_img = dict(
36
+ img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
37
+ mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
38
+ cls_name=cls_name,
39
+ specie_name=specie,
40
+ anomaly=1 if is_abnormal else 0,
41
+ )
42
+ cls_info.append(info_img)
43
+
44
+ info[phase][cls_name] = cls_info
45
+
46
+ with open(self.meta_path, 'w') as f:
47
+ f.write(json.dumps(info, indent=4) + "\n")
48
+
49
+
50
+ if __name__ == '__main__':
51
+ runner = SDDSolver(root=SDD_ROOT)
52
+ runner.run()
data_preprocess/tn3k.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ from config import DATA_ROOT
5
+
6
+ TN3K_ROOT = os.path.join(DATA_ROOT, 'TN3K')
7
+
8
+ class TN3KSolver(object):
9
+ CLSNAMES = [
10
+ 'tn3k',
11
+ ]
12
+
13
+ def __init__(self, root=TN3K_ROOT, train_ratio=0.5):
14
+ self.root = root
15
+ self.meta_path = f'{root}/meta.json'
16
+ self.train_ratio = train_ratio
17
+
18
+ def run(self):
19
+ self.generate_meta_info()
20
+
21
+ def generate_meta_info(self):
22
+ info = dict(train={}, test={})
23
+ for cls_name in self.CLSNAMES:
24
+ cls_dir = f'{self.root}/{cls_name}'
25
+ for phase in ['train', 'test']:
26
+ cls_info = []
27
+ species = os.listdir(f'{cls_dir}/{phase}')
28
+ for specie in species:
29
+ is_abnormal = True if specie not in ['good'] else False
30
+ img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
31
+ mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
32
+ img_names.sort()
33
+ mask_names.sort() if mask_names is not None else None
34
+ for idx, img_name in enumerate(img_names):
35
+ info_img = dict(
36
+ img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
37
+ mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
38
+ cls_name=cls_name,
39
+ specie_name=specie,
40
+ anomaly=1 if is_abnormal else 0,
41
+ )
42
+ cls_info.append(info_img)
43
+
44
+ info[phase][cls_name] = cls_info
45
+
46
+ with open(self.meta_path, 'w') as f:
47
+ f.write(json.dumps(info, indent=4) + "\n")
48
+
49
+
50
+ if __name__ == '__main__':
51
+ runner = TN3KSolver(root=TN3K_ROOT)
52
+ runner.run()
data_preprocess/visa.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import pandas as pd
4
+ import random
5
+ from dataset import VISA_ROOT
6
+
7
+ class VisASolver(object):
8
+ CLSNAMES = [
9
+ 'candle', 'capsules', 'cashew', 'chewinggum', 'fryum',
10
+ 'macaroni1', 'macaroni2', 'pcb1', 'pcb2', 'pcb3',
11
+ 'pcb4', 'pipe_fryum',
12
+ ]
13
+
14
+ def __init__(self, root=VISA_ROOT, train_ratio=0.5):
15
+ self.root = root
16
+ self.meta_path = f'{root}/meta.json'
17
+ self.phases = ['train', 'test']
18
+ self.csv_data = pd.read_csv(f'{root}/split_csv/1cls.csv', header=0)
19
+ self.train_ratio = train_ratio
20
+
21
+ def run(self):
22
+ self.generate_meta_info()
23
+
24
+ def generate_meta_info(self):
25
+ columns = self.csv_data.columns # [object, split, label, image, mask]
26
+ info = {phase: {} for phase in self.phases}
27
+ for cls_name in self.CLSNAMES:
28
+ cls_data = self.csv_data[self.csv_data[columns[0]] == cls_name]
29
+ for phase in self.phases:
30
+ cls_info = []
31
+ cls_data_phase = cls_data[cls_data[columns[1]] == phase]
32
+ cls_data_phase.index = list(range(len(cls_data_phase)))
33
+ for idx in range(cls_data_phase.shape[0]):
34
+ data = cls_data_phase.loc[idx]
35
+ is_abnormal = True if data[2] == 'anomaly' else False
36
+ info_img = dict(
37
+ img_path=data[3],
38
+ mask_path=data[4] if is_abnormal else '',
39
+ cls_name=cls_name,
40
+ specie_name='',
41
+ anomaly=1 if is_abnormal else 0,
42
+ )
43
+ cls_info.append(info_img)
44
+ info[phase][cls_name] = cls_info
45
+ with open(self.meta_path, 'w') as f:
46
+ f.write(json.dumps(info, indent=4) + "\n")
47
+
48
+
49
+
50
+ if __name__ == '__main__':
51
+ runner = VisASolver(root=VISA_ROOT)
52
+ runner.run()
dataset/__init__.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .mvtec import MVTEC_CLS_NAMES, MVTecDataset, MVTEC_ROOT
2
+ from .visa import VISA_CLS_NAMES, VisaDataset, VISA_ROOT
3
+ from .mpdd import MPDD_CLS_NAMES, MPDDDataset, MPDD_ROOT
4
+ from .btad import BTAD_CLS_NAMES, BTADDataset, BTAD_ROOT
5
+ from .sdd import SDD_CLS_NAMES, SDDDataset, SDD_ROOT
6
+ from .dagm import DAGM_CLS_NAMES, DAGMDataset, DAGM_ROOT
7
+ from .dtd import DTD_CLS_NAMES,DTDDataset,DTD_ROOT
8
+ from .isic import ISIC_CLS_NAMES,ISICDataset,ISIC_ROOT
9
+ from .colondb import ColonDB_CLS_NAMES, ColonDBDataset, ColonDB_ROOT
10
+ from .clinicdb import ClinicDB_CLS_NAMES, ClinicDBDataset, ClinicDB_ROOT
11
+ from .tn3k import TN3K_CLS_NAMES, TN3KDataset, TN3K_ROOT
12
+ from .headct import HEADCT_CLS_NAMES,HEADCTDataset,HEADCT_ROOT
13
+ from .brain_mri import BrainMRI_CLS_NAMES,BrainMRIDataset,BrainMRI_ROOT
14
+ from .br35h import Br35h_CLS_NAMES,Br35hDataset,Br35h_ROOT
15
+ from torch.utils.data import ConcatDataset
16
+
17
+ dataset_dict = {
18
+ 'br35h': (Br35h_CLS_NAMES, Br35hDataset, Br35h_ROOT),
19
+ 'brain_mri': (BrainMRI_CLS_NAMES, BrainMRIDataset, BrainMRI_ROOT),
20
+ 'btad': (BTAD_CLS_NAMES, BTADDataset, BTAD_ROOT),
21
+ 'clinicdb': (ClinicDB_CLS_NAMES, ClinicDBDataset, ClinicDB_ROOT),
22
+ 'colondb': (ColonDB_CLS_NAMES, ColonDBDataset, ColonDB_ROOT),
23
+ 'dagm': (DAGM_CLS_NAMES, DAGMDataset, DAGM_ROOT),
24
+ 'dtd': (DTD_CLS_NAMES, DTDDataset, DTD_ROOT),
25
+ 'headct': (HEADCT_CLS_NAMES, HEADCTDataset, HEADCT_ROOT),
26
+ 'isic': (ISIC_CLS_NAMES, ISICDataset, ISIC_ROOT),
27
+ 'mpdd': (MPDD_CLS_NAMES, MPDDDataset, MPDD_ROOT),
28
+ 'mvtec': (MVTEC_CLS_NAMES, MVTecDataset, MVTEC_ROOT),
29
+ 'sdd': (SDD_CLS_NAMES, SDDDataset, SDD_ROOT),
30
+ 'tn3k': (TN3K_CLS_NAMES, TN3KDataset, TN3K_ROOT),
31
+ 'visa': (VISA_CLS_NAMES, VisaDataset, VISA_ROOT),
32
+ }
33
+
34
+ def get_data(dataset_type_list, transform, target_transform, training):
35
+ if not isinstance(dataset_type_list, list):
36
+ dataset_type_list = [dataset_type_list]
37
+
38
+ dataset_cls_names_list = []
39
+ dataset_instance_list = []
40
+ dataset_root_list = []
41
+ for dataset_type in dataset_type_list:
42
+ if dataset_dict.get(dataset_type, ''):
43
+ dataset_cls_names, dataset_instance, dataset_root = dataset_dict[dataset_type]
44
+ dataset_instance = dataset_instance(
45
+ clsnames=dataset_cls_names,
46
+ transform=transform,
47
+ target_transform=target_transform,
48
+ training=training
49
+ )
50
+
51
+ dataset_cls_names_list.append(dataset_cls_names)
52
+ dataset_instance_list.append(dataset_instance)
53
+ dataset_root_list.append(dataset_root)
54
+
55
+ else:
56
+ print(f'Only support {list(dataset_dict.keys())}, but entered {dataset_type}...')
57
+ raise NotImplementedError
58
+
59
+ if len(dataset_type_list) > 1:
60
+ dataset_instance = ConcatDataset(dataset_instance_list)
61
+ dataset_cls_names = dataset_cls_names_list
62
+ dataset_root = dataset_root_list
63
+ else:
64
+ dataset_instance = dataset_instance_list[0]
65
+ dataset_cls_names = dataset_cls_names_list[0]
66
+ dataset_root = dataset_root_list[0]
67
+
68
+ return dataset_cls_names, dataset_instance, dataset_root
dataset/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (2.67 kB). View file
 
dataset/__pycache__/br35h.cpython-39.pyc ADDED
Binary file (3.38 kB). View file
 
dataset/__pycache__/brain_mri.cpython-39.pyc ADDED
Binary file (3.38 kB). View file
 
dataset/__pycache__/btad.cpython-39.pyc ADDED
Binary file (3.59 kB). View file
 
dataset/__pycache__/clinicdb.cpython-39.pyc ADDED
Binary file (3.74 kB). View file
 
dataset/__pycache__/colondb.cpython-39.pyc ADDED
Binary file (3.6 kB). View file
 
dataset/__pycache__/dagm.cpython-39.pyc ADDED
Binary file (3.66 kB). View file
 
dataset/__pycache__/dtd.cpython-39.pyc ADDED
Binary file (3.7 kB). View file
 
dataset/__pycache__/headct.cpython-39.pyc ADDED
Binary file (3.37 kB). View file
 
dataset/__pycache__/isic.cpython-39.pyc ADDED
Binary file (3.56 kB). View file
 
dataset/__pycache__/mpdd.cpython-39.pyc ADDED
Binary file (3.63 kB). View file
 
dataset/__pycache__/mvtec.cpython-39.pyc ADDED
Binary file (3.71 kB). View file
 
dataset/__pycache__/sdd.cpython-39.pyc ADDED
Binary file (3.57 kB). View file
 
dataset/__pycache__/tn3k.cpython-39.pyc ADDED
Binary file (3.56 kB). View file
 
dataset/__pycache__/visa.cpython-39.pyc ADDED
Binary file (2.59 kB). View file
 
dataset/base_dataset.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base class for our zero-shot anomaly detection dataset
3
+ """
4
+ import json
5
+ import os
6
+ import random
7
+ import numpy as np
8
+ import torch.utils.data as data
9
+ from PIL import Image
10
+ import cv2
11
+ from config import DATA_ROOT
12
+
13
+
14
+ class DataSolver:
15
+ def __init__(self, root, clsnames):
16
+ self.root = root
17
+ self.clsnames = clsnames
18
+ self.path = os.path.join(root, 'meta.json')
19
+
20
+ def run(self):
21
+ with open(self.path, 'r') as f:
22
+ info = json.load(f)
23
+
24
+ info_required = dict(train={}, test={})
25
+ for cls in self.clsnames:
26
+ for k in info.keys():
27
+ info_required[k][cls] = info[k][cls]
28
+
29
+ return info_required
30
+
31
+
32
+ class BaseDataset(data.Dataset):
33
+ def __init__(self, clsnames, transform, target_transform, root, aug_rate=0., training=True):
34
+ self.root = root
35
+ self.transform = transform
36
+ self.target_transform = target_transform
37
+ self.aug_rate = aug_rate
38
+ self.training = training
39
+ self.data_all = []
40
+ self.cls_names = clsnames
41
+
42
+ solver = DataSolver(root, clsnames)
43
+ meta_info = solver.run()
44
+
45
+ self.meta_info = meta_info['test'] # Only utilize the test dataset for both training and testing
46
+ for cls_name in self.cls_names:
47
+ self.data_all.extend(self.meta_info[cls_name])
48
+
49
+ self.length = len(self.data_all)
50
+
51
+ def __len__(self):
52
+ return self.length
53
+
54
+ def combine_img(self, cls_name):
55
+ """
56
+ From April-GAN: https://github.com/ByChelsea/VAND-APRIL-GAN
57
+ Here we combine four images into a single image for data augmentation.
58
+ """
59
+ img_info = random.sample(self.meta_info[cls_name], 4)
60
+
61
+ img_ls = []
62
+ mask_ls = []
63
+
64
+ for data in img_info:
65
+ img_path = os.path.join(self.root, data['img_path'])
66
+ mask_path = os.path.join(self.root, data['mask_path'])
67
+
68
+ img = Image.open(img_path).convert('RGB')
69
+ img_ls.append(img)
70
+
71
+ if not data['anomaly']:
72
+ img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
73
+ else:
74
+ img_mask = np.array(Image.open(mask_path).convert('L')) > 0
75
+ img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
76
+
77
+ mask_ls.append(img_mask)
78
+
79
+ # Image
80
+ image_width, image_height = img_ls[0].size
81
+ result_image = Image.new("RGB", (2 * image_width, 2 * image_height))
82
+ for i, img in enumerate(img_ls):
83
+ row = i // 2
84
+ col = i % 2
85
+ x = col * image_width
86
+ y = row * image_height
87
+ result_image.paste(img, (x, y))
88
+
89
+ # Mask
90
+ result_mask = Image.new("L", (2 * image_width, 2 * image_height))
91
+ for i, img in enumerate(mask_ls):
92
+ row = i // 2
93
+ col = i % 2
94
+ x = col * image_width
95
+ y = row * image_height
96
+ result_mask.paste(img, (x, y))
97
+
98
+ return result_image, result_mask
99
+
100
+ def __getitem__(self, index):
101
+ data = self.data_all[index]
102
+ img_path = os.path.join(self.root, data['img_path'])
103
+ mask_path = os.path.join(self.root, data['mask_path'])
104
+ cls_name = data['cls_name']
105
+ anomaly = data['anomaly']
106
+ random_number = random.random()
107
+
108
+ if self.training and random_number < self.aug_rate:
109
+ img, img_mask = self.combine_img(cls_name)
110
+ else:
111
+ if img_path.endswith('.tif'):
112
+ img = cv2.imread(img_path)
113
+ img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
114
+ else:
115
+ img = Image.open(img_path).convert('RGB')
116
+ if anomaly == 0:
117
+ img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
118
+ else:
119
+ if data['mask_path']:
120
+ img_mask = np.array(Image.open(mask_path).convert('L')) > 0
121
+ img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
122
+ else:
123
+ img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
124
+ # Transforms
125
+ if self.transform is not None:
126
+ img = self.transform(img)
127
+ if self.target_transform is not None and img_mask is not None:
128
+ img_mask = self.target_transform(img_mask)
129
+ if img_mask is None:
130
+ img_mask = []
131
+
132
+ return {
133
+ 'img': img,
134
+ 'img_mask': img_mask,
135
+ 'cls_name': cls_name,
136
+ 'anomaly': anomaly,
137
+ 'img_path': img_path
138
+ }
dataset/br35h.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .base_dataset import BaseDataset
3
+ from config import DATA_ROOT
4
+
5
+ '''dataset source: https://www.kaggle.com/datasets/ahmedhamada0/brain-tumor-detection'''
6
+
7
+ Br35h_CLS_NAMES = [
8
+ 'br35h',
9
+ ]
10
+ Br35h_ROOT = os.path.join(DATA_ROOT, 'Br35h_anomaly_detection')
11
+
12
+ class Br35hDataset(BaseDataset):
13
+ def __init__(self, transform, target_transform, clsnames=Br35h_CLS_NAMES, aug_rate=0.0, root=Br35h_ROOT, training=True):
14
+ super(Br35hDataset, self).__init__(
15
+ clsnames=clsnames, transform=transform, target_transform=target_transform,
16
+ root=root, aug_rate=aug_rate, training=training
17
+ )
18
+