support device type `cpu` generate

#26
Files changed (1) hide show
  1. modeling_GOT.py +200 -134
modeling_GOT.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, StoppingCriteria, TextStreamer
2
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
3
  from typing import List, Optional, Tuple, Union
@@ -19,7 +20,7 @@ DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
19
  DEFAULT_IM_START_TOKEN = '<img>'
20
  DEFAULT_IM_END_TOKEN = '</img>'
21
 
22
- from enum import auto, Enum
23
  class SeparatorStyle(Enum):
24
  """Different separator style."""
25
  SINGLE = auto()
@@ -65,7 +66,7 @@ class Conversation:
65
  return ret
66
  if self.sep_style == SeparatorStyle.MPT:
67
  if self.system:
68
- ret = self.system + self.sep
69
  else:
70
  ret = ''
71
  for role, message in self.messages:
@@ -79,7 +80,6 @@ class Conversation:
79
  else:
80
  raise ValueError(f"Invalid style: {self.sep_style}")
81
 
82
-
83
  def append_message(self, role, message):
84
  self.messages.append([role, message])
85
 
@@ -94,12 +94,12 @@ class Conversation:
94
  sep2=self.sep2)
95
 
96
 
97
-
98
  class KeywordsStoppingCriteria(StoppingCriteria):
99
  def __init__(self, keywords, tokenizer, input_ids):
100
  self.keywords = keywords
101
  self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
102
- self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1]
 
103
  self.tokenizer = tokenizer
104
  self.start_len = None
105
  self.input_ids = input_ids
@@ -111,12 +111,13 @@ class KeywordsStoppingCriteria(StoppingCriteria):
111
  for keyword_id in self.keyword_ids:
112
  if output_ids[0, -1] == keyword_id:
113
  return True
114
- outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
 
115
  for keyword in self.keywords:
116
  if keyword in outputs:
117
  return True
118
  return False
119
-
120
 
121
  class GOTImageEvalProcessor:
122
  def __init__(self, image_size=384, mean=None, std=None):
@@ -136,11 +137,11 @@ class GOTImageEvalProcessor:
136
  self.normalize,
137
  ]
138
  )
 
139
  def __call__(self, item):
140
  return self.transform(item)
141
 
142
 
143
-
144
  class GOTConfig(Qwen2Config):
145
  model_type = "GOT"
146
 
@@ -153,28 +154,25 @@ class GOTQwenModel(Qwen2Model):
153
 
154
  self.vision_tower_high = build_GOT_vit_b()
155
 
156
- self.mm_projector_vary = nn.Linear(1024, 1024)
157
-
158
 
159
  def initialize_vision_modules(
160
- self,
161
  vision_tower,
162
  pretrained_stage1_model=None,
163
  freeze_vision_tower=False,
164
  use_im_start_end=False,
165
  vision_select_layer=-1,
166
  dtype=torch.float16,
167
- device="cuda"
168
  ):
169
 
170
-
171
  image_processor_high = GOTImageEvalProcessor(image_size=1024)
172
-
173
  self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)
174
 
175
  self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device)
176
 
177
-
178
  image_token_len = 256
179
 
180
  self.config.vision_tower = vision_tower
@@ -184,13 +182,12 @@ class GOTQwenModel(Qwen2Model):
184
 
185
  self.config.vision_select_layer = vision_select_layer
186
  self.config.freeze_vision_tower = freeze_vision_tower
187
-
188
  return dict(
189
  image_processor_high=image_processor_high,
190
  image_token_len=image_token_len,
191
  )
192
-
193
-
194
  def forward(
195
  self,
196
  input_ids: torch.LongTensor = None,
@@ -209,16 +206,17 @@ class GOTQwenModel(Qwen2Model):
209
  orig_embeds_params = getattr(self, 'orig_embeds_params', None)
210
  if orig_embeds_params is not None:
211
  with torch.no_grad():
212
- self.get_input_embeddings().weight[:-self.num_new_tokens] = orig_embeds_params[:-self.num_new_tokens].data
 
 
213
 
214
  if inputs_embeds is None:
215
  inputs_embeds = self.embed_tokens(input_ids)
216
 
217
-
218
  vision_tower_high = getattr(self, 'vision_tower_high', None)
219
 
220
-
221
- if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
222
  use_im_start_end = getattr(self.config, "use_im_start_end", -1)
223
 
224
  vision_select_layer = getattr(self.config, "vision_select_layer", -1)
@@ -232,15 +230,15 @@ class GOTQwenModel(Qwen2Model):
232
  im_start_token = 151857
233
 
234
  im_end_token = 151858
235
-
236
  image_features = []
237
-
238
  for image in images:
239
  P, C, H, W = image.shape
240
  if P == 1:
241
  with torch.set_grad_enabled(False):
242
  cnn_feature = vision_tower_high(image)
243
- cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024
244
  image_feature = self.mm_projector_vary(cnn_feature)
245
  image_features.append(image_feature)
246
 
@@ -249,7 +247,7 @@ class GOTQwenModel(Qwen2Model):
249
  image_patches_features = []
250
  for image_patch in image_patches:
251
  image_p = torch.stack([image_patch])
252
-
253
  with torch.set_grad_enabled(False):
254
  cnn_feature_p = vision_tower_high(image_p)
255
  cnn_feature_p = cnn_feature_p.flatten(2).permute(0, 2, 1)
@@ -258,39 +256,44 @@ class GOTQwenModel(Qwen2Model):
258
  image_feature = torch.cat(image_patches_features, dim=1)
259
  image_features.append(image_feature)
260
 
261
-
262
- dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
263
  dummy_image_features = dummy_image_features_2
264
  use_im_start_end = True
265
  new_input_embeds = []
266
- for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
 
267
  if (cur_input_ids == im_patch_token).sum() == 0:
268
  cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
269
  new_input_embeds.append(cur_input_embeds)
270
  continue
271
 
272
  if use_im_start_end:
273
- if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():
274
- raise ValueError("The number of image start tokens and image end tokens should be the same.")
275
-
 
 
276
  image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
277
- for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):
278
- per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)
 
 
279
  num_patches = per_cur_image_features.shape[0]
280
 
281
  if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
282
- raise ValueError("The image end token should follow the image start token.")
283
-
 
284
  cur_input_embeds = torch.cat(
285
  (
286
- cur_input_embeds[:image_start_token_pos+1],
287
- per_cur_image_features,
288
  cur_input_embeds[image_start_token_pos + num_patches + 1:]
289
- ),
290
  dim=0
291
  )
292
 
293
-
294
  new_input_embeds.append(cur_input_embeds)
295
  else:
296
  raise NotImplementedError
@@ -299,13 +302,12 @@ class GOTQwenModel(Qwen2Model):
299
 
300
  return super(GOTQwenModel, self).forward(
301
  input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
302
- inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,
303
  output_attentions=output_attentions, output_hidden_states=output_hidden_states,
304
  return_dict=return_dict
305
  )
306
 
307
 
308
-
309
  class GOTQwenForCausalLM(Qwen2ForCausalLM):
310
  config_class = GOTConfig
311
  # supports_gradient_checkpointing = True
@@ -336,15 +338,14 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
336
  output_hidden_states: Optional[bool] = None,
337
  images: Optional[torch.FloatTensor] = None,
338
  return_dict: Optional[bool] = None,
339
-
340
  ) -> Union[Tuple, CausalLMOutputWithPast]:
341
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
342
  output_hidden_states = (
343
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
344
- )
345
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
346
 
347
- outputs = self.model(
348
  input_ids=input_ids,
349
  past_key_values=past_key_values,
350
  attention_mask=attention_mask,
@@ -355,7 +356,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
355
  output_hidden_states=output_hidden_states,
356
  images=images,
357
  return_dict=return_dict
358
-
359
  )
360
 
361
  hidden_states = outputs[0]
@@ -389,7 +390,6 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
389
  attentions=outputs.attentions,
390
  )
391
 
392
-
393
  def prepare_inputs_for_generation(
394
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
395
  ):
@@ -408,14 +408,16 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
408
  # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
409
  # input)
410
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
411
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
412
  # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
413
  # input_ids based on the past_length.
414
  elif past_length < input_ids.shape[1]:
415
  input_ids = input_ids[:, past_length:]
416
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
 
417
 
418
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
 
419
  if (
420
  max_cache_length is not None
421
  and attention_mask is not None
@@ -429,7 +431,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
429
  position_ids = attention_mask.long().cumsum(-1) - 1
430
  position_ids.masked_fill_(attention_mask == 0, 1)
431
  if past_key_values:
432
- position_ids = position_ids[:, -input_ids.shape[1] :]
433
 
434
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
435
  if inputs_embeds is not None and past_key_values is None:
@@ -449,15 +451,13 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
449
  return model_inputs
450
 
451
  def initialize_vision_tokenizer(
452
- self,
453
- tokenizer,
454
- freeze_lm_model=False,
455
  pretrained_stage1_model=None,
456
- device="cuda"
457
  ):
458
  config = self.get_model().config
459
 
460
-
461
  self.resize_token_embeddings(len(tokenizer))
462
 
463
  config.im_patch_token = 151859
@@ -484,12 +484,23 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
484
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
485
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
486
 
487
- def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
 
 
 
 
 
 
 
 
 
 
 
 
488
 
489
  self.disable_torch_init()
490
 
491
-
492
- image_processor_high = GOTImageEvalProcessor(image_size=1024)
493
 
494
  use_im_start_end = True
495
 
@@ -501,7 +512,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
501
  image = self.load_image(image_file)
502
 
503
  w, h = image.size
504
-
505
  if ocr_type == 'format':
506
  qs = 'OCR with format: '
507
  else:
@@ -510,13 +521,13 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
510
  if ocr_box:
511
  bbox = eval(ocr_box)
512
  if len(bbox) == 2:
513
- bbox[0] = int(bbox[0]/w*1000)
514
- bbox[1] = int(bbox[1]/h*1000)
515
  if len(bbox) == 4:
516
- bbox[0] = int(bbox[0]/w*1000)
517
- bbox[1] = int(bbox[1]/h*1000)
518
- bbox[2] = int(bbox[2]/w*1000)
519
- bbox[3] = int(bbox[3]/h*1000)
520
  if ocr_type == 'format':
521
  qs = str(bbox) + ' ' + 'OCR with format: '
522
  else:
@@ -529,11 +540,11 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
529
  qs = '[' + ocr_color + ']' + ' ' + 'OCR: '
530
 
531
  if use_im_start_end:
532
- qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
 
533
  else:
534
  qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
535
 
536
-
537
  conv_mpt = Conversation(
538
  system="""<|im_start|>system
539
  You should follow the instructions carefully and explain your answers in detail.""",
@@ -558,40 +569,42 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
558
 
559
  image_tensor_1 = image_processor_high(image)
560
 
561
- input_ids = torch.as_tensor(inputs.input_ids).cuda()
562
 
563
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
564
  keywords = [stop_str]
565
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
566
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
567
 
 
568
  if stream_flag:
569
- with torch.autocast("cuda", dtype=torch.bfloat16):
570
  output_ids = self.generate(
571
  input_ids,
572
- images=[image_tensor_1.unsqueeze(0).half().cuda()],
573
  do_sample=False,
574
- num_beams = 1,
575
- no_repeat_ngram_size = 20,
576
  streamer=streamer,
577
  max_new_tokens=4096,
578
  stopping_criteria=[stopping_criteria]
579
- )
580
  else:
581
- with torch.autocast("cuda", dtype=torch.bfloat16):
582
  output_ids = self.generate(
583
  input_ids,
584
- images=[image_tensor_1.unsqueeze(0).half().cuda()],
585
  do_sample=False,
586
- num_beams = 1,
587
- no_repeat_ngram_size = 20,
588
  # streamer=streamer,
589
  max_new_tokens=4096,
590
  stopping_criteria=[stopping_criteria]
591
- )
592
-
593
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
594
-
595
  if outputs.endswith(stop_str):
596
  outputs = outputs[:-len(stop_str)]
597
  outputs = outputs.strip()
@@ -606,8 +619,8 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
606
  tk = verovio.toolkit()
607
  tk.loadData(outputs)
608
  tk.setOptions({"pageWidth": 2100, "footer": 'none',
609
- 'barLineWidth': 0.5, 'beamMaxSlope': 15,
610
- 'staffLineWidth': 0.2, 'spacingStaff': 6})
611
  tk.getPageCount()
612
  svg = tk.renderToSVG()
613
  svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
@@ -616,35 +629,52 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
616
 
617
  if ocr_type == 'format' and '**kern' not in outputs:
618
 
619
-
620
- if '\\begin{tikzpicture}' not in outputs:
621
  html_path_2 = save_render_file
622
  right_num = outputs.count('\\right')
623
- left_num = outputs.count('\left')
624
 
625
  if right_num != left_num:
626
- outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
627
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
628
 
629
  outputs = outputs.replace('"', '``').replace('$', '')
630
 
631
  outputs_list = outputs.split('\n')
632
- gt= ''
633
  for out in outputs_list:
634
- gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
635
-
636
- gt = gt[:-2]
637
 
 
638
 
639
  lines = content_mmd_to_html
640
  lines = lines.split("const text =")
641
- new_web = lines[0] + 'const text =' + gt + lines[1]
642
 
643
  else:
644
  html_path_2 = save_render_file
645
  outputs = outputs.translate(translation_table)
646
  outputs_list = outputs.split('\n')
647
- gt= ''
648
  for out in outputs_list:
649
  if out:
650
  if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
@@ -652,7 +682,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
652
  out = out[:-1]
653
  if out is None:
654
  break
655
-
656
  if out:
657
  if out[-1] != ';':
658
  gt += out[:-1] + ';\n'
@@ -661,7 +691,6 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
661
  else:
662
  gt += out + '\n'
663
 
664
-
665
  lines = tik_html
666
  lines = lines.split("const text =")
667
  new_web = lines[0] + gt + lines[1]
@@ -671,7 +700,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
671
  return response_str
672
 
673
  def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
674
-
675
  def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
676
  best_ratio_diff = float('inf')
677
  best_ratio = (1, 1)
@@ -687,14 +716,25 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
687
  best_ratio = ratio
688
  # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
689
  return best_ratio
690
-
691
  orig_width, orig_height = image.size
692
  aspect_ratio = orig_width / orig_height
693
 
694
  # calculate the existing image aspect ratio
695
  target_ratios = set(
696
- (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
697
- i * j <= max_num and i * j >= min_num)
 
 
 
 
 
 
 
 
 
 
 
698
  # print(target_ratios)
699
  target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
700
 
@@ -727,18 +767,25 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
727
  processed_images.append(thumbnail_img)
728
  return processed_images
729
 
730
-
731
- def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
 
 
 
 
 
 
 
 
 
732
  # Model
733
  self.disable_torch_init()
734
- multi_page=False
735
 
736
-
737
- image_processor_high = GOTImageEvalProcessor(image_size=1024)
738
 
739
  use_im_start_end = True
740
 
741
-
742
  image_token_len = 256
743
 
744
  image_list = []
@@ -778,18 +825,16 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
778
  image_tensor_1 = image_processor_high(image)
779
  image_list.append(image_tensor_1)
780
 
781
-
782
  image_list = torch.stack(image_list)
783
 
784
- print('====new images batch size======: \n',image_list.shape)
785
-
786
 
787
  if use_im_start_end:
788
- qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs
 
789
  else:
790
  qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
791
 
792
-
793
  conv_mpt = Conversation(
794
  system="""<|im_start|>system
795
  You should follow the instructions carefully and explain your answers in detail.""",
@@ -812,43 +857,45 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
812
 
813
  inputs = tokenizer([prompt])
814
 
815
- input_ids = torch.as_tensor(inputs.input_ids).cuda()
816
 
817
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
818
  keywords = [stop_str]
819
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
820
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
821
 
 
822
  if stream_flag:
823
- with torch.autocast("cuda", dtype=torch.bfloat16):
824
  output_ids = self.generate(
825
  input_ids,
826
- images=[image_list.half().cuda()],
827
  do_sample=False,
828
- num_beams = 1,
829
  # no_repeat_ngram_size = 20,
830
  streamer=streamer,
831
  max_new_tokens=4096,
832
  stopping_criteria=[stopping_criteria]
833
- )
834
  else:
835
- with torch.autocast("cuda", dtype=torch.bfloat16):
836
  output_ids = self.generate(
837
  input_ids,
838
- images=[image_list.half().cuda()],
839
  do_sample=False,
840
- num_beams = 1,
841
  # no_repeat_ngram_size = 20,
842
  # streamer=streamer,
843
  max_new_tokens=4096,
844
  stopping_criteria=[stopping_criteria]
845
- )
846
 
847
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
848
-
849
  if outputs.endswith(stop_str):
850
  outputs = outputs[:-len(stop_str)]
851
- outputs = outputs.strip()
852
  response_str = outputs
853
 
854
  if render:
@@ -856,26 +903,45 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
856
  from .render_tools import content_mmd_to_html
857
  html_path_2 = save_render_file
858
  right_num = outputs.count('\\right')
859
- left_num = outputs.count('\left')
860
 
861
  if right_num != left_num:
862
- outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
863
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
864
 
865
  outputs = outputs.replace('"', '``').replace('$', '')
866
 
867
  outputs_list = outputs.split('\n')
868
- gt= ''
869
  for out in outputs_list:
870
- gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
871
-
872
  gt = gt[:-2]
873
 
874
  lines = content_mmd_to_html
875
  lines = lines.split("const text =")
876
- new_web = lines[0] + 'const text =' + gt + lines[1]
877
-
878
  with open(html_path_2, 'w') as web_f_new:
879
  web_f_new.write(new_web)
880
 
881
- return response_str
 
1
+ from enum import auto, Enum
2
  from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, StoppingCriteria, TextStreamer
3
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
4
  from typing import List, Optional, Tuple, Union
 
20
  DEFAULT_IM_START_TOKEN = '<img>'
21
  DEFAULT_IM_END_TOKEN = '</img>'
22
 
23
+
24
  class SeparatorStyle(Enum):
25
  """Different separator style."""
26
  SINGLE = auto()
 
66
  return ret
67
  if self.sep_style == SeparatorStyle.MPT:
68
  if self.system:
69
+ ret = self.system + self.sep
70
  else:
71
  ret = ''
72
  for role, message in self.messages:
 
80
  else:
81
  raise ValueError(f"Invalid style: {self.sep_style}")
82
 
 
83
  def append_message(self, role, message):
84
  self.messages.append([role, message])
85
 
 
94
  sep2=self.sep2)
95
 
96
 
 
97
  class KeywordsStoppingCriteria(StoppingCriteria):
98
  def __init__(self, keywords, tokenizer, input_ids):
99
  self.keywords = keywords
100
  self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
101
+ self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(
102
+ keyword_id) is list and len(keyword_id) == 1]
103
  self.tokenizer = tokenizer
104
  self.start_len = None
105
  self.input_ids = input_ids
 
111
  for keyword_id in self.keyword_ids:
112
  if output_ids[0, -1] == keyword_id:
113
  return True
114
+ outputs = self.tokenizer.batch_decode(
115
+ output_ids[:, self.start_len:], skip_special_tokens=True)[0]
116
  for keyword in self.keywords:
117
  if keyword in outputs:
118
  return True
119
  return False
120
+
121
 
122
  class GOTImageEvalProcessor:
123
  def __init__(self, image_size=384, mean=None, std=None):
 
137
  self.normalize,
138
  ]
139
  )
140
+
141
  def __call__(self, item):
142
  return self.transform(item)
143
 
144
 
 
145
  class GOTConfig(Qwen2Config):
146
  model_type = "GOT"
147
 
 
154
 
155
  self.vision_tower_high = build_GOT_vit_b()
156
 
157
+ self.mm_projector_vary = nn.Linear(1024, 1024)
 
158
 
159
  def initialize_vision_modules(
160
+ self,
161
  vision_tower,
162
  pretrained_stage1_model=None,
163
  freeze_vision_tower=False,
164
  use_im_start_end=False,
165
  vision_select_layer=-1,
166
  dtype=torch.float16,
 
167
  ):
168
 
169
+ device = self.device
170
  image_processor_high = GOTImageEvalProcessor(image_size=1024)
171
+
172
  self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)
173
 
174
  self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device)
175
 
 
176
  image_token_len = 256
177
 
178
  self.config.vision_tower = vision_tower
 
182
 
183
  self.config.vision_select_layer = vision_select_layer
184
  self.config.freeze_vision_tower = freeze_vision_tower
185
+
186
  return dict(
187
  image_processor_high=image_processor_high,
188
  image_token_len=image_token_len,
189
  )
190
+
 
191
  def forward(
192
  self,
193
  input_ids: torch.LongTensor = None,
 
206
  orig_embeds_params = getattr(self, 'orig_embeds_params', None)
207
  if orig_embeds_params is not None:
208
  with torch.no_grad():
209
+ self.get_input_embeddings().weight[:-
210
+ self.num_new_tokens] = orig_embeds_params[:-
211
+ self.num_new_tokens].data
212
 
213
  if inputs_embeds is None:
214
  inputs_embeds = self.embed_tokens(input_ids)
215
 
 
216
  vision_tower_high = getattr(self, 'vision_tower_high', None)
217
 
218
+ if vision_tower_high is not None and (
219
+ input_ids.shape[1] != 1 or self.training) and images is not None:
220
  use_im_start_end = getattr(self.config, "use_im_start_end", -1)
221
 
222
  vision_select_layer = getattr(self.config, "vision_select_layer", -1)
 
230
  im_start_token = 151857
231
 
232
  im_end_token = 151858
233
+
234
  image_features = []
235
+
236
  for image in images:
237
  P, C, H, W = image.shape
238
  if P == 1:
239
  with torch.set_grad_enabled(False):
240
  cnn_feature = vision_tower_high(image)
241
+ cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024
242
  image_feature = self.mm_projector_vary(cnn_feature)
243
  image_features.append(image_feature)
244
 
 
247
  image_patches_features = []
248
  for image_patch in image_patches:
249
  image_p = torch.stack([image_patch])
250
+
251
  with torch.set_grad_enabled(False):
252
  cnn_feature_p = vision_tower_high(image_p)
253
  cnn_feature_p = cnn_feature_p.flatten(2).permute(0, 2, 1)
 
256
  image_feature = torch.cat(image_patches_features, dim=1)
257
  image_features.append(image_feature)
258
 
259
+ dummy_image_features_2 = torch.zeros(
260
+ 256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
261
  dummy_image_features = dummy_image_features_2
262
  use_im_start_end = True
263
  new_input_embeds = []
264
+ for cur_input_ids, cur_input_embeds, cur_image_features in zip(
265
+ input_ids, inputs_embeds, image_features):
266
  if (cur_input_ids == im_patch_token).sum() == 0:
267
  cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
268
  new_input_embeds.append(cur_input_embeds)
269
  continue
270
 
271
  if use_im_start_end:
272
+ if (cur_input_ids == im_start_token).sum() != (
273
+ cur_input_ids == im_end_token).sum():
274
+ raise ValueError(
275
+ "The number of image start tokens and image end tokens should be the same.")
276
+
277
  image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
278
+ for image_start_token_pos, per_cur_image_features in zip(
279
+ image_start_tokens, cur_image_features):
280
+ per_cur_image_features = per_cur_image_features.to(
281
+ device=cur_input_embeds.device)
282
  num_patches = per_cur_image_features.shape[0]
283
 
284
  if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
285
+ raise ValueError(
286
+ "The image end token should follow the image start token.")
287
+
288
  cur_input_embeds = torch.cat(
289
  (
290
+ cur_input_embeds[:image_start_token_pos + 1],
291
+ per_cur_image_features,
292
  cur_input_embeds[image_start_token_pos + num_patches + 1:]
293
+ ),
294
  dim=0
295
  )
296
 
 
297
  new_input_embeds.append(cur_input_embeds)
298
  else:
299
  raise NotImplementedError
 
302
 
303
  return super(GOTQwenModel, self).forward(
304
  input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
305
+ inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids=position_ids,
306
  output_attentions=output_attentions, output_hidden_states=output_hidden_states,
307
  return_dict=return_dict
308
  )
309
 
310
 
 
311
  class GOTQwenForCausalLM(Qwen2ForCausalLM):
312
  config_class = GOTConfig
313
  # supports_gradient_checkpointing = True
 
338
  output_hidden_states: Optional[bool] = None,
339
  images: Optional[torch.FloatTensor] = None,
340
  return_dict: Optional[bool] = None,
341
+
342
  ) -> Union[Tuple, CausalLMOutputWithPast]:
343
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
344
  output_hidden_states = (
345
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
 
346
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
347
 
348
+ outputs = self.model(
349
  input_ids=input_ids,
350
  past_key_values=past_key_values,
351
  attention_mask=attention_mask,
 
356
  output_hidden_states=output_hidden_states,
357
  images=images,
358
  return_dict=return_dict
359
+
360
  )
361
 
362
  hidden_states = outputs[0]
 
390
  attentions=outputs.attentions,
391
  )
392
 
 
393
  def prepare_inputs_for_generation(
394
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
395
  ):
 
408
  # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
409
  # input)
410
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
411
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
412
  # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
413
  # input_ids based on the past_length.
414
  elif past_length < input_ids.shape[1]:
415
  input_ids = input_ids[:, past_length:]
416
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume
417
+ # input_ids only has unprocessed tokens.
418
 
419
+ # If we are about to go beyond the maximum cache length, we need to crop
420
+ # the input attention mask.
421
  if (
422
  max_cache_length is not None
423
  and attention_mask is not None
 
431
  position_ids = attention_mask.long().cumsum(-1) - 1
432
  position_ids.masked_fill_(attention_mask == 0, 1)
433
  if past_key_values:
434
+ position_ids = position_ids[:, -input_ids.shape[1]:]
435
 
436
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
437
  if inputs_embeds is not None and past_key_values is None:
 
451
  return model_inputs
452
 
453
  def initialize_vision_tokenizer(
454
+ self,
455
+ tokenizer,
456
+ freeze_lm_model=False,
457
  pretrained_stage1_model=None,
 
458
  ):
459
  config = self.get_model().config
460
 
 
461
  self.resize_token_embeddings(len(tokenizer))
462
 
463
  config.im_patch_token = 151859
 
484
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
485
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
486
 
487
+ def chat(
488
+ self,
489
+ tokenizer,
490
+ image_file,
491
+ ocr_type,
492
+ ocr_box='',
493
+ ocr_color='',
494
+ render=False,
495
+ save_render_file=None,
496
+ print_prompt=False,
497
+ gradio_input=False,
498
+ stream_flag=False,
499
+ streamer=None):
500
 
501
  self.disable_torch_init()
502
 
503
+ image_processor_high = GOTImageEvalProcessor(image_size=1024)
 
504
 
505
  use_im_start_end = True
506
 
 
512
  image = self.load_image(image_file)
513
 
514
  w, h = image.size
515
+
516
  if ocr_type == 'format':
517
  qs = 'OCR with format: '
518
  else:
 
521
  if ocr_box:
522
  bbox = eval(ocr_box)
523
  if len(bbox) == 2:
524
+ bbox[0] = int(bbox[0] / w * 1000)
525
+ bbox[1] = int(bbox[1] / h * 1000)
526
  if len(bbox) == 4:
527
+ bbox[0] = int(bbox[0] / w * 1000)
528
+ bbox[1] = int(bbox[1] / h * 1000)
529
+ bbox[2] = int(bbox[2] / w * 1000)
530
+ bbox[3] = int(bbox[3] / h * 1000)
531
  if ocr_type == 'format':
532
  qs = str(bbox) + ' ' + 'OCR with format: '
533
  else:
 
540
  qs = '[' + ocr_color + ']' + ' ' + 'OCR: '
541
 
542
  if use_im_start_end:
543
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * \
544
+ image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
545
  else:
546
  qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
547
 
 
548
  conv_mpt = Conversation(
549
  system="""<|im_start|>system
550
  You should follow the instructions carefully and explain your answers in detail.""",
 
569
 
570
  image_tensor_1 = image_processor_high(image)
571
 
572
+ input_ids = torch.as_tensor(inputs.input_ids).to(self.model.device)
573
 
574
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
575
  keywords = [stop_str]
576
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
577
+ streamer = streamer if streamer else TextStreamer(
578
+ tokenizer, skip_prompt=True, skip_special_tokens=True)
579
 
580
+ device = "cuda" if "cuda" in str(self.model.device) else "cpu"
581
  if stream_flag:
582
+ with torch.autocast(device, dtype=torch.bfloat16):
583
  output_ids = self.generate(
584
  input_ids,
585
+ images=[image_tensor_1.unsqueeze(0).half().to(self.model.device)],
586
  do_sample=False,
587
+ num_beams=1,
588
+ no_repeat_ngram_size=20,
589
  streamer=streamer,
590
  max_new_tokens=4096,
591
  stopping_criteria=[stopping_criteria]
592
+ )
593
  else:
594
+ with torch.autocast(device, dtype=torch.bfloat16):
595
  output_ids = self.generate(
596
  input_ids,
597
+ images=[image_tensor_1.unsqueeze(0).half().to(self.model.device)],
598
  do_sample=False,
599
+ num_beams=1,
600
+ no_repeat_ngram_size=20,
601
  # streamer=streamer,
602
  max_new_tokens=4096,
603
  stopping_criteria=[stopping_criteria]
604
+ )
605
+
606
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
607
+
608
  if outputs.endswith(stop_str):
609
  outputs = outputs[:-len(stop_str)]
610
  outputs = outputs.strip()
 
619
  tk = verovio.toolkit()
620
  tk.loadData(outputs)
621
  tk.setOptions({"pageWidth": 2100, "footer": 'none',
622
+ 'barLineWidth': 0.5, 'beamMaxSlope': 15,
623
+ 'staffLineWidth': 0.2, 'spacingStaff': 6})
624
  tk.getPageCount()
625
  svg = tk.renderToSVG()
626
  svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
 
629
 
630
  if ocr_type == 'format' and '**kern' not in outputs:
631
 
632
+ if '\\begin{tikzpicture}' not in outputs:
 
633
  html_path_2 = save_render_file
634
  right_num = outputs.count('\\right')
635
+ left_num = outputs.count('\\left')
636
 
637
  if right_num != left_num:
638
+ outputs = outputs.replace(
639
+ '\\left(',
640
+ '(').replace(
641
+ '\\right)',
642
+ ')').replace(
643
+ '\\left[',
644
+ '[').replace(
645
+ '\\right]',
646
+ ']').replace(
647
+ '\\left{',
648
+ '{').replace(
649
+ '\\right}',
650
+ '}').replace(
651
+ '\\left|',
652
+ '|').replace(
653
+ '\\right|',
654
+ '|').replace(
655
+ '\\left.',
656
+ '.').replace(
657
+ '\\right.',
658
+ '.')
659
 
660
  outputs = outputs.replace('"', '``').replace('$', '')
661
 
662
  outputs_list = outputs.split('\n')
663
+ gt = ''
664
  for out in outputs_list:
665
+ gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
 
 
666
 
667
+ gt = gt[:-2]
668
 
669
  lines = content_mmd_to_html
670
  lines = lines.split("const text =")
671
+ new_web = lines[0] + 'const text =' + gt + lines[1]
672
 
673
  else:
674
  html_path_2 = save_render_file
675
  outputs = outputs.translate(translation_table)
676
  outputs_list = outputs.split('\n')
677
+ gt = ''
678
  for out in outputs_list:
679
  if out:
680
  if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
 
682
  out = out[:-1]
683
  if out is None:
684
  break
685
+
686
  if out:
687
  if out[-1] != ';':
688
  gt += out[:-1] + ';\n'
 
691
  else:
692
  gt += out + '\n'
693
 
 
694
  lines = tik_html
695
  lines = lines.split("const text =")
696
  new_web = lines[0] + gt + lines[1]
 
700
  return response_str
701
 
702
  def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
703
+
704
  def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
705
  best_ratio_diff = float('inf')
706
  best_ratio = (1, 1)
 
716
  best_ratio = ratio
717
  # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
718
  return best_ratio
719
+
720
  orig_width, orig_height = image.size
721
  aspect_ratio = orig_width / orig_height
722
 
723
  # calculate the existing image aspect ratio
724
  target_ratios = set(
725
+ (i,
726
+ j) for n in range(
727
+ min_num,
728
+ max_num +
729
+ 1) for i in range(
730
+ 1,
731
+ n +
732
+ 1) for j in range(
733
+ 1,
734
+ n +
735
+ 1) if i *
736
+ j <= max_num and i *
737
+ j >= min_num)
738
  # print(target_ratios)
739
  target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
740
 
 
767
  processed_images.append(thumbnail_img)
768
  return processed_images
769
 
770
+ def chat_crop(
771
+ self,
772
+ tokenizer,
773
+ image_file,
774
+ ocr_type,
775
+ render=False,
776
+ save_render_file=None,
777
+ print_prompt=False,
778
+ gradio_input=False,
779
+ stream_flag=False,
780
+ streamer=None):
781
  # Model
782
  self.disable_torch_init()
783
+ multi_page = False
784
 
785
+ image_processor_high = GOTImageEvalProcessor(image_size=1024)
 
786
 
787
  use_im_start_end = True
788
 
 
789
  image_token_len = 256
790
 
791
  image_list = []
 
825
  image_tensor_1 = image_processor_high(image)
826
  image_list.append(image_tensor_1)
827
 
 
828
  image_list = torch.stack(image_list)
829
 
830
+ print('====new images batch size======: \n', image_list.shape)
 
831
 
832
  if use_im_start_end:
833
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * \
834
+ image_token_len * ll + DEFAULT_IM_END_TOKEN + '\n' + qs
835
  else:
836
  qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
837
 
 
838
  conv_mpt = Conversation(
839
  system="""<|im_start|>system
840
  You should follow the instructions carefully and explain your answers in detail.""",
 
857
 
858
  inputs = tokenizer([prompt])
859
 
860
+ input_ids = torch.as_tensor(inputs.input_ids).to(self.model.device)
861
 
862
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
863
  keywords = [stop_str]
864
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
865
+ streamer = streamer if streamer else TextStreamer(
866
+ tokenizer, skip_prompt=True, skip_special_tokens=True)
867
 
868
+ device = "cuda" if "cuda" in str(self.model.device) else "cpu"
869
  if stream_flag:
870
+ with torch.autocast(device, dtype=torch.bfloat16):
871
  output_ids = self.generate(
872
  input_ids,
873
+ images=[image_list.half().to(self.model.device)],
874
  do_sample=False,
875
+ num_beams=1,
876
  # no_repeat_ngram_size = 20,
877
  streamer=streamer,
878
  max_new_tokens=4096,
879
  stopping_criteria=[stopping_criteria]
880
+ )
881
  else:
882
+ with torch.autocast(device, dtype=torch.bfloat16):
883
  output_ids = self.generate(
884
  input_ids,
885
+ images=[image_list.half().to(self.model.device)],
886
  do_sample=False,
887
+ num_beams=1,
888
  # no_repeat_ngram_size = 20,
889
  # streamer=streamer,
890
  max_new_tokens=4096,
891
  stopping_criteria=[stopping_criteria]
892
+ )
893
 
894
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
895
+
896
  if outputs.endswith(stop_str):
897
  outputs = outputs[:-len(stop_str)]
898
+ outputs = outputs.strip()
899
  response_str = outputs
900
 
901
  if render:
 
903
  from .render_tools import content_mmd_to_html
904
  html_path_2 = save_render_file
905
  right_num = outputs.count('\\right')
906
+ left_num = outputs.count('\\left')
907
 
908
  if right_num != left_num:
909
+ outputs = outputs.replace(
910
+ '\\left(',
911
+ '(').replace(
912
+ '\\right)',
913
+ ')').replace(
914
+ '\\left[',
915
+ '[').replace(
916
+ '\\right]',
917
+ ']').replace(
918
+ '\\left{',
919
+ '{').replace(
920
+ '\\right}',
921
+ '}').replace(
922
+ '\\left|',
923
+ '|').replace(
924
+ '\\right|',
925
+ '|').replace(
926
+ '\\left.',
927
+ '.').replace(
928
+ '\\right.',
929
+ '.')
930
 
931
  outputs = outputs.replace('"', '``').replace('$', '')
932
 
933
  outputs_list = outputs.split('\n')
934
+ gt = ''
935
  for out in outputs_list:
936
+ gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
937
+
938
  gt = gt[:-2]
939
 
940
  lines = content_mmd_to_html
941
  lines = lines.split("const text =")
942
+ new_web = lines[0] + 'const text =' + gt + lines[1]
943
+
944
  with open(html_path_2, 'w') as web_f_new:
945
  web_f_new.write(new_web)
946
 
947
+ return response_str