XFious commited on
Commit
063d44d
·
1 Parent(s): c53e345

fix slow inference

Browse files
Files changed (3) hide show
  1. app.py +19 -21
  2. dearth_config.py +2 -2
  3. dearth_model.py +2 -2
app.py CHANGED
@@ -15,15 +15,14 @@ import asyncio
15
 
16
  tk = None
17
  model_states = None
18
- model = None
19
  lock_using_model = threading.Lock()
20
  recent_generate_timestamp = time.time()
21
 
22
- MODEL_LIVE_TIME = 15 * 60 # 15 minutes
23
 
24
 
25
  def load_model():
26
- global tk, model_states, model
27
 
28
  tk = transformers.AutoTokenizer.from_pretrained("./tk")
29
  model_path = "./ts100-re2-h1-4000-model.pt"
@@ -44,18 +43,6 @@ def load_model():
44
  new_key = k[len(unwanted_prefix_dueto_compile):]
45
  model_states[new_key] = model_states.pop(k)
46
 
47
- yml_path = "./ts100-re2-h1.yml"
48
- with open(yml_path, "r") as f:
49
- config = yaml.load(f, Loader=yaml.FullLoader)['model']
50
- if "vocab_size" not in config:
51
- config['vocab_size'] = tk.vocab_size
52
- config["attn_window_size"] = 500
53
- # print(config)
54
- config = DearthConfig(**config)
55
- model = DearthForCausalLM(config)
56
-
57
- model.load_state_dict(model_states)
58
- model.eval()
59
 
60
 
61
  def main_free_mem():
@@ -66,12 +53,11 @@ def main_free_mem():
66
 
67
 
68
  def free_mem():
69
- global tk, model_states, model, recent_generate_timestamp, lock_using_model
70
  lock_using_model.acquire()
71
- if time.time() - recent_generate_timestamp >= MODEL_LIVE_TIME and model is not None:
72
  tk = None
73
  model_states = None
74
- model = None
75
  print(f"free mem, {time.time()}")
76
  lock_using_model.release()
77
  try:
@@ -85,13 +71,25 @@ def generate(input, num_more_tokens):
85
  global tk, model_states, model, recent_generate_timestamp, lock_using_model
86
  lock_using_model.acquire()
87
  time_start = time.time()
88
- if model is None:
89
  load_model()
90
  elif time.time() - recent_generate_timestamp > MODEL_LIVE_TIME:
91
  tk = None
92
  model_states = None
93
- model = None
94
  load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  recent_generate_timestamp = time.time()
96
  print(f"load model time: {time.time() - time_start}")
97
 
@@ -158,7 +156,7 @@ if __name__ == "__main__":
158
  with gr.Row():
159
  with gr.Column():
160
  inp = gr.Textbox(lines=5, label="Input Text", value=example_input[random.randint(0, len(example_input)-1)], elem_id="input_textbox")
161
- generate_max_slider = gr.Slider(16, 64, step=1.0, value=32, label="more tokens", info="")
162
  generate_button = gr.Button(value="Generate")
163
  with gr.Column():
164
  out = gr.Textbox(lines=5, label="Output Text", value="")
 
15
 
16
  tk = None
17
  model_states = None
 
18
  lock_using_model = threading.Lock()
19
  recent_generate_timestamp = time.time()
20
 
21
+ MODEL_LIVE_TIME = 5#15 * 60 # 15 minutes
22
 
23
 
24
  def load_model():
25
+ global tk, model_states
26
 
27
  tk = transformers.AutoTokenizer.from_pretrained("./tk")
28
  model_path = "./ts100-re2-h1-4000-model.pt"
 
43
  new_key = k[len(unwanted_prefix_dueto_compile):]
44
  model_states[new_key] = model_states.pop(k)
45
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
 
48
  def main_free_mem():
 
53
 
54
 
55
  def free_mem():
56
+ global tk, model_states, recent_generate_timestamp, lock_using_model
57
  lock_using_model.acquire()
58
+ if time.time() - recent_generate_timestamp >= MODEL_LIVE_TIME and tk is not None:
59
  tk = None
60
  model_states = None
 
61
  print(f"free mem, {time.time()}")
62
  lock_using_model.release()
63
  try:
 
71
  global tk, model_states, model, recent_generate_timestamp, lock_using_model
72
  lock_using_model.acquire()
73
  time_start = time.time()
74
+ if tk is None:
75
  load_model()
76
  elif time.time() - recent_generate_timestamp > MODEL_LIVE_TIME:
77
  tk = None
78
  model_states = None
 
79
  load_model()
80
+
81
+ yml_path = "./ts100-re2-h1.yml"
82
+ with open(yml_path, "r") as f:
83
+ config = yaml.load(f, Loader=yaml.FullLoader)['model']
84
+ if "vocab_size" not in config:
85
+ config['vocab_size'] = tk.vocab_size
86
+ config["attn_window_size"] = 500
87
+ # print(config)
88
+ config = DearthConfig(**config)
89
+ model = DearthForCausalLM(config)
90
+
91
+ model.load_state_dict(model_states)
92
+ model.eval()
93
  recent_generate_timestamp = time.time()
94
  print(f"load model time: {time.time() - time_start}")
95
 
 
156
  with gr.Row():
157
  with gr.Column():
158
  inp = gr.Textbox(lines=5, label="Input Text", value=example_input[random.randint(0, len(example_input)-1)], elem_id="input_textbox")
159
+ generate_max_slider = gr.Slider(8, 64, step=1.0, value=16, label="more tokens", info="")
160
  generate_button = gr.Button(value="Generate")
161
  with gr.Column():
162
  out = gr.Textbox(lines=5, label="Output Text", value="")
dearth_config.py CHANGED
@@ -46,7 +46,7 @@ class DearthConfig(PretrainedConfig):
46
  self.hidden_dim = hidden_dim
47
  if hidden_dim is None:
48
  self.hidden_dim = dim * 4
49
- print(f"hidden_dim is not specified. Set to {self.hidden_dim}")
50
  self.multiple_of = multiple_of
51
  self.dropout_rate = dropout_rate
52
  self.layer_init_factor = layer_init_factor
@@ -66,7 +66,7 @@ class DearthConfig(PretrainedConfig):
66
  self.mimic_use_alibi = mimic_use_alibi
67
 
68
  if "attn_window_size" in kwargs:
69
- print("Warning: attn_window_size is deprecated. Please use sliding_window_size instead !!!!!!!!!!!")
70
  self.sliding_window_size = kwargs["attn_window_size"]
71
 
72
  super().__init__(
 
46
  self.hidden_dim = hidden_dim
47
  if hidden_dim is None:
48
  self.hidden_dim = dim * 4
49
+ #print(f"hidden_dim is not specified. Set to {self.hidden_dim}")
50
  self.multiple_of = multiple_of
51
  self.dropout_rate = dropout_rate
52
  self.layer_init_factor = layer_init_factor
 
66
  self.mimic_use_alibi = mimic_use_alibi
67
 
68
  if "attn_window_size" in kwargs:
69
+ #print("Warning: attn_window_size is deprecated. Please use sliding_window_size instead !!!!!!!!!!!")
70
  self.sliding_window_size = kwargs["attn_window_size"]
71
 
72
  super().__init__(
dearth_model.py CHANGED
@@ -611,10 +611,10 @@ class DearthModel(nn.Module):
611
  self.residual_factor = config.residual_factor if config.residual_factor is not None else float(config.n_layer * 2) ** (1/4)
612
  if config.residual_factor is None:
613
  config.residual_factor = self.residual_factor
614
- logging.warning(f"residual_factor is not set, using default value {self.residual_factor} = (2 * n_layer) ** 1/4")
615
  if config.layer_init_factor is None:
616
  config.layer_init_factor = self.layer_init_factor
617
- logging.warning(f"layer_init_factor is not set, using default value {self.layer_init_factor} = (n_layer * 8) ** -1/2")
618
 
619
  self.config = config
620
 
 
611
  self.residual_factor = config.residual_factor if config.residual_factor is not None else float(config.n_layer * 2) ** (1/4)
612
  if config.residual_factor is None:
613
  config.residual_factor = self.residual_factor
614
+ #logging.warning(f"residual_factor is not set, using default value {self.residual_factor} = (2 * n_layer) ** 1/4")
615
  if config.layer_init_factor is None:
616
  config.layer_init_factor = self.layer_init_factor
617
+ #logging.warning(f"layer_init_factor is not set, using default value {self.layer_init_factor} = (n_layer * 8) ** -1/2")
618
 
619
  self.config = config
620