Spaces:
Running
Running
import os | |
import torch | |
import torch.nn as nn | |
import pytorch_lightning as pl | |
from typing import List, Union | |
from PIL import Image | |
WS_REPOS = ["Eugeoter/waifu-scorer-v3"] | |
class MLP(pl.LightningModule): | |
def __init__(self, input_size, xcol='emb', ycol='avg_rating', batch_norm=True): | |
super().__init__() | |
self.input_size = input_size | |
self.xcol = xcol | |
self.ycol = ycol | |
self.layers = nn.Sequential( | |
nn.Linear(self.input_size, 2048), | |
nn.ReLU(), | |
nn.BatchNorm1d(2048) if batch_norm else nn.Identity(), | |
nn.Dropout(0.3), | |
nn.Linear(2048, 512), | |
nn.ReLU(), | |
nn.BatchNorm1d(512) if batch_norm else nn.Identity(), | |
nn.Dropout(0.3), | |
nn.Linear(512, 256), | |
nn.ReLU(), | |
nn.BatchNorm1d(256) if batch_norm else nn.Identity(), | |
nn.Dropout(0.2), | |
nn.Linear(256, 128), | |
nn.ReLU(), | |
nn.BatchNorm1d(128) if batch_norm else nn.Identity(), | |
nn.Dropout(0.1), | |
nn.Linear(128, 32), | |
nn.ReLU(), | |
nn.Linear(32, 1) | |
) | |
def forward(self, x): | |
return self.layers(x) | |
class WaifuScorer(object): | |
def __init__(self, model_path=None, device='cuda', cache_dir=None, verbose=False): | |
self.verbose = verbose | |
if model_path is None: | |
model_path = repo2path(WS_REPOS[0]) | |
if self.verbose: | |
print(f"model path not set, switch to default: `{model_path}`") | |
if not os.path.isfile(model_path): | |
model_path = download_from_url(model_path, cache_dir=cache_dir) | |
print(f"loading pretrained model from `{model_path}`") | |
self.mlp = load_model(model_path, input_size=768, device=device) | |
self.model2, self.preprocess = load_clip_models("ViT-L/14", device=device) | |
self.device = self.mlp.device | |
self.dtype = self.mlp.dtype | |
self.mlp.eval() | |
def __call__(self, images: List[Image.Image]) -> Union[List[float], float]: | |
if isinstance(images, Image.Image): | |
images = [images] | |
n = len(images) | |
if n == 1: | |
images = images*2 # batch norm | |
images = encode_images(images, self.model2, self.preprocess, device=self.device).to(device=self.device, dtype=self.dtype) | |
predictions = self.mlp(images) | |
scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist() | |
# if n == 1: | |
# scores = scores[0] | |
return scores | |
def repo2path(model_repo_and_path: str): | |
if os.path.isfile(model_repo_and_path): | |
model_path = model_repo_and_path | |
elif os.path.isdir(model_repo_and_path): | |
model_path = os.path.join(model_repo_and_path, "model.pth") | |
elif model_repo_and_path in WS_REPOS: | |
model_path = model_repo_and_path + '/model.pth' | |
else: | |
raise ValueError(f"Invalid model_repo_and_path: {model_repo_and_path}") | |
return model_path | |
def download_from_url(url, cache_dir=None, verbose=True): | |
from huggingface_hub import hf_hub_download | |
split = url.split("/") | |
username, repo_id, model_name = split[-3], split[-2], split[-1] | |
# if verbose: | |
# print(f"[download_from_url]: {username}/{repo_id}/{model_name}") | |
model_path = hf_hub_download(f"{username}/{repo_id}", model_name, cache_dir=cache_dir) | |
return model_path | |
def load_clip_models(name: str = "ViT-L/14", device='cuda'): | |
import clip | |
model2, preprocess = clip.load(name, device=device) # RN50x64 | |
return model2, preprocess | |
def load_model(model_path: str = None, input_size=768, device: str = 'cuda', dtype=None): | |
model = MLP(input_size=input_size) | |
if model_path: | |
s = torch.load(model_path, map_location=device) | |
model.load_state_dict(s) | |
model.to(device) | |
if dtype: | |
model = model.to(dtype=dtype) | |
return model | |
def normalized(a: torch.Tensor, order=2, dim=-1): | |
l2 = a.norm(order, dim, keepdim=True) | |
l2[l2 == 0] = 1 | |
return a / l2 | |
def encode_images(images: List[Image.Image], model2, preprocess, device='cuda') -> torch.Tensor: | |
if isinstance(images, Image.Image): | |
images = [images] | |
image_tensors = [preprocess(img).unsqueeze(0) for img in images] | |
image_batch = torch.cat(image_tensors).to(device) | |
image_features = model2.encode_image(image_batch) | |
im_emb_arr = normalized(image_features).cpu().float() | |
return im_emb_arr | |