nawhgnuj commited on
Commit
6836f82
·
verified ·
1 Parent(s): fd40b8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -36
app.py CHANGED
@@ -1,10 +1,7 @@
1
  import os
2
- import time
3
- import spaces
4
  import torch
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
6
  import gradio as gr
7
- from threading import Thread
8
 
9
  MODEL_LIST = ["nawhgnuj/KamalaHarris-Llama-3.1-8B-Chat"]
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
@@ -60,8 +57,7 @@ model = AutoModelForCausalLM.from_pretrained(
60
  device_map="auto",
61
  quantization_config=quantization_config)
62
 
63
- @spaces.GPU()
64
- def stream_chat(
65
  message: str,
66
  history: list,
67
  temperature: float,
@@ -91,33 +87,23 @@ Crucially, Keep responses concise and impactful."""
91
  conversation.append({"role": "user", "content": message})
92
 
93
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
94
- attention_mask = torch.ones_like(input_ids)
95
 
96
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
97
-
98
- generate_kwargs = dict(
99
- input_ids=input_ids,
100
- attention_mask=attention_mask,
101
- max_new_tokens=max_new_tokens,
102
- do_sample=True,
103
- top_p=top_p,
104
- top_k=top_k,
105
- temperature=temperature,
106
- repetition_penalty=repetition_penalty,
107
- no_repeat_ngram_size=no_repeat_ngram_size,
108
- pad_token_id=tokenizer.pad_token_id,
109
- eos_token_id=tokenizer.eos_token_id,
110
- streamer=streamer,
111
- )
112
-
113
  with torch.no_grad():
114
- thread = Thread(target=model.generate, kwargs=generate_kwargs)
115
- thread.start()
116
-
117
- buffer = ""
118
- for new_text in streamer:
119
- buffer += new_text
120
- yield buffer
 
 
 
 
 
 
 
 
121
 
122
  def add_text(history, text):
123
  history = history + [(text, None)]
@@ -125,11 +111,9 @@ def add_text(history, text):
125
 
126
  def bot(history, temperature, max_new_tokens, top_p, top_k, repetition_penalty, no_repeat_ngram_size):
127
  user_message = history[-1][0]
128
- bot_response = stream_chat(user_message, history[:-1], temperature, max_new_tokens, top_p, top_k, repetition_penalty, no_repeat_ngram_size)
129
- history[-1][1] = ""
130
- for character in bot_response:
131
- history[-1][1] += character
132
- yield history
133
 
134
  with gr.Blocks(css=CSS, theme=gr.themes.Default()) as demo:
135
  gr.HTML(TITLE)
 
1
  import os
 
 
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
  import gradio as gr
 
5
 
6
  MODEL_LIST = ["nawhgnuj/KamalaHarris-Llama-3.1-8B-Chat"]
7
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
57
  device_map="auto",
58
  quantization_config=quantization_config)
59
 
60
+ def generate_response(
 
61
  message: str,
62
  history: list,
63
  temperature: float,
 
87
  conversation.append({"role": "user", "content": message})
88
 
89
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
 
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  with torch.no_grad():
92
+ output = model.generate(
93
+ input_ids,
94
+ max_new_tokens=max_new_tokens,
95
+ do_sample=True,
96
+ top_p=top_p,
97
+ top_k=top_k,
98
+ temperature=temperature,
99
+ repetition_penalty=repetition_penalty,
100
+ no_repeat_ngram_size=no_repeat_ngram_size,
101
+ pad_token_id=tokenizer.pad_token_id,
102
+ eos_token_id=tokenizer.eos_token_id,
103
+ )
104
+
105
+ response = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
106
+ return response.strip()
107
 
108
  def add_text(history, text):
109
  history = history + [(text, None)]
 
111
 
112
  def bot(history, temperature, max_new_tokens, top_p, top_k, repetition_penalty, no_repeat_ngram_size):
113
  user_message = history[-1][0]
114
+ bot_response = generate_response(user_message, history[:-1], temperature, max_new_tokens, top_p, top_k, repetition_penalty, no_repeat_ngram_size)
115
+ history[-1][1] = bot_response
116
+ return history
 
 
117
 
118
  with gr.Blocks(css=CSS, theme=gr.themes.Default()) as demo:
119
  gr.HTML(TITLE)