|
import torch |
|
import whisperx |
|
import gradio as gr |
|
from peft import PeftModel |
|
from configs import get_config_phase2 |
|
from transformers import AutoTokenizer, AutoProcessor, CLIPVisionModel, AutoModelForCausalLM |
|
|
|
config = get_config_phase2() |
|
|
|
clip_model = CLIPVisionModel.from_pretrained(config.get("clip_model_name")) |
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
config.get("phi2_model_name"), |
|
low_cpu_mem_usage=True, |
|
return_dict=True, |
|
torch_dtype=torch.float32, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
ckpts = "ckpts/Qlora_adaptor/" |
|
phi2_model = PeftModel.from_pretrained(base_model, ckpts) |
|
phi2_model = phi2_model.merge_and_unload().to(config.get("device")) |
|
|
|
projection_layer = torch.nn.Linear(config.get("clip_embed"), config.get("phi_embed")) |
|
projection_layer.load_state_dict(torch.load('./ckpts/model_phase2.pth', map_location=config.get("device"))) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True) |
|
processor = AutoProcessor.from_pretrained(config.get("clip_model_name"), trust_remote_code=True) |
|
|
|
audio_model = whisperx.load_model('tiny', 'cpu', compute_type="float32") |
|
|
|
|
|
def generate_answers(img=None, aud = None, q = None, max_tokens = 30): |
|
batch_size = 1 |
|
start_iq = tokenizer.encode("<iQ>") |
|
end_iq = tokenizer.encode("</iQ>") |
|
start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1) |
|
end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1) |
|
start_iq_embeds = phi2_model.model.embed_tokens(start_iq_embeds.to(config.get("device"))) |
|
end_iq_embeds = phi2_model.model.embed_tokens(end_iq_embeds.to(config.get("device"))) |
|
|
|
inputs_embeddings = [] |
|
inputs_embeddings.append(start_iq_embeds) |
|
|
|
predicted_caption = torch.full((batch_size, max_tokens), 50256, dtype=torch.long, device=config.get('device')) |
|
|
|
if img is not None: |
|
images = processor(images=img, return_tensors="pt")['pixel_values'].to(config.get("device")) |
|
images = {'pixel_values': images.to(config.get("device"))} |
|
clip_outputs = clip_model(**images) |
|
|
|
images = clip_outputs.last_hidden_state[:, 1:, :] |
|
image_embeddings = projection_layer(images).to(torch.float32) |
|
inputs_embeddings.append(image_embeddings) |
|
|
|
if aud is not None: |
|
trans = audio_model.transcribe(aud) |
|
audio_res = "" |
|
for seg in trans['segments']: |
|
audio_res += seg['text'] |
|
audio_res = audio_res.strip() |
|
audio_tokens = tokenizer(audio_res,return_tensors="pt", return_attention_mask=False)['input_ids'] |
|
audio_embeds = phi2_model.model.embed_tokens(audio_tokens.to(config.get("device"))) |
|
inputs_embeddings.append(audio_embeds) |
|
|
|
if q!='': |
|
ques = tokenizer(q, return_tensors="pt", return_attention_mask=False)['input_ids'] |
|
q_embeds = phi2_model.model.embed_tokens(ques.to(config.get("device"))) |
|
inputs_embeddings.append(q_embeds) |
|
|
|
inputs_embeddings.append(end_iq_embeds) |
|
|
|
combined_embeds = torch.cat(inputs_embeddings, dim=1) |
|
predicted_caption = phi2_model.generate(inputs_embeds=combined_embeds, |
|
max_new_tokens=max_tokens, |
|
return_dict_in_generate = True) |
|
|
|
predicted_captions_decoded =tokenizer.batch_decode(predicted_caption.sequences[:, 1:])[0] |
|
predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>","") |
|
return predicted_captions_decoded |
|
|
|
|
|
examples = [ |
|
["./examples/Image_2.jpg","./examples/Image_2.wav", "How many animals are there in image?", 10], |
|
["./examples/Image_1.jpg","./examples/General.wav", "Whhat is there in Image?", 20], |
|
["./examples/Image_3.jpg","./examples/General.wav", "Which animal is this?", 20], |
|
["./examples/Image_4.jpg","./examples/General.wav", "What represents this Image?", 20], |
|
] |
|
|
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown( |
|
""" |
|
# MultiModelLLM |
|
Multimodel GPT with inputs as Image, Audio, Text with output as Text. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image = gr.Image(label='Image', type="pil", value=None) |
|
audio_q = gr.Audio(label="Audio Question", value=None, sources=['microphone', 'upload'], type='filepath') |
|
question = gr.Text(label ='Question?', value=None) |
|
max_tokens = gr.Slider(1, 50, value=10, step=1, label="Max tokens") |
|
with gr.Row(): |
|
answer = gr.Text(label ='Answer') |
|
with gr.Row(): |
|
submit = gr.Button("Submit") |
|
submit.click(generate_answers, inputs=[image, audio_q, question, max_tokens], outputs=[answer]) |
|
clear_btn = gr.ClearButton([image, audio_q, question, max_tokens, answer]) |
|
|
|
|
|
|
|
gr.Examples(examples=examples, fn = generate_answers, inputs=[image, audio_q, question, max_tokens], outputs=answer) |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
demo.launch(share=True, debug=True) |