Vadhwid / Core /compression_scorer.py
QinOwen
add-vader-videocrafter
824b515
raw
history blame
3.91 kB
# Adapt from Cheng An Hsieh, et. al.: https://github.com/RewardMultiverse/reward-multiverse
from PIL import Image
import io
import numpy as np
import torch.nn as nn
import torch
import torchvision
import albumentations as A
from transformers import CLIPModel, CLIPProcessor
# import ipdb
# st = ipdb.set_trace
def jpeg_compressibility(device):
def _fn(images):
'''
args:
images: shape NCHW
'''
org_type = images.dtype
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
transform_images_tensor = torch.Tensor(np.array(images)).to(device, dtype=org_type)
transform_images_tensor = (transform_images_tensor.permute(0,3,1,2) / 255).clamp(0,1) # NHWC -> NCHW
transform_images_pil = [Image.fromarray(image) for image in images]
buffers = [io.BytesIO() for _ in transform_images_pil]
for image, buffer in zip(transform_images_pil, buffers):
image.save(buffer, format="JPEG", quality=95)
sizes = [buffer.tell() / 1000 for buffer in buffers]
return np.array(sizes), transform_images_tensor
return _fn
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(768, 512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 32),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(32, 1),
)
def forward(self, embed):
return self.layers(embed)
def jpegcompression_loss_fn(target=None,
grad_scale=0,
device=None,
accelerator=None,
torch_dtype=None,
reward_model_resume_from=None):
scorer = JpegCompressionScorer(dtype=torch_dtype, model_path=reward_model_resume_from).to(device, dtype=torch_dtype)
scorer.requires_grad_(False)
scorer.eval()
def loss_fn(im_pix_un):
if accelerator.mixed_precision == "fp16":
with accelerator.autocast():
rewards = scorer(im_pix_un)
else:
rewards = scorer(im_pix_un)
if target is None:
loss = rewards
else:
loss = abs(rewards - target)
return loss * grad_scale, rewards
return loss_fn
class JpegCompressionScorer(nn.Module):
def __init__(self, dtype=None, model_path=None):
super().__init__()
self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
self.clip.requires_grad_(False)
self.score_generator = MLP()
if model_path:
state_dict = torch.load(model_path)
self.score_generator.load_state_dict(state_dict)
if dtype:
self.dtype = dtype
self.target_size = (224,224)
self.normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
def set_device(self, device, inference_type):
# self.clip.to(device, dtype = inference_type)
self.score_generator.to(device) # , dtype = inference_type
def __call__(self, images):
device = next(self.parameters()).device
im_pix = torchvision.transforms.Resize(self.target_size)(images)
im_pix = self.normalize(im_pix).to(images.dtype)
embed = self.clip.get_image_features(pixel_values=im_pix)
embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
return self.score_generator(embed).squeeze(1)