Narsil HF staff commited on
Commit
e5280de
1 Parent(s): 06e7970

Better visualization.

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -38,7 +38,7 @@ def find_choices(state_dict):
38
 
39
  svdvals = torch.linalg.svdvals(state_dict[name])
40
  svdvals /= svdvals.sum()
41
- for rank, val in enumerate(svdvals.tolist()[:20]):
42
  data.append((name, layer, group_name, rank, val))
43
  data = np.array(data)
44
  DATA = pd.DataFrame(data, columns=["name", "layer", "group_name", "rank", "val"])
@@ -63,8 +63,8 @@ def weights_fn(model_id):
63
  with gr.Blocks() as scatter_plot:
64
  with gr.Row():
65
  with gr.Column():
66
- model_id = gr.Textbox(value="gpt")
67
- weights = gr.Dropdown(choices=["qkv", "c_fc"])
68
  with gr.Column():
69
  plot = gr.LinePlot(show_label=False).style(container=True)
70
  model_id.change(weights_fn, inputs=model_id, outputs=weights)
 
38
 
39
  svdvals = torch.linalg.svdvals(state_dict[name])
40
  svdvals /= svdvals.sum()
41
+ for rank, val in enumerate(svdvals.tolist()[:300]):
42
  data.append((name, layer, group_name, rank, val))
43
  data = np.array(data)
44
  DATA = pd.DataFrame(data, columns=["name", "layer", "group_name", "rank", "val"])
 
63
  with gr.Blocks() as scatter_plot:
64
  with gr.Row():
65
  with gr.Column():
66
+ model_id = gr.Textbox(label="model_id")
67
+ weights = gr.Dropdown(label="weights")
68
  with gr.Column():
69
  plot = gr.LinePlot(show_label=False).style(container=True)
70
  model_id.change(weights_fn, inputs=model_id, outputs=weights)