File size: 3,102 Bytes
56d19ee
 
 
 
 
 
92bc8af
7a8af0a
3b53a17
 
519f2b4
56d19ee
b4759d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
---
license: apache-2.0
language:
- en
---

GPT-J (with value head weights) trained on HH with PPO following [@reciprocated's](https://github.com/reciprocated) `trlx` example [here](https://github.com/CarperAI/trlx/blob/2f90ba0ecd640ae18cd62adb5e934a4b779f534b/examples/hh/ppo_hh.py).

- Dataset: [Dahoas/full-hh-rlhf](https://huggingface.co/datasets/Dahoas/full-hh-rlhf)
- Logs: https://wandb.ai/jon-tow/trlx/reports/hh-gpt-j--VmlldzozODE1NjAw
- Notebook: https://colab.research.google.com/drive/1B-XKZv7h6u_pkyvckGocukEX5zLmACqc

Usage:

```python
from transformers import AutoTokenizer
from trlx.models.modeling_ppo import AutoModelForCausalLMWithHydraValueHead

model = AutoModelForCausalLMWithHydraValueHead.from_pretrained("jon-tow/hh-gpt-j")
# original_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

prompt_1 = """\
Human: Hello, can you help me?
Assistant: Sure, what can I do for you?
Human: I'm looking for a good recipe for a strawberry cake. What ingredients do I need?
Assistant:\
"""
prompt_2 = """\
Human: Hi! What kind of music do you like?
Assistant: I like all kinds of music.
Human: I'm trying to learn how to play the guitar. Do you have any tips?
Assistant:\
"""
prompts = [prompt_1, prompt_2]
inputs = tokenizer(
    [prompt_1, prompt_2],
    return_tensors="pt",
    padding=True,
)

samples = model.generate(
    **inputs,
    max_new_tokens=64,
    top_k=0,
    top_p=1.0,
    do_sample=True,
)

responses = []
prompt_tokens_lengths = [len(tokenizer.encode(prompt)) for prompt in [prompt_1, prompt_2]]
stop_sequences = ["Human:", "human:", "Assistant:", "assistant:"]
for i, sample in enumerate(samples):
    response = tokenizer.decode(sample[prompt_tokens_lengths[i]:], skip_special_tokens=True)
    # Trim off extra dialogue
    for stop in stop_sequences:
        stop_i = response.find(stop)
        if stop_i >= 0:
            response = response[:stop_i].rstrip()
    responses.append(response)

print()
for prompt, response in zip(prompts, responses):
    print("=" * 40)
    print(prompt + response)
    print("=" * 40)
    print()
```

Output:
```
========================================
Human: Hello, can you help me?
Assistant: Sure, what can I do for you?
Human: I'm looking for a good recipe for a strawberry cake. What ingredients do I need?
Assistant: Is strawberry flavour a primary flavour you want in the cake?
========================================

========================================
Human: Hi! What kind of music do you like?
Assistant: I like all kinds of music.
Human: I'm trying to learn how to play the guitar. Do you have any tips?
Assistant: One thing you can try is to form chords and strums. Form chords and strums will help you to practice and learn how to play instruments easily. You can also download free music online. Besure to check out different genres and instruments. You don't have to learn everyone all at once. Learning the basics is
========================================
```