jed-tiotuico commited on
Commit
ffa493c
1 Parent(s): 5628e55

first commit

Browse files
Files changed (3) hide show
  1. app.py +332 -0
  2. pre-requirements.txt +2 -0
  3. requirements.txt +14 -0
app.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ from google.protobuf import message
3
+ import torch
4
+ import time
5
+ import threading
6
+ import streamlit as st
7
+ import random
8
+ from typing import Iterable
9
+ # from unsloth import FastLanguageModel
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, PreTrainedTokenizerFast
11
+ from datetime import datetime
12
+ from threading import Thread
13
+
14
+ # fine_tuned_model_name = "jed-tiotuico/twitter-llama"
15
+ # sota_model_name = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit"
16
+
17
+ fine_tuned_model_name = "MBZUAI/LaMini-GPT-124M"
18
+ sota_model_name = "MBZUAI/LaMini-GPT-124M"
19
+ alpaca_input_text_format = "### Instruction:\n{}\n\n### Response:\n"
20
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
+ # if device is cpu try mps?
22
+ if device == "cpu":
23
+ # check if mps is available
24
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
25
+
26
+ def get_model_tokenizer(sota_model_name):
27
+ tokenizer = AutoTokenizer.from_pretrained(
28
+ sota_model_name,
29
+ cache_dir="/Users/jedtiotuico/.hf_cache",
30
+ trust_remote_code=True
31
+ )
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ sota_model_name,
34
+ cache_dir="/Users/jedtiotuico/.hf_cache",
35
+ trust_remote_code=True
36
+ ).to(device)
37
+
38
+ return model, tokenizer
39
+
40
+ def write_user_chat_message(user_chat, customer_msg):
41
+ if customer_msg:
42
+ if user_chat == None:
43
+ user_chat = st.chat_message("user")
44
+
45
+ user_chat.write(customer_msg)
46
+
47
+ def write_stream_user_chat_message(user_chat, model, token, prompt):
48
+ if prompt:
49
+ if user_chat == None:
50
+ user_chat = st.chat_message("user")
51
+
52
+ new_customer_msg = user_chat.write_stream(
53
+ stream_generation(
54
+ prompt,
55
+ show_prompt=False,
56
+ tokenizer=tokenizer,
57
+ model=model,
58
+ )
59
+ )
60
+
61
+ return new_customer_msg
62
+
63
+ def get_mistral_model_tokenizer(sota_model_name):
64
+ tokenizer = AutoTokenizer.from_pretrained(
65
+ sota_model_name,
66
+ cache_dir="/Users/jedtiotuico/.hf_cache",
67
+ trust_remote_code=True
68
+ )
69
+ model = AutoModelForCausalLM.from_pretrained(
70
+ sota_model_name,
71
+ cache_dir="/Users/jedtiotuico/.hf_cache",
72
+ trust_remote_code=True
73
+ ).to(device)
74
+
75
+ return model, tokenizer
76
+
77
+ class DeckPicker:
78
+ def __init__(self, items):
79
+ self.items = items[:] # Make a copy of the items to shuffle
80
+ self.original_items = items[:] # Keep the original order
81
+ random.shuffle(self.items) # Shuffle the items
82
+ self.index = -1 # Initialize the index
83
+
84
+ def pick(self):
85
+ """Pick the next item from the deck. If all items have been picked, reshuffle."""
86
+ self.index += 1
87
+ if self.index >= len(self.items):
88
+ self.index = 0
89
+ random.shuffle(self.items) # Reshuffle if at the end
90
+ return self.items[self.index]
91
+
92
+ def get_state(self):
93
+ """Return the current state of the deck and the last picked index."""
94
+ return self.items, self.index
95
+
96
+ # Example of usage
97
+ nouns = [
98
+ "service", "issue", "account", "support", "problem", "help", "team",
99
+ "request", "response", "email", "ticket", "update", "error", "system",
100
+ "connection", "downtime", "billing", "charge", "refund", "password",
101
+ "outage", "agent", "feature", "access", "status", "interface", "network",
102
+ "subscription", "upgrade", "notification", "data", "server", "log", "message",
103
+ "renewal", "setup", "security", "feedback", "confirmation", "printer"
104
+ ]
105
+
106
+ verbs = [
107
+ "have", "print", "need", "help", "update", "resolve", "access", "contact",
108
+ "receive", "reset", "support", "experience", "report", "request", "process",
109
+ "check", "confirm", "explain", "manage", "handle", "disconnect", "renew",
110
+ "change", "fix", "cancel", "complete", "notify", "respond", "fail", "restore",
111
+ "review", "escalate", "submit", "configure", "troubleshoot", "log", "operate",
112
+ "suspend", "pay", "adjust"
113
+ ]
114
+
115
+ adjectives = [
116
+ "quick", "immediate", "urgent", "unable", "detailed", "frequent", "technical",
117
+ "possible", "slow", "helpful", "unresponsive", "secure", "successful", "necessary",
118
+ "available", "scheduled", "regular", "interrupted", "automatic", "manual", "last",
119
+ "online", "offline", "new", "current", "prior", "due", "related", "temporary",
120
+ "permanent", "next", "previous", "complicated", "easy", "difficult", "major",
121
+ "minor", "alternative", "additional", "expired"
122
+ ]
123
+
124
+ def create_few_shots(noun_picker, verb_picker, adjective_picker):
125
+ noun = noun_picker.pick()
126
+ verb = verb_picker.pick()
127
+ adjective = adjective_picker.pick()
128
+
129
+ context = f"""
130
+ Write a short realistic customer support tweet message by a customer for another company.
131
+ Avoid adding hashtags or mentions in the message.
132
+ Ensure that the sentiment is negative.
133
+ Ensure that the word count is around 15 to 25 words.
134
+ Ensure the message contains the noun: {noun}, verb: {verb}, and adjective: {adjective}.
135
+
136
+ Example of return messages 5/5:
137
+
138
+ 1/5: your website is straight up garbage. how do you sell high end technology but you cant get a website right?
139
+ 2/5: my phone is all static during calls and when i plug in headphones any audio still comes thru the speaks wtf
140
+ 3/5: hi, i'm having trouble logging into my groceries account it keeps refreshing back to the log in page, any ideas?
141
+ 4/5: please check you dms asap if you're really about customer service. 2 weeks since my accident and nothing.
142
+ 5/5: I'm extremely disappointed with your service. You charged me for a temporary solution, and there's no adjustment in sight.
143
+
144
+ Now it's your turn, ensure to only generate one message
145
+ 1/1:
146
+ """
147
+ return context
148
+
149
+ st.header("ReplyCaddy")
150
+ st.write("AI-powered customer support assistant. Reduces anxiety when responding to customer support on social media.")
151
+ # image https://github.com/unslothai/unsloth/blob/main/images/made%20with%20unsloth.png?raw=true
152
+ # st.write("Made with [Unsloth](https://github.com/unslothai/unsloth/blob/main/images/made%20with%20unsloth.png?raw=true")
153
+
154
+ def stream_generation(
155
+ prompt: str,
156
+ tokenizer: PreTrainedTokenizerFast,
157
+ model: AutoModelForCausalLM,
158
+ max_new_tokens: int = 2048,
159
+ temperature: float = 0.7,
160
+ top_p: float = 0.9,
161
+ top_k: int = 100,
162
+ repetition_penalty: float = 1.1,
163
+ penalty_alpha: float = 0.25,
164
+ no_repeat_ngram_size: int = 3,
165
+ show_prompt: bool = False,
166
+ ) -> Iterable[str]:
167
+ """
168
+ Stream the generation of a prompt.
169
+
170
+ Args:
171
+ prompt (str): the prompt
172
+ max_new_tokens (int, optional): the maximum number of tokens to generate. Defaults to 32.
173
+ temperature (float, optional): the temperature of the generation. Defaults to 0.7.
174
+ top_p (float, optional): the top-p value of the generation. Defaults to 0.9.
175
+ top_k (int, optional): the top-k value of the generation. Defaults to 100.
176
+ repetition_penalty (float, optional): the repetition penalty of the generation. Defaults to 1.1.
177
+ penalty_alpha (float, optional): the penalty alpha of the generation. Defaults to 0.25.
178
+ no_repeat_ngram_size (int, optional): the no repeat ngram size of the generation. Defaults to 3.
179
+ show_prompt (bool, optional): whether to show the prompt or not. Defaults to False.
180
+ tokenizer (PreTrainedTokenizerFast): the tokenizer
181
+ model (AutoModelForCausalLM): the model
182
+
183
+ Yields:
184
+ str: the generated text
185
+ """
186
+ # init the streaming object with tokenizer
187
+ # skip_prompt = not show_prompt, skip_special_tokens = True
188
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=not show_prompt, skip_special_tokens=True) # type: ignore
189
+
190
+ # setup kwargs for generation
191
+ generation_kwargs = dict(
192
+ input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device),
193
+ streamer=streamer,
194
+ do_sample=True,
195
+ temperature=temperature,
196
+ top_p=top_p,
197
+ top_k=top_k,
198
+ repetition_penalty=repetition_penalty,
199
+ penalty_alpha=penalty_alpha,
200
+ no_repeat_ngram_size=no_repeat_ngram_size,
201
+ max_new_tokens=max_new_tokens,
202
+ )
203
+
204
+ # start the generation in a separate thread
205
+ generation_thread = threading.Thread(
206
+ target=model.generate, kwargs=generation_kwargs # type: ignore
207
+ )
208
+ generation_thread.start()
209
+
210
+ blacklisted_tokens = ["<|url|>"]
211
+ for new_text in streamer:
212
+ # filter out blacklisted tokens
213
+ if any(token in new_text for token in blacklisted_tokens):
214
+ continue
215
+
216
+ yield new_text
217
+
218
+ # wait for the generation to finish
219
+ generation_thread.join()
220
+
221
+ twitter_llama_model = None
222
+ twitter_llama_tokenizer = None
223
+ streamer = None
224
+
225
+ # define state and the chat messages
226
+ def init_session_states(assistant_chat, user_chat):
227
+ if "user_msg_as_prompt" not in st.session_state:
228
+ st.session_state["user_msg_as_prompt"] = ""
229
+
230
+ user_chat = None
231
+ if "user_msg_as_prompt" in st.session_state:
232
+ user_chat = st.chat_message("user")
233
+
234
+ assistant_chat = st.chat_message("assistant")
235
+ if "greet" not in st.session_state:
236
+ st.session_state["greet"] = False
237
+ greeting_text = "Hello! I'm here to help. Copy and paste your customer's message, or generate using AI."
238
+ assistant_chat.write(greeting_text)
239
+
240
+ init_session_states(assistant_chat, user_chat)
241
+
242
+ # Generate Response Tweet
243
+ if user_chat:
244
+ if st.button("Generate Polite and Friendly Response"):
245
+ if "user_msg_as_prompt" in st.session_state:
246
+ customer_msg = st.session_state["user_msg_as_prompt"]
247
+ if customer_msg:
248
+ write_user_chat_message(user_chat, customer_msg)
249
+
250
+ model, tokenizer = get_model_tokenizer(sota_model_name)
251
+
252
+ input_text = alpaca_input_text_format.format(customer_msg)
253
+ st.markdown(f"""```\n{input_text}```""", unsafe_allow_html=True)
254
+ response_tweet = assistant_chat.write_stream(
255
+ stream_generation(
256
+ input_text,
257
+ show_prompt=False,
258
+ tokenizer=tokenizer,
259
+ model=model,
260
+ )
261
+ )
262
+ else:
263
+ st.error("Please enter a customer message, or generate one for the ai to respond")
264
+
265
+ # main ui prompt
266
+ # - text box
267
+ # - submit
268
+ with st.form(key="my_form"):
269
+ prompt = st.text_area("Customer Message")
270
+ write_user_chat_message(user_chat, prompt)
271
+ if st.form_submit_button("Submit"):
272
+ assistant_chat.write("Hi, Human.")
273
+
274
+ # below ui prompt
275
+ # - examples
276
+ # st.markdown("<b>Example:</b>", unsafe_allow_html=True)
277
+ if st.button("your website is straight up garbage. how do you sell high end technology but you cant get a website right?"):
278
+ customer_msg = "your website is straight up garbage. how do you sell high end technology but you cant get a website right?"
279
+ st.session_state["user_msg_as_prompt"] = customer_msg
280
+ write_user_chat_message(user_chat, customer_msg)
281
+ model, tokenizer = get_model_tokenizer(sota_model_name)
282
+ input_text = alpaca_input_text_format.format(customer_msg)
283
+ st.write(f"```\n{input_text}```")
284
+ assistant_chat.write_stream(
285
+ stream_generation(
286
+ input_text,
287
+ show_prompt=False,
288
+ tokenizer=tokenizer,
289
+ model=model,
290
+ )
291
+ )
292
+
293
+ # - Generate Customer Tweet
294
+ if st.button("Generate Customer Message using Few Shots"):
295
+ max_seq_length = 2048
296
+ dtype = torch.float16
297
+ load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
298
+
299
+ model, tokenizer = get_mistral_model_tokenizer(sota_model_name)
300
+
301
+ noun_picker = DeckPicker(nouns)
302
+ verb_picker = DeckPicker(verbs)
303
+ adjective_picker = DeckPicker(adjectives)
304
+ few_shots = create_few_shots(noun_picker, verb_picker, adjective_picker)
305
+ few_shot_prompt = f"<s>[INST]{few_shots}[/INST]\n"
306
+ st.markdown("Prompt:")
307
+ st.markdown(f"""```\n{few_shot_prompt}```""", unsafe_allow_html=True)
308
+
309
+ new_customer_msg = write_stream_user_chat_message(user_chat, model, tokenizer, few_shot_prompt)
310
+ st.session_state["user_msg_as_prompt"] = new_customer_msg
311
+
312
+
313
+ st.markdown("------------")
314
+ st.markdown("<p>Thanks to:</p>", unsafe_allow_html=True)
315
+ st.markdown("""Unsloth https://github.com/unslothai check out the [wiki](https://github.com/unslothai/unsloth/wiki)""")
316
+ st.markdown("""Georgi Gerganov's ggml https://github.com/ggerganov/ggml""")
317
+ st.markdown("""Meta's Llama https://github.com/meta-llama""")
318
+ st.markdown("""Mistral AI - https://github.com/mistralai""")
319
+ st.markdown("""Zhang Peiyuan's TinyLlama https://github.com/jzhang38/TinyLlama""")
320
+ st.markdown("""Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois,
321
+ Xuechen Li, Carlos Guestrin, Percy Liang, Tatsunori B. Hashimoto
322
+ - [Alpaca: A Strong, Replicable Instruction-Following Model](https://crfm.stanford.edu/2023/03/13/alpaca.html)""")
323
+
324
+ if device == "cuda":
325
+ gpu_stats = torch.cuda.get_device_properties(0)
326
+ max_memory = gpu_stats.total_memory / 1024 ** 3
327
+ start_gpu_memory = torch.cuda.memory_reserved(0) / 1024 ** 3
328
+ st.write(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
329
+ st.write(f"{start_gpu_memory} GB of memory reserved.")
330
+
331
+ st.write("Packages:")
332
+ st.write(f"pytorch: {torch.__version__}")
pre-requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ pip >= 24.0
2
+ wheel
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.2.1
2
+ peft==0.10.0
3
+ transformers
4
+ unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git
5
+ packaging
6
+ ninja
7
+ einops
8
+ xformers<0.0.26
9
+ trl
10
+ accelerate
11
+ bitsandbytes
12
+ jsonlines
13
+ regex
14
+ streamlit