File size: 6,533 Bytes
e1af88b
 
31ff879
 
 
 
 
 
 
 
cdcabe6
31ff879
 
cdcabe6
31ff879
e1af88b
b3f84d4
31ff879
 
 
adad0ff
 
 
 
 
 
 
31ff879
 
 
 
 
c7eb635
adad0ff
ab31e5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5ed554
ab31e5e
c7eb635
 
 
 
c20ea98
 
 
 
 
 
c7eb635
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02d2380
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7eb635
 
31ff879
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
---
license: mit
datasets:
- llm-blender/mix-instruct
metrics:
- BERTScore
- BLEURT
- BARTScore
- Pairwise Rank
tags:
- pair-ranker
- pair_ranker
- reward_model
- reward-model
- RLHF
---

PairRanker used in llm-blender, trained on deberta-v3-large. This is the ranker model used in experiments in LLM-Blender paper, 
which is trained on [mixinstruct](https://huggingface.co/datasets/llm-blender/mix-instruct) dataset for 5 epochs.

- Github: [https://github.com/yuchenlin/LLM-Blender](https://github.com/yuchenlin/LLM-Blender)
- Paper: [https://arxiv.org/abs/2306.02561](https://arxiv.org/abs/2306.02561)


## Statistics

### Context length
|  PairRanker type  | Source max length | Candidate max length | Total max length |
|:-----------------:|:-----------------:|----------------------|------------------|
| [pair-ranker](https://huggingface.co/llm-blender/pair-ranker) (This model)              | 128               | 128                  | 384              |
| [pair-reward-model](https://huggingface.co/llm-blender/pair-reward-model/) | 1224              | 412                  | 2048             |


### MixInstrut Performance

|    **Methods**    | BERTScore | BARTScore |   BLEURT  | GPT-Rank |  Beat Vic(%)  |   Beat OA(%)  |  Top-1(%)  |  Top-2(%)  |  Top-3(%)  |
|:-----------------:|:---------:|:---------:|:---------:|:--------:|:----------:|:----------:|:----------:|:----------:|:----------:|
|   Open Assistant  | **74.68** |   -3.45   | **-0.39** | **3.90** |  **62.78** |     N/A    |    17.35   |    35.67   |    51.98   |
|       Vicuna      |   69.60   | **-3.44** |   -0.61   |   4.13   |     N/A    |  **64.77** |  **25.47** |  **41.23** |  **52.88** |
|       Alpaca      |   71.46   |   -3.57   |   -0.53   |   4.62   |    56.70   |    61.35   |    15.41   |    29.81   |    44.46   |
|       Baize       |   65.57   |   -3.53   |   -0.66   |   4.86   |    52.76   |    56.40   |    14.23   |    26.91   |    38.80   |
|        moss       |   64.85   |   -3.65   |   -0.73   |   5.09   |    51.62   |    51.79   |    15.93   |    27.52   |    38.27   |
|      ChatGLM      |   70.38   |   -3.52   |   -0.62   |   5.63   |    44.04   |    45.67   |    9.41    |    19.37   |    28.78   |
|       Koala       |   63.96   |   -3.85   |   -0.84   |   6.76   |    39.93   |    39.01   |    8.15    |    15.72   |    22.55   |
|      Dolly v2     |   62.26   |   -3.83   |   -0.87   |   6.90   |    33.33   |    31.44   |    5.16    |    10.06   |    16.45   |
|     Mosaic MPT    |   63.21   |   -3.72   |   -0.82   |   7.19   |    30.87   |    30.16   |    5.39    |    10.61   |    16.24   |
|      StableLM     |   62.47   |   -4.12   |   -0.98   |   8.71   |    21.55   |    19.87   |    2.33    |    4.74    |    7.96    |
|      Flan-T5      |   64.92   |   -4.57   |   -1.23   |   8.81   |    23.89   |    19.93   |    1.30    |    2.87    |    5.32    |
| Oracle(BERTScore) | **77.67** |   -3.17   |   -0.27   |   3.88   |    54.41   |    38.84   |    20.16   |    38.11   |    53.49   |
|   Oracle(BLEURT)  |   75.02   |   -3.15   | **-0.15** |   3.77   |    55.61   |    45.80   |    21.48   |    39.84   |    55.36   |
| Oracle(BARTScore) |   73.23   | **-2.87** |   -0.38   |   3.69   |    50.32   |    57.01   |    26.10   |    43.70   |    57.33   |
|  Oracle(ChatGPT)  |   70.32   |   -3.33   |   -0.51   | **1.00** | **100.00** | **100.00** | **100.00** | **100.00** | **100.00** |
|     Random    |   66.36   |   -3.76   |   -0.77   |   6.14   |   37.75   |   36.91   |   11.28   |   20.69   |   29.05   |
|  MLM-Scoring  |   64.77   |   -4.03   |   -0.88   |   7.00   |   33.87   |   30.39   |    7.29   |   14.09   |   21.46   |
|     SimCLS    | **73.14** |   -3.22   |   -0.38   |   3.50   |   52.11   |   49.93   |   26.72   |   46.24   |   60.72   |
| SummaReranker |   71.60   |   -3.25   |   -0.41   |   3.66   | **55.63** |   48.46   |   23.89   |   42.44   |   57.54   |
|   [**PairRanker**](https://huggingface.co/llm-blender/pair-ranker)  |   72.97   | **-3.14** | **-0.37** | **3.20** |   54.76   | **57.79** | **30.08** | **50.68** | **65.12** |

## Usage Example
Since PairRanker contains some custom layers and tokens. We recommend use our pairranker with our llm-blender python repo.
Otherwise, loading it directly with hugging face `from_pretrained()` API will encounter errors.

- First install `llm-blender`
```bash
pip install git+https://github.com/yuchenlin/LLM-Blender.git
```

- Then use pairranker with the following code:
```python
import llm_blender
# ranker config
ranker_config = llm_blender.RankerConfig()
ranker_config.ranker_type = "pairranker" # only supports pairranker now.
ranker_config.model_type = "deberta"
ranker_config.model_name = "microsoft/deberta-v3-large" # ranker backbone
ranker_config.load_checkpoint = "llm-blender/pair-ranker" # hugging face hub model path or your local ranker checkpoint <your checkpoint path>
ranker_config.cache_dir = "./hf_models" # hugging face model cache dir
ranker_config.source_maxlength = 128
ranker_config.candidate_maxlength = 128
ranker_config.n_tasks = 1 # number of singal that has been used to train the ranker. This checkpoint is trained using BARTScore only, thus being 1.
fuser_config = llm_blender.GenFuserConfig()
# ignore fuser config as we don't use it here. You can load it if you want
blender_config = llm_blender.BlenderConfig()
# blender config
blender_config.device = "cuda" # blender ranker and fuser device
blender = llm_blender.Blender(blender_config, ranker_config, fuser_config)
```

- Then you can rank candidates with the following function

```python
inputs = ["input1", "input2"]
candidates_texts = [["candidate1 for input1", "candidatefor input1"], ["candidate1 for input2", "candidate2 for input2"]]
ranks = blender.rank(inputs, candidates_texts, return_scores=False, batch_size=2)
# ranks is a list of ranks where ranks[i][j] represents the ranks of candidate-j for input-i
```

- Using pairranker to directly compare two candidates
```python
candidates_A = [cands[0] for cands in candidates]
candidates_B = [cands[1] for cands in candidates]
comparison_results = blender.compare(inputs, candidates_A, candidates_B)
# comparison_results is a list of bool, where element[i] denotes whether candidates_A[i] is better than candidates_B[i] for inputs[i]
```

See LLM-Blender Github [README.md](https://github.com/yuchenlin/LLM-Blender#rank-and-fusion)
and jupyter file [blender_usage.ipynb](https://github.com/yuchenlin/LLM-Blender/blob/main/blender_usage.ipynb)
for detailed usage examples.