yirmibesogluz commited on
Commit
6805c60
1 Parent(s): 3255b79

Added generation params to summarization

Browse files
Files changed (1) hide show
  1. app.py +18 -10
app.py CHANGED
@@ -110,14 +110,11 @@ def paraphrase(input, model_choice="turna_paraphrasing_tatoeba"):
110
  return paraphrasing_sub(input)[0]["generated_text"]
111
 
112
  @spaces.GPU
113
- def summarize(input, model_choice="turna_summarization_tr_news"):
114
- if model_choice=="turna_summarization_tr_news":
115
- news_sum = pipeline(model="boun-tabi-LMG/turna_summarization_tr_news", device=0)
116
-
117
- return news_sum(input)[0]["generated_text"]
118
- else:
119
- summarization_model = pipeline(model="boun-tabi-LMG/turna_summarization_mlsum", device=0)
120
- return summarization_model(input)[0]["generated_text"]
121
 
122
  @spaces.GPU
123
  def categorize(input):
@@ -258,12 +255,23 @@ with gr.Blocks(theme="abidlabs/Lime") as demo:
258
  with gr.Row():
259
  with gr.Column():
260
  sum_choice = gr.Radio(choices = ["turna_summarization_mlsum", "turna_summarization_tr_news"], label ="Model", value="turna_summarization_mlsum")
 
 
 
 
 
 
 
 
 
 
 
261
  sum_input = gr.Textbox(label = "Summarization Input")
262
  sum_submit = gr.Button()
263
  sum_output = gr.Textbox(label = "Summarization Output")
264
 
265
- sum_submit.click(summarize, inputs=[sum_input, sum_choice], outputs=sum_output)
266
- sum_examples = gr.Examples(examples = long_text, inputs = [sum_input, sum_choice], outputs=sum_output, fn=summarize)
267
 
268
  gr.Markdown(CITATION)
269
 
 
110
  return paraphrasing_sub(input)[0]["generated_text"]
111
 
112
  @spaces.GPU
113
+ def summarize(input, model_choice="turna_summarization_tr_news", max_new_tokens, length_penalty, no_repeat_ngram_size):
114
+ model_mapping = {"turna_summarization_tr_news": "boun-tabi-LMG/turna_summarization_tr_news",
115
+ "turna_summarization_mlsum": "boun-tabi-LMG/turna_summarization_mlsum"}
116
+ summarization_model = pipeline(model=model_mapping[model_choice], device=0)
117
+ return summarization_model(input, max_new_tokens = max_new_tokens, length_penalty=length_penalty, no_repeat_ngram_size=no_repeat_ngram_size)[0]["generated_text"]
 
 
 
118
 
119
  @spaces.GPU
120
  def categorize(input):
 
255
  with gr.Row():
256
  with gr.Column():
257
  sum_choice = gr.Radio(choices = ["turna_summarization_mlsum", "turna_summarization_tr_news"], label ="Model", value="turna_summarization_mlsum")
258
+ with gr.Accordion("Advanced Generation Parameters"):
259
+ max_new_tokens = gr.Slider(label = "Maximum length",
260
+ minimum = 0,
261
+ maximum = 512,
262
+ value = 128)
263
+ length_penalty = gr.Slider(label = "Length penalty",
264
+ minimum = -10,
265
+ maximum = 10,
266
+ value=2.0)
267
+ no_repeat_ngram_size =gr.Slider(label="No Repeat N-Gram Size", minimum=0,value=3,)
268
+ with gr.Column():
269
  sum_input = gr.Textbox(label = "Summarization Input")
270
  sum_submit = gr.Button()
271
  sum_output = gr.Textbox(label = "Summarization Output")
272
 
273
+ sum_submit.click(summarize, inputs=[sum_input, sum_choice, max_new_tokens, length_penalty, no_repeat_ngram_size], outputs=sum_output)
274
+ sum_examples = gr.Examples(examples = long_text, inputs = [sum_input, sum_choice, max_new_tokens, length_penalty, no_repeat_ngram_size], outputs=sum_output, fn=summarize)
275
 
276
  gr.Markdown(CITATION)
277