martin-gorner commited on
Commit
2ca0c5e
1 Parent(s): d96a4ed

initial commit

Browse files
Files changed (12) hide show
  1. .gitignore +3 -0
  2. app.py +209 -52
  3. chatstate.py +94 -0
  4. img/bot.png +0 -0
  5. img/gemma.png +0 -0
  6. img/keras_logo_k.png +0 -0
  7. img/llama.png +0 -0
  8. img/mistral.png +0 -0
  9. img/usr.png +0 -0
  10. img/vicuna.png +0 -0
  11. models.py +105 -0
  12. requirements.txt +6 -1
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .DS_Store
2
+ .vscode
3
+ __pycache__
app.py CHANGED
@@ -1,63 +1,220 @@
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
8
 
 
 
 
9
 
10
- def respond(
 
11
  message,
12
- history: list[tuple[str, str]],
 
 
 
13
  system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
  ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  ),
59
- ],
60
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  if __name__ == "__main__":
 
1
+ import os
2
+
3
+ os.environ["KERAS_BACKEND"] = "jax"
4
+
5
  import gradio as gr
6
+ from gradio import ChatMessage
7
+ import keras_hub
8
+
9
+ from chatstate import ChatState
10
+ from models import (
11
+ model_presets,
12
+ load_model,
13
+ model_labels,
14
+ preset_to_website_url,
15
+ get_appropriate_chat_template,
16
+ )
17
+
18
+ model_labels_list = list(model_labels)
19
+
20
+ # lod a warm up (compile) all the models
21
+ models = []
22
+ for preset in model_presets:
23
+ model = load_model(preset)
24
+ chat_template = get_appropriate_chat_template(preset)
25
+ chat_state = ChatState(model, "", chat_template)
26
+ prompt, response = chat_state.send_message("Hello")
27
+ print("model " + preset + "loaded and initialized.")
28
+ print("The model responded: " + response)
29
+
30
+ models = [load_model(preset) for preset in model_presets]
31
+ # model = keras_hub.models.Llama3CausalLM.from_preset(
32
+ # "hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16"
33
+ # )
34
+ # models = [model, model]
35
+
36
+
37
+ def chat_turn_assistant_1(
38
+ model,
39
+ message,
40
+ history,
41
+ system_message,
42
+ preset,
43
+ # max_tokens,
44
+ # temperature,
45
+ # top_p,
46
+ ):
47
+ chat_template = get_appropriate_chat_template(preset)
48
+ chat_state = ChatState(model, system_message, chat_template)
49
 
50
+ for msg in history:
51
+ msg = ChatMessage(**msg)
52
+ if msg.role == "user":
53
+ chat_state.add_to_history_as_user(msg.content)
54
+ elif msg.role == "assistant":
55
+ chat_state.add_to_history_as_model(msg.content)
56
 
57
+ prompt, response = chat_state.send_message(message)
58
+ history.append(ChatMessage(role="assistant", content=response))
59
+ return history
60
 
61
+
62
+ def chat_turn_assistant(
63
  message,
64
+ sel1,
65
+ history1,
66
+ sel2,
67
+ history2,
68
  system_message,
69
+ # max_tokens,
70
+ # temperature,
71
+ # top_p,
72
  ):
73
+ history1 = chat_turn_assistant_1(
74
+ models[sel1], message, history1, system_message, model_presets[sel1]
75
+ )
76
+ history2 = chat_turn_assistant_1(
77
+ models[sel2], message, history2, system_message, model_presets[sel2]
78
+ )
79
+ return "", history1, history2
80
+
81
+
82
+ def chat_turn_user_1(message, history):
83
+ history.append(ChatMessage(role="user", content=message))
84
+ return history
85
+
86
+
87
+ def chat_turn_user(message, history1, history2):
88
+ history1 = chat_turn_user_1(message, history1)
89
+ history2 = chat_turn_user_1(message, history2)
90
+ return "", history1, history2
91
+
92
+
93
+ def bot_icon_select(model_name):
94
+ if "gemma" in model_name:
95
+ return "img/gemma.png"
96
+ elif "llama" in model_name:
97
+ return "img/llama.png"
98
+ elif "vicuna" in model_name:
99
+ return "img/vicuna.png"
100
+ elif "mistral" in model_name:
101
+ return "img/mistral.png"
102
+ # default
103
+ return "img/bot.png"
104
+
105
+
106
+ def instantiate_chatbots(sel1, sel2):
107
+ model_name1 = model_presets[sel1]
108
+ chatbot1 = gr.Chatbot(
109
+ type="messages",
110
+ show_label=False,
111
+ avatar_images=("img/usr.png", bot_icon_select(model_name1)),
112
+ )
113
+ model_name2 = model_presets[sel2]
114
+ chatbot2 = gr.Chatbot(
115
+ type="messages",
116
+ show_label=False,
117
+ avatar_images=("img/usr.png", bot_icon_select(model_name2)),
118
+ )
119
+ return chatbot1, chatbot2
120
+
121
+
122
+ def instantiate_select_boxes(sel1, sel2, model_labels):
123
+ sel1 = gr.Dropdown(
124
+ choices=[(name, i) for i, name in enumerate(model_labels)],
125
+ show_label=False,
126
+ info="<span style='color:black'>Selected model 1:</span> "
127
+ + "<a href='"
128
+ + preset_to_website_url(model_presets[sel1])
129
+ + "'>"
130
+ + preset_to_website_url(model_presets[sel1])
131
+ + "</a>",
132
+ value=sel1,
133
+ )
134
+ sel2 = gr.Dropdown(
135
+ choices=[(name, i) for i, name in enumerate(model_labels)],
136
+ show_label=False,
137
+ info="<span style='color:black'>Selected model 2:</span> "
138
+ + "<a href='"
139
+ + preset_to_website_url(model_presets[sel2])
140
+ + "'>"
141
+ + preset_to_website_url(model_presets[sel2])
142
+ + "</a>",
143
+ value=sel2,
144
+ )
145
+ return sel1, sel2
146
+
147
+
148
+ def instantiate_chatbots_and_select_boxes(sel1, sel2, model_labels):
149
+ chatbot1, chatbot2 = instantiate_chatbots(sel1, sel2)
150
+ sel1, sel2 = instantiate_select_boxes(sel1, sel2, model_labels)
151
+ return sel1, chatbot1, sel2, chatbot2
152
+
153
+
154
+ with gr.Blocks(fill_width=True, title="Keras demo") as demo:
155
+
156
+ with gr.Row():
157
+ gr.Image(
158
+ "img/keras_logo_k.png",
159
+ width=80,
160
+ height=80,
161
+ min_width=80,
162
+ show_label=False,
163
+ show_download_button=False,
164
+ show_fullscreen_button=False,
165
+ interactive=False,
166
+ scale=0.01,
167
+ container=False,
168
+ )
169
+ gr.HTML(
170
+ "<H2> Battle of the Keras chatbots on TPU</H2>"
171
+ + "All the models are loaded into the TPU memory. "
172
+ + "You can call them at will and compare their answers. <br/>"
173
+ + "The entire chat history is fed to the models at every submission."
174
+ + "This demno is runnig on a Google TPU v5e 2x4 (8 cores).",
175
+ )
176
+ with gr.Row():
177
+ sel1, sel2 = instantiate_select_boxes(0, 1, model_labels_list)
178
+
179
+ with gr.Row():
180
+ chatbot1, chatbot2 = instantiate_chatbots(sel1.value, sel2.value)
181
+
182
+ msg = gr.Textbox(
183
+ label="Your message:",
184
+ )
185
+ with gr.Row():
186
+ gr.ClearButton([msg, chatbot1, chatbot2])
187
+ with gr.Accordion("Additional settings", open=False):
188
+ system_message = gr.Textbox(
189
+ label="Sytem prompt",
190
+ value="You are a helpful assistant and your name is Eliza.",
191
+ )
192
+
193
+ sel1.select(
194
+ lambda sel1, sel2: instantiate_chatbots_and_select_boxes(
195
+ sel1, sel2, model_labels_list
196
  ),
197
+ inputs=[sel1, sel2],
198
+ outputs=[sel1, chatbot1, sel2, chatbot2],
199
+ )
200
+
201
+ sel2.select(
202
+ lambda sel1, sel2: instantiate_chatbots_and_select_boxes(
203
+ sel1, sel2, model_labels_list
204
+ ),
205
+ inputs=[sel1, sel2],
206
+ outputs=[sel1, chatbot1, sel2, chatbot2],
207
+ )
208
+
209
+ msg.submit(
210
+ chat_turn_user,
211
+ inputs=[msg, chatbot1, chatbot2],
212
+ outputs=[msg, chatbot1, chatbot2],
213
+ ).then(
214
+ chat_turn_assistant,
215
+ [msg, sel1, chatbot1, sel2, chatbot2, system_message],
216
+ outputs=[msg, chatbot1, chatbot2],
217
+ )
218
 
219
 
220
  if __name__ == "__main__":
chatstate.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # chat helper
2
+ class ChatState:
3
+
4
+ def __init__(self, model, system="", chat_template="auto"):
5
+ chat_template = (
6
+ type(model).__name__ if chat_template == "auto" else chat_template
7
+ )
8
+
9
+ if chat_template == "Llama3CausalLM":
10
+ self.__START_TURN_SYSTEM__ = (
11
+ "<|start_header_id|>system<|end_header_id|>\n\n"
12
+ )
13
+ self.__START_TURN_USER__ = (
14
+ "<|start_header_id|>user<|end_header_id|>\n\n"
15
+ )
16
+ self.__START_TURN_MODEL__ = (
17
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
18
+ )
19
+ self.__END_TURN_SYSTEM__ = "<|eot_id|>"
20
+ self.__END_TURN_USER__ = "<|eot_id|>"
21
+ self.__END_TURN_MODEL__ = "<|eot_id|>"
22
+ print("Using chat template for: Llama")
23
+ elif chat_template == "GemmaCausalLM":
24
+ self.__START_TURN_SYSTEM__ = ""
25
+ self.__START_TURN_USER__ = "<start_of_turn>user\n"
26
+ self.__START_TURN_MODEL__ = "<start_of_turn>model\n"
27
+ self.__END_TURN_SYSTEM__ = "\n"
28
+ self.__END_TURN_USER__ = "<end_of_turn>\n"
29
+ self.__END_TURN_MODEL__ = "<end_of_turn>\n"
30
+ print("Using chat template for: Gemma")
31
+ elif chat_template == "MistralCausalLM":
32
+ self.__START_TURN_SYSTEM__ = ""
33
+ self.__START_TURN_USER__ = "[INST]"
34
+ self.__START_TURN_MODEL__ = ""
35
+ self.__END_TURN_SYSTEM__ = "<s>"
36
+ self.__END_TURN_USER__ = "[/INST]"
37
+ self.__END_TURN_MODEL__ = "</s>"
38
+ print("Using chat template for: Mistral")
39
+ elif chat_template == "Vicuna":
40
+ self.__START_TURN_SYSTEM__ = ""
41
+ self.__START_TURN_USER__ = "USER: "
42
+ self.__START_TURN_MODEL__ = "ASSISTANT: "
43
+ self.__END_TURN_SYSTEM__ = "\n\n"
44
+ self.__END_TURN_USER__ = "\n"
45
+ self.__END_TURN_MODEL__ = "</s>\n"
46
+ print("Using chat template for : Vicuna")
47
+ else:
48
+ assert (0, "Unknown turn tags for this model class")
49
+
50
+ self.model = model
51
+ self.system = system
52
+ self.history = []
53
+
54
+ def add_to_history_as_user(self, message):
55
+ self.history.append(
56
+ self.__START_TURN_USER__ + message + self.__END_TURN_USER__
57
+ )
58
+
59
+ def add_to_history_as_model(self, message):
60
+ self.history.append(
61
+ self.__START_TURN_MODEL__ + message + self.__END_TURN_MODEL__
62
+ )
63
+
64
+ def get_history(self):
65
+ return "".join([*self.history])
66
+
67
+ def get_full_prompt(self):
68
+ prompt = self.get_history() + self.__START_TURN_MODEL__
69
+ if len(self.system) > 0:
70
+ prompt = (
71
+ self.__START_TURN_SYSTEM__
72
+ + self.system
73
+ + self.__END_TURN_SYSTEM__
74
+ + prompt
75
+ )
76
+ return prompt
77
+
78
+ def send_message(self, message):
79
+ """
80
+ Handles sending a user message and getting a model response.
81
+
82
+ Args:
83
+ message: The user's message.
84
+
85
+ Returns:
86
+ The model's response.
87
+ """
88
+ self.add_to_history_as_user(message)
89
+ prompt = self.get_full_prompt()
90
+ response = self.model.generate(
91
+ prompt, max_length=1024, strip_prompt=True
92
+ )
93
+ self.add_to_history_as_model(response)
94
+ return (message, response)
img/bot.png ADDED
img/gemma.png ADDED
img/keras_logo_k.png ADDED
img/llama.png ADDED
img/mistral.png ADDED
img/usr.png ADDED
img/vicuna.png ADDED
models.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import keras
2
+ import keras_hub
3
+
4
+ model_presets = [
5
+ "hf://google/gemma-2-instruct-9b-keras",
6
+ "hf://meta-llama/Llama-3.1-8B-Instruct",
7
+ "hf://google/codegemma-7b-it-keras",
8
+ "hf://keras/mistral_instruct_7b_en",
9
+ "hf://keras/vicuna_1.5_7b_en",
10
+ ]
11
+
12
+ model_labels = map(lambda s: s.removeprefix("hf://"), model_presets)
13
+ model_labels = map(lambda s: s.removeprefix("google/"), model_labels)
14
+ model_labels = map(lambda s: s.removeprefix("keras/"), model_labels)
15
+ model_labels = map(lambda s: s.removeprefix("meta-llama/"), model_labels)
16
+
17
+
18
+ def preset_to_website_url(preset):
19
+ preset = preset.removeprefix("hf://")
20
+ url = "http://huggingface.co/" + preset
21
+ return url
22
+
23
+
24
+ def get_appropriate_chat_template(preset):
25
+ return "Vicuna" if "vicuna" in preset else "auto"
26
+
27
+
28
+ def get_default_layout_map(preset_name, device_mesh):
29
+ # Llama's default layout map works for mistral and vicuna
30
+ # because their transformer layers have the same names.
31
+ if (
32
+ "Llama" in preset_name
33
+ or "mistral" in preset_name
34
+ or "vicuna" in preset_name
35
+ ):
36
+ return keras_hub.models.Llama3Backbone.get_layout_map(device_mesh)
37
+ elif "gemma" in preset_name:
38
+ return keras_hub.models.GemmaBackbone.get_layout_map(device_mesh)
39
+
40
+
41
+ def log_applied_layout_map(model):
42
+ if "Gemma" in type(model):
43
+ transformer_decoder_block_name = "decoder_block_1"
44
+ elif "Llama3" in type(model) or "Mistral" in type(model):
45
+ transformer_decoder_block_name = "transformer_layer_1"
46
+ else:
47
+ assert (0, "Model type not recognized. Cannot display model layout.")
48
+ # See how layer sharding was applied
49
+ embedding_layer = model.backbone.get_layer("token_embedding")
50
+ print(embedding_layer)
51
+ decoder_block = model.backbone.get_layer(transformer_decoder_block_name)
52
+ print(type(decoder_block))
53
+ for variable in embedding_layer.weights + decoder_block.weights:
54
+ print(
55
+ f"{variable.path:<58} \
56
+ {str(variable.shape):<16} \
57
+ {str(variable.value.sharding.spec):<35} \
58
+ {str(variable.dtype)}"
59
+ )
60
+
61
+
62
+ def load_model(preset):
63
+ devices = keras.distribution.list_devices()
64
+ device_mesh = keras.distribution.DeviceMesh(
65
+ shape=(1, len(devices)), axis_names=["batch", "model"], devices=devices
66
+ )
67
+ model_parallel = keras.distribution.ModelParallel(
68
+ layout_map=get_default_layout_map(preset, device_mesh),
69
+ batch_dim_name="batch",
70
+ )
71
+
72
+ with model_parallel.scope():
73
+ # These two buggy models need this workaround to be loaded in bfloat16
74
+ if "google/gemma-2-instruct-9b-keras" in preset:
75
+ model = keras_hub.models.GemmaCausalLM(
76
+ backbone=keras_hub.models.GemmaBackbone.from_preset(
77
+ preset, dtype="bfloat16"
78
+ ),
79
+ preprocessor=keras_hub.models.GemmaCausalLMPreprocessor.from_preset(
80
+ preset
81
+ ),
82
+ )
83
+ elif "meta-llama/Llama-3.1-8B-Instruct" in preset:
84
+ model = keras_hub.models.Llama3CausalLM(
85
+ backbone=keras_hub.models.Llama3Backbone.from_preset(
86
+ preset, dtype="bfloat16"
87
+ ),
88
+ preprocessor=keras_hub.models.Llama3CausalLMPreprocessor.from_preset(
89
+ preset
90
+ ),
91
+ )
92
+ else:
93
+ model = keras_hub.models.CausalLM.from_preset(
94
+ preset, dtype="bfloat16"
95
+ )
96
+
97
+ log_applied_layout_map(model)
98
+ return model
99
+
100
+
101
+ # Some small models too
102
+ # model1 = keras_hub.models.CausalLM.from_preset("hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16")
103
+ # model2 = keras_hub.models.CausalLM.from_preset("hf://google/gemma-2b-it-keras", dtype="bfloat16")
104
+ # model3 = keras_hub.models.CausalLM.from_preset("hf://meta-llama/Llama-3.2-3B-Instruct", dtype="bfloat16")
105
+ # keras/gemma_1.1_instruct_7b_en
requirements.txt CHANGED
@@ -1 +1,6 @@
1
- huggingface_hub==0.25.2
 
 
 
 
 
 
1
+ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
2
+ jax[tpu]
3
+ keras>=3
4
+ keras-hub
5
+ safetensors
6
+ huggingface_hub