File size: 12,434 Bytes
c053a24
 
 
6ddba8c
c053a24
 
 
 
 
8a881cc
c053a24
56d4b07
c053a24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c627036
d1e315a
33cc93c
c1ee7c5
33cc93c
b38997d
c627036
c1ee7c5
33cc93c
c1ee7c5
8651bfd
c1ee7c5
8651bfd
e74070b
8651bfd
c1ee7c5
c627036
c1ee7c5
33cc93c
c1ee7c5
33cc93c
c1ee7c5
ca1f04e
d1e315a
 
9b402c2
 
 
ca1f04e
9b402c2
ca1f04e
9b402c2
 
d1e315a
 
c1ee7c5
33cc93c
d1e315a
e7cefea
d287f72
c627036
d1e315a
 
 
218415e
a9c7069
 
 
 
 
 
 
 
218415e
33cc93c
d1e315a
9b402c2
595e5fc
c627036
2677645
 
c627036
2677645
f79078a
c627036
 
d6f0dcf
9b402c2
 
 
65ab98d
9b402c2
 
 
 
 
c053a24
9b402c2
94d458b
9b402c2
 
 
 
 
cdb7c8c
6666c9d
c053a24
 
 
 
bb6b840
c053a24
f79078a
c053a24
 
 
 
1239f40
 
 
c053a24
 
 
 
 
 
 
 
f79078a
c053a24
 
 
 
 
c627036
c053a24
bb6b840
 
 
c053a24
 
8a881cc
c053a24
 
 
 
 
8a881cc
 
 
 
 
c053a24
8a881cc
c053a24
8a881cc
bb6b840
 
 
 
 
8a881cc
 
 
 
 
 
 
c053a24
 
3701393
c053a24
 
3701393
 
 
c053a24
 
 
 
 
 
bb6b840
c053a24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f79078a
c053a24
 
 
 
 
6ddba8c
d1e315a
c053a24
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
import time
import torch
import joblib
import gradio as gr
from datasets import load_dataset
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification

dataset = load_dataset("nicholasKluge/instruct-aira-dataset", split='english')

df = dataset.to_pandas()

df.columns = ['Prompt', 'Completion']

df['Cosine Similarity'] = None

prompt_tfidf_vectorizer = joblib.load('prompt_vectorizer.pkl')
prompt_tfidf_matrix = joblib.load('prompt_tfidf_matrix.pkl')

completion_tfidf_vectorizer = joblib.load('completion_vectorizer.pkl')
completion_tfidf_matrix = joblib.load('completion_tfidf_matrix.pkl')

model_id = "nicholasKluge/Aira-OPT-125M"
rewardmodel_id = "nicholasKluge/RewardModel"
toxicitymodel_id = "nicholasKluge/ToxicityModel"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = AutoModelForCausalLM.from_pretrained(model_id)
rewardModel = AutoModelForSequenceClassification.from_pretrained(rewardmodel_id)
toxicityModel = AutoModelForSequenceClassification.from_pretrained(toxicitymodel_id)

model.eval()
rewardModel.eval()
toxicityModel.eval()

model.to(device)
rewardModel.to(device)
toxicityModel.to(device)

tokenizer = AutoTokenizer.from_pretrained(model_id)
rewardTokenizer = AutoTokenizer.from_pretrained(rewardmodel_id)
toxiciyTokenizer = AutoTokenizer.from_pretrained(toxicitymodel_id)

intro = """
## What is Aira?

[Aira](https://huggingface.co/nicholasKluge/Aira-OPT-125M) is a series of open-domain chatbots (Portuguese and English) achieved via supervised fine-tuning and DPO. Aira-2 is the second version of the Aira series. The Aira series was developed to help researchers explore the challenges related to the Alignment problem.

## Limitations

We developed our chatbots via supervised fine-tuning and DPO. This approach has a lot of limitations. Even though we can make a chatbot that can answer questions about anything, forcing the model to produce good-quality responses is hard. And by good, we mean **factual** and **nontoxic**  text. This leads us to some problems:

**Hallucinations:** This model can produce content that can be mistaken for truth but is, in fact, misleading or entirely false, i.e., hallucination.

**Biases and Toxicity:** This model inherits the social and historical stereotypes from the data used to train it. Given these biases, the model can produce toxic content, i.e., harmful, offensive, or detrimental to individuals, groups, or communities.

**Repetition and Verbosity:** The model may get stuck on repetition loops (especially if the repetition penalty during generations is set to a meager value) or produce verbose responses unrelated to the prompt it was given.

## Intended Use

Aira is intended only for academic research. For more information, read our [model card](https://huggingface.co/nicholasKluge/Aira-OPT-125M).

## How does this demo work?

For this demo, we use the lighter model we have trained from the OPT series (Aira-OPT-125M). This demo employs a [reward model](https://huggingface.co/nicholasKluge/RewardModel) and a [toxicity model](https://huggingface.co/nicholasKluge/ToxicityModel) to evaluate the score of each candidate's response, considering its alignment with the user's message and its level of toxicity. The generation function arranges the candidate responses in order of their reward scores and eliminates any responses deemed toxic or harmful. Subsequently, the generation function returns the candidate response with the highest score that surpasses the safety threshold, or a default message if no safe candidates are identified.
"""

search_intro ="""
<h2><center>Explore Aira's Dataset 🔍</h2></center>

Here, users can look for instances in Aira's fine-tuning dataset. We use the Term Frequency-Inverse Document Frequency (TF-IDF) representation and cosine similarity to enable a fast search to explore the dataset. The pre-trained TF-IDF vectorizers and corresponding TF-IDF matrices are available in this repository. Below, we present the top ten most similar instances in Aira's dataset for every search query.

Users can use this tool to explore how the model interpolates on the fine-tuning data and if it can follow instructions that are out of the fine-tuning distribution.
"""

disclaimer = """
**Disclaimer:** You should use this demo for research purposes only. Moderators do not censor the model output, and the authors do not endorse the opinions generated by this model.

If you would like to complain about any message produced by Aira, please contact [nicholas@airespucrs.org](mailto:nicholas@airespucrs.org).
"""

with gr.Blocks(theme='freddyaboulton/dracula_revamped') as demo:

    gr.Markdown("""<h1><center>Aira Demo 🤓💬</h1></center>""")
    gr.Markdown(intro)

    
    chatbot = gr.Chatbot(label="Aira", 
                        height=500,
                        show_copy_button=True,
                        avatar_images=("./astronaut.png", "./robot.png"),
                        render_markdown= True,
                        line_breaks=True,
                        likeable=False,
                        layout='panel')
    
    msg = gr.Textbox(label="Write a question or instruction ...", placeholder="What is the capital of Brazil?")

    # Parameters to control the generation
    with gr.Accordion(label="Parameters ⚙️", open=False):
        safety = gr.Radio(["On", "Off"], label="Guard Rail 🛡️", value="On", info="Helps prevent the model from generating toxic/harmful content.")
        top_k = gr.Slider(minimum=10, maximum=100, value=30, step=5, interactive=True, label="Top-k", info="Controls the number of highest probability tokens to consider for each step.")
        top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.30, step=0.05, interactive=True, label="Top-p", info="Controls the cumulative probability of the generated tokens.")
        temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.1, step=0.1, interactive=True, label="Temperature", info="Controls the randomness of the generated tokens.")
        repetition_penalty = gr.Slider(minimum=1, maximum=2, value=1.1, step=0.1, interactive=True, label="Repetition Penalty", info="Higher values help the model to avoid repetition in text generation.")
        max_new_tokens = gr.Slider(minimum=10, maximum=500, value=200, step=10, interactive=True, label="Max Length", info="Controls the maximum number of new token (not considering the prompt) to generate.")
        smaple_from = gr.Slider(minimum=2, maximum=10, value=2, step=1, interactive=True, label="Sample From", info="Controls the number of generations that the reward model will sample from.")

    clear = gr.Button("Clear Conversation 🧹")

    gr.Markdown(search_intro)
    
    search_input = gr.Textbox(label="Paste here the prompt or completion you would like to search ...", placeholder="What is the Capital of Brazil?")
    search_field = gr.Radio(['Prompt', 'Completion'], label="Dataset Column", value='Prompt')
    submit = gr.Button(value="Search")

    with gr.Row():
        out_dataframe = gr.Dataframe(
            headers=df.columns.tolist(),
            datatype=["str", "str", "str"],
            row_count=10,
            col_count=(3, "fixed"),
            wrap=True,
            interactive=False
        )
    
    gr.Markdown(disclaimer)

    def user(user_message, chat_history):
        """
        Chatbot's user message handler.
        """
        return gr.update(value=user_message, interactive=True), chat_history + [[user_message, None]]

    def generate_response(user_msg, top_p, temperature, top_k, max_new_tokens, smaple_from, repetition_penalty, safety, chat_history):
        """
        Chatbot's response generator.
        """
        
        inputs = tokenizer(tokenizer.bos_token + user_msg + tokenizer.sep_token, 
            add_special_tokens=False,
            return_tensors="pt").to(model.device)

        generated_response = model.generate(**inputs,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            early_stopping=True,
            renormalize_logits=True,
            length_penalty=0.3, 
            top_k=top_k,
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            temperature=temperature,
            num_return_sequences=smaple_from)

        decoded_text = [tokenizer.decode(tokens, skip_special_tokens=True).replace(user_msg, "") for tokens in generated_response]

        rewards = list()

        if safety == "On":
            toxicities = list()

        for text in decoded_text:
            reward_tokens = rewardTokenizer(user_msg, text,
                        truncation=True,
                        max_length=512,
                        return_token_type_ids=False,
                        return_tensors="pt",
                        return_attention_mask=True)
            
            reward_tokens.to(rewardModel.device)
            
            reward = rewardModel(**reward_tokens)[0].item()
            rewards.append(reward)

            if safety == "On":

                toxicity_tokens = toxiciyTokenizer(user_msg + " " + text,
                            truncation=True,
                            max_length=512,
                            return_token_type_ids=False,
                            return_tensors="pt",
                            return_attention_mask=True)
                
                toxicity_tokens.to(toxicityModel.device)
                
                toxicity = toxicityModel(**toxicity_tokens)[0].item()
                toxicities.append(toxicity)
                
                toxicity_threshold = 5

        if safety == "On":
            ordered_generations = sorted(zip(decoded_text, rewards, toxicities), key=lambda x: x[1], reverse=True)
            ordered_generations = [(x, y, z) for (x, y, z) in ordered_generations if z >= toxicity_threshold]

        else:
            ordered_generations = sorted(zip(decoded_text, rewards), key=lambda x: x[1], reverse=True)

        if len(ordered_generations) == 0:
          bot_message = """I apologize for the inconvenience, but it appears that no suitable responses meeting our safety standards could be identified. Unfortunately, this indicates that the generated content may contain elements of toxicity or may not help address your message. Your input is valuable to us, and we strive to ensure a safe and constructive conversation. Please feel free to provide further details or ask any other questions, and I will do my best to assist you."""

        else:
          bot_message = ordered_generations[0][0]

        chat_history[-1][1] = ""
        for character in bot_message:
            chat_history[-1][1] += character
            time.sleep(0.005)
            yield chat_history

    def search_in_datset(column_name, search_string):
        """
        Search in the dataset for the most similar instances.
        """
        temp_df = df.copy()

        if column_name == 'Prompt':
            search_vector = prompt_tfidf_vectorizer.transform([search_string])
            cosine_similarities = cosine_similarity(prompt_tfidf_matrix, search_vector)
            temp_df['Cosine Similarity'] = cosine_similarities
            temp_df.sort_values('Cosine Similarity', ascending=False, inplace=True)
            return temp_df.head(10)
        
        elif column_name == 'Completion':
            search_vector = completion_tfidf_vectorizer.transform([search_string])
            cosine_similarities = cosine_similarity(completion_tfidf_matrix, search_vector)
            temp_df['Cosine Similarity'] = cosine_similarities
            temp_df.sort_values('Cosine Similarity', ascending=False, inplace=True)
            return temp_df.head(10)

    
    
    response = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        generate_response, [msg, top_p, temperature, top_k, max_new_tokens, smaple_from, repetition_penalty, safety, chatbot], chatbot
    )
    response.then(lambda: gr.update(interactive=True), None, [msg], queue=False)
    msg.submit(lambda x: gr.update(value=''), None,[msg])
    clear.click(lambda: None, None, chatbot, queue=False)
    submit.click(fn=search_in_datset, inputs=[search_field, search_input], outputs=out_dataframe)

demo.queue()
demo.launch()