File size: 2,181 Bytes
06e7970
 
 
 
 
 
3bd7542
 
 
 
06e7970
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bd7542
 
06e7970
3bd7542
06e7970
 
 
 
 
 
 
 
 
3bd7542
06e7970
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5280de
 
06e7970
 
 
 
 
 
4c415fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
from transformers import pipeline
import numpy as np
import pandas as pd
import re
import torch
import altair as alt


alt.data_transformers.disable_max_rows()

number_re = re.compile(r"\.[0-9]*\.")

STATE_DICT = {}
DATA = pd.DataFrame()


def scatter_plot_fn(group_name):
    global DATA
    df = DATA[DATA.group_name == group_name]
    return gr.LinePlot.update(
        value=df,
        x="rank",
        y="val",
        color="layer",
        tooltip=["val", "rank", "layer"],
        caption="",
    )


def find_choices(state_dict):
    if not state_dict:
        return []
    global DATA
    layered_tensors = [k for k, v in state_dict.items() if number_re.findall(k) and len(v.shape) == 2]
    choices = set()
    data = []
    for name in layered_tensors:
        group_name = number_re.sub(".{N}.", name)
        choices.add(group_name)
        layer = int(number_re.search(name).group()[1:-1])

        svdvals = torch.linalg.svdvals(state_dict[name])
        svdvals /= svdvals.sum()
        for rank, val in enumerate(svdvals.tolist()):
            data.append((name, layer, group_name, rank, val))
    data = np.array(data)
    DATA = pd.DataFrame(data, columns=["name", "layer", "group_name", "rank", "val"])
    DATA["val"] = DATA["val"].astype("float")
    DATA["layer"] = DATA["layer"].astype("category")
    DATA["rank"] = DATA["rank"].astype("int32")
    return choices


def weights_fn(model_id):
    global STATE_DICT
    try:
        pipe = pipeline(model=model_id)
        STATE_DICT = pipe.model.state_dict()
    except Exception as e:
        print(e)
        STATE_DICT = {}
    choices = find_choices(STATE_DICT)
    return gr.Dropdown.update(choices=choices)


with gr.Blocks() as scatter_plot:
    with gr.Row():
        with gr.Column():
            model_id = gr.Textbox(label="model_id")
            weights = gr.Dropdown(label="weights")
        with gr.Column():
            plot = gr.LinePlot(show_label=False).style(container=True)
    model_id.change(weights_fn, inputs=model_id, outputs=weights)
    weights.change(fn=scatter_plot_fn, inputs=weights, outputs=plot)

if __name__ == "__main__":
    scatter_plot.launch()