tinyllava commited on
Commit
f00108d
1 Parent(s): 3915e08

Update modeling_tinyllava_phi.py

Browse files
Files changed (1) hide show
  1. modeling_tinyllava_phi.py +165 -15
modeling_tinyllava_phi.py CHANGED
@@ -1,15 +1,9 @@
1
-
2
- from .configuration import TinyLlavaConfig, IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
3
-
4
- #from .data_preprocess import load_image, process_images, tokenizer_image_token
5
- from dataclasses import dataclass
6
  from typing import List, Optional, Tuple, Union
7
- import ast
8
  import re
9
 
10
  import torch
11
  import torch.utils.checkpoint
12
- from torch import nn, Tensor
13
  from torch.nn import functional as F
14
 
15
  from transformers import PreTrainedModel
@@ -17,12 +11,11 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
17
  from transformers.generation.utils import GenerateOutput
18
  from transformers import CLIPVisionModel, CLIPImageProcessor, SiglipVisionModel, SiglipImageProcessor
19
 
20
- import time
 
21
  from transformers import AutoConfig, AutoModelForCausalLM, PhiForCausalLM
22
- import os
23
- import sys
24
- sys.path.append(os.path.dirname(sys.path[0]))
25
- from . import test
26
 
27
  # from tinyllava.utils.data_utils import get_value_from_kwargs
28
  CONTROLLER_HEART_BEAT_EXPIRATION = 30
@@ -39,12 +32,170 @@ logger = logging.get_logger(__name__)
39
 
40
  # this import has to be relative, otherwise, when setting trust_remote_code=True
41
  # huggingface transformers won't be able to load the module correctly
42
- from numbers import Number
43
  from typing import List, Optional, Union
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  ACT_TYPE = {
49
  'relu': nn.ReLU,
50
  'gelu': nn.GELU
@@ -138,7 +289,6 @@ class TinyLlavaPreTrainedModel(PreTrainedModel):
138
  return self.language_model._supports_sdpa
139
 
140
 
141
-
142
  class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
143
  def __init__(self, config: TinyLlavaConfig):
144
 
@@ -478,4 +628,4 @@ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
478
 
479
 
480
  AutoConfig.register("tinyllava", TinyLlavaConfig)
481
- AutoModelForCausalLM.register(TinyLlavaConfig, TinyLlavaForConditionalGeneration)
 
 
 
 
 
 
1
  from typing import List, Optional, Tuple, Union
 
2
  import re
3
 
4
  import torch
5
  import torch.utils.checkpoint
6
+ from torch import nn
7
  from torch.nn import functional as F
8
 
9
  from transformers import PreTrainedModel
 
11
  from transformers.generation.utils import GenerateOutput
12
  from transformers import CLIPVisionModel, CLIPImageProcessor, SiglipVisionModel, SiglipImageProcessor
13
 
14
+ from .configuration import TinyLlavaConfig, IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
15
+
16
  from transformers import AutoConfig, AutoModelForCausalLM, PhiForCausalLM
17
+
18
+ import time
 
 
19
 
20
  # from tinyllava.utils.data_utils import get_value_from_kwargs
21
  CONTROLLER_HEART_BEAT_EXPIRATION = 30
 
32
 
33
  # this import has to be relative, otherwise, when setting trust_remote_code=True
34
  # huggingface transformers won't be able to load the module correctly
 
35
  from typing import List, Optional, Union
36
+ import requests
37
+ from PIL import Image
38
+ from io import BytesIO
39
+ import base64
40
+
41
+ # Model Constants
42
+ IGNORE_INDEX = -100
43
+ IMAGE_TOKEN_INDEX = -200
44
+ DEFAULT_IMAGE_TOKEN = "<image>"
45
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
46
+ DEFAULT_IM_START_TOKEN = "<im_start>"
47
+ DEFAULT_IM_END_TOKEN = "<im_end>"
48
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
49
+ import dataclasses
50
+ from enum import auto, Enum
51
+ from typing import List, Tuple
52
+
53
+
54
+ class SeparatorStyle(Enum):
55
+ """Different separator style."""
56
+ SINGLE = auto()
57
+ TWO = auto()
58
+ MPT = auto()
59
+ PLAIN = auto()
60
+ LLAMA_2 = auto()
61
+ TINY_LLAMA = auto()
62
+ QWEN_2 = auto()
63
+
64
+
65
+ @dataclasses.dataclass
66
+ class Conversation:
67
+ """A class that keeps all conversation history."""
68
+ system: str
69
+ roles: List[str]
70
+ messages: List[List[str]]
71
+ offset: int
72
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
73
+ sep: str = "###"
74
+ sep2: str = None
75
+ version: str = "Unknown"
76
+
77
+ skip_next: bool = False
78
+
79
+ def get_prompt(self):
80
+ messages = self.messages
81
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
82
+ messages = self.messages.copy()
83
+ init_role, init_msg = messages[0].copy()
84
+ init_msg = init_msg[0].replace("<image>", "").strip()
85
+ if 'mmtag' in self.version:
86
+ messages[0] = (init_role, init_msg)
87
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
88
+ messages.insert(1, (self.roles[1], "Received."))
89
+ else:
90
+ messages[0] = (init_role, "<image>\n" + init_msg)
91
+
92
+ if self.sep_style == SeparatorStyle.TWO:
93
+ seps = [self.sep, self.sep2]
94
+ ret = self.system + seps[0]
95
+ for i, (role, message) in enumerate(messages):
96
+ if message:
97
+ if type(message) is tuple:
98
+ message, _, _ = message
99
+ ret += role + ": " + message + seps[i % 2]
100
+ else:
101
+ ret += role + ":"
102
+ else:
103
+ raise ValueError(f"Invalid style: {self.sep_style}")
104
+
105
+ return ret
106
+
107
+ def append_message(self, role, message):
108
+ self.messages.append([role, message])
109
+
110
+ def copy(self):
111
+ return Conversation(
112
+ system=self.system,
113
+ roles=self.roles,
114
+ messages=[[x, y] for x, y in self.messages],
115
+ offset=self.offset,
116
+ sep_style=self.sep_style,
117
+ sep=self.sep,
118
+ sep2=self.sep2,
119
+ version=self.version)
120
+
121
+
122
 
123
 
124
+ conv_phi_v0 = Conversation(
125
+ system="A chat between a curious user and an artificial intelligence assistant. "
126
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
127
+ roles=("USER", "ASSISTANT"),
128
+ version="phi",
129
+ messages=(),
130
+ offset=0,
131
+ sep_style=SeparatorStyle.TWO,
132
+ sep=" ",
133
+ sep2="<|endoftext|>",
134
+ )
135
 
136
 
137
+ def load_image_from_base64(image):
138
+ return Image.open(BytesIO(base64.b64decode(image)))
139
+
140
+
141
+ def expand2square(pil_img, background_color):
142
+ width, height = pil_img.size
143
+ if width == height:
144
+ return pil_img
145
+ elif width > height:
146
+ result = Image.new(pil_img.mode, (width, width), background_color)
147
+ result.paste(pil_img, (0, (width - height) // 2))
148
+ return result
149
+ else:
150
+ result = Image.new(pil_img.mode, (height, height), background_color)
151
+ result.paste(pil_img, ((height - width) // 2, 0))
152
+ return result
153
+
154
+
155
+ def process_images(images, image_processor, model_cfg):
156
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
157
+ new_images = []
158
+ if image_aspect_ratio == 'pad':
159
+ for image in images:
160
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
161
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
162
+ new_images.append(image)
163
+ else:
164
+ return image_processor(images, return_tensors='pt')['pixel_values']
165
+ if all(x.shape == new_images[0].shape for x in new_images):
166
+ new_images = torch.stack(new_images, dim=0)
167
+ return new_images
168
+
169
+
170
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
171
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
172
+
173
+ def insert_separator(X, sep):
174
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
175
+
176
+ input_ids = []
177
+ offset = 0
178
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
179
+ offset = 1
180
+ input_ids.append(prompt_chunks[0][0])
181
+
182
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
183
+ input_ids.extend(x[offset:])
184
+
185
+ if return_tensors is not None:
186
+ if return_tensors == 'pt':
187
+ return torch.tensor(input_ids, dtype=torch.long)
188
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
189
+ return input_ids
190
+
191
+ def load_image(image_file):
192
+ if image_file.startswith("http") or image_file.startswith("https"):
193
+ response = requests.get(image_file)
194
+ image = Image.open(BytesIO(response.content)).convert("RGB")
195
+ else:
196
+ image = Image.open(image_file).convert("RGB")
197
+ return image
198
+
199
  ACT_TYPE = {
200
  'relu': nn.ReLU,
201
  'gelu': nn.GELU
 
289
  return self.language_model._supports_sdpa
290
 
291
 
 
292
  class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
293
  def __init__(self, config: TinyLlavaConfig):
294
 
 
628
 
629
 
630
  AutoConfig.register("tinyllava", TinyLlavaConfig)
631
+ AutoModelForCausalLM.register(TinyLlavaConfig, TinyLlavaForConditionalGeneration)