|
|
|
from PIL import Image |
|
import io |
|
import numpy as np |
|
import torch.nn as nn |
|
import torch |
|
import torchvision |
|
import albumentations as A |
|
from transformers import CLIPModel, CLIPProcessor |
|
|
|
|
|
|
|
def jpeg_compressibility(device): |
|
def _fn(images): |
|
''' |
|
args: |
|
images: shape NCHW |
|
''' |
|
org_type = images.dtype |
|
if isinstance(images, torch.Tensor): |
|
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() |
|
images = images.transpose(0, 2, 3, 1) |
|
|
|
transform_images_tensor = torch.Tensor(np.array(images)).to(device, dtype=org_type) |
|
transform_images_tensor = (transform_images_tensor.permute(0,3,1,2) / 255).clamp(0,1) |
|
transform_images_pil = [Image.fromarray(image) for image in images] |
|
buffers = [io.BytesIO() for _ in transform_images_pil] |
|
|
|
for image, buffer in zip(transform_images_pil, buffers): |
|
image.save(buffer, format="JPEG", quality=95) |
|
|
|
sizes = [buffer.tell() / 1000 for buffer in buffers] |
|
|
|
return np.array(sizes), transform_images_tensor |
|
|
|
return _fn |
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.layers = nn.Sequential( |
|
nn.Linear(768, 512), |
|
nn.ReLU(), |
|
nn.Dropout(0.2), |
|
nn.Linear(512, 256), |
|
nn.ReLU(), |
|
nn.Dropout(0.2), |
|
nn.Linear(256, 128), |
|
nn.ReLU(), |
|
nn.Dropout(0.2), |
|
nn.Linear(128, 32), |
|
nn.ReLU(), |
|
nn.Dropout(0.1), |
|
nn.Linear(32, 1), |
|
) |
|
|
|
def forward(self, embed): |
|
return self.layers(embed) |
|
|
|
def jpegcompression_loss_fn(target=None, |
|
grad_scale=0, |
|
device=None, |
|
accelerator=None, |
|
torch_dtype=None, |
|
reward_model_resume_from=None): |
|
scorer = JpegCompressionScorer(dtype=torch_dtype, model_path=reward_model_resume_from).to(device, dtype=torch_dtype) |
|
scorer.requires_grad_(False) |
|
scorer.eval() |
|
def loss_fn(im_pix_un): |
|
if accelerator.mixed_precision == "fp16": |
|
with accelerator.autocast(): |
|
rewards = scorer(im_pix_un) |
|
else: |
|
rewards = scorer(im_pix_un) |
|
|
|
if target is None: |
|
loss = rewards |
|
else: |
|
loss = abs(rewards - target) |
|
return loss * grad_scale, rewards |
|
return loss_fn |
|
|
|
class JpegCompressionScorer(nn.Module): |
|
def __init__(self, dtype=None, model_path=None): |
|
super().__init__() |
|
self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") |
|
self.clip.requires_grad_(False) |
|
self.score_generator = MLP() |
|
|
|
if model_path: |
|
state_dict = torch.load(model_path) |
|
self.score_generator.load_state_dict(state_dict) |
|
if dtype: |
|
self.dtype = dtype |
|
self.target_size = (224,224) |
|
self.normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], |
|
std=[0.26862954, 0.26130258, 0.27577711]) |
|
|
|
|
|
def set_device(self, device, inference_type): |
|
|
|
self.score_generator.to(device) |
|
|
|
def __call__(self, images): |
|
device = next(self.parameters()).device |
|
im_pix = torchvision.transforms.Resize(self.target_size)(images) |
|
im_pix = self.normalize(im_pix).to(images.dtype) |
|
embed = self.clip.get_image_features(pixel_values=im_pix) |
|
embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) |
|
return self.score_generator(embed).squeeze(1) |