Spaces:
Runtime error
Runtime error
VatsalPatel18
commited on
Delete train.py
Browse files
train.py
DELETED
@@ -1,105 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch.utils.data import DataLoader
|
3 |
-
from torch_geometric.data import Batch
|
4 |
-
from sklearn.model_selection import train_test_split
|
5 |
-
import pickle
|
6 |
-
|
7 |
-
from OmicsConfig import OmicsConfig
|
8 |
-
from MultiOmicsGraphAttentionAutoencoderModel import MultiOmicsGraphAttentionAutoencoderModel
|
9 |
-
from GATv2EncoderModel import GATv2EncoderModel
|
10 |
-
from GATv2DecoderModel import GATv2DecoderModel
|
11 |
-
from EdgeWeightPredictorModel import EdgeWeightPredictorModel
|
12 |
-
|
13 |
-
def collate_graph_data(batch):
|
14 |
-
return Batch.from_data_list(batch)
|
15 |
-
|
16 |
-
def create_data_loader(graph_data_dict, batch_size=1, shuffle=True):
|
17 |
-
graph_data = list(graph_data_dict.values())
|
18 |
-
return DataLoader(graph_data, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_graph_data)
|
19 |
-
|
20 |
-
# Load your data
|
21 |
-
graph_data_dict = torch.load('data/graph_data_dictN.pth')
|
22 |
-
|
23 |
-
# Split the data
|
24 |
-
train_data, temp_data = train_test_split(list(graph_data_dict.items()), train_size=0.6, random_state=42)
|
25 |
-
val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42)
|
26 |
-
|
27 |
-
# Convert lists back into dictionaries
|
28 |
-
train_data = dict(train_data)
|
29 |
-
val_data = dict(val_data)
|
30 |
-
test_data = dict(test_data)
|
31 |
-
|
32 |
-
# Define the configuration for the model
|
33 |
-
autoencoder_config = OmicsConfig(
|
34 |
-
in_channels=17,
|
35 |
-
edge_attr_channels=1,
|
36 |
-
out_channels=1,
|
37 |
-
original_feature_size=17,
|
38 |
-
learning_rate=0.01,
|
39 |
-
num_layers=2,
|
40 |
-
edge_decoder_hidden_sizes=[128, 64],
|
41 |
-
edge_decoder_activations=['ReLU', 'ReLU']
|
42 |
-
)
|
43 |
-
|
44 |
-
# Initialize the model
|
45 |
-
autoencoder_model = MultiOmicsGraphAttentionAutoencoderModel(autoencoder_config)
|
46 |
-
|
47 |
-
# Create data loaders
|
48 |
-
train_loader = create_data_loader(train_data, batch_size=4, shuffle=True)
|
49 |
-
val_loader = create_data_loader(val_data, batch_size=4, shuffle=False)
|
50 |
-
test_loader = create_data_loader(test_data, batch_size=4, shuffle=False)
|
51 |
-
|
52 |
-
# Define the device
|
53 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
54 |
-
|
55 |
-
# Training process
|
56 |
-
def train_autoencoder(autoencoder_model, train_loader, validation_loader, epochs, device):
|
57 |
-
autoencoder_model.to(device)
|
58 |
-
train_losses = []
|
59 |
-
val_losses = []
|
60 |
-
|
61 |
-
for epoch in range(epochs):
|
62 |
-
# Train
|
63 |
-
autoencoder_model.train()
|
64 |
-
train_loss, train_cosine_similarity = autoencoder_model.train_model(train_loader, device)
|
65 |
-
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Train Cosine Similarity: {train_cosine_similarity:.4f}")
|
66 |
-
train_losses.append(train_loss)
|
67 |
-
|
68 |
-
# Validate
|
69 |
-
autoencoder_model.eval()
|
70 |
-
val_loss, val_cosine_similarity = autoencoder_model.validate(validation_loader, device)
|
71 |
-
print(f"Epoch {epoch+1}/{epochs}, Validation Loss: {val_loss:.4f}, Validation Cosine Similarity: {val_cosine_similarity:.4f}")
|
72 |
-
val_losses.append(val_loss)
|
73 |
-
|
74 |
-
# Save the trained encoder weights
|
75 |
-
trained_encoder_path = "lc_models/MultiOmicsAutoencoder/trained_encoder"
|
76 |
-
autoencoder_model.encoder.save_pretrained(trained_encoder_path)
|
77 |
-
|
78 |
-
# Save the trained decoder weights
|
79 |
-
trained_decoder_path = "lc_models/MultiOmicsAutoencoder/trained_decoder"
|
80 |
-
autoencoder_model.decoder.save_pretrained(trained_decoder_path)
|
81 |
-
|
82 |
-
# Save the trained edge weight predictor weights (if needed separately)
|
83 |
-
trained_edge_weight_predictor_path = "lc_models/MultiOmicsAutoencoder/trained_edge_weight_predictor"
|
84 |
-
autoencoder_model.decoder.edge_weight_predictor.save_pretrained(trained_edge_weight_predictor_path)
|
85 |
-
|
86 |
-
# Optionally save the entire autoencoder again if you want to have a complete package
|
87 |
-
trained_autoencoder_path = "lc_models/MultiOmicsAutoencoder/trained_autoencoder"
|
88 |
-
autoencoder_model.save_pretrained(trained_autoencoder_path)
|
89 |
-
|
90 |
-
return train_losses, val_losses
|
91 |
-
|
92 |
-
# Train and save the model
|
93 |
-
train_losses, val_losses = train_autoencoder(autoencoder_model, train_loader, val_loader, epochs=10, device=device)
|
94 |
-
|
95 |
-
# Evaluate the model
|
96 |
-
test_loss, test_accuracy = autoencoder_model.evaluate(test_loader, device)
|
97 |
-
print(f"Test Loss: {test_loss:.4f}")
|
98 |
-
print(f"Test Accuracy: {test_accuracy:.4%}")
|
99 |
-
|
100 |
-
# Save the training and validation losses
|
101 |
-
with open('./results/train_loss.pkl', 'wb') as f:
|
102 |
-
pickle.dump(train_losses, f)
|
103 |
-
|
104 |
-
with open('./results/val_loss.pkl', 'wb') as f:
|
105 |
-
pickle.dump(val_losses, f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|