|
import argparse
|
|
import glob
|
|
import json
|
|
import locale
|
|
import os
|
|
import re
|
|
from functools import reduce
|
|
from multiprocessing import freeze_support
|
|
|
|
import deepspeed
|
|
import torch
|
|
import torch.distributed as dist
|
|
from dotenv import load_dotenv
|
|
from torch.utils.data import DistributedSampler
|
|
from tqdm import tqdm
|
|
|
|
from dataset.dataset import PersonaChatDataset
|
|
from utils.dist_helper import setup
|
|
from utils.format_inputs import TASK_TYPE
|
|
from utils.parser_helper import str2bool
|
|
|
|
os.environ["PYTHONIOENCODING"] = "utf-8"
|
|
myLocale = locale.setlocale(category=locale.LC_ALL, locale="C.UTF-8")
|
|
load_dotenv()
|
|
|
|
argparse = argparse.ArgumentParser()
|
|
argparse.add_argument('--model_path', type=str, default=None)
|
|
argparse.add_argument('--path_pattern', type=str, default=None)
|
|
argparse.add_argument('--batch_size', type=int)
|
|
argparse.add_argument('--valid_path', type=str, default=None)
|
|
argparse.add_argument('--local_rank', type=int, default=-1)
|
|
argparse.add_argument('--skip_exists', type=str2bool, default=False)
|
|
argparse.add_argument('--selection_noise', type=float, default=None)
|
|
parser = deepspeed.add_config_arguments(argparse)
|
|
args = argparse.parse_args()
|
|
_cmd_args = parser.parse_args()
|
|
freeze_support()
|
|
|
|
VICUNA_PREFIX = 'PATH_TO_VICUNA'
|
|
|
|
|
|
def test_process(model_paths, batch_size, valid_path, skip_exists, selection_noise, cmd_args):
|
|
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
|
with open(cmd_args.deepspeed_config) as json_file:
|
|
ds_config = json.load(json_file)
|
|
del cmd_args.deepspeed_config
|
|
|
|
setup()
|
|
for model_path in model_paths:
|
|
try:
|
|
if selection_noise is not None:
|
|
save_dir = os.sep.join(model_path.split(os.sep)[:-1]) + os.sep + f'evaluation_result_selection_noise={selection_noise}.pkl'
|
|
else:
|
|
save_dir = os.sep.join(model_path.split(os.sep)[:-1]) + os.sep + 'evaluation_result.pkl'
|
|
if os.path.exists(save_dir) and (skip_exists):
|
|
continue
|
|
print(
|
|
f"Start setup rank {deepspeed.comm.get_local_rank()} of {world_size} on GPU {torch.cuda.current_device()}")
|
|
|
|
ckpt = torch.load(os.sep.join(model_path.split(os.sep)[:-1]) + os.sep + 'checkpoint_best.pth',
|
|
map_location=f'cpu')
|
|
config = ckpt['config']
|
|
ds_config['train_micro_batch_size_per_gpu'] = batch_size
|
|
load_precision = '32'
|
|
if config.model.load_bit == 16:
|
|
ds_config['float16']['enabled'] = True
|
|
load_precision = 'fp16'
|
|
if 'llama' in config.model.model_name.lower():
|
|
ds_config['float16']['enabled'] = False
|
|
ds_config['bf16']['enabled'] = True
|
|
load_precision = 'bf16'
|
|
load_bit_map = {
|
|
'fp16': torch.float16,
|
|
'bf16': torch.bfloat16,
|
|
'32': torch.float32}
|
|
|
|
if config.model.model_type == 'selective_pt':
|
|
from models.selective_llm_chat import SelectLLMChat as LLMChat
|
|
else:
|
|
from models.llm_chat import LLMChat
|
|
if 'vicuna' in config.model.model_name and (not os.path.exists(config.model.model_name)):
|
|
config.model.model_name = VICUNA_PREFIX + os.sep + config.model.model_name.split(os.sep)[-1]
|
|
_model = LLMChat(config, batch_size)
|
|
left_tokenizer = _model.left_tokenizer
|
|
right_tokenizer = _model.right_tokenizer
|
|
print(f'LOADING {model_path} with {load_precision} precision')
|
|
model_engine, _, _, _ = deepspeed.initialize(args=cmd_args,
|
|
model=_model,
|
|
config=ds_config,
|
|
)
|
|
model_engine.load_checkpoint(model_path, load_module_strict=False, load_optimizer_states=False,
|
|
load_lr_scheduler_states=False,
|
|
load_module_only=True)
|
|
valid_path_file = valid_path
|
|
if valid_path_file is None:
|
|
valid_path_file = config.dataset.valid
|
|
if config.dataset.test.__class__ is str:
|
|
valid_path_file = config.dataset.test
|
|
print('using train split from personachat')
|
|
task_type = TASK_TYPE(config.training.task_type)
|
|
valid_dataset = PersonaChatDataset(valid_path_file, max_context_turns=config.dataset.max_context_turns)
|
|
from dataset.dataset import get_dataloader
|
|
max_new_token = 32
|
|
valid_sampler = DistributedSampler(valid_dataset, num_replicas=world_size, shuffle=False,
|
|
drop_last=False)
|
|
valid_dataloader = get_dataloader(valid_dataset, batch_size, num_workers=0, sampler=valid_sampler)
|
|
|
|
context_input = []
|
|
persona_list = []
|
|
dist_pred_text = [None for _ in range(world_size)]
|
|
dist_gt_text = [None for _ in range(world_size)]
|
|
pred_text = []
|
|
gt_text = []
|
|
tqdm_iterator = tqdm(valid_dataloader, total=len(valid_dataloader))
|
|
selected_prompts = []
|
|
for data in tqdm_iterator:
|
|
_, text, batch_selected_prompts = LLMChat.test_step(model_engine, data, left_tokenizer,
|
|
right_tokenizer,
|
|
config, max_new_tokens=max_new_token,
|
|
tqdm_instance=tqdm_iterator,
|
|
selection_noise=selection_noise)
|
|
if batch_selected_prompts.__class__ != list:
|
|
selected_prompts += (batch_selected_prompts.detach().cpu().tolist())
|
|
|
|
context_input += data['context_input']
|
|
persona_list += data['persona_list']
|
|
pred_text += text
|
|
gt_text += data['target']
|
|
|
|
clean_preds = []
|
|
for pred in pred_text:
|
|
search_result = re.search('R:|Q:|Summary:|\n|\:', pred)
|
|
if search_result is not None:
|
|
clean_preds.append(pred[:search_result.span()[0]])
|
|
else:
|
|
clean_preds.append(pred)
|
|
pred_text = clean_preds
|
|
dist.all_gather_object(dist_pred_text, pred_text)
|
|
dist.all_gather_object(dist_gt_text, gt_text)
|
|
pred_text = reduce(lambda x, y: x + y, dist_pred_text)
|
|
gt_text = reduce(lambda x, y: x + y, dist_gt_text)
|
|
from evaluation import bleu_score, f1_score, normalize_answer
|
|
bleu = bleu_score(pred_text, [gt_text])
|
|
import pickle
|
|
|
|
result = {
|
|
'context_input': context_input,
|
|
'persona_list': persona_list,
|
|
'pred_text': pred_text,
|
|
'gt_text': gt_text,
|
|
'bleu': bleu,
|
|
}
|
|
from collections import Counter
|
|
counter = Counter(selected_prompts)
|
|
if deepspeed.comm.get_local_rank() == 0:
|
|
print('bleu: ', bleu)
|
|
with open(save_dir, 'wb') as file:
|
|
pickle.dump(result, file)
|
|
with open(save_dir.replace('.pkl', '.txt'), 'w', encoding='utf-8') as file:
|
|
file.write('bleu: ' + str(bleu) + '\n')
|
|
if len(selected_prompts) > 0:
|
|
file.write('selected prompt: ' + str(counter) + '\n')
|
|
for i in range(len(context_input)):
|
|
if context_input[i].__class__ == list:
|
|
file.write('context: ' + str(u' '.join(context_input[i]).encode('utf-8')) + '\n')
|
|
else:
|
|
file.write('context: ' + str(context_input[i].encode('utf-8')) + '\n')
|
|
file.write('persona: ' + str(u' '.join(persona_list[i]).encode('utf-8')) + '\n')
|
|
file.write('pred: ' + pred_text[i] + '\n')
|
|
file.write('gt: ' + gt_text[i] + '\n')
|
|
if len(selected_prompts) > 0:
|
|
file.write('selected prompt: ' + str(selected_prompts[i]) + '\n')
|
|
file.write('\n')
|
|
except Exception as e:
|
|
save_dir = os.sep.join(model_path.split(os.sep)[:-1]) + os.sep + "test_error.txt"
|
|
print(f'WRITING TESTING ERROR! ERROR: {str(e)}')
|
|
with open(save_dir, 'w') as file:
|
|
file.write(str(e))
|
|
deepspeed.comm.barrier()
|
|
deepspeed.comm.barrier()
|
|
|
|
|
|
model_path_arg = args.model_path
|
|
model_paths = [model_path_arg]
|
|
if len(glob.glob(model_path_arg+os.sep+'ds_ckpt'+os.sep+'*')):
|
|
model_paths = [model_path_arg+os.sep+'ds_ckpt']
|
|
elif not model_path_arg.endswith('.pth'):
|
|
import glob
|
|
path_pattern = args.path_pattern
|
|
if path_pattern is not None:
|
|
model_paths = glob.glob(f'{model_path_arg}/{path_pattern}/ds_ckpt/*/*.pt')
|
|
else:
|
|
model_paths = glob.glob(f'{model_path_arg}/*/ds_ckpt/*/*.pt')
|
|
model_paths = list(set([os.sep.join(p.split(os.sep)[:-2]) for p in model_paths]))
|
|
print(model_paths)
|
|
num_of_gpus = torch.cuda.device_count()
|
|
print(f"{num_of_gpus} GPUs available")
|
|
test_process(model_paths, args.batch_size, args.valid_path,
|
|
args.skip_exists, args.selection_noise, cmd_args=_cmd_args)
|
|
deepspeed.comm.barrier()
|
|
deepspeed.comm.destroy_process_group()
|
|
|
|
|
|
print('Test Ends')
|
|
|