stablelm-2-chat / app.py
pvduy's picture
Update app.py
eb95198 verified
raw
history blame
3.94 kB
import argparse
import os
import spaces
import gradio as gr
import json
from threading import Thread
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_LENGTH = 4096
DEFAULT_MAX_NEW_TOKENS = 1024
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--base_model", type=str) # model path
parser.add_argument("--n_gpus", type=int, default=1) # n_gpu
return parser.parse_args()
@spaces.GPU()
def predict(message, history, system_prompt, temperature, max_tokens):
global model, tokenizer, device
messages = [{'role': 'system', 'content': system_prompt}]
for human, assistant in history:
messages.append({'role': 'user', 'content': human})
messages.append({'role': 'assistant', 'content': assistant})
messages.append({'role': 'user', 'content': message})
problem = [tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)]
stop_tokens = ["<|endoftext|>", "<|im_end|>"]
streamer = TextIteratorStreamer(tokenizer, timeout=100.0, skip_prompt=True, skip_special_tokens=True)
enc = tokenizer(problem, return_tensors="pt", padding=True, truncation=True)
input_ids = enc.input_ids
attention_mask = enc.attention_mask
if input_ids.shape[1] > MAX_LENGTH:
input_ids = input_ids[:, -MAX_LENGTH:]
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
generate_kwargs = dict(
{"input_ids": input_ids, "attention_mask": attention_mask},
streamer=streamer,
do_sample=True,
top_p=0.95,
temperature=temperature,
max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
use_cache=True,
eos_token_id=100278 # <|im_end|>
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
if __name__ == "__main__":
args = parse_args()
tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-2-12b-chat", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-2-12b-chat", trust_remote_code=True, torch_dtype=torch.bfloat16)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
gr.ChatInterface(
predict,
title="StableLM 2 12B Chat - Demo",
description="StableLM 2 12B Chat - StabilityAI",
theme="soft",
chatbot=gr.Chatbot(label="Chat History",),
textbox=gr.Textbox(placeholder="input", container=False, scale=7),
retry_btn=None,
undo_btn="Delete Previous",
clear_btn="Clear",
additional_inputs=[
gr.Textbox("You are a helpful assistant.", label="System Prompt"),
gr.Slider(0, 1, 0.5, label="Temperature"),
gr.Slider(100, 2048, 1024, label="Max Tokens"),
],
examples=[
["Implement snake game using pygame"],
["Escribe un poema corto sobre la historia del Mediterráneo."],
["Scrivi un Haiku che celebri il gelato"],
["Schreibe ein Haiku über die Alpen."],
["Ecris une prose a propos de la mer du Nord."],
["Escreva um poema sobre a saudade."],
["""\
This is the error that is thrown for the given python code.
```python
import pandas as pd
df = pd.read_csv(PATH)
This throws an error,
```
```shell
pandas/io/common.py"", line 873, in get_handle
handle = open(
^^^^^
FileNotFoundError: [Errno 2] No such file or directory: 'data.csv'```
How to solve this ?
"""
],
["Jane has 8 apples, out of which 2 are red and 3 are green. How many of them are white? Assuming there are only 3 colors, can you solve this with python?"],
],
additional_inputs_accordion_name="Parameters",
).queue().launch()