|
from typing import Any, Dict, List, Tuple |
|
|
|
import clip |
|
from hydra import compose, initialize |
|
from hydra.utils import instantiate |
|
from omegaconf import OmegaConf |
|
import torch |
|
from torchtyping import TensorType |
|
from torch.utils.data import DataLoader |
|
import torch.nn.functional as F |
|
|
|
from src.diffuser import Diffuser |
|
from src.datasets.multimodal_dataset import MultimodalDataset |
|
|
|
|
|
|
|
batch_size, context_length = None, None |
|
collate_fn = DataLoader([]).collate_fn |
|
|
|
|
|
|
|
|
|
def to_device(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]: |
|
for key, value in batch.items(): |
|
if isinstance(value, torch.Tensor): |
|
batch[key] = value.to(device) |
|
return batch |
|
|
|
|
|
def load_clip_model(version: str, device: str) -> clip.model.CLIP: |
|
model, _ = clip.load(version, device=device, jit=False) |
|
model.eval() |
|
for p in model.parameters(): |
|
p.requires_grad = False |
|
return model |
|
|
|
|
|
def encode_text( |
|
caption_raws: List[str], |
|
clip_model: clip.model.CLIP, |
|
max_token_length: int, |
|
device: str, |
|
) -> TensorType["batch_size", "context_length"]: |
|
if max_token_length is not None: |
|
default_context_length = 77 |
|
context_length = max_token_length + 2 |
|
assert context_length < default_context_length |
|
|
|
texts = clip.tokenize( |
|
caption_raws, context_length=context_length, truncate=True |
|
) |
|
zero_pad = torch.zeros( |
|
[texts.shape[0], default_context_length - context_length], |
|
dtype=texts.dtype, |
|
device=texts.device, |
|
) |
|
texts = torch.cat([texts, zero_pad], dim=1) |
|
else: |
|
|
|
texts = clip.tokenize(caption_raws, truncate=True) |
|
|
|
|
|
x = clip_model.token_embedding(texts.to(device)).type(clip_model.dtype) |
|
x = x + clip_model.positional_embedding.type(clip_model.dtype) |
|
x = x.permute(1, 0, 2) |
|
x = clip_model.transformer(x) |
|
x = x.permute(1, 0, 2) |
|
x = clip_model.ln_final(x).type(clip_model.dtype) |
|
|
|
|
|
x_tokens = x[torch.arange(x.shape[0]), texts.argmax(dim=-1)].float() |
|
x_seq = [x[k, : (m + 1)].float() for k, m in enumerate(texts.argmax(dim=-1))] |
|
|
|
return x_seq, x_tokens |
|
|
|
|
|
def get_batch( |
|
prompt: str, |
|
sample_id: str, |
|
character_position: str, |
|
clip_model: clip.model.CLIP, |
|
dataset: MultimodalDataset, |
|
seq_feat: bool, |
|
device: torch.device, |
|
) -> Dict[str, Any]: |
|
|
|
sample_index = dataset.root_filenames.index(sample_id) |
|
raw_batch = dataset[sample_index] |
|
batch = collate_fn([to_device(raw_batch, device)]) |
|
|
|
|
|
caption_seq, caption_tokens = encode_text([prompt], clip_model, None, device) |
|
print(caption_seq[0].device) |
|
|
|
if seq_feat: |
|
caption_feat = caption_seq[0] |
|
caption_feat = F.pad(caption_feat, (0, 0, 0, 77 - caption_feat.shape[0])) |
|
caption_feat = caption_feat.unsqueeze(0).permute(0, 2, 1) |
|
else: |
|
caption_feat = caption_tokens |
|
|
|
|
|
batch["caption_raw"] = [prompt] |
|
batch["caption_feat"] = caption_feat |
|
|
|
character_position = torch.tensor( |
|
[float(coord) for coord in character_position.strip("[]").split(",")], |
|
device=device, |
|
) |
|
|
|
batch['char_feat'] = character_position.unsqueeze(-1).repeat(1, 1, 300) |
|
batch['char_raw']['char_raw_feat'] = character_position.unsqueeze(-1).repeat(1, 1, 300) |
|
batch['char_raw']['char_vertices'] = torch.zeros_like(batch['char_raw']['char_vertices']) |
|
batch['char_raw']['char_faces'] = torch.zeros_like(batch['char_raw']['char_faces']) |
|
|
|
return batch |
|
|
|
|
|
def init( |
|
config_name: str, |
|
) -> Tuple[Diffuser, clip.model.CLIP, MultimodalDataset]: |
|
with initialize(version_base="1.3", config_path="../configs"): |
|
config = compose(config_name=config_name) |
|
|
|
OmegaConf.register_new_resolver("eval", eval) |
|
|
|
|
|
|
|
diffuser = instantiate(config.diffuser) |
|
state_dict = torch.load(config.checkpoint_path, map_location="cpu")["state_dict"] |
|
state_dict["ema.initted"] = diffuser.ema.initted |
|
state_dict["ema.step"] = diffuser.ema.step |
|
diffuser.load_state_dict(state_dict, strict=False) |
|
diffuser.to("cpu").eval() |
|
|
|
|
|
clip_model = load_clip_model("ViT-B/32", "cpu") |
|
|
|
|
|
config.dataset.char.load_vertices = True |
|
config.batch_size = 1 |
|
dataset = instantiate(config.dataset) |
|
dataset.set_split("demo") |
|
diffuser.modalities = list(dataset.modality_datasets.keys()) |
|
diffuser.get_matrix = dataset.get_matrix |
|
diffuser.v_get_matrix = dataset.get_matrix |
|
|
|
return diffuser, clip_model, dataset, config.compnode.device |
|
|