ZamiSanj commited on
Commit
cda715f
1 Parent(s): 319c652

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -117
app.py CHANGED
@@ -1,125 +1,221 @@
1
- import gradio as gr
2
  import torch
3
- import re, os, warnings
4
- from langchain import PromptTemplate, LLMChain
5
- from langchain.llms.base import LLM
6
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
7
- from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel
8
  warnings.filterwarnings("ignore")
9
 
10
- def init_model_and_tokenizer(PEFT_MODEL):
11
- config = PeftConfig.from_pretrained(PEFT_MODEL)
12
- bnb_config = BitsAndBytesConfig(
13
- load_in_4bit=True,
14
- bnb_4bit_quant_type="nf4",
15
- bnb_4bit_use_double_quant=True,
16
- bnb_4bit_compute_dtype=torch.float16,
17
- )
18
-
19
- peft_base_model = AutoModelForCausalLM.from_pretrained(
20
- config.base_model_name_or_path,
21
- return_dict=True,
22
- quantization_config=bnb_config,
23
- device_map="auto",
24
- trust_remote_code=True,
25
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- peft_model = PeftModel.from_pretrained(peft_base_model, PEFT_MODEL)
 
 
 
 
 
28
 
29
- peft_tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
30
- peft_tokenizer.pad_token = peft_tokenizer.eos_token
 
 
31
 
32
- return peft_model, peft_tokenizer
33
 
34
- def init_llm_chain(peft_model, peft_tokenizer):
35
- class CustomLLM(LLM):
36
- def _call(self, prompt: str, stop=None, run_manager=None) -> str:
37
- device = "cuda:0"
38
- peft_encoding = peft_tokenizer(prompt, return_tensors="pt").to(device)
39
- peft_outputs = peft_model.generate(input_ids=peft_encoding.input_ids, generation_config=GenerationConfig(max_new_tokens=256, pad_token_id = peft_tokenizer.eos_token_id, \
40
- eos_token_id = peft_tokenizer.eos_token_id, attention_mask = peft_encoding.attention_mask, \
41
- temperature=0.4, top_p=0.6, repetition_penalty=1.3, num_return_sequences=1,))
42
- peft_text_output = peft_tokenizer.decode(peft_outputs[0], skip_special_tokens=True)
43
- return peft_text_output
44
 
45
- @property
46
- def _llm_type(self) -> str:
47
- return "custom"
48
-
49
- llm = CustomLLM()
50
-
51
- template = """Answer the following question truthfully.
52
- If you don't know the answer, respond 'Sorry, I don't know the answer to this question.'.
53
- If the question is too complex, respond 'Kindly, consult a psychiatrist for further queries.'.
54
-
55
- Example Format:
56
- <HUMAN>: question here
57
- <ASSISTANT>: answer here
58
-
59
- Begin!
60
-
61
- <HUMAN>: {query}
62
- <ASSISTANT>:"""
63
-
64
- prompt = PromptTemplate(template=template, input_variables=["query"])
65
- llm_chain = LLMChain(prompt=prompt, llm=llm)
66
-
67
- return llm_chain
68
-
69
- def user(user_message, history):
70
- return "", history + [[user_message, None]]
71
-
72
- def bot(history):
73
- if len(history) >= 2:
74
- query = history[-2][0] + "\n" + history[-2][1] + "\nHere, is the next QUESTION: " + history[-1][0]
75
- else:
76
- query = history[-1][0]
77
-
78
- bot_message = llm_chain.run(query)
79
- bot_message = post_process_chat(bot_message)
80
-
81
- history[-1][1] = ""
82
- history[-1][1] += bot_message
83
- return history
84
-
85
- def post_process_chat(bot_message):
86
- try:
87
- bot_message = re.findall(r"<ASSISTANT>:.*?Begin!", bot_message, re.DOTALL)[1]
88
- except IndexError:
89
- pass
90
-
91
- bot_message = re.split(r'<ASSISTANT>\:?\s?', bot_message)[-1].split("Begin!")[0]
92
-
93
- bot_message = re.sub(r"^(.*?\.)(?=\n|$)", r"\1", bot_message, flags=re.DOTALL)
94
- try:
95
- bot_message = re.search(r"(.*\.)", bot_message, re.DOTALL).group(1)
96
- except AttributeError:
97
- pass
98
-
99
- bot_message = re.sub(r"\n\d.$", "", bot_message)
100
- bot_message = re.split(r"(Goodbye|Take care|Best Wishes)", bot_message, flags=re.IGNORECASE)[0].strip()
101
- bot_message = bot_message.replace("\n\n", "\n")
102
-
103
- return bot_message
104
-
105
- model = "heliosbrahma/falcon-7b-sharded-bf16-finetuned-mental-health-conversational"
106
- peft_model, peft_tokenizer = init_model_and_tokenizer(PEFT_MODEL = model)
107
-
108
- with gr.Blocks() as interface:
109
- gr.HTML("""<h1>Welcome to Mental Health Conversational AI</h1>""")
110
- gr.Markdown(
111
- """Chatbot specifically designed to provide psychoeducation, offer non-judgemental and empathetic support, self-assessment and monitoring.<br>
112
- Get instant response for any mental health related queries. If the chatbot seems you need external support, then it will respond appropriately.<br>"""
113
- )
114
-
115
- chatbot = gr.Chatbot()
116
- query = gr.Textbox(label="Type your query here, then press 'enter' and scroll up for response")
117
- clear = gr.Button(value="Clear Chat History!")
118
- clear.style(size="sm")
119
-
120
- llm_chain = init_llm_chain(peft_model, peft_tokenizer)
121
-
122
- query.submit(user, [query, chatbot], [query, chatbot], queue=False).then(bot, chatbot, chatbot)
123
- clear.click(lambda: None, None, chatbot, queue=False)
124
-
125
- interface.queue().launch()
 
 
1
  import torch
2
+ from datasets import load_dataset
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, GenerationConfig
4
+ from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel, prepare_model_for_kbit_training
5
+ from trl import SFTTrainer
6
+ import warnings
7
  warnings.filterwarnings("ignore")
8
 
9
+ data = load_dataset("heliosbrahma/mental_health_chatbot_dataset")
10
+ model_name = "vilsonrodrigues/falcon-7b-instruct-sharded" # sharded falcon-7b model
11
+
12
+ bnb_config = BitsAndBytesConfig(
13
+ load_in_4bit=True, # load model in 4-bit precision
14
+ bnb_4bit_quant_type="nf4", # pre-trained model should be quantized in 4-bit NF format
15
+ bnb_4bit_use_double_quant=True, # Using double quantization as mentioned in QLoRA paper
16
+ bnb_4bit_compute_dtype=torch.bf16, # During computation, pre-trained model should be loaded in BF16 format
17
+ )
18
+
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ model_name,
21
+ quantization_config=bnb_config, # Use bitsandbytes config
22
+ device_map="auto", # Specifying device_map="auto" so that HF Accelerate will determine which GPU to put each layer of the model on
23
+ trust_remote_code=True, # Set trust_remote_code=True to use falcon-7b model with custom code
24
+ )
25
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) # Set trust_remote_code=True
26
+ tokenizer.pad_token = tokenizer.eos_token # Setting pad_token same as eos_token
27
+ model = prepare_model_for_kbit_training(model)
28
+
29
+ lora_alpha = 32 # scaling factor for the weight matrices
30
+ lora_dropout = 0.05 # dropout probability of the LoRA layers
31
+ lora_rank = 16 # dimension of the low-rank matrices
32
+
33
+ peft_config = LoraConfig(
34
+ lora_alpha=lora_alpha,
35
+ lora_dropout=lora_dropout,
36
+ r=lora_rank,
37
+ bias="none", # setting to 'none' for only training weight params instead of biases
38
+ task_type="CAUSAL_LM",
39
+ target_modules=[ # Setting names of modules in falcon-7b model that we want to apply LoRA to
40
+ "query_key_value",
41
+ "dense",
42
+ "dense_h_to_4h",
43
+ "dense_4h_to_h",
44
+ ]
45
+ )
46
+
47
+ peft_model = get_peft_model(model, peft_config)
48
+
49
+ output_dir = "./falcon-7b-sharded-fp16-finetuned-mental-health-conversational"
50
+ per_device_train_batch_size = 16 # reduce batch size by 2x if out-of-memory error
51
+ gradient_accumulation_steps = 4 # increase gradient accumulation steps by 2x if batch size is reduced
52
+ optim = "paged_adamw_32bit" # activates the paging for better memory management
53
+ save_strategy="steps" # checkpoint save strategy to adopt during training
54
+ save_steps = 10 # number of updates steps before two checkpoint saves
55
+ logging_steps = 10 # number of update steps between two logs if logging_strategy="steps"
56
+ learning_rate = 2e-4 # learning rate for AdamW optimizer
57
+ max_grad_norm = 0.3 # maximum gradient norm (for gradient clipping)
58
+ max_steps = 70 # training will happen for 70 steps
59
+ warmup_ratio = 0.03 # number of steps used for a linear warmup from 0 to learning_rate
60
+ lr_scheduler_type = "cosine" # learning rate scheduler
61
+
62
+ training_arguments = TrainingArguments(
63
+ output_dir=output_dir,
64
+ per_device_train_batch_size=per_device_train_batch_size,
65
+ gradient_accumulation_steps=gradient_accumulation_steps,
66
+ optim=optim,
67
+ save_steps=save_steps,
68
+ logging_steps=logging_steps,
69
+ learning_rate=learning_rate,
70
+ bf16=True,
71
+ max_grad_norm=max_grad_norm,
72
+ max_steps=max_steps,
73
+ warmup_ratio=warmup_ratio,
74
+ group_by_length=True,
75
+ lr_scheduler_type=lr_scheduler_type,
76
+ push_to_hub=True,
77
+ )
78
+ trainer = SFTTrainer(
79
+ model=peft_model,
80
+ train_dataset=data['train'],
81
+ peft_config=peft_config,
82
+ dataset_text_field="text",
83
+ ac=1024,
84
+ tokenizer=tokenizer,
85
+ args=training_arguments,
86
+ )
87
+
88
+ # upcasting the layer norms in torch.bfloat16 for more stable training
89
+ for name, module in trainer.model.named_modules():
90
+ if "norm" in name:
91
+ module = module.to(torch.bfloat16)
92
+
93
+ peft_model.config.use_cache = False
94
+ trainer.train()
95
+ trainer.push_to_hub("therapx")
96
+
97
+ # import gradio as gr
98
+ # import torch
99
+ # import re, os, warnings
100
+ # from langchain import PromptTemplate, LLMChain
101
+ # from langchain.llms.base import LLM
102
+ # from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
103
+ # from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel
104
+ # warnings.filterwarnings("ignore")
105
+
106
+ # def init_model_and_tokenizer(PEFT_MODEL):
107
+ # config = PeftConfig.from_pretrained(PEFT_MODEL)
108
+ # bnb_config = BitsAndBytesConfig(
109
+ # load_in_4bit=True,
110
+ # bnb_4bit_quant_type="nf4",
111
+ # bnb_4bit_use_double_quant=True,
112
+ # bnb_4bit_compute_dtype=torch.float16,
113
+ # )
114
+
115
+ # peft_base_model = AutoModelForCausalLM.from_pretrained(
116
+ # config.base_model_name_or_path,
117
+ # return_dict=True,
118
+ # quantization_config=bnb_config,
119
+ # device_map="auto",
120
+ # trust_remote_code=True,
121
+ # )
122
+
123
+ # peft_model = PeftModel.from_pretrained(peft_base_model, PEFT_MODEL)
124
+
125
+ # peft_tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
126
+ # peft_tokenizer.pad_token = peft_tokenizer.eos_token
127
+
128
+ # return peft_model, peft_tokenizer
129
+
130
+ # def init_llm_chain(peft_model, peft_tokenizer):
131
+ # class CustomLLM(LLM):
132
+ # def _call(self, prompt: str, stop=None, run_manager=None) -> str:
133
+ # device = "cuda:0"
134
+ # peft_encoding = peft_tokenizer(prompt, return_tensors="pt").to(device)
135
+ # peft_outputs = peft_model.generate(input_ids=peft_encoding.input_ids, generation_config=GenerationConfig(max_new_tokens=256, pad_token_id = peft_tokenizer.eos_token_id, \
136
+ # eos_token_id = peft_tokenizer.eos_token_id, attention_mask = peft_encoding.attention_mask, \
137
+ # temperature=0.4, top_p=0.6, repetition_penalty=1.3, num_return_sequences=1,))
138
+ # peft_text_output = peft_tokenizer.decode(peft_outputs[0], skip_special_tokens=True)
139
+ # return peft_text_output
140
+
141
+ # @property
142
+ # def _llm_type(self) -> str:
143
+ # return "custom"
144
+
145
+ # llm = CustomLLM()
146
+
147
+ # template = """Answer the following question truthfully.
148
+ # If you don't know the answer, respond 'Sorry, I don't know the answer to this question.'.
149
+ # If the question is too complex, respond 'Kindly, consult a psychiatrist for further queries.'.
150
+
151
+ # Example Format:
152
+ # <HUMAN>: question here
153
+ # <ASSISTANT>: answer here
154
+
155
+ # Begin!
156
+
157
+ # <HUMAN>: {query}
158
+ # <ASSISTANT>:"""
159
+
160
+ # prompt = PromptTemplate(template=template, input_variables=["query"])
161
+ # llm_chain = LLMChain(prompt=prompt, llm=llm)
162
+
163
+ # return llm_chain
164
+
165
+ # def user(user_message, history):
166
+ # return "", history + [[user_message, None]]
167
+
168
+ # def bot(history):
169
+ # if len(history) >= 2:
170
+ # query = history[-2][0] + "\n" + history[-2][1] + "\nHere, is the next QUESTION: " + history[-1][0]
171
+ # else:
172
+ # query = history[-1][0]
173
+
174
+ # bot_message = llm_chain.run(query)
175
+ # bot_message = post_process_chat(bot_message)
176
+
177
+ # history[-1][1] = ""
178
+ # history[-1][1] += bot_message
179
+ # return history
180
+
181
+ # def post_process_chat(bot_message):
182
+ # try:
183
+ # bot_message = re.findall(r"<ASSISTANT>:.*?Begin!", bot_message, re.DOTALL)[1]
184
+ # except IndexError:
185
+ # pass
186
+
187
+ # bot_message = re.split(r'<ASSISTANT>\:?\s?', bot_message)[-1].split("Begin!")[0]
188
+
189
+ # bot_message = re.sub(r"^(.*?\.)(?=\n|$)", r"\1", bot_message, flags=re.DOTALL)
190
+ # try:
191
+ # bot_message = re.search(r"(.*\.)", bot_message, re.DOTALL).group(1)
192
+ # except AttributeError:
193
+ # pass
194
+
195
+ # bot_message = re.sub(r"\n\d.$", "", bot_message)
196
+ # bot_message = re.split(r"(Goodbye|Take care|Best Wishes)", bot_message, flags=re.IGNORECASE)[0].strip()
197
+ # bot_message = bot_message.replace("\n\n", "\n")
198
+
199
+ # return bot_message
200
+
201
+ # model = "heliosbrahma/falcon-7b-sharded-bf16-finetuned-mental-health-conversational"
202
+ # peft_model, peft_tokenizer = init_model_and_tokenizer(PEFT_MODEL = model)
203
 
204
+ # with gr.Blocks() as interface:
205
+ # gr.HTML("""<h1>Welcome to Mental Health Conversational AI</h1>""")
206
+ # gr.Markdown(
207
+ # """Chatbot specifically designed to provide psychoeducation, offer non-judgemental and empathetic support, self-assessment and monitoring.<br>
208
+ # Get instant response for any mental health related queries. If the chatbot seems you need external support, then it will respond appropriately.<br>"""
209
+ # )
210
 
211
+ # chatbot = gr.Chatbot()
212
+ # query = gr.Textbox(label="Type your query here, then press 'enter' and scroll up for response")
213
+ # clear = gr.Button(value="Clear Chat History!")
214
+ # clear.style(size="sm")
215
 
216
+ # llm_chain = init_llm_chain(peft_model, peft_tokenizer)
217
 
218
+ # query.submit(user, [query, chatbot], [query, chatbot], queue=False).then(bot, chatbot, chatbot)
219
+ # clear.click(lambda: None, None, chatbot, queue=False)
 
 
 
 
 
 
 
 
220
 
221
+ # interface.queue().launch()