desc / app.py
PhilHolst's picture
Update app.py
f3bbc2a
raw
history blame contribute delete
No virus
1.11 kB
import gradio as gr
import requests
from io import BytesIO
from PIL import Image
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
# Load GPT-2 model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
def generate_caption(image):
# Preprocess image
response = requests.get(image)
img = Image.open(BytesIO(response.content)).convert('RGB')
img = img.resize((224, 224))
# Generate caption using GPT-2
input_text = "This is an image of " + tokenizer.decode(tokenizer.encode(image)) + ". "
input_ids = tokenizer.encode(input_text, return_tensors='pt')
output = model.generate(input_ids=input_ids, max_length=200, do_sample=True)
caption = tokenizer.decode(output[0], skip_special_tokens=True)
return caption
# Create Gradio interface
inputs = gr.inputs.Image()
outputs = gr.outputs.Textbox()
gr.Interface(fn=generate_caption, inputs=inputs, outputs=outputs, title='Image Captioning with GPT-2', description='Upload an image and get a detailed caption generated by GPT-2.').launch()