import os | |
import pickle | |
import matplotlib.pyplot as plt | |
import networkx as nx | |
import torch | |
from tqdm import tqdm | |
from Modules.ToucanTTS.InferenceToucanTTS import ToucanTTS | |
from Utility.utils import load_json_from_path | |
distance_types = ["tree", "asp", "map", "learned", "l1"] | |
modes = ["plot_all", "plot_neighbors"] | |
neighbor = "Latin" | |
num_neighbors = 12 | |
distance_type = distance_types[0] # switch here | |
mode = modes[1] | |
edge_threshold = 0.01 | |
# TODO histograms to figure out a good threshold | |
cache_root = "." | |
supervised_iso_codes = load_json_from_path(os.path.join(cache_root, "supervised_languages.json")) | |
if distance_type == "l1": | |
iso_codes_to_ids = load_json_from_path(os.path.join(cache_root, "iso_lookup.json"))[-1] | |
model_path = "../../Models/ToucanTTS_Meta/" | |
checkpoint = torch.load(model_path, map_location='cpu') | |
embedding_provider = ToucanTTS(weights=checkpoint["model"], config=checkpoint["config"]).encoder.language_embedding | |
embedding_provider.requires_grad_(False) | |
l1_dist = dict() | |
seen_langs = set() | |
for lang_1 in supervised_iso_codes: | |
if lang_1 not in seen_langs: | |
seen_langs.add(lang_1) | |
l1_dist[lang_1] = dict() | |
for lang_2 in supervised_iso_codes: | |
if lang_2 not in seen_langs: # it's symmetric | |
l1_dist[lang_1][lang_2] = torch.nn.functional.mse_loss(embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_1]])).squeeze(), embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_2]])).squeeze()) | |
largest_value_l1_dist = 0.0 | |
for _, values in l1_dist.items(): | |
for _, value in values.items(): | |
largest_value_l1_dist = max(largest_value_l1_dist, value) | |
for key1 in l1_dist: | |
for key2 in l1_dist[key1]: | |
l1_dist[key1][key2] = l1_dist[key1][key2] / largest_value_l1_dist | |
distance_measure = l1_dist | |
if distance_type == "tree": | |
tree_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_tree_dist.json") | |
tree_dist = load_json_from_path(tree_lookup_path) | |
distance_measure = tree_dist | |
if distance_type == "map": | |
map_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_map_dist.json") | |
map_dist = load_json_from_path(map_lookup_path) | |
largest_value_map_dist = 0.0 | |
for _, values in map_dist.items(): | |
for _, value in values.items(): | |
largest_value_map_dist = max(largest_value_map_dist, value) | |
for key1 in map_dist: | |
for key2 in map_dist[key1]: | |
map_dist[key1][key2] = map_dist[key1][key2] / largest_value_map_dist | |
distance_measure = map_dist | |
if distance_type == "learned": | |
learned_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_map_dist.json") | |
learned_dist = load_json_from_path(learned_lookup_path) | |
largest_value_learned_dist = 0.0 | |
for _, values in learned_dist.items(): | |
for _, value in values.items(): | |
largest_value_learned_dist = max(largest_value_learned_dist, value) | |
for key1 in learned_dist: | |
for key2 in learned_dist[key1]: | |
learned_dist[key1][key2] = learned_dist[key1][key2] / largest_value_learned_dist | |
distance_measure = learned_dist | |
if distance_type == "asp": | |
asp_dict_path = os.path.join(cache_root, "asp_dict.pkl") | |
with open(asp_dict_path, 'rb') as dictfile: | |
asp_sim = pickle.load(dictfile) | |
lang_list = list(asp_sim.keys()) | |
asp_dist = dict() | |
seen_langs = set() | |
for lang_1 in lang_list: | |
if lang_1 not in seen_langs: | |
seen_langs.add(lang_1) | |
asp_dist[lang_1] = dict() | |
for index, lang_2 in enumerate(lang_list): | |
if lang_2 not in seen_langs: # it's symmetric | |
asp_dist[lang_1][lang_2] = 1 - asp_sim[lang_1][index] | |
distance_measure = asp_dist | |
iso_codes_to_names = load_json_from_path(os.path.join(cache_root, "iso_to_fullname.json")) | |
distances = list() | |
for lang_1 in distance_measure: | |
if lang_1 not in iso_codes_to_names: | |
continue | |
if lang_1 not in supervised_iso_codes and iso_codes_to_names[lang_1] != neighbor: | |
continue | |
for lang_2 in distance_measure[lang_1]: | |
try: | |
if lang_2 not in supervised_iso_codes and iso_codes_to_names[lang_2] != neighbor: | |
continue | |
except KeyError: | |
continue | |
distances.append((iso_codes_to_names[lang_1], iso_codes_to_names[lang_2], distance_measure[lang_1][lang_2])) | |
# Create a graph | |
G = nx.Graph() | |
# Add edges along with distances as weights | |
min_dist = min(d for _, _, d in distances) | |
max_dist = max(d for _, _, d in distances) | |
normalized_distances = [(entity1, entity2, (d - min_dist) / (max_dist - min_dist)) for entity1, entity2, d in distances] | |
if mode == "plot_neighbors": | |
fullnames = list() | |
fullnames.append(neighbor) | |
for code in supervised_iso_codes: | |
fullnames.append(iso_codes_to_names[code]) | |
supervised_iso_codes = fullnames | |
d_dist = list() | |
for entity1, entity2, d in tqdm(normalized_distances): | |
if (neighbor == entity2 or neighbor == entity1) and (entity1 in supervised_iso_codes and entity2 in supervised_iso_codes): | |
if entity1 != entity2: | |
d_dist.append(d) | |
thresh = sorted(d_dist)[num_neighbors] | |
# distance_scores = sorted(d_dist)[:num_neighbors] | |
neighbors = list() | |
for entity1, entity2, d in tqdm(normalized_distances): | |
if (d < thresh and (neighbor == entity2 or neighbor == entity1)) and (entity1 in supervised_iso_codes and entity2 in supervised_iso_codes): | |
neighbors.append(entity1) | |
neighbors.append(entity2) | |
unique_neighbors = list(set(neighbors)) | |
unique_neighbors.remove(neighbor) | |
for entity1, entity2, d in tqdm(normalized_distances): | |
if (neighbor == entity2 or neighbor == entity1) and (entity1 in supervised_iso_codes and entity2 in supervised_iso_codes): | |
if entity1 != entity2 and d < thresh: | |
spring_tension = ((thresh - d) ** 2) * 20000 # for vis purposes | |
print(f"{d}-->{spring_tension}") | |
G.add_edge(entity1, entity2, weight=spring_tension) | |
for entity1, entity2, d in tqdm(normalized_distances): | |
if (entity2 in unique_neighbors and entity1 in unique_neighbors) and (entity1 in supervised_iso_codes and entity2 in supervised_iso_codes): | |
if entity1 != entity2: | |
spring_tension = 1 - d | |
G.add_edge(entity1, entity2, weight=spring_tension) | |
# Draw the graph | |
pos = nx.spring_layout(G, weight="weight") # Positions for all nodes | |
edges = G.edges(data=True) | |
# Draw nodes | |
nx.draw_networkx_nodes(G, pos, node_size=1, alpha=0.01) | |
# Draw edges with labels | |
edges_connected_to_specific_node = [(u, v) for u, v in G.edges() if u == neighbor or v == neighbor] | |
# nx.draw_networkx_edges(G, pos, alpha=0.1) | |
nx.draw_networkx_edges(G, pos, edgelist=edges_connected_to_specific_node, edge_color='red', alpha=0.3, width=3) | |
for u, v, d in edges: | |
if u == neighbor or v == neighbor: | |
nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): round((thresh - (d['weight'] / 20000) ** (1 / 2)) * 10, 2)}, font_color="red", alpha=0.3) # reverse modifications | |
else: | |
pass | |
# nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): d['weight']}) | |
# Draw node labels | |
nx.draw_networkx_labels(G, pos, font_size=14, font_family='sans-serif', font_color='green') | |
nx.draw_networkx_labels(G, pos, labels={neighbor: neighbor}, font_size=14, font_family='sans-serif', font_color='red') | |
plt.title(f'Graph of {distance_type} Distances') | |
plt.subplots_adjust(left=0, right=1, top=1, bottom=0) | |
plt.tight_layout(pad=0) | |
plt.savefig("avg.png", dpi=300) | | | |
elif mode == "plot_all": | |
for entity1, entity2, d in tqdm(normalized_distances): | |
if d < edge_threshold and entity1 != entity2: | |
spring_tension = edge_threshold - d | |
G.add_edge(entity1, entity2, weight=spring_tension) | |
# Draw the graph | |
pos = nx.spring_layout(G, weight="weight") # Positions for all nodes | |
edges = G.edges(data=True) | |
# Draw nodes | |
nx.draw_networkx_nodes(G, pos, node_size=1, alpha=0.01) | |
# Draw edges with labels | |
nx.draw_networkx_edges(G, pos, alpha=0.1, edge_color="blue") | |
# nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): d['weight'] for u, v, d in edges}) | |
# Draw node labels | |
nx.draw_networkx_labels(G, pos, font_size=10, font_family='sans-serif') | |
plt.title(f'Graph of {distance_type} Distances') | |
plt.subplots_adjust(left=0, right=1, top=1, bottom=0) | |
plt.tight_layout(pad=0) | | | |