Spaces:
Running
Running
from pyvis.network import Network | |
import os | |
import json | |
import gzip | |
NODE_TYPE_COLORS = { | |
'Disease': '#079dbb', | |
'HPO': '#58d0e8', | |
'Drug': '#815ac0', | |
'Compound': '#d2b7e5', | |
'Domain': '#6bbf59', | |
'GO_term_P': '#ff8800', | |
'GO_term_F': '#ffaa00', | |
'GO_term_C': '#ffc300', | |
'Pathway': '#720026', | |
'kegg_Pathway': '#720026', | |
'EC_number': '#ce4257', | |
'Protein': '#3aa6a4' | |
} | |
EDGE_LABEL_TRANSLATION = { | |
'Orthology': 'is ortholog to', | |
'Pathway': 'takes part in', | |
'kegg_path_prot': 'takes part in', | |
'protein_domain': 'has', | |
'PPI': 'interacts with', | |
'HPO': 'is associated with', | |
'kegg_dis_prot': 'is related to', | |
'Disease': 'is related to', | |
'Drug': 'targets', | |
'protein_ec': 'catalyzes', | |
'Chembl': 'targets', | |
('protein_function', 'GO_term_F'): 'enables', | |
('protein_function', 'GO_term_P'): 'is involved in', | |
('protein_function', 'GO_term_C'): 'localizes to', | |
} | |
NODE_LABEL_TRANSLATION = { | |
'HPO': 'Phenotype', | |
'GO_term_P': 'Biological Process', | |
'GO_term_F': 'Molecular Function', | |
'GO_term_C': 'Cellular Component', | |
'kegg_Pathway': 'Pathway', | |
'EC_number': 'EC Number', | |
} | |
GO_CATEGORY_MAPPING = { | |
'Biological Process': 'GO_term_P', | |
'Molecular Function': 'GO_term_F', | |
'Cellular Component': 'GO_term_C' | |
} | |
def get_node_url(node_type, node_id): | |
"""Get the URL for a node based on its type and ID""" | |
if node_type.startswith('GO_term'): | |
return f"https://www.ebi.ac.uk/QuickGO/term/{node_id}" | |
elif node_type == 'Protein': | |
return f"https://www.uniprot.org/uniprotkb/{node_id}/entry" | |
elif node_type == 'Disease': | |
if ':' in node_id: | |
ontology = node_id.split(':')[0] | |
if ontology == 'EFO': | |
return f"http://www.ebi.ac.uk/efo/EFO_{node_id.split(':')[1]}" | |
elif ontology == 'MONDO': | |
return f'http://purl.obolibrary.org/obo/MONDO_{node_id.split(":")[1]}' | |
elif ontology == 'Orphanet': | |
return f"http://www.orpha.net/ORDO/Orphanet_{node_id.split(':')[1]}" | |
else: | |
return f"https://www.genome.jp/entry/{node_id}" | |
elif node_type == 'HPO': | |
return f"https://hpo.jax.org/browse/term/{node_id}" | |
elif node_type == 'Drug': | |
return f"https://go.drugbank.com/drugs/{node_id}" | |
elif node_type == 'Compound': | |
return f"https://www.ebi.ac.uk/chembl/explore/compound/{node_id}" | |
elif node_type == 'Domain': | |
return f"https://www.ebi.ac.uk/interpro/entry/InterPro/{node_id}" | |
elif node_type == 'Pathway': | |
return f"https://reactome.org/content/detail/{node_id}" | |
elif node_type == 'kegg_Pathway': | |
return f"https://www.genome.jp/pathway/{node_id}" | |
elif node_type == 'EC_number': | |
return f"https://enzyme.expasy.org/EC/{node_id}" | |
else: | |
return None | |
def _gather_protein_edges(data, protein_id): | |
protein_idx = data['Protein']['id_mapping'][protein_id] | |
reverse_id_mapping = {} | |
for node_type in data.node_types: | |
reverse_id_mapping[node_type] = {v:k for k, v in data[node_type]['id_mapping'].items()} | |
protein_edges = {} | |
print(f'Gathering edges for {protein_id}...') | |
for edge_type in data.edge_types: | |
if 'rev' not in edge_type[1]: | |
if edge_type not in protein_edges: | |
protein_edges[edge_type] = [] | |
if edge_type[0] == 'Protein': | |
print(f'Gathering edges for {edge_type}...') | |
# append the edges with protein_idx as source node | |
edges = data[edge_type].edge_index[:, data[edge_type].edge_index[0] == protein_idx] | |
protein_edges[edge_type].extend(edges.T.tolist()) | |
elif edge_type[2] == 'Protein': | |
print(f'Gathering edges for {edge_type}...') | |
# append the edges with protein_idx as target node | |
edges = data[edge_type].edge_index[:, data[edge_type].edge_index[1] == protein_idx] | |
protein_edges[edge_type].extend(edges.T.tolist()) | |
for edge_type in protein_edges.keys(): | |
if protein_edges[edge_type]: | |
mapped_edges = set() | |
for edge in protein_edges[edge_type]: | |
# Get source and target node types from edge_type | |
source_type, _, target_type = edge_type | |
# Map indices back to original IDs | |
source_id = reverse_id_mapping[source_type][edge[0]] | |
target_id = reverse_id_mapping[target_type][edge[1]] | |
mapped_edges.add((source_id, target_id)) | |
protein_edges[edge_type] = mapped_edges | |
return protein_edges | |
def _filter_edges(protein_id, protein_edges, prediction_df, limit=10): | |
filtered_edges = {} | |
prediction_categories = prediction_df['GO_category'].unique() | |
prediction_categories = [GO_CATEGORY_MAPPING[category] for category in prediction_categories] | |
go_category_reverse_mapping = {v:k for k, v in GO_CATEGORY_MAPPING.items()} | |
for edge_type, edges in protein_edges.items(): | |
# Skip if edges is empty | |
if edges is None or len(edges) == 0: | |
continue | |
if edge_type[2].startswith('GO_term'): # Check if it's any GO term edge | |
if edge_type[2] in prediction_categories: | |
# Handle edges for GO terms that are in prediction_df | |
category_mask = (prediction_df['GO_category'] == go_category_reverse_mapping[edge_type[2]]) & (prediction_df['UniProt_ID'] == protein_id) | |
category_predictions = prediction_df[category_mask] | |
if len(category_predictions) > 0: | |
category_predictions = category_predictions.sort_values(by='Probability', ascending=False) | |
edges_set = set(edges) # Convert to set for O(1) lookup | |
valid_edges = [] | |
for _, row in category_predictions.iterrows(): | |
term = row['GO_ID'] | |
prob = row['Probability'] | |
edge = (protein_id, term) | |
is_ground_truth = edge in edges_set | |
valid_edges.append((edge, prob, is_ground_truth)) | |
if len(valid_edges) >= limit: | |
break | |
filtered_edges[edge_type] = valid_edges | |
else: | |
# If no predictions but it's a GO category in prediction_df | |
filtered_edges[edge_type] = [(edge, 'no_pred', True) for edge in list(edges)[:limit]] | |
else: | |
# For GO terms not in prediction_df, mark them as ground truth with blue color | |
filtered_edges[edge_type] = [(edge, 'no_pred', True) for edge in list(edges)[:limit]] | |
else: | |
# For non-GO edges, include all edges up to limit | |
filtered_edges[edge_type] = [(edge, None, True) for edge in list(edges)[:limit]] | |
return filtered_edges | |
def visualize_protein_subgraph(data, protein_id, prediction_df, limit=10): | |
with gzip.open('data/name_info.json.gz', 'rt', encoding='utf-8') as file: | |
name_info = json.load(file) | |
protein_edges = _gather_protein_edges(data, protein_id) | |
visualized_edges = _filter_edges(protein_id, protein_edges, prediction_df, limit) | |
print(f'Edges to be visualized: {visualized_edges}') | |
net = Network(height="600px", width="100%", directed=True, notebook=False) | |
# Create groups configuration from NODE_TYPE_COLORS | |
groups_config = {} | |
for node_type, color in NODE_TYPE_COLORS.items(): | |
groups_config[node_type] = { | |
"color": {"background": color, "border": color} | |
} | |
# Convert groups_config to a JSON-compatible string | |
groups_json = json.dumps(groups_config) | |
# Configure physics options with settings for better clustering | |
net.set_options("""{ | |
"physics": { | |
"enabled": true, | |
"barnesHut": { | |
"gravitationalConstant": -1000, | |
"springLength": 250, | |
"springConstant": 0.001, | |
"damping": 0.09, | |
"avoidOverlap": 0 | |
}, | |
"forceAtlas2Based": { | |
"gravitationalConstant": -50, | |
"centralGravity": 0.01, | |
"springLength": 100, | |
"springConstant": 0.08, | |
"damping": 0.4, | |
"avoidOverlap": 0 | |
}, | |
"solver": "barnesHut", | |
"stabilization": { | |
"enabled": true, | |
"iterations": 1000, | |
"updateInterval": 25 | |
} | |
}, | |
"layout": { | |
"improvedLayout": true, | |
"hierarchical": { | |
"enabled": false | |
} | |
}, | |
"interaction": { | |
"hover": true, | |
"navigationButtons": true, | |
"multiselect": true | |
}, | |
"configure": { | |
"enabled": false, | |
"filter": ["physics", "layout", "manipulation"], | |
"showButton": true | |
}, | |
"groups": """ + groups_json + "}") | |
# Add the main protein node | |
query_node_url = get_node_url('Protein', protein_id) | |
node_name = name_info['Protein'][protein_id] | |
query_node_title = f"{node_name} (Query Protein)" | |
if query_node_url: | |
query_node_title = f'<a href="{query_node_url}" target="_blank">{query_node_title}</a>' | |
net.add_node(protein_id, | |
label=protein_id, | |
title=query_node_title, | |
color={'background': 'white', 'border': '#c1121f'}, | |
borderWidth=4, | |
shape="dot", | |
font={'color': '#000000', 'size': 15}, | |
group='Protein', | |
size=30, | |
mass=2.5) | |
# Track added nodes to avoid duplication | |
added_nodes = {protein_id} | |
# Add edges and target nodes | |
for edge_type, edges in visualized_edges.items(): | |
source_type, relation_type, target_type = edge_type | |
if relation_type == 'protein_function': | |
relation_type = EDGE_LABEL_TRANSLATION[(relation_type, target_type)] | |
else: | |
relation_type = EDGE_LABEL_TRANSLATION[relation_type] | |
for edge_info in edges: | |
edge, probability, is_ground_truth = edge_info | |
source, target = edge[0], edge[1] | |
source_str = str(source) | |
target_str = str(target) | |
# Add source node if not present | |
if source_str not in added_nodes: | |
if not source_type.startswith('GO_term'): | |
node_name = name_info[source_type][source_str] | |
else: | |
node_name = name_info['GO_term'][source_str] | |
url = get_node_url(source_type, source_str) | |
title = f"{node_name} ({NODE_LABEL_TRANSLATION[source_type] if source_type in NODE_LABEL_TRANSLATION else source_type})" | |
if url: | |
title = f'<a href="{url}" target="_blank">{title}</a>' | |
net.add_node(source_str, | |
label=source_str, | |
shape="dot", | |
font={'color': '#000000', 'size': 12}, | |
title=title, | |
group=source_type, | |
size=15, | |
mass=1.5) | |
added_nodes.add(source_str) | |
# Add target node if not present | |
if target_str not in added_nodes: | |
if not target_type.startswith('GO_term'): | |
node_name = name_info[target_type][target_str] | |
else: | |
node_name = name_info['GO_term'][target_str] | |
url = get_node_url(target_type, target_str) | |
title = f"{node_name} ({NODE_LABEL_TRANSLATION[target_type] if target_type in NODE_LABEL_TRANSLATION else target_type})" | |
if url: | |
title = f'<a href="{url}" target="_blank">{title}</a>' | |
net.add_node(target_str, | |
label=target_str, | |
shape="dot", | |
font={'color': '#000000', 'size': 12}, | |
title=title, | |
group=target_type, | |
size=15, | |
mass=1.5) | |
added_nodes.add(target_str) | |
# Add edge with relationship type and probability as label | |
edge_label = f"{relation_type}" | |
if probability is not None: | |
if probability == 'no_pred': | |
edge_color = '#219ebc' | |
edge_label += ' (P=Not generated)' | |
else: | |
edge_label += f" (P={probability:.2f})" | |
edge_color = '#8338ec' if is_ground_truth else '#c1121f' | |
# if validated prediction purple, if non-validated prediction red, if no prediction (directly from database) blue | |
net.add_edge(source_str, target_str, | |
label=edge_label, | |
font={'size': 0}, | |
color=edge_color, | |
title=edge_label, | |
length=200, | |
smooth={'type': 'curvedCW', 'roundness': 0.1}) | |
else: | |
net.add_edge(source_str, target_str, | |
label=edge_label, | |
font={'size': 0}, | |
color='#666666', # Keep default gray for non-GO edges | |
title=edge_label, | |
length=200, | |
smooth={'type': 'curvedCW', 'roundness': 0.1}) | |
# LEGEND | |
legend_html = """ | |
<style> | |
.kg-legend { | |
margin-top: 20px; | |
padding: 20px; | |
border: 1px solid #ddd; | |
border-radius: 5px; | |
font-family: Arial, sans-serif; | |
display: flex; | |
gap: 20px; | |
} | |
.legend-section-nodes { | |
flex: 2; /* Takes up 2/3 of the space */ | |
} | |
.legend-section-edges { | |
flex: 1; /* Takes up 1/3 of the space */ | |
} | |
.legend-title { | |
margin-bottom: 15px; | |
color: #333; | |
font-size: 16px; | |
font-weight: bold; | |
} | |
.nodes-grid { | |
display: grid; | |
grid-template-columns: repeat(2, 1fr); | |
gap: 12px; | |
} | |
.edges-grid { | |
display: grid; | |
grid-template-columns: 1fr; | |
gap: 12px; | |
} | |
.legend-item { | |
display: flex; | |
align-items: center; | |
padding: 4px; | |
} | |
.node-indicator { | |
width: 15px; | |
height: 15px; | |
border-radius: 50%; | |
margin-right: 10px; | |
flex-shrink: 0; | |
} | |
.edge-indicator { | |
width: 40px; | |
height: 3px; | |
margin-right: 10px; | |
flex-shrink: 0; | |
} | |
.legend-label { | |
font-size: 14px; | |
} | |
</style> | |
<div class="kg-legend"> | |
<div class="legend-section-nodes"> | |
<div class="legend-title">Node Types</div> | |
<div class="nodes-grid">""" | |
# Node types in 2 columns | |
for node_type, color in NODE_TYPE_COLORS.items(): | |
if node_type == 'kegg_Pathway': | |
continue | |
if node_type in NODE_LABEL_TRANSLATION: | |
node_label = NODE_LABEL_TRANSLATION[node_type] | |
else: | |
node_label = node_type | |
legend_html += f""" | |
<div class="legend-item"> | |
<div class="node-indicator" style="background-color: {color};"></div> | |
<span class="legend-label">{node_label}</span> | |
</div>""" | |
# Edge types in 1 column | |
legend_html += """ | |
</div> | |
</div> | |
<div class="legend-section-edges"> | |
<div class="legend-title">Edge Colors</div> | |
<div class="edges-grid"> | |
<div class="legend-item"> | |
<div class="edge-indicator" style="background-color: #8338ec;"></div> | |
<span class="legend-label">Validated GO Prediction</span> | |
</div> | |
<div class="legend-item"> | |
<div class="edge-indicator" style="background-color: #c1121f;"></div> | |
<span class="legend-label">Non-validated GO Prediction</span> | |
</div> | |
<div class="legend-item"> | |
<div class="edge-indicator" style="background-color: #219ebc;"></div> | |
<span class="legend-label">Ground Truth GO Annotation</span> | |
</div> | |
<div class="legend-item"> | |
<div class="edge-indicator" style="background-color: #666666;"></div> | |
<span class="legend-label">Other Relationships</span> | |
</div> | |
</div> | |
</div> | |
</div> | |
""" | |
# Save graph to a protein-specific file in a temporary directory | |
os.makedirs('temp_viz', exist_ok=True) | |
file_path = os.path.join('temp_viz', f'{protein_id}_graph.html') | |
net.save_graph(file_path) | |
with open(file_path, 'r', encoding='utf-8') as f: | |
content = f.read() | |
# Insert the legend before the closing body tag | |
content = content.replace('</body>', f'{legend_html}</body>') | |
with open(file_path, 'w', encoding='utf-8') as f: | |
f.write(content) | |
return file_path, visualized_edges |