kjcjohnson commited on
Commit
de0bfd0
1 Parent(s): 5c10330

Better performance and config?

Browse files
Files changed (2) hide show
  1. app.py +4 -3
  2. loop.py +45 -7
app.py CHANGED
@@ -5,14 +5,15 @@ MODEL_ID = "TinyLlama/TinyLlama_v1.1_math_code"
5
 
6
  handler = loop.EndpointHandler(MODEL_ID)
7
 
8
- def respond(prompt, grammar):
9
- args = { "inputs": prompt, "grammar": grammar }
10
  return handler(args)[0]
11
 
12
  demo = gr.Interface(
13
  respond,
14
  inputs=["textarea", "textarea"],
15
- outputs=["textarea"]
 
16
  )
17
 
18
  if __name__ == "__main__":
 
5
 
6
  handler = loop.EndpointHandler(MODEL_ID)
7
 
8
+ def respond(prompt, grammar, max_new_tokens, max_time):
9
+ args = { "inputs": prompt, "grammar": grammar, "max-new-tokens": max_new_tokens, "max-time": max_time }
10
  return handler(args)[0]
11
 
12
  demo = gr.Interface(
13
  respond,
14
  inputs=["textarea", "textarea"],
15
+ outputs=["textarea"],
16
+ additional_inputs=[gr.Number(value=512, precision=0), gr.Number(value=30, precision=0)]
17
  )
18
 
19
  if __name__ == "__main__":
loop.py CHANGED
@@ -1,23 +1,51 @@
1
  from typing import Dict, List, Any
2
 
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from transformers.generation.logits_process import LogitsProcessorList, InfNanRemoveLogitsProcessor
5
  from transformers_gad.grammar_utils import IncrementalGrammarConstraint
6
  from transformers_gad.generation.logits_process import GrammarAlignedOracleLogitsProcessor
7
 
 
 
 
 
 
 
8
  class EndpointHandler():
9
  def __init__(self, path=""):
10
  # Preload
 
 
 
 
 
11
  self.tokenizer = AutoTokenizer.from_pretrained(path)
 
 
12
  self.model = AutoModelForCausalLM.from_pretrained(path)
 
 
 
 
13
 
14
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
15
  # do it!
16
- inputs = data.get("inputs",data)
 
 
 
 
 
 
 
17
  grammar_str = data.get("grammar", "")
18
- MAX_NEW_TOKENS=4096
19
- MAX_TIME=300
 
 
20
  print(grammar_str)
 
21
  grammar = IncrementalGrammarConstraint(grammar_str, "root", self.tokenizer)
22
 
23
  # Initialize logits processor for the grammar
@@ -28,14 +56,24 @@ class EndpointHandler():
28
  gad_oracle_processor,
29
  ])
30
 
31
- input_ids = self.tokenizer([inputs], add_special_tokens=False, return_tensors="pt")["input_ids"]
 
32
 
33
  output = self.model.generate(
34
  input_ids,
35
  do_sample=True,
36
- max_time=MAX_TIME,
37
- max_new_tokens=MAX_NEW_TOKENS,
38
- logits_processor=logits_processors
 
 
 
 
 
 
 
 
 
39
  )
40
 
41
  gad_oracle_processor.reset()
 
1
  from typing import Dict, List, Any
2
 
3
+ import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from transformers.generation.logits_process import LogitsProcessorList, InfNanRemoveLogitsProcessor
6
  from transformers_gad.grammar_utils import IncrementalGrammarConstraint
7
  from transformers_gad.generation.logits_process import GrammarAlignedOracleLogitsProcessor
8
 
9
+ def safe_int_cast(str, default):
10
+ try:
11
+ return int(str)
12
+ except ValueError:
13
+ return default
14
+
15
  class EndpointHandler():
16
  def __init__(self, path=""):
17
  # Preload
18
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
+ DTYPE = torch.bfloat16
20
+
21
+ self.device = torch.device(DEVICE)
22
+
23
  self.tokenizer = AutoTokenizer.from_pretrained(path)
24
+ self.tokenizer.pad_token = self.tokenizer.eos_token
25
+
26
  self.model = AutoModelForCausalLM.from_pretrained(path)
27
+ self.model.to(self.device)
28
+ self.model.to(dtype=DTYPE)
29
+ self.model.resize_token_embeddings(len(self.tokenizer))
30
+ self.model = torch.compile(self.model, mode='reduce-overhead', fullgraph=True)
31
 
32
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
33
  # do it!
34
+ MAX_NEW_TOKENS=512
35
+ MAX_TIME=30
36
+ TEMPERATURE = 1.0
37
+ REPETITION_PENALTY = 1.0
38
+ TOP_P = 1.0
39
+ TOP_K = 0
40
+
41
+ inputs = data.get("inputs", data)
42
  grammar_str = data.get("grammar", "")
43
+ max_new_tokens = safe_int_cast(data.get("max-new-tokens"), MAX_NEW_TOKENS)
44
+ max_time = safe_int_cast(data.get("max-time"), MAX_TIME)
45
+
46
+ print("=== GOT GRAMMAR ===")
47
  print(grammar_str)
48
+ print("===================")
49
  grammar = IncrementalGrammarConstraint(grammar_str, "root", self.tokenizer)
50
 
51
  # Initialize logits processor for the grammar
 
56
  gad_oracle_processor,
57
  ])
58
 
59
+ input_ids = self.tokenizer([inputs], add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"]
60
+ input_ids = input_ids.to(self.model.device)
61
 
62
  output = self.model.generate(
63
  input_ids,
64
  do_sample=True,
65
+ pad_token_id=self.tokenizer.eos_token_id,
66
+ eos_token_id=self.tokenizer.eos_token_id,
67
+ max_time=max_time,
68
+ max_new_tokens=max_new_tokens,
69
+ top_p=TOP_P,
70
+ top_k=TOP_K,
71
+ repetition_penalty=REPETITION_PENALTY,
72
+ temperature=TEMPERATURE,
73
+ logits_processor=logits_processors,
74
+ num_return_sequences=1,
75
+ return_dict_in_generate=True,
76
+ output_scores=True
77
  )
78
 
79
  gad_oracle_processor.reset()