MuskanMjn commited on
Commit
c7961a7
1 Parent(s): 553ef5c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from sklearn import neighbors
4
+
5
+ def train_and_plot(weights, n_neighbors):
6
+ np.random.seed(0)
7
+ X = np.sort(5 * np.random.rand(40, 1), axis=0)
8
+ T = np.linspace(0, 5, 500)[:, np.newaxis]
9
+ y = np.sin(X).ravel()
10
+
11
+ # Add noise to targets
12
+ y[::5] += 1 * (0.5 - np.random.rand(8))
13
+
14
+ knn = neighbors.KNeighborsRegressor(n_neighbors, weights=weights)
15
+ fit = knn.fit(X, y)
16
+ y_ = knn.predict(T)
17
+ score = knn.score(T, y_)
18
+
19
+ plt.scatter(X, y, color="darkorange", label="data")
20
+ plt.plot(T, y_, color="navy", label="prediction")
21
+ plt.axis("tight")
22
+ plt.legend()
23
+ plt.title("KNeighborsRegressor (k = %i, weights = '%s')" % (n_neighbors, weights))
24
+
25
+ plt.tight_layout()
26
+ return plt, score
27
+
28
+
29
+ with gr.Blocks() as demo:
30
+ link = "https://scikit-learn.org/stable/auto_examples/neighbors/plot_regression.html#sphx-glr-auto-examples-neighbors-plot-regression-py"
31
+ gr.Markdown("## Nearest Neighbors regression")
32
+ gr.Markdown(f"This demo is based on this [scikit-learn example]({link}).")
33
+ gr.HTML("<hr>")
34
+ gr.Markdown("In this demo, we learn a noise-infused sine function using k-Nearest Neighbor and observe how the function learned varies as we change the following hyperparameters:")
35
+ gr.Markdown("""1. Weight function
36
+ 2. Number of neighbors""")
37
+
38
+ with gr.Row():
39
+ weights = gr.Radio(['uniform', "distance"], label="Weights", info="Choose the weight function")
40
+ n_neighbors = gr.Slider(label="Neighbors", info="Choose the number of neighbors", minimum =1, maximum=15, step=1)
41
+
42
+ btn = gr.Button(value="Submit")
43
+
44
+
45
+ with gr.Row():
46
+ with gr.Column(scale=3):
47
+ plot = gr.Plot(label="KNeighborsRegressor Plot")
48
+ with gr.Column(scale=1):
49
+ num = gr.Textbox(label="Test Accuracy")
50
+
51
+
52
+ btn.click(train_and_plot, inputs=[weights, n_neighbors], outputs=[plot, num])
53
+
54
+
55
+ if __name__ == "__main__":
56
+ demo.launch()