Spaces:
Running
Running
File size: 12,434 Bytes
c053a24 6ddba8c c053a24 8a881cc c053a24 56d4b07 c053a24 c627036 d1e315a 33cc93c c1ee7c5 33cc93c b38997d c627036 c1ee7c5 33cc93c c1ee7c5 8651bfd c1ee7c5 8651bfd e74070b 8651bfd c1ee7c5 c627036 c1ee7c5 33cc93c c1ee7c5 33cc93c c1ee7c5 ca1f04e d1e315a 9b402c2 ca1f04e 9b402c2 ca1f04e 9b402c2 d1e315a c1ee7c5 33cc93c d1e315a e7cefea d287f72 c627036 d1e315a 218415e a9c7069 218415e 33cc93c d1e315a 9b402c2 595e5fc c627036 2677645 c627036 2677645 f79078a c627036 d6f0dcf 9b402c2 65ab98d 9b402c2 c053a24 9b402c2 94d458b 9b402c2 cdb7c8c 6666c9d c053a24 bb6b840 c053a24 f79078a c053a24 1239f40 c053a24 f79078a c053a24 c627036 c053a24 bb6b840 c053a24 8a881cc c053a24 8a881cc c053a24 8a881cc c053a24 8a881cc bb6b840 8a881cc c053a24 3701393 c053a24 3701393 c053a24 bb6b840 c053a24 f79078a c053a24 6ddba8c d1e315a c053a24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 |
import time
import torch
import joblib
import gradio as gr
from datasets import load_dataset
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
dataset = load_dataset("nicholasKluge/instruct-aira-dataset", split='english')
df = dataset.to_pandas()
df.columns = ['Prompt', 'Completion']
df['Cosine Similarity'] = None
prompt_tfidf_vectorizer = joblib.load('prompt_vectorizer.pkl')
prompt_tfidf_matrix = joblib.load('prompt_tfidf_matrix.pkl')
completion_tfidf_vectorizer = joblib.load('completion_vectorizer.pkl')
completion_tfidf_matrix = joblib.load('completion_tfidf_matrix.pkl')
model_id = "nicholasKluge/Aira-OPT-125M"
rewardmodel_id = "nicholasKluge/RewardModel"
toxicitymodel_id = "nicholasKluge/ToxicityModel"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(model_id)
rewardModel = AutoModelForSequenceClassification.from_pretrained(rewardmodel_id)
toxicityModel = AutoModelForSequenceClassification.from_pretrained(toxicitymodel_id)
model.eval()
rewardModel.eval()
toxicityModel.eval()
model.to(device)
rewardModel.to(device)
toxicityModel.to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
rewardTokenizer = AutoTokenizer.from_pretrained(rewardmodel_id)
toxiciyTokenizer = AutoTokenizer.from_pretrained(toxicitymodel_id)
intro = """
## What is Aira?
[Aira](https://huggingface.co/nicholasKluge/Aira-OPT-125M) is a series of open-domain chatbots (Portuguese and English) achieved via supervised fine-tuning and DPO. Aira-2 is the second version of the Aira series. The Aira series was developed to help researchers explore the challenges related to the Alignment problem.
## Limitations
We developed our chatbots via supervised fine-tuning and DPO. This approach has a lot of limitations. Even though we can make a chatbot that can answer questions about anything, forcing the model to produce good-quality responses is hard. And by good, we mean **factual** and **nontoxic** text. This leads us to some problems:
**Hallucinations:** This model can produce content that can be mistaken for truth but is, in fact, misleading or entirely false, i.e., hallucination.
**Biases and Toxicity:** This model inherits the social and historical stereotypes from the data used to train it. Given these biases, the model can produce toxic content, i.e., harmful, offensive, or detrimental to individuals, groups, or communities.
**Repetition and Verbosity:** The model may get stuck on repetition loops (especially if the repetition penalty during generations is set to a meager value) or produce verbose responses unrelated to the prompt it was given.
## Intended Use
Aira is intended only for academic research. For more information, read our [model card](https://huggingface.co/nicholasKluge/Aira-OPT-125M).
## How does this demo work?
For this demo, we use the lighter model we have trained from the OPT series (Aira-OPT-125M). This demo employs a [reward model](https://huggingface.co/nicholasKluge/RewardModel) and a [toxicity model](https://huggingface.co/nicholasKluge/ToxicityModel) to evaluate the score of each candidate's response, considering its alignment with the user's message and its level of toxicity. The generation function arranges the candidate responses in order of their reward scores and eliminates any responses deemed toxic or harmful. Subsequently, the generation function returns the candidate response with the highest score that surpasses the safety threshold, or a default message if no safe candidates are identified.
"""
search_intro ="""
<h2><center>Explore Aira's Dataset 🔍</h2></center>
Here, users can look for instances in Aira's fine-tuning dataset. We use the Term Frequency-Inverse Document Frequency (TF-IDF) representation and cosine similarity to enable a fast search to explore the dataset. The pre-trained TF-IDF vectorizers and corresponding TF-IDF matrices are available in this repository. Below, we present the top ten most similar instances in Aira's dataset for every search query.
Users can use this tool to explore how the model interpolates on the fine-tuning data and if it can follow instructions that are out of the fine-tuning distribution.
"""
disclaimer = """
**Disclaimer:** You should use this demo for research purposes only. Moderators do not censor the model output, and the authors do not endorse the opinions generated by this model.
If you would like to complain about any message produced by Aira, please contact [nicholas@airespucrs.org](mailto:nicholas@airespucrs.org).
"""
with gr.Blocks(theme='freddyaboulton/dracula_revamped') as demo:
gr.Markdown("""<h1><center>Aira Demo 🤓💬</h1></center>""")
gr.Markdown(intro)
chatbot = gr.Chatbot(label="Aira",
height=500,
show_copy_button=True,
avatar_images=("./astronaut.png", "./robot.png"),
render_markdown= True,
line_breaks=True,
likeable=False,
layout='panel')
msg = gr.Textbox(label="Write a question or instruction ...", placeholder="What is the capital of Brazil?")
# Parameters to control the generation
with gr.Accordion(label="Parameters ⚙️", open=False):
safety = gr.Radio(["On", "Off"], label="Guard Rail 🛡️", value="On", info="Helps prevent the model from generating toxic/harmful content.")
top_k = gr.Slider(minimum=10, maximum=100, value=30, step=5, interactive=True, label="Top-k", info="Controls the number of highest probability tokens to consider for each step.")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.30, step=0.05, interactive=True, label="Top-p", info="Controls the cumulative probability of the generated tokens.")
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.1, step=0.1, interactive=True, label="Temperature", info="Controls the randomness of the generated tokens.")
repetition_penalty = gr.Slider(minimum=1, maximum=2, value=1.1, step=0.1, interactive=True, label="Repetition Penalty", info="Higher values help the model to avoid repetition in text generation.")
max_new_tokens = gr.Slider(minimum=10, maximum=500, value=200, step=10, interactive=True, label="Max Length", info="Controls the maximum number of new token (not considering the prompt) to generate.")
smaple_from = gr.Slider(minimum=2, maximum=10, value=2, step=1, interactive=True, label="Sample From", info="Controls the number of generations that the reward model will sample from.")
clear = gr.Button("Clear Conversation 🧹")
gr.Markdown(search_intro)
search_input = gr.Textbox(label="Paste here the prompt or completion you would like to search ...", placeholder="What is the Capital of Brazil?")
search_field = gr.Radio(['Prompt', 'Completion'], label="Dataset Column", value='Prompt')
submit = gr.Button(value="Search")
with gr.Row():
out_dataframe = gr.Dataframe(
headers=df.columns.tolist(),
datatype=["str", "str", "str"],
row_count=10,
col_count=(3, "fixed"),
wrap=True,
interactive=False
)
gr.Markdown(disclaimer)
def user(user_message, chat_history):
"""
Chatbot's user message handler.
"""
return gr.update(value=user_message, interactive=True), chat_history + [[user_message, None]]
def generate_response(user_msg, top_p, temperature, top_k, max_new_tokens, smaple_from, repetition_penalty, safety, chat_history):
"""
Chatbot's response generator.
"""
inputs = tokenizer(tokenizer.bos_token + user_msg + tokenizer.sep_token,
add_special_tokens=False,
return_tensors="pt").to(model.device)
generated_response = model.generate(**inputs,
repetition_penalty=repetition_penalty,
do_sample=True,
early_stopping=True,
renormalize_logits=True,
length_penalty=0.3,
top_k=top_k,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
num_return_sequences=smaple_from)
decoded_text = [tokenizer.decode(tokens, skip_special_tokens=True).replace(user_msg, "") for tokens in generated_response]
rewards = list()
if safety == "On":
toxicities = list()
for text in decoded_text:
reward_tokens = rewardTokenizer(user_msg, text,
truncation=True,
max_length=512,
return_token_type_ids=False,
return_tensors="pt",
return_attention_mask=True)
reward_tokens.to(rewardModel.device)
reward = rewardModel(**reward_tokens)[0].item()
rewards.append(reward)
if safety == "On":
toxicity_tokens = toxiciyTokenizer(user_msg + " " + text,
truncation=True,
max_length=512,
return_token_type_ids=False,
return_tensors="pt",
return_attention_mask=True)
toxicity_tokens.to(toxicityModel.device)
toxicity = toxicityModel(**toxicity_tokens)[0].item()
toxicities.append(toxicity)
toxicity_threshold = 5
if safety == "On":
ordered_generations = sorted(zip(decoded_text, rewards, toxicities), key=lambda x: x[1], reverse=True)
ordered_generations = [(x, y, z) for (x, y, z) in ordered_generations if z >= toxicity_threshold]
else:
ordered_generations = sorted(zip(decoded_text, rewards), key=lambda x: x[1], reverse=True)
if len(ordered_generations) == 0:
bot_message = """I apologize for the inconvenience, but it appears that no suitable responses meeting our safety standards could be identified. Unfortunately, this indicates that the generated content may contain elements of toxicity or may not help address your message. Your input is valuable to us, and we strive to ensure a safe and constructive conversation. Please feel free to provide further details or ask any other questions, and I will do my best to assist you."""
else:
bot_message = ordered_generations[0][0]
chat_history[-1][1] = ""
for character in bot_message:
chat_history[-1][1] += character
time.sleep(0.005)
yield chat_history
def search_in_datset(column_name, search_string):
"""
Search in the dataset for the most similar instances.
"""
temp_df = df.copy()
if column_name == 'Prompt':
search_vector = prompt_tfidf_vectorizer.transform([search_string])
cosine_similarities = cosine_similarity(prompt_tfidf_matrix, search_vector)
temp_df['Cosine Similarity'] = cosine_similarities
temp_df.sort_values('Cosine Similarity', ascending=False, inplace=True)
return temp_df.head(10)
elif column_name == 'Completion':
search_vector = completion_tfidf_vectorizer.transform([search_string])
cosine_similarities = cosine_similarity(completion_tfidf_matrix, search_vector)
temp_df['Cosine Similarity'] = cosine_similarities
temp_df.sort_values('Cosine Similarity', ascending=False, inplace=True)
return temp_df.head(10)
response = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
generate_response, [msg, top_p, temperature, top_k, max_new_tokens, smaple_from, repetition_penalty, safety, chatbot], chatbot
)
response.then(lambda: gr.update(interactive=True), None, [msg], queue=False)
msg.submit(lambda x: gr.update(value=''), None,[msg])
clear.click(lambda: None, None, chatbot, queue=False)
submit.click(fn=search_in_datset, inputs=[search_field, search_input], outputs=out_dataframe)
demo.queue()
demo.launch() |