File size: 5,188 Bytes
fd8d20e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c913cf7
4202b9f
 
c913cf7
fd8d20e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a54dbb
7eb72e8
3d72f48
9a54dbb
 
fd8d20e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import whisperx
import gradio as gr
from peft import PeftModel
from configs import get_config_phase2
from transformers import AutoTokenizer, AutoProcessor, CLIPVisionModel, AutoModelForCausalLM

config = get_config_phase2() 

clip_model = CLIPVisionModel.from_pretrained(config.get("clip_model_name"))

base_model = AutoModelForCausalLM.from_pretrained(
    config.get("phi2_model_name"),
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float32,
    trust_remote_code=True
)


ckpts = "ckpts/Qlora_adaptor/"
phi2_model = PeftModel.from_pretrained(base_model, ckpts)
phi2_model = phi2_model.merge_and_unload().to(config.get("device"))

projection_layer = torch.nn.Linear(config.get("clip_embed"), config.get("phi_embed"))
projection_layer.load_state_dict(torch.load('./ckpts/model_phase2.pth', map_location=config.get("device")))

# tokenizer
tokenizer  = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)
processor  = AutoProcessor.from_pretrained(config.get("clip_model_name"), trust_remote_code=True)

audio_model = whisperx.load_model('tiny', 'cpu', compute_type="float32")


def generate_answers(img=None, aud = None, q = None, max_tokens = 30):          
    batch_size = 1
    start_iq = tokenizer.encode("<iQ>")
    end_iq = tokenizer.encode("</iQ>")
    start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
    end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
    start_iq_embeds = phi2_model.model.embed_tokens(start_iq_embeds.to(config.get("device")))
    end_iq_embeds = phi2_model.model.embed_tokens(end_iq_embeds.to(config.get("device")))
    
    inputs_embeddings = []
    inputs_embeddings.append(start_iq_embeds)

    predicted_caption = torch.full((batch_size, max_tokens), 50256, dtype=torch.long, device=config.get('device'))
    
    if img is not None:
        images = processor(images=img, return_tensors="pt")['pixel_values'].to(config.get("device"))
        images = {'pixel_values': images.to(config.get("device"))}
        clip_outputs = clip_model(**images)
        # remove cls token
        images = clip_outputs.last_hidden_state[:, 1:, :]
        image_embeddings = projection_layer(images).to(torch.float32)
        inputs_embeddings.append(image_embeddings)
    
    if aud is not None:
        trans = audio_model.transcribe(aud)
        audio_res = ""
        for seg in trans['segments']:
            audio_res += seg['text']
        audio_res = audio_res.strip()
        audio_tokens = tokenizer(audio_res,return_tensors="pt", return_attention_mask=False)['input_ids']
        audio_embeds = phi2_model.model.embed_tokens(audio_tokens.to(config.get("device")))
        inputs_embeddings.append(audio_embeds)
        
    if q!='':
        ques = tokenizer(q, return_tensors="pt", return_attention_mask=False)['input_ids']
        q_embeds = phi2_model.model.embed_tokens(ques.to(config.get("device")))
        inputs_embeddings.append(q_embeds)
        
    inputs_embeddings.append(end_iq_embeds)
    # Combine embeddings
    combined_embeds  = torch.cat(inputs_embeddings, dim=1)
    predicted_caption = phi2_model.generate(inputs_embeds=combined_embeds,
                                                  max_new_tokens=max_tokens,
                                                  return_dict_in_generate = True)

    predicted_captions_decoded =tokenizer.batch_decode(predicted_caption.sequences[:, 1:])[0]
    predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>","")
    return predicted_captions_decoded

# List of examples (image, audio, question, max_tokens)
examples = [
    ["./examples/Image_2.jpg","./examples/Image_2.wav", "How many animals are there in image?", 10],
    ["./examples/Image_1.jpg","./examples/General.wav", "Whhat is there in Image?", 20],
    ["./examples/Image_3.jpg","./examples/General.wav", "Which animal is this?", 20],
    ["./examples/Image_4.jpg","./examples/General.wav", "What represents this Image?", 20],
]

with gr.Blocks() as demo:

    gr.Markdown(
    """
    # MultiModelLLM
    Multimodel GPT with inputs as Image, Audio, Text with output as Text.
    """
    )

    with gr.Row():
        with gr.Column():
            image = gr.Image(label='Image', type="pil", value=None)
            audio_q = gr.Audio(label="Audio Question", value=None, sources=['microphone', 'upload'], type='filepath')
            question = gr.Text(label ='Question?', value=None)
            max_tokens = gr.Slider(1, 50, value=10, step=1, label="Max tokens")
    with gr.Row():
        answer   = gr.Text(label ='Answer')
    with gr.Row():
        submit = gr.Button("Submit")
        submit.click(generate_answers, inputs=[image, audio_q, question, max_tokens], outputs=[answer])
        clear_btn = gr.ClearButton([image, audio_q, question, max_tokens, answer])
    # Add examples
    # gr.Examples(examples=examples, inputs=[image, audio_q, question, max_tokens], outputs=answer)
    # Add examples
    gr.Examples(examples=examples, fn = generate_answers, inputs=[image, audio_q, question, max_tokens], outputs=answer)



    
if __name__ == "__main__":
    
    demo.launch(share=True, debug=True)