|
import torch |
|
from datasets import load_dataset |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AdamW, get_linear_schedule_with_warmup, BitsAndBytesConfig |
|
import transformers |
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
base_model_id= "google/gemma-2b" |
|
torch.cuda.set_device(0) |
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
print("Using device:", device) |
|
|
|
dataset = load_dataset("ysharma/short_jokes") |
|
|
|
train_data = dataset['train'] |
|
|
|
twenty_percent_size = int(0.2 * len(train_data)) |
|
subset = train_data.shuffle(seed=42)[:twenty_percent_size] |
|
|
|
|
|
import torch |
|
print("Available devices:", torch.cuda.device_count()) |
|
print("Current device:", torch.cuda.current_device()) |
|
|
|
|
|
|
|
from accelerate import FullyShardedDataParallelPlugin, Accelerator |
|
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig |
|
|
|
fsdp_plugin = FullyShardedDataParallelPlugin( |
|
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False), |
|
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False), |
|
) |
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
base_model_id, |
|
padding_side="left", |
|
add_eos_token=True, |
|
add_bos_token=True, |
|
) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(base_model_id, quantization_config=bnb_config, device_map="auto") |
|
|
|
|
|
def tokenize_function(examples): |
|
return tokenizer(examples["Joke"], padding="max_length", truncation=True, max_length=50) |
|
|
|
from datasets import load_dataset |
|
dataset = load_dataset("ysharma/short_jokes") |
|
|
|
|
|
|
|
train_test_split = dataset['train'].train_test_split(test_size=0.1) |
|
train_data = train_test_split['train'] |
|
test_data = train_test_split['test'] |
|
|
|
|
|
tokenized_train_data = train_data.map(tokenize_function, batched=True) |
|
tokenized_test_data = test_data.map(tokenize_function, batched=True) |
|
|
|
|
|
|
|
eval_prompt = " why man are " |
|
|
|
eval_tokenizer = AutoTokenizer.from_pretrained( |
|
base_model_id, |
|
add_bos_token=True, |
|
) |
|
|
|
model_input = eval_tokenizer(eval_prompt, return_tensors="pt").to("cuda") |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
print(eval_tokenizer.decode(model.generate(**model_input, max_new_tokens=50, repetition_penalty=1.15)[0], skip_special_tokens=True)) |
|
|
|
|
|
from peft import prepare_model_for_kbit_training,LoraConfig, get_peft_model |
|
|
|
model.gradient_checkpointing_enable() |
|
model = prepare_model_for_kbit_training(model) |
|
|
|
def print_trainable_parameters(model): |
|
""" |
|
Prints the number of trainable parameters in the model. |
|
""" |
|
trainable_params = 0 |
|
all_param = 0 |
|
for _, param in model.named_parameters(): |
|
all_param += param.numel() |
|
if param.requires_grad: |
|
trainable_params += param.numel() |
|
print( |
|
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" |
|
) |
|
|
|
config = LoraConfig( |
|
r=32, |
|
lora_alpha=64, |
|
target_modules=[ |
|
"q_proj", |
|
"k_proj", |
|
"v_proj", |
|
"o_proj", |
|
"gate_proj", |
|
"up_proj", |
|
"down_proj", |
|
"lm_head", |
|
], |
|
bias="none", |
|
lora_dropout=0.05, |
|
task_type="CAUSAL_LM", |
|
) |
|
|
|
model = get_peft_model(model, config) |
|
print_trainable_parameters(model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.to(device) |
|
accelerator = Accelerator(fsdp_plugin=fsdp_plugin) |
|
print("Accelerator device:", accelerator.device) |
|
model = accelerator.prepare_model(model) |
|
|
|
|
|
from datetime import datetime |
|
|
|
project = "jokes-gemma" |
|
base_model_name = "gemma" |
|
run_name = base_model_name + "-" + project |
|
output_dir = "./" + run_name |
|
|
|
trainer = transformers.Trainer( |
|
model=model, |
|
train_dataset=tokenized_train_data, |
|
eval_dataset=tokenized_test_data, |
|
args=transformers.TrainingArguments( |
|
output_dir=output_dir, |
|
warmup_steps=1, |
|
per_device_train_batch_size=2, |
|
gradient_accumulation_steps=1, |
|
gradient_checkpointing=True, |
|
max_steps=500, |
|
learning_rate=2.5e-5, |
|
bf16=True, |
|
optim="paged_adamw_8bit", |
|
logging_steps=25, |
|
logging_dir="./logs", |
|
save_strategy="steps", |
|
save_steps=25, |
|
evaluation_strategy="steps", |
|
eval_steps=25, |
|
do_eval=True, |
|
report_to="wandb", |
|
run_name=f"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}" |
|
), |
|
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False), |
|
) |
|
|
|
model.config.use_cache = False |
|
trainer.train() |
|
|
|
|
|
|
|
|