SPT / test.py
hqsiswiliam's picture
Upload 43 files
8359bb1 verified
import argparse
import glob
import json
import locale
import os
import re
from functools import reduce
from multiprocessing import freeze_support
import deepspeed
import torch
import torch.distributed as dist
from dotenv import load_dotenv
from torch.utils.data import DistributedSampler
from tqdm import tqdm
from dataset.dataset import PersonaChatDataset
from utils.dist_helper import setup
from utils.format_inputs import TASK_TYPE
from utils.parser_helper import str2bool
os.environ["PYTHONIOENCODING"] = "utf-8"
myLocale = locale.setlocale(category=locale.LC_ALL, locale="C.UTF-8")
load_dotenv()
argparse = argparse.ArgumentParser()
argparse.add_argument('--model_path', type=str, default=None)
argparse.add_argument('--path_pattern', type=str, default=None)
argparse.add_argument('--batch_size', type=int)
argparse.add_argument('--valid_path', type=str, default=None)
argparse.add_argument('--local_rank', type=int, default=-1)
argparse.add_argument('--skip_exists', type=str2bool, default=False)
argparse.add_argument('--selection_noise', type=float, default=None)
parser = deepspeed.add_config_arguments(argparse)
args = argparse.parse_args()
_cmd_args = parser.parse_args()
freeze_support()
VICUNA_PREFIX = 'PATH_TO_VICUNA'
def test_process(model_paths, batch_size, valid_path, skip_exists, selection_noise, cmd_args):
world_size = int(os.getenv("WORLD_SIZE", "1"))
with open(cmd_args.deepspeed_config) as json_file:
ds_config = json.load(json_file)
del cmd_args.deepspeed_config
setup()
for model_path in model_paths:
try:
if selection_noise is not None:
save_dir = os.sep.join(model_path.split(os.sep)[:-1]) + os.sep + f'evaluation_result_selection_noise={selection_noise}.pkl'
else:
save_dir = os.sep.join(model_path.split(os.sep)[:-1]) + os.sep + 'evaluation_result.pkl'
if os.path.exists(save_dir) and (skip_exists):
continue
print(
f"Start setup rank {deepspeed.comm.get_local_rank()} of {world_size} on GPU {torch.cuda.current_device()}")
ckpt = torch.load(os.sep.join(model_path.split(os.sep)[:-1]) + os.sep + 'checkpoint_best.pth',
map_location=f'cpu')
config = ckpt['config']
ds_config['train_micro_batch_size_per_gpu'] = batch_size
load_precision = '32'
if config.model.load_bit == 16:
ds_config['float16']['enabled'] = True
load_precision = 'fp16'
if 'llama' in config.model.model_name.lower():
ds_config['float16']['enabled'] = False
ds_config['bf16']['enabled'] = True
load_precision = 'bf16'
load_bit_map = {
'fp16': torch.float16,
'bf16': torch.bfloat16,
'32': torch.float32}
if config.model.model_type == 'selective_pt':
from models.selective_llm_chat import SelectLLMChat as LLMChat
else:
from models.llm_chat import LLMChat
if 'vicuna' in config.model.model_name and (not os.path.exists(config.model.model_name)):
config.model.model_name = VICUNA_PREFIX + os.sep + config.model.model_name.split(os.sep)[-1]
_model = LLMChat(config, batch_size)
left_tokenizer = _model.left_tokenizer
right_tokenizer = _model.right_tokenizer
print(f'LOADING {model_path} with {load_precision} precision')
model_engine, _, _, _ = deepspeed.initialize(args=cmd_args,
model=_model,
config=ds_config,
)
model_engine.load_checkpoint(model_path, load_module_strict=False, load_optimizer_states=False,
load_lr_scheduler_states=False,
load_module_only=True)
valid_path_file = valid_path
if valid_path_file is None:
valid_path_file = config.dataset.valid
if config.dataset.test.__class__ is str:
valid_path_file = config.dataset.test
print('using train split from personachat')
task_type = TASK_TYPE(config.training.task_type)
valid_dataset = PersonaChatDataset(valid_path_file, max_context_turns=config.dataset.max_context_turns)
from dataset.dataset import get_dataloader
max_new_token = 32
valid_sampler = DistributedSampler(valid_dataset, num_replicas=world_size, shuffle=False,
drop_last=False)
valid_dataloader = get_dataloader(valid_dataset, batch_size, num_workers=0, sampler=valid_sampler)
context_input = []
persona_list = []
dist_pred_text = [None for _ in range(world_size)]
dist_gt_text = [None for _ in range(world_size)]
pred_text = []
gt_text = []
tqdm_iterator = tqdm(valid_dataloader, total=len(valid_dataloader))
selected_prompts = []
for data in tqdm_iterator:
_, text, batch_selected_prompts = LLMChat.test_step(model_engine, data, left_tokenizer,
right_tokenizer,
config, max_new_tokens=max_new_token,
tqdm_instance=tqdm_iterator,
selection_noise=selection_noise)
if batch_selected_prompts.__class__ != list:
selected_prompts += (batch_selected_prompts.detach().cpu().tolist())
context_input += data['context_input']
persona_list += data['persona_list']
pred_text += text
gt_text += data['target']
clean_preds = []
for pred in pred_text:
search_result = re.search('R:|Q:|Summary:|\n|\:', pred)
if search_result is not None:
clean_preds.append(pred[:search_result.span()[0]])
else:
clean_preds.append(pred)
pred_text = clean_preds
dist.all_gather_object(dist_pred_text, pred_text)
dist.all_gather_object(dist_gt_text, gt_text)
pred_text = reduce(lambda x, y: x + y, dist_pred_text)
gt_text = reduce(lambda x, y: x + y, dist_gt_text)
from evaluation import bleu_score, f1_score, normalize_answer
bleu = bleu_score(pred_text, [gt_text])
import pickle
result = {
'context_input': context_input,
'persona_list': persona_list,
'pred_text': pred_text,
'gt_text': gt_text,
'bleu': bleu,
}
from collections import Counter
counter = Counter(selected_prompts)
if deepspeed.comm.get_local_rank() == 0:
print('bleu: ', bleu)
with open(save_dir, 'wb') as file:
pickle.dump(result, file)
with open(save_dir.replace('.pkl', '.txt'), 'w', encoding='utf-8') as file:
file.write('bleu: ' + str(bleu) + '\n')
if len(selected_prompts) > 0:
file.write('selected prompt: ' + str(counter) + '\n')
for i in range(len(context_input)):
if context_input[i].__class__ == list:
file.write('context: ' + str(u' '.join(context_input[i]).encode('utf-8')) + '\n')
else:
file.write('context: ' + str(context_input[i].encode('utf-8')) + '\n')
file.write('persona: ' + str(u' '.join(persona_list[i]).encode('utf-8')) + '\n')
file.write('pred: ' + pred_text[i] + '\n')
file.write('gt: ' + gt_text[i] + '\n')
if len(selected_prompts) > 0:
file.write('selected prompt: ' + str(selected_prompts[i]) + '\n')
file.write('\n')
except Exception as e:
save_dir = os.sep.join(model_path.split(os.sep)[:-1]) + os.sep + "test_error.txt"
print(f'WRITING TESTING ERROR! ERROR: {str(e)}')
with open(save_dir, 'w') as file:
file.write(str(e))
deepspeed.comm.barrier()
deepspeed.comm.barrier()
model_path_arg = args.model_path
model_paths = [model_path_arg]
if len(glob.glob(model_path_arg+os.sep+'ds_ckpt'+os.sep+'*')):
model_paths = [model_path_arg+os.sep+'ds_ckpt']
elif not model_path_arg.endswith('.pth'):
import glob
path_pattern = args.path_pattern
if path_pattern is not None:
model_paths = glob.glob(f'{model_path_arg}/{path_pattern}/ds_ckpt/*/*.pt')
else:
model_paths = glob.glob(f'{model_path_arg}/*/ds_ckpt/*/*.pt')
model_paths = list(set([os.sep.join(p.split(os.sep)[:-2]) for p in model_paths]))
print(model_paths)
num_of_gpus = torch.cuda.device_count()
print(f"{num_of_gpus} GPUs available")
test_process(model_paths, args.batch_size, args.valid_path,
args.skip_exists, args.selection_noise, cmd_args=_cmd_args)
deepspeed.comm.barrier()
deepspeed.comm.destroy_process_group()
# if not model_path_arg.endswith('.pth'):
# evaluate_folder(model_path_arg, skip_exists=args.skip_exists)
print('Test Ends')