ProtHGT / run_prothgt_app.py
Erva Ulusoy
updated requirements
126dbe4
raw
history blame
7.68 kB
import torch
from torch.nn import Linear
from torch_geometric.nn import HGTConv, MLP
import pandas as pd
import yaml
import os
from datasets import load_dataset
import gdown
class ProtHGT(torch.nn.Module):
def __init__(self, data,hidden_channels, num_heads, num_layers, mlp_hidden_layers, mlp_dropout):
super().__init__()
self.lin_dict = torch.nn.ModuleDict()
for node_type in data.node_types:
input_dim = data[node_type].x.size(1) # Get actual input dimension from data
self.lin_dict[node_type] = Linear(input_dim, hidden_channels)
self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = HGTConv(hidden_channels, hidden_channels, data.metadata(), num_heads, group='sum')
self.convs.append(conv)
self.mlp = MLP(mlp_hidden_layers , dropout=mlp_dropout, norm=None)
def generate_embeddings(self, x_dict, edge_index_dict):
# Generate updated embeddings through the HGT layers
x_dict = {
node_type: self.lin_dict[node_type](x).relu_()
for node_type, x in x_dict.items()
}
for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)
return x_dict
def forward(self, x_dict, edge_index_dict, tr_edge_label_index, target_type, test=False):
# Get updated embeddings
x_dict = self.generate_embeddings(x_dict, edge_index_dict)
# Make predictions
row, col = tr_edge_label_index
z = torch.cat([x_dict["Protein"][row], x_dict[target_type][col]], dim=-1)
return self.mlp(z).view(-1), x_dict
def _load_data(heterodata, protein_ids, go_category=None):
"""Process the loaded heterodata for specific proteins and GO categories."""
# Get protein indices for all input proteins
protein_indices = [heterodata['Protein']['id_mapping'][pid] for pid in protein_ids]
# Create edge indices for prediction
categories = [go_category] if go_category else ['GO_term_F', 'GO_term_P', 'GO_term_C']
for category in categories:
# Create pairs for all proteins with all GO terms
n_terms = len(heterodata[category]['id_mapping'])
protein_indices_repeated = torch.tensor(protein_indices).repeat_interleave(n_terms)
term_indices = torch.arange(n_terms).repeat(len(protein_indices))
edge_index = torch.stack([protein_indices_repeated, term_indices])
heterodata.edge_index_dict[('Protein', 'protein_function', category)] = edge_index
return heterodata
def get_available_proteins(protein_list_file='data/available_proteins.txt'):
with open(protein_list_file, 'r') as file:
return [line.strip() for line in file.readlines()]
def _generate_predictions(heterodata, model, target_type):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()
heterodata = heterodata.to(device)
with torch.no_grad():
edge_label_index = heterodata.edge_index_dict[('Protein', 'protein_function', target_type)]
predictions, _ = model(heterodata.x_dict, heterodata.edge_index_dict, edge_label_index, target_type)
predictions = torch.sigmoid(predictions)
return predictions.cpu()
def _create_prediction_df(predictions, heterodata, protein_ids, go_category):
go_category_dict = {
'GO_term_F': 'Molecular Function',
'GO_term_P': 'Biological Process',
'GO_term_C': 'Cellular Component'
}
# Create a list to store individual protein predictions
all_predictions = []
# Number of GO terms for this category
n_go_terms = len(heterodata[go_category]['id_mapping'])
# Process predictions for each protein
for i, protein_id in enumerate(protein_ids):
# Get the slice of predictions for this protein
protein_predictions = predictions[i * n_go_terms:(i + 1) * n_go_terms]
prediction_df = pd.DataFrame({
'Protein': protein_id,
'GO_category': go_category_dict[go_category],
'GO_term': list(heterodata[go_category]['id_mapping'].keys()),
'Probability': protein_predictions.numpy()
})
all_predictions.append(prediction_df)
# Combine all predictions
combined_df = pd.concat(all_predictions, ignore_index=True)
combined_df.sort_values(by=['Protein', 'Probability'], ascending=[True, False], inplace=True)
combined_df.reset_index(drop=True, inplace=True)
return combined_df
def generate_prediction_df(protein_ids, model_paths, model_config_paths, go_category):
all_predictions = []
# Convert single protein ID to list if necessary
if isinstance(protein_ids, str):
protein_ids = [protein_ids]
# Load dataset once
# heterodata = load_dataset('HUBioDataLab/ProtHGT-KG', data_files="prothgt-kg.json.gz")
print('Loading data...')
file_id = "18u1o2sm8YjMo9joFw4Ilwvg0-rUU0PXK"
output = "data/prothgt-kg.pt"
url = f"https://drive.google.com/uc?id={file_id}"
print(f"Downloading file from {url}...")
try:
gdown.download(url, output, quiet=False)
print(f"File downloaded to {output}")
except Exception as e:
print(f"Error downloading file: {e}")
raise
heterodata = torch.load(output)
print(heterodata.edge_types)
# Remove unnecessary edge types
edge_types_to_remove = [
('Protein', 'protein_function', 'GO_term_F'),
('Protein', 'protein_function', 'GO_term_P'),
('Protein', 'protein_function', 'GO_term_C'),
('GO_term_F', 'rev_protein_function', 'Protein'),
('GO_term_P', 'rev_protein_function', 'Protein'),
('GO_term_C', 'rev_protein_function', 'Protein')
]
for edge_type in edge_types_to_remove:
if edge_type in heterodata.edge_index_dict:
del heterodata.edge_index_dict[edge_type]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for go_cat, model_config_path, model_path in zip(go_category, model_config_paths, model_paths):
print(f'Generating predictions for {go_cat}...')
# Process data for current GO category
processed_data = _load_data(heterodata, protein_ids, go_cat)
# Load model config
with open(model_config_path, 'r') as file:
model_config = yaml.safe_load(file)
# Initialize model with configuration
model = ProtHGT(
processed_data,
hidden_channels=model_config['hidden_channels'][0],
num_heads=model_config['num_heads'],
num_layers=model_config['num_layers'],
mlp_hidden_layers=model_config['hidden_channels'][1],
mlp_dropout=model_config['mlp_dropout']
)
# Load model weights
model.load_state_dict(torch.load(model_path, map_location=device))
print(f'Loaded model weights from {model_path}')
# Generate predictions
predictions = _generate_predictions(processed_data, model, go_cat)
prediction_df = _create_prediction_df(predictions, processed_data, protein_ids, go_cat)
all_predictions.append(prediction_df)
# Clean up memory
del processed_data
del model
del predictions
torch.cuda.empty_cache() # Clear CUDA cache if using GPU
del heterodata
# Combine all predictions
final_df = pd.concat(all_predictions, ignore_index=True)
# Clean up
del all_predictions
torch.cuda.empty_cache()
return final_df