zeddotes commited on
Commit
5c85d5d
·
1 Parent(s): 66db8ea
Files changed (3) hide show
  1. .gitignore +2 -0
  2. .python-version +2 -0
  3. train.py +75 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ venv
2
+ my_blip_computer_thoughts/
.python-version ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+ 3.9.21
train.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ train.py
3
+
4
+ A complete example of fine-tuning BLIP on 'agentsea/computer-thoughts' for captioning.
5
+ All processing is done in the collate function. This is simpler and avoids shape mismatches.
6
+ """
7
+
8
+ import torch
9
+ from datasets import load_dataset, Image as HFImage
10
+ from transformers import (
11
+ BlipProcessor,
12
+ BlipForConditionalGeneration,
13
+ TrainingArguments,
14
+ Trainer
15
+ )
16
+
17
+ # 1. Load dataset
18
+ dataset = load_dataset("agentsea/computer-thoughts")
19
+
20
+ # 2. Rename "image_before" -> "image" and cast to HFImage so it becomes a PIL Image
21
+ dataset = dataset.rename_column("image_before", "image")
22
+ dataset = dataset.cast_column("image", HFImage())
23
+
24
+ # 3. Create a small subset for demo (just 5 examples). Remove this if you want the full data.
25
+ train_subset = dataset["train"].select(range(5))
26
+
27
+ # 4. Load the BLIP base model and processor
28
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
29
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
30
+
31
+ # 5. Define a collate_fn that transforms images+text on-the-fly
32
+ def collate_fn(examples):
33
+ # examples is a list of dicts, each dict with keys:
34
+ # 'task', 'image', 'image_after', 'action', 'thought', 'bad_thought', 'subtask', 'bad_subtask', etc.
35
+ # We'll use 'image' (PIL) and 'subtask' (string) as the caption.
36
+ images = [ex["image"] for ex in examples] # PIL images
37
+ texts = [ex["subtask"] for ex in examples] # or whichever text column you want
38
+
39
+ inputs = processor(images=images, text=texts, return_tensors="pt", padding=True)
40
+
41
+ # Add labels so the model can compute cross-entropy loss
42
+ # For a basic approach: labels = input_ids
43
+ inputs["labels"] = inputs["input_ids"].clone()
44
+
45
+ return inputs
46
+
47
+ # 6. Define training arguments
48
+ training_args = TrainingArguments(
49
+ output_dir="./my_blip_computer_thoughts",
50
+ num_train_epochs=1,
51
+ per_device_train_batch_size=1,
52
+ gradient_accumulation_steps=4, # effectively batch size 4 per device
53
+ logging_steps=5,
54
+ save_steps=20,
55
+ save_total_limit=2,
56
+ remove_unused_columns=False # important when custom columns are in the dataset
57
+ )
58
+
59
+ # 6. Create Trainer
60
+ trainer = Trainer(
61
+ model=model,
62
+ args=training_args,
63
+ train_dataset=train_subset, # or dataset["train"] for the full set
64
+ data_collator=collate_fn,
65
+ )
66
+
67
+ # 7. Train
68
+ trainer.train()
69
+
70
+ # 9. Push the final model + processor to Hugging Face Hub
71
+ # (Make sure you're logged in: huggingface-cli login)
72
+ model.push_to_hub("zeddotes/blip-computer-thoughts")
73
+ processor.push_to_hub("zeddotes/blip-computer-thoughts")
74
+
75
+ print("Done training and pushed model to zeddotes/blip-computer-thoughts!")