|
import streamlit as st |
|
from PIL import Image |
|
import torch |
|
from transformers import CLIPProcessor, CLIPModel |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") |
|
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16") |
|
|
|
|
|
def predict_image_description(image): |
|
|
|
inputs = processor(text=["a photo of an animal", "a photo of a human", "a photo of a car", "a photo of a tree", "a photo of a house"], |
|
images=image, |
|
return_tensors="pt", |
|
padding=True) |
|
|
|
|
|
outputs = model(**inputs) |
|
logits_per_image = outputs.logits_per_image |
|
probs = logits_per_image.softmax(dim=1) |
|
|
|
|
|
top_3_probabilities, top_3_indices = torch.topk(probs, 3) |
|
labels = ["an animal", "a human", "a car", "a tree", "a house"] |
|
|
|
predictions = [] |
|
for i in range(3): |
|
prediction = labels[top_3_indices[0][i]] |
|
probability = top_3_probabilities[0][i].item() |
|
predictions.append(f"{prediction}: {probability * 100:.2f}%") |
|
|
|
return predictions |
|
|
|
|
|
st.title("Real-Time Image-to-Text Generator") |
|
st.markdown("Upload an image, and I will tell you what it is!") |
|
|
|
|
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_file is not None: |
|
|
|
image = Image.open(uploaded_file) |
|
|
|
|
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
|
|
|
predictions = predict_image_description(image) |
|
|
|
|
|
st.write("Predictions:") |
|
for prediction in predictions: |
|
st.write(prediction) |