Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
import os | |
from collections.abc import Iterator | |
from threading import Thread | |
import gradio as gr | |
import spaces | |
import torch | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
BitsAndBytesConfig, | |
TextIteratorStreamer, | |
) | |
DESCRIPTION = """# Swallow-13B instruct""" | |
if not torch.cuda.is_available(): | |
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" | |
if torch.cuda.is_available(): | |
model_name = "tokyotech-llm/Swallow-13b-instruct-hf" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
quantization_config=BitsAndBytesConfig(load_in_8bit=True), | |
low_cpu_mem_usage=True, | |
device_map="auto", | |
) | |
MAX_INPUT_TOKENS = 2048 | |
PROMPT_DICT = { | |
"prompt_input": ( | |
"以下に、あるタスクを説明する指示があり、それに付随する入力が更なる文脈を提供しています。" | |
"リクエストを適切に完了するための回答を記述してください。\n\n" | |
"### 指示:\n{instruction}\n\n### 入力:\n{input}\n\n### 応答:" | |
), | |
"prompt_no_input": ( | |
"以下に、あるタスクを説明する指示があります。" | |
"リクエストを適切に完了するための回答を記述してください。\n\n" | |
"### 指示:\n{instruction}\n\n### 応答:" | |
), | |
} | |
def create_prompt(instruction: str, input_text: str | None = None) -> str: | |
"""Generate a prompt based on the given instruction and an optional input. | |
If input is provided, it uses the 'prompt_input' template from PROMPT_DICT. | |
If no input is provided, it uses the 'prompt_no_input' template. | |
Args: | |
instruction (str): The instruction describing the task. | |
input_text (str | None): Additional input providing context for the task. Defaults to None. | |
Returns: | |
str: The generated prompt. | |
""" | |
if input_text: | |
# Use the 'prompt_input' template when additional input is provided | |
return PROMPT_DICT["prompt_input"].format(instruction=instruction, input=input_text) | |
# Use the 'prompt_no_input' template when no additional input is provided | |
return PROMPT_DICT["prompt_no_input"].format(instruction=instruction) | |
def run( | |
instruction: str, | |
input_text: str | None = None, | |
max_new_tokens: int = 256, | |
temperature: float = 0.99, | |
top_p: float = 0.95, | |
) -> Iterator[str]: | |
if input_text == "": | |
input_text = None | |
prompt = create_prompt(instruction, input_text) | |
input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") | |
if input_ids.shape[-1] > MAX_INPUT_TOKENS: | |
error_message = f"Input exceeds maximum number of tokens ({MAX_INPUT_TOKENS})" | |
raise gr.Error(error_message) | |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
{"input_ids": input_ids.to(model.device)}, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
outputs = [] | |
for text in streamer: | |
outputs.append(text) | |
yield "".join(outputs) | |
def process_example(instruction: str, input_text: str) -> Iterator[str]: | |
yield from run(instruction, input_text) | |
with gr.Blocks(css_paths="style.css") as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(): | |
instruction = gr.Textbox(label="Instruction", lines=5) | |
input_text = gr.Textbox(label="Input (optional)", lines=5) | |
run_button = gr.Button() | |
with gr.Accordion(label="Advanced Options", open=False): | |
max_new_tokens = gr.Slider(label="Max New Tokens", minimum=1, maximum=1024, step=1, value=256) | |
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, step=0.01, value=0.99) | |
top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.95) | |
with gr.Column(): | |
output = gr.Textbox(label="Output", lines=10) | |
run_button.click( | |
fn=run, | |
inputs=[instruction, input_text, max_new_tokens, temperature, top_p], | |
outputs=output, | |
api_name="run", | |
) | |
gr.Examples( | |
examples=[ | |
[ | |
"以下のトピックに関する詳細な情報を提供してください。", | |
"東京工業大学の主なキャンパスについて教えてください。", | |
], | |
[ | |
"以下のトピックに関する詳細な情報を提供してください。", | |
"夢オチとは何かについて教えてください。", | |
], | |
["暴れん坊将軍って誰のことですか?", ""], # noqa: RUF001 | |
], | |
inputs=[instruction, input_text], | |
outputs=output, | |
fn=process_example, | |
cache_examples=os.getenv("CACHE_EXAMPLES") == "1", | |
api_name=False, | |
) | |
if __name__ == "__main__": | |
demo.launch() | |