File size: 4,629 Bytes
2b6a5b3 297cc58 2b6a5b3 297cc58 2b6a5b3 297cc58 1392114 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
# Copyright (C) 2024 Ronan Le Meillat
# License: Apache License 2.0
# Description: Train the model on the dataset
import os
import torch
from huggingface_hub import login as hf_login
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoProcessor, BitsAndBytesConfig, Idefics3ForConditionalGeneration, TrainingArguments, Trainer
from datasets.utils.logging import disable_progress_bar
disable_progress_bar()
HF_TOKEN = ""
if os.environ.get('HF_TOKEN') is not None:
HF_TOKEN = os.environ.get('HF_TOKEN')
print(f"Hugging Face token found in environment variable")
hf_login(
token=HF_TOKEN,
add_to_git_credential=True
)
dataset_id = "eltorio/ROCO-radiology"
prompt= "You are an expert radiologist certified with over 15 years of experience in diagnostic imaging, describe this image"
source_model_id = "HuggingFaceM4/Idefics3-8B-Llama3"
destination_model_id = "eltorio/ROCO-idefics3-8B"
output_dir = "IDEFICS3_ROCO"
cache_dir = "/workspace/data"
train_dataset = load_dataset(dataset_id, split="train", cache_dir=cache_dir)
DEVICE = "cuda:0"
USE_LORA = False
USE_QLORA = True
processor = AutoProcessor.from_pretrained(
source_model_id,
do_image_splitting=False
)
if USE_QLORA or USE_LORA:
lora_config = LoraConfig(
r=8,
lora_alpha=8,
lora_dropout=0.1,
target_modules='.*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$',
use_dora=False if USE_QLORA else True,
init_lora_weights="gaussian"
)
if USE_QLORA:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
model = Idefics3ForConditionalGeneration.from_pretrained(
source_model_id,
torch_dtype=torch.float16,
quantization_config=bnb_config if USE_QLORA else None,
)
model.add_adapter(lora_config)
model.enable_adapters()
else:
model = Idefics3ForConditionalGeneration.from_pretrained(
source_model_id,
torch_dtype=torch.float16,
_attn_implementation="flash_attention_2", # This works for A100 or H100
).to(DEVICE)
class MyDataCollator:
def __init__(self, processor):
self.processor = processor
self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
processor.tokenizer.additional_special_tokens.index("<image>")
]
def __call__(self, samples):
texts = []
images = []
for sample in samples:
image = sample["image"]
answer = sample["caption"]
messages = [
{
"role": "system",
"content": [
{"type": "text", "text": prompt}
]
},
{
"role": "user",
"content": [
{"type": "image"},
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": answer}
]
}
]
text = processor.apply_chat_template(messages, add_generation_prompt=False)
texts.append(text.strip())
images.append([image.convert('RGB')])
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = self.image_token_id
batch["labels"] = labels
return batch
data_collator = MyDataCollator(processor)
training_args = TrainingArguments(
output_dir = output_dir,
overwrite_output_dir = False,
auto_find_batch_size = True,
learning_rate = 2e-4,
fp16 = True,
per_device_train_batch_size = 2,
per_device_eval_batch_size = 2,
gradient_accumulation_steps = 8,
dataloader_pin_memory = False,
save_total_limit = 3,
evaluation_strategy = None,
save_strategy = "steps",
eval_steps = 100,
save_steps = 10, # checkpoint each 10 steps
resume_from_checkpoint = True,
logging_steps = 5,
remove_unused_columns = False,
push_to_hub = True,
label_names = ["labels"],
load_best_model_at_end = False,
report_to = "none",
optim = "paged_adamw_8bit",
)
trainer = Trainer(
model = model,
args = training_args,
data_collator = data_collator,
train_dataset = train_dataset,
)
trainer.train()
|