peder commited on
Commit
c18b92f
·
1 Parent(s): ba2588f

fix on button press + prompting

Browse files
Files changed (2) hide show
  1. app.py +59 -17
  2. app.py.d264c618bf578cebbd21d2c2379d7b50.tmp +321 -0
app.py CHANGED
@@ -24,7 +24,7 @@ MODEL_NAME = os.environ.get("MODEL_NAME", "NbAiLab/nb-gpt-j-6B-alpaca")
24
  MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 256))
25
 
26
  HEADER_INFO = """
27
- # CBS_Alpaca-GPT-j
28
  Norwegian GPT-J-6B NorPaca Model.
29
  """.strip()
30
  LOGO = "https://upload.wikimedia.org/wikipedia/commons/thumb/1/19/Logo_CopenhagenBusinessSchool.svg/1200px-Logo_CopenhagenBusinessSchool.svg.png"
@@ -44,7 +44,8 @@ For more information, visit the [model repository](https://huggingface.co/CBSMas
44
 
45
  ## Configuration
46
  """.strip()
47
- PROMPT_BOX = "Enter your text..."
 
48
  EXAMPLES = [
49
  "Nedenfor er en instruksjon som beskriver en oppgave. Skriv et svar som fullfører forespørselen på riktig måte. ### Instruksjon: Analyser fordelene ved å jobbe i et team. ### Respons:",
50
  'Nedenfor er en instruksjon som beskriver en oppgave. Skriv et svar som fullfører forespørselen på riktig måte. ### Instruksjon: Oppsummer den faglige artikkelen "Kunstig intelligens og arbeidets fremtid". ### Respons:',
@@ -141,9 +142,32 @@ class TextGeneration:
141
  **generation_kwargs,
142
  )[0]["generated_text"]
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  # @st.cache(allow_output_mutation=True, hash_funcs={AutoModelForCausalLM: lambda _: None})
146
- @st.cache(allow_output_mutation=True, hash_funcs={TextGeneration: lambda _: None})
147
  def load_text_generator():
148
  generator = TextGeneration()
149
  generator.load()
@@ -188,7 +212,7 @@ def main():
188
  "generation.",
189
  min_value=0.0,
190
  max_value=1.0,
191
- value=float(query_params.get("top_p", [0.95])[0]),
192
  step=0.01
193
  )
194
  temperature = st.sidebar.slider(
@@ -196,7 +220,7 @@ def main():
196
  help="The value used to module the next token probabilities",
197
  min_value=0.1,
198
  max_value=10.0,
199
- value=float(query_params.get("temperature", [0.8])[0]),
200
  step=0.05
201
  )
202
  do_sample = st.sidebar.selectbox(
@@ -206,13 +230,6 @@ def main():
206
  index=int(query_params.get("do_sample", ["true"])[
207
  0].lower()[0] in ("t", "y", "1")),
208
  )
209
- # do_clean = st.sidebar.selectbox(
210
- # label='Clean text?',
211
- # options=(False, True),
212
- # help="Whether or not to remove repeated words and trim unfinished last sentences.",
213
- # index=int(query_params.get("do_clean", ["true"])[
214
- # 0].lower()[0] in ("t", "y", "1")),
215
- # )
216
  generation_kwargs = {
217
  "max_length": max_length,
218
  "top_k": top_k,
@@ -226,19 +243,44 @@ def main():
226
  prompt = st.selectbox('Examples', prompts, index=len(prompts) - 1)
227
 
228
  if prompt == "Custom":
229
- prompt_box = query_params.get("text", [PROMPT_BOX])[0].strip()
 
 
 
 
230
  else:
 
 
 
 
 
 
 
 
 
231
  prompt_box = prompt
232
 
233
- text = st.text_area("Enter text", prompt_box)
 
 
 
 
 
 
 
 
 
234
  generation_kwargs_ph = st.empty()
235
  cleaner = Normalizer()
236
- if st.button("Generate!") or text != PROMPT_BOX:
237
  output = st.empty()
238
  with st.spinner(text="Generating..."):
239
  generation_kwargs_ph.markdown(
240
  ", ".join([f"`{k}`: {v}" for k, v in generation_kwargs.items()]))
241
- if text:
 
 
 
242
  share_args = {"text": text, **generation_kwargs}
243
  st.experimental_set_query_params(**share_args)
244
  for _ in range(5):
@@ -260,7 +302,7 @@ def main():
260
  components.html(
261
  f"""
262
  <a class="twitter-share-button"
263
- data-text="Check my prompt using NB-GPT-J-6B-NorPaca!🇳🇴 https://ai.nb.no/demo/nb-gpt-j-6B-NorPaca/?{urlencode(share_args)}"
264
  data-show-count="false">
265
  data-size="Small"
266
  data-hashtags="nb,gpt-j"
 
24
  MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 256))
25
 
26
  HEADER_INFO = """
27
+ # GPT-NorPaca
28
  Norwegian GPT-J-6B NorPaca Model.
29
  """.strip()
30
  LOGO = "https://upload.wikimedia.org/wikipedia/commons/thumb/1/19/Logo_CopenhagenBusinessSchool.svg/1200px-Logo_CopenhagenBusinessSchool.svg.png"
 
44
 
45
  ## Configuration
46
  """.strip()
47
+ PROMPT_BOX_INSTRUCTION = "Enter your Instructions here..."
48
+ PROMPT_BOX_INPUT = "Enter your Input here..."
49
  EXAMPLES = [
50
  "Nedenfor er en instruksjon som beskriver en oppgave. Skriv et svar som fullfører forespørselen på riktig måte. ### Instruksjon: Analyser fordelene ved å jobbe i et team. ### Respons:",
51
  'Nedenfor er en instruksjon som beskriver en oppgave. Skriv et svar som fullfører forespørselen på riktig måte. ### Instruksjon: Oppsummer den faglige artikkelen "Kunstig intelligens og arbeidets fremtid". ### Respons:',
 
142
  **generation_kwargs,
143
  )[0]["generated_text"]
144
 
145
+ # Generate responses
146
+
147
+
148
+ def generate_prompt(instruction, input=None):
149
+ if input:
150
+ prompt = f"""Nedenfor er en instruksjon som beskriver en oppgave, sammen med et input som gir ytterligere kontekst. Skriv et svar som fullfører forespørselen på riktig måte.
151
+
152
+ ### Instruksjon:
153
+ {instruction}
154
+
155
+ ### Input:
156
+ {input}
157
+
158
+ ### Respons:"""
159
+ else:
160
+ prompt = f""""Nedenfor er en instruksjon som beskriver en oppgave. Skriv et svar som fullfører forespørselen på riktig måte.
161
+
162
+ ### Instruksjon:
163
+ {instruction}
164
+
165
+ ### Respons:"""
166
+ return prompt
167
+
168
 
169
  # @st.cache(allow_output_mutation=True, hash_funcs={AutoModelForCausalLM: lambda _: None})
170
+ # @st.cache(allow_output_mutation=True, hash_funcs={TextGeneration: lambda _: None})
171
  def load_text_generator():
172
  generator = TextGeneration()
173
  generator.load()
 
212
  "generation.",
213
  min_value=0.0,
214
  max_value=1.0,
215
+ value=float(query_params.get("top_p", [0.75])[0]),
216
  step=0.01
217
  )
218
  temperature = st.sidebar.slider(
 
220
  help="The value used to module the next token probabilities",
221
  min_value=0.1,
222
  max_value=10.0,
223
+ value=float(query_params.get("temperature", [0.2])[0]),
224
  step=0.05
225
  )
226
  do_sample = st.sidebar.selectbox(
 
230
  index=int(query_params.get("do_sample", ["true"])[
231
  0].lower()[0] in ("t", "y", "1")),
232
  )
 
 
 
 
 
 
 
233
  generation_kwargs = {
234
  "max_length": max_length,
235
  "top_k": top_k,
 
243
  prompt = st.selectbox('Examples', prompts, index=len(prompts) - 1)
244
 
245
  if prompt == "Custom":
246
+ prompt_box_instruction = query_params.get(
247
+ "text1", [PROMPT_BOX_INSTRUCTION])[0].strip()
248
+ prompt_box_input = query_params.get(
249
+ "text2", [PROMPT_BOX_INPUT])[0].strip()
250
+ prompt_box = f"{prompt_box_instruction} {prompt_box_input}"
251
  else:
252
+ if "### Input:" in prompt:
253
+ prompt_box_instruction = prompt.split("### Instruksjon:")[
254
+ 1].split("### Input:")[0].strip()
255
+ prompt_box_input = prompt.split(
256
+ "### Input:")[1].split("### Respons:")[0].strip()
257
+ else:
258
+ prompt_box_instruction = prompt.split(
259
+ "### Instruksjon:")[1].split("### Respons:")[0].strip()
260
+ prompt_box_input = None
261
  prompt_box = prompt
262
 
263
+ if prompt == "Custom":
264
+ text_instruction = st.text_area(
265
+ "Enter Instruction", PROMPT_BOX_INSTRUCTION)
266
+ text_input = st.text_area("Enter Input", PROMPT_BOX_INPUT)
267
+ else:
268
+ text_instruction = st.text_area(
269
+ "Enter Instruction", prompt_box_instruction)
270
+ text_input = st.text_area("Enter Input", prompt_box_input) if "### Input:" in prompt else st.text_area(
271
+ "Enter Input", PROMPT_BOX_INPUT)
272
+
273
  generation_kwargs_ph = st.empty()
274
  cleaner = Normalizer()
275
+ if st.button("Generate!"):
276
  output = st.empty()
277
  with st.spinner(text="Generating..."):
278
  generation_kwargs_ph.markdown(
279
  ", ".join([f"`{k}`: {v}" for k, v in generation_kwargs.items()]))
280
+ if text_instruction:
281
+ text = generate_prompt(text_instruction, text_input) if text_input != "Enter your Input here..." else generate_prompt(
282
+ text_instruction)
283
+ #print("TEXT OUT", text)
284
  share_args = {"text": text, **generation_kwargs}
285
  st.experimental_set_query_params(**share_args)
286
  for _ in range(5):
 
302
  components.html(
303
  f"""
304
  <a class="twitter-share-button"
305
+ data-text="Check my prompt using NB-GPT-J-6B-NorPaca!🇳🇴 https://huggingface.co/spaces/MasterThesisCBS/NorPaca_GPT?{urlencode(share_args)}"
306
  data-show-count="false">
307
  data-size="Small"
308
  data-hashtags="nb,gpt-j"
app.py.d264c618bf578cebbd21d2c2379d7b50.tmp ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import os
3
+ from urllib.parse import urlencode
4
+ #from pyngrok import ngrok
5
+
6
+ import streamlit as st
7
+ import streamlit.components.v1 as components
8
+ import torch
9
+ from transformers import pipeline, set_seed
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM
11
+
12
+ # #import torch
13
+ # print(f"Is CUDA available: {torch.cuda.is_available()}")
14
+ # # True
15
+ # print(
16
+ # f"CUDA device for you Perrito: {torch.cuda.get_device_name(torch.cuda.current_device())}")
17
+ # # Tesla T4
18
+
19
+ HF_AUTH_TOKEN = "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"
20
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ #print("DEVICE SENOOOOOR", DEVICE)
22
+ DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16
23
+ MODEL_NAME = os.environ.get("MODEL_NAME", "NbAiLab/nb-gpt-j-6B-alpaca")
24
+ MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 256))
25
+
26
+ HEADER_INFO = """
27
+ # GPT-NorPaca
28
+ Norwegian GPT-J-6B NorPaca Model.
29
+ """.strip()
30
+ LOGO = "https://upload.wikimedia.org/wikipedia/commons/thumb/1/19/Logo_CopenhagenBusinessSchool.svg/1200px-Logo_CopenhagenBusinessSchool.svg.png"
31
+ SIDEBAR_INFO = f"""
32
+ <div align=center>
33
+ <img src="{LOGO}" width=100/>
34
+
35
+ # NB-GPT-J-6B-NorPaca
36
+
37
+ </div>
38
+
39
+ NB-GPT-J-6B NorPaca is a hybrid of a GPT-3 and Llama model, trained on the Norwegian Colossal Corpus and other Internet sources. It is a 6.7 billion parameter model, and is the largest model in the GPT-J family.
40
+
41
+ This model has been trained with [Mesh Transformer JAX](https://github.com/kingoflolz/mesh-transformer-jax) using TPUs provided by Google through the Tensor Research Cloud program, starting off the [GPT-J-6B model weigths from EleutherAI](https://huggingface.co/EleutherAI/gpt-j-6B), and trained on the [Norwegian Colossal Corpus](https://huggingface.co/datasets/NbAiLab/NCC) and other Internet sources. *This demo runs on {DEVICE}*.
42
+
43
+ For more information, visit the [model repository](https://huggingface.co/CBSMasterThesis).
44
+
45
+ ## Configuration
46
+ """.strip()
47
+ PROMPT_BOX_INSTRUCTION = "Enter your Instructions here..."
48
+ PROMPT_BOX_INPUT = "Enter your Input here..."
49
+ EXAMPLES = [
50
+ "Nedenfor er en instruksjon som beskriver en oppgave. Skriv et svar som fullfører forespørselen på riktig måte. ### Instruksjon: Analyser fordelene ved å jobbe i et team. ### Respons:",
51
+ 'Nedenfor er en instruksjon som beskriver en oppgave. Skriv et svar som fullfører forespørselen på riktig måte. ### Instruksjon: Oppsummer den faglige artikkelen "Kunstig intelligens og arbeidets fremtid". ### Respons:',
52
+ 'Nedenfor er en instruksjon som beskriver en oppgave. Skriv et svar som fullfører forespørselen på riktig måte. ### Instruksjon: Generer et kreativt slagord for en bedrift som bruker fornybare energikilder. ### Respons:',
53
+ 'Nedenfor er en instruksjon som beskriver en oppgave. Skriv et svar som fullfører forespørselen på riktig måte. ### Instruksjon: Regn ut arealet av en firkant med lengde 10m. Skriv ut et flyttall. ### Respons:',
54
+ ]
55
+
56
+
57
+ def style():
58
+ st.markdown("""
59
+ <link href="https://fonts.googleapis.com/css2?family=Roboto:wght@300&display=swap%22%20rel=%22stylesheet%22" rel="stylesheet">
60
+ <style>
61
+ .ltr,
62
+ textarea {
63
+ font-family: Roboto !important;
64
+ text-align: left;
65
+ direction: ltr !important;
66
+ }
67
+ .ltr-box {
68
+ border-bottom: 1px solid #ddd;
69
+ padding-bottom: 20px;
70
+ }
71
+ .rtl {
72
+ text-align: left;
73
+ direction: ltr !important;
74
+ }
75
+ span.result-text {
76
+ padding: 3px 3px;
77
+ line-height: 32px;
78
+ }
79
+ span.generated-text {
80
+ background-color: rgb(118 200 147 / 13%);
81
+ }
82
+ </style>""", unsafe_allow_html=True)
83
+
84
+
85
+ class Normalizer:
86
+ def remove_repetitions(self, text):
87
+ """Remove repetitions"""
88
+ first_ocurrences = []
89
+ for sentence in text.split("."):
90
+ if sentence not in first_ocurrences:
91
+ first_ocurrences.append(sentence)
92
+ return '.'.join(first_ocurrences)
93
+
94
+ def trim_last_sentence(self, text):
95
+ """Trim last sentence if incomplete"""
96
+ return text[:text.rfind(".") + 1]
97
+
98
+ def clean_txt(self, text):
99
+ return self.trim_last_sentence(self.remove_repetitions(text))
100
+
101
+
102
+ class TextGeneration:
103
+ def __init__(self):
104
+ self.tokenizer = None
105
+ self.generator = None
106
+ self.task = "text-generation"
107
+ self.model_name_or_path = MODEL_NAME
108
+ set_seed(42)
109
+
110
+ def load(self):
111
+ print("Loading model... ", end="")
112
+ self.tokenizer = AutoTokenizer.from_pretrained(
113
+ self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None,
114
+ )
115
+ self.model = AutoModelForCausalLM.from_pretrained(
116
+ self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None,
117
+ pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id,
118
+ torch_dtype=DTYPE, low_cpu_mem_usage=False if DEVICE == "cpu" else True
119
+ ).to(device=DEVICE, non_blocking=True)
120
+ _ = self.model.eval()
121
+ # -1 if DEVICE == "cpu" else int(DEVICE.split(":")[-1])
122
+ device_number = torch.cuda.current_device()
123
+ self.generator = pipeline(
124
+ self.task, model=self.model, tokenizer=self.tokenizer, device=device_number)
125
+ print("Done")
126
+ # with torch.no_grad():
127
+ # tokens = tokenizer.encode(prompt, return_tensors='pt').to(device=device, non_blocking=True)
128
+ # gen_tokens = self.model.generate(tokens, do_sample=True, temperature=0.8, max_length=128)
129
+ # generated = tokenizer.batch_decode(gen_tokens)[0]
130
+
131
+ # return generated
132
+
133
+ def generate(self, prompt, generation_kwargs):
134
+ max_length = len(self.tokenizer(prompt)[
135
+ "input_ids"]) + generation_kwargs["max_length"]
136
+ generation_kwargs["max_length"] = min(
137
+ max_length, self.model.config.n_positions)
138
+ # generation_kwargs["num_return_sequences"] = 1
139
+ # generation_kwargs["return_full_text"] = False
140
+ return self.generator(
141
+ prompt,
142
+ **generation_kwargs,
143
+ )[0]["generated_text"]
144
+
145
+ # Generate responses
146
+
147
+
148
+ def generate_prompt(instruction, input=None):
149
+ if input:
150
+ prompt = f"""Nedenfor er en instruksjon som beskriver en oppgave, sammen med et input som gir ytterligere kontekst. Skriv et svar som fullfører forespørselen på riktig måte.
151
+
152
+ ### Instruksjon:
153
+ {instruction}
154
+
155
+ ### Input:
156
+ {input}
157
+
158
+ ### Respons:"""
159
+ else:
160
+ prompt = f""""Nedenfor er en instruksjon som beskriver en oppgave. Skriv et svar som fullfører forespørselen på riktig måte.
161
+
162
+ ### Instruksjon:
163
+ {instruction}
164
+
165
+ ### Respons:"""
166
+ return prompt
167
+
168
+
169
+ # @st.cache(allow_output_mutation=True, hash_funcs={AutoModelForCausalLM: lambda _: None})
170
+ # @st.cache(allow_output_mutation=True, hash_funcs={TextGeneration: lambda _: None})
171
+ def load_text_generator():
172
+ generator = TextGeneration()
173
+ generator.load()
174
+ return generator
175
+
176
+
177
+ def main():
178
+ st.set_page_config(
179
+ page_title="NB-GPT-J-6B-NorPaca",
180
+ page_icon="🇳🇴",
181
+ layout="wide",
182
+ initial_sidebar_state="expanded"
183
+ )
184
+ style()
185
+ with st.spinner('Loading the model. Please, wait...'):
186
+ generator = load_text_generator()
187
+
188
+ st.sidebar.markdown(SIDEBAR_INFO, unsafe_allow_html=True)
189
+ query_params = st.experimental_get_query_params()
190
+ if query_params:
191
+ st.experimental_set_query_params(**dict())
192
+
193
+ max_length = st.sidebar.slider(
194
+ label='Max words to generate',
195
+ help="The maximum length of the sequence to be generated.",
196
+ min_value=1,
197
+ max_value=MAX_LENGTH,
198
+ value=int(query_params.get("max_length", [50])[0]),
199
+ step=1
200
+ )
201
+ top_k = st.sidebar.slider(
202
+ label='Top-k',
203
+ help="The number of highest probability vocabulary tokens to keep for top-k-filtering",
204
+ min_value=40,
205
+ max_value=80,
206
+ value=int(query_params.get("top_k", [50])[0]),
207
+ step=1
208
+ )
209
+ top_p = st.sidebar.slider(
210
+ label='Top-p',
211
+ help="Only the most probable tokens with probabilities that add up to `top_p` or higher are kept for "
212
+ "generation.",
213
+ min_value=0.0,
214
+ max_value=1.0,
215
+ value=float(query_params.get("top_p", [0.75])[0]),
216
+ step=0.01
217
+ )
218
+ temperature = st.sidebar.slider(
219
+ label='Temperature',
220
+ help="The value used to module the next token probabilities",
221
+ min_value=0.1,
222
+ max_value=10.0,
223
+ value=float(query_params.get("temperature", [0.2])[0]),
224
+ step=0.05
225
+ )
226
+ do_sample = st.sidebar.selectbox(
227
+ label='Sampling?',
228
+ options=(False, True),
229
+ help="Whether or not to use sampling; use greedy decoding otherwise.",
230
+ index=int(query_params.get("do_sample", ["true"])[
231
+ 0].lower()[0] in ("t", "y", "1")),
232
+ )
233
+ generation_kwargs = {
234
+ "max_length": max_length,
235
+ "top_k": top_k,
236
+ "top_p": top_p,
237
+ "temperature": temperature,
238
+ "do_sample": do_sample,
239
+ # "do_clean": do_clean,
240
+ }
241
+ st.markdown(HEADER_INFO)
242
+ prompts = EXAMPLES + ["Custom"]
243
+ prompt = st.selectbox('Examples', prompts, index=len(prompts) - 1)
244
+
245
+ if prompt == "Custom":
246
+ prompt_box_instruction = query_params.get(
247
+ "text1", [PROMPT_BOX_INSTRUCTION])[0].strip()
248
+ prompt_box_input = query_params.get(
249
+ "text2", [PROMPT_BOX_INPUT])[0].strip()
250
+ prompt_box = f"{prompt_box_instruction} {prompt_box_input}"
251
+ else:
252
+ if "### Input:" in prompt:
253
+ prompt_box_instruction = prompt.split("### Instruksjon:")[
254
+ 1].split("### Input:")[0].strip()
255
+ prompt_box_input = prompt.split(
256
+ "### Input:")[1].split("### Respons:")[0].strip()
257
+ else:
258
+ prompt_box_instruction = prompt.split(
259
+ "### Instruksjon:")[1].split("### Respons:")[0].strip()
260
+ prompt_box_input = None
261
+ prompt_box = prompt
262
+
263
+ if prompt == "Custom":
264
+ text_instruction = st.text_area(
265
+ "Enter Instruction", PROMPT_BOX_INSTRUCTION)
266
+ text_input = st.text_area("Enter Input", PROMPT_BOX_INPUT)
267
+ else:
268
+ text_instruction = st.text_area(
269
+ "Enter Instruction", prompt_box_instruction)
270
+ text_input = st.text_area("Enter Input", prompt_box_input) if "### Input:" in prompt else st.text_area(
271
+ "Enter Input", PROMPT_BOX_INPUT)
272
+
273
+ generation_kwargs_ph = st.empty()
274
+ cleaner = Normalizer()
275
+ if st.button("Generate!"):
276
+ output = st.empty()
277
+ with st.spinner(text="Generating..."):
278
+ generation_kwargs_ph.markdown(
279
+ ", ".join([f"`{k}`: {v}" for k, v in generation_kwargs.items()]))
280
+ if text_instruction:
281
+ text = generate_prompt(text_instruction, text_input) if text_input != "Enter your Input here..." else generate_prompt(
282
+ text_instruction)
283
+ #print("TEXT OUT", text)
284
+ share_args = {"text": text, **generation_kwargs}
285
+ st.experimental_set_query_params(**share_args)
286
+ for _ in range(5):
287
+ generated_text = generator.generate(
288
+ text, generation_kwargs)
289
+ # if do_clean:
290
+ # generated_text = cleaner.clean_txt(generated_text)
291
+ if generated_text.strip().startswith(text):
292
+ generated_text = generated_text.replace(
293
+ text, "", 1).strip()
294
+ output.markdown(
295
+ f'<p class="ltr ltr-box">'
296
+ f'<span class="result-text">{text} <span>'
297
+ f'<span class="result-text generated-text">{generated_text}</span>'
298
+ f'</p>',
299
+ unsafe_allow_html=True
300
+ )
301
+ if generated_text.strip():
302
+ components.html(
303
+ f"""
304
+ <a class="twitter-share-button"
305
+ data-text="Check my prompt using NB-GPT-J-6B-NorPaca!🇳🇴 https://ai.nb.no/demo/nb-gpt-j-6B-NorPaca/?{urlencode(share_args)}"
306
+ data-show-count="false">
307
+ data-size="Small"
308
+ data-hashtags="nb,gpt-j"
309
+ Tweet
310
+ </a>
311
+ <script async src="https://platform.twitter.com/widgets.js" charset="utf-8"></script>
312
+ """
313
+ )
314
+ break
315
+ if not generated_text.strip():
316
+ st.markdown(
317
+ "*Tried 5 times but did not produce any result. Try again!*")
318
+
319
+
320
+ if __name__ == '__main__':
321
+ main()