picpilot-server / scripts /products10k_captions.py
Vikramjeet Singh
Update products10k_captions.py
53c7a9e
import torch
from datasets import load_dataset, Dataset
from transformers import BlipProcessor, BlipForConditionalGeneration
from tqdm import tqdm
from config import PRODUCTS_10k_DATASET, CAPTIONING_MODEL_NAME
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ImageCaptioner:
"""
A class for generating captions for images using a pre-trained model.
Args:
dataset (str): The path to the dataset.
processor (str): The pre-trained processor model to use for image processing.
model (str): The pre-trained model to use for caption generation.
prompt (str): The conditioning prompt to use for caption generation.
Attributes:
dataset: The loaded dataset.
processor: The pre-trained processor model.
model: The pre-trained caption generation model.
prompt: The conditioning prompt for generating captions.
Methods:
process_dataset: Preprocesses the dataset.
generate_caption: Generates a caption for a single image.
generate_captions: Generates captions for all images in the dataset.
"""
def __init__(self, dataset: str, processor: str, model: str, prompt: str = "Product photo of"):
self.dataset = load_dataset(dataset, split="test")
self.dataset = self.dataset.select(range(10000)) # For demonstration purposes
self.processor = BlipProcessor.from_pretrained(processor)
self.model = BlipForConditionalGeneration.from_pretrained(model).to(device)
self.prompt = prompt
def process_dataset(self):
"""
Preprocesses the dataset by renaming the image column and removing unwanted columns.
Returns:
The preprocessed dataset.
"""
# Check if 'image' column exists, otherwise use 'pixel_values' if it exists
image_column = "image" if "image" in self.dataset.column_names else "pixel_values"
self.dataset = self.dataset.rename_column(image_column, "image")
if "label" in self.dataset.column_names:
self.dataset = self.dataset.remove_columns(["label"])
# Add an empty 'text' column for captions if it doesn't exist
if "text" not in self.dataset.column_names:
new_column = [""] * len(self.dataset)
self.dataset = self.dataset.add_column("text", new_column)
return self.dataset
def generate_caption(self, example):
"""
Generates a caption for a single image.
Args:
example (dict): A dictionary containing the image data.
Returns:
dict: The dictionary with the generated caption.
"""
image = example["image"].convert("RGB")
inputs = self.processor(images=image, return_tensors="pt").to(device)
prompt_inputs = self.processor(text=[self.prompt], return_tensors="pt").to(device)
outputs = self.model.generate(**inputs, **prompt_inputs)
blip_caption = self.processor.decode(outputs[0], skip_special_tokens=True)
example["text"] = blip_caption
return example
def generate_captions(self):
"""
Generates captions for all images in the dataset.
Returns:
Dataset: The dataset with generated captions.
"""
self.dataset = self.process_dataset()
self.dataset = self.dataset.map(self.generate_caption, batched=False)
return self.dataset
# Initialize ImageCaptioner
ic = ImageCaptioner(
dataset=PRODUCTS_10k_DATASET,
processor=CAPTIONING_MODEL_NAME,
model=CAPTIONING_MODEL_NAME,
prompt='Commercial photography of'
)
# Generate captions for the dataset
products10k_dataset = ic.generate_captions()
# Save the dataset to the hub
products10k_dataset.push_to_hub("VikramSingh178/Products-10k-BLIP-captions")