|
import spaces |
|
import torch |
|
import sys |
|
import html |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
|
from threading import Thread |
|
import gradio as gr |
|
|
|
|
|
title = """# 🙋🏻♂️Welcome to🌟Tonic's🐯📏TigerAI-StructLM-7B |
|
StructLM, is a series of open-source large language models (LLMs) finetuned for structured knowledge grounding (SKG) tasks. You can build with this endpoint using 🐯📏TigerAI-StructLM available here : [TIGER-Lab/StructLM-7B](https://huggingface.co/TIGER-Lab/StructLM-7B). |
|
You can also use 🐯📏TigerAI-StructLM by cloning this space. Simply click here: <a style="display:inline-block" href="https://huggingface.co/spaces/Tonic/TigerLM?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></h3> |
|
Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) Math with [introspector](https://huggingface.co/introspector) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [SciTonic](https://github.com/Tonic-AI/scitonic)🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗 |
|
""" |
|
assistant_message = """Use the information in the following table to solve the problem, choose between the choices if they are provided. table:""" |
|
system_message = "You are an AI assistant that specializes in analyzing and reasoning over structured information. You will be given a task, optionally with some structured knowledge input. Your answer must strictly adhere to the output format, if specified." |
|
tabular_data = "col : day | kilometers row 1 : tuesday | 0 row 2 : wednesday | 0 row 3 : thursday | 4 row 4 : friday | 0 row 5 : saturday | 0" |
|
user_message = "Allie kept track of how many kilometers she walked during the past 5 days. What is the range of the numbers?" |
|
model_name = 'TIGER-Lab/StructLM-7B' |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True) |
|
|
|
|
|
|
|
@spaces.GPU |
|
def predict(user_message, system_message="", assistant_message = "", tabular_data = "", max_new_tokens=125, temperature=0.1, top_p=0.9, repetition_penalty=1.9, do_sample=False): |
|
prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n{assistant_message}\n\n{tabular_data}\n\n\nQuestion:\n\n{user_message}[/INST]" |
|
inputs = tokenizer(prompt, return_tensors='pt', add_special_tokens=True) |
|
input_ids = inputs["input_ids"].to(model.device) |
|
output_ids = model.generate( |
|
input_ids, |
|
max_length=input_ids.shape[1] + max_new_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
pad_token_id=tokenizer.eos_token_id, |
|
do_sample=do_sample |
|
) |
|
response = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
return response |
|
|
|
def main(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown(title) |
|
with gr.Row(): |
|
system_message_input = gr.Textbox(label="📉System Prompt", placeholder=system_message) |
|
assistant_message_input = gr.Textbox(label="Assistant Message", placeholder=assistant_message) |
|
tabular_data_input = gr.Textbox(label="Tabular Data", placeholder=tabular_data) |
|
user_message_input = gr.Textbox(label="🫡Enter your query here...", placeholder=user_message) |
|
|
|
with gr.Accordion("Advanced Settings"): |
|
with gr.Row(): |
|
max_new_tokens = gr.Slider(label="Max new tokens", value=125, minimum=25, maximum=1250) |
|
temperature = gr.Slider(label="Temperature", value=0.1, minimum=0.05, maximum=1.0) |
|
top_p = gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.01, maximum=0.99) |
|
repetition_penalty = gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0) |
|
do_sample = gr.Checkbox(label="Do sample", value=False) |
|
|
|
output_text = gr.Textbox(label="🐯📏TigerAI-StructLM-7B") |
|
|
|
gr.Button("Try🐯📏TigerAI-StructLM").click( |
|
predict, |
|
inputs=[user_message_input, system_message_input, assistant_message_input, tabular_data_input, max_new_tokens, temperature, top_p, repetition_penalty, do_sample], |
|
outputs=output_text |
|
) |
|
|
|
demo.launch() |