Text Generation
Transformers
Safetensors
English
deberta
reward_model
reward-model
RLHF
evaluation
llm
instruction
reranking
Inference Endpoints
yuchenlin commited on
Commit
90f9aa4
1 Parent(s): 447053d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +40 -24
README.md CHANGED
@@ -61,7 +61,7 @@ blender.loadranker("llm-blender/PairRM") # load PairRM
61
 
62
  ## Usage
63
 
64
- ### Use case 1: Comparing/Ranking output candidates given an instruction
65
 
66
  - Ranking a list candidate responses
67
 
@@ -88,7 +88,8 @@ comparison_results = blender.compare(inputs, candidates_A, candidates_B)
88
  # comparison_results[0]--> True
89
  ```
90
 
91
- - Directly compare two multi-turn conversations given that user's query in each turn are fiexed and responses are different.
 
92
  ```python
93
  conv1 = [
94
  {
@@ -96,7 +97,7 @@ conv1 = [
96
  "role": "USER"
97
  },
98
  {
99
- "content": "<assistant1‘s response 1>",
100
  "role": "ASSISTANT"
101
  },
102
  ...
@@ -107,7 +108,7 @@ conv2 = [
107
  "role": "USER"
108
  },
109
  {
110
- "content": "<assistant2's response 1>",
111
  "role": "ASSISTANT"
112
  },
113
  ...
@@ -115,36 +116,51 @@ conv2 = [
115
  comparison_results = blender.compare_conversations([conv1], [conv2])
116
  # comparison_results is a list of bool, where each element denotes whether all the responses in conv1 together is better than that of conv2
117
  ```
 
118
 
119
- ### Use case 2: Best-of-n Sampling (Decoding Enhancment)
120
- **Best-of-n Sampling**, aka, rejection sampling, is a strategy to enhance the response quality by selecting the one that was ranked highest by the reward model (Learn more at[OpenAI WebGPT section 3.2](https://arxiv.org/pdf/2112.09332.pdf) and [OpenAI Blog](https://openai.com/research/measuring-goodharts-law)).
121
 
122
- Best-of-n sampling is a easy way to imporve your llm power with just a few lines of code. An example of applying on zephyr is as follows.
 
 
123
 
124
  ```python
 
 
125
  from transformers import AutoTokenizer, AutoModelForCausalLM
126
-
127
  tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
128
  model = AutoModelForCausalLM.from_pretrained("HuggingFaceH4/zephyr-7b-beta", device_map="auto")
 
129
 
130
- inputs = [...] # your list of inputs
131
- system_message = {
132
- "role": "system",
133
- "content": "You are a friendly chatbot who always responds in the style of a pirate",
134
- }
135
- messages = [
136
- [
137
- system_message,
138
- {"role": "user", "content": _input},
139
- ]
140
- for _input in zip(inputs)
141
- ]
142
  prompts = [tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True) for m in messages]
 
 
 
 
 
 
 
 
 
 
143
  outputs = blender.best_of_n_generate(model, tokenizer, prompts, n=10)
144
- print("### Prompt:")
145
- print(prompts[0])
146
- print("### best-of-n generations:")
147
- print(outputs[0])
 
 
 
 
 
 
 
 
 
148
  ```
149
 
150
  ### Use case 3: RLHF
 
61
 
62
  ## Usage
63
 
64
+ ### Use Case 1: Comparing/Ranking output candidates given an instruction
65
 
66
  - Ranking a list candidate responses
67
 
 
88
  # comparison_results[0]--> True
89
  ```
90
 
91
+ <details><summary> Comparing two multi-turn conversations. </summary>
92
+
93
  ```python
94
  conv1 = [
95
  {
 
97
  "role": "USER"
98
  },
99
  {
100
+ "content": "[assistant1‘s response 1]",
101
  "role": "ASSISTANT"
102
  },
103
  ...
 
108
  "role": "USER"
109
  },
110
  {
111
+ "content": "[assistant2's response 1]",
112
  "role": "ASSISTANT"
113
  },
114
  ...
 
116
  comparison_results = blender.compare_conversations([conv1], [conv2])
117
  # comparison_results is a list of bool, where each element denotes whether all the responses in conv1 together is better than that of conv2
118
  ```
119
+ </details>
120
 
121
+
122
+ ### Use Case 2: Best-of-n Sampling (Decoding Enhancment)
123
 
124
+ **Best-of-n Sampling**, aka, rejection sampling, is a strategy to enhance the response quality by selecting the one that was ranked highest by the reward model
125
+ (see more in [OpenAI WebGPT section 3.2](https://arxiv.org/pdf/2112.09332.pdf) and [OpenAI Blog](https://openai.com/research/measuring-goodharts-law)).
126
+ Best-of-n sampling with PairRM is a very easy way to imporve your LLMs with only a few changes of your inference code:
127
 
128
  ```python
129
+ # loading models
130
+ import llm_blender
131
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
132
  tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
133
  model = AutoModelForCausalLM.from_pretrained("HuggingFaceH4/zephyr-7b-beta", device_map="auto")
134
+ system_message = {"role": "system", "content": "You are a friendly chatbot."}
135
 
136
+ # formatting your inputs
137
+ inputs = ["can you tell me a joke about OpenAI?"]
138
+ messages = [[system_message, {"role": "user", "content": _input}] for _input in inputs]
 
 
 
 
 
 
 
 
 
139
  prompts = [tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True) for m in messages]
140
+
141
+ # Conventional generation method
142
+ input_ids = tokenizer(prompts[0], return_tensors="pt").input_ids
143
+ sampled_outputs = model.generate(input_ids, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1)
144
+ print(tokenizer.decode(sampled_outputs[0][len(input_ids[0]):], skip_special_tokens=False))
145
+ # --> The output could be a bad case such as a very short one, e.g., `Sure`
146
+
147
+ # PairRM for best-of-n sampling
148
+ blender = llm_blender.Blender()
149
+ blender.loadranker("llm-blender/PairRM") # load ranker checkpoint
150
  outputs = blender.best_of_n_generate(model, tokenizer, prompts, n=10)
151
+
152
+ print("### Prompt:\n", prompts[0])
153
+ print("### best-of-n generations:\n", outputs[0])
154
+ # --> The output will be much more stable and consistently better than single sampling, for example:
155
+ """
156
+ Sure, here's a joke about OpenAI:
157
+
158
+ Why did OpenAI decide to hire a mime as their new AI researcher?
159
+
160
+ Because they wanted someone who could communicate complex ideas without making a sound!
161
+
162
+ (Note: This is a joke, not a reflection of OpenAI's actual hiring practices.)
163
+ """
164
  ```
165
 
166
  ### Use case 3: RLHF