File size: 3,109 Bytes
95b97b8
 
 
 
 
 
a4a0e50
 
 
 
 
95b97b8
 
 
 
 
 
 
 
4a4e551
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95b97b8
a4a0e50
4a4e551
 
95b97b8
 
4a4e551
95b97b8
 
 
4a4e551
 
95b97b8
4a4e551
95b97b8
 
 
 
a4a0e50
95b97b8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import requests
import streamlit as st
import time 
from transformers import pipeline
import os

st.set_page_config(page_title="Text Summarization", page_icon="📈")

HF_AUTH_TOKEN = os.getenv('HF_AUTH_TOKEN')
headers = {"Authorization": f"Bearer {HF_AUTH_TOKEN}"}

def write():

	st.markdown("# Text Summarization")
	st.sidebar.header("Text Summarization")
	st.write(
		"""Here, you can summarize your text using the fine-tuned TURNA summarization models. """
	)

	# Sidebar

    # Taken from https://huggingface.co/spaces/flax-community/spanish-gpt2/blob/main/app.py
    st.sidebar.subheader("Configurable parameters")

    model_name = st.sidebar.selectbox(
        "Model Selector",
        options=[
            "turna_summarization_mlsum",
            "turna_summarization_tr_news",
        ],
        index=0,
    )
    max_new_tokens = st.sidebar.number_input(
        "Maximum length",
        min_value=0,
        max_value=128,
        value=128,
        help="The maximum length of the sequence to be generated.",
    )
    length_penalty = st.sidebar.number_input(
        "Length penalty",
        value=2.0,
        help=" length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences. ",
    )
    """do_sample = st.sidebar.selectbox(
        "Sampling?",
        (True, False),
        help="Whether or not to use sampling; use greedy decoding otherwise.",
    )
    num_beams = st.sidebar.number_input(
        "Number of beams",
        min_value=1,
        max_value=10,
        value=3,
        help="The number of beams to use for beam search.",
    )
    repetition_penalty = st.sidebar.number_input(
        "Repetition Penalty",
        min_value=0.0,
        value=3.0,
        step=0.1,
        help="The parameter for repetition penalty. 1.0 means no penalty",
    )"""
    no_repeat_ngram_size = st.sidebar.number_input(
        "No Repeat N-Gram Size",
        min_value=0,
        value=3,
        help="If set to int > 0, all ngrams of that size can only occur once.",
    )

	input_text = st.text_area(label='Enter a text: ', height=200, 
			value="Kalp krizi geçirenlerin yaklaşık üçte birinin kısa bir süre önce grip atlattığı düşünülüyor. Peki grip virüsü ne yapıyor da kalp krizine yol açıyor? Karpuz şöyle açıkladı: Grip virüsü kanın yapışkanlığını veya pıhtılaşmasını artırıyor.")
	url = ("https://api-inference.huggingface.co/models/boun-tabi-LMG/" + model_name.lower())
	params = {"length_penalty": length_penalty, "no_repeat_ngram_size": no_repeat_ngram_size, "max_new_tokens": max_new_tokens }
	if st.button("Generate"):
		with st.spinner('Generating...'):
			output = query(input_text, url, params)
			st.success(output)


def query(text, url, params):
	data = {"inputs": payload, "parameters": params}
	while True:
		response = requests.post(url, headers=headers, json=data)
		if 'error' not in response.json():
			output = response.json()[0]["generated_text"]
			return output
		else:
			print(response.json())
			time.sleep(15)
			print('Sending request again', flush=True)