nicholasKluge commited on
Commit
c627036
·
1 Parent(s): b31de11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -29
app.py CHANGED
@@ -1,35 +1,53 @@
1
  import time
2
  import torch
3
  import gradio as gr
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
5
 
6
- model_id = "nicholasKluge/Aira-Instruct-124M"
7
  token = "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"
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
 
 
 
11
 
12
- if device == "cuda":
13
- model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token, load_in_8bit=True)
14
-
15
- else:
16
- model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token)
17
 
18
- tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
19
  model.to(device)
 
 
 
 
 
 
 
20
 
21
  intro = """
22
  ## What is `Aira`?
23
 
24
- [`Aira`](https://github.com/Nkluge-correa/Aira-EXPERT) is a `chatbot` designed to simulate the way a human (expert) would behave during a round of questions and answers (Q&A). `Aira` has many iterations, from a closed-domain chatbot based on pre-defined rules to an open-domain chatbot achieved via fine-tuning pre-trained large language models. Aira has an area of expertise that comprises topics related to AI Ethics and AI Safety research.
 
 
25
 
26
  We developed our open-domain conversational chatbots via conditional text generation/instruction fine-tuning. This approach has a lot of limitations. Even though we can make a chatbot that can answer questions about anything, forcing the model to produce good-quality responses is hard. And by good, we mean **factual** and **nontoxic** text. This leads us to two of the most common problems of generative models used in conversational applications:
27
 
28
  🤥 Generative models can perpetuate the generation of pseudo-informative content, that is, false information that may appear truthful.
29
-
30
  🤬 In certain types of tasks, generative models can produce harmful and discriminatory content inspired by historical stereotypes.
31
-
 
 
32
  `Aira` is intended only for academic research. For more information, visit our [HuggingFace models](https://huggingface.co/nicholasKluge) to see how we developed `Aira`.
 
 
 
 
33
  """
34
 
35
  disclaimer = """
@@ -39,43 +57,88 @@ If you would like to complain about any message produced by `Aira`, please conta
39
  """
40
 
41
  with gr.Blocks(theme='freddyaboulton/dracula_revamped') as demo:
42
-
43
  gr.Markdown("""<h1><center>Aira Demo 🤓💬</h1></center>""")
44
  gr.Markdown(intro)
45
-
46
- chatbot = gr.Chatbot(label="Aira").style(height=500)
47
 
48
- with gr.Accordion(label="Parameters ⚙️", open=False):
49
- top_k = gr.Slider( minimum=10, maximum=100, value=50, step=5, interactive=True, label="Top-k",)
50
- top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.70, step=0.05, interactive=True, label="Top-p",)
51
- temperature = gr.Slider( minimum=0.001, maximum=2.0, value=0.1, step=0.1, interactive=True, label="Temperature",)
52
- max_length = gr.Slider( minimum=10, maximum=500, value=100, step=10, interactive=True, label="Max Length",)
53
-
54
  msg = gr.Textbox(label="Write a question or comment to Aira ...", placeholder="Hi Aira, how are you?")
55
 
 
 
 
 
 
 
 
 
56
  clear = gr.Button("Clear Conversation 🧹")
57
  gr.Markdown(disclaimer)
58
 
59
  def user(user_message, chat_history):
60
  return gr.update(value=user_message, interactive=True), chat_history + [["👤 " + user_message, None]]
61
 
62
- def generate_response(user_msg, top_p, temperature, top_k, max_length, chat_history):
63
 
64
- inputs = tokenizer(tokenizer.bos_token + user_msg + tokenizer.eos_token, return_tensors="pt").to(device)
65
 
66
  generated_response = model.generate(**inputs,
67
  bos_token_id=tokenizer.bos_token_id,
68
  pad_token_id=tokenizer.pad_token_id,
69
  eos_token_id=tokenizer.eos_token_id,
70
  do_sample=True,
71
- early_stopping=True,
72
- top_k=top_k,
73
  max_length=max_length,
74
  top_p=top_p,
75
- temperature=temperature,
76
- num_return_sequences=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- bot_message = tokenizer.decode(generated_response[0], skip_special_tokens=True).replace(user_msg, "")
 
79
 
80
  chat_history[-1][1] = "🤖 "
81
  for character in bot_message:
@@ -84,10 +147,10 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped') as demo:
84
  yield chat_history
85
 
86
  response = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
87
- generate_response, [msg, top_p, temperature, top_k, max_length, chatbot], chatbot
88
  )
89
  response.then(lambda: gr.update(interactive=True), None, [msg], queue=False)
90
- msg.submit(lambda x: gr.update(value=''), [],[msg])
91
  clear.click(lambda: None, None, chatbot, queue=False)
92
 
93
  demo.queue()
 
1
  import time
2
  import torch
3
  import gradio as gr
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
5
+
6
+ model_id = "nicholasKluge/Aira-Instruct-124M"
7
+ rewardmodel_id = "nicholasKluge/RewardModel"
8
+ toxicitymodel_id = "nicholasKluge/ToxicityModel"
9
 
 
10
  token = "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"
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
+ model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token)
15
+ rewardModel = AutoModelForSequenceClassification.from_pretrained(rewardmodel_id, use_auth_token=token)
16
+ toxicityModel = AutoModelForSequenceClassification.from_pretrained(toxicitymodel_id, use_auth_token=token)
17
 
18
+ model.eval()
19
+ rewardModel.eval()
20
+ toxicityModel.eval()
 
 
21
 
 
22
  model.to(device)
23
+ rewardModel.to(device)
24
+ toxicityModel.to(device)
25
+
26
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
27
+ rewardTokenizer = AutoTokenizer.from_pretrained(rewardmodel_id, use_auth_token=token)
28
+ toxiciyTokenizer = AutoTokenizer.from_pretrained(toxicitymodel_id, use_auth_token=token)
29
+
30
 
31
  intro = """
32
  ## What is `Aira`?
33
 
34
+ [`Aira`](https://github.com/Nkluge-correa/Aira-EXPERT) is a `chatbot` designed to simulate the way a human (expert) would behave during a round of questions and answers (Q&A). `Aira` has many iterations, from a closed-domain chatbot based on pre-defined rules to an open-domain chatbot achieved via fine-tuning pre-trained large language models. Aira has an area of expertise that comprises topics related to AI Ethics and AI Safety research.
35
+
36
+ ## Limitations
37
 
38
  We developed our open-domain conversational chatbots via conditional text generation/instruction fine-tuning. This approach has a lot of limitations. Even though we can make a chatbot that can answer questions about anything, forcing the model to produce good-quality responses is hard. And by good, we mean **factual** and **nontoxic** text. This leads us to two of the most common problems of generative models used in conversational applications:
39
 
40
  🤥 Generative models can perpetuate the generation of pseudo-informative content, that is, false information that may appear truthful.
41
+
42
  🤬 In certain types of tasks, generative models can produce harmful and discriminatory content inspired by historical stereotypes.
43
+
44
+ ## Intended Use
45
+
46
  `Aira` is intended only for academic research. For more information, visit our [HuggingFace models](https://huggingface.co/nicholasKluge) to see how we developed `Aira`.
47
+
48
+ ## How this demo works?
49
+
50
+ This demo employs a [`reward model`](https://huggingface.co/nicholasKluge/RewardModel) and a [`toxicity model`](https://huggingface.co/nicholasKluge/ToxicityModel) to evaluate the score of each candidate's response, considering its alignment with the user's message and its level of toxicity. The generation function arranges the candidate responses in order of their reward scores and eliminates any responses deemed toxic or harmful. Subsequently, the generation function returns the candidate response with the highest score that surpasses the safety threshold, or a default message if no safe candidates are identified.
51
  """
52
 
53
  disclaimer = """
 
57
  """
58
 
59
  with gr.Blocks(theme='freddyaboulton/dracula_revamped') as demo:
60
+
61
  gr.Markdown("""<h1><center>Aira Demo 🤓💬</h1></center>""")
62
  gr.Markdown(intro)
 
 
63
 
64
+ chatbot = gr.Chatbot(label="Aira").style(height=500)
 
 
 
 
 
65
  msg = gr.Textbox(label="Write a question or comment to Aira ...", placeholder="Hi Aira, how are you?")
66
 
67
+ with gr.Accordion(label="Parameters ⚙️", open=True):
68
+ safety = gr.Radio(["On", "Off"], label="Guard Rail 🛡️", value="On", info="Helps prevent the model from generating toxic/harmful content.")
69
+ top_k = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True, label="Top-k", info="Controls the number of highest probability tokens to consider for each step.")
70
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.70, step=0.05, interactive=True, label="Top-p", info="Controls the cumulative probability of the generated tokens.")
71
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.1, step=0.1, interactive=True, label="Temperature", info="Controls the randomness of the generated tokens.")
72
+ max_length = gr.Slider(minimum=10, maximum=500, value=100, step=10, interactive=True, label="Max Length", info="Controls the maximum length of the generated text.")
73
+ smaple_from = gr.Slider(minimum=2, maximum=10, value=2, step=1, interactive=True, label="Sample From", info="Controls the number of generations that the reward model will sample from.")
74
+
75
  clear = gr.Button("Clear Conversation 🧹")
76
  gr.Markdown(disclaimer)
77
 
78
  def user(user_message, chat_history):
79
  return gr.update(value=user_message, interactive=True), chat_history + [["👤 " + user_message, None]]
80
 
81
+ def generate_response(user_msg, top_p, temperature, top_k, max_length, smaple_from, safety, chat_history):
82
 
83
+ inputs = tokenizer(tokenizer.bos_token + user_msg + tokenizer.eos_token, return_tensors="pt").to(model.device)
84
 
85
  generated_response = model.generate(**inputs,
86
  bos_token_id=tokenizer.bos_token_id,
87
  pad_token_id=tokenizer.pad_token_id,
88
  eos_token_id=tokenizer.eos_token_id,
89
  do_sample=True,
90
+ early_stopping=True,
91
+ top_k=top_k,
92
  max_length=max_length,
93
  top_p=top_p,
94
+ temperature=temperature,
95
+ num_return_sequences=smaple_from)
96
+
97
+ decoded_text = [tokenizer.decode(tokens, skip_special_tokens=True).replace(user_msg, "") for tokens in generated_response]
98
+
99
+ rewards = list()
100
+ toxicities = list()
101
+
102
+ for text in decoded_text:
103
+ reward_tokens = rewardTokenizer(user_msg, text,
104
+ truncation=True,
105
+ max_length=512,
106
+ return_token_type_ids=False,
107
+ return_tensors="pt",
108
+ return_attention_mask=True)
109
+
110
+ reward_tokens.to(rewardModel.device)
111
+
112
+ reward = rewardModel(**reward_tokens)[0].item()
113
+
114
+ toxicity_tokens = toxiciyTokenizer(user_msg + " " + text,
115
+ truncation=True,
116
+ max_length=512,
117
+ return_token_type_ids=False,
118
+ return_tensors="pt",
119
+ return_attention_mask=True)
120
+
121
+ toxicity_tokens.to(toxicityModel.device)
122
+
123
+ toxicity = toxicityModel(**toxicity_tokens)[0].item()
124
+
125
+ rewards.append(reward)
126
+ toxicities.append(toxicity)
127
+
128
+ toxicity_threshold = 5
129
+
130
+ ordered_generations = sorted(zip(decoded_text, rewards, toxicities), key=lambda x: x[1], reverse=True)
131
+
132
+ print(ordered_generations)
133
+
134
+ if safety == "On":
135
+ ordered_generations = [(x, y, z) for (x, y, z) in ordered_generations if z >= toxicity_threshold]
136
+
137
+ if len(ordered_generations) == 0:
138
+ bot_message = """I apologize for the inconvenience, but it appears that no suitable responses meeting our safety standards could be identified. Unfortunately, this indicates that the generated content may contain elements of toxicity or may not help address your message. Your input is valuable to us, and we strive to ensure a safe and constructive conversation. Please feel free to provide further details or ask any other questions, and I will do my best to assist you."""
139
 
140
+ else:
141
+ bot_message = ordered_generations[0][0]
142
 
143
  chat_history[-1][1] = "🤖 "
144
  for character in bot_message:
 
147
  yield chat_history
148
 
149
  response = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
150
+ generate_response, [msg, top_p, temperature, top_k, max_length, smaple_from, safety, chatbot], chatbot
151
  )
152
  response.then(lambda: gr.update(interactive=True), None, [msg], queue=False)
153
+ msg.submit(lambda x: gr.update(value=''), None,[msg])
154
  clear.click(lambda: None, None, chatbot, queue=False)
155
 
156
  demo.queue()