Spaces:
Sleeping
Sleeping
File size: 2,890 Bytes
b8c24aa 3a82207 63b82b4 c8fdb3b 3a82207 4e81072 7dc3087 08c1bd3 19af97e 7dc3087 8ea3940 7dc3087 63b82b4 b9b37c9 955e4ad 0844d7e 7dc3087 64d8a64 63b82b4 64d8a64 63b82b4 64d8a64 63b82b4 fccbbf3 63b82b4 08c1bd3 19af97e ea9c0d3 3a82207 63b82b4 3a82207 63b82b4 3a82207 63b82b4 0844d7e 3a82207 ea9c0d3 3a82207 7dc3087 0f1f78e 3a82207 63b82b4 e2534da 63b82b4 2cdab2a 63b82b4 ea9c0d3 63b82b4 9a34670 63b82b4 19af97e 63b82b4 3a82207 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
BitsAndBytesConfig,
)
import os
from threading import Thread
import spaces
import time
#token = os.environ["HF_TOKEN"]
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
)
model = AutoModelForCausalLM.from_pretrained(
"shisa-ai/shisa-v1-qwen2-7b", quantization_config=quantization_config)
tok = AutoTokenizer.from_pretrained("shisa-ai/shisa-v1-qwen2-7b")
#terminators = [
# tok.eos_token_id,
# tok.convert_tokens_to_ids("<|eot_id|>")
#]
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
device = torch.device("cpu")
print("Using CPU")
# model = model.to(device)
# Dispatch Errors
@spaces.GPU(duration=120)
def chat(message, history, temperature,do_sample, max_tokens):
chat = []
for item in history:
chat.append({"role": "user", "content": item[0]})
if item[1] is not None:
chat.append({"role": "assistant", "content": item[1]})
chat.append({"role": "user", "content": message})
messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
model_inputs = tok([messages], return_tensors="pt").to(device)
streamer = TextIteratorStreamer(
tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=tok.eos_token_id, # terminatorsをeos_token_idに変更
)
if temperature == 0:
generate_kwargs['do_sample'] = False
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_text = ""
for new_text in streamer:
partial_text += new_text
yield partial_text
tokens = len(tok.tokenize(partial_text))
yield partial_text
demo = gr.ChatInterface(
fn=chat,
examples=[["Write me a poem about Machine Learning."]],
# multimodal=False,
additional_inputs_accordion=gr.Accordion(
label="⚙️ Parameters", open=False, render=False
),
additional_inputs=[
gr.Slider(
minimum=0, maximum=1, step=0.1, value=0.9, label="Temperature", render=False
),
gr.Checkbox(label="Sampling",value=True),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=512,
label="Max new tokens",
render=False,
),
],
stop_btn="Stop Generation",
title="Chat With LLMs",
description="Now Running [shisa-ai/shisa-v1-qwen2-7b](https://huggingface.co/shisa-ai/shisa-v1-qwen2-7b) in 4bit"
)
demo.launch()
|