Stefan Dumitrescu commited on
Commit
c44f938
1 Parent(s): 0957f7e
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -33,9 +33,9 @@ top_p = st.sidebar.slider("Top-p", min_value=0.0, max_value=1.0, step=0.05, valu
33
 
34
 
35
  @st.cache(allow_output_mutation=True)
36
- def setModel(model_name):
37
- model = AutoModelWithLMHead.from_pretrained(model_name)
38
- tokenizer = AutoTokenizer.from_pretrained(model_name)
39
  return model, tokenizer
40
 
41
  def infer(model, tokenizer, text, input_ids, max_length, temperature, top_k, top_p):
@@ -52,7 +52,7 @@ def infer(model, tokenizer, text, input_ids, max_length, temperature, top_k, top
52
 
53
  return output_sequences
54
 
55
-
56
  output_sequences = infer(model, tokenizer, text_element, input_ids, max_length, temperature, top_k, top_p)
57
 
58
  for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
 
33
 
34
 
35
  @st.cache(allow_output_mutation=True)
36
+ def setModel(model_checkpoint):
37
+ model = AutoModelWithLMHead.from_pretrained(model_checkpoint)
38
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
39
  return model, tokenizer
40
 
41
  def infer(model, tokenizer, text, input_ids, max_length, temperature, top_k, top_p):
 
52
 
53
  return output_sequences
54
 
55
+ model, tokenizer = setModel(model_checkpoint)
56
  output_sequences = infer(model, tokenizer, text_element, input_ids, max_length, temperature, top_k, top_p)
57
 
58
  for generated_sequence_idx, generated_sequence in enumerate(output_sequences):