Spaces:
Sleeping
Sleeping
File size: 4,695 Bytes
cc1bdc1 293f365 cc1bdc1 293f365 cc1bdc1 293f365 cc1bdc1 293f365 cc1bdc1 c22db80 cc1bdc1 c22db80 cc1bdc1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
DESCRIPTION = """\
# Gemma 2 9B Neogenesis ITA 💎🌍🇮🇹
Fine-tuned version of VAGOsolutions/SauerkrautLM-gemma-2-9b-it to improve the performance on the Italian language.
Good model with 9.24 billion parameters, with 8k context length.
[🪪 **Model card**](https://huggingface.co/anakin87/gemma-2-9b-neogenesis-ita)
"""
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_id = "anakin87/gemma-2-9b-neogenesis-ita"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model.config.sliding_window = 4096
model.eval()
@spaces.GPU
def generate(
message: str,
chat_history: list[dict],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = chat_history.copy()
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=None,
examples=[
["Ciao! Come stai?"],
["Scrivi l'incipit di un racconto che inizia con: 'Era una notte buia e tempestosa, ma Anna non aveva paura del temporale..."],
["Cos'è uno static method in python? Fornisci un esempio"],
["Fammi un elenco puntato dei pro e contro di vivere in Italia. Massimo 2 pro e 2 contro."],
["Risolvere 9x^2+2x=-5"],
["Immagina di essere il capo di una missione spaziale su un pianeta sconosciuto. Durante l'esplorazione, scopri una civiltà aliena che sembra essere un pericolo per l'umanità. Come ti comporti con loro, e quali azioni intraprendi per proteggere il futuro dell'umanità, pur rispettando le leggi universali della non-interferenza?"],
["How many hours does it take a man to eat a Helicopter?"],
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
],
cache_examples=False,
type="messages",
)
fonts = {"font":[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"],
"font_mono": [gr.themes.GoogleFont("IBM Plex Mono"), "ui-monospace", "Consolas", "monospace"]}
with gr.Blocks(css_paths="style.css", fill_height=True, theme=gr.themes.Soft(**fonts)) as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|