taufiqdp's picture
Update app.py
e9f724a verified
import os
import torch
import spaces
import subprocess
import gradio as gr
from threading import Thread
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
login(os.environ.get("HF_TOKEN"))
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
model_id = "microsoft/Phi-3-mini-128k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
attn_implementation="flash_attention_2"
)
@spaces.GPU()
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int,
temperature: float,
top_p: float,
top_k: int,
repetition_penalty: int
):
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.append({"role": "user", "content": user})
conversation.append({"role": "assistant", "content": assistant})
conversation.append({"role": "user", "content": message})
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
input_ids, attention_mask = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True
).to(model.device).values()
generate_kwargs = dict(
{"input_ids": input_ids, "attention_mask": attention_mask},
streamer=streamer,
do_sample=True,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_k=top_k,
repetition_penalty=repetition_penalty,
top_p=top_p
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for new_token in streamer:
outputs.append(new_token)
yield "".join(outputs)
gr.ChatInterface(
fn=generate,
title="πŸš€ Phi-3 mini 128k instruct",
description="",
additional_inputs=[
gr.Textbox(
label="System prompt",
lines=5,
value="You are a helpful digital assistant."
),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=2048,
step=1,
value=1024,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=1.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=None,
examples=[
["Can you provide ways to eat combinations of bananas and dragonfruits?"],
["Write a story about a dragon fruit that flies into outer space!"],
["I am going to Bali, what should I see"],
],
).queue().launch()