CataLlama-Chat / app.py
laurentiubp's picture
Update app.py
d90ef51 verified
raw
history blame contribute delete
No virus
3.93 kB
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_INPUT_TOKEN_LENGTH = 8192
DESCRIPTION = """\
# CataLlama-v0.2-Instruct-DPO
This Space demonstrates model [CataLlama-v0.2-Instruct-DPO](https://huggingface.co/catallama/CataLlama-v0.2-Instruct-DPO).
CataLlama is a fine-tune of Llama-3-8B to enhance it's proficiency on the Catalan Language.
The model is capable of performing the following **tasks in Catalan**:
- Translation from English to Catalan and Catalan to English
- Summarization - both short form and long form
- Information extraction (suitable for RAG)
- Named Entity Recognition (NER)
- Open question answering
- Sentiment analysis
"""
LICENSE = """\
As a derivate work of [Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) by Meta, this demo is governed by the original [llama-3 license](https://llama.meta.com/llama3/license)
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available():
model_id = "catallama/CataLlama-v0.2-Instruct-SFT"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
@spaces.GPU(duration=120)
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int,
temperature: float,
top_p: float,
) -> Iterator[str]:
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(
value="Ets un chatbot amigable. Responeu preguntes i ajudeu els usuaris.",
label="System message",
lines=6
),
gr.Slider(
minimum=1,
maximum=2048,
value=1024,
step=256,
label="Max new tokens"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.3,
step=0.05,
label="Temperature"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.90,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
examples=[
["A quina velocitat poden volar els cocodrils?"],
["Explica pas a pas com resoldre l'equació següent: 2x + 10 = 0"],
["Pot Donald Trump sopar amb Juli Cèsar?"],
],
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
chat_interface.render()
gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.queue(max_size=20).launch()