vikhyatk commited on
Commit
c064b4c
1 Parent(s): fc49662

Upload Moondream

Browse files
Files changed (3) hide show
  1. model.safetensors +1 -1
  2. moondream.py +73 -2
  3. vision_encoder.py +8 -6
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7eb78782910cad8e8ba5a90146a640ef48b420524d8e946ecb65519834145acc
3
  size 3715037856
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:493ac8972766b8e4b9005bfab11454b93aab4987b44a01debebec3fa96773105
3
  size 3715037856
moondream.py CHANGED
@@ -9,13 +9,16 @@ from .configuration_moondream import PhiConfig
9
 
10
  class Moondream(PreTrainedModel):
11
  config_class = MoondreamConfig
 
12
 
13
  def __init__(self, config):
14
  super().__init__(config)
15
  self.vision_encoder = VisionEncoder()
16
 
17
  if type(config.phi_config) == dict:
18
- phi_config = PhiConfig(**config.phi_config)
 
 
19
  else:
20
  phi_config = config.phi_config
21
  self.text_model = PhiForCausalLM(phi_config)
@@ -94,7 +97,7 @@ class Moondream(PreTrainedModel):
94
  prompt,
95
  eos_text="<END>",
96
  tokenizer=tokenizer,
97
- max_new_tokens=256,
98
  **kwargs,
99
  )[0]
100
  cleaned_answer = re.sub("<$|<END$", "", answer).strip()
@@ -104,3 +107,71 @@ class Moondream(PreTrainedModel):
104
  result_queue.put(cleaned_answer)
105
  else:
106
  return cleaned_answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  class Moondream(PreTrainedModel):
11
  config_class = MoondreamConfig
12
+ _supports_flash_attn_2 = True
13
 
14
  def __init__(self, config):
15
  super().__init__(config)
16
  self.vision_encoder = VisionEncoder()
17
 
18
  if type(config.phi_config) == dict:
19
+ phi_config = PhiConfig(
20
+ **config.phi_config, attn_implementation=config._attn_implementation
21
+ )
22
  else:
23
  phi_config = config.phi_config
24
  self.text_model = PhiForCausalLM(phi_config)
 
97
  prompt,
98
  eos_text="<END>",
99
  tokenizer=tokenizer,
100
+ max_new_tokens=512,
101
  **kwargs,
102
  )[0]
103
  cleaned_answer = re.sub("<$|<END$", "", answer).strip()
 
107
  result_queue.put(cleaned_answer)
108
  else:
109
  return cleaned_answer
110
+
111
+ def batch_answer(
112
+ self,
113
+ images,
114
+ prompts,
115
+ tokenizer,
116
+ **kwargs,
117
+ ):
118
+ eos_tokens = tokenizer("<END>", add_special_tokens=False)[0].ids
119
+
120
+ image_embeds = self.encode_image(images)
121
+
122
+ templated_prompts = [
123
+ f"<image>\n\nQuestion: {prompt}\n\nAnswer: " for prompt in prompts
124
+ ]
125
+ prompt_embs = [
126
+ self.input_embeds(prompt, image_embed.unsqueeze(0), tokenizer)[0]
127
+ for prompt, image_embed in zip(templated_prompts, image_embeds)
128
+ ]
129
+
130
+ bos_emb = prompt_embs[0][0]
131
+ max_len = max([p.shape[0] for p in prompt_embs])
132
+
133
+ inputs_embeds = torch.cat(
134
+ [
135
+ torch.cat([bos_emb.repeat(max_len - p.shape[0], 1), p]).unsqueeze(0)
136
+ for p in prompt_embs
137
+ ],
138
+ dim=0,
139
+ )
140
+ attention_mask = torch.cat(
141
+ [
142
+ torch.cat(
143
+ [
144
+ torch.zeros(
145
+ 1,
146
+ max_len - p.shape[0],
147
+ device=self.device,
148
+ dtype=torch.long,
149
+ ),
150
+ torch.ones(1, p.shape[0], device=self.device, dtype=torch.long),
151
+ ],
152
+ dim=1,
153
+ )
154
+ for p in prompt_embs
155
+ ],
156
+ dim=0,
157
+ )
158
+
159
+ generate_config = {
160
+ "eos_token_id": eos_tokens,
161
+ "bos_token_id": tokenizer.bos_token_id,
162
+ "pad_token_id": tokenizer.eos_token_id,
163
+ "max_new_tokens": 512,
164
+ **kwargs,
165
+ }
166
+
167
+ with torch.no_grad():
168
+ output_ids = self.text_model.generate(
169
+ inputs_embeds=inputs_embeds,
170
+ attention_mask=attention_mask,
171
+ **generate_config,
172
+ )
173
+
174
+ return [
175
+ re.sub("<$|<END$", "", x).strip()
176
+ for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
177
+ ]
vision_encoder.py CHANGED
@@ -121,13 +121,15 @@ class VisionEncoder(nn.Module):
121
  def dtype(self):
122
  return self.projection.mlp.fc1.weight.dtype
123
 
124
- def __call__(self, image: Image) -> torch.Tensor:
 
 
 
125
  with torch.no_grad():
126
- x = (
127
- self.preprocess(image.convert("RGB"))
128
- .unsqueeze(0)
129
- .to(self.device, dtype=self.dtype)
130
- )
131
  x = rearrange(x, "b c (h p1) (w p2) -> b (h w) (c p1 p2)", p1=14, p2=14)
132
 
133
  x = self.encoder(x)
 
121
  def dtype(self):
122
  return self.projection.mlp.fc1.weight.dtype
123
 
124
+ def __call__(self, images) -> torch.Tensor:
125
+ if not isinstance(images, list):
126
+ images = [images]
127
+
128
  with torch.no_grad():
129
+ x = torch.stack(
130
+ [self.preprocess(image.convert("RGB")) for image in images]
131
+ ).to(self.device, dtype=self.dtype)
132
+
 
133
  x = rearrange(x, "b c (h p1) (w p2) -> b (h w) (c p1 p2)", p1=14, p2=14)
134
 
135
  x = self.encoder(x)