Spaces:
Sleeping
Sleeping
add feature: free memory
Browse files
app.py
CHANGED
@@ -8,30 +8,42 @@ from dearth_model import DearthForCausalLM
|
|
8 |
|
9 |
import random
|
10 |
import time
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
model_states =
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
yml_path = "./ts100-re2-h1.yml"
|
36 |
with open(yml_path, "r") as f:
|
37 |
config = yaml.load(f, Loader=yaml.FullLoader)['model']
|
@@ -43,8 +55,47 @@ def generate(input, num_more_tokens):
|
|
43 |
model = DearthForCausalLM(config)
|
44 |
|
45 |
model.load_state_dict(model_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
num_more_tokens = int(num_more_tokens)
|
49 |
# print(input)
|
50 |
input = input.strip()
|
@@ -52,7 +103,9 @@ def generate(input, num_more_tokens):
|
|
52 |
input_ids = [tk.bos_token_id] + input_ids
|
53 |
input_ids = torch.tensor(input_ids, dtype=torch.long).view(1, -1)
|
54 |
# print(input_ids)
|
|
|
55 |
|
|
|
56 |
output_ids = input_ids.squeeze(0).tolist()
|
57 |
for i in range(num_more_tokens):
|
58 |
input = torch.tensor(output_ids, dtype=torch.long).view(1, -1)
|
@@ -70,8 +123,12 @@ def generate(input, num_more_tokens):
|
|
70 |
# print(output_ids)
|
71 |
# print(tk.decode(output_ids))
|
72 |
output_ids = output_ids[1:]
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
-
return tk.decode(output_ids)
|
75 |
|
76 |
example_input = ["Once upon a time, there was a little girl",
|
77 |
"John and Sarah were playing together in their backyard when",
|
@@ -86,21 +143,11 @@ The PPL on the validation set is 1.7, in comparison, the teacher model has a PPL
|
|
86 |
"""
|
87 |
|
88 |
|
89 |
-
# demo = gr.Interface(
|
90 |
-
# fn=generate,
|
91 |
-
# title="Tinystories LM 11M",
|
92 |
-
# description=Description,
|
93 |
-
# inputs=[
|
94 |
-
# gr.Textbox(lines=5, label="Input Text", value=example_input[random.randint(0, len(example_input)-1)]),
|
95 |
-
# gr.Slider(16, 64, step=1.0, value=32, label="more tokens", info="")
|
96 |
-
# ],
|
97 |
-
# outputs="text"
|
98 |
-
# )
|
99 |
-
|
100 |
-
with open("./random_input_example.js" , "r") as f:
|
101 |
-
file_content = f.read()
|
102 |
-
|
103 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
104 |
with gr.Blocks(
|
105 |
title="Tinystories LM 11M",
|
106 |
js="./random_input_example.js"
|
|
|
8 |
|
9 |
import random
|
10 |
import time
|
11 |
+
import threading
|
12 |
+
import asyncio
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
tk = None
|
17 |
+
model_states = None
|
18 |
+
model = None
|
19 |
+
lock_using_model = threading.Lock()
|
20 |
+
recent_generate_timestamp = time.time()
|
21 |
+
|
22 |
+
MODEL_LIVE_TIME = 15 * 60 # 15 minutes
|
23 |
+
|
24 |
+
|
25 |
+
def load_model():
|
26 |
+
global tk, model_states, model
|
27 |
+
|
28 |
+
tk = transformers.AutoTokenizer.from_pretrained("./tk")
|
29 |
+
model_path = "./ts100-re2-h1-4000-model.pt"
|
30 |
+
states = torch.load(model_path, map_location="cpu")
|
31 |
+
model_states = states
|
32 |
+
unwanted_prefix_dueto_compile = '_orig_mod.'
|
33 |
+
unwanted_prefix_dueto_ddp = 'module.'
|
34 |
+
unwanted_prefix_dueto_ddp_compiled = 'module._orig_mod.'
|
35 |
+
|
36 |
+
for k,v in list(model_states.items()):
|
37 |
+
if k.startswith(unwanted_prefix_dueto_ddp_compiled):
|
38 |
+
new_key = k[len(unwanted_prefix_dueto_ddp_compiled):]
|
39 |
+
model_states[new_key] = model_states.pop(k)
|
40 |
+
elif k.startswith(unwanted_prefix_dueto_ddp):
|
41 |
+
new_key = k[len(unwanted_prefix_dueto_ddp):]
|
42 |
+
model_states[new_key] = model_states.pop(k)
|
43 |
+
elif k.startswith(unwanted_prefix_dueto_compile):
|
44 |
+
new_key = k[len(unwanted_prefix_dueto_compile):]
|
45 |
+
model_states[new_key] = model_states.pop(k)
|
46 |
+
|
47 |
yml_path = "./ts100-re2-h1.yml"
|
48 |
with open(yml_path, "r") as f:
|
49 |
config = yaml.load(f, Loader=yaml.FullLoader)['model']
|
|
|
55 |
model = DearthForCausalLM(config)
|
56 |
|
57 |
model.load_state_dict(model_states)
|
58 |
+
model.eval()
|
59 |
+
|
60 |
+
|
61 |
+
def main_free_mem():
|
62 |
+
event_loop = asyncio.new_event_loop()
|
63 |
+
asyncio.set_event_loop(event_loop)
|
64 |
+
event_loop.call_later(MODEL_LIVE_TIME, free_mem)
|
65 |
+
event_loop.run_forever()
|
66 |
|
67 |
|
68 |
+
def free_mem():
|
69 |
+
global tk, model_states, model, recent_generate_timestamp, lock_using_model
|
70 |
+
lock_using_model.acquire()
|
71 |
+
if time.time() - recent_generate_timestamp >= MODEL_LIVE_TIME and model is not None:
|
72 |
+
tk = None
|
73 |
+
model_states = None
|
74 |
+
model = None
|
75 |
+
print(f"free mem, {time.time()}")
|
76 |
+
lock_using_model.release()
|
77 |
+
try:
|
78 |
+
event_loop = asyncio.get_event_loop()
|
79 |
+
event_loop.call_later(MODEL_LIVE_TIME, free_mem)
|
80 |
+
except:
|
81 |
+
pass
|
82 |
+
|
83 |
+
|
84 |
+
def generate(input, num_more_tokens):
|
85 |
+
global tk, model_states, model, recent_generate_timestamp, lock_using_model
|
86 |
+
lock_using_model.acquire()
|
87 |
+
time_start = time.time()
|
88 |
+
if model is None:
|
89 |
+
load_model()
|
90 |
+
elif time.time() - recent_generate_timestamp > MODEL_LIVE_TIME:
|
91 |
+
tk = None
|
92 |
+
model_states = None
|
93 |
+
model = None
|
94 |
+
load_model()
|
95 |
+
recent_generate_timestamp = time.time()
|
96 |
+
print(f"load model time: {time.time() - time_start}")
|
97 |
+
|
98 |
+
time_start = time.time()
|
99 |
num_more_tokens = int(num_more_tokens)
|
100 |
# print(input)
|
101 |
input = input.strip()
|
|
|
103 |
input_ids = [tk.bos_token_id] + input_ids
|
104 |
input_ids = torch.tensor(input_ids, dtype=torch.long).view(1, -1)
|
105 |
# print(input_ids)
|
106 |
+
print(f"encode time: {time.time() - time_start}")
|
107 |
|
108 |
+
time_start = time.time()
|
109 |
output_ids = input_ids.squeeze(0).tolist()
|
110 |
for i in range(num_more_tokens):
|
111 |
input = torch.tensor(output_ids, dtype=torch.long).view(1, -1)
|
|
|
123 |
# print(output_ids)
|
124 |
# print(tk.decode(output_ids))
|
125 |
output_ids = output_ids[1:]
|
126 |
+
print(f"inference time: {time.time() - time_start}\n")
|
127 |
+
|
128 |
+
ret = tk.decode(output_ids)
|
129 |
+
lock_using_model.release()
|
130 |
+
return ret
|
131 |
|
|
|
132 |
|
133 |
example_input = ["Once upon a time, there was a little girl",
|
134 |
"John and Sarah were playing together in their backyard when",
|
|
|
143 |
"""
|
144 |
|
145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
if __name__ == "__main__":
|
147 |
+
load_model()
|
148 |
+
thread_free_mem = threading.Thread(target=main_free_mem)
|
149 |
+
thread_free_mem.start()
|
150 |
+
|
151 |
with gr.Blocks(
|
152 |
title="Tinystories LM 11M",
|
153 |
js="./random_input_example.js"
|