Spaces:
Runtime error
Runtime error
import os | |
from huggingface_hub import login | |
login(os.environ['hf_token']) | |
from transformers import CLIPConfig, CLIPModel | |
from torch import nn | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
def load_distillclip(model_id, revision=None): | |
ckpt_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", revision=revision) | |
config = CLIPConfig.from_pretrained(model_id) | |
model = CLIPModel(config) | |
model.vision_model.embeddings.patch_embedding = nn.Conv2d( | |
in_channels=model.config.vision_config.num_channels, | |
out_channels=model.vision_model.embeddings.embed_dim, | |
kernel_size=model.vision_model.embeddings.patch_size, | |
stride=model.vision_model.embeddings.patch_size, | |
bias=True, | |
) | |
model.vision_model.pre_layrnorm = nn.Identity() | |
print(model.load_state_dict({k.removeprefix('student.'): v for k, v in load_file(ckpt_path).items()})) | |
return model | |
import torch | |
from torch import nn | |
from einops import reduce | |
from tqdm.auto import tqdm | |
class ZeroShotCLIP(nn.Module): | |
def __init__(self, model=None, processor=None, classes=[], templates=[], load_in_8bit=False): | |
super().__init__() | |
self.model = model.eval() | |
self.processor = processor | |
self.classes = classes | |
self.templates = templates | |
self._init_weights() | |
def _init_weights(self): | |
self.model.eval() | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
weights = [] | |
for classname in tqdm(self.classes): | |
prompts = [template.format(classname) for template in self.templates] | |
prompts = self.processor(text=prompts, truncation=True, padding=True, return_tensors='pt') | |
embeddings = self.model.get_text_features(**{k: v.to(device) for k, v in prompts.items()}).cpu() | |
embeddings /= embeddings.norm(dim=-1, keepdim=True) | |
embeddings = reduce(embeddings, 'b d -> d', 'mean') | |
embeddings /= embeddings.norm() | |
weights.append(embeddings) | |
weights = torch.stack(weights) | |
self.register_buffer('weights', weights) | |
def forward(self, pixel_values): | |
x = self.model.get_image_features(pixel_values=pixel_values) | |
x /= x.norm(dim=-1, keepdim=True) | |
return x.mm(self.weights.t()) * 100.00000762939453 | |
def preprocess_and_forward(self, x): | |
x = self.processor(images=x, return_tensors='pt') | |
return self(x['pixel_values']) | |
from transformers import CLIPProcessor | |
model = load_distillclip('Ramos-Ramos/distillclip') | |
processor = CLIPProcessor.from_pretrained('Ramos-Ramos/distillclip') | |
def infer(image, classes, templates): | |
classes = [label.strip() for label in classes.split(',')] | |
print(classes) | |
templates = [template.strip() for template in templates.split(';')] | |
print(templates) | |
clip = ZeroShotCLIP(model=model, processor=processor, classes=classes, templates=templates) | |
preds = clip.preprocess_and_forward(image).softmax(dim=1).flatten() | |
return {label: score.item() for label, score in zip(classes, preds)} | |
import gradio as gr | |
title = 'DistillCLIP' | |
description = 'Zero-shot image classification demo with DistillCLIP' | |
article = '''DistillCLIP is a distilled version of [CLIP-ViT/B-32](https://huggingface.co/openai/clip-vit-base-patch32). | |
Please refer to the [DistillCLIP model card](https://huggingface.co/Ramos-Ramos/distillclip) for more details on DistillCLIP. | |
Note: As multiplying logits by a temperature prior to the softmax can better distinguish final scores, we multiply DistillCLIP's text-image similarity scores by the teacher CLIP's temperature.''' | |
demo = gr.Interface( | |
fn=infer, | |
inputs=[ | |
gr.Image(label='Image', type='pil'), | |
gr.Textbox(label='Classes', placeholder='cat, truck', info='Classes for classification. Separate classes with commas.'), | |
gr.Textbox(label='Prompt/s', placeholder='a photo of a {}.; a blurry photo of a {}.', info='Prompt templates. Use "{}" as placeholder for class. Separate prompts with semi-colons.') | |
], | |
outputs=gr.Label(label='Class scores'), | |
title=title, | |
description=description, | |
article=article | |
) | |
demo.launch() |