Narsil HF staff commited on
Commit
3bd7542
1 Parent(s): e5280de

Fixing full spectrum.

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -4,6 +4,10 @@ import numpy as np
4
  import pandas as pd
5
  import re
6
  import torch
 
 
 
 
7
 
8
  number_re = re.compile(r"\.[0-9]*\.")
9
 
@@ -25,10 +29,10 @@ def scatter_plot_fn(group_name):
25
 
26
 
27
  def find_choices(state_dict):
 
 
28
  global DATA
29
- layered_tensors = [
30
- k for k, v in state_dict.items() if number_re.findall(k) and len(v.shape) == 2
31
- ]
32
  choices = set()
33
  data = []
34
  for name in layered_tensors:
@@ -38,7 +42,7 @@ def find_choices(state_dict):
38
 
39
  svdvals = torch.linalg.svdvals(state_dict[name])
40
  svdvals /= svdvals.sum()
41
- for rank, val in enumerate(svdvals.tolist()[:300]):
42
  data.append((name, layer, group_name, rank, val))
43
  data = np.array(data)
44
  DATA = pd.DataFrame(data, columns=["name", "layer", "group_name", "rank", "val"])
@@ -71,4 +75,4 @@ with gr.Blocks() as scatter_plot:
71
  weights.change(fn=scatter_plot_fn, inputs=weights, outputs=plot)
72
 
73
  if __name__ == "__main__":
74
- scatter_plot.launch()
 
4
  import pandas as pd
5
  import re
6
  import torch
7
+ import altair as alt
8
+
9
+
10
+ alt.data_transformers.disable_max_rows()
11
 
12
  number_re = re.compile(r"\.[0-9]*\.")
13
 
 
29
 
30
 
31
  def find_choices(state_dict):
32
+ if not state_dict:
33
+ return []
34
  global DATA
35
+ layered_tensors = [k for k, v in state_dict.items() if number_re.findall(k) and len(v.shape) == 2]
 
 
36
  choices = set()
37
  data = []
38
  for name in layered_tensors:
 
42
 
43
  svdvals = torch.linalg.svdvals(state_dict[name])
44
  svdvals /= svdvals.sum()
45
+ for rank, val in enumerate(svdvals.tolist()):
46
  data.append((name, layer, group_name, rank, val))
47
  data = np.array(data)
48
  DATA = pd.DataFrame(data, columns=["name", "layer", "group_name", "rank", "val"])
 
75
  weights.change(fn=scatter_plot_fn, inputs=weights, outputs=plot)
76
 
77
  if __name__ == "__main__":
78
+ scatter_plot.launch(share=True)