ankanpy commited on
Commit
a60cc48
·
verified ·
1 Parent(s): 5477e29

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +237 -0
  2. final_model.pth +3 -0
  3. requirements.txt +23 -0
app.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.transforms as transforms
4
+ import torchvision.models as models
5
+ from torchvision.models import ResNet50_Weights
6
+ import gradio as gr
7
+ import pickle
8
+
9
+
10
+ class Vocabulary:
11
+ def __init__(self, freq_threshold=5):
12
+ self.freq_threshold = freq_threshold
13
+ # self.itos = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
14
+ self.itos = {0: "pad", 1: "startofseq", 2: "endofseq", 3: "unk"}
15
+ self.stoi = {v: k for k, v in self.itos.items()}
16
+ self.index = 4
17
+
18
+ def __len__(self):
19
+ return len(self.itos)
20
+
21
+ def tokenizer(self, text):
22
+ text = text.lower()
23
+ tokens = re.findall(r"\w+", text)
24
+ return tokens
25
+
26
+ def build_vocabulary(self, sentence_list):
27
+ frequencies = Counter()
28
+ for sentence in sentence_list:
29
+ tokens = self.tokenizer(sentence)
30
+ frequencies.update(tokens)
31
+
32
+ for word, freq in frequencies.items():
33
+ if freq >= self.freq_threshold:
34
+ self.stoi[word] = self.index
35
+ self.itos[self.index] = word
36
+ self.index += 1
37
+
38
+ def numericalize(self, text):
39
+ tokens = self.tokenizer(text)
40
+ numericalized = []
41
+ for token in tokens:
42
+ if token in self.stoi:
43
+ numericalized.append(self.stoi[token])
44
+ else:
45
+ numericalized.append(self.stoi["<unk>"])
46
+ return numericalized
47
+
48
+
49
+ # You'll need to ensure these match your train.py
50
+ EMBED_DIM = 256
51
+ HIDDEN_DIM = 512
52
+ MAX_SEQ_LENGTH = 25
53
+ # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+ DEVICE = "cpu"
55
+
56
+ # Where you saved your model in train.py
57
+ # MODEL_SAVE_PATH = "best_checkpoint.pth"
58
+ MODEL_SAVE_PATH = "final_model.pth"
59
+
60
+ with open("vocab.pkl", "rb") as f:
61
+ vocab = pickle.load(f)
62
+
63
+ print(vocab)
64
+
65
+ vocab_size = len(vocab)
66
+
67
+ print(vocab_size)
68
+
69
+
70
+ # -----------------------------------------------------------------
71
+ # 2. Model (Must match structure in train.py)
72
+ # -----------------------------------------------------------------
73
+ class ResNetEncoder(nn.Module):
74
+ def __init__(self, embed_dim):
75
+ super().__init__()
76
+ resnet = models.resnet50(weights=ResNet50_Weights.DEFAULT)
77
+ for param in resnet.parameters():
78
+ param.requires_grad = True
79
+ modules = list(resnet.children())[:-1]
80
+ self.resnet = nn.Sequential(*modules)
81
+
82
+ self.fc = nn.Linear(resnet.fc.in_features, embed_dim)
83
+ self.batch_norm = nn.BatchNorm1d(embed_dim, momentum=0.01)
84
+
85
+ def forward(self, images):
86
+ with torch.no_grad():
87
+ features = self.resnet(images) # (batch_size, 2048, 1, 1)
88
+ features = features.view(features.size(0), -1)
89
+ features = self.fc(features)
90
+ features = self.batch_norm(features)
91
+ return features
92
+
93
+
94
+ class DecoderLSTM(nn.Module):
95
+ def __init__(self, embed_dim, hidden_dim, vocab_size, num_layers=1):
96
+ super().__init__()
97
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
98
+ self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
99
+ self.fc = nn.Linear(hidden_dim, vocab_size)
100
+
101
+ def forward(self, features, captions, states):
102
+ # remove the last token for input
103
+ captions_in = captions
104
+ emb = self.embedding(captions_in)
105
+ features = features.unsqueeze(1)
106
+
107
+ # print(features.shape)
108
+ # print(emb.shape)
109
+
110
+ lstm_input = torch.cat((features, emb), dim=1)
111
+ outputs, returned_states = self.lstm(lstm_input, states)
112
+ logits = self.fc(outputs)
113
+ return logits, returned_states
114
+
115
+ def generate(self, features, max_len=20):
116
+ """
117
+ Greedy generation from the features as initial context.
118
+ """
119
+ batch_size = features.size(0)
120
+ states = None
121
+ generated_captions = []
122
+
123
+ start_idx = 1 # <start>
124
+ end_idx = 2 # <end>
125
+
126
+ inputs = features
127
+ # current_tokens = torch.LongTensor([start_idx] * batch_size).to(features.device).unsqueeze(0)
128
+ current_tokens = [start_idx]
129
+
130
+ for _ in range(max_len):
131
+ input_tokens = torch.LongTensor(current_tokens).to(features.device).unsqueeze(0)
132
+ logits, states = self.forward(inputs, input_tokens, states)
133
+
134
+ logits = logits.contiguous().view(-1, vocab_size)
135
+ predicted = logits.argmax(dim=1)[-1].item()
136
+
137
+ generated_captions.append(predicted)
138
+ current_tokens.append(predicted)
139
+
140
+ # check if all ended
141
+ # all_ended = True
142
+ # for i, w in enumerate(predicted.numpy()):
143
+ # print(w)
144
+ # if w != end_idx:
145
+ # all_ended = False
146
+ # break
147
+ # if all_ended:
148
+ # break
149
+
150
+ return generated_captions
151
+
152
+
153
+ class ImageCaptioningModel(nn.Module):
154
+ def __init__(self, encoder, decoder):
155
+ super().__init__()
156
+ self.encoder = encoder
157
+ self.decoder = decoder
158
+
159
+ def generate(self, images, max_len=MAX_SEQ_LENGTH):
160
+ features = self.encoder(images)
161
+ return self.decoder.generate(features, max_len=max_len)
162
+
163
+
164
+ # -----------------------------------------------------------------
165
+ # 3. LOAD THE TRAINED MODEL
166
+ # -----------------------------------------------------------------
167
+ def load_trained_model():
168
+ encoder = ResNetEncoder(embed_dim=EMBED_DIM)
169
+ decoder = DecoderLSTM(EMBED_DIM, HIDDEN_DIM, vocab_size)
170
+ model = ImageCaptioningModel(encoder, decoder).to(DEVICE)
171
+
172
+ # Load weights from disk
173
+ state_dict = torch.load(MODEL_SAVE_PATH, map_location=DEVICE)
174
+ model.load_state_dict(state_dict["model_state_dict"])
175
+ model.eval()
176
+ # print(model)
177
+ return model
178
+
179
+
180
+ model = load_trained_model()
181
+
182
+ # -----------------------------------------------------------------
183
+ # 4. INFERENCE FUNCTION (FOR GRADIO)
184
+ # -----------------------------------------------------------------
185
+ transform_inference = transforms.Compose(
186
+ [
187
+ transforms.Resize((224, 224)),
188
+ transforms.ToTensor(),
189
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
190
+ ]
191
+ )
192
+
193
+
194
+ def generate_caption_for_image(img):
195
+ """
196
+ Gradio callback: takes a PIL image, returns a string caption.
197
+ """
198
+ pil_img = img.convert("RGB")
199
+ img_tensor = transform_inference(pil_img).unsqueeze(0).to(DEVICE)
200
+
201
+ with torch.no_grad():
202
+ output_indices = model.generate(img_tensor, max_len=MAX_SEQ_LENGTH)
203
+ # output_indices is a list of lists. For 1 image, output_indices[0].
204
+ idx_list = output_indices
205
+
206
+ result_words = []
207
+ # end_token_idx = vocab.stoi["<end>"]
208
+ end_token_idx = vocab.stoi["endofseq"]
209
+ for idx in idx_list:
210
+ if idx == end_token_idx:
211
+ break
212
+ # word = vocab.itos.get(idx, "<unk>")
213
+ word = vocab.itos.get(idx, "unk")
214
+ # skip <start>/<pad> in final output
215
+ # if word not in ["<start>", "<pad>", "<end>"]:
216
+ if word not in ["startofseq", "pad", "endofseq"]:
217
+ result_words.append(word)
218
+ return " ".join(result_words)
219
+
220
+
221
+ # -----------------------------------------------------------------
222
+ # 5. BUILD GRADIO INTERFACE
223
+ # -----------------------------------------------------------------
224
+ def main():
225
+ iface = gr.Interface(
226
+ fn=generate_caption_for_image,
227
+ inputs=gr.Image(type="pil"),
228
+ outputs="text",
229
+ title="Image Captioning (ResNet + LSTM)",
230
+ description="Upload an image to get a generated caption from the trained model.",
231
+ )
232
+ iface.launch(share=True)
233
+
234
+
235
+ if __name__ == "__main__":
236
+ print("Loaded model. Starting Gradio interface...")
237
+ main()
final_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f42b553e46644720d5ba5b1adc0c7c6740dc9eb1d454d2b7519c36dc6f49cd5
3
+ size 128890708
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # requirements.txt
2
+
3
+ # Core deep learning libraries
4
+ torch>=1.9.0
5
+ torchvision>=0.10.0
6
+ torchaudio>=0.9.0
7
+
8
+ # For displaying training progress in the command line
9
+ tqdm>=4.62.0
10
+
11
+ # Gradio for the web-based inference UI
12
+ gradio>=3.0.0
13
+
14
+ # Image handling
15
+ Pillow>=8.0.0
16
+ pickle
17
+
18
+ # (Optional) if you want to store vocabulary or other metadata in Pickle/JSON, but typically included with Python
19
+ # pickle5>=0.0.11
20
+
21
+ # If you need other libraries (e.g., for advanced tokenization), add them here.
22
+ # spacy>=3.0.0
23
+ # nltk>=3.6.0