complete-dope commited on
Commit
e29386b
1 Parent(s): 8a8d3b0
Files changed (3) hide show
  1. main.py +155 -0
  2. prov_data2.jsonl +0 -0
  3. requirements.txt +9 -0
main.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #this repo contains the code for mixtral model for finding the icd-10 codes and this scripts runs well on the single GPU and is now trying to run with the multiple GPU and i need to make sure that this script runs in a multi gpu environment
2
+
3
+ import warnings
4
+ warnings.filterwarnings("ignore")
5
+
6
+ from accelerate import FullyShardedDataParallelPlugin, Accelerator
7
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
8
+ from datasets import load_dataset
9
+ import torch
10
+ import transformers
11
+ from datetime import datetime
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
13
+ from peft import prepare_model_for_kbit_training , LoraConfig, get_peft_model
14
+
15
+ fsdp_plugin = FullyShardedDataParallelPlugin(
16
+ state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
17
+ optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
18
+ ) #made to distribute the weights across multi gpu env
19
+
20
+ accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
21
+
22
+ ## Loading the dataset
23
+ def Profiler_load_dataset(data_files , field = 'train'):
24
+ return load_dataset('json' , data_files = data_files , field= field)
25
+
26
+
27
+ ## high ram used here
28
+ train_dataset = Profiler_load_dataset(data_files='/content/prov_data2.jsonl', field='train')
29
+ eval_dataset = Profiler_load_dataset(data_files='/content/prov_data2.jsonl', field='test')
30
+
31
+
32
+ ### What is the use of formatting function ?
33
+ ## It formats the data in this form for the mixtral model ( means easy to use in an instruction fine-tuning scenario )
34
+ def format_fun(example):
35
+ text = f" The ICD10 code for {example['Input']} is , {example['Output']} "
36
+ return text
37
+
38
+ # base_model_id = "mistralai/Mixtral-8x7B-v0.1"
39
+ #try out different models from the hugging faces library ( the best would have been the once released by the authors but that wont be quantised so dont think it would work well !!
40
+
41
+
42
+ base_model_id = '' # this is passed in as arg -> args.model_id
43
+
44
+ bnb_config = BitsAndBytesConfig(
45
+ load_in_4bit=True,
46
+ bnb_4bit_use_double_quant=True,
47
+ bnb_4bit_compute_dtype=torch.bfloat16
48
+ )
49
+
50
+ model = AutoModelForCausalLM.from_pretrained(base_model_id, quantization_config=bnb_config, device_map="cuda")
51
+ ## The model got loaded and works !!
52
+
53
+
54
+ tokenizer = AutoTokenizer.from_pretrained(
55
+ base_model_id,
56
+ padding_side="left",
57
+ add_eos_token=True,
58
+ add_bos_token=True,
59
+ )
60
+ tokenizer.pad_token = tokenizer.eos_token
61
+
62
+
63
+ max_length = 50 #max number of word generation
64
+ def generate_and_tokenize_prompt(prompt):
65
+ result = tokenizer(
66
+ format_fun(prompt),
67
+ truncation=True,
68
+ max_length=max_length,
69
+ padding="max_length",
70
+ )
71
+ result["labels"] = result["input_ids"].copy() #what this do ??
72
+ return result
73
+
74
+ tokenized_train_dataset = train_dataset.map(generate_and_tokenize_prompt)
75
+ tokenized_val_dataset = eval_dataset.map(generate_and_tokenize_prompt)
76
+
77
+
78
+ #Fine tuning the model
79
+ model.gradient_checkpointing_enable()
80
+ model = prepare_model_for_kbit_training(model)
81
+
82
+ config = LoraConfig(
83
+ r=32,
84
+ lora_alpha=64,
85
+ target_modules=[
86
+ "q_proj",
87
+ "k_proj",
88
+ "v_proj",
89
+ "o_proj",
90
+ "w1",
91
+ "w2",
92
+ "w3",
93
+ "lm_head",
94
+ ],
95
+ bias="none",
96
+ lora_dropout=0.05, # Conventional
97
+ task_type="CAUSAL_LM",
98
+ )
99
+
100
+ model = get_peft_model(model, config)
101
+
102
+ if torch.cuda.device_count() > 1: # If more than 1 GPU
103
+ model.is_parallelizable = True
104
+ model.model_parallel = True
105
+
106
+
107
+ project = "icd-finetune"
108
+ base_model_name = "mixtral"
109
+ run_name = base_model_name + "-" + project
110
+ output_dir = "./" + run_name
111
+
112
+ trainer = transformers.Trainer(
113
+ model=model,
114
+ train_dataset=tokenized_train_dataset,
115
+ eval_dataset=tokenized_val_dataset,
116
+ args=transformers.TrainingArguments(
117
+ output_dir=output_dir,
118
+ warmup_steps=1,
119
+ per_device_train_batch_size=2,
120
+ gradient_accumulation_steps=1,
121
+ gradient_checkpointing=True,
122
+ max_steps=300,
123
+ learning_rate=2.5e-5, # Want a small lr for finetuning
124
+ fp16=True,
125
+ optim="paged_adamw_8bit",
126
+ logging_steps=25, # When to start reporting loss
127
+ logging_dir="./logs", # Directory for storing logs
128
+ save_strategy="steps", # Save the model checkpoint every logging step
129
+ save_steps=25, # Save checkpoints every 50 steps
130
+ evaluation_strategy="steps", # Evaluate the model every logging step
131
+ eval_steps=25, # Evaluate and save checkpoints every 50 steps
132
+ do_eval=True, # Perform evaluation at the end of training
133
+ ),
134
+ data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
135
+ )
136
+
137
+ model.config.use_cache = False # silence the warnings. Please re-enable for inference!
138
+ trainer.train()
139
+
140
+
141
+
142
+ # Implement RAG on the fine tuned model
143
+
144
+
145
+
146
+
147
+
148
+ # final model prepared
149
+ '''
150
+ 1) Make sure the model runs on multi gpu script !
151
+ 2) The dataset is loaded
152
+ 3) The langchain implementation to oversee the prompt generation guide
153
+ 4) Also try the bert models rather than directly using the mixtral model ()
154
+ 5) Once the model is trained copy the checkpoint folder and paste in a local env
155
+ '''
prov_data2.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ git+https://github.com/huggingface/transformers.git
4
+ git+https://github.com/huggingface/peft.git
5
+ git+https://github.com/huggingface/accelerate.git
6
+ datasets
7
+ scipy
8
+ ipywidgets
9
+ matplotlib