ankanpy's picture
Upload 3 files
a60cc48 verified
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.models import ResNet50_Weights
import gradio as gr
import pickle
class Vocabulary:
def __init__(self, freq_threshold=5):
self.freq_threshold = freq_threshold
# self.itos = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
self.itos = {0: "pad", 1: "startofseq", 2: "endofseq", 3: "unk"}
self.stoi = {v: k for k, v in self.itos.items()}
self.index = 4
def __len__(self):
return len(self.itos)
def tokenizer(self, text):
text = text.lower()
tokens = re.findall(r"\w+", text)
return tokens
def build_vocabulary(self, sentence_list):
frequencies = Counter()
for sentence in sentence_list:
tokens = self.tokenizer(sentence)
frequencies.update(tokens)
for word, freq in frequencies.items():
if freq >= self.freq_threshold:
self.stoi[word] = self.index
self.itos[self.index] = word
self.index += 1
def numericalize(self, text):
tokens = self.tokenizer(text)
numericalized = []
for token in tokens:
if token in self.stoi:
numericalized.append(self.stoi[token])
else:
numericalized.append(self.stoi["<unk>"])
return numericalized
# You'll need to ensure these match your train.py
EMBED_DIM = 256
HIDDEN_DIM = 512
MAX_SEQ_LENGTH = 25
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = "cpu"
# Where you saved your model in train.py
# MODEL_SAVE_PATH = "best_checkpoint.pth"
MODEL_SAVE_PATH = "final_model.pth"
with open("vocab.pkl", "rb") as f:
vocab = pickle.load(f)
print(vocab)
vocab_size = len(vocab)
print(vocab_size)
# -----------------------------------------------------------------
# 2. Model (Must match structure in train.py)
# -----------------------------------------------------------------
class ResNetEncoder(nn.Module):
def __init__(self, embed_dim):
super().__init__()
resnet = models.resnet50(weights=ResNet50_Weights.DEFAULT)
for param in resnet.parameters():
param.requires_grad = True
modules = list(resnet.children())[:-1]
self.resnet = nn.Sequential(*modules)
self.fc = nn.Linear(resnet.fc.in_features, embed_dim)
self.batch_norm = nn.BatchNorm1d(embed_dim, momentum=0.01)
def forward(self, images):
with torch.no_grad():
features = self.resnet(images) # (batch_size, 2048, 1, 1)
features = features.view(features.size(0), -1)
features = self.fc(features)
features = self.batch_norm(features)
return features
class DecoderLSTM(nn.Module):
def __init__(self, embed_dim, hidden_dim, vocab_size, num_layers=1):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_dim, vocab_size)
def forward(self, features, captions, states):
# remove the last token for input
captions_in = captions
emb = self.embedding(captions_in)
features = features.unsqueeze(1)
# print(features.shape)
# print(emb.shape)
lstm_input = torch.cat((features, emb), dim=1)
outputs, returned_states = self.lstm(lstm_input, states)
logits = self.fc(outputs)
return logits, returned_states
def generate(self, features, max_len=20):
"""
Greedy generation from the features as initial context.
"""
batch_size = features.size(0)
states = None
generated_captions = []
start_idx = 1 # <start>
end_idx = 2 # <end>
inputs = features
# current_tokens = torch.LongTensor([start_idx] * batch_size).to(features.device).unsqueeze(0)
current_tokens = [start_idx]
for _ in range(max_len):
input_tokens = torch.LongTensor(current_tokens).to(features.device).unsqueeze(0)
logits, states = self.forward(inputs, input_tokens, states)
logits = logits.contiguous().view(-1, vocab_size)
predicted = logits.argmax(dim=1)[-1].item()
generated_captions.append(predicted)
current_tokens.append(predicted)
# check if all ended
# all_ended = True
# for i, w in enumerate(predicted.numpy()):
# print(w)
# if w != end_idx:
# all_ended = False
# break
# if all_ended:
# break
return generated_captions
class ImageCaptioningModel(nn.Module):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def generate(self, images, max_len=MAX_SEQ_LENGTH):
features = self.encoder(images)
return self.decoder.generate(features, max_len=max_len)
# -----------------------------------------------------------------
# 3. LOAD THE TRAINED MODEL
# -----------------------------------------------------------------
def load_trained_model():
encoder = ResNetEncoder(embed_dim=EMBED_DIM)
decoder = DecoderLSTM(EMBED_DIM, HIDDEN_DIM, vocab_size)
model = ImageCaptioningModel(encoder, decoder).to(DEVICE)
# Load weights from disk
state_dict = torch.load(MODEL_SAVE_PATH, map_location=DEVICE)
model.load_state_dict(state_dict["model_state_dict"])
model.eval()
# print(model)
return model
model = load_trained_model()
# -----------------------------------------------------------------
# 4. INFERENCE FUNCTION (FOR GRADIO)
# -----------------------------------------------------------------
transform_inference = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
def generate_caption_for_image(img):
"""
Gradio callback: takes a PIL image, returns a string caption.
"""
pil_img = img.convert("RGB")
img_tensor = transform_inference(pil_img).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output_indices = model.generate(img_tensor, max_len=MAX_SEQ_LENGTH)
# output_indices is a list of lists. For 1 image, output_indices[0].
idx_list = output_indices
result_words = []
# end_token_idx = vocab.stoi["<end>"]
end_token_idx = vocab.stoi["endofseq"]
for idx in idx_list:
if idx == end_token_idx:
break
# word = vocab.itos.get(idx, "<unk>")
word = vocab.itos.get(idx, "unk")
# skip <start>/<pad> in final output
# if word not in ["<start>", "<pad>", "<end>"]:
if word not in ["startofseq", "pad", "endofseq"]:
result_words.append(word)
return " ".join(result_words)
# -----------------------------------------------------------------
# 5. BUILD GRADIO INTERFACE
# -----------------------------------------------------------------
def main():
iface = gr.Interface(
fn=generate_caption_for_image,
inputs=gr.Image(type="pil"),
outputs="text",
title="Image Captioning (ResNet + LSTM)",
description="Upload an image to get a generated caption from the trained model.",
)
iface.launch(share=True)
if __name__ == "__main__":
print("Loaded model. Starting Gradio interface...")
main()