Flux9665's picture
update to the current version
70399da
raw
history blame
8.7 kB
import os
import pickle
import matplotlib.pyplot as plt
import networkx as nx
import torch
from tqdm import tqdm
from Modules.ToucanTTS.InferenceToucanTTS import ToucanTTS
from Utility.utils import load_json_from_path
distance_types = ["tree", "asp", "map", "learned", "l1"]
modes = ["plot_all", "plot_neighbors"]
neighbor = "Latin"
num_neighbors = 12
distance_type = distance_types[0] # switch here
mode = modes[1]
edge_threshold = 0.01
# TODO histograms to figure out a good threshold
cache_root = "."
supervised_iso_codes = load_json_from_path(os.path.join(cache_root, "supervised_languages.json"))
if distance_type == "l1":
iso_codes_to_ids = load_json_from_path(os.path.join(cache_root, "iso_lookup.json"))[-1]
model_path = "../../Models/ToucanTTS_Meta/best.pt"
checkpoint = torch.load(model_path, map_location='cpu')
embedding_provider = ToucanTTS(weights=checkpoint["model"], config=checkpoint["config"]).encoder.language_embedding
embedding_provider.requires_grad_(False)
l1_dist = dict()
seen_langs = set()
for lang_1 in supervised_iso_codes:
if lang_1 not in seen_langs:
seen_langs.add(lang_1)
l1_dist[lang_1] = dict()
for lang_2 in supervised_iso_codes:
if lang_2 not in seen_langs: # it's symmetric
l1_dist[lang_1][lang_2] = torch.nn.functional.mse_loss(embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_1]])).squeeze(), embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_2]])).squeeze())
largest_value_l1_dist = 0.0
for _, values in l1_dist.items():
for _, value in values.items():
largest_value_l1_dist = max(largest_value_l1_dist, value)
for key1 in l1_dist:
for key2 in l1_dist[key1]:
l1_dist[key1][key2] = l1_dist[key1][key2] / largest_value_l1_dist
distance_measure = l1_dist
if distance_type == "tree":
tree_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_tree_dist.json")
tree_dist = load_json_from_path(tree_lookup_path)
distance_measure = tree_dist
if distance_type == "map":
map_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_map_dist.json")
map_dist = load_json_from_path(map_lookup_path)
largest_value_map_dist = 0.0
for _, values in map_dist.items():
for _, value in values.items():
largest_value_map_dist = max(largest_value_map_dist, value)
for key1 in map_dist:
for key2 in map_dist[key1]:
map_dist[key1][key2] = map_dist[key1][key2] / largest_value_map_dist
distance_measure = map_dist
if distance_type == "learned":
learned_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_map_dist.json")
learned_dist = load_json_from_path(learned_lookup_path)
largest_value_learned_dist = 0.0
for _, values in learned_dist.items():
for _, value in values.items():
largest_value_learned_dist = max(largest_value_learned_dist, value)
for key1 in learned_dist:
for key2 in learned_dist[key1]:
learned_dist[key1][key2] = learned_dist[key1][key2] / largest_value_learned_dist
distance_measure = learned_dist
if distance_type == "asp":
asp_dict_path = os.path.join(cache_root, "asp_dict.pkl")
with open(asp_dict_path, 'rb') as dictfile:
asp_sim = pickle.load(dictfile)
lang_list = list(asp_sim.keys())
asp_dist = dict()
seen_langs = set()
for lang_1 in lang_list:
if lang_1 not in seen_langs:
seen_langs.add(lang_1)
asp_dist[lang_1] = dict()
for index, lang_2 in enumerate(lang_list):
if lang_2 not in seen_langs: # it's symmetric
asp_dist[lang_1][lang_2] = 1 - asp_sim[lang_1][index]
distance_measure = asp_dist
iso_codes_to_names = load_json_from_path(os.path.join(cache_root, "iso_to_fullname.json"))
distances = list()
for lang_1 in distance_measure:
if lang_1 not in iso_codes_to_names:
continue
if lang_1 not in supervised_iso_codes and iso_codes_to_names[lang_1] != neighbor:
continue
for lang_2 in distance_measure[lang_1]:
try:
if lang_2 not in supervised_iso_codes and iso_codes_to_names[lang_2] != neighbor:
continue
except KeyError:
continue
distances.append((iso_codes_to_names[lang_1], iso_codes_to_names[lang_2], distance_measure[lang_1][lang_2]))
# Create a graph
G = nx.Graph()
# Add edges along with distances as weights
min_dist = min(d for _, _, d in distances)
max_dist = max(d for _, _, d in distances)
normalized_distances = [(entity1, entity2, (d - min_dist) / (max_dist - min_dist)) for entity1, entity2, d in distances]
if mode == "plot_neighbors":
fullnames = list()
fullnames.append(neighbor)
for code in supervised_iso_codes:
fullnames.append(iso_codes_to_names[code])
supervised_iso_codes = fullnames
d_dist = list()
for entity1, entity2, d in tqdm(normalized_distances):
if (neighbor == entity2 or neighbor == entity1) and (entity1 in supervised_iso_codes and entity2 in supervised_iso_codes):
if entity1 != entity2:
d_dist.append(d)
thresh = sorted(d_dist)[num_neighbors]
# distance_scores = sorted(d_dist)[:num_neighbors]
neighbors = list()
for entity1, entity2, d in tqdm(normalized_distances):
if (d < thresh and (neighbor == entity2 or neighbor == entity1)) and (entity1 in supervised_iso_codes and entity2 in supervised_iso_codes):
neighbors.append(entity1)
neighbors.append(entity2)
unique_neighbors = list(set(neighbors))
unique_neighbors.remove(neighbor)
for entity1, entity2, d in tqdm(normalized_distances):
if (neighbor == entity2 or neighbor == entity1) and (entity1 in supervised_iso_codes and entity2 in supervised_iso_codes):
if entity1 != entity2 and d < thresh:
spring_tension = ((thresh - d) ** 2) * 20000 # for vis purposes
print(f"{d}-->{spring_tension}")
G.add_edge(entity1, entity2, weight=spring_tension)
for entity1, entity2, d in tqdm(normalized_distances):
if (entity2 in unique_neighbors and entity1 in unique_neighbors) and (entity1 in supervised_iso_codes and entity2 in supervised_iso_codes):
if entity1 != entity2:
spring_tension = 1 - d
G.add_edge(entity1, entity2, weight=spring_tension)
# Draw the graph
pos = nx.spring_layout(G, weight="weight") # Positions for all nodes
edges = G.edges(data=True)
# Draw nodes
nx.draw_networkx_nodes(G, pos, node_size=1, alpha=0.01)
# Draw edges with labels
edges_connected_to_specific_node = [(u, v) for u, v in G.edges() if u == neighbor or v == neighbor]
# nx.draw_networkx_edges(G, pos, alpha=0.1)
nx.draw_networkx_edges(G, pos, edgelist=edges_connected_to_specific_node, edge_color='red', alpha=0.3, width=3)
for u, v, d in edges:
if u == neighbor or v == neighbor:
nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): round((thresh - (d['weight'] / 20000) ** (1 / 2)) * 10, 2)}, font_color="red", alpha=0.3) # reverse modifications
else:
pass
# nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): d['weight']})
# Draw node labels
nx.draw_networkx_labels(G, pos, font_size=14, font_family='sans-serif', font_color='green')
nx.draw_networkx_labels(G, pos, labels={neighbor: neighbor}, font_size=14, font_family='sans-serif', font_color='red')
plt.title(f'Graph of {distance_type} Distances')
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
plt.tight_layout(pad=0)
plt.savefig("avg.png", dpi=300)
plt.show()
elif mode == "plot_all":
for entity1, entity2, d in tqdm(normalized_distances):
if d < edge_threshold and entity1 != entity2:
spring_tension = edge_threshold - d
G.add_edge(entity1, entity2, weight=spring_tension)
# Draw the graph
pos = nx.spring_layout(G, weight="weight") # Positions for all nodes
edges = G.edges(data=True)
# Draw nodes
nx.draw_networkx_nodes(G, pos, node_size=1, alpha=0.01)
# Draw edges with labels
nx.draw_networkx_edges(G, pos, alpha=0.1, edge_color="blue")
# nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): d['weight'] for u, v, d in edges})
# Draw node labels
nx.draw_networkx_labels(G, pos, font_size=10, font_family='sans-serif')
plt.title(f'Graph of {distance_type} Distances')
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
plt.tight_layout(pad=0)
plt.show()