nafisehNik commited on
Commit
93d78ca
·
1 Parent(s): 2ae5795

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -36,19 +36,17 @@ with st.spinner(text="Please wait while the model is loading...."):
36
  tokenizer = load_tokenizer('nafisehNik/girt-t5-base')
37
 
38
 
39
- def compute(sample, num_beams, length_penalty, early_stopping, max_length, min_length):
40
 
41
  inputs = tokenizer(sample, return_tensors="pt").to('cpu')
42
 
43
  outputs = model.generate(
44
  **inputs,
45
- num_beams=num_beams,
46
- num_return_sequences=1,
47
- length_penalty=length_penalty,
48
- no_repeat_ngram_size=2,
49
- early_stopping=early_stopping,
50
  max_length=max_length,
51
- min_length=min_length).to('cpu')
 
 
52
 
53
  generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=False)
54
  generated_text = generated_texts[0]
@@ -83,12 +81,14 @@ with tab2:
83
  sent = st.text_input(
84
  "Sentence:", placeholder="Enter a prompt.", on_change=None
85
  )
86
-
87
  # TODO: Check if this is needed!
88
  clicked = st.button("Submit")
89
 
90
- if sent:
91
- res = compute(sent, num_beams=2, length_penalty=1.0, early_stopping=True, max_length=300, min_length=20)
92
- st.code(res, language="python")
 
 
93
 
94
 
 
36
  tokenizer = load_tokenizer('nafisehNik/girt-t5-base')
37
 
38
 
39
+ def compute(sample, top_p, top_k, do_sample, max_length, min_length):
40
 
41
  inputs = tokenizer(sample, return_tensors="pt").to('cpu')
42
 
43
  outputs = model.generate(
44
  **inputs,
45
+ min_length= min_length,
 
 
 
 
46
  max_length=max_length,
47
+ do_sample=do_sample,
48
+ top_p=top_p,
49
+ top_k=top_k).to('cpu')
50
 
51
  generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=False)
52
  generated_text = generated_texts[0]
 
81
  sent = st.text_input(
82
  "Sentence:", placeholder="Enter a prompt.", on_change=None
83
  )
84
+
85
  # TODO: Check if this is needed!
86
  clicked = st.button("Submit")
87
 
88
+ with st.spinner("Please Wait..."):
89
+
90
+ if sent:
91
+ res = compute(sent, top_p=0.92, top_k=0, do_sample=True, max_length=512, min_length=0)
92
+ st.code(res, language="python")
93
 
94