Text Generation
English
crayon
language-technologies
Pascrayon commited on
Commit
9537b41
·
1 Parent(s): 41ed25e

Code for training and inference

Browse files
Files changed (1) hide show
  1. README.md +156 -0
README.md CHANGED
@@ -12,6 +12,162 @@ tags:
12
 
13
  # Bloom 560M Finetuned on Instructions
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # Training Code
16
 
17
  ```python
 
12
 
13
  # Bloom 560M Finetuned on Instructions
14
 
15
+ ## Credit
16
+
17
+ Code 99.99% copied from
18
+ *https://github.com/bofenghuang/vigogne*
19
+ *https://colab.research.google.com/drive/1jCkpikz0J2o20FBQmYmAGdiKmJGOMo-o?usp=sharing#scrollTo=DpYr24pR8T_0*
20
+
21
+
22
+ # Inference Code
23
+
24
+ ```python
25
+
26
+ from peft import PeftModel
27
+ from transformers import PreTrainedTokenizer, PreTrainedModel, AutoTokenizer, AutoModelForCausalLM
28
+ from peft import PeftModelForCausalLM, LoraConfig
29
+ from typing import Optional
30
+ from transformers import GenerationConfig
31
+ import torch
32
+
33
+ PROMPT_DICT = {
34
+ "prompt_input": (
35
+ "Below is a^n instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
36
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
37
+ ),
38
+ "prompt_no_input": (
39
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
40
+ "### Instruction:\n{instruction}\n\n### Response:\n"
41
+ ),
42
+ }
43
+
44
+
45
+ def get_model(model_name_or_path: str, load_in_8bit: bool = True, device_map="auto",
46
+ cpu: bool = False) -> PreTrainedModel:
47
+ if cpu:
48
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map=device_map,
49
+ low_cpu_mem_usage=True)
50
+ else:
51
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, load_in_8bit=load_in_8bit,
52
+ device_map=device_map, torch_dtype=torch.float16)
53
+
54
+ return model
55
+
56
+
57
+ def get_peft_model(model: PreTrainedModel, lora_model_name_or_path: Optional[str] = None) -> PeftModelForCausalLM:
58
+ model = PeftModel.from_pretrained(model, lora_model_name_or_path, torch_dtype=torch.float16)
59
+
60
+ return model
61
+
62
+
63
+ def get_tokenizer(model_name_or_path: str, max_input_len: int) -> PreTrainedTokenizer:
64
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, model_max_length=max_input_len,
65
+ padding_side="right", use_fast=False)
66
+
67
+ return tokenizer
68
+
69
+
70
+ def get_llm_inference_model(base_model_name_or_path: str, lora_model_name_or_path: str, load_in_8bit: bool,
71
+ device_map) -> PeftModel:
72
+ cpu = True if not torch.cuda.is_available() else False
73
+
74
+ model = get_model(base_model_name_or_path, load_in_8bit, device_map, cpu=cpu)
75
+
76
+ model = get_peft_model(model, lora_model_name_or_path=lora_model_name_or_path)
77
+
78
+ if not load_in_8bit:
79
+ model.half()
80
+
81
+ model.eval()
82
+
83
+ if torch.__version__ >= "2":
84
+ model = torch.compile(model)
85
+
86
+ return model
87
+
88
+
89
+ def generate_prompt(example):
90
+ return (
91
+ PROMPT_DICT["prompt_input"].format_map(example)
92
+ if example["input"]
93
+ else PROMPT_DICT["prompt_no_input"].format_map(example)
94
+ )
95
+
96
+
97
+ def infer(instruction: str, input_text: Optional[str] = None, temperature: float = 0.1, top_p: float = 0.95,
98
+ max_new_tokens: int = 512, early_stopping: bool = True, do_sample: bool = True,
99
+ repetition_penalty: float = 2.5) -> str:
100
+ prompt = generate_prompt({"instruction": instruction, "input": input_text})
101
+
102
+ tokenized_inputs = tokenizer(prompt, return_tensors="pt")
103
+
104
+ device = "cuda" if torch.cuda.is_available() else "cpu"
105
+
106
+ input_ids = tokenized_inputs["input_ids"].to(device)
107
+
108
+ generation_config = GenerationConfig(temperature=temperature, top_p=top_p, do_sample=do_sample,
109
+ repetition_penalty=repetition_penalty, early_stopping=early_stopping)
110
+
111
+ with torch.inference_mode():
112
+ generation_output = model.generate(input_ids=input_ids, generation_config=generation_config,
113
+ return_dict_in_generate=True, max_new_tokens=max_new_tokens)
114
+
115
+ output = generation_output.sequences[0]
116
+
117
+ output = tokenizer.decode(output, skip_special_tokens=True)
118
+
119
+ return output.split("### Response:")[1].strip()
120
+
121
+
122
+ base_model_name_or_path = "bigscience/bloom-560m"
123
+
124
+ lora_model_name_or_path = "crayon-coe/alpaca-bloom-560m-en"
125
+
126
+ model = get_llm_inference_model(base_model_name_or_path, lora_model_name_or_path, True, "auto")
127
+
128
+ tokenizer = get_tokenizer(base_model_name_or_path, 512)
129
+
130
+ context = "Write a letter expressing your love for computers"
131
+
132
+ output = infer(context)
133
+
134
+ print(output)
135
+
136
+ # Output
137
+ # I am so grateful to have been able access this wonderful computer system and its amazing features, which I can now use daily with ease.
138
+ #
139
+ # My heartfelt thanks go out in advance of all my friends who are using it as well.
140
+ # Thank you again!
141
+
142
+ ```
143
+
144
+ # Training Parameters
145
+
146
+ ```json
147
+ {
148
+ "max_input_len": 512,
149
+ "load_in_8bit": True,
150
+ "model_name_or_path": "bigscience/bloom-560m",
151
+ "device_map": "auto",
152
+ "bias": "none",
153
+ "lora_dropout": 0.05,
154
+ "lora_alpha": 32,
155
+ "target_modules": ["query_key_value"],
156
+ "task_type": "CAUSAL_LM",
157
+ "lora_r": 16,
158
+ "pad_to_multiple_of": 8,
159
+ "num_train_epochs": 3,
160
+ "learning_rate": 0.0003,
161
+ "gradient_accumulation_steps": 16,
162
+ "per_device_train_batch_size": 8,
163
+ "val_set_size": 500,
164
+ "save_steps": 200,
165
+ "eval_steps": 200,
166
+ "evaluation_strategy": "steps",
167
+ "save_strategy": "steps"
168
+ }
169
+ ```
170
+
171
  # Training Code
172
 
173
  ```python