peder commited on
Commit
b08d994
·
1 Parent(s): c18b92f

cache added

Browse files
app.py CHANGED
@@ -107,6 +107,7 @@ class TextGeneration:
107
  self.model_name_or_path = MODEL_NAME
108
  set_seed(42)
109
 
 
110
  def load(self):
111
  print("Loading model... ", end="")
112
  self.tokenizer = AutoTokenizer.from_pretrained(
@@ -168,6 +169,7 @@ def generate_prompt(instruction, input=None):
168
 
169
  # @st.cache(allow_output_mutation=True, hash_funcs={AutoModelForCausalLM: lambda _: None})
170
  # @st.cache(allow_output_mutation=True, hash_funcs={TextGeneration: lambda _: None})
 
171
  def load_text_generator():
172
  generator = TextGeneration()
173
  generator.load()
 
107
  self.model_name_or_path = MODEL_NAME
108
  set_seed(42)
109
 
110
+ @st.cache_resource
111
  def load(self):
112
  print("Loading model... ", end="")
113
  self.tokenizer = AutoTokenizer.from_pretrained(
 
169
 
170
  # @st.cache(allow_output_mutation=True, hash_funcs={AutoModelForCausalLM: lambda _: None})
171
  # @st.cache(allow_output_mutation=True, hash_funcs={TextGeneration: lambda _: None})
172
+ #@st.cache_resource
173
  def load_text_generator():
174
  generator = TextGeneration()
175
  generator.load()
app.py.d264c618bf578cebbd21d2c2379d7b50.tmp DELETED
@@ -1,321 +0,0 @@
1
- import random
2
- import os
3
- from urllib.parse import urlencode
4
- #from pyngrok import ngrok
5
-
6
- import streamlit as st
7
- import streamlit.components.v1 as components
8
- import torch
9
- from transformers import pipeline, set_seed
10
- from transformers import AutoTokenizer, AutoModelForCausalLM
11
-
12
- # #import torch
13
- # print(f"Is CUDA available: {torch.cuda.is_available()}")
14
- # # True
15
- # print(
16
- # f"CUDA device for you Perrito: {torch.cuda.get_device_name(torch.cuda.current_device())}")
17
- # # Tesla T4
18
-
19
- HF_AUTH_TOKEN = "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"
20
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
- #print("DEVICE SENOOOOOR", DEVICE)
22
- DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16
23
- MODEL_NAME = os.environ.get("MODEL_NAME", "NbAiLab/nb-gpt-j-6B-alpaca")
24
- MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 256))
25
-
26
- HEADER_INFO = """
27
- # GPT-NorPaca
28
- Norwegian GPT-J-6B NorPaca Model.
29
- """.strip()
30
- LOGO = "https://upload.wikimedia.org/wikipedia/commons/thumb/1/19/Logo_CopenhagenBusinessSchool.svg/1200px-Logo_CopenhagenBusinessSchool.svg.png"
31
- SIDEBAR_INFO = f"""
32
- <div align=center>
33
- <img src="{LOGO}" width=100/>
34
-
35
- # NB-GPT-J-6B-NorPaca
36
-
37
- </div>
38
-
39
- NB-GPT-J-6B NorPaca is a hybrid of a GPT-3 and Llama model, trained on the Norwegian Colossal Corpus and other Internet sources. It is a 6.7 billion parameter model, and is the largest model in the GPT-J family.
40
-
41
- This model has been trained with [Mesh Transformer JAX](https://github.com/kingoflolz/mesh-transformer-jax) using TPUs provided by Google through the Tensor Research Cloud program, starting off the [GPT-J-6B model weigths from EleutherAI](https://huggingface.co/EleutherAI/gpt-j-6B), and trained on the [Norwegian Colossal Corpus](https://huggingface.co/datasets/NbAiLab/NCC) and other Internet sources. *This demo runs on {DEVICE}*.
42
-
43
- For more information, visit the [model repository](https://huggingface.co/CBSMasterThesis).
44
-
45
- ## Configuration
46
- """.strip()
47
- PROMPT_BOX_INSTRUCTION = "Enter your Instructions here..."
48
- PROMPT_BOX_INPUT = "Enter your Input here..."
49
- EXAMPLES = [
50
- "Nedenfor er en instruksjon som beskriver en oppgave. Skriv et svar som fullfører forespørselen på riktig måte. ### Instruksjon: Analyser fordelene ved å jobbe i et team. ### Respons:",
51
- 'Nedenfor er en instruksjon som beskriver en oppgave. Skriv et svar som fullfører forespørselen på riktig måte. ### Instruksjon: Oppsummer den faglige artikkelen "Kunstig intelligens og arbeidets fremtid". ### Respons:',
52
- 'Nedenfor er en instruksjon som beskriver en oppgave. Skriv et svar som fullfører forespørselen på riktig måte. ### Instruksjon: Generer et kreativt slagord for en bedrift som bruker fornybare energikilder. ### Respons:',
53
- 'Nedenfor er en instruksjon som beskriver en oppgave. Skriv et svar som fullfører forespørselen på riktig måte. ### Instruksjon: Regn ut arealet av en firkant med lengde 10m. Skriv ut et flyttall. ### Respons:',
54
- ]
55
-
56
-
57
- def style():
58
- st.markdown("""
59
- <link href="https://fonts.googleapis.com/css2?family=Roboto:wght@300&display=swap%22%20rel=%22stylesheet%22" rel="stylesheet">
60
- <style>
61
- .ltr,
62
- textarea {
63
- font-family: Roboto !important;
64
- text-align: left;
65
- direction: ltr !important;
66
- }
67
- .ltr-box {
68
- border-bottom: 1px solid #ddd;
69
- padding-bottom: 20px;
70
- }
71
- .rtl {
72
- text-align: left;
73
- direction: ltr !important;
74
- }
75
- span.result-text {
76
- padding: 3px 3px;
77
- line-height: 32px;
78
- }
79
- span.generated-text {
80
- background-color: rgb(118 200 147 / 13%);
81
- }
82
- </style>""", unsafe_allow_html=True)
83
-
84
-
85
- class Normalizer:
86
- def remove_repetitions(self, text):
87
- """Remove repetitions"""
88
- first_ocurrences = []
89
- for sentence in text.split("."):
90
- if sentence not in first_ocurrences:
91
- first_ocurrences.append(sentence)
92
- return '.'.join(first_ocurrences)
93
-
94
- def trim_last_sentence(self, text):
95
- """Trim last sentence if incomplete"""
96
- return text[:text.rfind(".") + 1]
97
-
98
- def clean_txt(self, text):
99
- return self.trim_last_sentence(self.remove_repetitions(text))
100
-
101
-
102
- class TextGeneration:
103
- def __init__(self):
104
- self.tokenizer = None
105
- self.generator = None
106
- self.task = "text-generation"
107
- self.model_name_or_path = MODEL_NAME
108
- set_seed(42)
109
-
110
- def load(self):
111
- print("Loading model... ", end="")
112
- self.tokenizer = AutoTokenizer.from_pretrained(
113
- self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None,
114
- )
115
- self.model = AutoModelForCausalLM.from_pretrained(
116
- self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None,
117
- pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id,
118
- torch_dtype=DTYPE, low_cpu_mem_usage=False if DEVICE == "cpu" else True
119
- ).to(device=DEVICE, non_blocking=True)
120
- _ = self.model.eval()
121
- # -1 if DEVICE == "cpu" else int(DEVICE.split(":")[-1])
122
- device_number = torch.cuda.current_device()
123
- self.generator = pipeline(
124
- self.task, model=self.model, tokenizer=self.tokenizer, device=device_number)
125
- print("Done")
126
- # with torch.no_grad():
127
- # tokens = tokenizer.encode(prompt, return_tensors='pt').to(device=device, non_blocking=True)
128
- # gen_tokens = self.model.generate(tokens, do_sample=True, temperature=0.8, max_length=128)
129
- # generated = tokenizer.batch_decode(gen_tokens)[0]
130
-
131
- # return generated
132
-
133
- def generate(self, prompt, generation_kwargs):
134
- max_length = len(self.tokenizer(prompt)[
135
- "input_ids"]) + generation_kwargs["max_length"]
136
- generation_kwargs["max_length"] = min(
137
- max_length, self.model.config.n_positions)
138
- # generation_kwargs["num_return_sequences"] = 1
139
- # generation_kwargs["return_full_text"] = False
140
- return self.generator(
141
- prompt,
142
- **generation_kwargs,
143
- )[0]["generated_text"]
144
-
145
- # Generate responses
146
-
147
-
148
- def generate_prompt(instruction, input=None):
149
- if input:
150
- prompt = f"""Nedenfor er en instruksjon som beskriver en oppgave, sammen med et input som gir ytterligere kontekst. Skriv et svar som fullfører forespørselen på riktig måte.
151
-
152
- ### Instruksjon:
153
- {instruction}
154
-
155
- ### Input:
156
- {input}
157
-
158
- ### Respons:"""
159
- else:
160
- prompt = f""""Nedenfor er en instruksjon som beskriver en oppgave. Skriv et svar som fullfører forespørselen på riktig måte.
161
-
162
- ### Instruksjon:
163
- {instruction}
164
-
165
- ### Respons:"""
166
- return prompt
167
-
168
-
169
- # @st.cache(allow_output_mutation=True, hash_funcs={AutoModelForCausalLM: lambda _: None})
170
- # @st.cache(allow_output_mutation=True, hash_funcs={TextGeneration: lambda _: None})
171
- def load_text_generator():
172
- generator = TextGeneration()
173
- generator.load()
174
- return generator
175
-
176
-
177
- def main():
178
- st.set_page_config(
179
- page_title="NB-GPT-J-6B-NorPaca",
180
- page_icon="🇳🇴",
181
- layout="wide",
182
- initial_sidebar_state="expanded"
183
- )
184
- style()
185
- with st.spinner('Loading the model. Please, wait...'):
186
- generator = load_text_generator()
187
-
188
- st.sidebar.markdown(SIDEBAR_INFO, unsafe_allow_html=True)
189
- query_params = st.experimental_get_query_params()
190
- if query_params:
191
- st.experimental_set_query_params(**dict())
192
-
193
- max_length = st.sidebar.slider(
194
- label='Max words to generate',
195
- help="The maximum length of the sequence to be generated.",
196
- min_value=1,
197
- max_value=MAX_LENGTH,
198
- value=int(query_params.get("max_length", [50])[0]),
199
- step=1
200
- )
201
- top_k = st.sidebar.slider(
202
- label='Top-k',
203
- help="The number of highest probability vocabulary tokens to keep for top-k-filtering",
204
- min_value=40,
205
- max_value=80,
206
- value=int(query_params.get("top_k", [50])[0]),
207
- step=1
208
- )
209
- top_p = st.sidebar.slider(
210
- label='Top-p',
211
- help="Only the most probable tokens with probabilities that add up to `top_p` or higher are kept for "
212
- "generation.",
213
- min_value=0.0,
214
- max_value=1.0,
215
- value=float(query_params.get("top_p", [0.75])[0]),
216
- step=0.01
217
- )
218
- temperature = st.sidebar.slider(
219
- label='Temperature',
220
- help="The value used to module the next token probabilities",
221
- min_value=0.1,
222
- max_value=10.0,
223
- value=float(query_params.get("temperature", [0.2])[0]),
224
- step=0.05
225
- )
226
- do_sample = st.sidebar.selectbox(
227
- label='Sampling?',
228
- options=(False, True),
229
- help="Whether or not to use sampling; use greedy decoding otherwise.",
230
- index=int(query_params.get("do_sample", ["true"])[
231
- 0].lower()[0] in ("t", "y", "1")),
232
- )
233
- generation_kwargs = {
234
- "max_length": max_length,
235
- "top_k": top_k,
236
- "top_p": top_p,
237
- "temperature": temperature,
238
- "do_sample": do_sample,
239
- # "do_clean": do_clean,
240
- }
241
- st.markdown(HEADER_INFO)
242
- prompts = EXAMPLES + ["Custom"]
243
- prompt = st.selectbox('Examples', prompts, index=len(prompts) - 1)
244
-
245
- if prompt == "Custom":
246
- prompt_box_instruction = query_params.get(
247
- "text1", [PROMPT_BOX_INSTRUCTION])[0].strip()
248
- prompt_box_input = query_params.get(
249
- "text2", [PROMPT_BOX_INPUT])[0].strip()
250
- prompt_box = f"{prompt_box_instruction} {prompt_box_input}"
251
- else:
252
- if "### Input:" in prompt:
253
- prompt_box_instruction = prompt.split("### Instruksjon:")[
254
- 1].split("### Input:")[0].strip()
255
- prompt_box_input = prompt.split(
256
- "### Input:")[1].split("### Respons:")[0].strip()
257
- else:
258
- prompt_box_instruction = prompt.split(
259
- "### Instruksjon:")[1].split("### Respons:")[0].strip()
260
- prompt_box_input = None
261
- prompt_box = prompt
262
-
263
- if prompt == "Custom":
264
- text_instruction = st.text_area(
265
- "Enter Instruction", PROMPT_BOX_INSTRUCTION)
266
- text_input = st.text_area("Enter Input", PROMPT_BOX_INPUT)
267
- else:
268
- text_instruction = st.text_area(
269
- "Enter Instruction", prompt_box_instruction)
270
- text_input = st.text_area("Enter Input", prompt_box_input) if "### Input:" in prompt else st.text_area(
271
- "Enter Input", PROMPT_BOX_INPUT)
272
-
273
- generation_kwargs_ph = st.empty()
274
- cleaner = Normalizer()
275
- if st.button("Generate!"):
276
- output = st.empty()
277
- with st.spinner(text="Generating..."):
278
- generation_kwargs_ph.markdown(
279
- ", ".join([f"`{k}`: {v}" for k, v in generation_kwargs.items()]))
280
- if text_instruction:
281
- text = generate_prompt(text_instruction, text_input) if text_input != "Enter your Input here..." else generate_prompt(
282
- text_instruction)
283
- #print("TEXT OUT", text)
284
- share_args = {"text": text, **generation_kwargs}
285
- st.experimental_set_query_params(**share_args)
286
- for _ in range(5):
287
- generated_text = generator.generate(
288
- text, generation_kwargs)
289
- # if do_clean:
290
- # generated_text = cleaner.clean_txt(generated_text)
291
- if generated_text.strip().startswith(text):
292
- generated_text = generated_text.replace(
293
- text, "", 1).strip()
294
- output.markdown(
295
- f'<p class="ltr ltr-box">'
296
- f'<span class="result-text">{text} <span>'
297
- f'<span class="result-text generated-text">{generated_text}</span>'
298
- f'</p>',
299
- unsafe_allow_html=True
300
- )
301
- if generated_text.strip():
302
- components.html(
303
- f"""
304
- <a class="twitter-share-button"
305
- data-text="Check my prompt using NB-GPT-J-6B-NorPaca!🇳🇴 https://ai.nb.no/demo/nb-gpt-j-6B-NorPaca/?{urlencode(share_args)}"
306
- data-show-count="false">
307
- data-size="Small"
308
- data-hashtags="nb,gpt-j"
309
- Tweet
310
- </a>
311
- <script async src="https://platform.twitter.com/widgets.js" charset="utf-8"></script>
312
- """
313
- )
314
- break
315
- if not generated_text.strip():
316
- st.markdown(
317
- "*Tried 5 times but did not produce any result. Try again!*")
318
-
319
-
320
- if __name__ == '__main__':
321
- main()