XFious commited on
Commit
c53e345
·
1 Parent(s): 7569c73

add feature: free memory

Browse files
Files changed (1) hide show
  1. app.py +86 -39
app.py CHANGED
@@ -8,30 +8,42 @@ from dearth_model import DearthForCausalLM
8
 
9
  import random
10
  import time
11
-
12
-
13
-
14
- tk = transformers.AutoTokenizer.from_pretrained("./tk")
15
- model_path = "./ts100-re2-h1-4000-model.pt"
16
- states = torch.load(model_path, map_location="cpu")
17
- model_states = states
18
- unwanted_prefix_dueto_compile = '_orig_mod.'
19
- unwanted_prefix_dueto_ddp = 'module.'
20
- unwanted_prefix_dueto_ddp_compiled = 'module._orig_mod.'
21
-
22
- for k,v in list(model_states.items()):
23
- if k.startswith(unwanted_prefix_dueto_ddp_compiled):
24
- new_key = k[len(unwanted_prefix_dueto_ddp_compiled):]
25
- model_states[k[len(unwanted_prefix_dueto_ddp_compiled):]] = model_states.pop(k)
26
- elif k.startswith(unwanted_prefix_dueto_ddp):
27
- new_key = k[len(unwanted_prefix_dueto_ddp):]
28
- model_states[k[len(unwanted_prefix_dueto_ddp):]] = model_states.pop(k)
29
- elif k.startswith(unwanted_prefix_dueto_compile):
30
- new_key = k[len(unwanted_prefix_dueto_compile):]
31
- model_states[k[len(unwanted_prefix_dueto_compile):]] = model_states.pop(k)
32
-
33
- def generate(input, num_more_tokens):
34
-
 
 
 
 
 
 
 
 
 
 
 
 
35
  yml_path = "./ts100-re2-h1.yml"
36
  with open(yml_path, "r") as f:
37
  config = yaml.load(f, Loader=yaml.FullLoader)['model']
@@ -43,8 +55,47 @@ def generate(input, num_more_tokens):
43
  model = DearthForCausalLM(config)
44
 
45
  model.load_state_dict(model_states)
 
 
 
 
 
 
 
 
46
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  num_more_tokens = int(num_more_tokens)
49
  # print(input)
50
  input = input.strip()
@@ -52,7 +103,9 @@ def generate(input, num_more_tokens):
52
  input_ids = [tk.bos_token_id] + input_ids
53
  input_ids = torch.tensor(input_ids, dtype=torch.long).view(1, -1)
54
  # print(input_ids)
 
55
 
 
56
  output_ids = input_ids.squeeze(0).tolist()
57
  for i in range(num_more_tokens):
58
  input = torch.tensor(output_ids, dtype=torch.long).view(1, -1)
@@ -70,8 +123,12 @@ def generate(input, num_more_tokens):
70
  # print(output_ids)
71
  # print(tk.decode(output_ids))
72
  output_ids = output_ids[1:]
 
 
 
 
 
73
 
74
- return tk.decode(output_ids)
75
 
76
  example_input = ["Once upon a time, there was a little girl",
77
  "John and Sarah were playing together in their backyard when",
@@ -86,21 +143,11 @@ The PPL on the validation set is 1.7, in comparison, the teacher model has a PPL
86
  """
87
 
88
 
89
- # demo = gr.Interface(
90
- # fn=generate,
91
- # title="Tinystories LM 11M",
92
- # description=Description,
93
- # inputs=[
94
- # gr.Textbox(lines=5, label="Input Text", value=example_input[random.randint(0, len(example_input)-1)]),
95
- # gr.Slider(16, 64, step=1.0, value=32, label="more tokens", info="")
96
- # ],
97
- # outputs="text"
98
- # )
99
-
100
- with open("./random_input_example.js" , "r") as f:
101
- file_content = f.read()
102
-
103
  if __name__ == "__main__":
 
 
 
 
104
  with gr.Blocks(
105
  title="Tinystories LM 11M",
106
  js="./random_input_example.js"
 
8
 
9
  import random
10
  import time
11
+ import threading
12
+ import asyncio
13
+
14
+
15
+
16
+ tk = None
17
+ model_states = None
18
+ model = None
19
+ lock_using_model = threading.Lock()
20
+ recent_generate_timestamp = time.time()
21
+
22
+ MODEL_LIVE_TIME = 15 * 60 # 15 minutes
23
+
24
+
25
+ def load_model():
26
+ global tk, model_states, model
27
+
28
+ tk = transformers.AutoTokenizer.from_pretrained("./tk")
29
+ model_path = "./ts100-re2-h1-4000-model.pt"
30
+ states = torch.load(model_path, map_location="cpu")
31
+ model_states = states
32
+ unwanted_prefix_dueto_compile = '_orig_mod.'
33
+ unwanted_prefix_dueto_ddp = 'module.'
34
+ unwanted_prefix_dueto_ddp_compiled = 'module._orig_mod.'
35
+
36
+ for k,v in list(model_states.items()):
37
+ if k.startswith(unwanted_prefix_dueto_ddp_compiled):
38
+ new_key = k[len(unwanted_prefix_dueto_ddp_compiled):]
39
+ model_states[new_key] = model_states.pop(k)
40
+ elif k.startswith(unwanted_prefix_dueto_ddp):
41
+ new_key = k[len(unwanted_prefix_dueto_ddp):]
42
+ model_states[new_key] = model_states.pop(k)
43
+ elif k.startswith(unwanted_prefix_dueto_compile):
44
+ new_key = k[len(unwanted_prefix_dueto_compile):]
45
+ model_states[new_key] = model_states.pop(k)
46
+
47
  yml_path = "./ts100-re2-h1.yml"
48
  with open(yml_path, "r") as f:
49
  config = yaml.load(f, Loader=yaml.FullLoader)['model']
 
55
  model = DearthForCausalLM(config)
56
 
57
  model.load_state_dict(model_states)
58
+ model.eval()
59
+
60
+
61
+ def main_free_mem():
62
+ event_loop = asyncio.new_event_loop()
63
+ asyncio.set_event_loop(event_loop)
64
+ event_loop.call_later(MODEL_LIVE_TIME, free_mem)
65
+ event_loop.run_forever()
66
 
67
 
68
+ def free_mem():
69
+ global tk, model_states, model, recent_generate_timestamp, lock_using_model
70
+ lock_using_model.acquire()
71
+ if time.time() - recent_generate_timestamp >= MODEL_LIVE_TIME and model is not None:
72
+ tk = None
73
+ model_states = None
74
+ model = None
75
+ print(f"free mem, {time.time()}")
76
+ lock_using_model.release()
77
+ try:
78
+ event_loop = asyncio.get_event_loop()
79
+ event_loop.call_later(MODEL_LIVE_TIME, free_mem)
80
+ except:
81
+ pass
82
+
83
+
84
+ def generate(input, num_more_tokens):
85
+ global tk, model_states, model, recent_generate_timestamp, lock_using_model
86
+ lock_using_model.acquire()
87
+ time_start = time.time()
88
+ if model is None:
89
+ load_model()
90
+ elif time.time() - recent_generate_timestamp > MODEL_LIVE_TIME:
91
+ tk = None
92
+ model_states = None
93
+ model = None
94
+ load_model()
95
+ recent_generate_timestamp = time.time()
96
+ print(f"load model time: {time.time() - time_start}")
97
+
98
+ time_start = time.time()
99
  num_more_tokens = int(num_more_tokens)
100
  # print(input)
101
  input = input.strip()
 
103
  input_ids = [tk.bos_token_id] + input_ids
104
  input_ids = torch.tensor(input_ids, dtype=torch.long).view(1, -1)
105
  # print(input_ids)
106
+ print(f"encode time: {time.time() - time_start}")
107
 
108
+ time_start = time.time()
109
  output_ids = input_ids.squeeze(0).tolist()
110
  for i in range(num_more_tokens):
111
  input = torch.tensor(output_ids, dtype=torch.long).view(1, -1)
 
123
  # print(output_ids)
124
  # print(tk.decode(output_ids))
125
  output_ids = output_ids[1:]
126
+ print(f"inference time: {time.time() - time_start}\n")
127
+
128
+ ret = tk.decode(output_ids)
129
+ lock_using_model.release()
130
+ return ret
131
 
 
132
 
133
  example_input = ["Once upon a time, there was a little girl",
134
  "John and Sarah were playing together in their backyard when",
 
143
  """
144
 
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  if __name__ == "__main__":
147
+ load_model()
148
+ thread_free_mem = threading.Thread(target=main_free_mem)
149
+ thread_free_mem.start()
150
+
151
  with gr.Blocks(
152
  title="Tinystories LM 11M",
153
  js="./random_input_example.js"