PhilHolst commited on
Commit
b9e51b5
1 Parent(s): d618452

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ViTFeatureExtractor, ViTForImageClassification
2
+ from PIL import Image
3
+ import requests
4
+ import gradio as gr
5
+ import os
6
+
7
+ feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
8
+
9
+ model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
10
+
11
+ def inference(image):
12
+ inputs = feature_extractor(images=image, return_tensors="pt")
13
+ outputs = model(**inputs)
14
+ logits = outputs.logits
15
+ # model predicts one of the 1000 ImageNet classes
16
+ predicted_class_idx = logits.argmax(-1).item()
17
+ print(type(model.config.id2label[predicted_class_idx]))
18
+ return "Predicted class:"+model.config.id2label[predicted_class_idx]
19
+
20
+ demo = gr.Blocks()
21
+
22
+ with demo:
23
+ gr.Markdown(
24
+ """
25
+ # Welcome to this Replit Template for Gradio!
26
+ Start by adding a image, this demo uses google/vit-base-patch16-224 model from Hugging Face model Hub for a image classification demo, for more details read the [model card on Hugging Face](https://huggingface.co/google/vit-base-patch16-224)
27
+ """)
28
+ inp = gr.Image(type="pil")
29
+ out = gr.Label()
30
+
31
+ button = gr.Button(value="Run")
32
+ gr.Examples(
33
+ examples=[os.path.join(os.path.dirname(__file__), "lion.jpeg")],
34
+ inputs=inp,
35
+ outputs=out,
36
+ fn=inference,
37
+ cache_examples=False)
38
+
39
+ button.click(fn=inference,
40
+ inputs=inp,
41
+ outputs=out)
42
+
43
+
44
+
45
+ demo.launch(share=True)