MultimodalChatGPT_TSAI / app_gradio.py
RashiAgarwal's picture
Upload app_gradio.py
bfb2de9 verified
raw
history blame
4.85 kB
import gradio as gr
import peft
from peft import LoraConfig, PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
import torch
from PIL import Image
import requests
import numpy as np
import torch.nn as nn
import whisperx
import ffmpeg, pydub
from pydub import AudioSegment
clip_model_name = "wkcn/TinyCLIP-ViT-61M-32-Text-29M-LAION400M"
phi_model_name = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(clip_model_name)
tokenizer.pad_token = tokenizer.eos_token
IMAGE_TOKEN_ID = 23893 # token for word comment
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_embed = 640
phi_embed = 2560
compute_type = "float16"
audio_batch_size = 1
import gc
# models
clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
gc.collect()
phi_model = AutoModelForCausalLM.from_pretrained(
phi_model_name,
trust_remote_code=True,
)
audio_model = whisperx.load_model("small", device, compute_type=compute_type)
# load weights
model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/qlora_adaptor')
merged_model = model_to_merge.merge_and_unload().to(device)
projection.load_state_dict(torch.load('./model_chkpt/ft_projection.pth',map_location=torch.device(device)))
def inference(img=None,img_audio=None,val_q=None):
max_generate_length = 100
val_combined_embeds = []
with torch.no_grad():
# image
if img is not None:
image_processed = processor(images=img, return_tensors="pt").to(device)
clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
val_image_embeds = projection(clip_val_outputs)
img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
val_combined_embeds.append(val_image_embeds)
val_combined_embeds.append(img_token_embeds)
# audio
if img_audio is not None:
# accepting only initial 15 secs speech
audio = AudioSegment.from_mp3( img_audio)
clipped_audio = audio[:15*1000]
clipped_audio.export( 'audio.mp3', format="mp3")
result = audio_model.transcribe('audio.mp3')
audio_text = ''
audio_text = result["segments"][0]['text']
audio_text = audio_text.strip()
audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
val_combined_embeds.append(audio_embeds)
# text question
if len(val_q) != 0:
val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
val_combined_embeds.append(val_q_embeds)
# val_combined_emb
val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
predicted_caption = torch.full((1,max_generate_length),50256).to(device)
for g in range(max_generate_length):
phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
predicted_caption[:,g] = predicted_word_token.view(1,-1)
next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560
val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
return predicted_captions_decoded
with gr.Blocks() as demo:
gr.Markdown(
"""
# MultiModal GPT _TSAI
Build on Tiny Clip model and Microsoft's Phi-2 model further fine tuned on Instruct150K.
"""
)
# app GUI
with gr.Row():
with gr.Column():
img_input = gr.Image(label='Image',type="pil")
img_audio = gr.Audio(label="Speak a Query", sources=['microphone', 'upload'], type='filepath')
img_question = gr.Text(label ='Write a Query')
with gr.Column():
img_answer = gr.Text(label ='Answer')
section_btn = gr.Button("Generate")
section_btn.click(inference, inputs=[img_input,img_audio,img_question], outputs=[img_answer])
if __name__ == "__main__":
demo.launch(debug=True)