Image Captioning - Fine Tune ViT-PhoBERT Model
This is ViT-PhoBERT fine tune Model on vietnamese_face_wiki dataset
Model Evaluation
The model being train for 50 epochs using GPU T4x2 from Kaggle
Evaluate using 3 different metrics
How to use
import needed library
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from PIL import Image
from datasets import load_dataset
from torch.utils.data import Dataset
from transformers import AutoImageProcessor, AutoTokenizer, VisionEncoderDecoderModel
load the dataset you need
from datasets import load_dataset
dataset = load_dataset("Seeker38/augmented_vi_face_wiki", split="train")
load the model
from transformers import AutoImageProcessor, AutoTokenizer, VisionEncoderDecoderModel
model = VisionEncoderDecoderModel.from_pretrained("Seeker38/ViT_PhoBert_face_vi_wiki")
phobert_tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base-v2", add_special_tokens=True)
if phobert_tokenizer.pad_token is None:
phobert_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
contruct caption generate method
def generate_caption(model, dataset, tokenizer, device, num_images=20, max_length=50):
model.eval()
sampled_indices = random.sample(range(len(dataset)), num_images)
sampled_images = [dataset[idx]['image'] for idx in sampled_indices]
pixel_values_list = []
for image in sampled_images:
image = image.resize((224, 224))
image = np.array(image, dtype=np.uint8)
image = torch.tensor(np.moveaxis(image, -1, 0), dtype=torch.float32)
pixel_values_list.append(image)
pixel_values = torch.stack(pixel_values_list).to(device)
with torch.no_grad():
outputs = model.generate(pixel_values, num_beams=10, max_length=max_length, early_stopping=True, length_penalty=1.0)
decoded_preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)
# Display the images and their captions in a single column
fig, axs = plt.subplots(num_images, 2, figsize=(15, 5 * num_images))
for i, (image, caption) in enumerate(zip(sampled_images, decoded_preds)):
axs[i, 0].imshow(image)
axs[i, 0].axis('off')
axs[i, 1].text(0, 0.5, caption, wrap=True, fontsize=12)
axs[i, 1].axis('off')
plt.tight_layout()
# Save the plot to a local file
output_file = "/kaggle/working/generated_captions.png"
plt.savefig(output_file)
plt.show()
print(f"Plot saved as {output_file}")
Run and enjoy
generate_caption(model, dataset, phobert_tokenizer, device,5,70)
- Downloads last month
- 57
Inference API (serverless) does not yet support transformers models for this pipeline type.