File size: 2,609 Bytes
b134866 351ee24 98f2c2f c1ddbf7 98f2c2f 351ee24 98f2c2f e983036 98f2c2f 351ee24 98f2c2f 351ee24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
---
license: apache-2.0
---
[STFPM](https://github.com/openvinotoolkit/anomalib/tree/main/anomalib/models/stfpm) model from [Anomalib](https://github.com/openvinotoolkit/anomalib) fine-tuned for capsule category of the [MVTec dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad). Checkpoint trained using the following [notebook](https://github.com/openvinotoolkit/anomalib/blob/main/notebooks/000_getting_started/001_getting_started.ipynb).
```
ββββββββββββββββββββββββββββββββββββββββββββββββββ
Test metric DataLoader 0
ββββββββββββββββββββββββββββββββββββββββββββββββββ
image_AUROC 0.9541285037994385
image_F1Score 0.9680365324020386
pixel_AUROC 0.9857622385025024
pixel_F1Score 0.4696350395679474
ββββββββββββββββββββββββββββββββββββββββββββββββββ
```
The main intent is to use it in samples and demos for model optimization. Here is the advantages:
- MVTec dataset can automatically downloaded and is quite small.
- The model from the anomaly detection domain such as STFPM is sensitive to the optimization methods to allows demonstrate methods with accuracy controll.
Here is the code to test the checkpoint:
```python
from pytorch_lightning import Trainer
from anomalib.config import get_configurable_parameters
from anomalib.data import get_datamodule
from anomalib.models import get_model
from anomalib.utils.callbacks import LoadModelCallback, get_callbacks
CHECKPOINT_URL = 'https://huggingface.co/alexsu52/stfpm_mvtec_capsule/resolve/main/pytorch_model.bin'
CHECKPOINT_PATH = '~/pytorch_model.bin'
#Download CHECKPOINT_URL to CHECKPOINT_PATH
config = get_configurable_parameters(config_path="./anomalib/models/stfpm/config.yaml")
config["dataset"]["path"] = <path_to_dataset>
config['dataset']['category'] = 'capsule'
datamodule = get_datamodule(config)
datamodule.setup() # Downloads the dataset if it's not in the specified `root` directory
datamodule.prepare_data() # Create train/val/test/prediction sets.
model = get_model(config)
callbacks = get_callbacks(config)
load_model_callback = LoadModelCallback(weights_path=CHECKPOINT_PATH)
callbacks.insert(0, load_model_callback)
trainer = Trainer(**config.trainer, callbacks=callbacks)
trainer.test(model=model, datamodule=datamodule)
```
|