Rehman1603 commited on
Commit
2a2d54d
·
verified ·
1 Parent(s): c5e3eb3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -0
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ from transformers import BertTokenizer, BertModel, ViTModel
6
+ import gradio as gr
7
+
8
+ # Set device
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ # Define the VQA Model
12
+ class VQAModel(nn.Module):
13
+ def __init__(self, vit_model, bert_model, num_classes, hidden_size=768):
14
+ super(VQAModel, self).__init__()
15
+ self.vit_model = vit_model
16
+ self.bert_model = bert_model
17
+ self.fc = nn.Linear(768 + hidden_size, hidden_size) # Adjust input size to match concatenated features
18
+ self.classifier = nn.Linear(hidden_size, num_classes) # num_classes is dynamically determined
19
+
20
+ def forward(self, image, question):
21
+ # Extract image features
22
+ with torch.no_grad():
23
+ image_features = self.vit_model(image).last_hidden_state[:, 0, :] # [CLS] token, Shape: (batch_size, 768)
24
+
25
+ # Extract text features
26
+ with torch.no_grad():
27
+ question_encoded = self.bert_model(question).last_hidden_state[:, 0, :] # [CLS] token, Shape: (batch_size, 768)
28
+
29
+ # Concatenate image and text features
30
+ combined_features = torch.cat((image_features, question_encoded), dim=1) # Shape: (batch_size, 1536)
31
+
32
+ # Pass through fully connected layer
33
+ combined_features = self.fc(combined_features) # Shape: (batch_size, hidden_size)
34
+
35
+ # Classify
36
+ output = self.classifier(combined_features) # Shape: (batch_size, num_classes)
37
+ return output
38
+
39
+ # Load the saved model checkpoint
40
+ checkpoint_path = 'vqa_vit_best_model.pth' # Path to the saved model
41
+ checkpoint = torch.load(checkpoint_path, map_location=device)
42
+
43
+ # Load ViT and BERT models
44
+ vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k').to(device)
45
+ bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
46
+ bert_model = BertModel.from_pretrained('bert-base-uncased').to(device)
47
+
48
+ # Initialize the VQA model with the correct number of classes
49
+ model = VQAModel(vit_model, bert_model, num_classes=checkpoint['num_classes']).to(device)
50
+
51
+ # Load the model state dict
52
+ model.load_state_dict(checkpoint['model_state_dict'])
53
+
54
+ # Load the answer-to-label mapping
55
+ answer_to_label = checkpoint['answer_to_label']
56
+ label_to_answer = {v: k for k, v in answer_to_label.items()} # Reverse mapping for inference
57
+
58
+ # Set the model to evaluation mode
59
+ model.eval()
60
+
61
+ # Define transformations for the image
62
+ transform = transforms.Compose([
63
+ transforms.Resize((224, 224)), # Resize to 224x224 as required by ViT
64
+ transforms.ToTensor(),
65
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize for ViT
66
+ ])
67
+
68
+ # Function to preprocess and predict
69
+ def predict(image_path, question):
70
+ # Load and transform the image
71
+ image = Image.open(image_path).convert('RGB')
72
+ image = transform(image).unsqueeze(0).to(device) # Add batch dimension and move to device
73
+
74
+ # Tokenize the question
75
+ question_encoded = bert_tokenizer(
76
+ question,
77
+ return_tensors='pt',
78
+ padding='max_length', # Pad to the maximum length
79
+ truncation=True, # Truncate if the question is too long
80
+ max_length=32 # Set a maximum sequence length
81
+ ).to(device)
82
+
83
+ # Perform inference
84
+ with torch.no_grad():
85
+ output = model(image, question_encoded['input_ids'])
86
+
87
+ # Get the predicted label
88
+ _, predicted_label = torch.max(output, 1)
89
+ predicted_label = predicted_label.item()
90
+
91
+ # Map the label back to the answer
92
+ predicted_answer = label_to_answer[predicted_label]
93
+
94
+ return predicted_answer
95
+
96
+ # Define the question (already set)
97
+ question = "What is the overall complexity of this model?"
98
+
99
+ # Define the Gradio interface function
100
+ def vqa_interface(image):
101
+ # Predict the answer using the provided image and the predefined question
102
+ predicted_answer = predict(image, question)
103
+ return predicted_answer
104
+
105
+ # Create the Gradio interface
106
+ iface = gr.Interface(
107
+ fn=vqa_interface, # Function to call
108
+ inputs=gr.Image(type="filepath"), # Input type: image file path
109
+ outputs="text", # Output type: text (predicted answer)
110
+ title="Visual Question Answering (VQA) System",
111
+ description="Upload an image, and the system will answer the question: 'What is the overall complexity of this model?'",
112
+ examples=[
113
+ ["02_uml.png"],["2ndIterationClassDiagram.png"],["4-gameUML.png"]]
114
+ )
115
+
116
+ # Launch the Gradio interface
117
+ iface.launch()