Image Classification
mlx-image
Safetensors
MLX
vision
riccardomusmeci commited on
Commit
bfa5ced
·
verified ·
1 Parent(s): 34784d3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +65 -1
README.md CHANGED
@@ -1,3 +1,67 @@
1
  ---
2
- license: mit
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: apache-2.0
3
+ tags:
4
+ - mlx
5
+ - mlx-image
6
+ - vision
7
+ - image-classification
8
+ datasets:
9
+ - imagenet-1k
10
+ library_name: mlx-image
11
  ---
12
+ # vit_base_patch16_384.swag_e2e
13
+
14
+ A [Vision Transformer](https://arxiv.org/abs/2010.11929v2) image classification model. Weights are learned with [SWAG](https://arxiv.org/abs/2201.08371) on ImageNet-1k data.
15
+
16
+ Disclaimer: This is a porting of the torchvision model weights to Apple MLX Framework.
17
+
18
+
19
+ ## How to use
20
+ ```bash
21
+ pip install mlx-image
22
+ ```
23
+
24
+ Here is how to use this model for image classification:
25
+
26
+ ```python
27
+ from mlxim.model import create_model
28
+ from mlxim.io import read_rgb
29
+ from mlxim.transform import ImageNetTransform
30
+
31
+ transform = ImageNetTransform(train=False, img_size=224)
32
+ x = transform(read_rgb("cat.png"))
33
+ x = mx.expand_dims(x, 0)
34
+
35
+ model = create_model("vit_base_patch16_384.swag_e2e")
36
+ model.eval()
37
+
38
+ logits = model(x)
39
+ ```
40
+
41
+ You can also use the embeds from layer before head:
42
+ ```python
43
+ from mlxim.model import create_model
44
+ from mlxim.io import read_rgb
45
+ from mlxim.transform import ImageNetTransform
46
+
47
+ transform = ImageNetTransform(train=False, img_size=224)
48
+ x = transform(read_rgb("cat.png"))
49
+ x = mx.expand_dims(x, 0)
50
+
51
+ # first option
52
+ model = create_model("vit_base_patch16_384.swag_e2e", num_classes=0)
53
+ model.eval()
54
+
55
+ embeds = model(x)
56
+
57
+ # second option
58
+ model = create_model("vit_base_patch16_384.swag_e2e")
59
+ model.eval()
60
+
61
+ embeds = model.features(x)
62
+ ```
63
+
64
+
65
+ ## Model Comparison
66
+
67
+ Explore the metrics of this model in [mlx-image model results](https://github.com/riccardomusmeci/mlx-image/blob/main/results/results-imagenet-1k.csv).