|
--- |
|
license: apache-2.0 |
|
language: |
|
- en |
|
--- |
|
|
|
GPT-J (with value head weights) trained on HH with PPO following [@reciprocated's](https://github.com/reciprocated) `trlx` example [here](https://github.com/CarperAI/trlx/blob/2f90ba0ecd640ae18cd62adb5e934a4b779f534b/examples/hh/ppo_hh.py). |
|
|
|
- Dataset: [Dahoas/full-hh-rlhf](https://huggingface.co/datasets/Dahoas/full-hh-rlhf) |
|
- Logs: https://wandb.ai/jon-tow/trlx/reports/hh-gpt-j--VmlldzozODE1NjAw |
|
- Notebook: https://colab.research.google.com/drive/1B-XKZv7h6u_pkyvckGocukEX5zLmACqc |
|
|
|
Usage: |
|
|
|
```python |
|
from transformers import AutoTokenizer |
|
from trlx.models.modeling_ppo import AutoModelForCausalLMWithHydraValueHead |
|
|
|
model = AutoModelForCausalLMWithHydraValueHead.from_pretrained("jon-tow/hh-gpt-j") |
|
# original_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.padding_side = "left" |
|
|
|
prompt_1 = """\ |
|
Human: Hello, can you help me? |
|
Assistant: Sure, what can I do for you? |
|
Human: I'm looking for a good recipe for a strawberry cake. What ingredients do I need? |
|
Assistant:\ |
|
""" |
|
prompt_2 = """\ |
|
Human: Hi! What kind of music do you like? |
|
Assistant: I like all kinds of music. |
|
Human: I'm trying to learn how to play the guitar. Do you have any tips? |
|
Assistant:\ |
|
""" |
|
prompts = [prompt_1, prompt_2] |
|
inputs = tokenizer( |
|
[prompt_1, prompt_2], |
|
return_tensors="pt", |
|
padding=True, |
|
) |
|
|
|
samples = model.generate( |
|
**inputs, |
|
max_new_tokens=64, |
|
top_k=0, |
|
top_p=1.0, |
|
do_sample=True, |
|
) |
|
|
|
responses = [] |
|
prompt_tokens_lengths = [len(tokenizer.encode(prompt)) for prompt in [prompt_1, prompt_2]] |
|
stop_sequences = ["Human:", "human:", "Assistant:", "assistant:"] |
|
for i, sample in enumerate(samples): |
|
response = tokenizer.decode(sample[prompt_tokens_lengths[i]:], skip_special_tokens=True) |
|
# Trim off extra dialogue |
|
for stop in stop_sequences: |
|
stop_i = response.find(stop) |
|
if stop_i >= 0: |
|
response = response[:stop_i].rstrip() |
|
responses.append(response) |
|
|
|
print() |
|
for prompt, response in zip(prompts, responses): |
|
print("=" * 40) |
|
print(prompt + response) |
|
print("=" * 40) |
|
print() |
|
``` |
|
|
|
Output: |
|
``` |
|
======================================== |
|
Human: Hello, can you help me? |
|
Assistant: Sure, what can I do for you? |
|
Human: I'm looking for a good recipe for a strawberry cake. What ingredients do I need? |
|
Assistant: Is strawberry flavour a primary flavour you want in the cake? |
|
======================================== |
|
|
|
======================================== |
|
Human: Hi! What kind of music do you like? |
|
Assistant: I like all kinds of music. |
|
Human: I'm trying to learn how to play the guitar. Do you have any tips? |
|
Assistant: One thing you can try is to form chords and strums. Form chords and strums will help you to practice and learn how to play instruments easily. You can also download free music online. Besure to check out different genres and instruments. You don't have to learn everyone all at once. Learning the basics is |
|
======================================== |
|
``` |