czczup commited on
Commit
d1dc194
1 Parent(s): 88c4cde

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_internvl_chat.py +15 -7
modeling_internvl_chat.py CHANGED
@@ -17,10 +17,10 @@ from transformers.generation.streamers import BaseStreamer
17
  from transformers.modeling_outputs import CausalLMOutputWithPast
18
  from transformers.modeling_utils import PreTrainedModel
19
  from transformers.utils import ModelOutput, logging
 
20
 
21
  from .configuration_internvl_chat import InternVLChatConfig
22
  from .modeling_intern_vit import InternVisionModel
23
- from transformers.generation.utils import GreedySearchOutput,validate_stopping_criteria,GreedySearchDecoderOnlyOutput,GreedySearchEncoderDecoderOutput
24
 
25
  logger = logging.get_logger(__name__)
26
 
@@ -375,16 +375,24 @@ class InternVLChatModel(PreTrainedModel):
375
  vit_embeds = self.mlp1(vit_embeds)
376
  return vit_embeds
377
 
378
- def chat(self, tokenizer, pixel_values, question, generation_config,
379
  IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'):
380
 
381
  img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
382
  self.img_context_token_id = img_context_token_id
383
 
384
  from .conversation import get_conv_template
 
385
  template = get_conv_template(self.template)
386
- image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token + IMG_END_TOKEN
387
- template.append_message(template.roles[0], image_tokens + '\n' + question)
 
 
 
 
 
 
 
388
  template.append_message(template.roles[1], None)
389
  query = template.get_prompt()
390
  model_inputs = tokenizer(query, return_tensors='pt')
@@ -398,9 +406,8 @@ class InternVLChatModel(PreTrainedModel):
398
  **generation_config
399
  )
400
  response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
401
- query_to_print = query.replace(image_tokens, '<image>')
402
- print(query_to_print, response)
403
- return response
404
 
405
  @torch.no_grad()
406
  def generate(
@@ -421,6 +428,7 @@ class InternVLChatModel(PreTrainedModel):
421
  vit_embeds = visual_features
422
  else:
423
  vit_embeds = self.extract_feature(pixel_values)
 
424
  input_embeds = self.language_model.get_input_embeddings()(input_ids)
425
  B, N, C = input_embeds.shape
426
  input_embeds = input_embeds.reshape(B * N, C)
 
17
  from transformers.modeling_outputs import CausalLMOutputWithPast
18
  from transformers.modeling_utils import PreTrainedModel
19
  from transformers.utils import ModelOutput, logging
20
+ from transformers.generation.utils import GreedySearchOutput, validate_stopping_criteria, GreedySearchDecoderOnlyOutput,GreedySearchEncoderDecoderOutput
21
 
22
  from .configuration_internvl_chat import InternVLChatConfig
23
  from .modeling_intern_vit import InternVisionModel
 
24
 
25
  logger = logging.get_logger(__name__)
26
 
 
375
  vit_embeds = self.mlp1(vit_embeds)
376
  return vit_embeds
377
 
378
+ def chat(self, tokenizer, pixel_values, question, generation_config, history=None,
379
  IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'):
380
 
381
  img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
382
  self.img_context_token_id = img_context_token_id
383
 
384
  from .conversation import get_conv_template
385
+
386
  template = get_conv_template(self.template)
387
+ if history is None:
388
+ history = []
389
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token + IMG_END_TOKEN
390
+ question = image_tokens + '\n' + question
391
+ else:
392
+ for (old_question, old_answer) in history:
393
+ template.append_message(template.roles[0], old_question)
394
+ template.append_message(template.roles[1], old_answer)
395
+ template.append_message(template.roles[0], question)
396
  template.append_message(template.roles[1], None)
397
  query = template.get_prompt()
398
  model_inputs = tokenizer(query, return_tensors='pt')
 
406
  **generation_config
407
  )
408
  response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
409
+ history.append((question, response))
410
+ return response, history
 
411
 
412
  @torch.no_grad()
413
  def generate(
 
428
  vit_embeds = visual_features
429
  else:
430
  vit_embeds = self.extract_feature(pixel_values)
431
+
432
  input_embeds = self.language_model.get_input_embeddings()(input_ids)
433
  B, N, C = input_embeds.shape
434
  input_embeds = input_embeds.reshape(B * N, C)