## The following code is adapted from ## https://docs.mystic.ai/docs/llama-2-with-vllm-7b-13b-multi-gpu-70b from transformers import AutoTokenizer from vllm import LLM, SamplingParams from arguments import get_args from dataset_conv import get_chatqa2_input, preprocess from tqdm import tqdm import torch import os os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ['VLLM_NCCL_SO_PATH'] = '/usr/local/lib/python3.8/dist-packages/nvidia/nccl/lib/libnccl.so.2' def get_prompt_list(args): ## get tokenizer tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) data_list = preprocess(args.sample_input_file, inference_only=True, retrieved_neighbours=args.use_retrieved_neighbours) print("number of total data_list:", len(data_list)) if args.start_idx != -1 and args.end_idx != -1: print("getting data from %d to %d" % (args.start_idx, args.end_idx)) data_list = data_list[args.start_idx:args.end_idx] print("number of test samples in the dataset:", len(data_list)) prompt_list = get_chatqa2_input(data_list, args.eval_dataset, tokenizer, num_ctx=args.num_ctx, max_output_len=args.max_tokens, max_seq_length=args.max_seq_length) return prompt_list def main(): args = get_args() ## bos token for llama-3 bos_token = "<|begin_of_text|>" ## get model_path model_path = args.model_folder ## get prompt_list prompt_list = get_prompt_list(args) output_path = os.path.join(model_path, "outputs") if not os.path.exists(output_path): os.mkdir(output_path) ## get output_datapath if args.start_idx != -1 and args.end_idx != -1: if args.use_retrieved_neighbours: output_datapath = os.path.join(output_path, "%s_output_%dto%d_ctx%d.txt" % (args.eval_dataset, args.start_idx, args.end_idx, args.num_ctx)) else: output_datapath = os.path.join(output_path, "%s_output_%dto%d.txt" % (args.eval_dataset, args.start_idx, args.end_idx)) else: if args.use_retrieved_neighbours: output_datapath = os.path.join(output_path, "%s_output_ctx%d.txt" % (args.eval_dataset, args.num_ctx)) else: output_datapath = os.path.join(output_path, "%s_output.txt" % (args.eval_dataset)) ## run inference sampling_params = SamplingParams(temperature=0, top_k=1, max_tokens=args.max_tokens) ## This changes the GPU support to 8 model_vllm = LLM(model_path, tensor_parallel_size=8, dtype=torch.bfloat16) print(model_vllm) output_list = [] for prompt in tqdm(prompt_list): prompt = bos_token + prompt output = model_vllm.generate([prompt], sampling_params)[0] generated_text = output.outputs[0].text generated_text = generated_text.strip().replace("\n", " ") ## for llama3 if "<|eot_id|>" in generated_text: idx = generated_text.index("<|eot_id|>") generated_text = generated_text[:idx] if "<|end_of_text|>" in generated_text: idx = generated_text.index("<|end_of_text|>") generated_text = generated_text[:idx] print("="*80) print("prompt:", prompt) print("-"*80) print("generated_text:", generated_text) print("="*80) output_list.append(generated_text) print("writing to %s" % output_datapath) with open(output_datapath, "w", encoding="utf-8") as f: for output in output_list: f.write(output + "\n") if __name__ == "__main__": main()