updated app.py
Browse files
app.py
CHANGED
@@ -13,7 +13,7 @@
|
|
13 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
# See the License for the specific language governing permissions and
|
15 |
# limitations under the License.
|
16 |
-
|
17 |
"""
|
18 |
Example command with bag of words:
|
19 |
python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95
|
@@ -608,10 +608,9 @@ def generate_text_pplm(
|
|
608 |
last_reps = torch.ones(50257)
|
609 |
last_reps = last_reps.to(device)
|
610 |
for i in range_func:
|
611 |
-
|
612 |
# Get past/probs for current output, except for last word
|
613 |
# Note that GPT takes 2 inputs: past + current_token
|
614 |
-
|
615 |
# run model forward to obtain unperturbed
|
616 |
if past is None and output_so_far is not None:
|
617 |
last = output_so_far[:, -1:]
|
@@ -739,7 +738,7 @@ def set_generic_model_params(discrim_weights, discrim_meta):
|
|
739 |
DISCRIMINATOR_MODELS_PARAMS['generic'] = meta
|
740 |
|
741 |
|
742 |
-
pretrained_model="
|
743 |
cond_text=""
|
744 |
uncond=False
|
745 |
num_samples=1
|
@@ -758,15 +757,15 @@ grad_length=10000
|
|
758 |
horizon_length=5
|
759 |
window_length=0
|
760 |
decay=False
|
761 |
-
gamma=1.
|
762 |
gm_scale=0.95
|
763 |
kl_scale=0.01
|
764 |
seed=0
|
765 |
no_cuda=False
|
766 |
colorama=False
|
767 |
verbosity="quiet"
|
768 |
-
fp="./paper_code/discrim_models/persoothe_classifier.pt"
|
769 |
-
model_fp="./paper_code/discrim_models/persoothe_encoder.pt"
|
770 |
calc_perplexity=False
|
771 |
is_deep=False
|
772 |
is_deeper=True
|
@@ -801,10 +800,7 @@ model = GPT2LMHeadModel.from_pretrained(
|
|
801 |
output_hidden_states=True
|
802 |
)
|
803 |
if model_fp != None and model_fp != "":
|
804 |
-
|
805 |
-
model.load_state_dict(torch.load(model_fp))
|
806 |
-
except:
|
807 |
-
print("Can't load local model")
|
808 |
model.to(device)
|
809 |
model.eval()
|
810 |
|
@@ -817,16 +813,19 @@ for param in model.parameters():
|
|
817 |
|
818 |
eot_token = "<|endoftext|>"
|
819 |
|
820 |
-
def get_reply(response, history =
|
821 |
-
|
822 |
-
|
823 |
-
|
|
|
|
|
|
|
|
|
824 |
# figure out conditioning text
|
825 |
tokenized_cond_text = tokenizer.encode(
|
826 |
-
convo_hist,
|
827 |
add_special_tokens=False
|
828 |
)
|
829 |
-
|
830 |
# generate perturbed texts
|
831 |
|
832 |
# full_text_generation returns:
|
@@ -861,30 +860,21 @@ def get_reply(response, history = "How are you?<|endoftext|>"):
|
|
861 |
)
|
862 |
|
863 |
# iterate through the perturbed texts
|
864 |
-
|
865 |
-
|
866 |
-
|
867 |
-
|
868 |
-
|
869 |
-
|
870 |
-
|
871 |
-
|
872 |
-
html += "
|
873 |
-
|
874 |
-
|
875 |
-
|
876 |
-
|
877 |
-
|
878 |
-
|
879 |
-
|
880 |
-
# write some HTML
|
881 |
-
html = "<div class='chatbot'>"
|
882 |
-
for m, msg in enumerate(convo_hist_split):
|
883 |
-
cls = "user" if m%2 == 0 else "bot"
|
884 |
-
html += "<div class='msg {}'> {}</div>".format(cls, msg)
|
885 |
-
html += "</div>"
|
886 |
-
|
887 |
-
convo_hist = history + "*ai has no response*" + eot_token
|
888 |
|
889 |
return html, convo_hist
|
890 |
|
@@ -898,6 +888,11 @@ css = """
|
|
898 |
|
899 |
gr.Interface(fn=get_reply,
|
900 |
theme="default",
|
901 |
-
inputs=[gr.inputs.Textbox(placeholder="How are you?"),
|
|
|
|
|
|
|
|
|
|
|
902 |
outputs=["html", "state"],
|
903 |
-
css=css).launch()
|
|
|
13 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
# See the License for the specific language governing permissions and
|
15 |
# limitations under the License.
|
16 |
+
# print
|
17 |
"""
|
18 |
Example command with bag of words:
|
19 |
python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95
|
|
|
608 |
last_reps = torch.ones(50257)
|
609 |
last_reps = last_reps.to(device)
|
610 |
for i in range_func:
|
|
|
611 |
# Get past/probs for current output, except for last word
|
612 |
# Note that GPT takes 2 inputs: past + current_token
|
613 |
+
|
614 |
# run model forward to obtain unperturbed
|
615 |
if past is None and output_so_far is not None:
|
616 |
last = output_so_far[:, -1:]
|
|
|
738 |
DISCRIMINATOR_MODELS_PARAMS['generic'] = meta
|
739 |
|
740 |
|
741 |
+
pretrained_model="microsoft/DialoGPT-large"
|
742 |
cond_text=""
|
743 |
uncond=False
|
744 |
num_samples=1
|
|
|
757 |
horizon_length=5
|
758 |
window_length=0
|
759 |
decay=False
|
760 |
+
gamma=1.0
|
761 |
gm_scale=0.95
|
762 |
kl_scale=0.01
|
763 |
seed=0
|
764 |
no_cuda=False
|
765 |
colorama=False
|
766 |
verbosity="quiet"
|
767 |
+
fp="./paper_code/discrim_models/persoothe_classifier.pt" #"/content/drive/Shareddrives/COS_IW04_ZL/COSIW04/Discriminators/3_class_lrggpt_fit_deeper_2/3_PerSoothe_classifier_head_epoch_8.pt"
|
768 |
+
model_fp="./paper_code/discrim_models/persoothe_encoder.pt" #None
|
769 |
calc_perplexity=False
|
770 |
is_deep=False
|
771 |
is_deeper=True
|
|
|
800 |
output_hidden_states=True
|
801 |
)
|
802 |
if model_fp != None and model_fp != "":
|
803 |
+
model.load_state_dict(torch.load(model_fp))
|
|
|
|
|
|
|
804 |
model.to(device)
|
805 |
model.eval()
|
806 |
|
|
|
813 |
|
814 |
eot_token = "<|endoftext|>"
|
815 |
|
816 |
+
def get_reply(response, history = None, in_stepsize = 2.56, in_horizon_length = 5, in_num_iterations = 10, in_top_k = 2):
|
817 |
+
stepsize = in_stepsize
|
818 |
+
horizon_length = int(in_horizon_length)
|
819 |
+
num_iterations = int(in_num_iterations)
|
820 |
+
top_k = int(in_top_k)
|
821 |
+
if response.endswith(("bye", "Bye", "bye.", "Bye.", "bye!", "Bye!")):
|
822 |
+
return "<div class='chatbot'>Chatbot restarted</div>", None
|
823 |
+
convo_hist = (history if history != None else "How are you?<|endoftext|>") + response + eot_token
|
824 |
# figure out conditioning text
|
825 |
tokenized_cond_text = tokenizer.encode(
|
826 |
+
eot_token + convo_hist,
|
827 |
add_special_tokens=False
|
828 |
)
|
|
|
829 |
# generate perturbed texts
|
830 |
|
831 |
# full_text_generation returns:
|
|
|
860 |
)
|
861 |
|
862 |
# iterate through the perturbed texts
|
863 |
+
for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
|
864 |
+
try:
|
865 |
+
pert_gen_text = tokenizer.decode(pert_gen_tok_text.tolist()[0])
|
866 |
+
convo_hist_split = pert_gen_text.split(eot_token)
|
867 |
+
html = "<div class='chatbot'>"
|
868 |
+
for m, msg in enumerate(convo_hist_split[1:-1]):
|
869 |
+
cls = "user" if m%2 == 0 else "bot"
|
870 |
+
html += "<div class='msg {}'> {}</div>".format(cls, msg)
|
871 |
+
html += "</div>"
|
872 |
+
|
873 |
+
if len(convo_hist_split) > 4: convo_hist_split = convo_hist_split[-4:]
|
874 |
+
convo_hist = eot_token.join(convo_hist_split)
|
875 |
+
|
876 |
+
except:
|
877 |
+
return "<div class='chatbot'>Error occured, chatbot restarted</div>", None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
878 |
|
879 |
return html, convo_hist
|
880 |
|
|
|
888 |
|
889 |
gr.Interface(fn=get_reply,
|
890 |
theme="default",
|
891 |
+
inputs=[gr.inputs.Textbox(placeholder="How are you?"),
|
892 |
+
"state",
|
893 |
+
gr.inputs.Number(default=2.56, label="Step"),
|
894 |
+
gr.inputs.Number(default=5, label="Horizon"),
|
895 |
+
gr.inputs.Number(default=10, label="Iterations"),
|
896 |
+
gr.inputs.Number(default=2, label="Top_k")],
|
897 |
outputs=["html", "state"],
|
898 |
+
css=css).launch(share=True)
|