peder commited on
Commit
9feb130
·
1 Parent(s): 45d07a3
Files changed (1) hide show
  1. app.py +264 -4
app.py CHANGED
@@ -1,10 +1,270 @@
 
 
 
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
 
 
 
 
3
 
4
- def hello():
5
- x = st.slider('Select a value')
6
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
 
9
  if __name__ == '__main__':
10
- hello()
 
1
+ import random
2
+ import os
3
+ from urllib.parse import urlencode
4
+
5
  import streamlit as st
6
+ import streamlit.components.v1 as components
7
+ import torch
8
+ from transformers import pipeline, set_seed
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+
11
+
12
+ HF_AUTH_TOKEN = "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"
13
+ DEVICE = os.environ.get("cuda:0", "cpu") # cuda:0
14
+ DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16
15
+ MODEL_NAME = os.environ.get("MODEL_NAME", "NbAiLab/nb-gpt-j-6B-alpaca")
16
+ MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 256))
17
+ print("hello Boys")
18
+ HEADER_INFO = """
19
+ # CBS_Alpaca-GPT-j
20
+ Norwegian GPT-J-6B NorPaca Model.
21
+ """.strip()
22
+ LOGO = "https://upload.wikimedia.org/wikipedia/commons/thumb/1/19/Logo_CopenhagenBusinessSchool.svg/1200px-Logo_CopenhagenBusinessSchool.svg.png"
23
+ SIDEBAR_INFO = f"""
24
+ <div align=center>
25
+ <img src="{LOGO}" width=100/>
26
+
27
+ # NB-GPT-J-6B-NorPaca
28
+
29
+ </div>
30
+
31
+ NB-GPT-J-6B NorPaca is a hybrid of a GPT-3 and Llama model, trained on the Norwegian Colossal Corpus and other Internet sources. It is a 6.7 billion parameter model, and is the largest model in the GPT-J family.
32
+
33
+ This model has been trained with [Mesh Transformer JAX](https://github.com/kingoflolz/mesh-transformer-jax) using TPUs provided by Google through the Tensor Research Cloud program, starting off the [GPT-J-6B model weigths from EleutherAI](https://huggingface.co/EleutherAI/gpt-j-6B), and trained on the [Norwegian Colossal Corpus](https://huggingface.co/datasets/NbAiLab/NCC) and other Internet sources. *This demo runs on {DEVICE.split(':')[0].upper()}*.
34
+
35
+ For more information, visit the [model repository](https://huggingface.co/CBSMasterThesis).
36
+
37
+ ## Configuration
38
+ """.strip()
39
+ PROMPT_BOX = "Enter your text..."
40
+ EXAMPLES = [
41
+ "Nedenfor er en instruksjon som beskriver en oppgave. Skriv et svar som fullfører forespørselen på riktig måte. ### Instruksjon: Analyser fordelene ved å jobbe i et team. ### Respons:",
42
+ 'Nedenfor er en instruksjon som beskriver en oppgave. Skriv et svar som fullfører forespørselen på riktig måte. ### Instruksjon: Oppsummer den faglige artikkelen "Kunstig intelligens og arbeidets fremtid". ### Respons:',
43
+ 'Nedenfor er en instruksjon som beskriver en oppgave. Skriv et svar som fullfører forespørselen på riktig måte. ### Instruksjon: Generer et kreativt slagord for en bedrift som bruker fornybare energikilder. ### Respons:',
44
+ 'Nedenfor er en instruksjon som beskriver en oppgave. Skriv et svar som fullfører forespørselen på riktig måte. ### Instruksjon: Regn ut arealet av en firkant med lengde 10m. Skriv ut et flyttall. ### Respons:',
45
+ ]
46
+
47
+
48
+ def style():
49
+ st.markdown("""
50
+ <link href="https://fonts.googleapis.com/css2?family=Roboto:wght@300&display=swap%22%20rel=%22stylesheet%22" rel="stylesheet">
51
+ <style>
52
+ .ltr,
53
+ textarea {
54
+ font-family: Roboto !important;
55
+ text-align: left;
56
+ direction: ltr !important;
57
+ }
58
+ .ltr-box {
59
+ border-bottom: 1px solid #ddd;
60
+ padding-bottom: 20px;
61
+ }
62
+ .rtl {
63
+ text-align: left;
64
+ direction: ltr !important;
65
+ }
66
+ span.result-text {
67
+ padding: 3px 3px;
68
+ line-height: 32px;
69
+ }
70
+ span.generated-text {
71
+ background-color: rgb(118 200 147 / 13%);
72
+ }
73
+ </style>""", unsafe_allow_html=True)
74
+
75
+
76
+ class Normalizer:
77
+ def remove_repetitions(self, text):
78
+ """Remove repetitions"""
79
+ first_ocurrences = []
80
+ for sentence in text.split("."):
81
+ if sentence not in first_ocurrences:
82
+ first_ocurrences.append(sentence)
83
+ return '.'.join(first_ocurrences)
84
+
85
+ def trim_last_sentence(self, text):
86
+ """Trim last sentence if incomplete"""
87
+ return text[:text.rfind(".") + 1]
88
+
89
+ def clean_txt(self, text):
90
+ return self.trim_last_sentence(self.remove_repetitions(text))
91
+
92
+
93
+ class TextGeneration:
94
+ def __init__(self):
95
+ self.tokenizer = None
96
+ self.generator = None
97
+ self.task = "text-generation"
98
+ self.model_name_or_path = MODEL_NAME
99
+ set_seed(42)
100
+
101
+ def load(self):
102
+ print("Loading model... ", end="")
103
+ self.tokenizer = AutoTokenizer.from_pretrained(
104
+ self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None,
105
+ )
106
+ self.model = AutoModelForCausalLM.from_pretrained(
107
+ self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None,
108
+ pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id,
109
+ torch_dtype=DTYPE, low_cpu_mem_usage=False if DEVICE == "cpu" else True
110
+ ).to(device=DEVICE, non_blocking=True)
111
+ _ = self.model.eval()
112
+ device_number = -1 if DEVICE == "cpu" else int(DEVICE.split(":")[-1])
113
+ self.generator = pipeline(
114
+ self.task, model=self.model, tokenizer=self.tokenizer, device=device_number)
115
+ print("Done")
116
+ # with torch.no_grad():
117
+ # tokens = tokenizer.encode(prompt, return_tensors='pt').to(device=device, non_blocking=True)
118
+ # gen_tokens = self.model.generate(tokens, do_sample=True, temperature=0.8, max_length=128)
119
+ # generated = tokenizer.batch_decode(gen_tokens)[0]
120
+
121
+ # return generated
122
+
123
+ def generate(self, prompt, generation_kwargs):
124
+ max_length = len(self.tokenizer(prompt)[
125
+ "input_ids"]) + generation_kwargs["max_length"]
126
+ generation_kwargs["max_length"] = min(
127
+ max_length, self.model.config.n_positions)
128
+ # generation_kwargs["num_return_sequences"] = 1
129
+ # generation_kwargs["return_full_text"] = False
130
+ return self.generator(
131
+ prompt,
132
+ **generation_kwargs,
133
+ )[0]["generated_text"]
134
+
135
+
136
+ # @st.cache(allow_output_mutation=True, hash_funcs={AutoModelForCausalLM: lambda _: None})
137
+ @st.cache(allow_output_mutation=True, hash_funcs={TextGeneration: lambda _: None})
138
+ def load_text_generator():
139
+ generator = TextGeneration()
140
+ generator.load()
141
+ return generator
142
+
143
+
144
+ def main():
145
+ st.set_page_config(
146
+ page_title="NB-GPT-J-6B-NorPaca",
147
+ page_icon="🇳🇴",
148
+ layout="wide",
149
+ initial_sidebar_state="expanded"
150
+ )
151
+ style()
152
+ with st.spinner('Loading the model. Please, wait...'):
153
+ generator = load_text_generator()
154
+
155
+ st.sidebar.markdown(SIDEBAR_INFO, unsafe_allow_html=True)
156
+ query_params = st.experimental_get_query_params()
157
+ if query_params:
158
+ st.experimental_set_query_params(**dict())
159
+
160
+ max_length = st.sidebar.slider(
161
+ label='Max words to generate',
162
+ help="The maximum length of the sequence to be generated.",
163
+ min_value=1,
164
+ max_value=MAX_LENGTH,
165
+ value=int(query_params.get("max_length", [50])[0]),
166
+ step=1
167
+ )
168
+ top_k = st.sidebar.slider(
169
+ label='Top-k',
170
+ help="The number of highest probability vocabulary tokens to keep for top-k-filtering",
171
+ min_value=40,
172
+ max_value=80,
173
+ value=int(query_params.get("top_k", [50])[0]),
174
+ step=1
175
+ )
176
+ top_p = st.sidebar.slider(
177
+ label='Top-p',
178
+ help="Only the most probable tokens with probabilities that add up to `top_p` or higher are kept for "
179
+ "generation.",
180
+ min_value=0.0,
181
+ max_value=1.0,
182
+ value=float(query_params.get("top_p", [0.95])[0]),
183
+ step=0.01
184
+ )
185
+ temperature = st.sidebar.slider(
186
+ label='Temperature',
187
+ help="The value used to module the next token probabilities",
188
+ min_value=0.1,
189
+ max_value=10.0,
190
+ value=float(query_params.get("temperature", [0.8])[0]),
191
+ step=0.05
192
+ )
193
+ do_sample = st.sidebar.selectbox(
194
+ label='Sampling?',
195
+ options=(False, True),
196
+ help="Whether or not to use sampling; use greedy decoding otherwise.",
197
+ index=int(query_params.get("do_sample", ["true"])[
198
+ 0].lower()[0] in ("t", "y", "1")),
199
+ )
200
+ do_clean = st.sidebar.selectbox(
201
+ label='Clean text?',
202
+ options=(False, True),
203
+ help="Whether or not to remove repeated words and trim unfinished last sentences.",
204
+ index=int(query_params.get("do_clean", ["true"])[
205
+ 0].lower()[0] in ("t", "y", "1")),
206
+ )
207
+ generation_kwargs = {
208
+ "max_length": max_length,
209
+ "top_k": top_k,
210
+ "top_p": top_p,
211
+ "temperature": temperature,
212
+ "do_sample": do_sample,
213
+ "do_clean": do_clean,
214
+ }
215
+ st.markdown(HEADER_INFO)
216
+ prompts = EXAMPLES + ["Custom"]
217
+ prompt = st.selectbox('Examples', prompts, index=len(prompts) - 1)
218
 
219
+ if prompt == "Custom":
220
+ prompt_box = query_params.get("text", [PROMPT_BOX])[0].strip()
221
+ else:
222
+ prompt_box = prompt
223
 
224
+ text = st.text_area("Enter text", prompt_box)
225
+ generation_kwargs_ph = st.empty()
226
+ cleaner = Normalizer()
227
+ if st.button("Generate!") or text != PROMPT_BOX:
228
+ output = st.empty()
229
+ with st.spinner(text="Generating..."):
230
+ generation_kwargs_ph.markdown(
231
+ ", ".join([f"`{k}`: {v}" for k, v in generation_kwargs.items()]))
232
+ if text:
233
+ share_args = {"text": text, **generation_kwargs}
234
+ st.experimental_set_query_params(**share_args)
235
+ for _ in range(5):
236
+ generated_text = generator.generate(
237
+ text, generation_kwargs)
238
+ if do_clean:
239
+ generated_text = cleaner.clean_txt(generated_text)
240
+ if generated_text.strip().startswith(text):
241
+ generated_text = generated_text.replace(
242
+ text, "", 1).strip()
243
+ output.markdown(
244
+ f'<p class="ltr ltr-box">'
245
+ f'<span class="result-text">{text} <span>'
246
+ f'<span class="result-text generated-text">{generated_text}</span>'
247
+ f'</p>',
248
+ unsafe_allow_html=True
249
+ )
250
+ if generated_text.strip():
251
+ components.html(
252
+ f"""
253
+ <a class="twitter-share-button"
254
+ data-text="Check my prompt using NB-GPT-J-6B-NorPaca!🇳🇴 https://ai.nb.no/demo/nb-gpt-j-6B-NorPaca/?{urlencode(share_args)}"
255
+ data-show-count="false">
256
+ data-size="Small"
257
+ data-hashtags="nb,gpt-j"
258
+ Tweet
259
+ </a>
260
+ <script async src="https://platform.twitter.com/widgets.js" charset="utf-8"></script>
261
+ """
262
+ )
263
+ break
264
+ if not generated_text.strip():
265
+ st.markdown(
266
+ "*Tried 5 times but did not produce any result. Try again!*")
267
 
268
 
269
  if __name__ == '__main__':
270
+ main()