ThomasSimonini HF staff commited on
Commit
ae09ff5
·
verified ·
1 Parent(s): 0be24e7

Upload moondream.py

Browse files
Files changed (1) hide show
  1. moondream.py +193 -0
moondream.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  from .vision_encoder import VisionEncoder
3
  from .configuration_moondream import MoondreamConfig
@@ -113,6 +114,198 @@ class Moondream(PreTrainedModel):
113
  else:
114
  return cleaned_answer
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  def batch_answer(
117
  self,
118
  images,
 
1
+ """
2
  import torch
3
  from .vision_encoder import VisionEncoder
4
  from .configuration_moondream import MoondreamConfig
 
114
  else:
115
  return cleaned_answer
116
 
117
+ def batch_answer(
118
+ self,
119
+ images,
120
+ prompts,
121
+ tokenizer,
122
+ **kwargs,
123
+ ):
124
+ image_embeds = self.encode_image(images)
125
+
126
+ templated_prompts = [
127
+ f"<image>\n\nQuestion: {prompt}\n\nAnswer:" for prompt in prompts
128
+ ]
129
+ prompt_embs = [
130
+ self.input_embeds(prompt, image_embed.unsqueeze(0), tokenizer)[0]
131
+ for prompt, image_embed in zip(templated_prompts, image_embeds)
132
+ ]
133
+
134
+ bos_emb = prompt_embs[0][0]
135
+ max_len = max([p.shape[0] for p in prompt_embs])
136
+
137
+ inputs_embeds = torch.cat(
138
+ [
139
+ torch.cat([bos_emb.repeat(max_len - p.shape[0], 1), p]).unsqueeze(0)
140
+ for p in prompt_embs
141
+ ],
142
+ dim=0,
143
+ )
144
+ attention_mask = torch.cat(
145
+ [
146
+ torch.cat(
147
+ [
148
+ torch.zeros(
149
+ 1,
150
+ max_len - p.shape[0],
151
+ device=self.device,
152
+ dtype=torch.long,
153
+ ),
154
+ torch.ones(1, p.shape[0], device=self.device, dtype=torch.long),
155
+ ],
156
+ dim=1,
157
+ )
158
+ for p in prompt_embs
159
+ ],
160
+ dim=0,
161
+ )
162
+
163
+ generate_config = {
164
+ "eos_token_id": tokenizer.eos_token_id,
165
+ "bos_token_id": tokenizer.bos_token_id,
166
+ "pad_token_id": tokenizer.bos_token_id,
167
+ "max_new_tokens": 512,
168
+ **kwargs,
169
+ }
170
+
171
+ with torch.no_grad():
172
+ output_ids = self.text_model.generate(
173
+ inputs_embeds=inputs_embeds,
174
+ attention_mask=attention_mask,
175
+ **generate_config,
176
+ )
177
+
178
+ return [
179
+ x.strip()
180
+ for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
181
+ ]
182
+ """
183
+ import torch
184
+ from .vision_encoder import VisionEncoder
185
+ from .configuration_moondream import MoondreamConfig
186
+ from transformers import PreTrainedModel, TextIteratorStreamer
187
+
188
+ from .modeling_phi import PhiForCausalLM
189
+ from .configuration_moondream import PhiConfig
190
+
191
+ class Moondream(PreTrainedModel):
192
+ config_class = MoondreamConfig
193
+ _supports_flash_attn_2 = True
194
+
195
+ def __init__(self, config):
196
+ super().__init__(config)
197
+ self.vision_encoder = VisionEncoder(
198
+ use_flash_attn=config._attn_implementation == "flash_attention_2"
199
+ )
200
+
201
+ if type(config.text_config) == dict:
202
+ phi_config = PhiConfig(
203
+ **config.text_config, attn_implementation=config._attn_implementation
204
+ )
205
+ else:
206
+ phi_config = config.text_config
207
+ self.text_model = PhiForCausalLM(phi_config)
208
+
209
+ @property
210
+ def device(self):
211
+ return self.text_model.device
212
+
213
+ def encode_image(self, image):
214
+ with torch.no_grad():
215
+ return self.vision_encoder(image)
216
+
217
+ def input_embeds(self, prompt, image_embeds, tokenizer):
218
+ def _tokenize(txt):
219
+ return tokenizer(
220
+ txt, return_tensors="pt", add_special_tokens=False
221
+ ).input_ids.to(self.device)
222
+
223
+ text_emb = self.text_model.get_input_embeddings()
224
+
225
+ # Add BOS token
226
+ embeds = []
227
+ embeds.append(
228
+ text_emb((torch.tensor([[tokenizer.bos_token_id]], device=self.device)))
229
+ )
230
+
231
+ if "<image>" not in prompt:
232
+ embeds.append(text_emb(_tokenize(prompt)))
233
+ else:
234
+ assert prompt.count("<image>") == 1
235
+ before, after = prompt.split("<image>")
236
+ if len(before) > 0:
237
+ embeds.append(text_emb(_tokenize(before)))
238
+ embeds.append(image_embeds.to(self.device))
239
+ if len(after) > 0:
240
+ embeds.append(text_emb(_tokenize(after)))
241
+
242
+ return torch.cat(embeds, dim=1)
243
+
244
+ def get_input_embeddings(self):
245
+ return self.text_model.get_input_embeddings()
246
+
247
+ async def generate(
248
+ self,
249
+ image_embeds,
250
+ prompt,
251
+ tokenizer,
252
+ max_new_tokens=128,
253
+ **kwargs,
254
+ ):
255
+ generate_config = {
256
+ "eos_token_id": tokenizer.eos_token_id,
257
+ "bos_token_id": tokenizer.bos_token_id,
258
+ "pad_token_id": tokenizer.bos_token_id,
259
+ "max_new_tokens": max_new_tokens,
260
+ **kwargs,
261
+ }
262
+
263
+ with torch.no_grad():
264
+ inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
265
+ streamer = TextIteratorStreamer(tokenizer)
266
+
267
+ # Start generation in a separate thread
268
+ thread = Thread(target=self.text_model.generate, kwargs={
269
+ "inputs_embeds": inputs_embeds,
270
+ "streamer": streamer,
271
+ **generate_config
272
+ })
273
+ thread.start()
274
+
275
+ # Yield generated text as it becomes available
276
+ for new_text in streamer:
277
+ yield new_text
278
+
279
+ thread.join()
280
+ print("FINISHED")
281
+
282
+ return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
283
+
284
+ def answer_question(
285
+ self,
286
+ image_embeds,
287
+ question,
288
+ tokenizer,
289
+ chat_history="",
290
+ result_queue=None,
291
+ **kwargs,
292
+ ):
293
+ prompt = f"<image>\n\n{chat_history}Question: {question}\n\nAnswer:"
294
+ answer = self.generate(
295
+ image_embeds,
296
+ prompt,
297
+ tokenizer=tokenizer,
298
+ max_new_tokens=512,
299
+ **kwargs,
300
+ )[0]
301
+ cleaned_answer = answer.strip()
302
+
303
+ # Use the result_queue to pass the result if it is provided
304
+ if result_queue:
305
+ result_queue.put(cleaned_answer)
306
+ else:
307
+ return cleaned_answer
308
+
309
  def batch_answer(
310
  self,
311
  images,