Spaces:
Running
Running
import os | |
import torch | |
import torch.nn as nn | |
import pandas as pd | |
from PIL import Image | |
from torchvision import transforms | |
from transformers import BertTokenizer, AutoModel | |
from torch.utils.data import Dataset, DataLoader, random_split | |
from sklearn.model_selection import train_test_split | |
from typing import List | |
from dataclasses import dataclass | |
import gradio as gr | |
import torch, re | |
import numpy as np | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration, ViTImageProcessor, BertTokenizer, BlipProcessor, BlipForQuestionAnswering, AutoProcessor, AutoModelForCausalLM, DonutProcessor, VisionEncoderDecoderModel, Pix2StructProcessor, Pix2StructForConditionalGeneration, AutoModelForSeq2SeqLM | |
import librosa | |
from PIL import Image | |
from torch.nn.utils import rnn | |
from gtts import gTTS | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class LabelClassifier(nn.Module): | |
def __init__(self): | |
super(LabelClassifier, self).__init__() | |
self.text_encoder = AutoModel.from_pretrained('bert-base-uncased') | |
self.image_encoder = AutoModel.from_pretrained('microsoft/swin-tiny-patch4-window7-224') | |
self.intermediate_dim = 128 | |
self.fusion = nn.Sequential( | |
nn.Linear(self.text_encoder.config.hidden_size + self.image_encoder.config.hidden_size, self.intermediate_dim), | |
nn.ReLU(), | |
nn.Dropout(0.5), | |
) | |
self.classifier = nn.Linear(self.intermediate_dim, 6) # Concatenating BERT output and Swin Transformer output | |
self.criterion = nn.CrossEntropyLoss() | |
def forward(self, | |
input_ids: torch.LongTensor,pixel_values: torch.FloatTensor, attention_mask: torch.LongTensor = None, token_type_ids: torch.LongTensor = None, labels: torch.LongTensor = None): | |
encoded_text = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) | |
encoded_image = self.image_encoder(pixel_values=pixel_values) | |
# print(encoded_text['last_hidden_state'].shape) | |
# print(encoded_image['last_hidden_state'].shape) | |
fused_state = self.fusion(torch.cat((encoded_text['pooler_output'], encoded_image['pooler_output']), dim=1)) | |
# Pass through the classifier | |
logits = self.classifier(fused_state) | |
out = {"logits": logits} | |
if labels is not None: | |
loss = self.criterion(logits, labels) | |
out["loss"] = loss | |
return out | |
model = LabelClassifier().to(device) | |
model.load_state_dict(torch.load('classifier.pth')) | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
processor = ViTImageProcessor.from_pretrained('microsoft/swin-tiny-patch4-window7-224') | |
# Load the Whisper model in Hugging Face format: | |
# processor2 = WhisperProcessor.from_pretrained("openai/whisper-medium.en") | |
# model2 = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium.en") | |
def m1(que, image): | |
processor3 = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large") | |
model3 = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large").to("cuda") | |
inputs = processor3(image, que, return_tensors="pt").to("cuda") | |
out = model3.generate(**inputs) | |
return processor3.decode(out[0], skip_special_tokens=True) | |
def m2(que, image): | |
processor3 = AutoProcessor.from_pretrained("microsoft/git-large-textvqa") | |
model3 = AutoModelForCausalLM.from_pretrained("microsoft/git-large-textvqa") | |
pixel_values = processor3(images=image, return_tensors="pt").pixel_values | |
input_ids = processor3(text=que, add_special_tokens=False).input_ids | |
input_ids = [processor3.tokenizer.cls_token_id] + input_ids | |
input_ids = torch.tensor(input_ids).unsqueeze(0) | |
generated_ids = model3.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50) | |
return processor3.batch_decode(generated_ids, skip_special_tokens=True) | |
def m3(que, image): | |
processor3 = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa") | |
model3 = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model3.to(device) | |
prompt = "<s_docvqa><s_question>{que}</s_question><s_answer>" | |
decoder_input_ids = processor3.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids | |
pixel_values = processor3(image, return_tensors="pt").pixel_values | |
outputs = model3.generate( | |
pixel_values.to(device), | |
decoder_input_ids=decoder_input_ids.to(device), | |
max_length=model3.decoder.config.max_position_embeddings, | |
pad_token_id=processor3.tokenizer.pad_token_id, | |
eos_token_id=processor3.tokenizer.eos_token_id, | |
use_cache=True, | |
bad_words_ids=[[processor3.tokenizer.unk_token_id]], | |
return_dict_in_generate=True, | |
) | |
sequence = processor3.batch_decode(outputs.sequences)[0] | |
sequence = sequence.replace(processor3.tokenizer.eos_token, "").replace(processor3.tokenizer.pad_token, "") | |
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token | |
return processor3.token2json(sequence)['answer'] | |
def m4(que, image): | |
processor3 = Pix2StructProcessor.from_pretrained('google/matcha-plotqa-v1') | |
model3 = Pix2StructForConditionalGeneration.from_pretrained('google/matcha-plotqa-v1') | |
inputs = processor3(images=image, text=que, return_tensors="pt") | |
predictions = model3.generate(**inputs, max_new_tokens=512) | |
return processor3.decode(predictions[0], skip_special_tokens=True) | |
def m5(que, image): | |
processor3 = AutoProcessor.from_pretrained("google/pix2struct-ocrvqa-large") | |
model3 = AutoModelForSeq2SeqLM.from_pretrained("google/pix2struct-ocrvqa-large") | |
inputs = processor3(images=image, text=que, return_tensors="pt").to("cuda") | |
predictions = model3.generate(**inputs) | |
return processor3.decode(predictions[0], skip_special_tokens=True) | |
def m6(que, image): | |
processor3 = AutoProcessor.from_pretrained("google/pix2struct-infographics-vqa-large") | |
model3 = AutoModelForSeq2SeqLM.from_pretrained("google/pix2struct-infographics-vqa-large") | |
inputs = processor3(images=image, text=que, return_tensors="pt").to("cuda") | |
predictions = model3.generate(**inputs) | |
return processor3.decode(predictions[0], skip_special_tokens=True) | |
def predict_answer(category, que, image): | |
if category == 0: | |
return m1(que, image) | |
elif category == 1: | |
return m2(que, image) | |
elif category == 2: | |
return m3(que, image) | |
elif category == 3: | |
return m4(que, image) | |
elif category == 4: | |
return m5(que, image) | |
else: | |
return m6(que, image) | |
def transcribe_audio(audio): | |
# print(audio) | |
processor2 = WhisperProcessor.from_pretrained("openai/whisper-large-v3",language='en') | |
model2 = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") | |
sampling_rate = audio[0] | |
audio_data = audio[1] | |
# print(np.array([audio_data]).shape) | |
audio_data_float = np.array(audio_data).astype(np.float32) | |
resampled_audio_data = librosa.resample(audio_data_float, orig_sr=sampling_rate, target_sr=16000) | |
# Use the model and processor to transcribe the audio: | |
input_features = processor2( | |
resampled_audio_data, sampling_rate=16000, return_tensors="pt" | |
).input_features | |
# Generate token ids | |
predicted_ids = model2.generate(input_features) | |
# Decode token ids to text | |
transcription = processor2.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
return transcription | |
def predict_category(que, input_image): | |
# print(type(input_image)) | |
# print(input_image) | |
encoded_text = tokenizer( | |
text=que, | |
padding='longest', | |
max_length=24, | |
truncation=True, | |
return_tensors='pt', | |
return_token_type_ids=True, | |
return_attention_mask=True, | |
) | |
encoded_image = processor(input_image, return_tensors='pt').to(device) | |
dict = { | |
'input_ids': encoded_text['input_ids'].to(device), | |
'token_type_ids': encoded_text['token_type_ids'].to(device), | |
'attention_mask': encoded_text['attention_mask'].to(device), | |
'pixel_values': encoded_image['pixel_values'].to(device) | |
} | |
output = model(input_ids=dict['input_ids'],token_type_ids=dict['token_type_ids'],attention_mask=dict['attention_mask'],pixel_values=dict['pixel_values']) | |
preds = output["logits"].argmax(axis=-1).cpu().numpy() | |
return preds[0] | |
def combine(audio, input_image): | |
que = transcribe_audio(audio) | |
# que = "What is the animal here?" | |
image = Image.fromarray(input_image).convert('RGB') | |
category = predict_category(que, image) | |
answer = predict_answer(0, que, image) | |
# print(category) | |
tts = gTTS(answer) | |
tts.save('answer.mp3') | |
return que, answer, 'answer.mp3' | |
# Define the Gradio interface for recording audio and displaying the transcription | |
model_interface = gr.Interface(fn=combine, inputs=[gr.Microphone(label="Ask your question"),gr.Image(label="Upload the image")], outputs=[gr.Text(label="Transcribed Question"), gr.Text(label="Answer"), gr.Audio(label="Audio Answer")]) | |
# image_upload_interface = gr.Interface(fn=upload_image, inputs=gr.Image(label="Upload the image"), outputs="text") | |
# Launch the Gradio interface | |
model_interface.launch(debug=True) |