klinic / calculate_smilar_nodes.py
ACMCMC
WIP
a6bd112
raw
history blame contribute delete
No virus
2.47 kB
# %%
def transe_distance(head, tail, relation, entity_embeddings, relation_embeddings):
head_embedding = entity_embeddings[head]
tail_embedding = entity_embeddings[tail]
relation_embeddings = relation_embeddings[relation]
distance = head_embedding + relation_embeddings - tail_embedding
return distance
def calculate_similar_nodes(node, entity_embeddings, relation_embeddings, top_n=10):
distances = []
for i in range(len(entity_embeddings)):
distance = transe_distance(node, i, 0, entity_embeddings, relation_embeddings)
distances.append((i, distance))
distances.sort(key=lambda x: x[1].norm().item())
return distances[:top_n]
# %%
import pandas as pd
# Load the embeddings from the CSV files
entity_embeddings = pd.read_csv("entity_embeddings.csv", index_col=0)
# The embedding column is a string, convert it to a tensor
import torch
entity_embeddings["embedding"] = entity_embeddings["embedding"].apply(
lambda x: torch.tensor(eval(x))
)
entity_embeddings.head()
# Now, load the relation embeddings
relation_embeddings = pd.read_csv("relation_embeddings.csv", index_col=0)
relation_embeddings["embedding"] = relation_embeddings["embedding"].apply(
lambda x: torch.tensor(eval(x))
)
display(relation_embeddings.head())
# %%
# Find the index of the entity with the uri "http://identifiers.org/medgen/C0002395"
head = entity_embeddings[
entity_embeddings["uri"] == "http://identifiers.org/medgen/C0002395"
].index[0]
# Find the index of the entity with the uri "http://identifiers.org/medgen/C1843013"
tail = entity_embeddings[
entity_embeddings["uri"] == "http://identifiers.org/medgen/C1843013"
].index[0]
relation = 0
distance = transe_distance(
head,
tail,
relation,
entity_embeddings["embedding"],
relation_embeddings["embedding"],
)
print(
f'Distance between {entity_embeddings["label"][head]} ({head}) and {entity_embeddings["label"][tail]} ({tail}) via relation {relation_embeddings["label"][relation]} is {distance.norm().item()}'
)
# %%
# Calculate similar nodes to the head
similar_nodes = calculate_similar_nodes(
head, entity_embeddings["embedding"], relation_embeddings["embedding"]
)
print(f"Similar nodes to {entity_embeddings['label'][head]} ({head}):")
# Print the similar nodes
for i, (node, distance) in enumerate(similar_nodes):
print(
f"{i}: {entity_embeddings['label'][node]} ({node}) with distance {distance.norm().item()}"
)
# %%