""" train.py A complete example of fine-tuning BLIP on 'agentsea/computer-thoughts' for captioning. All processing is done in the collate function. This is simpler and avoids shape mismatches. """ import torch from datasets import load_dataset, Image as HFImage from transformers import ( BlipProcessor, BlipForConditionalGeneration, TrainingArguments, Trainer ) # 1. Load dataset dataset = load_dataset("agentsea/computer-thoughts") # 2. Rename "image_before" -> "image" and cast to HFImage so it becomes a PIL Image dataset = dataset.rename_column("image_before", "image") dataset = dataset.cast_column("image", HFImage()) # 3. Create a small subset for demo (just 5 examples). Remove this if you want the full data. train_subset = dataset["train"].select(range(5)) # 4. Load the BLIP base model and processor processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large") # 5. Define a collate_fn that transforms images+text on-the-fly def collate_fn(examples): # examples is a list of dicts, each dict with keys: # 'task', 'image', 'image_after', 'action', 'thought', 'bad_thought', 'subtask', 'bad_subtask', etc. # We'll use 'image' (PIL) and 'subtask' (string) as the caption. images = [ex["image"] for ex in examples] # PIL images texts = [ex["subtask"] for ex in examples] # or whichever text column you want inputs = processor(images=images, text=texts, return_tensors="pt", padding=True) # Add labels so the model can compute cross-entropy loss # For a basic approach: labels = input_ids inputs["labels"] = inputs["input_ids"].clone() return inputs # 6. Define training arguments training_args = TrainingArguments( output_dir="./my_blip_computer_thoughts", num_train_epochs=1, per_device_train_batch_size=1, gradient_accumulation_steps=4, # effectively batch size 4 per device logging_steps=5, save_steps=20, save_total_limit=2, remove_unused_columns=False # important when custom columns are in the dataset ) # 6. Create Trainer trainer = Trainer( model=model, args=training_args, train_dataset=train_subset, # or dataset["train"] for the full set data_collator=collate_fn, ) # 7. Train trainer.train() # 9. Push the final model + processor to Hugging Face Hub # (Make sure you're logged in: huggingface-cli login) model.push_to_hub("zeddotes/blip-computer-thoughts") processor.push_to_hub("zeddotes/blip-computer-thoughts") print("Done training and pushed model to zeddotes/blip-computer-thoughts!")