ProtHGT / visualize_kg.py
Erva Ulusoy
updated node titles to contain node name instead of id
e0fbc94
from pyvis.network import Network
import os
import json
import gzip
NODE_TYPE_COLORS = {
'Disease': '#079dbb',
'HPO': '#58d0e8',
'Drug': '#815ac0',
'Compound': '#d2b7e5',
'Domain': '#6bbf59',
'GO_term_P': '#ff8800',
'GO_term_F': '#ffaa00',
'GO_term_C': '#ffc300',
'Pathway': '#720026',
'kegg_Pathway': '#720026',
'EC_number': '#ce4257',
'Protein': '#3aa6a4'
}
EDGE_LABEL_TRANSLATION = {
'Orthology': 'is ortholog to',
'Pathway': 'takes part in',
'kegg_path_prot': 'takes part in',
'protein_domain': 'has',
'PPI': 'interacts with',
'HPO': 'is associated with',
'kegg_dis_prot': 'is related to',
'Disease': 'is related to',
'Drug': 'targets',
'protein_ec': 'catalyzes',
'Chembl': 'targets',
('protein_function', 'GO_term_F'): 'enables',
('protein_function', 'GO_term_P'): 'is involved in',
('protein_function', 'GO_term_C'): 'localizes to',
}
NODE_LABEL_TRANSLATION = {
'HPO': 'Phenotype',
'GO_term_P': 'Biological Process',
'GO_term_F': 'Molecular Function',
'GO_term_C': 'Cellular Component',
'kegg_Pathway': 'Pathway',
'EC_number': 'EC Number',
}
GO_CATEGORY_MAPPING = {
'Biological Process': 'GO_term_P',
'Molecular Function': 'GO_term_F',
'Cellular Component': 'GO_term_C'
}
def get_node_url(node_type, node_id):
"""Get the URL for a node based on its type and ID"""
if node_type.startswith('GO_term'):
return f"https://www.ebi.ac.uk/QuickGO/term/{node_id}"
elif node_type == 'Protein':
return f"https://www.uniprot.org/uniprotkb/{node_id}/entry"
elif node_type == 'Disease':
if ':' in node_id:
ontology = node_id.split(':')[0]
if ontology == 'EFO':
return f"http://www.ebi.ac.uk/efo/EFO_{node_id.split(':')[1]}"
elif ontology == 'MONDO':
return f'http://purl.obolibrary.org/obo/MONDO_{node_id.split(":")[1]}'
elif ontology == 'Orphanet':
return f"http://www.orpha.net/ORDO/Orphanet_{node_id.split(':')[1]}"
else:
return f"https://www.genome.jp/entry/{node_id}"
elif node_type == 'HPO':
return f"https://hpo.jax.org/browse/term/{node_id}"
elif node_type == 'Drug':
return f"https://go.drugbank.com/drugs/{node_id}"
elif node_type == 'Compound':
return f"https://www.ebi.ac.uk/chembl/explore/compound/{node_id}"
elif node_type == 'Domain':
return f"https://www.ebi.ac.uk/interpro/entry/InterPro/{node_id}"
elif node_type == 'Pathway':
return f"https://reactome.org/content/detail/{node_id}"
elif node_type == 'kegg_Pathway':
return f"https://www.genome.jp/pathway/{node_id}"
elif node_type == 'EC_number':
return f"https://enzyme.expasy.org/EC/{node_id}"
else:
return None
def _gather_protein_edges(data, protein_id):
protein_idx = data['Protein']['id_mapping'][protein_id]
reverse_id_mapping = {}
for node_type in data.node_types:
reverse_id_mapping[node_type] = {v:k for k, v in data[node_type]['id_mapping'].items()}
protein_edges = {}
print(f'Gathering edges for {protein_id}...')
for edge_type in data.edge_types:
if 'rev' not in edge_type[1]:
if edge_type not in protein_edges:
protein_edges[edge_type] = []
if edge_type[0] == 'Protein':
print(f'Gathering edges for {edge_type}...')
# append the edges with protein_idx as source node
edges = data[edge_type].edge_index[:, data[edge_type].edge_index[0] == protein_idx]
protein_edges[edge_type].extend(edges.T.tolist())
elif edge_type[2] == 'Protein':
print(f'Gathering edges for {edge_type}...')
# append the edges with protein_idx as target node
edges = data[edge_type].edge_index[:, data[edge_type].edge_index[1] == protein_idx]
protein_edges[edge_type].extend(edges.T.tolist())
for edge_type in protein_edges.keys():
if protein_edges[edge_type]:
mapped_edges = set()
for edge in protein_edges[edge_type]:
# Get source and target node types from edge_type
source_type, _, target_type = edge_type
# Map indices back to original IDs
source_id = reverse_id_mapping[source_type][edge[0]]
target_id = reverse_id_mapping[target_type][edge[1]]
mapped_edges.add((source_id, target_id))
protein_edges[edge_type] = mapped_edges
return protein_edges
def _filter_edges(protein_id, protein_edges, prediction_df, limit=10):
filtered_edges = {}
prediction_categories = prediction_df['GO_category'].unique()
prediction_categories = [GO_CATEGORY_MAPPING[category] for category in prediction_categories]
go_category_reverse_mapping = {v:k for k, v in GO_CATEGORY_MAPPING.items()}
for edge_type, edges in protein_edges.items():
# Skip if edges is empty
if edges is None or len(edges) == 0:
continue
if edge_type[2].startswith('GO_term'): # Check if it's any GO term edge
if edge_type[2] in prediction_categories:
# Handle edges for GO terms that are in prediction_df
category_mask = (prediction_df['GO_category'] == go_category_reverse_mapping[edge_type[2]]) & (prediction_df['UniProt_ID'] == protein_id)
category_predictions = prediction_df[category_mask]
if len(category_predictions) > 0:
category_predictions = category_predictions.sort_values(by='Probability', ascending=False)
edges_set = set(edges) # Convert to set for O(1) lookup
valid_edges = []
for _, row in category_predictions.iterrows():
term = row['GO_ID']
prob = row['Probability']
edge = (protein_id, term)
is_ground_truth = edge in edges_set
valid_edges.append((edge, prob, is_ground_truth))
if len(valid_edges) >= limit:
break
filtered_edges[edge_type] = valid_edges
else:
# If no predictions but it's a GO category in prediction_df
filtered_edges[edge_type] = [(edge, 'no_pred', True) for edge in list(edges)[:limit]]
else:
# For GO terms not in prediction_df, mark them as ground truth with blue color
filtered_edges[edge_type] = [(edge, 'no_pred', True) for edge in list(edges)[:limit]]
else:
# For non-GO edges, include all edges up to limit
filtered_edges[edge_type] = [(edge, None, True) for edge in list(edges)[:limit]]
return filtered_edges
def visualize_protein_subgraph(data, protein_id, prediction_df, limit=10):
with gzip.open('data/name_info.json.gz', 'rt', encoding='utf-8') as file:
name_info = json.load(file)
protein_edges = _gather_protein_edges(data, protein_id)
visualized_edges = _filter_edges(protein_id, protein_edges, prediction_df, limit)
print(f'Edges to be visualized: {visualized_edges}')
net = Network(height="600px", width="100%", directed=True, notebook=False)
# Create groups configuration from NODE_TYPE_COLORS
groups_config = {}
for node_type, color in NODE_TYPE_COLORS.items():
groups_config[node_type] = {
"color": {"background": color, "border": color}
}
# Convert groups_config to a JSON-compatible string
groups_json = json.dumps(groups_config)
# Configure physics options with settings for better clustering
net.set_options("""{
"physics": {
"enabled": true,
"barnesHut": {
"gravitationalConstant": -1000,
"springLength": 250,
"springConstant": 0.001,
"damping": 0.09,
"avoidOverlap": 0
},
"forceAtlas2Based": {
"gravitationalConstant": -50,
"centralGravity": 0.01,
"springLength": 100,
"springConstant": 0.08,
"damping": 0.4,
"avoidOverlap": 0
},
"solver": "barnesHut",
"stabilization": {
"enabled": true,
"iterations": 1000,
"updateInterval": 25
}
},
"layout": {
"improvedLayout": true,
"hierarchical": {
"enabled": false
}
},
"interaction": {
"hover": true,
"navigationButtons": true,
"multiselect": true
},
"configure": {
"enabled": false,
"filter": ["physics", "layout", "manipulation"],
"showButton": true
},
"groups": """ + groups_json + "}")
# Add the main protein node
query_node_url = get_node_url('Protein', protein_id)
node_name = name_info['Protein'][protein_id]
query_node_title = f"{node_name} (Query Protein)"
if query_node_url:
query_node_title = f'<a href="{query_node_url}" target="_blank">{query_node_title}</a>'
net.add_node(protein_id,
label=protein_id,
title=query_node_title,
color={'background': 'white', 'border': '#c1121f'},
borderWidth=4,
shape="dot",
font={'color': '#000000', 'size': 15},
group='Protein',
size=30,
mass=2.5)
# Track added nodes to avoid duplication
added_nodes = {protein_id}
# Add edges and target nodes
for edge_type, edges in visualized_edges.items():
source_type, relation_type, target_type = edge_type
if relation_type == 'protein_function':
relation_type = EDGE_LABEL_TRANSLATION[(relation_type, target_type)]
else:
relation_type = EDGE_LABEL_TRANSLATION[relation_type]
for edge_info in edges:
edge, probability, is_ground_truth = edge_info
source, target = edge[0], edge[1]
source_str = str(source)
target_str = str(target)
# Add source node if not present
if source_str not in added_nodes:
if not source_type.startswith('GO_term'):
node_name = name_info[source_type][source_str]
else:
node_name = name_info['GO_term'][source_str]
url = get_node_url(source_type, source_str)
title = f"{node_name} ({NODE_LABEL_TRANSLATION[source_type] if source_type in NODE_LABEL_TRANSLATION else source_type})"
if url:
title = f'<a href="{url}" target="_blank">{title}</a>'
net.add_node(source_str,
label=source_str,
shape="dot",
font={'color': '#000000', 'size': 12},
title=title,
group=source_type,
size=15,
mass=1.5)
added_nodes.add(source_str)
# Add target node if not present
if target_str not in added_nodes:
if not target_type.startswith('GO_term'):
node_name = name_info[target_type][target_str]
else:
node_name = name_info['GO_term'][target_str]
url = get_node_url(target_type, target_str)
title = f"{node_name} ({NODE_LABEL_TRANSLATION[target_type] if target_type in NODE_LABEL_TRANSLATION else target_type})"
if url:
title = f'<a href="{url}" target="_blank">{title}</a>'
net.add_node(target_str,
label=target_str,
shape="dot",
font={'color': '#000000', 'size': 12},
title=title,
group=target_type,
size=15,
mass=1.5)
added_nodes.add(target_str)
# Add edge with relationship type and probability as label
edge_label = f"{relation_type}"
if probability is not None:
if probability == 'no_pred':
edge_color = '#219ebc'
edge_label += ' (P=Not generated)'
else:
edge_label += f" (P={probability:.2f})"
edge_color = '#8338ec' if is_ground_truth else '#c1121f'
# if validated prediction purple, if non-validated prediction red, if no prediction (directly from database) blue
net.add_edge(source_str, target_str,
label=edge_label,
font={'size': 0},
color=edge_color,
title=edge_label,
length=200,
smooth={'type': 'curvedCW', 'roundness': 0.1})
else:
net.add_edge(source_str, target_str,
label=edge_label,
font={'size': 0},
color='#666666', # Keep default gray for non-GO edges
title=edge_label,
length=200,
smooth={'type': 'curvedCW', 'roundness': 0.1})
# LEGEND
legend_html = """
<style>
.kg-legend {
margin-top: 20px;
padding: 20px;
border: 1px solid #ddd;
border-radius: 5px;
font-family: Arial, sans-serif;
display: flex;
gap: 20px;
}
.legend-section-nodes {
flex: 2; /* Takes up 2/3 of the space */
}
.legend-section-edges {
flex: 1; /* Takes up 1/3 of the space */
}
.legend-title {
margin-bottom: 15px;
color: #333;
font-size: 16px;
font-weight: bold;
}
.nodes-grid {
display: grid;
grid-template-columns: repeat(2, 1fr);
gap: 12px;
}
.edges-grid {
display: grid;
grid-template-columns: 1fr;
gap: 12px;
}
.legend-item {
display: flex;
align-items: center;
padding: 4px;
}
.node-indicator {
width: 15px;
height: 15px;
border-radius: 50%;
margin-right: 10px;
flex-shrink: 0;
}
.edge-indicator {
width: 40px;
height: 3px;
margin-right: 10px;
flex-shrink: 0;
}
.legend-label {
font-size: 14px;
}
</style>
<div class="kg-legend">
<div class="legend-section-nodes">
<div class="legend-title">Node Types</div>
<div class="nodes-grid">"""
# Node types in 2 columns
for node_type, color in NODE_TYPE_COLORS.items():
if node_type == 'kegg_Pathway':
continue
if node_type in NODE_LABEL_TRANSLATION:
node_label = NODE_LABEL_TRANSLATION[node_type]
else:
node_label = node_type
legend_html += f"""
<div class="legend-item">
<div class="node-indicator" style="background-color: {color};"></div>
<span class="legend-label">{node_label}</span>
</div>"""
# Edge types in 1 column
legend_html += """
</div>
</div>
<div class="legend-section-edges">
<div class="legend-title">Edge Colors</div>
<div class="edges-grid">
<div class="legend-item">
<div class="edge-indicator" style="background-color: #8338ec;"></div>
<span class="legend-label">Validated GO Prediction</span>
</div>
<div class="legend-item">
<div class="edge-indicator" style="background-color: #c1121f;"></div>
<span class="legend-label">Non-validated GO Prediction</span>
</div>
<div class="legend-item">
<div class="edge-indicator" style="background-color: #219ebc;"></div>
<span class="legend-label">Ground Truth GO Annotation</span>
</div>
<div class="legend-item">
<div class="edge-indicator" style="background-color: #666666;"></div>
<span class="legend-label">Other Relationships</span>
</div>
</div>
</div>
</div>
"""
# Save graph to a protein-specific file in a temporary directory
os.makedirs('temp_viz', exist_ok=True)
file_path = os.path.join('temp_viz', f'{protein_id}_graph.html')
net.save_graph(file_path)
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# Insert the legend before the closing body tag
content = content.replace('</body>', f'{legend_html}</body>')
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
return file_path, visualized_edges