File size: 4,997 Bytes
e8aa2e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from transformers import PretrainedConfig
import torch.nn as nn
from transformers import PreTrainedModel
import torch
from safetensors.torch import save_file
import os
from timm.models.vision_transformer import Block
from .mar import MAR
class MARConfig(PretrainedConfig):
model_type = "mar"
def __init__(self,
img_size=256,
vae_stride=16,
patch_size=1,
encoder_embed_dim=1024,
encoder_depth=16,
encoder_num_heads=16,
decoder_embed_dim=1024,
decoder_depth=16,
decoder_num_heads=16,
mlp_ratio=4.,
norm_layer="LayerNorm",
vae_embed_dim=16,
mask_ratio_min=0.7,
label_drop_prob=0.1,
class_num=1000,
attn_dropout=0.1,
proj_dropout=0.1,
buffer_size=64,
diffloss_d=3,
diffloss_w=1024,
num_sampling_steps='100',
diffusion_batch_mul=4,
grad_checkpointing=False,
**kwargs):
super().__init__(**kwargs)
# store parameters in the config
self.img_size = img_size
self.vae_stride = vae_stride
self.patch_size = patch_size
self.encoder_embed_dim = encoder_embed_dim
self.encoder_depth = encoder_depth
self.encoder_num_heads = encoder_num_heads
self.decoder_embed_dim = decoder_embed_dim
self.decoder_depth = decoder_depth
self.decoder_num_heads = decoder_num_heads
self.mlp_ratio = mlp_ratio
self.norm_layer = norm_layer
self.vae_embed_dim = vae_embed_dim
self.mask_ratio_min = mask_ratio_min
self.label_drop_prob = label_drop_prob
self.class_num = class_num
self.attn_dropout = attn_dropout
self.proj_dropout = proj_dropout
self.buffer_size = buffer_size
self.diffloss_d = diffloss_d
self.diffloss_w = diffloss_w
self.num_sampling_steps = num_sampling_steps
self.diffusion_batch_mul = diffusion_batch_mul
self.grad_checkpointing = grad_checkpointing
class MARModel(PreTrainedModel):
# links to MARConfig class
config_class = MARConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# convert norm_layer from string to class
norm_layer = getattr(nn, config.norm_layer)
# init the mar model using the parameters from config
self.model = MAR(
img_size=config.img_size,
vae_stride=config.vae_stride,
patch_size=config.patch_size,
encoder_embed_dim=config.encoder_embed_dim,
encoder_depth=config.encoder_depth,
encoder_num_heads=config.encoder_num_heads,
decoder_embed_dim=config.decoder_embed_dim,
decoder_depth=config.decoder_depth,
decoder_num_heads=config.decoder_num_heads,
mlp_ratio=config.mlp_ratio,
norm_layer=norm_layer, # use the actual class for the layer
vae_embed_dim=config.vae_embed_dim,
mask_ratio_min=config.mask_ratio_min,
label_drop_prob=config.label_drop_prob,
class_num=config.class_num,
attn_dropout=config.attn_dropout,
proj_dropout=config.proj_dropout,
buffer_size=config.buffer_size,
diffloss_d=config.diffloss_d,
diffloss_w=config.diffloss_w,
num_sampling_steps=config.num_sampling_steps,
diffusion_batch_mul=config.diffusion_batch_mul,
grad_checkpointing=config.grad_checkpointing,
)
def forward(self, imgs, labels):
# calls the forward method from the mar class - passing imgs & labels
return self.model(imgs, labels)
def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
# call the sample_tokens method from the MAR class
return self.model.sample_tokens(bsz, num_iter, cfg, cfg_schedule, labels, temperature, progress)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = MARConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
model = cls(config)
state_dict = torch.load('./checkpoint-last.safetensors')
model.model.load_state_dict(state_dict)
return model
def save_pretrained(self, save_directory):
# we will save to safetensors
os.makedirs(save_directory, exist_ok=True)
state_dict = self.model.state_dict()
safetensors_path = os.path.join(save_directory, "pytorch_model.safetensors")
save_file(state_dict, safetensors_path)
# save the configuration as usual
self.config.save_pretrained(save_directory)
|