Alejandro Velez
tdc geneformer
47990ca
raw
history blame
12.8 kB
import os
import random
import numpy as np
import pandas as pd
import torch
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from .imports import *
from .model import GeneformerMultiTask
from .utils import calculate_task_specific_metrics, get_layer_freeze_range
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def initialize_wandb(config):
if config.get("use_wandb", False):
import wandb
wandb.init(project=config["wandb_project"], config=config)
print("Weights & Biases (wandb) initialized and will be used for logging.")
else:
print(
"Weights & Biases (wandb) is not enabled. Logging will use other methods."
)
def create_model(config, num_labels_list, device):
model = GeneformerMultiTask(
config["pretrained_path"],
num_labels_list,
dropout_rate=config["dropout_rate"],
use_task_weights=config["use_task_weights"],
task_weights=config["task_weights"],
max_layers_to_freeze=config["max_layers_to_freeze"],
use_attention_pooling=config["use_attention_pooling"],
)
if config["use_data_parallel"]:
model = nn.DataParallel(model)
return model.to(device)
def setup_optimizer_and_scheduler(model, config, total_steps):
optimizer = AdamW(
model.parameters(),
lr=config["learning_rate"],
weight_decay=config["weight_decay"],
)
warmup_steps = int(config["warmup_ratio"] * total_steps)
if config["lr_scheduler_type"] == "linear":
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
)
elif config["lr_scheduler_type"] == "cosine":
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps,
num_cycles=0.5,
)
return optimizer, scheduler
def train_epoch(
model, train_loader, optimizer, scheduler, device, config, writer, epoch
):
model.train()
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
for batch_idx, batch in enumerate(progress_bar):
optimizer.zero_grad()
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = [
batch["labels"][task_name].to(device) for task_name in config["task_names"]
]
loss, _, _ = model(input_ids, attention_mask, labels)
loss.backward()
if config["gradient_clipping"]:
torch.nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"])
optimizer.step()
scheduler.step()
writer.add_scalar(
"Training Loss", loss.item(), epoch * len(train_loader) + batch_idx
)
if config.get("use_wandb", False):
import wandb
wandb.log({"Training Loss": loss.item()})
# Update progress bar
progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
return loss.item() # Return the last batch loss
def validate_model(model, val_loader, device, config):
model.eval()
val_loss = 0.0
task_true_labels = {task_name: [] for task_name in config["task_names"]}
task_pred_labels = {task_name: [] for task_name in config["task_names"]}
task_pred_probs = {task_name: [] for task_name in config["task_names"]}
with torch.no_grad():
for batch in val_loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = [
batch["labels"][task_name].to(device)
for task_name in config["task_names"]
]
loss, logits, _ = model(input_ids, attention_mask, labels)
val_loss += loss.item()
for sample_idx in range(len(batch["input_ids"])):
for i, task_name in enumerate(config["task_names"]):
true_label = batch["labels"][task_name][sample_idx].item()
pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
pred_prob = (
torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
)
task_true_labels[task_name].append(true_label)
task_pred_labels[task_name].append(pred_label)
task_pred_probs[task_name].append(pred_prob)
val_loss /= len(val_loader)
return val_loss, task_true_labels, task_pred_labels, task_pred_probs
def log_metrics(task_metrics, val_loss, config, writer, epochs):
for task_name, metrics in task_metrics.items():
print(
f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}"
)
if config.get("use_wandb", False):
import wandb
wandb.log(
{
f"{task_name} Validation F1 Macro": metrics["f1"],
f"{task_name} Validation Accuracy": metrics["accuracy"],
}
)
writer.add_scalar("Validation Loss", val_loss, epochs)
for task_name, metrics in task_metrics.items():
writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epochs)
writer.add_scalar(
f"{task_name} - Validation Accuracy", metrics["accuracy"], epochs
)
def save_validation_predictions(
val_cell_id_mapping,
task_true_labels,
task_pred_labels,
task_pred_probs,
config,
trial_number=None,
):
if trial_number is not None:
trial_results_dir = os.path.join(config["results_dir"], f"trial_{trial_number}")
os.makedirs(trial_results_dir, exist_ok=True)
val_preds_file = os.path.join(trial_results_dir, "val_preds.csv")
else:
val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv")
rows = []
for sample_idx in range(len(val_cell_id_mapping)):
row = {"Cell ID": val_cell_id_mapping[sample_idx]}
for task_name in config["task_names"]:
row[f"{task_name} True"] = task_true_labels[task_name][sample_idx]
row[f"{task_name} Pred"] = task_pred_labels[task_name][sample_idx]
row[f"{task_name} Probabilities"] = ",".join(
map(str, task_pred_probs[task_name][sample_idx])
)
rows.append(row)
df = pd.DataFrame(rows)
df.to_csv(val_preds_file, index=False)
print(f"Validation predictions saved to {val_preds_file}")
def train_model(
config,
device,
train_loader,
val_loader,
train_cell_id_mapping,
val_cell_id_mapping,
num_labels_list,
):
set_seed(config["seed"])
initialize_wandb(config)
model = create_model(config, num_labels_list, device)
total_steps = len(train_loader) * config["epochs"]
optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
log_dir = os.path.join(config["tensorboard_log_dir"], "manual_run")
writer = SummaryWriter(log_dir=log_dir)
epoch_progress = tqdm(range(config["epochs"]), desc="Training Progress")
for epoch in epoch_progress:
last_loss = train_epoch(
model, train_loader, optimizer, scheduler, device, config, writer, epoch
)
epoch_progress.set_postfix({"last_loss": f"{last_loss:.4f}"})
val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
model, val_loader, device, config
)
task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
writer.close()
save_validation_predictions(
val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config
)
if config.get("use_wandb", False):
import wandb
wandb.finish()
print(f"\nFinal Validation Loss: {val_loss:.4f}")
return val_loss, model # Return both the validation loss and the trained model
def objective(
trial,
train_loader,
val_loader,
train_cell_id_mapping,
val_cell_id_mapping,
num_labels_list,
config,
device,
):
set_seed(config["seed"]) # Set the seed before each trial
initialize_wandb(config)
# Hyperparameters
config["learning_rate"] = trial.suggest_float(
"learning_rate",
config["hyperparameters"]["learning_rate"]["low"],
config["hyperparameters"]["learning_rate"]["high"],
log=config["hyperparameters"]["learning_rate"]["log"],
)
config["warmup_ratio"] = trial.suggest_float(
"warmup_ratio",
config["hyperparameters"]["warmup_ratio"]["low"],
config["hyperparameters"]["warmup_ratio"]["high"],
)
config["weight_decay"] = trial.suggest_float(
"weight_decay",
config["hyperparameters"]["weight_decay"]["low"],
config["hyperparameters"]["weight_decay"]["high"],
)
config["dropout_rate"] = trial.suggest_float(
"dropout_rate",
config["hyperparameters"]["dropout_rate"]["low"],
config["hyperparameters"]["dropout_rate"]["high"],
)
config["lr_scheduler_type"] = trial.suggest_categorical(
"lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"]
)
config["use_attention_pooling"] = trial.suggest_categorical(
"use_attention_pooling", [False]
)
if config["use_task_weights"]:
config["task_weights"] = [
trial.suggest_float(
f"task_weight_{i}",
config["hyperparameters"]["task_weights"]["low"],
config["hyperparameters"]["task_weights"]["high"],
)
for i in range(len(num_labels_list))
]
weight_sum = sum(config["task_weights"])
config["task_weights"] = [
weight / weight_sum for weight in config["task_weights"]
]
else:
config["task_weights"] = None
# Dynamic range for max_layers_to_freeze
freeze_range = get_layer_freeze_range(config["pretrained_path"])
config["max_layers_to_freeze"] = trial.suggest_int(
"max_layers_to_freeze",
freeze_range["min"],
freeze_range["max"]
)
model = create_model(config, num_labels_list, device)
total_steps = len(train_loader) * config["epochs"]
optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
log_dir = os.path.join(config["tensorboard_log_dir"], f"trial_{trial.number}")
writer = SummaryWriter(log_dir=log_dir)
for epoch in range(config["epochs"]):
train_epoch(
model, train_loader, optimizer, scheduler, device, config, writer, epoch
)
val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
model, val_loader, device, config
)
task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
writer.close()
save_validation_predictions(
val_cell_id_mapping,
task_true_labels,
task_pred_labels,
task_pred_probs,
config,
trial.number,
)
trial.set_user_attr("model_state_dict", model.state_dict())
trial.set_user_attr("task_weights", config["task_weights"])
trial.report(val_loss, config["epochs"])
if trial.should_prune():
raise optuna.TrialPruned()
if config.get("use_wandb", False):
import wandb
wandb.log(
{
"trial_number": trial.number,
"val_loss": val_loss,
**{
f"{task_name}_f1": metrics["f1"]
for task_name, metrics in task_metrics.items()
},
**{
f"{task_name}_accuracy": metrics["accuracy"]
for task_name, metrics in task_metrics.items()
},
**{
k: v
for k, v in config.items()
if k
in [
"learning_rate",
"warmup_ratio",
"weight_decay",
"dropout_rate",
"lr_scheduler_type",
"use_attention_pooling",
"max_layers_to_freeze",
]
},
}
)
wandb.finish()
return val_loss