meta-demo-app / ML-SLRC /Util_funs.py
BecomeAllan
update
6b71499
raw
history blame contribute delete
No virus
22.5 kB
from ML_SLRC import *
import os
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from torch.optim import Adam
import gc
from torchmetrics import functional as fn
import random
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_curve, auc
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt
import warnings
import torch
import time
from sklearn.manifold import TSNE
from copy import deepcopy
import seaborn as sns
import matplotlib.pylab as plt
import json
from pathlib import Path
import re
from collections import defaultdict
# SEED = 2222
# gen_seed = torch.Generator().manual_seed(SEED)
# Random seed function
def random_seed(value):
torch.backends.cudnn.deterministic=True
torch.manual_seed(value)
torch.cuda.manual_seed(value)
np.random.seed(value)
random.seed(value)
# Tasks for meta-learner
def create_batch_of_tasks(taskset, is_shuffle = True, batch_size = 4):
idxs = list(range(0,len(taskset)))
if is_shuffle:
random.shuffle(idxs)
for i in range(0,len(idxs), batch_size):
yield [taskset[idxs[i]] for i in range(i, min(i + batch_size,len(taskset)))]
# Prepare data to process by Domain-learner
def prepare_data(data, batch_size, tokenizer,max_seq_length,
input = 'text', output = 'label',
train_size_per_class = 5, global_datasets = False,
treat_text_fun =None):
data = data.reset_index().drop("index", axis=1)
if global_datasets:
global data_train, data_test
# Sample task for training
data_train = data.groupby('label').sample(train_size_per_class, replace=False)
idex = data.index.isin(data_train.index)
# The Test set to label by the model
data_test = data
# Transform in dataset to model
## Train
dataset_train = SLR_DataSet(
data = data_train.sample(frac=1),
input = input,
output = output,
tokenizer=tokenizer,
max_seq_length =max_seq_length,
treat_text =treat_text_fun)
## Test
dataset_test = SLR_DataSet(
data = data_test,
input = input,
output = output,
tokenizer=tokenizer,
max_seq_length =max_seq_length,
treat_text =treat_text_fun)
# Dataloaders
## Train
data_train_loader = DataLoader(dataset_train,
shuffle=True,
batch_size=batch_size['train']
)
## Test
if len(dataset_test) % batch_size['test'] == 1 :
data_test_loader = DataLoader(dataset_test,
batch_size=batch_size['test'],
drop_last=True)
else:
data_test_loader = DataLoader(dataset_test,
batch_size=batch_size['test'],
drop_last=False)
return data_train_loader, data_test_loader, data_train, data_test
# Meta trainer
def meta_train(data, model, device, Info,
print_epoch =True,
Test_resource =None,
treat_text_fun =None):
# Meta-learner model
learner = Learner(model = model, device = device, **Info)
# Testing tasks
if isinstance(Test_resource, pd.DataFrame):
test = MetaTask(Test_resource, num_task = 0, k_support=10, k_query=10,
training=False,treat_text =treat_text_fun, **Info)
torch.clear_autocast_cache()
gc.collect()
torch.cuda.empty_cache()
# Meta epoch (Outer epoch)
for epoch in tqdm(range(Info['meta_epoch']), desc= "Meta epoch ", ncols=80):
# Train tasks
train = MetaTask(data,
num_task = Info['num_task_train'],
k_support=Info['k_qry'],
k_query=Info['k_spt'],
treat_text =treat_text_fun, **Info)
# Batch of train tasks
db = create_batch_of_tasks(train, is_shuffle = True, batch_size = Info["outer_batch_size"])
if print_epoch:
# Outer loop bach training
for step, task_batch in enumerate(db):
print("\n-----------------Training Mode","Meta_epoch:", epoch ,"-----------------\n")
# meta-feedfoward (outer-feedfoward)
acc = learner(task_batch, valid_train= print_epoch)
print('Step:', step, '\ttraining Acc:', acc)
if isinstance(Test_resource, pd.DataFrame):
# Validating Model
if ((epoch+1) % 4) + step == 0:
random_seed(123)
print("\n-----------------Testing Mode-----------------\n")
# Batch of test tasks
db_test = create_batch_of_tasks(test, is_shuffle = False, batch_size = 1)
acc_all_test = []
# Looping testing tasks
for test_batch in db_test:
acc = learner(test_batch, training = False)
acc_all_test.append(acc)
print('Test acc:', np.mean(acc_all_test))
del acc_all_test, db_test
# Restarting training randomly
random_seed(int(time.time() % 10))
else:
for step, task_batch in enumerate(db):
# meta-feedfoward (outer-feedfoward)
acc = learner(task_batch, print_epoch, valid_train= print_epoch)
torch.clear_autocast_cache()
gc.collect()
torch.cuda.empty_cache()
def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr = 1, print_info = True, name = 'name', weight_decay = 1):
# Start the model's parameters
model_meta = deepcopy(model)
optimizer = Adam(model_meta.parameters(), lr=lr, weight_decay = weight_decay)
model_meta.to(device)
model_meta.train()
# Task epoch (Inner epoch)
for i in range(0, epoch):
all_loss = []
# Inner training batch (support set)
for inner_step, batch in enumerate(data_train_loader):
batch = tuple(t.to(device) for t in batch)
input_ids, attention_mask,q_token_type_ids, label_id = batch
# Inner Feedfoward
loss, _, _ = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze())
# compute grads
loss.backward()
# update parameters
optimizer.step()
optimizer.zero_grad()
all_loss.append(loss.item())
if (i % 2 == 0) & print_info:
print("Loss: ", np.mean(all_loss))
# Test evaluation
model_meta.eval()
all_loss = []
all_acc = []
features = []
labels = []
predi_logit = []
with torch.no_grad():
# Test's Batch loop
for inner_step, batch in enumerate(tqdm(data_test_loader,
desc="Test validation | " + name,
ncols=80)) :
batch = tuple(t.to(device) for t in batch)
input_ids, attention_mask,q_token_type_ids, label_id = batch
# Predictions
_, feature, _ = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze())
# prediction = prediction.detach().cpu().squeeze()
# label_id = label_id.detach().cpu()
logit = feature[1].detach().cpu()
# feature_lat = feature[0].detach().cpu()
# labels.append(label_id.numpy().squeeze())
# features.append(feature_lat.numpy())
predi_logit.append(logit.numpy())
# Accuracy over the test's bach
# acc = fn.accuracy(prediction, label_id).item()
# all_acc.append(acc)
del input_ids, attention_mask, label_id, batch
if print_info:
print("acc:", np.mean(all_acc))
model_meta.to('cpu')
gc.collect()
torch.cuda.empty_cache()
del model_meta, optimizer
logits = np.concatenate(np.array(predi_logit,dtype=object))
logits = torch.tensor(logits.astype(np.float32)).detach().clone()
# return features, labels, predi_logit
return logits.detach().clone()
# Process predictions and map the feature_map in tsne
def map_feature_tsne(features, labels, predi_logit):
features = np.concatenate(np.array(features,dtype=object))
features = torch.tensor(features.astype(np.float32)).detach().clone()
labels = np.concatenate(np.array(labels,dtype=object))
labels = torch.tensor(labels.astype(int)).detach().clone()
logits = np.concatenate(np.array(predi_logit,dtype=object))
logits = torch.tensor(logits.astype(np.float32)).detach().clone()
# Dimention reduction
X_embedded = TSNE(n_components=2, learning_rate='auto',
init='random').fit_transform(features.detach().clone())
return logits.detach().clone(), X_embedded, labels.detach().clone(), features.detach().clone()
def wss_calc(logit, labels, trsh = 0.5):
# Prediction label given the threshold
predict_trash = torch.sigmoid(logit).squeeze() >= trsh
# Compute confusion matrix values
CM = confusion_matrix(labels, predict_trash.to(int) )
tn, fp, fne, tp = CM.ravel()
P = (tp + fne)
N = (tn + fp)
recall = tp/(tp+fne)
# WSS
wss = (tn + fne)/len(labels) -(1- recall)
# AWSS
awss = (tn/N - fne/P)
return {
"wss": round(wss,4),
"awss": round(awss,4),
"R": round(recall,4),
"CM": CM
}
# Compute the metrics
def plot(logits, X_embedded, labels, threshold, show = True,
namefig = "plot", make_plot = True, print_stats = True, save = True):
col = pd.MultiIndex.from_tuples([
("Predict", "0"),
("Predict", "1")
])
index = pd.MultiIndex.from_tuples([
("Real", "0"),
("Real", "1")
])
predict = torch.sigmoid(logits).detach().clone()
# Roc curve
fpr, tpr, thresholds = roc_curve(labels, predict.squeeze())
# Given by a Recall of 95% (threshold avaliation)
## WSS
### Index to recall
idx_wss95 = sum(tpr < 0.95)
### threshold
thresholds95 = thresholds[idx_wss95]
### Compute the metrics
wss95_info = wss_calc(logits,labels, thresholds95 )
acc_wss95 = fn.accuracy(predict, labels, threshold=thresholds95)
f1_wss95 = fn.f1_score(predict, labels, threshold=thresholds95)
# Given by a threshold (recall avaliation)
### Compute the metrics
wss_info = wss_calc(logits,labels, threshold )
acc_wssR = fn.accuracy(predict, labels, threshold=threshold)
f1_wssR = fn.f1_score(predict, labels, threshold=threshold)
metrics= {
# WSS
"WSS@95": wss95_info['wss'],
"AWSS@95": wss95_info['awss'],
"WSS@R": wss_info['wss'],
"AWSS@R": wss_info['awss'],
# Recall
"Recall_WSS@95": wss95_info['R'],
"Recall_WSS@R": wss_info['R'],
# acc
"acc@95": acc_wss95.item(),
"acc@R": acc_wssR.item(),
# f1
"f1@95": f1_wss95.item(),
"f1@R": f1_wssR.item(),
# threshold 95
"threshold@95": thresholds95
}
# Print stats
if print_stats:
wss95= f"WSS@95:{wss95_info['wss']}, R: {wss95_info['R']}"
wss95_adj= f"ASSWSS@95:{wss95_info['awss']}"
print(wss95)
print(wss95_adj)
print('Acc.:', round(acc_wss95.item(), 4))
print('F1-score:', round(f1_wss95.item(), 4))
print(f"threshold to wss95: {round(thresholds95, 4)}")
cm = pd.DataFrame(wss95_info['CM'],
index=index,
columns=col)
print("\nConfusion matrix:")
print(cm)
print("\n---Metrics with threshold:", threshold, "----\n")
wss= f"WSS@R:{wss_info['wss']}, R: {wss_info['R']}"
print(wss)
wss_adj= f"AWSS@R:{wss_info['awss']}"
print(wss_adj)
print('Acc.:', round(acc_wssR.item(), 4))
print('F1-score:', round(f1_wssR.item(), 4))
cm = pd.DataFrame(wss_info['CM'],
index=index,
columns=col)
print("\nConfusion matrix:")
print(cm)
# Plots
if make_plot:
fig, axes = plt.subplots(1, 4, figsize=(25,10))
alpha = torch.squeeze(predict).numpy()
# TSNE
p1 = sns.scatterplot(x=X_embedded[:, 0],
y=X_embedded[:, 1],
hue=labels,
alpha=alpha, ax = axes[0]).set_title('Predictions-TSNE', size=20)
# WSS@95
t_wss = predict >= thresholds95
t_wss = t_wss.squeeze().numpy()
p2 = sns.scatterplot(x=X_embedded[t_wss, 0],
y=X_embedded[t_wss, 1],
hue=labels[t_wss],
alpha=alpha[t_wss], ax = axes[1]).set_title('WSS@95', size=20)
# WSS@R
t = predict >= threshold
t = t.squeeze().numpy()
p3 = sns.scatterplot(x=X_embedded[t, 0],
y=X_embedded[t, 1],
hue=labels[t],
alpha=alpha[t], ax = axes[2]).set_title(f'Predictions-threshold {threshold}', size=20)
# ROC-Curve
roc_auc = auc(fpr, tpr)
lw = 2
axes[3].plot(
fpr,
tpr,
color="darkorange",
lw=lw,
label="ROC curve (area = %0.2f)" % roc_auc)
axes[3].plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
axes[3].axhline(y=0.95, color='r', linestyle='-')
# axes[3].set(xlabel="False Positive Rate", ylabel="True Positive Rate")
axes[3].legend(loc="lower right")
axes[3].set_title(label= "ROC", size = 20)
axes[3].set_ylabel("True Positive Rate", fontsize = 15)
axes[3].set_xlabel("False Positive Rate", fontsize = 15)
if show:
plt.show()
if save:
fig.savefig(namefig, dpi=fig.dpi)
return metrics
def auc_plot(logits,labels, color = "darkorange", label = "test"):
predict = torch.sigmoid(logits).detach().clone()
fpr, tpr, thresholds = roc_curve(labels, predict.squeeze())
roc_auc = auc(fpr, tpr)
lw = 2
label = label + str(round(roc_auc,2))
# print(label)
plt.plot(
fpr,
tpr,
color=color,
lw=lw,
label= label
)
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
plt.axhline(y=0.95, color='r', linestyle='-')
# Interface to evaluation
class diagnosis():
def __init__(self, names, Valid_resource, batch_size_test,
model,Info, device,treat_text_fun=None,start = 0):
self.names=names
self.Valid_resource=Valid_resource
self.batch_size_test=batch_size_test
self.model=model
self.start=start
self.Info = Info
self.device = device
self.treat_text_fun = treat_text_fun
# BOX INPUT
self.value_trash = widgets.FloatText(
value=0.95,
description='threshold',
disabled=False
)
self.valueb = widgets.IntText(
value=10,
description='size',
disabled=False
)
# Buttons
self.train_b = widgets.Button(description="Train")
self.next_b = widgets.Button(description="Next")
self.eval_b = widgets.Button(description="Evaluation")
self.hbox = widgets.HBox([self.train_b, self.valueb])
# Click buttons functions
self.next_b.on_click(self.Next_button)
self.train_b.on_click(self.Train_button)
self.eval_b.on_click(self.Evaluation_button)
# Next button
def Next_button(self,p):
clear_output()
self.i=self.i+1
# Select the domain data
self.domain = self.names[self.i]
self.data = self.Valid_resource[self.Valid_resource['domain'] == self.domain]
print("Name:", self.domain)
print(self.data['label'].value_counts())
display(self.hbox)
display(self.next_b)
# Train button
def Train_button(self, y):
clear_output()
print(self.domain)
# Prepare data for training (domain-learner)
self.data_train_loader, self.data_test_loader, self.data_train, self.data_test = prepare_data(self.data,
train_size_per_class = self.valueb.value,
batch_size = {'train': self.Info['inner_batch_size'],
'test': self.batch_size_test},
max_seq_length = self.Info['max_seq_length'],
tokenizer = self.Info['tokenizer'],
input = "text",
output = "label",
treat_text_fun=self.treat_text_fun)
# Train the model and predict in the test set
self.logits, self.X_embedded, self.labels, self.features = train_loop(self.data_train_loader, self.data_test_loader,
self.model, self.device,
epoch = self.Info['inner_update_step'],
lr=self.Info['inner_update_lr'],
print_info=True,
name = self.domain)
tresh_box = widgets.HBox([self.eval_b, self.value_trash])
display(self.hbox)
display(tresh_box)
display(self.next_b)
# Evaluation button
def Evaluation_button(self, te):
clear_output()
tresh_box = widgets.HBox([self.eval_b, self.value_trash])
print(self.domain)
# print("\n")
print("-------Train data-------")
print(data_train['label'].value_counts())
print("-------Test data-------")
print(data_test['label'].value_counts())
# print("\n")
display(self.next_b)
display(tresh_box)
display(self.hbox)
# Compute metrics
metrics = plot(self.logits, self.X_embedded, self.labels,
threshold=self.Info['threshold'], show = True,
namefig= 'test',
make_plot = True,
print_stats = True,
save=False)
def __call__(self):
self.i= self.start-1
clear_output()
display(self.next_b)
# Simulation attemps of domain learner
def pipeline_simulation(Valid_resource, names_to_valid, path_save,
model, Info, device, initializer_model,
treat_text_fun=None):
n_attempt = 5
batch_test = 100
# Create a directory to save informations
for name in names_to_valid:
name = re.sub("\.csv", "",name)
Path(path_save + name + "/img").mkdir(parents=True, exist_ok=True)
# Dict to sabe roc curves
roc_stats = defaultdict(lambda: defaultdict(
lambda: defaultdict(
list
)
)
)
all_metrics = []
# Loop over a list of domains
for name in names_to_valid:
# Select a domain dataset
data = Valid_resource[Valid_resource['domain'] == name].reset_index().drop("index", axis=1)
# Attempts simulation
for attempt in range(n_attempt):
print("---"*4,"attempt", attempt, "---"*4)
# Prepare data to pass to the model
data_train_loader, data_test_loader, _ , _ = prepare_data(data,
train_size_per_class = Info['k_spt'],
batch_size = {'train': Info['inner_batch_size'],
'test': batch_test},
max_seq_length = Info['max_seq_length'],
tokenizer = Info['tokenizer'],
input = "text",
output = "label",
treat_text_fun=treat_text_fun)
# Train the model and evaluate on the test set of the domain
logits, X_embedded, labels, features = train_loop(data_train_loader, data_test_loader,
model, device,
epoch = Info['inner_update_step'],
lr=Info['inner_update_lr'],
print_info=False,
name = name)
name_domain = re.sub("\.csv", "",name)
# Compute the metrics
metrics = plot(logits, X_embedded, labels,
threshold=Info['threshold'], show = False,
namefig= path_save + name_domain + "/img/" + str(attempt) + 'plots',
make_plot = True, print_stats = False, save = True)
# Compute the roc-curve
fpr, tpr, _ = roc_curve(labels, torch.sigmoid(logits).squeeze())
# Save the correspoud information of the domain
metrics['name'] = name_domain
metrics['layer_size'] = Info['bert_layers']
metrics['attempt'] = attempt
roc_stats[name_domain][str(Info['bert_layers'])]['fpr'].append(fpr.tolist())
roc_stats[name_domain][str(Info['bert_layers'])]['tpr'].append(tpr.tolist())
all_metrics.append(metrics)
# Save the metrics and the roc curve of the attemp
pd.DataFrame(all_metrics).to_csv(path_save+ "metrics.csv")
roc_path = path_save + "roc_stats.json"
with open(roc_path, 'w') as fp:
json.dump(roc_stats, fp)
del fpr, tpr, logits, X_embedded, labels
del features, metrics, _
# Save the information used to evaluate the validation resource
save_info = Info.copy()
save_info['model'] = initializer_model.tokenizer.name_or_path
save_info.pop("tokenizer")
save_info.pop("bert_layers")
info_path = path_save+"info.json"
with open(info_path, 'w') as fp:
json.dump(save_info, fp)
# Loading dataset statistics
def load_data_statistics(paths, names):
size = []
pos = []
neg = []
for p in paths:
data = pd.read_csv(p)
data = data.dropna()
# Dataset size
size.append(len(data))
# Number of positive labels
pos.append(data['labels'].value_counts()[1])
# Number of negative labels
neg.append(data['labels'].value_counts()[0])
del data
info_load = pd.DataFrame({
"size":size,
"pos":pos,
"neg":neg,
"names":names,
"paths": paths })
return info_load
# Loading the datasets
def load_data(train_info_load):
col = ['abstract','title', 'labels', 'domain']
data_train = pd.DataFrame(columns=col)
for p in train_info_load['paths']:
data_temp = pd.read_csv(p).loc[:, ['labels', 'title', 'abstract']]
data_temp = pd.read_csv(p).loc[:, ['labels', 'title', 'abstract']]
data_temp['domain'] = os.path.basename(p)
data_train = pd.concat([data_train, data_temp])
data_train['text'] = data_train['title'] + data_train['abstract'].replace(np.nan, '')
return( data_train \
.replace({"labels":{0:"negative", 1:'positive'}})\
.rename({"labels":"label"} , axis=1)\
.loc[ :,("text","domain","label")]
)