Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# coding: utf-8 | |
# In[37]: | |
import torch | |
import torch.nn as nn | |
from functools import partial | |
import clip | |
from einops import rearrange, repeat | |
from glob import glob | |
from PIL import Image | |
from torchvision import transforms as T | |
from tqdm import tqdm | |
import pickle | |
import os | |
import numpy as np | |
# In[17]: | |
device = 'cuda:0' | |
clip_norm = T.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), | |
std=(0.26862954, 0.26130258, 0.27577711)) | |
clip_transform = T.Compose([T.ToTensor(), | |
clip_norm]) | |
# In[1]: | |
class ClipImageEncoder(nn.Module): | |
""" | |
Uses the CLIP image encoder. | |
""" | |
def __init__( | |
self, | |
model='ViT-L/14', | |
context_dim=[], | |
jit=False, | |
device='cuda', | |
): | |
super().__init__() | |
self.context_dim = context_dim | |
self.model, _ = clip.load(name=model, device=device, jit=jit) | |
self.model = self.model.eval() | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, x): | |
b, n, c, h, w = x.shape | |
batch = rearrange(x, 'b n c h w -> (b n) c h w ') | |
ret = self.model.encode_image(batch) | |
return rearrange(ret, '(b n) w -> b n w ', b=b, n=n) | |
def preprocess(self, style_file): | |
if os.path.exists(style_file): | |
style_image = Image.open(style_file) | |
else: | |
style_image = Image.fromarray(np.zeros((224,224,3), dtype=np.uint8)) | |
x = clip_transform(style_image).unsqueeze(0).unsqueeze(0) | |
return x | |
def postprocess(self, x): | |
return x.squeeze(0).detach().cpu().numpy() | |
# In[23]: | |
encoder = ClipImageEncoder() | |
encoder = encoder.to(device) | |
# In[6]: | |
# style_files = glob("/home/soon/datasets/deepfashion_inshop/styles/**/*.jpg", recursive=True) | |
# # In[39]: | |
# for style_file in tqdm(style_files[:]): | |
# style_image = Image.open(style_file) | |
# x = clip_transform(style_image).unsqueeze(0).unsqueeze(0).to(device) | |
# emb = encoder(x).detach().cpu().squeeze(0).numpy() | |
# emb_file = style_file.replace('.jpg','.p') | |
# with open(emb_file, 'wb') as file: | |
# pickle.dump(emb, file) | |