Spaces:
Sleeping
Sleeping
import spaces | |
import requests | |
from PIL import Image | |
from io import BytesIO | |
import torch | |
from transformers import CLIPProcessor, CLIPModel | |
import gradio as gr | |
# Initialize the model and processor | |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
# Use the GPU decorator for the function that requires GPU | |
def get_embedding(image_or_text): | |
# Define device within the function to ensure it uses the GPU when available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
if image_or_text.startswith(('http:', 'https:')): | |
# Image URL | |
response = requests.get(image_or_text) | |
image = Image.open(BytesIO(response.content)) | |
inputs = processor(images=image, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
features = model.get_image_features(**inputs).cpu().numpy() | |
else: | |
# Text input | |
inputs = processor(text=[image_or_text], return_tensors="pt", padding=True).to(device) | |
with torch.no_grad(): | |
features = model.get_text_features(**inputs).cpu().numpy() | |
return features.flatten().tolist() | |
# Define the Gradio interface | |
interface = gr.Interface(fn=get_embedding, | |
inputs="text", | |
outputs="json", | |
title="CLIP Model Embeddings", | |
description="Enter an Image URL or text to get embeddings from CLIP.") | |
if __name__ == "__main__": | |
interface.launch(share=True) | |