PhilHolst commited on
Commit
f3bbc2a
1 Parent(s): 4c67f38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -44
app.py CHANGED
@@ -1,45 +1,30 @@
1
- from transformers import ViTFeatureExtractor, ViTForImageClassification
2
- from PIL import Image
3
- import requests
4
  import gradio as gr
5
- import os
6
-
7
- feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
8
-
9
- model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
10
-
11
- def inference(image):
12
- inputs = feature_extractor(images=image, return_tensors="pt")
13
- outputs = model(**inputs)
14
- logits = outputs.logits
15
- # model predicts one of the 1000 ImageNet classes
16
- predicted_class_idx = logits.argmax(-1).item()
17
- print(type(model.config.id2label[predicted_class_idx]))
18
- return "Predicted class:"+model.config.id2label[predicted_class_idx]
19
-
20
- demo = gr.Blocks()
21
-
22
- with demo:
23
- gr.Markdown(
24
- """
25
- # Welcome to this Replit Template for Gradio!
26
- Start by adding a image, this demo uses google/vit-base-patch16-224 model from Hugging Face model Hub for a image classification demo, for more details read the [model card on Hugging Face](https://huggingface.co/google/vit-base-patch16-224)
27
- """)
28
- inp = gr.Image(type="pil")
29
- out = gr.Label()
30
-
31
- button = gr.Button(value="Run")
32
- gr.Examples(
33
- examples=[os.path.join(os.path.dirname(__file__), "lion.jpeg")],
34
- inputs=inp,
35
- outputs=out,
36
- fn=inference,
37
- cache_examples=False)
38
-
39
- button.click(fn=inference,
40
- inputs=inp,
41
- outputs=out)
42
-
43
-
44
-
45
- demo.launch(share=True)
 
 
 
 
1
  import gradio as gr
2
+ import requests
3
+ from io import BytesIO
4
+ from PIL import Image
5
+ import torch
6
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
7
+
8
+ # Load GPT-2 model and tokenizer
9
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
10
+ model = GPT2LMHeadModel.from_pretrained('gpt2')
11
+
12
+ def generate_caption(image):
13
+ # Preprocess image
14
+ response = requests.get(image)
15
+ img = Image.open(BytesIO(response.content)).convert('RGB')
16
+ img = img.resize((224, 224))
17
+
18
+ # Generate caption using GPT-2
19
+ input_text = "This is an image of " + tokenizer.decode(tokenizer.encode(image)) + ". "
20
+ input_ids = tokenizer.encode(input_text, return_tensors='pt')
21
+ output = model.generate(input_ids=input_ids, max_length=200, do_sample=True)
22
+ caption = tokenizer.decode(output[0], skip_special_tokens=True)
23
+
24
+ return caption
25
+
26
+ # Create Gradio interface
27
+ inputs = gr.inputs.Image()
28
+ outputs = gr.outputs.Textbox()
29
+
30
+ gr.Interface(fn=generate_caption, inputs=inputs, outputs=outputs, title='Image Captioning with GPT-2', description='Upload an image and get a detailed caption generated by GPT-2.').launch()