Image-Text-to-Text
PEFT
Safetensors
English
File size: 4,629 Bytes
2b6a5b3
 
 
297cc58
 
 
 
 
 
 
2b6a5b3
 
297cc58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b6a5b3
 
297cc58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1392114
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright (C) 2024 Ronan Le Meillat
# License: Apache License 2.0
# Description: Train the model on the dataset
import os
import torch

from huggingface_hub import login as hf_login
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoProcessor, BitsAndBytesConfig, Idefics3ForConditionalGeneration, TrainingArguments, Trainer
from datasets.utils.logging import disable_progress_bar
disable_progress_bar()

HF_TOKEN = ""

if os.environ.get('HF_TOKEN') is not None:
  HF_TOKEN = os.environ.get('HF_TOKEN')
  print(f"Hugging Face token found in environment variable")

hf_login(
  token=HF_TOKEN,
  add_to_git_credential=True
)
dataset_id = "eltorio/ROCO-radiology"
prompt= "You are an expert radiologist certified with over 15 years of experience in diagnostic imaging, describe this image"
source_model_id = "HuggingFaceM4/Idefics3-8B-Llama3"
destination_model_id = "eltorio/ROCO-idefics3-8B"
output_dir = "IDEFICS3_ROCO"
cache_dir = "/workspace/data"
train_dataset = load_dataset(dataset_id, split="train", cache_dir=cache_dir)

DEVICE = "cuda:0"
USE_LORA = False
USE_QLORA = True

processor = AutoProcessor.from_pretrained(
    source_model_id,
    do_image_splitting=False
)

if USE_QLORA or USE_LORA:
    lora_config = LoraConfig(
        r=8,
        lora_alpha=8,
        lora_dropout=0.1,
        target_modules='.*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$',
        use_dora=False if USE_QLORA else True,
        init_lora_weights="gaussian"
    )
    if USE_QLORA:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16
        )
    model = Idefics3ForConditionalGeneration.from_pretrained(
        source_model_id,
        torch_dtype=torch.float16,
        quantization_config=bnb_config if USE_QLORA else None,
    )
    model.add_adapter(lora_config)
    model.enable_adapters()
else:
    model = Idefics3ForConditionalGeneration.from_pretrained(
        source_model_id,
        torch_dtype=torch.float16,
        _attn_implementation="flash_attention_2", # This works for A100 or H100
    ).to(DEVICE)
    
class MyDataCollator:
    def __init__(self, processor):
        self.processor = processor
        self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
            processor.tokenizer.additional_special_tokens.index("<image>")
        ]

    def __call__(self, samples):
        texts = []
        images = []
        for sample in samples:
            image = sample["image"]
            answer = sample["caption"]
            messages = [
                {
                    "role": "system",
                    "content": [
                        {"type": "text", "text": prompt}
                    ]

                },
                {
                    "role": "user",
                    "content": [
                        {"type": "image"},
                    ]
                },
                {
                    "role": "assistant",
                    "content": [
                        {"type": "text", "text": answer}
                    ]
                }
            ]
            text = processor.apply_chat_template(messages, add_generation_prompt=False)
            texts.append(text.strip())
            images.append([image.convert('RGB')])

        batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

        labels = batch["input_ids"].clone()
        labels[labels == processor.tokenizer.pad_token_id] = self.image_token_id
        batch["labels"] = labels

        return batch

data_collator = MyDataCollator(processor)


training_args = TrainingArguments(
    output_dir = output_dir,
    overwrite_output_dir = False,
    auto_find_batch_size = True,
    learning_rate = 2e-4,
    fp16 = True,
    per_device_train_batch_size = 2,
    per_device_eval_batch_size = 2,
    gradient_accumulation_steps = 8,
    dataloader_pin_memory = False,
    save_total_limit = 3,
    evaluation_strategy = None,
    save_strategy = "steps",
    eval_steps = 100,
    save_steps = 10, # checkpoint each 10 steps
    resume_from_checkpoint = True,
    logging_steps = 5,
    remove_unused_columns = False,
    push_to_hub = True,
    label_names = ["labels"],
    load_best_model_at_end = False,
    report_to = "none",
    optim = "paged_adamw_8bit",
)

trainer = Trainer(
    model = model,
    args = training_args,
    data_collator = data_collator,
    train_dataset = train_dataset,
)

trainer.train()