mgoin commited on
Commit
21bd07c
1 Parent(s): 24e1981

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -13
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import uuid
3
 
4
  import gradio as gr
5
- # import spaces
6
  import torch
7
  from transformers import AutoTokenizer
8
  from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
@@ -10,22 +10,27 @@ from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
10
  MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 1024
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
13
 
14
- DESCRIPTION = """\
15
- # NM vLLM Hermes Mistral Chat
 
16
  """
17
 
18
  if not torch.cuda.is_available():
19
  raise ValueError("Running on CPU 🥶 This demo does not work on CPU.")
20
 
21
- model_id = "neuralmagic/OpenHermes-2.5-Mistral-7B-pruned50"
22
- engine_args = AsyncEngineArgs(model=model_id, sparsity="sparse_w16a16", max_model_len=MAX_INPUT_TOKEN_LENGTH)
 
 
 
23
  engine = AsyncLLMEngine.from_engine_args(engine_args)
24
 
25
- tokenizer = AutoTokenizer.from_pretrained(model_id)
26
  tokenizer.use_default_system_prompt = False
27
 
28
- # @spaces.GPU
29
  async def generate(
30
  message: str,
31
  chat_history: list[tuple[str, str]],
@@ -37,13 +42,22 @@ async def generate(
37
  repetition_penalty: float = 1.2,
38
  ):
39
  conversation = []
 
40
  if system_prompt:
41
  conversation.append({"role": "system", "content": system_prompt})
 
42
  for user, assistant in chat_history:
43
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
 
 
 
 
 
44
  conversation.append({"role": "user", "content": message})
45
 
46
- formatted_conversation = tokenizer.apply_chat_template(conversation, tokenize=False)
 
 
47
 
48
  sampling_params = SamplingParams(
49
  max_tokens=max_new_tokens,
@@ -53,8 +67,10 @@ async def generate(
53
  repetition_penalty=repetition_penalty,
54
  )
55
 
56
- stream = await engine.add_request(uuid.uuid4().hex, formatted_conversation, sampling_params)
57
-
 
 
58
  async for request_output in stream:
59
  text = request_output.outputs[0].text
60
  yield text
@@ -112,8 +128,10 @@ chat_interface = gr.ChatInterface(
112
 
113
  with gr.Blocks(css="style.css") as demo:
114
  gr.Markdown(DESCRIPTION)
115
- gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
 
 
116
  chat_interface.render()
117
 
118
  if __name__ == "__main__":
119
- demo.queue(max_size=20).launch()
 
2
  import uuid
3
 
4
  import gradio as gr
5
+
6
  import torch
7
  from transformers import AutoTokenizer
8
  from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
 
10
  MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 1024
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
+ MODEL_ID = "neuralmagic/OpenHermes-2.5-Mistral-7B-pruned50"
14
 
15
+ DESCRIPTION = f"""\
16
+ # NM vLLM Chat
17
+ Model: {MODEL_ID}
18
  """
19
 
20
  if not torch.cuda.is_available():
21
  raise ValueError("Running on CPU 🥶 This demo does not work on CPU.")
22
 
23
+ engine_args = AsyncEngineArgs(
24
+ model=MODEL_ID,
25
+ sparsity="sparse_w16a16",
26
+ max_model_len=MAX_INPUT_TOKEN_LENGTH
27
+ )
28
  engine = AsyncLLMEngine.from_engine_args(engine_args)
29
 
30
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
31
  tokenizer.use_default_system_prompt = False
32
 
33
+
34
  async def generate(
35
  message: str,
36
  chat_history: list[tuple[str, str]],
 
42
  repetition_penalty: float = 1.2,
43
  ):
44
  conversation = []
45
+
46
  if system_prompt:
47
  conversation.append({"role": "system", "content": system_prompt})
48
+
49
  for user, assistant in chat_history:
50
+ conversation.extend(
51
+ [
52
+ {"role": "user", "content": user},
53
+ {"role": "assistant", "content": assistant},
54
+ ]
55
+ )
56
  conversation.append({"role": "user", "content": message})
57
 
58
+ formatted_conversation = tokenizer.apply_chat_template(
59
+ conversation, tokenize=False, add_generation_prompt=True
60
+ )
61
 
62
  sampling_params = SamplingParams(
63
  max_tokens=max_new_tokens,
 
67
  repetition_penalty=repetition_penalty,
68
  )
69
 
70
+ stream = await engine.add_request(
71
+ uuid.uuid4().hex, formatted_conversation, sampling_params
72
+ )
73
+
74
  async for request_output in stream:
75
  text = request_output.outputs[0].text
76
  yield text
 
128
 
129
  with gr.Blocks(css="style.css") as demo:
130
  gr.Markdown(DESCRIPTION)
131
+ gr.DuplicateButton(
132
+ value="Duplicate Space for private use", elem_id="duplicate-button"
133
+ )
134
  chat_interface.render()
135
 
136
  if __name__ == "__main__":
137
+ demo.queue(max_size=20).launch()