user-agent's picture
Update app.py
4c6f845 verified
import spaces
import requests
from PIL import Image
from io import BytesIO
import torch
from transformers import CLIPProcessor, CLIPModel
import gradio as gr
# Initialize the model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
@spaces.GPU # Use the GPU decorator for the function that requires GPU
def get_embedding(image_or_text):
# Define device within the function to ensure it uses the GPU when available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
if image_or_text.startswith(('http:', 'https:')):
# Image URL
response = requests.get(image_or_text)
image = Image.open(BytesIO(response.content))
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
features = model.get_image_features(**inputs).cpu().numpy()
else:
# Text input
inputs = processor(text=[image_or_text], return_tensors="pt", padding=True).to(device)
with torch.no_grad():
features = model.get_text_features(**inputs).cpu().numpy()
return features.flatten().tolist()
# Define the Gradio interface
interface = gr.Interface(fn=get_embedding,
inputs="text",
outputs="json",
title="CLIP Model Embeddings",
description="Enter an Image URL or text to get embeddings from CLIP.")
if __name__ == "__main__":
interface.launch(share=True)