File size: 4,305 Bytes
8ad9dbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9908090
8ad9dbd
 
 
 
 
bb1bf5f
8ad9dbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02c2615
 
 
 
8ad9dbd
 
71ce6d1
8ad9dbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from threading import Lock
import argparse

import numpy as np
from matplotlib import pyplot as plt
import gradio as gr
import torch
import pandas as pd

from biasprobe import BinaryProbe, PairwiseExtractionRunner, SimplePairPromptBuilder, ProbeConfig


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', '-s', type=int, default=0, help="the random seed")
    parser.add_argument('--port', '-p', type=int, default=8080, help="the port to launch the demo")
    parser.add_argument('--no-cuda', action='store_true', help="Use CPUs instead of GPUs")
    args = parser.parse_args()
    return args


def main():
    args = get_args()
    plt.switch_backend('agg')
    dmap = 'auto'
    mdict = {0: '24GIB'}
    config = ProbeConfig.create_for_model('mistralai/Mistral-7B-Instruct-v0.1')
    probe = BinaryProbe(config).cuda()
    probe.load_state_dict(torch.load('probe.pt'))

    runner = PairwiseExtractionRunner.from_pretrained('mistralai/Mistral-7B-Instruct-v0.1', optimize=False, torch_dtype=torch.float16, max_memory=mdict, device_map=dmap, low_cpu_mem_usage=True)

    @torch.no_grad()
    def run_extraction(prompt):
        builder = SimplePairPromptBuilder(criterion='more positive')
        lst = [x.strip() for x in prompt.lower()[:300].split(',')][:100]
        exp = runner.run_extraction(lst, lst, layers=[15], num_repeat=50, builder=builder, parallel=False, run_inference=True, debug=True, max_new_tokens=2)
        test_ds = exp.make_dataset(15)

        import torch

        raw_scores = []
        preds_list = []
        hs = []

        for idx, (tensor, labels) in enumerate(test_ds):
            with torch.no_grad():
                labels = labels - 1  # 1-indexed

                if tensor.shape[0] != 2:
                    continue

                h = tensor[1] - tensor[0]
                hs.append(h)

                try:
                    x = probe(tensor.unsqueeze(0).cuda().float()).squeeze()
                except IndexError:
                    continue

                pred = [0, 1] if x.item() > 0 else [1, 0]
                pred = np.array(pred)

            if test_ds.original_examples is not None:
                items = [x.content for x in test_ds.original_examples[idx].hits]
                preds_list.append(np.array(items, dtype=object)[labels][pred].tolist())

            raw_scores.append(x.item())

        df = pd.DataFrame({'Win Rate': np.array(raw_scores) > 0, 'Word': [x[0] for x in preds_list]})
        win_df = df.groupby('Word').mean('Win Rate')
        win_df = win_df.reset_index().sort_values('Win Rate')
        win_df['Win Rate'] = [str(x) + '%' for x in (win_df['Win Rate'] * 100).round(2).tolist()]

        return win_df

    with gr.Blocks(css='scrollbar.css') as demo:
        md = '''# BiasProbe: Revealing Preference Biases in Language Model Representations
        What do llamas really "think" about controversial words?
        Type some words below to see how Mistral-7B-Instruct associates them with
        positive and negative emotions.
        Higher win rates indicate that the word is more likely to be associated with
        positive emotions than other words in the list.
        
        Check out our paper, [What Do Llamas Really Think? Revealing Preference Biases in Language Model Representations](http://arxiv.org/abs/2311.18812).
        See our [codebase](https://github.com/castorini/biasprobe) on GitHub.
        '''
        gr.Markdown(md)

        with gr.Row():
            with gr.Column():
                text = gr.Textbox(label='Words', value='Republican, democrat, libertarian, authoritarian')
                submit_btn = gr.Button('Submit', elem_id='submit-btn')
            output = gr.DataFrame(pd.DataFrame({'Word': ['authoritarian', 'republican', 'democrat', 'libertarian'],
                                                'Win Rate': ['44.44%', '81.82%', '100%', '100%']}))

            submit_btn.click(
                fn=run_extraction,
                inputs=[text],
                outputs=[output])

    while True:
        try:
            demo.launch(server_name='0.0.0.0')
        except OSError:
            gr.close_all()
        except KeyboardInterrupt:
            gr.close_all()
            break


if __name__ == '__main__':
    main()