TigerLM / app.py
Tonic's picture
Update app.py
1d347ad verified
raw
history blame
5.43 kB
import spaces
import torch
import sys
import html
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
import gradio as gr
title = """# 🙋🏻‍♂️Welcome to🌟Tonic's🐯📏TigerAI-StructLM-7B
StructLM, is a series of open-source large language models (LLMs) finetuned for structured knowledge grounding (SKG) tasks. You can build with this endpoint using 🐯📏TigerAI-StructLM available here : [TIGER-Lab/StructLM-7B](https://huggingface.co/TIGER-Lab/StructLM-7B).
You can also use 🐯📏TigerAI-StructLM by cloning this space. Simply click here: <a style="display:inline-block" href="https://huggingface.co/spaces/Tonic/TigerLM?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></h3>
Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) Math with [introspector](https://huggingface.co/introspector) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [SciTonic](https://github.com/Tonic-AI/scitonic)🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
"""
assistant_message = """Use the information in the following table to solve the problem, choose between the choices if they are provided. table:"""
system_message = "You are an AI assistant that specializes in analyzing and reasoning over structured information. You will be given a task, optionally with some structured knowledge input. Your answer must strictly adhere to the output format, if specified."
tabular_data = "col : day | kilometers row 1 : tuesday | 0 row 2 : wednesday | 0 row 3 : thursday | 4 row 4 : friday | 0 row 5 : saturday | 0"
user_message = "Allie kept track of how many kilometers she walked during the past 5 days. What is the range of the numbers?"
model_name = 'TIGER-Lab/StructLM-7B'
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True)
# model.generation_config = GenerationConfig.from_pretrained(model_name)
# model.generation_config.pad_token_id = model.generation_config.eos_token_id
@spaces.GPU
def predict(user_message, system_message="", assistant_message = "", tabular_data = "", max_new_tokens=125, temperature=0.1, top_p=0.9, repetition_penalty=1.9, do_sample=False):
prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n{assistant_message}\n\n{tabular_data}\n\n\nQuestion:\n\n{user_message}[/INST]"
inputs = tokenizer(prompt, return_tensors='pt', add_special_tokens=True)
input_ids = inputs["input_ids"].to(model.device)
output_ids = model.generate(
input_ids,
max_length=input_ids.shape[1] + max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.eos_token_id,
do_sample=do_sample
)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return response
def main():
with gr.Blocks() as demo:
gr.Markdown(title)
with gr.Row():
system_message_input = gr.Textbox(label="📉System Prompt", placeholder=system_message) # Renamed variable
assistant_message_input = gr.Textbox(label="Assistant Message", placeholder=assistant_message) # Renamed variable
tabular_data_input = gr.Textbox(label="Tabular Data", placeholder=tabular_data) # Renamed variable
user_message_input = gr.Textbox(label="🫡Enter your query here...", placeholder=user_message) # Renamed variable
with gr.Accordion("Advanced Settings"):
with gr.Row():
max_new_tokens = gr.Slider(label="Max new tokens", value=125, minimum=25, maximum=1250)
temperature = gr.Slider(label="Temperature", value=0.1, minimum=0.05, maximum=1.0)
top_p = gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.01, maximum=0.99)
repetition_penalty = gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0)
do_sample = gr.Checkbox(label="Do sample", value=False)
output_text = gr.Textbox(label="🐯📏TigerAI-StructLM-7B")
gr.Button("Try🐯📏TigerAI-StructLM").click(
predict,
inputs=[user_message_input, system_message_input, assistant_message_input, tabular_data_input, max_new_tokens, temperature, top_p, repetition_penalty, do_sample],
outputs=output_text
)
demo.launch()