Sijuade commited on
Commit
b7be07b
1 Parent(s): 9cbbd7c

Upload 6 files

Browse files
Files changed (6) hide show
  1. MLMbanner.py +144 -0
  2. app.py +26 -0
  3. config.py +15 -0
  4. networks.py +79 -0
  5. requirements.txt +6 -0
  6. utils.py +90 -0
MLMbanner.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ def get_html():
4
+ html_string = """
5
+ <div id="banner">
6
+ <div id="particles-js"></div>
7
+ <h1>Multimodal Chatbot</h1>
8
+ <p>A chatbot that accepts text, audio, and images.</p>
9
+ <div class="icons">
10
+ <div class="icon" id="text-icon">&#128172;</div> <!-- Text Bubble Emoji -->
11
+ <div class="icon" id="audio-icon">&#127911;</div> <!-- Headphone Emoji -->
12
+ <div class="icon" id="image-icon">&#128247;</div> <!-- Camera Emoji -->
13
+ </div>
14
+ </div>
15
+
16
+ <style>
17
+ #banner {
18
+ background: linear-gradient(270deg, #6c5ce7, #a29bfe, #fd79a8);
19
+ background-size: 600% 600%;
20
+ color: white;
21
+ text-align: center;
22
+ padding: 30px;
23
+ border-radius: 15px;
24
+ box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2);
25
+ position: relative;
26
+ overflow: hidden;
27
+ animation: AnimatedGradient 15s ease infinite;
28
+ }
29
+
30
+ #particles-js {
31
+ position: absolute;
32
+ width: 100%;
33
+ height: 100%;
34
+ top: 0;
35
+ left: 0;
36
+ z-index: 1;
37
+ }
38
+
39
+ #banner > * {
40
+ position: relative;
41
+ z-index: 2;
42
+ }
43
+
44
+ #banner h1 {
45
+ font-size: 2.8em;
46
+ margin-bottom: 10px;
47
+ animation: fadeInDown 1.5s ease-in-out;
48
+ }
49
+
50
+ #banner p {
51
+ font-size: 1.3em;
52
+ animation: fadeInUp 1.5s ease-in-out;
53
+ }
54
+
55
+ .icons {
56
+ display: flex;
57
+ justify-content: center;
58
+ margin-top: 20px;
59
+ }
60
+
61
+ .icon {
62
+ font-size: 2em;
63
+ margin: 0 10px;
64
+ animation: bounce 2s infinite;
65
+ transition: transform 0.2s;
66
+ }
67
+
68
+ .icon:hover {
69
+ transform: scale(1.1);
70
+ }
71
+
72
+ @keyframes fadeInDown {
73
+ from { opacity: 0; transform: translateY(-20px); }
74
+ to { opacity: 1; transform: translateY(0); }
75
+ }
76
+
77
+ @keyframes fadeInUp {
78
+ from { opacity: 0; transform: translateY(20px); }
79
+ to { opacity: 1; transform: translateY(0); }
80
+ }
81
+
82
+ @keyframes bounce {
83
+ 0%, 100% { transform: translateY(0); }
84
+ 50% { transform: translateY(-10px); }
85
+ }
86
+
87
+ @keyframes AnimatedGradient {
88
+ 0%{background-position:0% 50%}
89
+ 50%{background-position:100% 50%}
90
+ 100%{background-position:0% 50%}
91
+ }
92
+ </style>
93
+
94
+ <script src="https://cdn.jsdelivr.net/particles.js/2.0.0/particles.min.js"></script>
95
+ <script>
96
+ document.addEventListener("DOMContentLoaded", function() {
97
+ particlesJS("particles-js", {
98
+ "particles": {
99
+ "number": {
100
+ "value": 80,
101
+ "density": {
102
+ "enable": true,
103
+ "value_area": 800
104
+ }
105
+ },
106
+ "color": {
107
+ "value": "#ffffff"
108
+ },
109
+ "shape": {
110
+ "type": "circle",
111
+ "stroke": {
112
+ "width": 0,
113
+ "color": "#000000"
114
+ },
115
+ "polygon": {
116
+ "nb_sides": 5
117
+ }
118
+ },
119
+ "opacity": {
120
+ "value": 0.5,
121
+ "random": false,
122
+ "anim": {
123
+ "enable": false,
124
+ "speed": 1,
125
+ "opacity_min": 0.1,
126
+ "sync": false
127
+ }
128
+ },
129
+ "size": {
130
+ "value": 3,
131
+ "random": true,
132
+ "anim": {
133
+ "enable": false,
134
+ "speed": 40,
135
+ "size_min": 0.1,
136
+ "sync": false
137
+ }
138
+ },
139
+ "line_linked": {
140
+ "enable
141
+
142
+ """
143
+
144
+ return(html_string)
app.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from MLMbanner import get_html
3
+ from utils import chatbot_response
4
+
5
+
6
+ with gr.Blocks() as demo:
7
+ gr.HTML(value=get_html, show_label=True)
8
+
9
+ with gr.Row():
10
+ text_input = gr.Textbox(label="Enter text", lines=10)
11
+ image_input = gr.Image(label="Upload image", type="pil")
12
+ audio_input = gr.Audio(label="Record or upload audio",
13
+ type="filepath",
14
+ sources=['microphone', 'upload'])
15
+
16
+ submit_button = gr.Button("Submit")
17
+
18
+ output = gr.Textbox(label="Chatbot Response", lines=10)
19
+
20
+ submit_button.click(
21
+ fn=chatbot_response,
22
+ inputs=[text_input, image_input, audio_input],
23
+ outputs=output
24
+ )
25
+
26
+ demo.launch()
config.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoProcessor, AutoTokenizer
3
+
4
+ class Config:
5
+
6
+ EOS_TOKEN_ID = 50256
7
+ QUESTION_ANSWER_SEPARATOR_ID = 50295 # Special token ID for question-answer separation
8
+ IMAGE_SEPARATOR_TOKENS = [685, 36259, 14041, 60, 220]
9
+
10
+ phi_model_name = "microsoft/phi-2"
11
+ model_name = "openai/clip-vit-base-patch32"
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ processor = AutoProcessor.from_pretrained(model_name)
15
+ tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
networks.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import peft
2
+ import torch
3
+ import whisperx
4
+ import torch.nn as nn
5
+ from config import Config
6
+ from transformers import CLIPVisionModel, AutoModelForCausalLM
7
+
8
+
9
+ phi_model_name, model_name, device = Config.phi_model_name, Config.model_name, Config.device
10
+
11
+ text_model = AutoModelForCausalLM.from_pretrained(phi_model_name,
12
+ torch_dtype=torch.float16,
13
+ #device_map="cuda",
14
+ low_cpu_mem_usage=True,
15
+ return_dict=True,
16
+ trust_remote_code=True)
17
+
18
+ peft_model = peft.PeftModel.from_pretrained(text_model, 'models/29000')
19
+ projection = load_projection_model("models/MModalGPT-FINETUNE-step=29000-loss=3.45.ckpt", 768, 2560)
20
+
21
+ clip_model = CLIPVisionModel.from_pretrained(model_name)
22
+ audio_model = whisperx.load_model("small", device.type, compute_type="float16")
23
+
24
+
25
+ projection = projection.to(device)
26
+ peft_model = peft_model.to(device)
27
+ clip_model = clip_model.to(device)
28
+
29
+
30
+ def load_projection_model(path, clip_embed, phi_embed):
31
+ """Loads a Projections model instance from a checkpoint and returns it with weights loaded.
32
+
33
+ Args:
34
+ path (str): Path to the checkpoint file.
35
+
36
+ Returns:
37
+ torch.nn.Module: The loaded Projections model instance.
38
+ """
39
+
40
+ state_dict = torch.load(path)['state_dict']
41
+ new_state_dict = {k.replace('projection.', ''): v for k, v in state_dict.items()}
42
+
43
+ model = Projections(clip_embed, phi_embed)
44
+ model.load_state_dict(new_state_dict)
45
+
46
+ return model
47
+
48
+
49
+ class Projections(nn.Module):
50
+ def __init__(
51
+ self,
52
+ clip_embed,
53
+ phi_embed,
54
+ num_projection_layers=6,
55
+ ):
56
+ super().__init__()
57
+
58
+ self.norm = nn.LayerNorm(phi_embed)
59
+ self.output = nn.Linear(clip_embed, phi_embed)
60
+ self.projection_layers = nn.ModuleList(
61
+ [
62
+ nn.Sequential(
63
+ nn.Linear(phi_embed, phi_embed),
64
+ nn.GELU(),
65
+ nn.Linear(phi_embed, phi_embed),
66
+ )
67
+ for _ in range(num_projection_layers)
68
+ ]
69
+ )
70
+
71
+ def forward(self, x):
72
+ x = self.output(x)
73
+ self.norm(x)
74
+ for layer in self.projection_layers:
75
+ residual = x
76
+ x = layer(x) + residual
77
+
78
+ return x
79
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ pandas
3
+ pillow
4
+ git+https://github.com/huggingface/transformers
5
+ git+https://github.com/m-bain/whisperx.git
6
+ git+https://github.com/huggingface/peft.git
utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from config import Config
3
+ from networks import peft_model
4
+
5
+
6
+ tokenizer = Config.tokenizer
7
+ tokenizer.pad_token = tokenizer.eos_token
8
+ tokenizer.add_tokens('<question-answer>')
9
+
10
+
11
+ def prepare_inputs(peft_model, audio_model, clip_model, projection, text_input=None, image_input=None, audio_input=None):
12
+
13
+ text_audio, text_embed, image_embed = None, None, None
14
+
15
+ if audio_input:
16
+ audio_transcribed = audio_model.transcribe(audio_input)
17
+ processed_audio = ''
18
+
19
+ for audio_segment in audio_transcribed['segments']:
20
+ processed_audio += audio_segment['text']
21
+
22
+ processed_audio = processed_audio.strip()
23
+
24
+ if image_input != None:
25
+ image_processed = Config.processor(images=image_input, return_tensors="pt")
26
+
27
+ with torch.no_grad():
28
+ outputs = clip_model(**image_processed)
29
+ last_hidden_state = outputs.last_hidden_state[:, 1:, :]
30
+ image_embed = projection(last_hidden_state.to(Config.device)).to(torch.float16)
31
+
32
+ if audio_input != None and text_input != None:
33
+ text_audio = f"{text_input} {processed_audio}"
34
+ elif audio_input and text_input == None:
35
+ text_audio = processed_audio
36
+ elif audio_input == None and text_input:
37
+ text_audio = text_input
38
+
39
+ if text_audio:
40
+ tokenized_text_audio = tokenizer.encode(text_audio)
41
+ tokenized_text_audio = Config.IMAGE_SEPARATOR_TOKENS + tokenized_text_audio + [Config.QUESTION_ANSWER_SEPARATOR_ID]
42
+
43
+ with torch.no_grad():
44
+ tokenized_text_audio = torch.tensor(tokenized_text_audio)
45
+ text_embed = peft_model.model.model.embed_tokens(tokenized_text_audio.to(Config.device)).unsqueeze(0)
46
+
47
+
48
+ if text_audio != None and image_input != None:
49
+ combined_embed = torch.cat([image_embed, text_embed], dim=1)
50
+ elif text_audio and image_input == None:
51
+ combined_embed = text_embed
52
+ elif text_audio == None and image_input:
53
+ combined_embed = image_embed
54
+
55
+ return(combined_embed)
56
+
57
+
58
+ def chatbot_response(text_input, image_input, audio_input):
59
+
60
+ if text_input == '':
61
+ text_input = None
62
+
63
+ if text_input == None and image_input == None and audio_input == None:
64
+ return "Please enter text, upload an image, or record audio."
65
+
66
+ combined_embeds = prepare_inputs(text_input, image_input, audio_input)
67
+ generated_tokens = generate_tokens(combined_embeds, max_tokens=60)
68
+ return(tokenizer.decode(generated_tokens))
69
+
70
+
71
+
72
+ def generate_tokens(combined_embeds, max_tokens=100):
73
+ pred_tokens = []
74
+
75
+ combined_embed = combined_embeds
76
+
77
+ for _ in range(max_tokens):
78
+ logits = peft_model(inputs_embeds=combined_embed).logits[:, -1, :]
79
+ next_token_id = logits.argmax(dim=-1)
80
+
81
+ if next_token_id.item() == 50256:
82
+ break
83
+
84
+ pred_tokens.append(next_token_id.item())
85
+ next_token_embed = peft_model.model.model.embed_tokens(next_token_id.unsqueeze(0))
86
+
87
+ with torch.no_grad():
88
+ combined_embed = torch.cat((combined_embed, next_token_embed), dim=1)
89
+
90
+ return(pred_tokens)