Tonic commited on
Commit
ea83807
·
verified ·
1 Parent(s): e446f68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -4
app.py CHANGED
@@ -19,7 +19,7 @@ user_message = "Allie kept track of how many kilometers she walked during the pa
19
  model_name = 'TIGER-Lab/StructLM-7B'
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
22
- model.generation_config = GenerationConfig.from_pretrained(model_name)
23
  # model.generation_config.pad_token_id = model.generation_config.eos_token_id
24
 
25
  @spaces.GPU
@@ -27,7 +27,6 @@ def predict(user_message, system_message="", assistant_message = "", tabular_dat
27
  prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n{assistant_message}\n\n{tabular_data}\n\n\nQuestion:\n\n{user_message}[/INST]"
28
  inputs = tokenizer(prompt, return_tensors='pt', add_special_tokens=True)
29
  input_ids = inputs["input_ids"].to(model.device)
30
-
31
  output_ids = model.generate(
32
  input_ids,
33
  max_length=input_ids.shape[1] + max_new_tokens,
@@ -37,7 +36,6 @@ def predict(user_message, system_message="", assistant_message = "", tabular_dat
37
  pad_token_id=tokenizer.eos_token_id,
38
  do_sample=do_sample
39
  )
40
-
41
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
42
  return response
43
 
@@ -58,7 +56,7 @@ def main():
58
  repetition_penalty = gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0)
59
  do_sample = gr.Checkbox(label="Do sample", value=False)
60
 
61
- output_text = gr.Textbox(label="🐯📏TigerAI-StructLM-7B", interactive=True)
62
 
63
  gr.Button("Try🐯📏TigerAI-StructLM").click(
64
  predict,
 
19
  model_name = 'TIGER-Lab/StructLM-7B'
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
22
+ # model.generation_config = GenerationConfig.from_pretrained(model_name)
23
  # model.generation_config.pad_token_id = model.generation_config.eos_token_id
24
 
25
  @spaces.GPU
 
27
  prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n{assistant_message}\n\n{tabular_data}\n\n\nQuestion:\n\n{user_message}[/INST]"
28
  inputs = tokenizer(prompt, return_tensors='pt', add_special_tokens=True)
29
  input_ids = inputs["input_ids"].to(model.device)
 
30
  output_ids = model.generate(
31
  input_ids,
32
  max_length=input_ids.shape[1] + max_new_tokens,
 
36
  pad_token_id=tokenizer.eos_token_id,
37
  do_sample=do_sample
38
  )
 
39
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
40
  return response
41
 
 
56
  repetition_penalty = gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0)
57
  do_sample = gr.Checkbox(label="Do sample", value=False)
58
 
59
+ output_text = gr.Textbox(label="🐯📏TigerAI-StructLM-7B")
60
 
61
  gr.Button("Try🐯📏TigerAI-StructLM").click(
62
  predict,