Spaces:
Paused
Paused
import os | |
import time | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig | |
import gradio as gr | |
from threading import Thread | |
MODEL_LIST = ["nawhgnuj/DonaldTrump-Llama-3.1-8B-Chat"] | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
MODEL = os.environ.get("MODEL_ID", "nawhgnuj/DonaldTrump-Llama-3.1-8B-Chat") | |
TITLE = "<h1 style='color: #B71C1C; text-align: center;'>Donald Trump Chatbot</h1>" | |
TRUMP_AVATAR = "https://upload.wikimedia.org/wikipedia/commons/5/56/Donald_Trump_official_portrait.jpg" | |
CSS = """ | |
.chatbot { | |
background-color: white; | |
} | |
.duplicate-button { | |
margin: auto !important; | |
color: white !important; | |
background: #B71C1C !important; | |
border-radius: 100vh !important; | |
} | |
h3 { | |
text-align: center; | |
color: #B71C1C; | |
} | |
.contain {object-fit: contain} | |
.avatar {width: 40px; height: 40px; border-radius: 50%; object-fit: cover;} | |
.user-message { | |
background-color: white !important; | |
color: black !important; | |
} | |
.bot-message { | |
background-color: #B71C1C !important; | |
color: white !important; | |
} | |
""" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
quantization_config=quantization_config) | |
def stream_chat( | |
message: str, | |
history: list, | |
): | |
system_prompt = "You are a Donald Trump chatbot. You only answer like Trump in style and tone." | |
temperature = 0.8 | |
max_new_tokens = 1024 | |
top_p = 1.0 | |
top_k = 20 | |
penalty = 1.2 | |
conversation = [ | |
{"role": "system", "content": system_prompt} | |
] | |
for prompt, answer in history: | |
conversation.extend([ | |
{"role": "user", "content": prompt}, | |
{"role": "assistant", "content": answer}, | |
]) | |
conversation.append({"role": "user", "content": message}) | |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
input_ids=input_ids, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
eos_token_id=[128001,128008,128009], | |
streamer=streamer, | |
) | |
with torch.no_grad(): | |
thread = Thread(target=model.generate, kwargs=generate_kwargs) | |
thread.start() | |
buffer = "" | |
for new_text in streamer: | |
buffer += new_text | |
yield buffer | |
def add_text(history, text): | |
history = history + [(text, None)] | |
return history, "" | |
def bot(history): | |
user_message = history[-1][0] | |
bot_response = stream_chat(user_message, history[:-1]) | |
history[-1][1] = "" | |
for character in bot_response: | |
history[-1][1] += character | |
yield history | |
with gr.Blocks(css=CSS, theme=gr.themes.Default()) as demo: | |
gr.HTML(TITLE) | |
chatbot = gr.Chatbot( | |
[], | |
elem_id="chatbot", | |
avatar_images=(None, TRUMP_AVATAR), | |
height=600, | |
bubble_full_width=False, | |
show_label=False, | |
) | |
msg = gr.Textbox( | |
placeholder="Ask Donald Trump a question", | |
container=False, | |
scale=7 | |
) | |
with gr.Row(): | |
submit = gr.Button("Submit", scale=1, variant="primary") | |
clear = gr.Button("Clear", scale=1) | |
gr.Examples( | |
examples=[ | |
["What's your stance on immigration?"], | |
["How would you describe your economic policies?"], | |
["What are your thoughts on the media?"], | |
], | |
inputs=msg, | |
) | |
submit.click(add_text, [chatbot, msg], [chatbot, msg], queue=False).then( | |
bot, chatbot, chatbot | |
) | |
clear.click(lambda: [], outputs=[chatbot], queue=False) | |
msg.submit(add_text, [chatbot, msg], [chatbot, msg], queue=False).then( | |
bot, chatbot, chatbot | |
) | |
if __name__ == "__main__": | |
demo.launch() |