Narsil HF staff commited on
Commit
06e7970
·
verified ·
1 Parent(s): 94ea337

Spectrum visalizer.

Browse files
Files changed (2) hide show
  1. app.py +74 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ import numpy as np
4
+ import pandas as pd
5
+ import re
6
+ import torch
7
+
8
+ number_re = re.compile(r"\.[0-9]*\.")
9
+
10
+ STATE_DICT = {}
11
+ DATA = pd.DataFrame()
12
+
13
+
14
+ def scatter_plot_fn(group_name):
15
+ global DATA
16
+ df = DATA[DATA.group_name == group_name]
17
+ return gr.LinePlot.update(
18
+ value=df,
19
+ x="rank",
20
+ y="val",
21
+ color="layer",
22
+ tooltip=["val", "rank", "layer"],
23
+ caption="",
24
+ )
25
+
26
+
27
+ def find_choices(state_dict):
28
+ global DATA
29
+ layered_tensors = [
30
+ k for k, v in state_dict.items() if number_re.findall(k) and len(v.shape) == 2
31
+ ]
32
+ choices = set()
33
+ data = []
34
+ for name in layered_tensors:
35
+ group_name = number_re.sub(".{N}.", name)
36
+ choices.add(group_name)
37
+ layer = int(number_re.search(name).group()[1:-1])
38
+
39
+ svdvals = torch.linalg.svdvals(state_dict[name])
40
+ svdvals /= svdvals.sum()
41
+ for rank, val in enumerate(svdvals.tolist()[:20]):
42
+ data.append((name, layer, group_name, rank, val))
43
+ data = np.array(data)
44
+ DATA = pd.DataFrame(data, columns=["name", "layer", "group_name", "rank", "val"])
45
+ DATA["val"] = DATA["val"].astype("float")
46
+ DATA["layer"] = DATA["layer"].astype("category")
47
+ DATA["rank"] = DATA["rank"].astype("int32")
48
+ return choices
49
+
50
+
51
+ def weights_fn(model_id):
52
+ global STATE_DICT
53
+ try:
54
+ pipe = pipeline(model=model_id)
55
+ STATE_DICT = pipe.model.state_dict()
56
+ except Exception as e:
57
+ print(e)
58
+ STATE_DICT = {}
59
+ choices = find_choices(STATE_DICT)
60
+ return gr.Dropdown.update(choices=choices)
61
+
62
+
63
+ with gr.Blocks() as scatter_plot:
64
+ with gr.Row():
65
+ with gr.Column():
66
+ model_id = gr.Textbox(value="gpt")
67
+ weights = gr.Dropdown(choices=["qkv", "c_fc"])
68
+ with gr.Column():
69
+ plot = gr.LinePlot(show_label=False).style(container=True)
70
+ model_id.change(weights_fn, inputs=model_id, outputs=weights)
71
+ weights.change(fn=scatter_plot_fn, inputs=weights, outputs=plot)
72
+
73
+ if __name__ == "__main__":
74
+ scatter_plot.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ numpy
4
+ pandas