|
--- |
|
license: apache-2.0 |
|
--- |
|
[SFTPM](https://github.com/openvinotoolkit/anomalib/tree/main/anomalib/models/stfpm) model from [Anomalib](https://github.com/openvinotoolkit/anomalib) fine-tuned for capsule category of the [MVTec dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad). Checkpoint trained using the following [notebook](https://github.com/openvinotoolkit/anomalib/blob/main/notebooks/000_getting_started/001_getting_started.ipynb). |
|
|
|
``` |
|
ββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
Test metric DataLoader 0 |
|
ββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
image_AUROC 0.8436378240585327 |
|
image_F1Score 0.9356223344802856 |
|
pixel_AUROC 0.9719913601875305 |
|
pixel_F1Score 0.41566985845565796 |
|
ββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
``` |
|
|
|
The main intent is to use it in samples and demos for model optimization. Here is the advantages: |
|
- MVTec dataset can automatically downloaded and is quite small. |
|
- The model from the anomaly detection domain such as SFTPM is sensitive to the optimization methods to allows demonstrate methods with accuracy controll. |
|
|
|
Here is the code to test the checkpoint: |
|
|
|
```python |
|
from pytorch_lightning import Trainer |
|
from anomalib.config import get_configurable_parameters |
|
from anomalib.data import get_datamodule |
|
from anomalib.models import get_model |
|
from anomalib.utils.callbacks import LoadModelCallback, get_callbacks |
|
|
|
CHECKPOINT_URL = 'https://huggingface.co/alexsu52/stfpm_mvtec_capsule/resolve/main/pytorch_model.bin' |
|
CHECKPOINT_PATH = '~/pytorch_model.bin' |
|
|
|
#Download CHECKPOINT_URL to CHECKPOINT_PATH |
|
|
|
config = get_configurable_parameters(config_path="./anomalib/models/sftpm/config.yaml") |
|
config["dataset"]["path"] = <path_to_dataset> |
|
config['dataset']['category'] = 'capsule' |
|
|
|
datamodule = get_datamodule(config) |
|
datamodule.setup() # Downloads the dataset if it's not in the specified `root` directory |
|
datamodule.prepare_data() # Create train/val/test/prediction sets. |
|
|
|
model = get_model(config) |
|
|
|
callbacks = get_callbacks(config) |
|
load_model_callback = LoadModelCallback(weights_path=CHECKPOINT_PATH) |
|
callbacks.insert(0, load_model_callback) |
|
|
|
trainer = Trainer(**config.trainer, callbacks=callbacks) |
|
trainer.test(model=model, datamodule=datamodule) |
|
``` |