g-h-chen commited on
Commit
270e869
1 Parent(s): 0762d75

upload generation_utils.py

Browse files
Files changed (1) hide show
  1. generation_utils.py +288 -0
generation_utils.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from queue import Queue
3
+
4
+ import torch
5
+ from PIL import Image
6
+ from copy import deepcopy
7
+ import requests, os
8
+
9
+ IMAGE_TOKEN_INDEX=-200
10
+ blacklist = ['<image>', '<s>', '</s>']
11
+ max_num_images = 3 # phi has a context length limit of 2048 and each image occupies 576 tokens.
12
+
13
+ def input_moderation(texts: list[list[str]]):
14
+ # perform input moderation on each message
15
+ for text_pair in texts:
16
+ # in-place operation
17
+ for b in blacklist:
18
+ text_pair[0] = text_pair[0].replace(b, '')
19
+ if text_pair[1] is not None:
20
+ text_pair[1] = text_pair[1].replace(b, '')
21
+
22
+ return texts
23
+
24
+ def insert_image_placeholder(t, num_images, placeholder='<image>', sep='\n'):
25
+ for _ in range(num_images):
26
+ t = f"{placeholder}{sep}" + t
27
+ return t
28
+
29
+ def get_conv(texts):
30
+ ret = []
31
+
32
+ for conv in texts:
33
+ ret.append({'from': 'human', 'value': conv[0]})
34
+ ret.append({'from': 'gpt', 'value': conv[1]}) # this is None for the last one
35
+
36
+ return ret
37
+
38
+ # copied from llava
39
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
40
+ prompt_chunks = [tokenizer(chunk, add_special_tokens=False).input_ids for chunk in prompt.split('<image>')]
41
+
42
+ def insert_separator(X, sep):
43
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
44
+
45
+ input_ids = []
46
+ offset = 0
47
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
48
+ offset = 1
49
+ input_ids.append(prompt_chunks[0][0])
50
+
51
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
52
+ input_ids.extend(x[offset:])
53
+
54
+ if return_tensors is not None:
55
+ if return_tensors == 'pt':
56
+ return torch.tensor(input_ids, dtype=torch.long)
57
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
58
+ return input_ids
59
+
60
+ def preprocess(tokenizer, data: list, return_tensors='pt'):
61
+ '''
62
+ [
63
+ {
64
+ 'from': 'human',
65
+ 'value': xxx,
66
+ },
67
+ {
68
+ 'from': 'gpt',
69
+ 'value': xxx
70
+ }
71
+ ]
72
+ '''
73
+ # needs update
74
+ if not isinstance(data, list):
75
+ raise ValueError('must be a list')
76
+
77
+ # this is per model (tokenizer)
78
+ return preprocess_allava(tokenizer, data, return_tensors=return_tensors)
79
+
80
+
81
+
82
+ def preprocess_vicuna_v1(self, convs: list, return_tensors) -> list: # tokenize and concat the coversations
83
+ input_ids = None
84
+ for ind, conv in enumerate(convs):
85
+ if ind % 2 == 0: # human
86
+ h = conv['value'].strip()
87
+ h = f"USER: {h} "
88
+ cur_input_ids = self.tokenizer_image_token(prompt=h, return_tensors=return_tensors)
89
+
90
+ if input_ids is None:
91
+ input_ids = cur_input_ids
92
+ else:
93
+ input_ids = torch.cat([input_ids, cur_input_ids])
94
+
95
+ else: # gpt
96
+ g = conv['value']
97
+ if g is not None:
98
+ cur_input_ids = self.tokenizer(f"ASSISTANT: {g}</s>", add_special_tokens= False, max_length=self.maxlen, truncation=True, return_tensors='pt').input_ids[0]
99
+ input_ids = torch.cat([input_ids, cur_input_ids])
100
+ else:
101
+ cur_input_ids = self.tokenizer(f"ASSISTANT:", add_special_tokens= False, max_length=self.maxlen, truncation=True, return_tensors='pt').input_ids[0]
102
+ input_ids = torch.cat([input_ids, cur_input_ids])
103
+
104
+
105
+ return input_ids
106
+
107
+ def preprocess_allava(tokenizer, convs: list, return_tensors) -> list: # tokenize and concat the coversations
108
+ input_ids = None
109
+
110
+
111
+ for ind, conv in enumerate(convs):
112
+ if ind % 2 == 0: # human
113
+ h = conv['value'].strip()
114
+ h = f"[INST] {h} [/INST] "
115
+ cur_input_ids = tokenizer_image_token(prompt=h, tokenizer=tokenizer, return_tensors=return_tensors)
116
+
117
+ if input_ids is None:
118
+ input_ids = cur_input_ids
119
+ else:
120
+ input_ids = torch.cat([input_ids, cur_input_ids])
121
+
122
+ else: # gpt
123
+ g = conv['value']
124
+ if g is not None:
125
+ cur_input_ids = tokenizer(f"{g}{tokenizer.eos_token}", add_special_tokens= False, truncation=True, return_tensors='pt').input_ids[0]
126
+ input_ids = torch.cat([input_ids, cur_input_ids])
127
+
128
+ return input_ids
129
+
130
+
131
+ # copied from llava
132
+ def get_image_tensors(processor, images, device):
133
+ list_image_tensors = []
134
+ crop_size = processor.crop_size
135
+ for fp in images:
136
+ if fp is None: # None is used as a placeholder
137
+ list_image_tensors.append(torch.zeros(3, crop_size['height'], crop_size['width']).to(device))
138
+ continue
139
+ elif isinstance(fp, str):
140
+ image = Image.open(fp).convert('RGB')
141
+ elif isinstance(fp, Image.Image):
142
+ image = fp # already an image
143
+ else:
144
+ raise TypeError(f'Unsupported type {type(fp)}')
145
+
146
+ # this is the way of preprocessing images we used in training, so we impose it here
147
+ if True:
148
+ # self.data_args.image_aspect_ratio == 'pad'
149
+ def expand2square(pil_img, background_color):
150
+ width, height = pil_img.size
151
+ if pil_img.mode == 'L':
152
+ pil_img = pil_img.convert('RGB')
153
+
154
+ if width == height:
155
+ return pil_img
156
+ elif width > height:
157
+ result = Image.new(pil_img.mode, (width, width), background_color)
158
+ result.paste(pil_img, (0, (width - height) // 2))
159
+ return result
160
+ else:
161
+ result = Image.new(pil_img.mode, (height, height), background_color)
162
+ result.paste(pil_img, ((height - width) // 2, 0))
163
+ return result
164
+
165
+ image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
166
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
167
+ else:
168
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] # a tensor
169
+ list_image_tensors.append(image.to(device))
170
+ # list_image_tensors.append(image)
171
+ return list_image_tensors
172
+
173
+
174
+
175
+
176
+ def build_allava_input(tokenizer, processor, texts, images, history=None, return_history=False, device='cuda'):
177
+ '''
178
+ texts: [[]]
179
+ '''
180
+
181
+ ############################
182
+ # 1. preprocess texts
183
+ ############################
184
+ if isinstance(texts, str):
185
+ texts = [[texts, None]]
186
+ else:
187
+ assert isinstance(texts, list) and isinstance(texts[0], list) , 'texts must be a list of list'
188
+
189
+ if history is not None:
190
+ texts = history + texts # concat them together
191
+
192
+ texts = input_moderation(texts)
193
+
194
+
195
+ ############################
196
+ # 2. preprocess images
197
+ ############################
198
+ if isinstance(images, str) or isinstance(images, Image.Image):
199
+ images = [images]
200
+
201
+ valid_images = []
202
+ if images is None:
203
+ images = [None]
204
+
205
+ for img in images:
206
+ try:
207
+ if os.path.exists(img): # make sure that the path exists
208
+ img = Image.open(img).convert('RGB')
209
+ else: # else it must be a URL
210
+ img = Image.open(requests.get(img, stream=True).raw)
211
+
212
+ valid_images.append(img)
213
+ except:
214
+ continue
215
+
216
+ images = valid_images
217
+
218
+ if images == []:
219
+ images = [None]
220
+
221
+
222
+ assert len(images) < max_num_images, f'Currently at most {max_num_images} images are supported'
223
+
224
+ ############################
225
+ # 3. collate conv
226
+ ############################
227
+
228
+ history = deepcopy(texts) # history is the texts without <image> placeholders
229
+
230
+ # insert <image>
231
+ image_place_holder_inserted = insert_image_placeholder(texts[0][0], len(images) if None not in images else 0) # only insert the placeholders for user input at the 1st round
232
+ texts[0][0] = image_place_holder_inserted
233
+
234
+ # collate strings into conv
235
+ conv = get_conv(texts)
236
+
237
+ # make input ids
238
+ input_ids = preprocess(tokenizer, conv, return_tensors='pt').unsqueeze(0).to(device)
239
+
240
+ list_image_tensors = get_image_tensors(processor, images, device)
241
+ image_tensors = torch.stack(list_image_tensors)
242
+
243
+ try:
244
+ dtype = torch.bfloat16
245
+ # if your hardware does not support bf16, the following line raises an error
246
+ torch.tensor(1, dtype=dtype).cuda()
247
+ except:
248
+ # default using fp16
249
+ dtype = torch.float16
250
+
251
+ if return_history:
252
+ return input_ids, image_tensors, history
253
+
254
+ return input_ids, image_tensors, None
255
+
256
+
257
+
258
+ class TextIterStreamer:
259
+ def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
260
+ self.tokenizer = tokenizer
261
+ self.skip_prompt = skip_prompt
262
+ self.skip_special_tokens = skip_special_tokens
263
+ self.tokens = []
264
+ self.text_queue = Queue()
265
+ self.next_tokens_are_prompt = True
266
+
267
+ def put(self, value):
268
+ if self.skip_prompt and self.next_tokens_are_prompt:
269
+ self.next_tokens_are_prompt = False
270
+ else:
271
+ if len(value.shape) > 1:
272
+ value = value[0]
273
+ self.tokens.extend(value.tolist())
274
+ self.text_queue.put(
275
+ self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
276
+
277
+ def end(self):
278
+ self.text_queue.put(None)
279
+
280
+ def __iter__(self):
281
+ return self
282
+
283
+ def __next__(self):
284
+ value = self.text_queue.get()
285
+ if value is None:
286
+ raise StopIteration()
287
+ else:
288
+ return value