AgainstEntropy commited on
Commit
9fae5ad
Β·
0 Parent(s):
Files changed (2) hide show
  1. README.md +15 -0
  2. app.py +269 -0
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Kanji Streaming
3
+ emoji: πŸ‰‘
4
+ colorFrom: purple
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 4.19.2
8
+ app_file: app.py
9
+ pinned: false
10
+ # suggested_hardware: a10g-small
11
+ models:
12
+ - mistralai/Mixtral-8x7B-Instruct-v0.1
13
+ ---
14
+
15
+ Check out [original repo](https://github.com/AgainstEntropy/kanji) for mroe details!
app.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from queue import SimpleQueue
4
+ from threading import Thread
5
+ from typing import Iterator
6
+
7
+ import gradio as gr
8
+ import spaces
9
+ import torch
10
+ from gradio import Chatbot
11
+ from huggingface_hub import InferenceClient
12
+
13
+ from image_utils import ImageStitcher
14
+ from StreamDiffusionIO import LatentConsistencyModelStreamIO
15
+
16
+ MAX_MAX_NEW_TOKENS = 2048
17
+ DEFAULT_MAX_NEW_TOKENS = 1024
18
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
19
+
20
+ DESCRIPTION = """\
21
+ # Kanji-Streaming Chat
22
+
23
+ 🌍 This Space is adapted from [Llama-2-7b-chat](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat) space, demonstrating how to "chat" with LLM with [Kanji-Streaming](https://github.com/AgainstEntropy/kanji).
24
+
25
+ πŸ”¨ The technique behind Kanji-Streaming is [StreamDiffusionIO](https://github.com/AgainstEntropy/StreamDiffusionIO), which is based on [StreamDiffusion](https://github.com/cumulo-autumn/StreamDiffusion), *but especially allows to render text streams into image streams*.
26
+
27
+ πŸ”Ž For more details about Kanji-Streaming, take a look at the [github repository](https://github.com/AgainstEntropy/kanji).
28
+ """
29
+
30
+ LICENSE = """
31
+ <p/>
32
+
33
+ ---
34
+ As a derivate work of [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta,
35
+ this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
36
+ """
37
+
38
+
39
+ parser = argparse.ArgumentParser(description="Gradio launcher for Streaming-Kanji.")
40
+ parser.add_argument(
41
+ "--sd_model_id_or_path",
42
+ type=str,
43
+ default="runwayml/stable-diffusion-v1-5",
44
+ required=False,
45
+ help="Path to downloaded sd-1-5 model or model identifier from huggingface.co/models.",
46
+ )
47
+ parser.add_argument(
48
+ "--lora_path",
49
+ type=str,
50
+ default="AgainstEntropy/kanji-lora-sd-v1-5",
51
+ required=False,
52
+ help="Path to downloaded LoRA weight or model identifier from huggingface.co/models.",
53
+ )
54
+ parser.add_argument(
55
+ "--lcm_lora_path",
56
+ type=str,
57
+ default="AgainstEntropy/kanji-lcm-lora-sd-v1-5",
58
+ required=False,
59
+ help="Path to downloaded LCM-LoRA weight or model identifier from huggingface.co/models.",
60
+ )
61
+ parser.add_argument(
62
+ "--img_res",
63
+ type=int,
64
+ default=64,
65
+ required=False,
66
+ help="Image resolution for displaying Kanji characters in ChatBot.",
67
+ )
68
+ parser.add_argument(
69
+ "--img_per_line",
70
+ type=int,
71
+ default=16,
72
+ required=False,
73
+ help="Number of Kanji characters to display in a single line.",
74
+ )
75
+ parser.add_argument(
76
+ "--tmp_dir",
77
+ type=str,
78
+ default="./tmp",
79
+ required=False,
80
+ help="Path to save temporary images generated by StreamDiffusionIO.",
81
+ )
82
+
83
+ args = parser.parse_args()
84
+
85
+ if torch.cuda.is_available():
86
+ device = "cuda"
87
+ else:
88
+ device = "cpu"
89
+ DESCRIPTION += "\n<p>Running on CPU πŸ₯Ά This demo works best on GPU.</p>"
90
+
91
+ DESCRIPTION += "\n<p>This demo will get the best kanji streaming experience in localhost (or SSH forward), instead of shared link generated by Gradio.</p>"
92
+
93
+ client = InferenceClient(
94
+ "mistralai/Mixtral-8x7B-Instruct-v0.1"
95
+ )
96
+
97
+ def format_prompt(message, history):
98
+ prompt = "<s>"
99
+ for user_prompt, bot_response in history:
100
+ prompt += f"[INST] {user_prompt} [/INST]"
101
+ prompt += f" {bot_response}</s> "
102
+ prompt += f"[INST] {message} [/INST]"
103
+ return prompt
104
+
105
+ lcm_stream = LatentConsistencyModelStreamIO(
106
+ model_id_or_path=args.sd_model_id_or_path,
107
+ lcm_lora_path=args.lcm_lora_path,
108
+ lora_dict={args.lora_path: 1},
109
+ resolution=128,
110
+ device=device,
111
+ use_xformers=True,
112
+ verbose=True,
113
+ )
114
+
115
+ tmp_dir_template = f"{args.tmp_dir}/%d"
116
+ response_num = 0
117
+ response_cache = ""
118
+
119
+ stitcher = ImageStitcher(
120
+ tmp_dir=tmp_dir_template % response_num,
121
+ img_res=args.img_res,
122
+ img_per_line=args.img_per_line,
123
+ verbose=True,
124
+ )
125
+
126
+
127
+ @spaces.GPU
128
+ def generate(
129
+ message: str,
130
+ chat_history: list[tuple[str, str]],
131
+ seed: int,
132
+ system_prompt: str,
133
+ max_new_tokens: int = 1024,
134
+ temperature: float = 0.6,
135
+ top_p: float = 0.9,
136
+ top_k: int = 50,
137
+ repetition_penalty: float = 1.2,
138
+ ) -> Iterator[str]:
139
+
140
+ if temperature < 1e-2:
141
+ temperature = 1e-2
142
+
143
+ global response_cache
144
+
145
+ generate_kwargs = dict(
146
+ max_new_tokens=max_new_tokens,
147
+ do_sample=True,
148
+ top_p=top_p,
149
+ top_k=top_k,
150
+ temperature=temperature,
151
+ repetition_penalty=repetition_penalty,
152
+ )
153
+ formatted_prompt = format_prompt(f"{system_prompt}, {message}", chat_history)
154
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
155
+
156
+ outputs = ""
157
+ prompt_queue = SimpleQueue()
158
+
159
+ lcm_stream.reset(seed)
160
+ stitcher.reset()
161
+
162
+ global response_num
163
+ response_num += 1
164
+ stitcher.update_tmp_dir(tmp_dir_template % response_num)
165
+
166
+ def append_to_queue():
167
+ for response in stream:
168
+ outputs += response.token.text
169
+ text = text.strip()
170
+ if text:
171
+ prompt_queue.put(text)
172
+ prompt_queue.put(None)
173
+
174
+ append_thread = Thread(target=append_to_queue)
175
+ append_thread.start()
176
+
177
+ def show_image(prompt: str = None):
178
+ image, text = lcm_stream(prompt)
179
+ img_path = None
180
+ if image is not None:
181
+ img_path = stitcher.add(image, text)
182
+ print(img_path)
183
+ return img_path
184
+
185
+ while True:
186
+ prompt = prompt_queue.get()
187
+ if prompt is None:
188
+ break
189
+ img_path = show_image(prompt)
190
+ if img_path is not None:
191
+ yield (img_path, )
192
+
193
+ # Continue to display the remaining images
194
+ while True:
195
+ img_path = show_image()
196
+ if img_path is not None:
197
+ yield (img_path, )
198
+ if lcm_stream.stop():
199
+ break
200
+
201
+ response_cache = outputs
202
+ return outputs
203
+
204
+
205
+ chat_interface = gr.ChatInterface(
206
+ fn=generate,
207
+ chatbot=Chatbot(height=400),
208
+ additional_inputs=[
209
+ gr.Number(
210
+ label="Seed",
211
+ info="Random Seed for Kanji Generation (maybe some kind of accent πŸ€”)",
212
+ step=1,
213
+ value=1026,
214
+ ),
215
+ gr.Textbox(label="System prompt", lines=4),
216
+ gr.Slider(
217
+ label="Max new tokens",
218
+ minimum=1,
219
+ maximum=MAX_MAX_NEW_TOKENS,
220
+ step=1,
221
+ value=DEFAULT_MAX_NEW_TOKENS,
222
+ ),
223
+ gr.Slider(
224
+ label="Temperature",
225
+ minimum=0.1,
226
+ maximum=4.0,
227
+ step=0.1,
228
+ value=0.6,
229
+ ),
230
+ gr.Slider(
231
+ label="Top-p (nucleus sampling)",
232
+ minimum=0.05,
233
+ maximum=1.0,
234
+ step=0.05,
235
+ value=0.9,
236
+ ),
237
+ gr.Slider(
238
+ label="Top-k",
239
+ minimum=1,
240
+ maximum=1000,
241
+ step=1,
242
+ value=50,
243
+ ),
244
+ gr.Slider(
245
+ label="Repetition penalty",
246
+ minimum=1.0,
247
+ maximum=2.0,
248
+ step=0.05,
249
+ value=1.2,
250
+ ),
251
+ ],
252
+ stop_btn=None,
253
+ examples=[
254
+ ["Hello there! How are you doing?"],
255
+ ["Can you explain briefly to me what is the Python programming language?"],
256
+ ["Explain the plot of Cinderella in a sentence."],
257
+ ["How many hours does it take a man to eat a Helicopter?"],
258
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
259
+ ],
260
+ )
261
+
262
+ with gr.Blocks(css="style.css") as demo:
263
+ gr.Markdown(DESCRIPTION)
264
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
265
+ chat_interface.render()
266
+ gr.Markdown(LICENSE)
267
+
268
+ if __name__ == "__main__":
269
+ demo.queue(max_size=20).launch(show_api=False)