project-baize commited on
Commit
3d2cf1e
1 Parent(s): 7edc905

Update app_modules/utils.py

Browse files
Files changed (1) hide show
  1. app_modules/utils.py +36 -28
app_modules/utils.py CHANGED
@@ -13,6 +13,7 @@ import html
13
  import markdown2
14
  import torch
15
  import sys
 
16
  from pygments.lexers import guess_lexer, ClassNotFound
17
 
18
  import gradio as gr
@@ -255,34 +256,41 @@ def greedy_search(input_ids: torch.Tensor,
255
  logits = outputs.logits[:, -1, :]
256
  past_key_values = outputs.past_key_values
257
 
258
- # apply temperature
259
- logits /= temperature
260
-
261
- probs = torch.softmax(logits, dim=-1)
262
- # apply top_p
263
- probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
264
- probs_sum = torch.cumsum(probs_sort, dim=-1)
265
- mask = probs_sum - probs_sort > top_p
266
- probs_sort[mask] = 0.0
267
-
268
- # apply top_k
269
- #if top_k is not None:
270
- # probs_sort1, _ = torch.topk(probs_sort, top_k)
271
- # min_top_probs_sort = torch.min(probs_sort1, dim=-1, keepdim=True).values
272
- # probs_sort = torch.where(probs_sort < min_top_probs_sort, torch.full_like(probs_sort, float(0.0)), probs_sort)
273
-
274
- probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
275
- next_token = torch.multinomial(probs_sort, num_samples=1)
276
- next_token = torch.gather(probs_idx, -1, next_token)
277
-
278
- input_ids = torch.cat((input_ids, next_token), dim=-1)
279
-
280
- generated_tokens.append(next_token[0].item())
281
- text = tokenizer.decode(generated_tokens)
282
-
283
- yield text
284
- if any([x in text for x in stop_words]):
285
- return
 
 
 
 
 
 
 
286
 
287
  def generate_prompt_with_history(text,history,tokenizer,max_length=2048):
288
  prompt = "The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.\n[|Human|]Hello!\n[|AI|]Hi!"
 
13
  import markdown2
14
  import torch
15
  import sys
16
+ import gc
17
  from pygments.lexers import guess_lexer, ClassNotFound
18
 
19
  import gradio as gr
 
256
  logits = outputs.logits[:, -1, :]
257
  past_key_values = outputs.past_key_values
258
 
259
+ # apply temperature
260
+ logits /= temperature
261
+
262
+ probs = torch.softmax(logits, dim=-1)
263
+ # apply top_p
264
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
265
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
266
+ mask = probs_sum - probs_sort > top_p
267
+ probs_sort[mask] = 0.0
268
+
269
+ # apply top_k
270
+ #if top_k is not None:
271
+ # probs_sort1, _ = torch.topk(probs_sort, top_k)
272
+ # min_top_probs_sort = torch.min(probs_sort1, dim=-1, keepdim=True).values
273
+ # probs_sort = torch.where(probs_sort < min_top_probs_sort, torch.full_like(probs_sort, float(0.0)), probs_sort)
274
+
275
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
276
+ next_token = torch.multinomial(probs_sort, num_samples=1)
277
+ next_token = torch.gather(probs_idx, -1, next_token)
278
+
279
+ input_ids = torch.cat((input_ids, next_token), dim=-1)
280
+
281
+ generated_tokens.append(next_token[0].item())
282
+ text = tokenizer.decode(generated_tokens)
283
+
284
+ yield text
285
+ if any([x in text for x in stop_words]):
286
+ del past_key_values
287
+ del logits
288
+ del probs
289
+ del probs_sort
290
+ del probs_idx
291
+ del probs_sum
292
+ gc.collect()
293
+ return
294
 
295
  def generate_prompt_with_history(text,history,tokenizer,max_length=2048):
296
  prompt = "The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.\n[|Human|]Hello!\n[|AI|]Hi!"