|
import torch |
|
import argparse |
|
from torch import optim |
|
from torch.utils.data import DataLoader |
|
from scripts.genomic_plip_model import GenomicPLIPModel |
|
from scripts.tile_file_dataloader import FlatTileDataset |
|
from transformers import CLIPVisionModel |
|
|
|
def train_model(data_dir, model_save_path, pretrained_model_path, lr, num_epochs, train_batch_size, validation_batch_size, num_workers): |
|
|
|
|
|
train_dataset = FlatTileDataset(data_dir=f'{data_dir}/train') |
|
train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=num_workers) |
|
|
|
validation_dataset = FlatTileDataset(data_dir=f'{data_dir}/validate') |
|
validation_data_loader = DataLoader(validation_dataset, batch_size=validation_batch_size, shuffle=False, num_workers=num_workers) |
|
|
|
|
|
base_model = CLIPVisionModel.from_pretrained(pretrained_model_path) |
|
custom_model = GenomicPLIPModel(base_model) |
|
|
|
criterion = torch.nn.CosineSimilarity(dim=1) |
|
optimizer = optim.Adam(custom_model.parameters(), lr=lr) |
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
|
custom_model.train() |
|
train_loss = 0.0 |
|
|
|
for batch_images, batch_scores in train_data_loader: |
|
optimizer.zero_grad() |
|
|
|
batch_loss = 0 |
|
for img, score in zip(batch_images, batch_scores): |
|
vision_features, score_features = custom_model(img.unsqueeze(0), score.unsqueeze(0)) |
|
cos_sim = criterion(score_features, vision_features) |
|
loss = 1-cos_sim.mean() |
|
|
|
batch_loss += loss.item() |
|
loss.backward() |
|
|
|
optimizer.step() |
|
train_loss += batch_loss |
|
print(f"Batch Cosine Similarity {batch_loss:.4f}") |
|
|
|
avg_train_loss = train_loss / len(train_data_loader) |
|
print(f"Epoch [{epoch+1}/{num_epochs}], Training Cosine Similarity: {avg_train_loss:.4f}") |
|
|
|
|
|
custom_model.eval() |
|
validation_loss = 0.0 |
|
|
|
with torch.no_grad(): |
|
for batch_images, batch_scores in validation_data_loader: |
|
batch_loss = 0 |
|
for img, score in zip(batch_images, batch_scores): |
|
vision_features, score_features = custom_model(img.unsqueeze(0), score.unsqueeze(0)) |
|
cos_sim = criterion(score_features, vision_features) |
|
loss = 1-cos_sim.mean() |
|
|
|
batch_loss += loss.item() |
|
|
|
validation_loss += batch_loss |
|
print(f"Validation Batch Cosine Similarity {batch_loss:.4f}") |
|
|
|
avg_validation_loss = validation_loss / len(validation_data_loader) |
|
print(f"Epoch [{epoch+1}/{num_epochs}], Validation Cosine Similarity: {avg_validation_loss:.4f}") |
|
|
|
|
|
torch.save(custom_model.state_dict(), model_save_path) |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description='Train the Genomic PLIP Model') |
|
parser.add_argument('--data_dir', type=str, default='Datasets/train_03', help='Directory containing the train, validate, and test datasets.') |
|
parser.add_argument('--model_save_path', type=str, default='genomic_plip.pth', help='Path to save the trained model.') |
|
parser.add_argument('--pretrained_model_path', type=str, default='./plip', help='Path to the pretrained CLIP model.') |
|
|
|
parser.add_argument('--lr', type=float, default=0.00001, help='Learning rate for the optimizer.') |
|
parser.add_argument('--num_epochs', type=int, default=1, help='Number of epochs to train for.') |
|
parser.add_argument('--train_batch_size', type=int, default=128, help='Batch size for the training data loader.') |
|
parser.add_argument('--validation_batch_size', type=int, default=128, help='Batch size for the validation data loader.') |
|
parser.add_argument('--num_workers', type=int, default=32, help='Number of worker threads for data loading.') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
train_model(args.data_dir, args.model_save_path, args.pretrained_model_path, args.lr, args.num_epochs, args.train_batch_size, args.validation_batch_size, args.num_workers) |
|
|
|
|