import os import torch import torch.nn as nn import pandas as pd from PIL import Image from torchvision import transforms from transformers import BertTokenizer, AutoModel from torch.utils.data import Dataset, DataLoader, random_split from sklearn.model_selection import train_test_split from typing import List from dataclasses import dataclass import gradio as gr import torch, re import numpy as np from transformers import WhisperProcessor, WhisperForConditionalGeneration, ViTImageProcessor, BertTokenizer, BlipProcessor, BlipForQuestionAnswering, AutoProcessor, AutoModelForCausalLM, DonutProcessor, VisionEncoderDecoderModel, Pix2StructProcessor, Pix2StructForConditionalGeneration, AutoModelForSeq2SeqLM import librosa from PIL import Image from torch.nn.utils import rnn from gtts import gTTS device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class LabelClassifier(nn.Module): def __init__(self): super(LabelClassifier, self).__init__() self.text_encoder = AutoModel.from_pretrained('bert-base-uncased') self.image_encoder = AutoModel.from_pretrained('microsoft/swin-tiny-patch4-window7-224') self.intermediate_dim = 128 self.fusion = nn.Sequential( nn.Linear(self.text_encoder.config.hidden_size + self.image_encoder.config.hidden_size, self.intermediate_dim), nn.ReLU(), nn.Dropout(0.5), ) self.classifier = nn.Linear(self.intermediate_dim, 6) # Concatenating BERT output and Swin Transformer output self.criterion = nn.CrossEntropyLoss() def forward(self, input_ids: torch.LongTensor,pixel_values: torch.FloatTensor, attention_mask: torch.LongTensor = None, token_type_ids: torch.LongTensor = None, labels: torch.LongTensor = None): encoded_text = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) encoded_image = self.image_encoder(pixel_values=pixel_values) # print(encoded_text['last_hidden_state'].shape) # print(encoded_image['last_hidden_state'].shape) fused_state = self.fusion(torch.cat((encoded_text['pooler_output'], encoded_image['pooler_output']), dim=1)) # Pass through the classifier logits = self.classifier(fused_state) out = {"logits": logits} if labels is not None: loss = self.criterion(logits, labels) out["loss"] = loss return out model = LabelClassifier().to(device) model.load_state_dict(torch.load('classifier.pth', map_location=torch.device('cpu'))) tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') processor = ViTImageProcessor.from_pretrained('microsoft/swin-tiny-patch4-window7-224') # Load the Whisper model in Hugging Face format: # processor2 = WhisperProcessor.from_pretrained("openai/whisper-medium.en") # model2 = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium.en") def m1(que, image): processor3 = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large") model3 = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large") inputs = processor3(image, que, return_tensors="pt") out = model3.generate(**inputs) return processor3.decode(out[0], skip_special_tokens=True) def m2(que, image): processor3 = AutoProcessor.from_pretrained("microsoft/git-large-textvqa") model3 = AutoModelForCausalLM.from_pretrained("microsoft/git-large-textvqa") pixel_values = processor3(images=image, return_tensors="pt").pixel_values input_ids = processor3(text=que, add_special_tokens=False).input_ids input_ids = [processor3.tokenizer.cls_token_id] + input_ids input_ids = torch.tensor(input_ids).unsqueeze(0) generated_ids = model3.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50) return processor3.batch_decode(generated_ids, skip_special_tokens=True) def m3(que, image): processor3 = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa") model3 = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa") model3.to(device) prompt = "{que}" decoder_input_ids = processor3.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids pixel_values = processor3(image, return_tensors="pt").pixel_values outputs = model3.generate( pixel_values.to(device), decoder_input_ids=decoder_input_ids.to(device), max_length=model3.decoder.config.max_position_embeddings, pad_token_id=processor3.tokenizer.pad_token_id, eos_token_id=processor3.tokenizer.eos_token_id, use_cache=True, bad_words_ids=[[processor3.tokenizer.unk_token_id]], return_dict_in_generate=True, ) sequence = processor3.batch_decode(outputs.sequences)[0] sequence = sequence.replace(processor3.tokenizer.eos_token, "").replace(processor3.tokenizer.pad_token, "") sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token return processor3.token2json(sequence)['answer'] def m4(que, image): processor3 = Pix2StructProcessor.from_pretrained('google/matcha-plotqa-v2') model3 = Pix2StructForConditionalGeneration.from_pretrained('google/matcha-plotqa-v2') inputs = processor3(images=image, text=que, return_tensors="pt") predictions = model3.generate(**inputs, max_new_tokens=512) return processor3.decode(predictions[0], skip_special_tokens=True) def m5(que, image): processor3 = AutoProcessor.from_pretrained("google/pix2struct-ocrvqa-large") model3 = AutoModelForSeq2SeqLM.from_pretrained("google/pix2struct-ocrvqa-large") inputs = processor3(images=image, text=que, return_tensors="pt") predictions = model3.generate(**inputs) return processor3.decode(predictions[0], skip_special_tokens=True) def m6(que, image): processor3 = AutoProcessor.from_pretrained("google/pix2struct-infographics-vqa-large") model3 = AutoModelForSeq2SeqLM.from_pretrained("google/pix2struct-infographics-vqa-large") inputs = processor3(images=image, text=que, return_tensors="pt") predictions = model3.generate(**inputs) return processor3.decode(predictions[0], skip_special_tokens=True) def predict_answer(category, que, image): if category == 0: return m1(que, image) elif category == 1: return m2(que, image) elif category == 2: return m3(que, image) elif category == 3: return m4(que, image) elif category == 4: return m5(que, image) else: return m6(que, image) def transcribe_audio(audio): # print(audio) processor2 = WhisperProcessor.from_pretrained("openai/whisper-large-v3",language='en') model2 = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") sampling_rate = audio[0] audio_data = audio[1] # print(np.array([audio_data]).shape) audio_data_float = np.array(audio_data).astype(np.float32) resampled_audio_data = librosa.resample(audio_data_float, orig_sr=sampling_rate, target_sr=16000) # Use the model and processor to transcribe the audio: input_features = processor2( resampled_audio_data, sampling_rate=16000, return_tensors="pt" ).input_features # Generate token ids predicted_ids = model2.generate(input_features) # Decode token ids to text transcription = processor2.batch_decode(predicted_ids, skip_special_tokens=True)[0] return transcription def predict_category(que, input_image): # print(type(input_image)) # print(input_image) encoded_text = tokenizer( text=que, padding='longest', max_length=24, truncation=True, return_tensors='pt', return_token_type_ids=True, return_attention_mask=True, ) encoded_image = processor(input_image, return_tensors='pt').to(device) dict = { 'input_ids': encoded_text['input_ids'].to(device), 'token_type_ids': encoded_text['token_type_ids'].to(device), 'attention_mask': encoded_text['attention_mask'].to(device), 'pixel_values': encoded_image['pixel_values'].to(device) } output = model(input_ids=dict['input_ids'],token_type_ids=dict['token_type_ids'],attention_mask=dict['attention_mask'],pixel_values=dict['pixel_values']) preds = output["logits"].argmax(axis=-1).cpu().numpy() return preds[0] def combine(audio, input_image, text_question=""): if audio: que = transcribe_audio(audio) else: que = text_question image = Image.fromarray(input_image).convert('RGB') category = predict_category(que, image) answer = predict_answer(0, que, image) tts = gTTS(answer) tts.save('answer.mp3') return que, answer, 'answer.mp3' # Define the Gradio interface for recording audio, text input, and image upload model_interface = gr.Interface(fn=combine, inputs=[gr.Microphone(label="Ask your question"), gr.Image(label="Upload the image"), gr.Textbox(label="Text Question")], outputs=[gr.Text(label="Transcribed Question"), gr.Text(label="Answer"), gr.Audio(label="Audio Answer")]) # Launch the Gradio interface model_interface.launch(debug=True)