Image Classification
mlx-image
Safetensors
MLX
vision
riccardomusmeci commited on
Commit
244c461
1 Parent(s): a149422

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +67 -0
README.md CHANGED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - mlx
5
+ - mlx-image
6
+ - vision
7
+ - image-classification
8
+ datasets:
9
+ - imagenet-1k
10
+ library_name: mlx-image
11
+ ---
12
+ # Model card for vit_base_patch16_224.swag_lin
13
+
14
+ A Vision Transformer [https://arxiv.org/abs/2010.11929v2] image classification model. Weights are learned with (SWAG)[<https://arxiv.org/abs/2201.08371] on ImageNet-1k data.
15
+
16
+ Disclaimer: This is a porting of the torchvision model weights to Apple MLX Framework.
17
+
18
+
19
+ ## How to use
20
+ ```bash
21
+ pip install mlx-image
22
+ ```
23
+
24
+ Here is how to use this model for image classification:
25
+
26
+ ```python
27
+ from mlxim.model import create_model
28
+ from mlxim.io import read_rgb
29
+ from mlxim.transform import ImageNetTransform
30
+
31
+ transform = ImageNetTransform(train=False, img_size=224)
32
+ x = transform(read_rgb("cat.png"))
33
+ x = mx.expand_dims(x, 0)
34
+
35
+ model = create_model("vit_base_patch16_224.swag_lin")
36
+ model.eval()
37
+
38
+ logits = model(x)
39
+ ```
40
+
41
+ You can also use the embeds from last conv layer:
42
+ ```python
43
+ from mlxim.model import create_model
44
+ from mlxim.io import read_rgb
45
+ from mlxim.transform import ImageNetTransform
46
+
47
+ transform = ImageNetTransform(train=False, img_size=224)
48
+ x = transform(read_rgb("cat.png"))
49
+ x = mx.expand_dims(x, 0)
50
+
51
+ # first option
52
+ model = create_model("vit_base_patch16_224.swag_lin", num_classes=0)
53
+ model.eval()
54
+
55
+ embeds = model(x)
56
+
57
+ # second option
58
+ model = create_model("vit_base_patch16_224.swag_lin")
59
+ model.eval()
60
+
61
+ embeds = model.features(x)
62
+ ```
63
+
64
+
65
+ ### Model Comparison
66
+
67
+ Explore the metrics of this model in (mlx-image model results)[https://github.com/riccardomusmeci/mlx-image/blob/main/results/results-imagenet-1k.csv].