JCai commited on
Commit
22cebd2
1 Parent(s): 9cd3c4e

upload app.py file

Browse files
Files changed (1) hide show
  1. app.py +234 -0
app.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ import torch
4
+ from transformers import pipeline
5
+
6
+ from typing import Iterable
7
+ from gradio.themes.base import Base
8
+ from gradio.themes.utils import colors, fonts, sizes
9
+
10
+ # Inference client setup
11
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
12
+ pipe = pipeline("text-generation", "microsoft/Phi-3-mini-4k-instruct", torch_dtype=torch.bfloat16, device_map="auto")
13
+
14
+ # Global flag to handle cancellation
15
+ stop_inference = False
16
+
17
+ def respond(
18
+ message,
19
+ history: list[tuple[str, str]],
20
+ system_message="You are not a friendly Chatbot.",
21
+ max_tokens=512,
22
+ temperature=0.7,
23
+ top_p=0.95,
24
+ use_local_model=False,
25
+ ):
26
+ global stop_inference
27
+ stop_inference = False # Reset cancellation flag
28
+
29
+ # Initialize history if it's None
30
+ if history is None:
31
+ history = []
32
+
33
+ if use_local_model:
34
+ # local inference
35
+ messages = [{"role": "system", "content": system_message}]
36
+ for val in history:
37
+ if val[0]:
38
+ messages.append({"role": "user", "content": val[0]})
39
+ if val[1]:
40
+ messages.append({"role": "assistant", "content": val[1]})
41
+ messages.append({"role": "user", "content": message})
42
+
43
+ response = ""
44
+ for output in pipe(
45
+ messages,
46
+ max_new_tokens=max_tokens,
47
+ temperature=temperature,
48
+ do_sample=True,
49
+ top_p=top_p,
50
+ ):
51
+ if stop_inference:
52
+ response = "Inference cancelled."
53
+ yield history + [(message, response)]
54
+ return
55
+ token = output['generated_text'][-1]['content']
56
+ response += token
57
+ yield history + [(message, response)] # Yield history + new response
58
+
59
+ else:
60
+ # API-based inference
61
+ messages = [{"role": "system", "content": system_message}]
62
+ for val in history:
63
+ if val[0]:
64
+ messages.append({"role": "user", "content": val[0]})
65
+ if val[1]:
66
+ messages.append({"role": "assistant", "content": val[1]})
67
+ messages.append({"role": "user", "content": message})
68
+
69
+ response = ""
70
+ for message_chunk in client.chat_completion(
71
+ messages,
72
+ max_tokens=max_tokens,
73
+ stream=True,
74
+ temperature=temperature,
75
+ top_p=top_p,
76
+ ):
77
+ if stop_inference:
78
+ response = "Inference cancelled."
79
+ yield history + [(message, response)]
80
+ return
81
+ if stop_inference:
82
+ response = "Inference cancelled."
83
+ break
84
+ token = message_chunk.choices[0].delta.content
85
+ response += token
86
+ yield history + [(message, response)] # Yield history + new response
87
+
88
+
89
+ def cancel_inference():
90
+ global stop_inference
91
+ stop_inference = True
92
+
93
+ # Custom CSS for a fancy look
94
+ custom_css = """
95
+ #main-container {
96
+ background-color: #FFC0CB;
97
+ font-family: 'Arial', sans-serif;
98
+ }
99
+
100
+ .gradio-container {
101
+ max-width: 700px;
102
+ margin: 0 auto;
103
+ padding: 20px;
104
+ background: white;
105
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
106
+ border-radius: 10px;
107
+ }
108
+
109
+ .gr-button {
110
+ background-color: #4CAF50;
111
+ color: white;
112
+ border: none;
113
+ border-radius: 5px;
114
+ padding: 10px 20px;
115
+ cursor: pointer;
116
+ transition: background-color 0.3s ease;
117
+ }
118
+
119
+ .gr-button:hover {
120
+ background-color: #45a049;
121
+ }
122
+
123
+ .gr-slider input {
124
+ color: #4CAF50;
125
+ }
126
+
127
+ .gr-chat {
128
+ font-size: 16px;
129
+ }
130
+
131
+ #title {
132
+ text-align: center;
133
+ font-size: 2em;
134
+ margin-bottom: 20px;
135
+ color: #333;
136
+ }
137
+ """
138
+
139
+ class UI_design(Base):
140
+ def __init__(
141
+ self,
142
+ *,
143
+ primary_hue: colors.Color | str = colors.emerald,
144
+ secondary_hue: colors.Color | str = colors.blue,
145
+ neutral_hue: colors.Color | str = colors.blue,
146
+ spacing_size: sizes.Size | str = sizes.spacing_md,
147
+ radius_size: sizes.Size | str = sizes.radius_md,
148
+ text_size: sizes.Size | str = sizes.text_lg,
149
+ font: fonts.Font
150
+ | str
151
+ | Iterable[fonts.Font | str] = (
152
+ fonts.GoogleFont("Quicksand"),
153
+ "ui-sans-serif",
154
+ "sans-serif",
155
+ ),
156
+ font_mono: fonts.Font
157
+ | str
158
+ | Iterable[fonts.Font | str] = (
159
+ fonts.GoogleFont("IBM Plex Mono"),
160
+ "ui-monospace",
161
+ "monospace",
162
+ ),
163
+ ):
164
+ super().__init__(
165
+ primary_hue=primary_hue,
166
+ secondary_hue=secondary_hue,
167
+ neutral_hue=neutral_hue,
168
+ spacing_size=spacing_size,
169
+ radius_size=radius_size,
170
+ text_size=text_size,
171
+ font=font,
172
+ font_mono=font_mono,
173
+ )
174
+ super().set(
175
+ body_background_fill="repeating-linear-gradient(45deg, *primary_200, *primary_200 10px, *primary_50 10px, *primary_50 20px)",
176
+ body_background_fill_dark="repeating-linear-gradient(45deg, *primary_800, *primary_800 10px, *primary_900 10px, *primary_900 20px)",
177
+ button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)",
178
+ button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)",
179
+ button_primary_text_color="white",
180
+ button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
181
+ slider_color="*secondary_300",
182
+ slider_color_dark="*secondary_600",
183
+ block_title_text_weight="600",
184
+ block_border_width="3px",
185
+ block_shadow="*shadow_drop_lg",
186
+ button_shadow="*shadow_drop_lg",
187
+ button_large_padding="32px",
188
+ )
189
+
190
+ ui_design = UI_design()
191
+
192
+ # Define the interface
193
+ with gr.Blocks(theme=ui_design) as demo:
194
+ gr.Markdown("<h1 style='text-align: left;'>🌟 NOT Fancy AI Chatbot 🌟</h1>")
195
+ gr.Markdown("DONT Interact with the AI chatbot using customizable settings below.")
196
+
197
+ with gr.Row():
198
+ with gr.Column():
199
+ with gr.Tabs() as input_tabs:
200
+ with gr.Tab("Sketch"):
201
+ input_sketchpad = gr.Sketchpad(type="pil", label="Sketch", layers=False)
202
+
203
+ input_text = gr.Textbox(label="input your question")
204
+
205
+ with gr.Row():
206
+ # with gr.Column():
207
+ # clear_btn = gr.ClearButton(
208
+ # [input_sketchpad, input_text])
209
+ with gr.Column():
210
+ submit_btn = gr.Button("Submit", variant="primary")
211
+
212
+ with gr.Row():
213
+ system_message = gr.Textbox(value="You are not a friendly Chatbot.", label="System message", interactive=True)
214
+ use_local_model = gr.Checkbox(label="Use Local Model", value=False)
215
+ button_1 = gr.Button("Submit", variant="primary")
216
+ with gr.Row():
217
+ max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens")
218
+ temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
219
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
220
+
221
+ chat_history = gr.Chatbot(label="Chat")
222
+
223
+ user_input = gr.Textbox(show_label=False, placeholder="Type your message here...")
224
+
225
+ cancel_button = gr.Button("Cancel Inference", variant="danger")
226
+
227
+ # Adjusted to ensure history is maintained and passed correctly
228
+ user_input.submit(respond, [user_input, chat_history, system_message, max_tokens, temperature, top_p, use_local_model], chat_history)
229
+
230
+ cancel_button.click(cancel_inference)
231
+
232
+ if __name__ == "__main__":
233
+ demo.launch(share=False) # Remove share=True because it's not supported on HF Spaces
234
+