Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,942 Bytes
02fda95 35f8f29 02fda95 f9c87b8 02fda95 1ce044e ed81c40 02fda95 5d5244f adb03e2 f9c87b8 adb03e2 02fda95 0a5a4db 02fda95 35f8f29 02fda95 849c7fb 02fda95 e4a6d0b 039d7bc 02fda95 35f8f29 02fda95 35f8f29 02fda95 eb28899 35f8f29 02fda95 abc35ea 02fda95 cc4fee4 02fda95 f22800b cc4fee4 02fda95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MAX_MAX_NEW_TOKENS = 4096
DEFAULT_MAX_NEW_TOKENS = 2048
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
DESCRIPTION = """\
# CantoneseLLM Chat
Please join our [Discord server](https://discord.gg/gG6GPp8XxQ) and give me your feedback
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available():
model_id = "hon9kon9ize/CantoneseLLMChat-v1.0-7B"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
model = torch.compile(model)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
tokenizer.use_default_system_prompt = False
@spaces.GPU(queue=False)
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 2048,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> str:
conversation = []
conversation.append({"role": "system", "content": system_prompt if system_prompt else "你係由 hon9kon9ize 開發嘅 CantoneseLLM,你係一個好幫得手嘅助理" })
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
print(chat_history)
input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors='pt')
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
output_ids = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=repetition_penalty
)
response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
return response
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.2,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=None,
examples=[
["如果我食咗早餐,我就唔會肚餓。我今日冇肚餓,咁我今日食咗早餐未?"],
["小明有5粒糖,小華有3粒糖。如果小明畀咗一粒糖俾小華,咁佢哋兩個一共仲有幾多粒糖?"],
["咩嘢係氣候變化?"],
["香港最高嘅山係邊坐山?"],
["人體最重要嘅器官係咩?"]
],
)
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch() |