bczhou commited on
Commit
5f3b360
·
1 Parent(s): 3738ba5

Rename linear_mapping.py to clip_gpt2.py

Browse files
Files changed (1) hide show
  1. linear_mapping.py → clip_gpt2.py +17 -35
linear_mapping.py → clip_gpt2.py RENAMED
@@ -1,41 +1,25 @@
1
- from config import LinearMappingConfig
2
  from transformers import (
3
- GPT2TokenizerFast, GPT2LMHeadModel, AutoModel,
4
- CLIPVisionModel, AutoProcessor, BatchEncoding,
 
5
  AutoConfig, CLIPVisionConfig
6
  )
7
  from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModelOutput
8
  import torch
9
  import torch.nn as nn
10
  from typing import List, Optional, Union, Tuple, Dict
11
- from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
12
- from torchvision.transforms.functional import InterpolationMode
13
 
14
-
15
- class Transform(torch.nn.Module):
16
- def __init__(self, image_size, mean, std):
17
- super().__init__()
18
- self.transforms = torch.nn.Sequential(
19
- Resize([image_size], interpolation=InterpolationMode.BICUBIC, antialias=True),
20
- CenterCrop(image_size),
21
- ConvertImageDtype(torch.float32),
22
- Normalize(mean, std),
23
- )
24
-
25
- def forward(self, x) -> torch.Tensor:
26
- """`x` should be an instance of `PIL.Image.Image`"""
27
- with torch.no_grad():
28
- x = self.transforms(x)
29
- return x
30
 
31
 
32
- class LinearMappingProcessor:
33
  """
34
- A combination of ImageProcessor and GPT2TokenizerFast
35
  """
36
 
37
- def __init__(self, config: LinearMappingConfig):
38
- self.image_processor = AutoProcessor.from_pretrained(config.image_model)
39
  self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
40
  self.add_image_token = config.add_image_token
41
  if config.add_image_token:
@@ -103,7 +87,7 @@ class ImagePrefix(nn.Module):
103
  Converts pixel values to prefix image prompts that are later fed to a LLM
104
  """
105
 
106
- def __init__(self, config: LinearMappingConfig):
107
  super().__init__()
108
  clip_config = CLIPVisionConfig.from_pretrained(config.image_model)
109
 
@@ -126,21 +110,16 @@ class ImagePrefix(nn.Module):
126
  return self.ln(prefix_prompts)
127
 
128
 
129
- class LinearMapping(nn.Module):
130
 
131
- def __init__(self, config: LinearMappingConfig):
132
  super().__init__()
133
  self.image_prefix = ImagePrefix(config)
134
  self.language_model = GPT2LMHeadModel(AutoConfig.from_pretrained(config.text_model))
135
  if config.text_from_pretrained:
136
  self.language_model = self.language_model.from_pretrained(config.text_model)
137
- self.processor = LinearMappingProcessor(config)
138
- self.tokenizer = self.processor.tokenizer
139
- self.image_processor = self.processor.image_processor
140
- self.add_image_token = config.add_image_token
141
- if config.add_image_token:
142
- self.language_model.resize_token_embeddings(len(self.tokenizer))
143
 
 
144
  if config.freeze_text_model:
145
  for module in self.language_model.modules():
146
  if not isinstance(module, nn.LayerNorm) or config.freeze_ln:
@@ -179,7 +158,7 @@ class LinearMapping(nn.Module):
179
 
180
  for label in labels:
181
  for k, token in enumerate(label):
182
- if token == self.tokenizer.eos_token_id:
183
  label[k + 1:] = -100
184
  break
185
  return {"hidden_states": inputs_embeddings, "labels": labels.to(dtype=torch.int64)}
@@ -208,6 +187,8 @@ class LinearMapping(nn.Module):
208
  pixel_values: Optional[torch.Tensor] = None,
209
  **kwargs
210
  ):
 
 
211
  if pixel_values is None:
212
  return self.language_model.generate(
213
  input_ids=input_ids,
@@ -249,6 +230,7 @@ class LinearMapping(nn.Module):
249
  )
250
  if past_input_ids is not None:
251
  generated_token_ids = torch.cat([past_input_ids, generated_token_ids], dim=-1)
 
252
  return generated_token_ids
253
 
254
  def forward(
 
1
+ from config import CLIPGPT2Config
2
  from transformers import (
3
+ GPT2TokenizerFast, GPT2LMHeadModel,
4
+ CLIPVisionModel, BatchEncoding,
5
+ CLIPImageProcessor,
6
  AutoConfig, CLIPVisionConfig
7
  )
8
  from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModelOutput
9
  import torch
10
  import torch.nn as nn
11
  from typing import List, Optional, Union, Tuple, Dict
 
 
12
 
13
+ EOS_TOKEN_ID = 50256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
+ class CLIPGPT2Processor:
17
  """
18
+ A combination of CLIP ImageProcessor and GPT2TokenizerFast
19
  """
20
 
21
+ def __init__(self, config: CLIPGPT2Config):
22
+ self.image_processor = CLIPImageProcessor.from_pretrained(config.image_model)
23
  self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
24
  self.add_image_token = config.add_image_token
25
  if config.add_image_token:
 
87
  Converts pixel values to prefix image prompts that are later fed to a LLM
88
  """
89
 
90
+ def __init__(self, config: CLIPGPT2Config):
91
  super().__init__()
92
  clip_config = CLIPVisionConfig.from_pretrained(config.image_model)
93
 
 
110
  return self.ln(prefix_prompts)
111
 
112
 
113
+ class CLIPGPT2(nn.Module):
114
 
115
+ def __init__(self, config: CLIPGPT2Config):
116
  super().__init__()
117
  self.image_prefix = ImagePrefix(config)
118
  self.language_model = GPT2LMHeadModel(AutoConfig.from_pretrained(config.text_model))
119
  if config.text_from_pretrained:
120
  self.language_model = self.language_model.from_pretrained(config.text_model)
 
 
 
 
 
 
121
 
122
+ self.language_model.resize_token_embeddings(config.vocab_size)
123
  if config.freeze_text_model:
124
  for module in self.language_model.modules():
125
  if not isinstance(module, nn.LayerNorm) or config.freeze_ln:
 
158
 
159
  for label in labels:
160
  for k, token in enumerate(label):
161
+ if token == EOS_TOKEN_ID:
162
  label[k + 1:] = -100
163
  break
164
  return {"hidden_states": inputs_embeddings, "labels": labels.to(dtype=torch.int64)}
 
187
  pixel_values: Optional[torch.Tensor] = None,
188
  **kwargs
189
  ):
190
+ in_training = self.training
191
+ self.eval()
192
  if pixel_values is None:
193
  return self.language_model.generate(
194
  input_ids=input_ids,
 
230
  )
231
  if past_input_ids is not None:
232
  generated_token_ids = torch.cat([past_input_ids, generated_token_ids], dim=-1)
233
+ self.train(in_training)
234
  return generated_token_ids
235
 
236
  def forward(