omics-plip-1 / .ipynb_checkpoints /train_omics_plip_model-checkpoint.py
VatsalPatel18's picture
Revized docker file
49f4666
raw
history blame
4.17 kB
import torch
import argparse
from torch import optim
from torch.utils.data import DataLoader
from scripts.genomic_plip_model import GenomicPLIPModel
from scripts.tile_file_dataloader import FlatTileDataset
from transformers import CLIPVisionModel
def train_model(data_dir, model_save_path, pretrained_model_path, lr, num_epochs, train_batch_size, validation_batch_size, num_workers):
# Load datasets
train_dataset = FlatTileDataset(data_dir=f'{data_dir}/train')
train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=num_workers)
validation_dataset = FlatTileDataset(data_dir=f'{data_dir}/validate')
validation_data_loader = DataLoader(validation_dataset, batch_size=validation_batch_size, shuffle=False, num_workers=num_workers)
# Initialize the model
base_model = CLIPVisionModel.from_pretrained(pretrained_model_path)
custom_model = GenomicPLIPModel(base_model)
criterion = torch.nn.CosineSimilarity(dim=1)
optimizer = optim.Adam(custom_model.parameters(), lr=lr)
for epoch in range(num_epochs):
# Training loop
custom_model.train()
train_loss = 0.0
for batch_images, batch_scores in train_data_loader:
optimizer.zero_grad()
batch_loss = 0
for img, score in zip(batch_images, batch_scores):
vision_features, score_features = custom_model(img.unsqueeze(0), score.unsqueeze(0))
cos_sim = criterion(score_features, vision_features)
loss = 1-cos_sim.mean()
batch_loss += loss.item()
loss.backward()
optimizer.step()
train_loss += batch_loss
print(f"Batch Cosine Similarity {batch_loss:.4f}")
avg_train_loss = train_loss / len(train_data_loader)
print(f"Epoch [{epoch+1}/{num_epochs}], Training Cosine Similarity: {avg_train_loss:.4f}")
# Validation loop
custom_model.eval()
validation_loss = 0.0
with torch.no_grad():
for batch_images, batch_scores in validation_data_loader:
batch_loss = 0
for img, score in zip(batch_images, batch_scores):
vision_features, score_features = custom_model(img.unsqueeze(0), score.unsqueeze(0))
cos_sim = criterion(score_features, vision_features)
loss = 1-cos_sim.mean()
batch_loss += loss.item()
validation_loss += batch_loss
print(f"Validation Batch Cosine Similarity {batch_loss:.4f}")
avg_validation_loss = validation_loss / len(validation_data_loader)
print(f"Epoch [{epoch+1}/{num_epochs}], Validation Cosine Similarity: {avg_validation_loss:.4f}")
# Save the trained model
torch.save(custom_model.state_dict(), model_save_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train the Genomic PLIP Model')
parser.add_argument('--data_dir', type=str, default='Datasets/train_03', help='Directory containing the train, validate, and test datasets.')
parser.add_argument('--model_save_path', type=str, default='genomic_plip.pth', help='Path to save the trained model.')
parser.add_argument('--pretrained_model_path', type=str, default='./plip', help='Path to the pretrained CLIP model.')
parser.add_argument('--lr', type=float, default=0.00001, help='Learning rate for the optimizer.')
parser.add_argument('--num_epochs', type=int, default=1, help='Number of epochs to train for.')
parser.add_argument('--train_batch_size', type=int, default=128, help='Batch size for the training data loader.')
parser.add_argument('--validation_batch_size', type=int, default=128, help='Batch size for the validation data loader.')
parser.add_argument('--num_workers', type=int, default=32, help='Number of worker threads for data loading.')
args = parser.parse_args()
train_model(args.data_dir, args.model_save_path, args.pretrained_model_path, args.lr, args.num_epochs, args.train_batch_size, args.validation_batch_size, args.num_workers)