File size: 3,892 Bytes
06e7970
 
 
 
 
 
3bd7542
 
 
 
06e7970
 
 
 
cfc1bbd
06e7970
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bd7542
cfc1bbd
06e7970
cfc1bbd
 
 
06e7970
 
cfc1bbd
06e7970
 
 
 
cfc1bbd
 
06e7970
 
 
3bd7542
06e7970
 
 
 
 
 
cfc1bbd
06e7970
 
 
cfc1bbd
06e7970
 
cfc1bbd
06e7970
 
 
 
cfc1bbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06e7970
 
 
 
 
e5280de
 
cfc1bbd
06e7970
 
cfc1bbd
 
 
06e7970
cfc1bbd
 
 
06e7970
 
4c415fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gradio as gr
from transformers import pipeline
import numpy as np
import pandas as pd
import re
import torch
import altair as alt


alt.data_transformers.disable_max_rows()

number_re = re.compile(r"\.[0-9]*\.")

STATE_DICT = {}
PIPE = None
DATA = pd.DataFrame()


def scatter_plot_fn(group_name):
    global DATA
    df = DATA[DATA.group_name == group_name]
    return gr.LinePlot.update(
        value=df,
        x="rank",
        y="val",
        color="layer",
        tooltip=["val", "rank", "layer"],
        caption="",
    )


def find_choices(state_dict):
    if not state_dict:
        return [], []
    global DATA
    layered_tensors = [
        k for k, v in state_dict.items() if number_re.findall(k) and len(v.shape) == 2
    ]
    choices = set()
    data = []
    max_layer = 0
    for name in layered_tensors:
        group_name = number_re.sub(".{N}.", name)
        choices.add(group_name)
        layer = int(number_re.search(name).group()[1:-1])
        if layer > max_layer:
            max_layer = layer

        svdvals = torch.linalg.svdvals(state_dict[name])
        svdvals /= svdvals.sum()
        for rank, val in enumerate(svdvals.tolist()):
            data.append((name, layer, group_name, rank, val))
    data = np.array(data)
    DATA = pd.DataFrame(data, columns=["name", "layer", "group_name", "rank", "val"])
    DATA["val"] = DATA["val"].astype("float")
    DATA["layer"] = DATA["layer"].astype("category")
    DATA["rank"] = DATA["rank"].astype("int32")
    return choices, list(range(max_layer + 1))


def weights_fn(model_id):
    global STATE_DICT, PIPE
    try:
        pipe = pipeline(model=model_id)
        PIPE = pipe
        STATE_DICT = pipe.model.state_dict()
    except Exception as e:
        print(e)
        STATE_DICT = {}
    choices, layers = find_choices(STATE_DICT)
    return [gr.Dropdown.update(choices=choices), gr.Dropdown.update(choices=layers)]


def layer_fn(weights, layer):
    k = 5
    directions = 10

    embeddings = PIPE.model.get_input_embeddings().weight
    weight_name = weights.replace("{N}", str(layer))

    weight = STATE_DICT[weight_name]

    U, S, Vh = torch.linalg.svd(weight)

    D = U if U.shape[0] == embeddings.shape[0] else Vh

    # words = D[:directions].matmul(embeddings.T).topk(k=k)
    # words_t = D[:, :directions].T.matmul(embeddings.T).topk(k=k)

    # Cosine similarity
    words = (
        (D[:directions] / D[:directions].norm(dim=0))
        .matmul(embeddings.T / embeddings.T.norm(dim=0))
        .topk(k=k)
    )
    words_t = (
        (D[:, :directions].T / D[:, :directions].norm(dim=1))
        .matmul(embeddings.T / embeddings.T.norm(dim=0))
        .topk(k=k)
    )

    data = [[PIPE.tokenizer.decode(w) for w in indices] for indices in words.indices]
    data = np.array(data)
    data = pd.DataFrame(data)

    data_t = [
        [PIPE.tokenizer.decode(w) for w in indices] for indices in words_t.indices
    ]
    data_t = np.array(data_t)
    data_t = pd.DataFrame(data_t)

    return (
        gr.Dataframe.update(value=data, interactive=False),
        gr.Dataframe.update(value=data_t, interactive=False),
    )


with gr.Blocks() as scatter_plot:
    with gr.Row():
        with gr.Column():
            model_id = gr.Textbox(label="model_id")
            weights = gr.Dropdown(label="weights")
            layer = gr.Dropdown(label="layer")
        with gr.Column():
            plot = gr.LinePlot(show_label=False).style(container=True)
            directions = gr.Dataframe(interactive=False)
            directions_t = gr.Dataframe(interactive=False)
    model_id.change(weights_fn, inputs=model_id, outputs=[weights, layer])
    weights.change(fn=scatter_plot_fn, inputs=weights, outputs=plot)
    layer.change(
        fn=layer_fn, inputs=[weights, layer], outputs=[directions, directions_t]
    )

if __name__ == "__main__":
    scatter_plot.launch()