Mohammad Ibrahim commited on
Commit
b2db059
1 Parent(s): 9ad123d

initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ /model_chkpt/*.pth filter=lfs diff=lfs merge=lfs -text
37
+ /model_chkpt/lora_adaptor/*.safetensors filter=lfs diff=lfs merge=lfs -text
38
+ /model_chkpt/lora_adaptor/*.json filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gui
2
+ import peft
3
+ from peft import LoraConfig
4
+ from transformers import AutoTokenizer,BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
5
+ import torch
6
+ from peft import PeftModel
7
+ import torch.nn as nn
8
+ import whisper
9
+ import os
10
+
11
+ os.environ['https_proxy'] = 'http://185.46.212.90:80'
12
+ os.environ['http_proxy'] = 'http://185.46.212.90:80'
13
+ clip_model_name = "openai/clip-vit-base-patch32"
14
+ phi_model_name = "microsoft/phi-2"
15
+ tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
16
+ processor = AutoProcessor.from_pretrained(clip_model_name)
17
+ tokenizer.pad_token = tokenizer.eos_token
18
+ IMAGE_TOKEN_ID = 23893 # token for word comment
19
+ QA_TOKEN_ID = 50295 # token for qa
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ clip_embed = 768
22
+ phi_embed = 2560
23
+ audio_batch_size = 16
24
+ current_dir = os.getcwd()
25
+
26
+ class SimpleResBlock(nn.Module):
27
+ def __init__(self, phi_embed):
28
+ super().__init__()
29
+ self.pre_norm = nn.LayerNorm(phi_embed)
30
+ self.proj = nn.Sequential(
31
+ nn.Linear(phi_embed, phi_embed),
32
+ nn.GELU(),
33
+ nn.Linear(phi_embed, phi_embed)
34
+ )
35
+ def forward(self, x):
36
+ x = self.pre_norm(x)
37
+ return x + self.proj(x)
38
+
39
+ # models
40
+ clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
41
+ projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
42
+ resblock = SimpleResBlock(phi_embed).to(device)
43
+ phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
44
+ audio_model = whisper.load_model("tiny", device=device)
45
+
46
+ lora_adaptor_path = os.path.join(current_dir, 'model_chkpt', 'lora_adaptor')
47
+ projection_path = os.path.join(current_dir, 'model_chkpt', 'step2_projection.pth')
48
+ resblock_path = os.path.join(current_dir, 'model_chkpt', 'step2_resblock.pth')
49
+
50
+ # load weights
51
+ model_to_merge = PeftModel.from_pretrained(phi_model,lora_adaptor_path, local_files_only=True, device_map={'': device})
52
+
53
+ merged_model = model_to_merge.merge_and_unload()
54
+ projection.load_state_dict(torch.load(projection_path,map_location=torch.device(device)))
55
+ resblock.load_state_dict(torch.load(resblock_path,map_location=torch.device(device)))
56
+
57
+ def generate_response(img=None,img_audio=None,val_q=None):
58
+
59
+ max_generate_length = 100
60
+ val_combined_embeds = []
61
+
62
+ with torch.no_grad():
63
+
64
+ # image
65
+ if img is not None:
66
+ image_processed = processor(images=img, return_tensors="pt").to(device)
67
+ clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
68
+ val_image_embeds = projection(clip_val_outputs)
69
+ val_image_embeds = resblock(val_image_embeds).to(torch.float16)
70
+
71
+ img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
72
+ img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
73
+
74
+ val_combined_embeds.append(val_image_embeds)
75
+ val_combined_embeds.append(img_token_embeds)
76
+
77
+ # audio
78
+ if img_audio is not None:
79
+ audio_result = audio_model.transcribe(img_audio)
80
+ audio_text = ''
81
+ for seg in audio_result['segments']:
82
+ audio_text += seg['text']
83
+ audio_text = audio_text.strip()
84
+ audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
85
+ audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
86
+ val_combined_embeds.append(audio_embeds)
87
+
88
+ # text question
89
+ if len(val_q) != 0:
90
+ val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
91
+ val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
92
+ val_combined_embeds.append(val_q_embeds)
93
+
94
+
95
+ if img_audio is not None or len(val_q) != 0: # add QA Token
96
+
97
+ QA_token_tensor = torch.tensor(QA_TOKEN_ID).to(device)
98
+ QA_token_embeds = merged_model.model.embed_tokens(QA_token_tensor).unsqueeze(0).unsqueeze(0)
99
+ val_combined_embeds.append(QA_token_embeds)
100
+
101
+ val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
102
+ predicted_caption = merged_model.generate(inputs_embeds=val_combined_embeds,
103
+ max_new_tokens=max_generate_length,
104
+ return_dict_in_generate = True)
105
+
106
+ predicted_captions_decoded = tokenizer.batch_decode(predicted_caption.sequences[:, 1:])[0]
107
+ predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>", "")
108
+
109
+ return predicted_captions_decoded
110
+
111
+
112
+ # Gradio interface setup with added styling
113
+ with gui.Blocks() as app_interface:
114
+
115
+ with gui.Row():
116
+ with gui.Column():
117
+ image_input = gui.Image(label='Upload Image', type="pil")
118
+ with gui.Column():
119
+ audio_input = gui.Audio(label="Audio Input", sources=['microphone', 'upload'], type='filepath')
120
+ text_input = gui.Text(label='Enter Text', placeholder="Type your query here...")
121
+ with gui.Row():
122
+ output_response = gui.Textbox(label='Generated Response', placeholder="Response will appear here...", lines=5)
123
+ submit_button = gui.Button("Generate Response", variant="primary")
124
+ submit_button.click(generate_response, inputs=[image_input, audio_input, text_input], outputs=output_response)
125
+
126
+ if __name__ == "__main__":
127
+ app_interface.launch(share=True)
model_chkpt/lora_adaptor/.ipynb_checkpoints/adapter_config-checkpoint.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "microsoft/phi-2",
5
+ "bias": "none",
6
+ "fan_in_fan_out": false,
7
+ "inference_mode": true,
8
+ "init_lora_weights": true,
9
+ "layers_pattern": null,
10
+ "layers_to_transform": null,
11
+ "loftq_config": {},
12
+ "lora_alpha": 16,
13
+ "lora_dropout": 0.1,
14
+ "megatron_config": null,
15
+ "megatron_core": "megatron.core",
16
+ "modules_to_save": null,
17
+ "peft_type": "LORA",
18
+ "r": 64,
19
+ "rank_pattern": {},
20
+ "revision": null,
21
+ "target_modules": [
22
+ "fc2",
23
+ "v_proj",
24
+ "fc1",
25
+ "k_proj",
26
+ "q_proj"
27
+ ],
28
+ "task_type": "CAUSAL_LM"
29
+ }
model_chkpt/lora_adaptor/adapter_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23212cff0181dce73efde3004af2309e9ad4a13ace72aa36d3d236874d85b8e4
3
+ size 603
model_chkpt/lora_adaptor/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ed4cfe545b28f8effe9ded98b472502a93a5e5e42edd6384591f0e1d71c3770
3
+ size 335586800
model_chkpt/step2_projection.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ec5d298b71c4b50f2b626e6df9a73d02d012da0794c1b768610fe52f4a8f860
3
+ size 7876174
model_chkpt/step2_resblock.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c161997824ad80b1af83639cd051a8beda84d4967be50982821249c509fab62c
3
+ size 52472590