Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -67,23 +67,6 @@ with open(meta_path, 'rb') as f:
|
|
67 |
stoi, itos = meta['stoi'], meta['itos']
|
68 |
encode = lambda s: [stoi[c] for c in s]
|
69 |
decode = lambda l: ''.join([itos[i] for i in l])
|
70 |
-
|
71 |
-
def gen(input):
|
72 |
-
generated_text = ''
|
73 |
-
start_ids = encode(add_caseifer(input))
|
74 |
-
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
|
75 |
-
for idx_next in model.generate_streaming(x, max_new_tokens, temperature=temperature, top_k=top_k):
|
76 |
-
# convert the index to a character and print it to the screen
|
77 |
-
char = decode([idx_next])
|
78 |
-
|
79 |
-
|
80 |
-
# check for newline character
|
81 |
-
if char == '\n':
|
82 |
-
out = remove_caseifer(generated_text)
|
83 |
-
return input + out
|
84 |
-
else:
|
85 |
-
# append the character to the generated text
|
86 |
-
generated_text += char
|
87 |
|
88 |
def load_model(model_name):
|
89 |
ckpt_path = os.path.join(out_dir, model_name)
|
@@ -108,8 +91,22 @@ def get_model_list():
|
|
108 |
|
109 |
def gen(input, model_name):
|
110 |
model = load_model(model_name)
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
iface = gr.Interface(fn=gen,
|
114 |
inputs=["text", gr.inputs.Dropdown(get_model_list(), label="Select Model")],
|
115 |
outputs="text")
|
|
|
67 |
stoi, itos = meta['stoi'], meta['itos']
|
68 |
encode = lambda s: [stoi[c] for c in s]
|
69 |
decode = lambda l: ''.join([itos[i] for i in l])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
def load_model(model_name):
|
72 |
ckpt_path = os.path.join(out_dir, model_name)
|
|
|
91 |
|
92 |
def gen(input, model_name):
|
93 |
model = load_model(model_name)
|
94 |
+
generated_text = ''
|
95 |
+
start_ids = encode(add_caseifer(input))
|
96 |
+
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
|
97 |
+
for idx_next in model.generate_streaming(x, max_new_tokens, temperature=temperature, top_k=top_k):
|
98 |
+
# convert the index to a character and print it to the screen
|
99 |
+
char = decode([idx_next])
|
100 |
+
|
101 |
|
102 |
+
# check for newline character
|
103 |
+
if char == '\n':
|
104 |
+
out = remove_caseifer(generated_text)
|
105 |
+
return input + out
|
106 |
+
else:
|
107 |
+
# append the character to the generated text
|
108 |
+
generated_text += char
|
109 |
+
|
110 |
iface = gr.Interface(fn=gen,
|
111 |
inputs=["text", gr.inputs.Dropdown(get_model_list(), label="Select Model")],
|
112 |
outputs="text")
|