yirmibesogluz commited on
Commit
4a4e551
·
1 Parent(s): a4a0e50

Re-added home page

Browse files
Files changed (3) hide show
  1. app.py +2 -61
  2. apps/home.py +57 -0
  3. apps/summarization.py +57 -5
app.py CHANGED
@@ -3,6 +3,7 @@ import streamlit as st
3
  from transformers import pipeline
4
 
5
  import apps.summarization
 
6
 
7
  st.set_page_config(
8
  page_title="Turna",
@@ -10,9 +11,8 @@ st.set_page_config(
10
  layout='wide'
11
  )
12
 
13
- API_URL = "https://api-inference.huggingface.co/models/boun-tabi-LMG/TURNA"
14
-
15
  PAGES = {
 
16
  "Text Summarization": apps.summarization
17
  }
18
 
@@ -24,65 +24,6 @@ page = PAGES[selection]
24
  # with st.spinner(f"Loading {selection} ..."):
25
  ast.shared.components.write_page(page)
26
 
27
- st.markdown(
28
- """
29
- <h1 style="text-align:left;">TURNA</h1>
30
- """,
31
- unsafe_allow_html=True,
32
- )
33
-
34
- st.write("#")
35
-
36
- col = st.columns(2)
37
-
38
- col[0].image("images/turna-logo.png", width=100)
39
-
40
- st.markdown(
41
- """
42
-
43
- <h3 style="text-align:right;">TURNA is a Turkish encoder-decoder language model.</h3>
44
-
45
- <p style="text-align:right;"><p>
46
- <p style="text-align:right;">Use the generation paramters on the sidebar to adjust generation quality.</p>
47
- <p style="text-align:right;"><p>
48
- """,
49
- unsafe_allow_html=True,
50
- )
51
-
52
- #st.title('Turkish Language Generation')
53
- #st.write('...with Turna')
54
- input_text = st.text_area(label='Enter a text: ', height=100,
55
- value="Türkiye'nin başkeni neresidir?")
56
- if st.button("Generate"):
57
- with st.spinner('Generating...'):
58
- output = query(input_text)
59
- st.success(output)
60
-
61
- def query(payload):
62
- #{"inputs": payload, ""}
63
- while True:
64
- response = requests.post(API_URL, json=payload)
65
- if 'error' not in response.json():
66
- output = response.json()[0]["generated_text"]
67
- return output
68
- else:
69
- time.sleep(15)
70
- print('Sending request again', flush=True)
71
-
72
- def pipe():
73
- pipe = pipeline("text2text-generation", model="boun-tabi-LMG/TURNA", tokenizer="boun-tabi-LMG/TURNA", temperature=0.7, repetition_penalty=0.5, top_p=0.9)
74
-
75
- """PAGES = {
76
- "Turkish Language Generation": pages.turna,
77
- }
78
-
79
- st.sidebar.title("Navigation")
80
- selection = st.sidebar.radio("Pages", list(PAGES.keys()))
81
-
82
- page = PAGES[selection]
83
- # with st.spinner(f"Loading {selection} ..."):
84
- ast.shared.components.write_page(page)"""
85
-
86
  st.sidebar.header("Info")
87
 
88
  st.sidebar.write(
 
3
  from transformers import pipeline
4
 
5
  import apps.summarization
6
+ import apps.home
7
 
8
  st.set_page_config(
9
  page_title="Turna",
 
11
  layout='wide'
12
  )
13
 
 
 
14
  PAGES = {
15
+ "Turna": apps.home
16
  "Text Summarization": apps.summarization
17
  }
18
 
 
24
  # with st.spinner(f"Loading {selection} ..."):
25
  ast.shared.components.write_page(page)
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  st.sidebar.header("Info")
28
 
29
  st.sidebar.write(
apps/home.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import streamlit as st
3
+ import time
4
+ from transformers import pipeline
5
+ import os
6
+
7
+ st.set_page_config(page_title="TURNA")
8
+
9
+ API_URL = "https://api-inference.huggingface.co/models/boun-tabi-LMG/TURNA"
10
+
11
+ st.markdown(
12
+ """
13
+ <h1 style="text-align:left;">TURNA</h1>
14
+ """,
15
+ unsafe_allow_html=True,
16
+ )
17
+
18
+ st.write("#")
19
+
20
+ col = st.columns(2)
21
+
22
+ col[0].image("images/turna-logo.png", width=100)
23
+
24
+ st.markdown(
25
+ """
26
+
27
+ <h3 style="text-align:right;">TURNA is a Turkish encoder-decoder language model.</h3>
28
+
29
+ <p style="text-align:right;"><p>
30
+ <p style="text-align:right;">Use the generation paramters on the sidebar to adjust generation quality.</p>
31
+ <p style="text-align:right;"><p>
32
+ """,
33
+ unsafe_allow_html=True,
34
+ )
35
+
36
+ #st.title('Turkish Language Generation')
37
+ #st.write('...with Turna')
38
+ input_text = st.text_area(label='Enter a text: ', height=100,
39
+ value="Türkiye'nin başkeni neresidir?")
40
+ if st.button("Generate"):
41
+ with st.spinner('Generating...'):
42
+ output = query(input_text)
43
+ st.success(output)
44
+
45
+ def query(payload):
46
+ #{"inputs": payload, ""}
47
+ while True:
48
+ response = requests.post(API_URL, json=payload)
49
+ if 'error' not in response.json():
50
+ output = response.json()[0]["generated_text"]
51
+ return output
52
+ else:
53
+ time.sleep(15)
54
+ print('Sending request again', flush=True)
55
+
56
+ def pipe():
57
+ pipe = pipeline("text2text-generation", model="boun-tabi-LMG/TURNA", tokenizer="boun-tabi-LMG/TURNA", temperature=0.7, repetition_penalty=0.5, top_p=0.9)
apps/summarization.py CHANGED
@@ -6,7 +6,6 @@ import os
6
 
7
  st.set_page_config(page_title="Text Summarization", page_icon="📈")
8
 
9
- API_URL = "https://api-inference.huggingface.co/models/boun-tabi-LMG/turna_summarization_mlsum"
10
  HF_AUTH_TOKEN = os.getenv('HF_AUTH_TOKEN')
11
  headers = {"Authorization": f"Bearer {HF_AUTH_TOKEN}"}
12
 
@@ -18,18 +17,71 @@ def write():
18
  """Here, you can summarize your text using the fine-tuned TURNA summarization models. """
19
  )
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  input_text = st.text_area(label='Enter a text: ', height=200,
22
  value="Kalp krizi geçirenlerin yaklaşık üçte birinin kısa bir süre önce grip atlattığı düşünülüyor. Peki grip virüsü ne yapıyor da kalp krizine yol açıyor? Karpuz şöyle açıkladı: Grip virüsü kanın yapışkanlığını veya pıhtılaşmasını artırıyor.")
 
 
23
  if st.button("Generate"):
24
  with st.spinner('Generating...'):
25
- output = query(input_text)
26
  st.success(output)
27
 
28
 
29
- def query(payload):
30
- data = {"inputs": payload, "parameters": {"length_penalty": 2.0, "no_repeat_ngram_size": 3, "max_length":128}}
31
  while True:
32
- response = requests.post(API_URL, headers=headers, json=data)
33
  if 'error' not in response.json():
34
  output = response.json()[0]["generated_text"]
35
  return output
 
6
 
7
  st.set_page_config(page_title="Text Summarization", page_icon="📈")
8
 
 
9
  HF_AUTH_TOKEN = os.getenv('HF_AUTH_TOKEN')
10
  headers = {"Authorization": f"Bearer {HF_AUTH_TOKEN}"}
11
 
 
17
  """Here, you can summarize your text using the fine-tuned TURNA summarization models. """
18
  )
19
 
20
+ # Sidebar
21
+
22
+ # Taken from https://huggingface.co/spaces/flax-community/spanish-gpt2/blob/main/app.py
23
+ st.sidebar.subheader("Configurable parameters")
24
+
25
+ model_name = st.sidebar.selectbox(
26
+ "Model Selector",
27
+ options=[
28
+ "turna_summarization_mlsum",
29
+ "turna_summarization_tr_news",
30
+ ],
31
+ index=0,
32
+ )
33
+ max_new_tokens = st.sidebar.number_input(
34
+ "Maximum length",
35
+ min_value=0,
36
+ max_value=128,
37
+ value=128,
38
+ help="The maximum length of the sequence to be generated.",
39
+ )
40
+ length_penalty = st.sidebar.number_input(
41
+ "Length penalty",
42
+ value=2.0,
43
+ help=" length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences. ",
44
+ )
45
+ """do_sample = st.sidebar.selectbox(
46
+ "Sampling?",
47
+ (True, False),
48
+ help="Whether or not to use sampling; use greedy decoding otherwise.",
49
+ )
50
+ num_beams = st.sidebar.number_input(
51
+ "Number of beams",
52
+ min_value=1,
53
+ max_value=10,
54
+ value=3,
55
+ help="The number of beams to use for beam search.",
56
+ )
57
+ repetition_penalty = st.sidebar.number_input(
58
+ "Repetition Penalty",
59
+ min_value=0.0,
60
+ value=3.0,
61
+ step=0.1,
62
+ help="The parameter for repetition penalty. 1.0 means no penalty",
63
+ )"""
64
+ no_repeat_ngram_size = st.sidebar.number_input(
65
+ "No Repeat N-Gram Size",
66
+ min_value=0,
67
+ value=3,
68
+ help="If set to int > 0, all ngrams of that size can only occur once.",
69
+ )
70
+
71
  input_text = st.text_area(label='Enter a text: ', height=200,
72
  value="Kalp krizi geçirenlerin yaklaşık üçte birinin kısa bir süre önce grip atlattığı düşünülüyor. Peki grip virüsü ne yapıyor da kalp krizine yol açıyor? Karpuz şöyle açıkladı: Grip virüsü kanın yapışkanlığını veya pıhtılaşmasını artırıyor.")
73
+ url = ("https://api-inference.huggingface.co/models/boun-tabi-LMG/" + model_name.lower())
74
+ params = {"length_penalty": length_penalty, "no_repeat_ngram_size": no_repeat_ngram_size, "max_new_tokens": max_new_tokens }
75
  if st.button("Generate"):
76
  with st.spinner('Generating...'):
77
+ output = query(input_text, url, params)
78
  st.success(output)
79
 
80
 
81
+ def query(text, url, params):
82
+ data = {"inputs": payload, "parameters": params}
83
  while True:
84
+ response = requests.post(url, headers=headers, json=data)
85
  if 'error' not in response.json():
86
  output = response.json()[0]["generated_text"]
87
  return output