Files changed (1) hide show
  1. README.md +53 -3
README.md CHANGED
@@ -1,3 +1,53 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ # AGE-ViT
6
+ Age-classifying Generative Entity Vision Transformer
7
+
8
+ A Vision Transformer finetuned to classify images of human faces into 'minor' or 'adult'.
9
+
10
+ This model is a finetuned version of https://huggingface.co/nateraw/vit-age-classifier which was finetuned on the fairface dataset.
11
+
12
+ ## Datasets
13
+ These datasets were used in finetuning, with fairface finetuning the classifier we built on top of.
14
+
15
+ ### FairFace dataset
16
+ https://github.com/dchen236/FairFace
17
+
18
+ This is a balanced dataset for race, gender, and age and was initial intended for bias mitigation. The majority of the images in this dataset are direct and front facing.
19
+
20
+ ### Synthetic Dataset
21
+ https://civitai.com/models/668458/synthetic-human-dataset
22
+
23
+ This dataset was fully generated by flux and contains 15k images of men, women, boys, and girls from the front, side, and slightly above. This dataset will be expanded with sd15 images and the model will be retrained.
24
+
25
+ To use the model
26
+
27
+ ```
28
+ import requests
29
+ from PIL import Image
30
+ from io import BytesIO
31
+
32
+ from transformers import ViTImageProcessor, ViTForImageClassification
33
+
34
+ # Get example image from official fairface repo + read it in as an image
35
+ r = requests.get('https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/9488af10-7f1f-4361-877b-d9cfafeab131/original=true,quality=90/24599129.jpeg')
36
+ im = Image.open(BytesIO(r.content))
37
+
38
+ model_dir = 'civitai/age-vit'
39
+
40
+ # Init model, transforms
41
+ model = ViTForImageClassification.from_pretrained(model_dir)
42
+ transforms = ViTFeatureExtractor.from_pretrained(model_dir)
43
+
44
+ # Transform our image and pass it through the model
45
+ inputs = transforms(im, return_tensors='pt')
46
+ output = model(**inputs)
47
+
48
+ # Predicted Class probabilities
49
+ proba = output.logits.softmax(1)
50
+
51
+ # Predicted Classes
52
+ preds = proba.argmax(1)
53
+ ```