eswardivi commited on
Commit
7dc3087
1 Parent(s): 9b7a2bc

Added Better Inferencing techq

Browse files
Files changed (1) hide show
  1. app.py +24 -12
app.py CHANGED
@@ -1,29 +1,36 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  import os
5
  from threading import Thread
6
  import spaces
 
7
 
8
  token = os.environ["HF_TOKEN"]
 
 
 
 
 
 
9
  model = AutoModelForCausalLM.from_pretrained("google/gemma-1.1-7b-it",
10
- # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
11
- torch_dtype=torch.float16,
12
  token=token)
13
- tok = AutoTokenizer.from_pretrained("google/gemma-1.1-7b-it",token=token)
14
- # using CUDA for an optimal experience
15
- # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
  if torch.cuda.is_available():
17
  device = torch.device('cuda')
18
  print(f"Using GPU: {torch.cuda.get_device_name(device)}")
19
  else:
20
  device = torch.device('cpu')
21
  print("Using CPU")
 
22
  model = model.to(device)
23
-
24
 
25
  @spaces.GPU
26
  def chat(message, history):
 
27
  chat = []
28
  for item in history:
29
  chat.append({"role": "user", "content": item[0]})
@@ -31,7 +38,6 @@ def chat(message, history):
31
  chat.append({"role": "assistant", "content": item[1]})
32
  chat.append({"role": "user", "content": message})
33
  messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
34
- # Tokenize the messages string
35
  model_inputs = tok([messages], return_tensors="pt").to(device)
36
  streamer = TextIteratorStreamer(
37
  tok, timeout=10., skip_prompt=True, skip_special_tokens=True)
@@ -48,15 +54,21 @@ def chat(message, history):
48
  t = Thread(target=model.generate, kwargs=generate_kwargs)
49
  t.start()
50
 
51
- # Initialize an empty string to store the generated text
52
  partial_text = ""
 
53
  for new_text in streamer:
54
- # print(new_text)
 
55
  partial_text += new_text
56
- # Yield an empty string to cleanup the message textbox and the updated conversation history
57
  yield partial_text
58
 
 
 
 
59
 
 
 
 
60
 
61
- demo = gr.ChatInterface(fn=chat, examples=[["Write me a poem about Machine Learning."]], title="gemma-1.1-7b-it")
62
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
4
  import os
5
  from threading import Thread
6
  import spaces
7
+ import time
8
 
9
  token = os.environ["HF_TOKEN"]
10
+
11
+ quantization_config = BitsAndBytesConfig(
12
+ load_in_4bit=True,
13
+ bnb_4bit_compute_dtype=torch.float16
14
+ )
15
+
16
  model = AutoModelForCausalLM.from_pretrained("google/gemma-1.1-7b-it",
17
+ quantization_config=quantization_config,
 
18
  token=token)
19
+ tok = AutoTokenizer.from_pretrained("google/gemma-1.1-7b-it", token=token)
20
+
 
21
  if torch.cuda.is_available():
22
  device = torch.device('cuda')
23
  print(f"Using GPU: {torch.cuda.get_device_name(device)}")
24
  else:
25
  device = torch.device('cpu')
26
  print("Using CPU")
27
+
28
  model = model.to(device)
29
+ model = model.to_bettertransformer()
30
 
31
  @spaces.GPU
32
  def chat(message, history):
33
+ start_time = time.time()
34
  chat = []
35
  for item in history:
36
  chat.append({"role": "user", "content": item[0]})
 
38
  chat.append({"role": "assistant", "content": item[1]})
39
  chat.append({"role": "user", "content": message})
40
  messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
 
41
  model_inputs = tok([messages], return_tensors="pt").to(device)
42
  streamer = TextIteratorStreamer(
43
  tok, timeout=10., skip_prompt=True, skip_special_tokens=True)
 
54
  t = Thread(target=model.generate, kwargs=generate_kwargs)
55
  t.start()
56
 
 
57
  partial_text = ""
58
+ first_token_time = None
59
  for new_text in streamer:
60
+ if not first_token_time:
61
+ first_token_time = time.time() - start_time
62
  partial_text += new_text
 
63
  yield partial_text
64
 
65
+ total_time = time.time() - start_time
66
+ tokens = len(tok.tokenize(partial_text))
67
+ tokens_per_second = tokens / total_time if total_time > 0 else 0
68
 
69
+ # Append the timing information to the final output
70
+ timing_info = f"\nTime taken to first token: {first_token_time:.2f} seconds\nTokens per second: {tokens_per_second:.2f}"
71
+ yield partial_text + timing_info
72
 
73
+ demo = gr.ChatInterface(fn=chat, examples=[["Write me a poem about Machine Learning."]], title="Chat With LLMS")
74
  demo.launch()