osanseviero commited on
Commit
1ee0b41
·
verified ·
1 Parent(s): 4a59ac9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import gradio as gr
4
+
5
+ def get_initial_distribution(seed=42):
6
+ np.random.seed(seed) # For reproducibility
7
+ token_probs = np.random.rand(10)
8
+ token_probs /= np.sum(token_probs) # Normalize to sum to 1
9
+ return token_probs
10
+
11
+ def adjust_distribution(temperature, top_k, top_p, initial_probs):
12
+ # Apply temperature scaling
13
+ token_probs = np.exp(np.log(initial_probs) / temperature)
14
+ token_probs /= np.sum(token_probs)
15
+
16
+ # Apply Top-K filtering
17
+ if top_k > 0:
18
+ top_k_indices = np.argsort(token_probs)[-top_k:]
19
+ top_k_probs = np.zeros_like(token_probs)
20
+ top_k_probs[top_k_indices] = token_probs[top_k_indices]
21
+ top_k_probs /= np.sum(top_k_probs) # Normalize after filtering
22
+ token_probs = top_k_probs
23
+
24
+ # Apply top_p (nucleus) filtering
25
+ if top_p < 1.0:
26
+ # Sort probabilities in descending order and compute cumulative sum
27
+ sorted_indices = np.argsort(token_probs)[::-1]
28
+ cumulative_probs = np.cumsum(token_probs[sorted_indices])
29
+
30
+ # Find the cutoff index for nucleus sampling
31
+ cutoff_index = np.searchsorted(cumulative_probs, top_p) + 1
32
+
33
+ # Get the indices that meet the threshold
34
+ top_p_indices = sorted_indices[:cutoff_index]
35
+ top_p_probs = np.zeros_like(token_probs)
36
+ top_p_probs[top_p_indices] = token_probs[top_p_indices]
37
+ top_p_probs /= np.sum(top_p_probs) # Normalize after filtering
38
+ token_probs = top_p_probs
39
+
40
+ # Plotting the probabilities
41
+ plt.figure(figsize=(10, 6))
42
+ plt.bar(range(10), token_probs, tick_label=[f'Token {i}' for i in range(10)])
43
+ plt.xlabel('Tokens')
44
+ plt.ylabel('Probabilities')
45
+ plt.title('Token Probability Distribution')
46
+ plt.ylim(0, 1)
47
+ plt.grid(True)
48
+ plt.tight_layout()
49
+
50
+ return plt
51
+
52
+ initial_probs = get_initial_distribution()
53
+
54
+ def update_plot(temperature, top_k, top_p):
55
+ return adjust_distribution(temperature, top_k, top_p, initial_probs)
56
+
57
+ interface = gr.Interface(
58
+ fn=update_plot,
59
+ inputs=[
60
+ gr.Slider(0.1, 2.0, step=0.1, value=1.0, label="Temperature"),
61
+ gr.Slider(0, 10, step=1, value=5, label="Top-k"),
62
+ gr.Slider(0.0, 1.0, step=0.01, value=0.9, label="Top-p"),
63
+ ],
64
+ outputs=gr.Plot(label="Token Probability Distribution"),
65
+ live=True
66
+ )
67
+
68
+ interface.launch()