|
from typing import Dict, List, Any |
|
import base64 |
|
from io import BytesIO |
|
from pathlib import Path |
|
import torch |
|
from torch import autocast |
|
import open_clip |
|
from open_clip import tokenizer |
|
from rudalle import get_vae |
|
from einops import rearrange |
|
from PIL import Image |
|
|
|
from modules import DenoiseUNet |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
batch_size = 1 |
|
steps = 11 |
|
scale = 5 |
|
|
|
def to_pil(images): |
|
images = images.permute(0, 2, 3, 1).cpu().numpy() |
|
images = (images * 255).round().astype("uint8") |
|
images = [Image.fromarray(image) for image in images] |
|
return images |
|
|
|
def log(t, eps=1e-20): |
|
return torch.log(t + eps) |
|
|
|
def gumbel_noise(t): |
|
noise = torch.zeros_like(t).uniform_(0, 1) |
|
return -log(-log(noise)) |
|
|
|
def gumbel_sample(t, temperature=1., dim=-1): |
|
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim) |
|
|
|
def sample(model, c, x=None, mask=None, T=12, size=(32, 32), starting_t=0, temp_range=[1.0, 1.0], typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=-1, renoise_steps=11, renoise_mode='start'): |
|
with torch.inference_mode(): |
|
r_range = torch.linspace(0, 1, T+1)[:-1][:, None].expand(-1, c.size(0)).to(c.device) |
|
temperatures = torch.linspace(temp_range[0], temp_range[1], T) |
|
preds = [] |
|
if x is None: |
|
x = torch.randint(0, model.num_labels, size=(c.size(0), *size), device=c.device) |
|
elif mask is not None: |
|
noise = torch.randint(0, model.num_labels, size=(c.size(0), *size), device=c.device) |
|
x = noise * mask + (1-mask) * x |
|
init_x = x.clone() |
|
for i in range(starting_t, T): |
|
if renoise_mode == 'prev': |
|
prev_x = x.clone() |
|
r, temp = r_range[i], temperatures[i] |
|
logits = model(x, c, r) |
|
if classifier_free_scale >= 0: |
|
logits_uncond = model(x, torch.zeros_like(c), r) |
|
logits = torch.lerp(logits_uncond, logits, classifier_free_scale) |
|
x = logits |
|
x_flat = x.permute(0, 2, 3, 1).reshape(-1, x.size(1)) |
|
if typical_filtering: |
|
x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1) |
|
x_flat_norm_p = torch.exp(x_flat_norm) |
|
entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True) |
|
|
|
c_flat_shifted = torch.abs((-x_flat_norm) - entropy) |
|
c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False) |
|
x_flat_cumsum = x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1) |
|
|
|
last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1) |
|
sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(1, last_ind.view(-1, 1)) |
|
if typical_min_tokens > 1: |
|
sorted_indices_to_remove[..., :typical_min_tokens] = 0 |
|
indices_to_remove = sorted_indices_to_remove.scatter(1, x_flat_indices, sorted_indices_to_remove) |
|
x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf")) |
|
|
|
x_flat = gumbel_sample(x_flat, temperature=temp) |
|
x = x_flat.view(x.size(0), *x.shape[2:]) |
|
if mask is not None: |
|
x = x * mask + (1-mask) * init_x |
|
if i < renoise_steps: |
|
if renoise_mode == 'start': |
|
x, _ = model.add_noise(x, r_range[i+1], random_x=init_x) |
|
elif renoise_mode == 'prev': |
|
x, _ = model.add_noise(x, r_range[i+1], random_x=prev_x) |
|
else: |
|
x, _ = model.add_noise(x, r_range[i+1]) |
|
preds.append(x.detach()) |
|
return preds |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
model_path = Path(path) / "model_600000.pt" |
|
state_dict = torch.load(model_path, map_location=device) |
|
model = DenoiseUNet(num_labels=8192).to(device) |
|
model.load_state_dict(state_dict) |
|
model.to(device).eval().requires_grad_() |
|
self.model = model |
|
|
|
vqmodel = get_vae().to(device) |
|
vqmodel.eval().requires_grad_(False) |
|
self.vqmodel = vqmodel |
|
|
|
clip_model, _, _ = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s12b_b42k') |
|
clip_model = clip_model.to(device).eval().requires_grad_(False) |
|
self.clip_model = clip_model |
|
|
|
def encode(self, x): |
|
return self.vqmodel.model.encode((2 * x - 1))[-1][-1] |
|
|
|
def decode(self, img_seq, shape=(32,32)): |
|
img_seq = img_seq.view(img_seq.shape[0], -1) |
|
b, n = img_seq.shape |
|
one_hot_indices = torch.nn.functional.one_hot(img_seq, num_classes=self.vqmodel.num_tokens).float() |
|
z = (one_hot_indices @ self.vqmodel.model.quantize.embed.weight) |
|
z = rearrange(z, 'b (h w) c -> b c h w', h=shape[0], w=shape[1]) |
|
img = self.vqmodel.model.decode(z) |
|
img = (img.clamp(-1., 1.) + 1) * 0.5 |
|
return img |
|
|
|
def __call__(self, data: Any) -> List[List[Dict[str, float]]]: |
|
""" |
|
Args: |
|
data (:obj:): |
|
includes the input data and the parameters for the inference. |
|
Return: |
|
A :obj:`dict`:. base64 encoded image |
|
""" |
|
inputs = data.pop("inputs", data) |
|
|
|
latent_shape = (32, 32) |
|
tokenized_text = tokenizer.tokenize([inputs] * batch_size).to(device) |
|
with autocast(device.type): |
|
clip_embeddings = self.clip_model.encode_text(tokenized_text) |
|
images = sample( |
|
self.model, clip_embeddings, T=12, size=latent_shape, starting_t=0, temp_range=[1.0, 1.0], |
|
typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, |
|
classifier_free_scale=scale, renoise_steps=steps, renoise_mode="start" |
|
) |
|
images = self.decode(images[-1], latent_shape) |
|
images = to_pil(images) |
|
|
|
|
|
buffered = BytesIO() |
|
images[0].save(buffered, format="JPEG") |
|
img_str = base64.b64encode(buffered.getvalue()) |
|
|
|
|
|
return {"image": img_str.decode()} |
|
|