#!/usr/bin/env python import os from collections.abc import Iterator from threading import Thread import gradio as gr import spaces import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer, ) DESCRIPTION = """# Swallow-13B instruct""" if not torch.cuda.is_available(): DESCRIPTION += "\n
Running on CPU đ„¶ This demo does not work on CPU.
" if torch.cuda.is_available(): model_name = "tokyotech-llm/Swallow-13b-instruct-hf" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=BitsAndBytesConfig(load_in_8bit=True), low_cpu_mem_usage=True, device_map="auto", ) MAX_INPUT_TOKENS = 2048 PROMPT_DICT = { "prompt_input": ( "仄äžă«ăăăăżăčăŻăèȘŹæăăæç€șăăăăăăă«ä»éăăć „ćăæŽăȘăæèăæäŸăăŠăăŸăă" "ăȘăŻăšăčăăé©ćă«ćźäșăăăăăźćçăèšèż°ăăŠăă ăăă\n\n" "### æç€ș:\n{instruction}\n\n### ć „ć:\n{input}\n\n### ćżç:" ), "prompt_no_input": ( "仄äžă«ăăăăżăčăŻăèȘŹæăăæç€șăăăăŸăă" "ăȘăŻăšăčăăé©ćă«ćźäșăăăăăźćçăèšèż°ăăŠăă ăăă\n\n" "### æç€ș:\n{instruction}\n\n### ćżç:" ), } def create_prompt(instruction: str, input_text: str | None = None) -> str: """Generate a prompt based on the given instruction and an optional input. If input is provided, it uses the 'prompt_input' template from PROMPT_DICT. If no input is provided, it uses the 'prompt_no_input' template. Args: instruction (str): The instruction describing the task. input_text (str | None): Additional input providing context for the task. Defaults to None. Returns: str: The generated prompt. """ if input_text: # Use the 'prompt_input' template when additional input is provided return PROMPT_DICT["prompt_input"].format(instruction=instruction, input=input_text) # Use the 'prompt_no_input' template when no additional input is provided return PROMPT_DICT["prompt_no_input"].format(instruction=instruction) @spaces.GPU @torch.inference_mode() def run( instruction: str, input_text: str | None = None, max_new_tokens: int = 256, temperature: float = 0.99, top_p: float = 0.95, ) -> Iterator[str]: if input_text == "": input_text = None prompt = create_prompt(instruction, input_text) input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") if input_ids.shape[-1] > MAX_INPUT_TOKENS: error_message = f"Input exceeds maximum number of tokens ({MAX_INPUT_TOKENS})" raise gr.Error(error_message) streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( {"input_ids": input_ids.to(model.device)}, streamer=streamer, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) def process_example(instruction: str, input_text: str) -> Iterator[str]: yield from run(instruction, input_text) with gr.Blocks(css_paths="style.css") as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(): instruction = gr.Textbox(label="Instruction", lines=5) input_text = gr.Textbox(label="Input (optional)", lines=5) run_button = gr.Button() with gr.Accordion(label="Advanced Options", open=False): max_new_tokens = gr.Slider(label="Max New Tokens", minimum=1, maximum=1024, step=1, value=256) temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, step=0.01, value=0.99) top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.95) with gr.Column(): output = gr.Textbox(label="Output", lines=10) run_button.click( fn=run, inputs=[instruction, input_text, max_new_tokens, temperature, top_p], outputs=output, api_name="run", ) gr.Examples( examples=[ [ "仄äžăźăăăăŻă«éąăăè©łçŽ°ăȘæ ć ±ăæäŸăăŠăă ăăă", "æ±äșŹć·„æ„性ćŠăźäž»ăȘăăŁăłăăčă«ă€ăăŠæăăŠăă ăăă", ], [ "仄äžăźăăăăŻă«éąăăè©łçŽ°ăȘæ ć ±ăæäŸăăŠăă ăăă", "怹ăȘăăšăŻäœăă«ă€ăăŠæăăŠăă ăăă", ], ["æŽăăćć°è»ăŁăŠèȘ°ăźăăšă§ăăïŒ", ""], # noqa: RUF001 ], inputs=[instruction, input_text], outputs=output, fn=process_example, cache_examples=os.getenv("CACHE_EXAMPLES") == "1", api_name=False, ) if __name__ == "__main__": demo.launch()