#!/usr/bin/env python import os from collections.abc import Iterator from threading import Thread import gradio as gr import spaces import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer, ) DESCRIPTION = """# Swallow-13B instruct""" if not torch.cuda.is_available(): DESCRIPTION += "\n

Running on CPU đŸ„¶ This demo does not work on CPU.

" if torch.cuda.is_available(): model_name = "tokyotech-llm/Swallow-13b-instruct-hf" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=BitsAndBytesConfig(load_in_8bit=True), low_cpu_mem_usage=True, device_map="auto", ) MAX_INPUT_TOKENS = 2048 PROMPT_DICT = { "prompt_input": ( "仄䞋に、あるタă‚čクをèȘŹæ˜Žă™ă‚‹æŒ‡ç€șăŒă‚ă‚Šă€ăă‚Œă«ä»˜éšă™ă‚‹ć…„ćŠ›ăŒæ›ŽăȘă‚‹æ–‡è„ˆă‚’æäŸ›ă—ăŠă„ăŸă™ă€‚" "ăƒȘクスă‚čăƒˆă‚’é©ćˆ‡ă«ćźŒäș†ă™ă‚‹ăŸă‚ăźć›žç­”ă‚’èš˜èż°ă—ăŠăă ă•ă„ă€‚\n\n" "### 指ç€ș:\n{instruction}\n\n### 慄抛:\n{input}\n\n### 濜答:" ), "prompt_no_input": ( "仄䞋に、あるタă‚čクをèȘŹæ˜Žă™ă‚‹æŒ‡ç€șăŒă‚ă‚ŠăŸă™ă€‚" "ăƒȘクスă‚čăƒˆă‚’é©ćˆ‡ă«ćźŒäș†ă™ă‚‹ăŸă‚ăźć›žç­”ă‚’èš˜èż°ă—ăŠăă ă•ă„ă€‚\n\n" "### 指ç€ș:\n{instruction}\n\n### 濜答:" ), } def create_prompt(instruction: str, input_text: str | None = None) -> str: """Generate a prompt based on the given instruction and an optional input. If input is provided, it uses the 'prompt_input' template from PROMPT_DICT. If no input is provided, it uses the 'prompt_no_input' template. Args: instruction (str): The instruction describing the task. input_text (str | None): Additional input providing context for the task. Defaults to None. Returns: str: The generated prompt. """ if input_text: # Use the 'prompt_input' template when additional input is provided return PROMPT_DICT["prompt_input"].format(instruction=instruction, input=input_text) # Use the 'prompt_no_input' template when no additional input is provided return PROMPT_DICT["prompt_no_input"].format(instruction=instruction) @spaces.GPU @torch.inference_mode() def run( instruction: str, input_text: str | None = None, max_new_tokens: int = 256, temperature: float = 0.99, top_p: float = 0.95, ) -> Iterator[str]: if input_text == "": input_text = None prompt = create_prompt(instruction, input_text) input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") if input_ids.shape[-1] > MAX_INPUT_TOKENS: error_message = f"Input exceeds maximum number of tokens ({MAX_INPUT_TOKENS})" raise gr.Error(error_message) streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( {"input_ids": input_ids.to(model.device)}, streamer=streamer, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) def process_example(instruction: str, input_text: str) -> Iterator[str]: yield from run(instruction, input_text) with gr.Blocks(css_paths="style.css") as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(): instruction = gr.Textbox(label="Instruction", lines=5) input_text = gr.Textbox(label="Input (optional)", lines=5) run_button = gr.Button() with gr.Accordion(label="Advanced Options", open=False): max_new_tokens = gr.Slider(label="Max New Tokens", minimum=1, maximum=1024, step=1, value=256) temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, step=0.01, value=0.99) top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.95) with gr.Column(): output = gr.Textbox(label="Output", lines=10) run_button.click( fn=run, inputs=[instruction, input_text, max_new_tokens, temperature, top_p], outputs=output, api_name="run", ) gr.Examples( examples=[ [ "ä»„äž‹ăźăƒˆăƒ”ăƒƒă‚Żă«é–ąă™ă‚‹è©łçŽ°ăȘæƒ…ć ±ă‚’æäŸ›ă—ăŠăă ă•ă„ă€‚", "東äșŹć·„æ„­ć€§ć­Šăźäž»ăȘキャンパă‚čă«ă€ă„ăŠæ•™ăˆăŠăă ă•ă„ă€‚", ], [ "ä»„äž‹ăźăƒˆăƒ”ăƒƒă‚Żă«é–ąă™ă‚‹è©łçŽ°ăȘæƒ…ć ±ă‚’æäŸ›ă—ăŠăă ă•ă„ă€‚", "怹ă‚ȘăƒăšăŻäœ•ă‹ă«ă€ă„ăŠæ•™ăˆăŠăă ă•ă„ă€‚", ], ["æšŽă‚Œă‚“ćŠć°†è»ăŁăŠèȘ°ăźă“ăšă§ă™ă‹ïŒŸ", ""], # noqa: RUF001 ], inputs=[instruction, input_text], outputs=output, fn=process_example, cache_examples=os.getenv("CACHE_EXAMPLES") == "1", api_name=False, ) if __name__ == "__main__": demo.launch()