zachlopez commited on
Commit
7d1b80e
·
1 Parent(s): 41d1c1c

Added parallelism workaround

Browse files
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -28,6 +28,7 @@ import json
28
  from operator import add
29
  from typing import List, Optional, Tuple, Union
30
  from random import choice, randint
 
31
  import numpy as np
32
  import torch
33
  import torch.nn.functional as F
@@ -748,7 +749,7 @@ discrim_weights=None
748
  discrim_meta=None
749
  class_label=0
750
  length=100
751
- stepsize=2.56
752
  temperature=1.0
753
  top_k=2
754
  sample=True
@@ -813,13 +814,16 @@ for param in model.parameters():
813
 
814
  eot_token = "<|endoftext|>"
815
 
816
- def get_reply(response, history = None, in_stepsize = 2.56, in_horizon_length = 1, in_num_iterations = 0, in_top_k = 2):
 
817
  stepsize = in_stepsize
818
  horizon_length = int(in_horizon_length)
819
  num_iterations = int(in_num_iterations)
820
  top_k = int(in_top_k)
821
  if response.endswith(("bye", "Bye", "bye.", "Bye.", "bye!", "Bye!")):
822
- return "<div class='chatbot'>Chatbot restarted</div>", None
 
 
823
  convo_hist = (history if history != None else "How are you?<|endoftext|>") + response + eot_token
824
  # figure out conditioning text
825
  tokenized_cond_text = tokenizer.encode(
@@ -874,9 +878,10 @@ def get_reply(response, history = None, in_stepsize = 2.56, in_horizon_length =
874
  convo_hist = eot_token.join(convo_hist_split)
875
 
876
  except:
877
- return "<div class='chatbot'>Error occured, chatbot restarted</div>", None
878
-
879
- return html, convo_hist
 
880
 
881
  css = """
882
  .chatbox {display:flex;flex-direction:column}
@@ -888,6 +893,8 @@ css = """
888
 
889
  gr.Interface(fn=get_reply,
890
  theme="default",
891
- inputs=[gr.inputs.Textbox(placeholder="How are you?"), "state"],
 
 
892
  outputs=["html", "state"],
893
- css=css).launch()
 
28
  from operator import add
29
  from typing import List, Optional, Tuple, Union
30
  from random import choice, randint
31
+ from matplotlib import use
32
  import numpy as np
33
  import torch
34
  import torch.nn.functional as F
 
749
  discrim_meta=None
750
  class_label=0
751
  length=100
752
+ stepsize=5.12
753
  temperature=1.0
754
  top_k=2
755
  sample=True
 
814
 
815
  eot_token = "<|endoftext|>"
816
 
817
+ def get_reply(response, username = None, histories = {}, in_stepsize = 5.12, in_horizon_length = 1, in_num_iterations = 0, in_top_k = 2):
818
+ if username == None or username == "": return "<div class='chatbot'>Enter a username</div>", histories
819
  stepsize = in_stepsize
820
  horizon_length = int(in_horizon_length)
821
  num_iterations = int(in_num_iterations)
822
  top_k = int(in_top_k)
823
  if response.endswith(("bye", "Bye", "bye.", "Bye.", "bye!", "Bye!")):
824
+ histories[username] = None
825
+ return "<div class='chatbot'>Chatbot restarted</div>", histories
826
+ history = histories.get(username, None)
827
  convo_hist = (history if history != None else "How are you?<|endoftext|>") + response + eot_token
828
  # figure out conditioning text
829
  tokenized_cond_text = tokenizer.encode(
 
878
  convo_hist = eot_token.join(convo_hist_split)
879
 
880
  except:
881
+ histories[username] = None
882
+ return "<div class='chatbot'>Error occured, chatbot restarted</div>", histories
883
+ histories[username] = convo_hist
884
+ return html, histories
885
 
886
  css = """
887
  .chatbox {display:flex;flex-direction:column}
 
893
 
894
  gr.Interface(fn=get_reply,
895
  theme="default",
896
+ inputs=[gr.inputs.Textbox(placeholder="How are you?"),
897
+ gr.inputs.Textbox(label="Username"),
898
+ "state"],
899
  outputs=["html", "state"],
900
+ css=css).launch(debug=True, enable_queue=True)