Flux9665's picture
update to the current version
70399da
raw
history blame
8.3 kB
import json
import os
import pickle
import random
import kan
import torch
from tqdm import tqdm
from Modules.ToucanTTS.InferenceToucanTTS import ToucanTTS
from Utility.utils import load_json_from_path
class MetricsCombiner(torch.nn.Module):
def __init__(self, m):
super().__init__()
self.scoring_function = kan.KAN(width=[3, 5, 1], grid=5, k=5, seed=m)
def forward(self, x):
return self.scoring_function(x.squeeze())
class EnsembleModel(torch.nn.Module):
def __init__(self, models):
super().__init__()
self.models = models
def forward(self, x):
distances = list()
for model in self.models:
distances.append(model(x))
return sum(distances) / len(distances)
def create_learned_cache(model_path, cache_root="."):
checkpoint = torch.load(model_path, map_location='cpu')
embedding_provider = ToucanTTS(weights=checkpoint["model"], config=checkpoint["config"]).encoder.language_embedding
embedding_provider.requires_grad_(False)
language_list = load_json_from_path(os.path.join(cache_root, "supervised_languages.json"))
tree_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_tree_dist.json")
map_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_map_dist.json")
asp_dict_path = os.path.join(cache_root, "asp_dict.pkl")
if not os.path.exists(tree_lookup_path) or not os.path.exists(map_lookup_path):
raise FileNotFoundError("Please ensure the caches exist!")
if not os.path.exists(asp_dict_path):
raise FileNotFoundError(f"{asp_dict_path} must be downloaded separately.")
tree_dist = load_json_from_path(tree_lookup_path)
map_dist = load_json_from_path(map_lookup_path)
with open(asp_dict_path, 'rb') as dictfile:
asp_sim = pickle.load(dictfile)
lang_list = list(asp_sim.keys())
largest_value_map_dist = 0.0
for _, values in map_dist.items():
for _, value in values.items():
largest_value_map_dist = max(largest_value_map_dist, value)
iso_codes_to_ids = load_json_from_path(os.path.join(cache_root, "iso_lookup.json"))[-1]
train_set = language_list
batch_size = 128
model_list = list()
print_intermediate_results = False
# ensemble preparation
n_models = 5
print(f"Training ensemble of {n_models} models for learned distance metric.")
for m in range(n_models):
model_list.append(MetricsCombiner(m))
optim = torch.optim.Adam(model_list[-1].parameters(), lr=0.0005)
running_loss = list()
for epoch in tqdm(range(35), desc=f"MetricsCombiner {m + 1}/{n_models} - Epoch"):
for i in range(1000):
# we have no dataloader, so first we build a batch
embedding_distance_batch = list()
metric_distance_batch = list()
for _ in range(batch_size):
lang_1 = random.sample(train_set, 1)[0]
lang_2 = random.sample(train_set, 1)[0]
embedding_distance_batch.append(torch.nn.functional.mse_loss(embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_1]])).squeeze(), embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_2]])).squeeze()))
try:
_tree_dist = tree_dist[lang_2][lang_1]
except KeyError:
_tree_dist = tree_dist[lang_1][lang_2]
try:
_map_dist = map_dist[lang_2][lang_1] / largest_value_map_dist
except KeyError:
_map_dist = map_dist[lang_1][lang_2] / largest_value_map_dist
_asp_dist = 1.0 - asp_sim[lang_1][lang_list.index(lang_2)]
metric_distance_batch.append(torch.tensor([_tree_dist, _map_dist, _asp_dist], dtype=torch.float32))
# ok now we have a batch prepared. Time to feed it to the model.
scores = model_list[-1](torch.stack(metric_distance_batch).squeeze())
if print_intermediate_results:
print("==================================")
print(scores.detach().squeeze()[:9])
print(torch.stack(embedding_distance_batch).squeeze()[:9])
loss = torch.nn.functional.mse_loss(scores.squeeze(), torch.stack(embedding_distance_batch).squeeze(), reduction="none")
loss = loss / (torch.stack(embedding_distance_batch).squeeze() + 0.0001)
loss = loss.mean()
running_loss.append(loss.item())
optim.zero_grad()
loss.backward()
optim.step()
print("\n\n")
print(sum(running_loss) / len(running_loss))
print("\n\n")
running_loss = list()
# model_list[-1].scoring_function.plot(folder=f"kan_vis_{m}", beta=5000)
# plt.show()
# Time to see if the final ensemble is any good
ensemble = EnsembleModel(model_list)
running_loss = list()
for i in range(100):
# we have no dataloader, so first we build a batch
embedding_distance_batch = list()
metric_distance_batch = list()
for _ in range(batch_size):
lang_1 = random.sample(train_set, 1)[0]
lang_2 = random.sample(train_set, 1)[0]
embedding_distance_batch.append(torch.nn.functional.mse_loss(embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_1]])).squeeze(), embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_2]])).squeeze()))
try:
_tree_dist = tree_dist[lang_2][lang_1]
except KeyError:
_tree_dist = tree_dist[lang_1][lang_2]
try:
_map_dist = map_dist[lang_2][lang_1] / largest_value_map_dist
except KeyError:
_map_dist = map_dist[lang_1][lang_2] / largest_value_map_dist
_asp_dist = 1.0 - asp_sim[lang_1][lang_list.index(lang_2)]
metric_distance_batch.append(torch.tensor([_tree_dist, _map_dist, _asp_dist], dtype=torch.float32))
scores = ensemble(torch.stack(metric_distance_batch).squeeze())
print("==================================")
print(scores.detach().squeeze()[:9])
print(torch.stack(embedding_distance_batch).squeeze()[:9])
loss = torch.nn.functional.mse_loss(scores.squeeze(), torch.stack(embedding_distance_batch).squeeze())
running_loss.append(loss.item())
print("\n\n")
print(sum(running_loss) / len(running_loss))
language_to_language_to_learned_distance = dict()
for lang_1 in tqdm(tree_dist):
for lang_2 in tree_dist:
try:
if lang_2 in language_to_language_to_learned_distance:
if lang_1 in language_to_language_to_learned_distance[lang_2]:
continue # it's symmetric
if lang_1 not in language_to_language_to_learned_distance:
language_to_language_to_learned_distance[lang_1] = dict()
try:
_tree_dist = tree_dist[lang_2][lang_1]
except KeyError:
_tree_dist = tree_dist[lang_1][lang_2]
try:
_map_dist = map_dist[lang_2][lang_1] / largest_value_map_dist
except KeyError:
_map_dist = map_dist[lang_1][lang_2] / largest_value_map_dist
_asp_dist = 1.0 - asp_sim[lang_1][lang_list.index(lang_2)]
metric_distance = torch.tensor([_tree_dist, _map_dist, _asp_dist], dtype=torch.float32)
with torch.inference_mode():
predicted_distance = ensemble(metric_distance.unsqueeze(0)).squeeze()
language_to_language_to_learned_distance[lang_1][lang_2] = predicted_distance.item()
except ValueError:
continue
except KeyError:
continue
with open(os.path.join(cache_root, 'lang_1_to_lang_2_to_learned_dist.json'), 'w', encoding='utf-8') as f:
json.dump(language_to_language_to_learned_distance, f, ensure_ascii=False, indent=4)
if __name__ == '__main__':
create_learned_cache("../../Models/ToucanTTS_Meta/best.pt")