HydroxApp_t2t / benchmark.py
yangzzay's picture
Upload folder using huggingface_hub
02a7b98
import os
import time
import argparse
from dotenv import load_dotenv
from distutils.util import strtobool
from memory_profiler import memory_usage
from tqdm import tqdm
from llama2_wrapper import LLAMA2_WRAPPER
def run_iteration(
llama2_wrapper, prompt_example, DEFAULT_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS
):
def generation():
generator = llama2_wrapper.run(
prompt_example,
[],
DEFAULT_SYSTEM_PROMPT,
DEFAULT_MAX_NEW_TOKENS,
1,
0.95,
50,
)
model_response = None
try:
first_model_response = next(generator)
except StopIteration:
pass
for model_response in generator:
pass
return llama2_wrapper.get_token_length(model_response), model_response
tic = time.perf_counter()
mem_usage, (output_token_length, model_response) = memory_usage(
(generation,), max_usage=True, retval=True
)
toc = time.perf_counter()
generation_time = toc - tic
tokens_per_second = output_token_length / generation_time
return generation_time, tokens_per_second, mem_usage, model_response
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--iter", type=int, default=5, help="Number of iterations")
parser.add_argument("--model_path", type=str, default="", help="model path")
parser.add_argument(
"--backend_type",
type=str,
default="",
help="Backend options: llama.cpp, gptq, transformers",
)
parser.add_argument(
"--load_in_8bit",
type=bool,
default=False,
help="Whether to use bitsandbytes 8 bit.",
)
args = parser.parse_args()
load_dotenv()
DEFAULT_SYSTEM_PROMPT = os.getenv("DEFAULT_SYSTEM_PROMPT", "")
MAX_MAX_NEW_TOKENS = int(os.getenv("MAX_MAX_NEW_TOKENS", 2048))
DEFAULT_MAX_NEW_TOKENS = int(os.getenv("DEFAULT_MAX_NEW_TOKENS", 1024))
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", 4000))
MODEL_PATH = os.getenv("MODEL_PATH")
assert MODEL_PATH is not None, f"MODEL_PATH is required, got: {MODEL_PATH}"
BACKEND_TYPE = os.getenv("BACKEND_TYPE")
assert BACKEND_TYPE is not None, f"BACKEND_TYPE is required, got: {BACKEND_TYPE}"
LOAD_IN_8BIT = bool(strtobool(os.getenv("LOAD_IN_8BIT", "True")))
if args.model_path != "":
MODEL_PATH = args.model_path
if args.backend_type != "":
BACKEND_TYPE = args.backend_type
if args.load_in_8bit:
LOAD_IN_8BIT = True
# Initialization
init_tic = time.perf_counter()
llama2_wrapper = LLAMA2_WRAPPER(
model_path=MODEL_PATH,
backend_type=BACKEND_TYPE,
max_tokens=MAX_INPUT_TOKEN_LENGTH,
load_in_8bit=LOAD_IN_8BIT,
# verbose=True,
)
init_toc = time.perf_counter()
initialization_time = init_toc - init_tic
total_time = 0
total_tokens_per_second = 0
total_memory_gen = 0
prompt_example = (
"Can you explain briefly to me what is the Python programming language?"
)
# Cold run
print("Performing cold run...")
run_iteration(
llama2_wrapper, prompt_example, DEFAULT_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS
)
# Timed runs
print(f"Performing {args.iter} timed runs...")
for i in tqdm(range(args.iter)):
try:
gen_time, tokens_per_sec, mem_gen, model_response = run_iteration(
llama2_wrapper,
prompt_example,
DEFAULT_SYSTEM_PROMPT,
DEFAULT_MAX_NEW_TOKENS,
)
total_time += gen_time
total_tokens_per_second += tokens_per_sec
total_memory_gen += mem_gen
except:
break
avg_time = total_time / (i + 1)
avg_tokens_per_second = total_tokens_per_second / (i + 1)
avg_memory_gen = total_memory_gen / (i + 1)
print(f"Last model response: {model_response}")
print(f"Initialization time: {initialization_time:0.4f} seconds.")
print(
f"Average generation time over {(i + 1)} iterations: {avg_time:0.4f} seconds."
)
print(
f"Average speed over {(i + 1)} iterations: {avg_tokens_per_second:0.4f} tokens/sec."
)
print(f"Average memory usage during generation: {avg_memory_gen:.2f} MiB")
if __name__ == "__main__":
main()