VatsalPatel18 commited on
Commit
b814581
·
verified ·
1 Parent(s): 2597eee

Delete train.py

Browse files
Files changed (1) hide show
  1. train.py +0 -105
train.py DELETED
@@ -1,105 +0,0 @@
1
- import torch
2
- from torch.utils.data import DataLoader
3
- from torch_geometric.data import Batch
4
- from sklearn.model_selection import train_test_split
5
- import pickle
6
-
7
- from OmicsConfig import OmicsConfig
8
- from MultiOmicsGraphAttentionAutoencoderModel import MultiOmicsGraphAttentionAutoencoderModel
9
- from GATv2EncoderModel import GATv2EncoderModel
10
- from GATv2DecoderModel import GATv2DecoderModel
11
- from EdgeWeightPredictorModel import EdgeWeightPredictorModel
12
-
13
- def collate_graph_data(batch):
14
- return Batch.from_data_list(batch)
15
-
16
- def create_data_loader(graph_data_dict, batch_size=1, shuffle=True):
17
- graph_data = list(graph_data_dict.values())
18
- return DataLoader(graph_data, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_graph_data)
19
-
20
- # Load your data
21
- graph_data_dict = torch.load('data/graph_data_dictN.pth')
22
-
23
- # Split the data
24
- train_data, temp_data = train_test_split(list(graph_data_dict.items()), train_size=0.6, random_state=42)
25
- val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42)
26
-
27
- # Convert lists back into dictionaries
28
- train_data = dict(train_data)
29
- val_data = dict(val_data)
30
- test_data = dict(test_data)
31
-
32
- # Define the configuration for the model
33
- autoencoder_config = OmicsConfig(
34
- in_channels=17,
35
- edge_attr_channels=1,
36
- out_channels=1,
37
- original_feature_size=17,
38
- learning_rate=0.01,
39
- num_layers=2,
40
- edge_decoder_hidden_sizes=[128, 64],
41
- edge_decoder_activations=['ReLU', 'ReLU']
42
- )
43
-
44
- # Initialize the model
45
- autoencoder_model = MultiOmicsGraphAttentionAutoencoderModel(autoencoder_config)
46
-
47
- # Create data loaders
48
- train_loader = create_data_loader(train_data, batch_size=4, shuffle=True)
49
- val_loader = create_data_loader(val_data, batch_size=4, shuffle=False)
50
- test_loader = create_data_loader(test_data, batch_size=4, shuffle=False)
51
-
52
- # Define the device
53
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
54
-
55
- # Training process
56
- def train_autoencoder(autoencoder_model, train_loader, validation_loader, epochs, device):
57
- autoencoder_model.to(device)
58
- train_losses = []
59
- val_losses = []
60
-
61
- for epoch in range(epochs):
62
- # Train
63
- autoencoder_model.train()
64
- train_loss, train_cosine_similarity = autoencoder_model.train_model(train_loader, device)
65
- print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Train Cosine Similarity: {train_cosine_similarity:.4f}")
66
- train_losses.append(train_loss)
67
-
68
- # Validate
69
- autoencoder_model.eval()
70
- val_loss, val_cosine_similarity = autoencoder_model.validate(validation_loader, device)
71
- print(f"Epoch {epoch+1}/{epochs}, Validation Loss: {val_loss:.4f}, Validation Cosine Similarity: {val_cosine_similarity:.4f}")
72
- val_losses.append(val_loss)
73
-
74
- # Save the trained encoder weights
75
- trained_encoder_path = "lc_models/MultiOmicsAutoencoder/trained_encoder"
76
- autoencoder_model.encoder.save_pretrained(trained_encoder_path)
77
-
78
- # Save the trained decoder weights
79
- trained_decoder_path = "lc_models/MultiOmicsAutoencoder/trained_decoder"
80
- autoencoder_model.decoder.save_pretrained(trained_decoder_path)
81
-
82
- # Save the trained edge weight predictor weights (if needed separately)
83
- trained_edge_weight_predictor_path = "lc_models/MultiOmicsAutoencoder/trained_edge_weight_predictor"
84
- autoencoder_model.decoder.edge_weight_predictor.save_pretrained(trained_edge_weight_predictor_path)
85
-
86
- # Optionally save the entire autoencoder again if you want to have a complete package
87
- trained_autoencoder_path = "lc_models/MultiOmicsAutoencoder/trained_autoencoder"
88
- autoencoder_model.save_pretrained(trained_autoencoder_path)
89
-
90
- return train_losses, val_losses
91
-
92
- # Train and save the model
93
- train_losses, val_losses = train_autoencoder(autoencoder_model, train_loader, val_loader, epochs=10, device=device)
94
-
95
- # Evaluate the model
96
- test_loss, test_accuracy = autoencoder_model.evaluate(test_loader, device)
97
- print(f"Test Loss: {test_loss:.4f}")
98
- print(f"Test Accuracy: {test_accuracy:.4%}")
99
-
100
- # Save the training and validation losses
101
- with open('./results/train_loss.pkl', 'wb') as f:
102
- pickle.dump(train_losses, f)
103
-
104
- with open('./results/val_loss.pkl', 'wb') as f:
105
- pickle.dump(val_losses, f)