File size: 4,997 Bytes
e8aa2e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from transformers import PretrainedConfig
import torch.nn as nn
from transformers import PreTrainedModel
import torch
from safetensors.torch import save_file
import os
from timm.models.vision_transformer import Block
from .mar import MAR

class MARConfig(PretrainedConfig):
    model_type = "mar"

    def __init__(self,
                 img_size=256,
                 vae_stride=16,
                 patch_size=1,
                 encoder_embed_dim=1024,
                 encoder_depth=16,
                 encoder_num_heads=16,
                 decoder_embed_dim=1024,
                 decoder_depth=16,
                 decoder_num_heads=16,
                 mlp_ratio=4.,
                 norm_layer="LayerNorm",
                 vae_embed_dim=16,
                 mask_ratio_min=0.7,
                 label_drop_prob=0.1,
                 class_num=1000,
                 attn_dropout=0.1,
                 proj_dropout=0.1,
                 buffer_size=64,
                 diffloss_d=3,
                 diffloss_w=1024,
                 num_sampling_steps='100',
                 diffusion_batch_mul=4,
                 grad_checkpointing=False,
                 **kwargs):
        super().__init__(**kwargs)

        # store parameters in the config
        self.img_size = img_size
        self.vae_stride = vae_stride
        self.patch_size = patch_size
        self.encoder_embed_dim = encoder_embed_dim
        self.encoder_depth = encoder_depth
        self.encoder_num_heads = encoder_num_heads
        self.decoder_embed_dim = decoder_embed_dim
        self.decoder_depth = decoder_depth
        self.decoder_num_heads = decoder_num_heads
        self.mlp_ratio = mlp_ratio
        self.norm_layer = norm_layer
        self.vae_embed_dim = vae_embed_dim
        self.mask_ratio_min = mask_ratio_min
        self.label_drop_prob = label_drop_prob
        self.class_num = class_num
        self.attn_dropout = attn_dropout
        self.proj_dropout = proj_dropout
        self.buffer_size = buffer_size
        self.diffloss_d = diffloss_d
        self.diffloss_w = diffloss_w
        self.num_sampling_steps = num_sampling_steps
        self.diffusion_batch_mul = diffusion_batch_mul
        self.grad_checkpointing = grad_checkpointing



class MARModel(PreTrainedModel):
    # links to MARConfig class
    config_class = MARConfig

    def __init__(self, config):
        super().__init__(config)
        self.config = config

        # convert norm_layer from string to class
        norm_layer = getattr(nn, config.norm_layer)

        # init the  mar model using the parameters from config
        self.model = MAR(
            img_size=config.img_size,
            vae_stride=config.vae_stride,
            patch_size=config.patch_size,
            encoder_embed_dim=config.encoder_embed_dim,
            encoder_depth=config.encoder_depth,
            encoder_num_heads=config.encoder_num_heads,
            decoder_embed_dim=config.decoder_embed_dim,
            decoder_depth=config.decoder_depth,
            decoder_num_heads=config.decoder_num_heads,
            mlp_ratio=config.mlp_ratio,
            norm_layer=norm_layer, # use the actual class for the layer
            vae_embed_dim=config.vae_embed_dim,
            mask_ratio_min=config.mask_ratio_min,
            label_drop_prob=config.label_drop_prob,
            class_num=config.class_num,
            attn_dropout=config.attn_dropout,
            proj_dropout=config.proj_dropout,
            buffer_size=config.buffer_size,
            diffloss_d=config.diffloss_d,
            diffloss_w=config.diffloss_w,
            num_sampling_steps=config.num_sampling_steps,
            diffusion_batch_mul=config.diffusion_batch_mul,
            grad_checkpointing=config.grad_checkpointing,
        )

    def forward(self, imgs, labels):
        # calls the forward method from the mar class - passing imgs & labels
        return self.model(imgs, labels)

    def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
        # call the sample_tokens method from the MAR class
        return self.model.sample_tokens(bsz, num_iter, cfg, cfg_schedule, labels, temperature, progress)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        config = MARConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
        model = cls(config)
        state_dict = torch.load('./checkpoint-last.safetensors')
        model.model.load_state_dict(state_dict)
        return model

    def save_pretrained(self, save_directory):
      # we will save to safetensors
      os.makedirs(save_directory, exist_ok=True)
      state_dict = self.model.state_dict()
      safetensors_path = os.path.join(save_directory, "pytorch_model.safetensors")
      save_file(state_dict, safetensors_path)
      
      # save the configuration as usual
      self.config.save_pretrained(save_directory)