jadechoghari commited on
Commit
e8aa2e4
·
verified ·
1 Parent(s): d059bcc

Create modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +129 -0
modeling.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel
4
+ import torch
5
+ from safetensors.torch import save_file
6
+ import os
7
+ from timm.models.vision_transformer import Block
8
+ from .mar import MAR
9
+
10
+ class MARConfig(PretrainedConfig):
11
+ model_type = "mar"
12
+
13
+ def __init__(self,
14
+ img_size=256,
15
+ vae_stride=16,
16
+ patch_size=1,
17
+ encoder_embed_dim=1024,
18
+ encoder_depth=16,
19
+ encoder_num_heads=16,
20
+ decoder_embed_dim=1024,
21
+ decoder_depth=16,
22
+ decoder_num_heads=16,
23
+ mlp_ratio=4.,
24
+ norm_layer="LayerNorm",
25
+ vae_embed_dim=16,
26
+ mask_ratio_min=0.7,
27
+ label_drop_prob=0.1,
28
+ class_num=1000,
29
+ attn_dropout=0.1,
30
+ proj_dropout=0.1,
31
+ buffer_size=64,
32
+ diffloss_d=3,
33
+ diffloss_w=1024,
34
+ num_sampling_steps='100',
35
+ diffusion_batch_mul=4,
36
+ grad_checkpointing=False,
37
+ **kwargs):
38
+ super().__init__(**kwargs)
39
+
40
+ # store parameters in the config
41
+ self.img_size = img_size
42
+ self.vae_stride = vae_stride
43
+ self.patch_size = patch_size
44
+ self.encoder_embed_dim = encoder_embed_dim
45
+ self.encoder_depth = encoder_depth
46
+ self.encoder_num_heads = encoder_num_heads
47
+ self.decoder_embed_dim = decoder_embed_dim
48
+ self.decoder_depth = decoder_depth
49
+ self.decoder_num_heads = decoder_num_heads
50
+ self.mlp_ratio = mlp_ratio
51
+ self.norm_layer = norm_layer
52
+ self.vae_embed_dim = vae_embed_dim
53
+ self.mask_ratio_min = mask_ratio_min
54
+ self.label_drop_prob = label_drop_prob
55
+ self.class_num = class_num
56
+ self.attn_dropout = attn_dropout
57
+ self.proj_dropout = proj_dropout
58
+ self.buffer_size = buffer_size
59
+ self.diffloss_d = diffloss_d
60
+ self.diffloss_w = diffloss_w
61
+ self.num_sampling_steps = num_sampling_steps
62
+ self.diffusion_batch_mul = diffusion_batch_mul
63
+ self.grad_checkpointing = grad_checkpointing
64
+
65
+
66
+
67
+ class MARModel(PreTrainedModel):
68
+ # links to MARConfig class
69
+ config_class = MARConfig
70
+
71
+ def __init__(self, config):
72
+ super().__init__(config)
73
+ self.config = config
74
+
75
+ # convert norm_layer from string to class
76
+ norm_layer = getattr(nn, config.norm_layer)
77
+
78
+ # init the mar model using the parameters from config
79
+ self.model = MAR(
80
+ img_size=config.img_size,
81
+ vae_stride=config.vae_stride,
82
+ patch_size=config.patch_size,
83
+ encoder_embed_dim=config.encoder_embed_dim,
84
+ encoder_depth=config.encoder_depth,
85
+ encoder_num_heads=config.encoder_num_heads,
86
+ decoder_embed_dim=config.decoder_embed_dim,
87
+ decoder_depth=config.decoder_depth,
88
+ decoder_num_heads=config.decoder_num_heads,
89
+ mlp_ratio=config.mlp_ratio,
90
+ norm_layer=norm_layer, # use the actual class for the layer
91
+ vae_embed_dim=config.vae_embed_dim,
92
+ mask_ratio_min=config.mask_ratio_min,
93
+ label_drop_prob=config.label_drop_prob,
94
+ class_num=config.class_num,
95
+ attn_dropout=config.attn_dropout,
96
+ proj_dropout=config.proj_dropout,
97
+ buffer_size=config.buffer_size,
98
+ diffloss_d=config.diffloss_d,
99
+ diffloss_w=config.diffloss_w,
100
+ num_sampling_steps=config.num_sampling_steps,
101
+ diffusion_batch_mul=config.diffusion_batch_mul,
102
+ grad_checkpointing=config.grad_checkpointing,
103
+ )
104
+
105
+ def forward(self, imgs, labels):
106
+ # calls the forward method from the mar class - passing imgs & labels
107
+ return self.model(imgs, labels)
108
+
109
+ def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
110
+ # call the sample_tokens method from the MAR class
111
+ return self.model.sample_tokens(bsz, num_iter, cfg, cfg_schedule, labels, temperature, progress)
112
+
113
+ @classmethod
114
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
115
+ config = MARConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
116
+ model = cls(config)
117
+ state_dict = torch.load('./checkpoint-last.safetensors')
118
+ model.model.load_state_dict(state_dict)
119
+ return model
120
+
121
+ def save_pretrained(self, save_directory):
122
+ # we will save to safetensors
123
+ os.makedirs(save_directory, exist_ok=True)
124
+ state_dict = self.model.state_dict()
125
+ safetensors_path = os.path.join(save_directory, "pytorch_model.safetensors")
126
+ save_file(state_dict, safetensors_path)
127
+
128
+ # save the configuration as usual
129
+ self.config.save_pretrained(save_directory)