Update app.py
Browse files
app.py
CHANGED
@@ -19,12 +19,11 @@ user_message = "Allie kept track of how many kilometers she walked during the pa
|
|
19 |
model_name = 'TIGER-Lab/StructLM-7B'
|
20 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
21 |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
|
22 |
-
|
23 |
# model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
24 |
|
25 |
-
@torch.inference_mode()
|
26 |
@spaces.GPU
|
27 |
-
def
|
28 |
prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n{assistant_message}\n\n{tabular_data}\n\n\nQuestion:\n\n{user_message}[/INST]"
|
29 |
inputs = tokenizer(prompt, return_tensors='pt', add_special_tokens=True)
|
30 |
input_ids = inputs["input_ids"].to(model.device)
|
@@ -62,7 +61,7 @@ def main():
|
|
62 |
output_text = gr.Textbox(label="🐯📏TigerAI-StructLM-7B", interactive=True)
|
63 |
|
64 |
gr.Button("Try🐯📏TigerAI-StructLM").click(
|
65 |
-
|
66 |
inputs=[user_message_input, system_message_input, assistant_message_input, tabular_data_input, max_new_tokens, temperature, top_p, repetition_penalty, do_sample],
|
67 |
outputs=output_text
|
68 |
)
|
|
|
19 |
model_name = 'TIGER-Lab/StructLM-7B'
|
20 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
21 |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
|
22 |
+
model.generation_config = GenerationConfig.from_pretrained(model_name)
|
23 |
# model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
24 |
|
|
|
25 |
@spaces.GPU
|
26 |
+
def predict(user_message, system_message="", assistant_message = "", tabular_data = "", max_new_tokens=125, temperature=0.1, top_p=0.9, repetition_penalty=1.9, do_sample=False):
|
27 |
prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n{assistant_message}\n\n{tabular_data}\n\n\nQuestion:\n\n{user_message}[/INST]"
|
28 |
inputs = tokenizer(prompt, return_tensors='pt', add_special_tokens=True)
|
29 |
input_ids = inputs["input_ids"].to(model.device)
|
|
|
61 |
output_text = gr.Textbox(label="🐯📏TigerAI-StructLM-7B", interactive=True)
|
62 |
|
63 |
gr.Button("Try🐯📏TigerAI-StructLM").click(
|
64 |
+
predict,
|
65 |
inputs=[user_message_input, system_message_input, assistant_message_input, tabular_data_input, max_new_tokens, temperature, top_p, repetition_penalty, do_sample],
|
66 |
outputs=output_text
|
67 |
)
|