alibidaran commited on
Commit
a1fa153
·
verified ·
1 Parent(s): 8b904e9

Create SMOLLM_VisionModel.py

Browse files
Files changed (1) hide show
  1. SMOLLM_VisionModel.py +41 -0
SMOLLM_VisionModel.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class SMOLLm_VISION_ImageCaptioning(torch.nn.Module):
4
+ def __init__(self, llm_model, hidden_dim):
5
+ super(ImageCaptioningModel, self).__init__()
6
+ self.llm_model = llm_model
7
+ self.fc = torch.nn.Linear(768, 960)
8
+ self.relu=torch.nn.GELU()
9
+ def forward(self, images, input_ids,att):
10
+ # Encode images
11
+ image_features = self.relu(self.fc(images))
12
+ #image_att=torch.zeros([images.shape[0],]).view(-1,1).to('cuda:0')
13
+
14
+ # Prepare text inputs for LLaMA2
15
+ llama_inputs = self.llm_model.prepare_inputs_for_generation(input_ids)
16
+ with torch.no_grad():
17
+ llama_embeds=self.llm_model.get_input_embeddings()(llama_inputs['input_ids'])
18
+
19
+ # Concatenate image features with LLaMA2 text inputs
20
+ combined_inputs = torch.cat([image_features.unsqueeze(1).float(),llama_embeds], dim=1)
21
+ #attention_mask=torch.cat((image_att,att),dim=-1)
22
+ outputs = self.llm_model(inputs_embeds=combined_inputs,attention_mask=att)
23
+
24
+ return outputs.logits[:,1:,:],combined_inputs
25
+ #return
26
+
27
+ class SmoLLM_processor():
28
+ def __init__(self,image_model,image_processor):
29
+ self.image_model=image_model
30
+ self.image_processor=image_processor
31
+
32
+ def get_features(self,image):
33
+ inputs = clip_processor(images=image, return_tensors="pt")
34
+ with torch.no_grad():
35
+ image_features = clip_model.get_image_features(**inputs.to('cuda:0')).squeeze()
36
+ #tokenized=tokenizer(prompt,padding=True, return_tensors='pt', max_length=50)
37
+ #input_ids=tokenized['input_ids'].squeeze() #image=tfms(image/255.)
38
+ #attention_mask=tokenized['attention_mask'].squeeze()
39
+ #x=input_ids[:-1]
40
+ #y=input_ids[1:]
41
+ return image_features