riccardomusmeci commited on
Commit
531604b
·
verified ·
1 Parent(s): 39bc2cd

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +61 -64
README.md CHANGED
@@ -1,67 +1,64 @@
1
- ---
2
- {}
3
- ---
4
-
5
- ---
6
- license: apache-2.0
7
- tags:
8
- - mlx
9
- - mlx-image
10
- - vision
11
- - image-classification
12
- datasets:
13
- - imagenet-1k
14
- library_name: mlx-image
15
- ---
16
- # regnet_y_800mf
17
-
18
- A RegNetY-800MF image classification model. Pretrained in ImageNet by torchvision contributors (see ImageNet1K-V2 weight details https://github.com/pytorch/vision/issues/3995#new-recipe).
19
-
20
- Disclaimer: This is a porting of the torch model weights to Apple MLX Framework.
21
-
22
- ## How to use
23
- ```bash
24
- pip install mlx-image
25
- ```
26
-
27
- Here is how to use this model for image classification:
28
-
29
- ```python
30
- from mlxim.model import create_model
31
- from mlxim.io import read_rgb
32
- from mlxim.transform import ImageNetTransform
33
-
34
- transform = ImageNetTransform(train=False, img_size=224)
35
- x = transform(read_rgb("cat.png"))
36
- x = mx.expand_dims(x, 0)
37
 
38
- model = create_model("regnet_y_800mf")
39
- model.eval()
40
-
41
- logits = model(x)
42
- ```
43
-
44
- You can also use the embeds from layer before head:
45
- ```python
46
- from mlxim.model import create_model
47
- from mlxim.io import read_rgb
48
- from mlxim.transform import ImageNetTransform
49
-
50
- transform = ImageNetTransform(train=False, img_size=224)
51
- x = transform(read_rgb("cat.png"))
52
- x = mx.expand_dims(x, 0)
53
-
54
- # first option
55
- model = create_model("regnet_y_800mf", num_classes=0)
56
- model.eval()
57
-
58
- embeds = model(x)
59
-
60
- # second option
61
- model = create_model("regnet_y_800mf")
62
- model.eval()
63
-
64
- embeds = model.get_features(x)
65
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
+ ---
3
+ license: apache-2.0
4
+ tags:
5
+ - mlx
6
+ - mlx-image
7
+ - vision
8
+ - image-classification
9
+ datasets:
10
+ - imagenet-1k
11
+ library_name: mlx-image
12
+ ---
13
+ # regnet_y_800mf
14
+
15
+ A RegNetY-800MF image classification model. Pretrained in ImageNet by torchvision contributors (see ImageNet1K-V2 weight details https://github.com/pytorch/vision/issues/3995#new-recipe).
16
+
17
+ Disclaimer: This is a porting of the torch model weights to Apple MLX Framework.
18
+
19
+ ## How to use
20
+ ```bash
21
+ pip install mlx-image
22
+ ```
23
+
24
+ Here is how to use this model for image classification:
25
+
26
+ ```python
27
+ from mlxim.model import create_model
28
+ from mlxim.io import read_rgb
29
+ from mlxim.transform import ImageNetTransform
30
+
31
+ transform = ImageNetTransform(train=False, img_size=224)
32
+ x = transform(read_rgb("cat.png"))
33
+ x = mx.expand_dims(x, 0)
34
+
35
+ model = create_model("regnet_y_800mf")
36
+ model.eval()
37
+
38
+ logits = model(x)
39
+ ```
40
+
41
+ You can also use the embeds from layer before head:
42
+ ```python
43
+ from mlxim.model import create_model
44
+ from mlxim.io import read_rgb
45
+ from mlxim.transform import ImageNetTransform
46
+
47
+ transform = ImageNetTransform(train=False, img_size=224)
48
+ x = transform(read_rgb("cat.png"))
49
+ x = mx.expand_dims(x, 0)
50
+
51
+ # first option
52
+ model = create_model("regnet_y_800mf", num_classes=0)
53
+ model.eval()
54
+
55
+ embeds = model(x)
56
+
57
+ # second option
58
+ model = create_model("regnet_y_800mf")
59
+ model.eval()
60
+
61
+ embeds = model.get_features(x)
62
+ ```
63
 
64