Tonic commited on
Commit
e446f68
·
verified ·
1 Parent(s): c9d06d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -19,12 +19,11 @@ user_message = "Allie kept track of how many kilometers she walked during the pa
19
  model_name = 'TIGER-Lab/StructLM-7B'
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
22
- # model.generation_config = GenerationConfig.from_pretrained(model_name)
23
  # model.generation_config.pad_token_id = model.generation_config.eos_token_id
24
 
25
- @torch.inference_mode()
26
  @spaces.GPU
27
- def predict_math_bot(user_message, system_message="", assistant_message = "", tabular_data = "", max_new_tokens=125, temperature=0.1, top_p=0.9, repetition_penalty=1.9, do_sample=False):
28
  prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n{assistant_message}\n\n{tabular_data}\n\n\nQuestion:\n\n{user_message}[/INST]"
29
  inputs = tokenizer(prompt, return_tensors='pt', add_special_tokens=True)
30
  input_ids = inputs["input_ids"].to(model.device)
@@ -62,7 +61,7 @@ def main():
62
  output_text = gr.Textbox(label="🐯📏TigerAI-StructLM-7B", interactive=True)
63
 
64
  gr.Button("Try🐯📏TigerAI-StructLM").click(
65
- predict_math_bot,
66
  inputs=[user_message_input, system_message_input, assistant_message_input, tabular_data_input, max_new_tokens, temperature, top_p, repetition_penalty, do_sample],
67
  outputs=output_text
68
  )
 
19
  model_name = 'TIGER-Lab/StructLM-7B'
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
22
+ model.generation_config = GenerationConfig.from_pretrained(model_name)
23
  # model.generation_config.pad_token_id = model.generation_config.eos_token_id
24
 
 
25
  @spaces.GPU
26
+ def predict(user_message, system_message="", assistant_message = "", tabular_data = "", max_new_tokens=125, temperature=0.1, top_p=0.9, repetition_penalty=1.9, do_sample=False):
27
  prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n{assistant_message}\n\n{tabular_data}\n\n\nQuestion:\n\n{user_message}[/INST]"
28
  inputs = tokenizer(prompt, return_tensors='pt', add_special_tokens=True)
29
  input_ids = inputs["input_ids"].to(model.device)
 
61
  output_text = gr.Textbox(label="🐯📏TigerAI-StructLM-7B", interactive=True)
62
 
63
  gr.Button("Try🐯📏TigerAI-StructLM").click(
64
+ predict,
65
  inputs=[user_message_input, system_message_input, assistant_message_input, tabular_data_input, max_new_tokens, temperature, top_p, repetition_penalty, do_sample],
66
  outputs=output_text
67
  )