Spaces:
Running
Running
import time | |
import torch | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
model_id = "nicholasKluge/Aira-Instruct-124" | |
token = "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" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
if device == "cuda": | |
model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token, load_in_8bit=True) | |
else: | |
model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token) | |
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token) | |
model.to(device) | |
intro = """ | |
## What is `Aira`? | |
[`Aira`](https://github.com/Nkluge-correa/Aira-EXPERT) is a `chatbot` designed to simulate the way a human (expert) would behave during a round of questions and answers (Q&A). `Aira` has many iterations, from a closed-domain chatbot based on pre-defined rules to an open-domain chatbot achieved via fine-tuning pre-trained large language models. Aira has an area of expertise that comprises topics related to AI Ethics and AI Safety research. | |
We developed our open-domain conversational chatbots via conditional text generation/instruction fine-tuning. This approach has a lot of limitations. Even though we can make a chatbot that can answer questions about anything, forcing the model to produce good-quality responses is hard. And by good, we mean **factual** and **nontoxic** text. This leads us to two of the most common problems of generative models used in conversational applications: | |
🤥 Generative models can perpetuate the generation of pseudo-informative content, that is, false information that may appear truthful. | |
🤬 In certain types of tasks, generative models can produce harmful and discriminatory content inspired by historical stereotypes against sensitive attributes (for example, gender, race, and religion). | |
`Aira` is intended only for academic research. For more information, visit our [HuggingFace models](https://huggingface.co/nicholasKluge) to see how we developed `Aira`. | |
""" | |
disclaimer = """ | |
**Disclaimer:** You should use this demo for research purposes only. Moderators do not censor the model output, and the authors do not endorse the opinions generated by this model. | |
If you would like to complain about any message produced by `Aira`, please contact [nicholas@airespucrs.org](mailto:nicholas@airespucrs.org). | |
""" | |
with gr.Blocks(theme='freddyaboulton/dracula_revamped') as demo: | |
gr.Markdown("""<h1><center>Aira Demo 🤓💬</h1></center>""") | |
gr.Markdown(intro) | |
chatbot = gr.Chatbot(label="Aira").style(height=500) | |
with gr.Accordion(label="Parameters ⚙️", open=False): | |
top_k = gr.Slider( minimum=10, maximum=100, value=50, step=5, interactive=True, label="Top-k",) | |
top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.70, step=0.05, interactive=True, label="Top-p",) | |
temperature = gr.Slider( minimum=0.001, maximum=2.0, value=0.1, step=0.1, interactive=True, label="Temperature",) | |
max_length = gr.Slider( minimum=10, maximum=500, value=100, step=10, interactive=True, label="Max Length",) | |
msg = gr.Textbox(label="Write a question or comment to Aira ...", placeholder="Hi Aira, how are you?") | |
clear = gr.Button("Clear Conversation 🧹") | |
gr.Markdown(disclaimer) | |
def user(user_message, chat_history): | |
return gr.update(value=user_message, interactive=True), chat_history + [["👤 " + user_message, None]] | |
def generate_response(user_msg, top_p, temperature, top_k, max_length, chat_history): | |
inputs = tokenizer(tokenizer.bos_token + user_msg + tokenizer.eos_token, return_tensors="pt").to(device) | |
generated_response = model.generate(**inputs, | |
bos_token_id=tokenizer.bos_token_id, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
do_sample=True, | |
early_stopping=True, | |
top_k=top_k, | |
max_length=max_length, | |
top_p=top_p, | |
temperature=temperature, | |
num_return_sequences=1) | |
bot_message = tokenizer.decode(generated_response[0], skip_special_tokens=True).replace(user_msg, "") | |
chat_history[-1][1] = "🤖 " | |
for character in bot_message: | |
chat_history[-1][1] += character | |
time.sleep(0.005) | |
yield chat_history | |
response = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( | |
generate_response, [msg, top_p, temperature, top_k, max_length, chatbot], chatbot | |
) | |
response.then(lambda: gr.update(interactive=True), None, [msg], queue=False) | |
msg.submit(lambda x: gr.update(value=''), [],[msg]) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
demo.queue() | |
demo.launch() |