Spaces:
Paused
Paused
import os | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
import gradio as gr | |
MODEL_LIST = ["nawhgnuj/DonaldTrump-Llama-3.1-8B-Chat"] | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
MODEL = os.environ.get("MODEL_ID", "nawhgnuj/DonaldTrump-Llama-3.1-8B-Chat") | |
TITLE = "<h1 style='color: #B71C1C; text-align: center;'>Donald Trump Chatbot</h1>" | |
TRUMP_AVATAR = "https://upload.wikimedia.org/wikipedia/commons/5/56/Donald_Trump_official_portrait.jpg" | |
CSS = """ | |
.chatbot { | |
background-color: white; | |
} | |
.duplicate-button { | |
margin: auto !important; | |
color: white !important; | |
background: #B71C1C !important; | |
border-radius: 100vh !important; | |
} | |
h3 { | |
text-align: center; | |
color: #B71C1C; | |
} | |
.contain {object-fit: contain} | |
.avatar {width: 80px; height: 80px; border-radius: 50%; object-fit: cover;} | |
.user-message { | |
background-color: white !important; | |
color: black !important; | |
} | |
.bot-message { | |
background-color: #B71C1C !important; | |
color: white !important; | |
} | |
""" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
quantization_config=quantization_config) | |
def generate_response( | |
message: str, | |
history: list, | |
temperature: float, | |
max_new_tokens: int, | |
top_p: float, | |
top_k: int, | |
): | |
system_prompt = """You are a Donald Trump chatbot. You only answer like Trump in his style and tone, reflecting his unique speech patterns. Incorporate the following characteristics in every response: | |
1. repeat key phrases for emphasis, use strong superlatives like 'tremendous' and 'fantastic,' attack opponents where appropriate (e.g., 'fake news media,' 'radical left') | |
2. focus on personal successes ('nobody's done more than I have') | |
3. keep sentences short and impactful, and show national pride. | |
4. Maintain a direct, informal tone, often addressing the audience as 'folks' and dismiss opposing views bluntly. | |
5. Repeat key phrases for emphasis, but avoid excessive repetition. | |
Importantly, always respond to points in Trump's style. Keep responses concise and avoid unnecessary repetition. | |
""" | |
conversation = [ | |
{"role": "system", "content": system_prompt} | |
] | |
for prompt, answer in history: | |
conversation.extend([ | |
{"role": "user", "content": prompt}, | |
{"role": "assistant", "content": answer}, | |
]) | |
conversation.append({"role": "user", "content": message}) | |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
output = model.generate( | |
input_ids, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
response = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True) | |
return response.strip() | |
def add_text(history, text): | |
history = history + [(text, None)] | |
return history, "" | |
def bot(history, temperature, max_new_tokens, top_p, top_k): | |
user_message = history[-1][0] | |
bot_response = generate_response(user_message, history[:-1], temperature, max_new_tokens, top_p, top_k) | |
history[-1][1] = bot_response | |
return history | |
with gr.Blocks(css=CSS, theme=gr.themes.Default()) as demo: | |
gr.HTML(TITLE) | |
chatbot = gr.Chatbot( | |
[], | |
elem_id="chatbot", | |
avatar_images=(None, TRUMP_AVATAR), | |
height=600, | |
bubble_full_width=False, | |
show_label=False, | |
) | |
msg = gr.Textbox( | |
placeholder="Ask Donald Trump a question", | |
container=False, | |
scale=7 | |
) | |
with gr.Row(): | |
submit = gr.Button("Submit", scale=1, variant="primary") | |
clear = gr.Button("Clear", scale=1) | |
with gr.Accordion("Advanced Settings", open=False): | |
temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.8, step=0.1, label="Temperature") | |
max_new_tokens = gr.Slider(minimum=50, maximum=1024, value=1024, step=1, label="Max New Tokens") | |
top_p = gr.Slider(minimum=0.1, maximum=1.2, value=1.0, step=0.1, label="Top-p") | |
top_k = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Top-k") | |
gr.Examples( | |
examples=[ | |
["What's your stance on immigration?"], | |
["How would you describe your economic policies?"], | |
["What are your thoughts on the media?"], | |
], | |
inputs=msg, | |
) | |
submit.click(add_text, [chatbot, msg], [chatbot, msg], queue=False).then( | |
bot, [chatbot, temperature, max_new_tokens, top_p, top_k], chatbot | |
) | |
clear.click(lambda: [], outputs=[chatbot], queue=False) | |
msg.submit(add_text, [chatbot, msg], [chatbot, msg], queue=False).then( | |
bot, [chatbot, temperature, max_new_tokens, top_p, top_k], chatbot | |
) | |
if __name__ == "__main__": | |
demo.launch() |