Create SMOLLM_VisionModel.py
Browse files- SMOLLM_VisionModel.py +41 -0
SMOLLM_VisionModel.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class SMOLLm_VISION_ImageCaptioning(torch.nn.Module):
|
4 |
+
def __init__(self, llm_model, hidden_dim):
|
5 |
+
super(ImageCaptioningModel, self).__init__()
|
6 |
+
self.llm_model = llm_model
|
7 |
+
self.fc = torch.nn.Linear(768, 960)
|
8 |
+
self.relu=torch.nn.GELU()
|
9 |
+
def forward(self, images, input_ids,att):
|
10 |
+
# Encode images
|
11 |
+
image_features = self.relu(self.fc(images))
|
12 |
+
#image_att=torch.zeros([images.shape[0],]).view(-1,1).to('cuda:0')
|
13 |
+
|
14 |
+
# Prepare text inputs for LLaMA2
|
15 |
+
llama_inputs = self.llm_model.prepare_inputs_for_generation(input_ids)
|
16 |
+
with torch.no_grad():
|
17 |
+
llama_embeds=self.llm_model.get_input_embeddings()(llama_inputs['input_ids'])
|
18 |
+
|
19 |
+
# Concatenate image features with LLaMA2 text inputs
|
20 |
+
combined_inputs = torch.cat([image_features.unsqueeze(1).float(),llama_embeds], dim=1)
|
21 |
+
#attention_mask=torch.cat((image_att,att),dim=-1)
|
22 |
+
outputs = self.llm_model(inputs_embeds=combined_inputs,attention_mask=att)
|
23 |
+
|
24 |
+
return outputs.logits[:,1:,:],combined_inputs
|
25 |
+
#return
|
26 |
+
|
27 |
+
class SmoLLM_processor():
|
28 |
+
def __init__(self,image_model,image_processor):
|
29 |
+
self.image_model=image_model
|
30 |
+
self.image_processor=image_processor
|
31 |
+
|
32 |
+
def get_features(self,image):
|
33 |
+
inputs = clip_processor(images=image, return_tensors="pt")
|
34 |
+
with torch.no_grad():
|
35 |
+
image_features = clip_model.get_image_features(**inputs.to('cuda:0')).squeeze()
|
36 |
+
#tokenized=tokenizer(prompt,padding=True, return_tensors='pt', max_length=50)
|
37 |
+
#input_ids=tokenized['input_ids'].squeeze() #image=tfms(image/255.)
|
38 |
+
#attention_mask=tokenized['attention_mask'].squeeze()
|
39 |
+
#x=input_ids[:-1]
|
40 |
+
#y=input_ids[1:]
|
41 |
+
return image_features
|