jmercat commited on
Commit
5313bd0
·
1 Parent(s): 01a6f6a

use autocast

Browse files
Files changed (1) hide show
  1. app.py +25 -19
app.py CHANGED
@@ -1,10 +1,14 @@
1
- import spaces
2
- import gradio as gr
3
  from threading import Thread
4
- from open_lm.hf import *
 
 
 
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
6
  import torch
7
- from gradio.layouts import Accordion
 
 
 
8
 
9
  # Define model options
10
  MODEL_OPTIONS = {
@@ -39,23 +43,25 @@ def generate(
39
  top_p = float(top_p)
40
 
41
  inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
42
-
43
- generate_kwargs = dict(
44
- **inputs,
45
- max_new_tokens=max_new_tokens,
46
- temperature=temperature,
47
- top_p=top_p,
48
- repetition_penalty=repetition_penalty,
49
- do_sample=True,
50
- pad_token_id=current_tokenizer.eos_token_id
51
- )
 
 
52
 
53
- streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
54
- streamer.stop_signal = current_tokenizer.decode(current_tokenizer.eos_token_id)
55
- generate_kwargs["streamer"] = streamer
56
 
57
- thread = Thread(target=current_model.generate, kwargs=generate_kwargs)
58
- thread.start()
59
 
60
  # Write the prompt in blue
61
  output = "<span style='color: blue;'>" + prompt + "</span>"
 
 
 
1
  from threading import Thread
2
+
3
+ import gradio as gr
4
+ from gradio.layouts import Accordion
5
+ import spaces
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
7
  import torch
8
+
9
+ from open_lm.hf import *
10
+ from open_lm.precision import get_autocast
11
+
12
 
13
  # Define model options
14
  MODEL_OPTIONS = {
 
43
  top_p = float(top_p)
44
 
45
  inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
46
+ autocast = get_autocast("amp_bf16")
47
+
48
+ with autocast():
49
+ generate_kwargs = dict(
50
+ **inputs,
51
+ max_new_tokens=max_new_tokens,
52
+ temperature=temperature,
53
+ top_p=top_p,
54
+ repetition_penalty=repetition_penalty,
55
+ do_sample=True,
56
+ pad_token_id=current_tokenizer.eos_token_id
57
+ )
58
 
59
+ streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
60
+ streamer.stop_signal = current_tokenizer.decode(current_tokenizer.eos_token_id)
61
+ generate_kwargs["streamer"] = streamer
62
 
63
+ thread = Thread(target=current_model.generate, kwargs=generate_kwargs)
64
+ thread.start()
65
 
66
  # Write the prompt in blue
67
  output = "<span style='color: blue;'>" + prompt + "</span>"