YangWu001 commited on
Commit
cf229a5
1 Parent(s): c6a0006
Files changed (1) hide show
  1. app.py +34 -10
app.py CHANGED
@@ -1,9 +1,12 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  import time
 
 
4
 
5
  # Inference client setup
6
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
7
 
8
  # Global flag to handle cancellation
9
  stop_inference = False
@@ -22,10 +25,31 @@ def respond(
22
 
23
  if use_local_model:
24
  # Simulate local inference
25
- time.sleep(2) # simulate a delay
26
- response = "This is a response from the local model."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  history.append((message, response))
28
- yield history
 
29
  else:
30
  # API-based inference
31
  messages = [{"role": "system", "content": system_message}]
@@ -49,12 +73,10 @@ def respond(
49
  break
50
  token = message_chunk.choices[0].delta.content
51
  response += token
52
- history[-1] = (message, response)
53
- yield history # Yield the history list of tuples
54
 
55
- # Finalize response in history
56
- history[-1] = (message, response) # Update with the full response
57
- yield history
58
 
59
  def cancel_inference():
60
  global stop_inference
@@ -127,8 +149,7 @@ with gr.Blocks(css=custom_css) as demo:
127
  cancel_button = gr.Button("Cancel Inference", variant="danger")
128
 
129
  def chat_fn(message, history):
130
- history.append((message, "")) # Initialize with empty response
131
- return respond(
132
  message,
133
  history,
134
  system_message.value,
@@ -137,6 +158,9 @@ with gr.Blocks(css=custom_css) as demo:
137
  top_p.value,
138
  use_local_model.value,
139
  )
 
 
 
140
 
141
  user_input.submit(chat_fn, [user_input, chat_history], chat_history)
142
  cancel_button.click(cancel_inference)
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  import time
4
+ import torch
5
+ from transformers import pipeline
6
 
7
  # Inference client setup
8
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
9
+ pipe = pipeline("text-generation", "microsoft/Phi-3-mini-4k-instruct", torch_dtype=torch.bfloat16, device_map="auto")
10
 
11
  # Global flag to handle cancellation
12
  stop_inference = False
 
25
 
26
  if use_local_model:
27
  # Simulate local inference
28
+ messages = [{"role": "system", "content": system_message}]
29
+
30
+ for val in history:
31
+ if val[0]:
32
+ messages.append({"role": "user", "content": val[0]})
33
+ if val[1]:
34
+ messages.append({"role": "assistant", "content": val[1]})
35
+
36
+ messages.append({"role": "user", "content": message})
37
+
38
+ response = ""
39
+ for message in pipe(
40
+ messages,
41
+ max_new_tokens=max_tokens,
42
+ temperature=temperature,
43
+ do_sample=True,
44
+ top_p=top_p,
45
+ ):
46
+ token = message['generated_text'][-1]['content']
47
+ response += token
48
+ yield response # Yielding response directly
49
+
50
  history.append((message, response))
51
+ yield history # Yield the updated history
52
+
53
  else:
54
  # API-based inference
55
  messages = [{"role": "system", "content": system_message}]
 
73
  break
74
  token = message_chunk.choices[0].delta.content
75
  response += token
76
+ yield response # Yielding response directly
 
77
 
78
+ history.append((message, response))
79
+ yield history # Yield the updated history
 
80
 
81
  def cancel_inference():
82
  global stop_inference
 
149
  cancel_button = gr.Button("Cancel Inference", variant="danger")
150
 
151
  def chat_fn(message, history):
152
+ response_gen = respond(
 
153
  message,
154
  history,
155
  system_message.value,
 
158
  top_p.value,
159
  use_local_model.value,
160
  )
161
+ for response in response_gen:
162
+ history[-1] = (message, response)
163
+ yield history
164
 
165
  user_input.submit(chat_fn, [user_input, chat_history], chat_history)
166
  cancel_button.click(cancel_inference)