Corianas commited on
Commit
64f85a2
·
1 Parent(s): 713c8ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -18
app.py CHANGED
@@ -67,23 +67,6 @@ with open(meta_path, 'rb') as f:
67
  stoi, itos = meta['stoi'], meta['itos']
68
  encode = lambda s: [stoi[c] for c in s]
69
  decode = lambda l: ''.join([itos[i] for i in l])
70
-
71
- def gen(input):
72
- generated_text = ''
73
- start_ids = encode(add_caseifer(input))
74
- x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
75
- for idx_next in model.generate_streaming(x, max_new_tokens, temperature=temperature, top_k=top_k):
76
- # convert the index to a character and print it to the screen
77
- char = decode([idx_next])
78
-
79
-
80
- # check for newline character
81
- if char == '\n':
82
- out = remove_caseifer(generated_text)
83
- return input + out
84
- else:
85
- # append the character to the generated text
86
- generated_text += char
87
 
88
  def load_model(model_name):
89
  ckpt_path = os.path.join(out_dir, model_name)
@@ -108,8 +91,22 @@ def get_model_list():
108
 
109
  def gen(input, model_name):
110
  model = load_model(model_name)
111
- # the rest of your gen function using this model...
 
 
 
 
 
 
112
 
 
 
 
 
 
 
 
 
113
  iface = gr.Interface(fn=gen,
114
  inputs=["text", gr.inputs.Dropdown(get_model_list(), label="Select Model")],
115
  outputs="text")
 
67
  stoi, itos = meta['stoi'], meta['itos']
68
  encode = lambda s: [stoi[c] for c in s]
69
  decode = lambda l: ''.join([itos[i] for i in l])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  def load_model(model_name):
72
  ckpt_path = os.path.join(out_dir, model_name)
 
91
 
92
  def gen(input, model_name):
93
  model = load_model(model_name)
94
+ generated_text = ''
95
+ start_ids = encode(add_caseifer(input))
96
+ x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
97
+ for idx_next in model.generate_streaming(x, max_new_tokens, temperature=temperature, top_k=top_k):
98
+ # convert the index to a character and print it to the screen
99
+ char = decode([idx_next])
100
+
101
 
102
+ # check for newline character
103
+ if char == '\n':
104
+ out = remove_caseifer(generated_text)
105
+ return input + out
106
+ else:
107
+ # append the character to the generated text
108
+ generated_text += char
109
+
110
  iface = gr.Interface(fn=gen,
111
  inputs=["text", gr.inputs.Dropdown(get_model_list(), label="Select Model")],
112
  outputs="text")