import torch import torch.nn as nn from torchvision import transforms from PIL import Image from transformers import BertTokenizer, BertModel, ViTModel import gradio as gr # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Define the VQA Model class VQAModel(nn.Module): def __init__(self, vit_model, bert_model, num_classes, hidden_size=768): super(VQAModel, self).__init__() self.vit_model = vit_model self.bert_model = bert_model self.fc = nn.Linear(768 + hidden_size, hidden_size) # Adjust input size to match concatenated features self.classifier = nn.Linear(hidden_size, num_classes) # num_classes is dynamically determined def forward(self, image, question): # Extract image features with torch.no_grad(): image_features = self.vit_model(image).last_hidden_state[:, 0, :] # [CLS] token, Shape: (batch_size, 768) # Extract text features with torch.no_grad(): question_encoded = self.bert_model(question).last_hidden_state[:, 0, :] # [CLS] token, Shape: (batch_size, 768) # Concatenate image and text features combined_features = torch.cat((image_features, question_encoded), dim=1) # Shape: (batch_size, 1536) # Pass through fully connected layer combined_features = self.fc(combined_features) # Shape: (batch_size, hidden_size) # Classify output = self.classifier(combined_features) # Shape: (batch_size, num_classes) return output # Load the saved model checkpoint checkpoint_path = 'vqa_vit_best_model.pth' # Path to the saved model checkpoint = torch.load(checkpoint_path, map_location=device) # Load ViT and BERT models vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k').to(device) bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') bert_model = BertModel.from_pretrained('bert-base-uncased').to(device) # Initialize the VQA model with the correct number of classes model = VQAModel(vit_model, bert_model, num_classes=checkpoint['num_classes']).to(device) # Load the model state dict model.load_state_dict(checkpoint['model_state_dict']) # Load the answer-to-label mapping answer_to_label = checkpoint['answer_to_label'] label_to_answer = {v: k for k, v in answer_to_label.items()} # Reverse mapping for inference # Set the model to evaluation mode model.eval() # Define transformations for the image transform = transforms.Compose([ transforms.Resize((224, 224)), # Resize to 224x224 as required by ViT transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize for ViT ]) # Function to preprocess and predict def predict(image_path, question): # Load and transform the image image = Image.open(image_path).convert('RGB') image = transform(image).unsqueeze(0).to(device) # Add batch dimension and move to device # Tokenize the question question_encoded = bert_tokenizer( question, return_tensors='pt', padding='max_length', # Pad to the maximum length truncation=True, # Truncate if the question is too long max_length=32 # Set a maximum sequence length ).to(device) # Perform inference with torch.no_grad(): output = model(image, question_encoded['input_ids']) # Get the predicted label _, predicted_label = torch.max(output, 1) predicted_label = predicted_label.item() # Map the label back to the answer predicted_answer = label_to_answer[predicted_label] return predicted_answer # Define the question (already set) question = "What is the overall complexity of this model?" # Define the Gradio interface function def vqa_interface(image): # Predict the answer using the provided image and the predefined question predicted_answer = predict(image, question) return predicted_answer # Create the Gradio interface iface = gr.Interface( fn=vqa_interface, # Function to call inputs=gr.Image(type="filepath"), # Input type: image file path outputs="text", # Output type: text (predicted answer) title="Visual Question Answering (VQA) System", description="Upload an image, and the system will answer the question: 'What is the overall complexity of this model?'", examples=[ ["02_uml.png"],["2ndIterationClassDiagram.png"],["4-gameUML.png"]] ) # Launch the Gradio interface iface.launch()