zachlopez commited on
Commit
795dec4
1 Parent(s): b8fbd3e

updated app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -43
app.py CHANGED
@@ -13,7 +13,7 @@
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
-
17
  """
18
  Example command with bag of words:
19
  python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95
@@ -608,10 +608,9 @@ def generate_text_pplm(
608
  last_reps = torch.ones(50257)
609
  last_reps = last_reps.to(device)
610
  for i in range_func:
611
-
612
  # Get past/probs for current output, except for last word
613
  # Note that GPT takes 2 inputs: past + current_token
614
-
615
  # run model forward to obtain unperturbed
616
  if past is None and output_so_far is not None:
617
  last = output_so_far[:, -1:]
@@ -739,7 +738,7 @@ def set_generic_model_params(discrim_weights, discrim_meta):
739
  DISCRIMINATOR_MODELS_PARAMS['generic'] = meta
740
 
741
 
742
- pretrained_model="gpt2-medium"
743
  cond_text=""
744
  uncond=False
745
  num_samples=1
@@ -758,15 +757,15 @@ grad_length=10000
758
  horizon_length=5
759
  window_length=0
760
  decay=False
761
- gamma=1.5
762
  gm_scale=0.95
763
  kl_scale=0.01
764
  seed=0
765
  no_cuda=False
766
  colorama=False
767
  verbosity="quiet"
768
- fp="./paper_code/discrim_models/persoothe_classifier.pt"
769
- model_fp="./paper_code/discrim_models/persoothe_encoder.pt"
770
  calc_perplexity=False
771
  is_deep=False
772
  is_deeper=True
@@ -801,10 +800,7 @@ model = GPT2LMHeadModel.from_pretrained(
801
  output_hidden_states=True
802
  )
803
  if model_fp != None and model_fp != "":
804
- try:
805
- model.load_state_dict(torch.load(model_fp))
806
- except:
807
- print("Can't load local model")
808
  model.to(device)
809
  model.eval()
810
 
@@ -817,16 +813,19 @@ for param in model.parameters():
817
 
818
  eot_token = "<|endoftext|>"
819
 
820
- def get_reply(response, history = "How are you?<|endoftext|>"):
821
- if response.endswith(("bye", "Bye", "bye.", "Bye.")):
822
- return
823
- convo_hist = history + response + eot_token
 
 
 
 
824
  # figure out conditioning text
825
  tokenized_cond_text = tokenizer.encode(
826
- convo_hist,
827
  add_special_tokens=False
828
  )
829
-
830
  # generate perturbed texts
831
 
832
  # full_text_generation returns:
@@ -861,30 +860,21 @@ def get_reply(response, history = "How are you?<|endoftext|>"):
861
  )
862
 
863
  # iterate through the perturbed texts
864
- try:
865
- pert_gen_text = tokenizer.decode(pert_gen_tok_texts[0].tolist()[0])
866
- convo_hist_split = pert_gen_text.split(eot_token)
867
-
868
- # write some HTML
869
- html = "<div class='chatbot'>"
870
- for m, msg in enumerate(convo_hist_split):
871
- cls = "user" if m%2 == 0 else "bot"
872
- html += "<div class='msg {}'> {}</div>".format(cls, msg)
873
- html += "</div>"
874
-
875
- if len(convo_hist_split) > 4: convo_hist_split = convo_hist_split[-4:]
876
- convo_hist = eot_token.join(convo_hist_split)
877
- except:
878
- convo_hist_split = history.split(eot_token) + ["*ai has no response*"]
879
-
880
- # write some HTML
881
- html = "<div class='chatbot'>"
882
- for m, msg in enumerate(convo_hist_split):
883
- cls = "user" if m%2 == 0 else "bot"
884
- html += "<div class='msg {}'> {}</div>".format(cls, msg)
885
- html += "</div>"
886
-
887
- convo_hist = history + "*ai has no response*" + eot_token
888
 
889
  return html, convo_hist
890
 
@@ -898,6 +888,11 @@ css = """
898
 
899
  gr.Interface(fn=get_reply,
900
  theme="default",
901
- inputs=[gr.inputs.Textbox(placeholder="How are you?"), "state"],
 
 
 
 
 
902
  outputs=["html", "state"],
903
- css=css).launch()
 
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
+ # print
17
  """
18
  Example command with bag of words:
19
  python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95
 
608
  last_reps = torch.ones(50257)
609
  last_reps = last_reps.to(device)
610
  for i in range_func:
 
611
  # Get past/probs for current output, except for last word
612
  # Note that GPT takes 2 inputs: past + current_token
613
+
614
  # run model forward to obtain unperturbed
615
  if past is None and output_so_far is not None:
616
  last = output_so_far[:, -1:]
 
738
  DISCRIMINATOR_MODELS_PARAMS['generic'] = meta
739
 
740
 
741
+ pretrained_model="microsoft/DialoGPT-large"
742
  cond_text=""
743
  uncond=False
744
  num_samples=1
 
757
  horizon_length=5
758
  window_length=0
759
  decay=False
760
+ gamma=1.0
761
  gm_scale=0.95
762
  kl_scale=0.01
763
  seed=0
764
  no_cuda=False
765
  colorama=False
766
  verbosity="quiet"
767
+ fp="./paper_code/discrim_models/persoothe_classifier.pt" #"/content/drive/Shareddrives/COS_IW04_ZL/COSIW04/Discriminators/3_class_lrggpt_fit_deeper_2/3_PerSoothe_classifier_head_epoch_8.pt"
768
+ model_fp="./paper_code/discrim_models/persoothe_encoder.pt" #None
769
  calc_perplexity=False
770
  is_deep=False
771
  is_deeper=True
 
800
  output_hidden_states=True
801
  )
802
  if model_fp != None and model_fp != "":
803
+ model.load_state_dict(torch.load(model_fp))
 
 
 
804
  model.to(device)
805
  model.eval()
806
 
 
813
 
814
  eot_token = "<|endoftext|>"
815
 
816
+ def get_reply(response, history = None, in_stepsize = 2.56, in_horizon_length = 5, in_num_iterations = 10, in_top_k = 2):
817
+ stepsize = in_stepsize
818
+ horizon_length = int(in_horizon_length)
819
+ num_iterations = int(in_num_iterations)
820
+ top_k = int(in_top_k)
821
+ if response.endswith(("bye", "Bye", "bye.", "Bye.", "bye!", "Bye!")):
822
+ return "<div class='chatbot'>Chatbot restarted</div>", None
823
+ convo_hist = (history if history != None else "How are you?<|endoftext|>") + response + eot_token
824
  # figure out conditioning text
825
  tokenized_cond_text = tokenizer.encode(
826
+ eot_token + convo_hist,
827
  add_special_tokens=False
828
  )
 
829
  # generate perturbed texts
830
 
831
  # full_text_generation returns:
 
860
  )
861
 
862
  # iterate through the perturbed texts
863
+ for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
864
+ try:
865
+ pert_gen_text = tokenizer.decode(pert_gen_tok_text.tolist()[0])
866
+ convo_hist_split = pert_gen_text.split(eot_token)
867
+ html = "<div class='chatbot'>"
868
+ for m, msg in enumerate(convo_hist_split[1:-1]):
869
+ cls = "user" if m%2 == 0 else "bot"
870
+ html += "<div class='msg {}'> {}</div>".format(cls, msg)
871
+ html += "</div>"
872
+
873
+ if len(convo_hist_split) > 4: convo_hist_split = convo_hist_split[-4:]
874
+ convo_hist = eot_token.join(convo_hist_split)
875
+
876
+ except:
877
+ return "<div class='chatbot'>Error occured, chatbot restarted</div>", None
 
 
 
 
 
 
 
 
 
878
 
879
  return html, convo_hist
880
 
 
888
 
889
  gr.Interface(fn=get_reply,
890
  theme="default",
891
+ inputs=[gr.inputs.Textbox(placeholder="How are you?"),
892
+ "state",
893
+ gr.inputs.Number(default=2.56, label="Step"),
894
+ gr.inputs.Number(default=5, label="Horizon"),
895
+ gr.inputs.Number(default=10, label="Iterations"),
896
+ gr.inputs.Number(default=2, label="Top_k")],
897
  outputs=["html", "state"],
898
+ css=css).launch(share=True)