Deci
/

Text Generation
Transformers
Safetensors
English
deci
Deci AI
DeciLM
custom_code
Eval Results
OferB commited on
Commit
f4091e9
1 Parent(s): 7c06d7a

Adding throughput benchmark example

Browse files
Files changed (1) hide show
  1. hf_benchmark_example.py +202 -0
hf_benchmark_example.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import datasets
4
+ import torch
5
+ import transformers
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
7
+ from argparse import ArgumentParser
8
+
9
+
10
+ def parse_args():
11
+ parser = ArgumentParser()
12
+
13
+ parser.add_argument(
14
+ "--model",
15
+ required=True,
16
+ help="Model to evaluate, provide a repo name in Hugging Face hub or a local path",
17
+ )
18
+ parser.add_argument(
19
+ "--temperature",
20
+ default=0.2,
21
+ type=float
22
+ )
23
+ parser.add_argument(
24
+ "--top_p",
25
+ default=0.95,
26
+ type=float
27
+ )
28
+ parser.add_argument(
29
+ "--top_k",
30
+ default=0,
31
+ type=float
32
+ )
33
+
34
+ parser.add_argument(
35
+ "--revision",
36
+ default=None,
37
+ help="Model revision to use",
38
+ )
39
+ parser.add_argument(
40
+ "--iterations",
41
+ type=int,
42
+ default=6,
43
+ help="Model revision to use",
44
+ )
45
+ parser.add_argument(
46
+ "--batch_size",
47
+ type=int,
48
+ default=64,
49
+ help="Batch size for evaluation on each worker, can be larger for HumanEval",
50
+
51
+ )
52
+ parser.add_argument(
53
+ "--prompt_length",
54
+ type=int,
55
+ default=512,
56
+ )
57
+ parser.add_argument(
58
+ "--max_new_tokens",
59
+ type=int,
60
+ default=512,
61
+ help="Maximum length of generated sequence (prompt+generation)",
62
+ )
63
+ parser.add_argument(
64
+ "--precision",
65
+ type=str,
66
+ default="bf16",
67
+ help="Model precision, from: fp32, fp16 or bf16",
68
+ )
69
+ parser.add_argument(
70
+ "--text_file",
71
+ type=str,
72
+ default="sample.txt",
73
+ help="text file that will be used to generate tokens for prompts",
74
+ )
75
+ parser.add_argument(
76
+ "--load_in_8bit",
77
+ action="store_true",
78
+ help="Load model in 8bit",
79
+ )
80
+ parser.add_argument(
81
+ "--load_in_4bit",
82
+ action="store_true",
83
+ help="Load model in 4bit",
84
+ )
85
+ return parser.parse_args()
86
+
87
+
88
+ def main():
89
+ args = parse_args()
90
+ transformers.logging.set_verbosity_error()
91
+ datasets.logging.set_verbosity_error()
92
+
93
+
94
+ results = {}
95
+ dict_precisions = {
96
+ "fp32": torch.float32,
97
+ "fp16": torch.float16,
98
+ "bf16": torch.bfloat16,
99
+ }
100
+ if args.precision not in dict_precisions:
101
+ raise ValueError(
102
+ f"Non valid precision {args.precision}, choose from: fp16, fp32, bf16"
103
+ )
104
+ if args.load_in_8bit:
105
+ print("Loading model in 8bit")
106
+ # the model needs to fit in one GPU
107
+ model = AutoModelForCausalLM.from_pretrained(
108
+ args.model,
109
+ revision=args.revision,
110
+ load_in_8bit=args.load_in_8bit,
111
+ trust_remote_code=args.trust_remote_code,
112
+ use_auth_token=args.use_auth_token,
113
+ device_map={"": 'cuda'},
114
+ )
115
+ elif args.load_in_4bit:
116
+ print("Loading model in 4bit")
117
+ # the model needs to fit in one GPU
118
+ model = AutoModelForCausalLM.from_pretrained(
119
+ args.model,
120
+ revision=args.revision,
121
+ load_in_4bit=args.load_in_4bit,
122
+ trust_remote_code=args.trust_remote_code,
123
+ use_auth_token=args.use_auth_token,
124
+ device_map={"": 'cuda'},
125
+ )
126
+ else:
127
+ print(f"Loading model in {args.precision}")
128
+ model = AutoModelForCausalLM.from_pretrained(
129
+ args.model,
130
+ torch_dtype=torch.bfloat16,
131
+ trust_remote_code=True,
132
+ use_auth_token=True
133
+ )
134
+
135
+ tokenizer = AutoTokenizer.from_pretrained(
136
+ args.model,
137
+ revision=args.revision,
138
+ trust_remote_code=True,
139
+ use_auth_token=True,
140
+ )
141
+
142
+ starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
143
+ model.cuda()
144
+ model.eval()
145
+
146
+ with open(args.text_file, "r") as f:
147
+ prompt = f.read()
148
+
149
+ prompt = torch.tensor(tokenizer.encode(prompt))[:args.prompt_length].cuda()
150
+
151
+ results = {'prefill': [], 'gen': [], 'max_new_tokens': args.max_new_tokens, 'prompt_length': args.prompt_length, 'model': args.model, 'batch_size': args.batch_size}
152
+ inputs = prompt.repeat(args.batch_size, 1)
153
+
154
+ #warmup
155
+ print('start warmup')
156
+ for _ in range(10):
157
+ with torch.no_grad():
158
+ _ = model.generate(
159
+ input_ids=inputs,
160
+ max_new_tokens=1,
161
+ do_sample=False,
162
+ )
163
+ print('finish warmup')
164
+ torch.cuda.synchronize()
165
+
166
+ for prefill_iter in range(args.iterations):
167
+ starter.record()
168
+ with torch.no_grad():
169
+ _ = model.generate(
170
+ input_ids=inputs,
171
+ max_new_tokens=1,
172
+ do_sample=False,
173
+ )
174
+ ender.record()
175
+ torch.cuda.synchronize()
176
+ t = starter.elapsed_time(ender) / 1000
177
+ results['prefill'].append(t)
178
+ print(f'{args.batch_size} prefill iter {prefill_iter} took: {t}')
179
+
180
+
181
+ for gen_iter in range(args.iterations):
182
+ starter.record()
183
+ with torch.no_grad():
184
+ _ = model.generate(
185
+ input_ids=inputs,
186
+ max_new_tokens=args.max_new_tokens,
187
+ do_sample=False,
188
+ )
189
+ ender.record()
190
+ torch.cuda.synchronize()
191
+ t = starter.elapsed_time(ender) / 1000
192
+ results['gen'].append(t)
193
+
194
+ print(f'{args.batch_size} total generation iter {gen_iter} took: {t}')
195
+ print(f'{args.batch_size * args.max_new_tokens / t} tokens per seconds')
196
+ model_str = args.model.split('/')[-1]
197
+ with open(f'timing_{model_str}_{args.batch_size}.json', 'w') as f:
198
+ json.dump(results, f)
199
+
200
+
201
+ if __name__ == "__main__":
202
+ main()