bhoomika28 commited on
Commit
dedf27f
·
1 Parent(s): 9c0c960

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py CHANGED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from sklearn.linear_model import LinearRegression, Ridge
5
+ from sklearn.preprocessing import PolynomialFeatures
6
+ from sklearn.metrics import mean_squared_error
7
+
8
+ st.subheader("Ridge Demo")
9
+ col1, col2 = st.columns(2)
10
+
11
+ degree = st.slider('Degree', 2, 40, 1)
12
+ alpha = st.slider('Lambda (Regularisation)', 0, 500, 1)
13
+
14
+
15
+ with col1:
16
+ st.markdown("#### Un-regularized")
17
+
18
+ with col2:
19
+ st.markdown("#### Regularized")
20
+
21
+ x = np.linspace(-1., 1., 100)
22
+ y = 4 + 3*x + 2*np.sin(x) + 2*np.random.randn(len(x))
23
+
24
+
25
+ poly = PolynomialFeatures(degree=degree, include_bias=False)
26
+ x_new = poly.fit_transform(x.reshape(-1, 1))
27
+
28
+ lr = LinearRegression()
29
+ lr.fit(x_new, y)
30
+ y_pred = lr.predict(x_new)
31
+
32
+
33
+ ri = Ridge(alpha = alpha)
34
+ ri.fit(x_new, y)
35
+ y_pred_ri = ri.predict(x_new)
36
+
37
+
38
+ fig1, ax1 = plt.subplots()
39
+ fig2, ax2 = plt.subplots()
40
+
41
+ ax1.scatter(x, y)
42
+ ax1.plot(x, y_pred)
43
+
44
+ ax2.scatter(x, y)
45
+ ax2.plot(x, y_pred_ri)
46
+
47
+ for ax in [ax1, ax2]:
48
+ ax.spines['right'].set_visible(False)
49
+ ax.spines['top'].set_visible(False)
50
+
51
+ # Only show ticks on the left and bottom spines
52
+ ax.yaxis.set_ticks_position('left')
53
+ ax.xaxis.set_ticks_position('bottom')
54
+
55
+ ax.set_xlabel("x")
56
+ ax.set_ylabel("y")
57
+
58
+ rmse = np.round(np.sqrt(mean_squared_error(y_pred, y)), 2)
59
+ ax1.set_title(f"Train RMSE: {rmse}")
60
+
61
+ rmse_ri = np.round(np.sqrt(mean_squared_error(y_pred_ri, y)), 2)
62
+ ax2.set_title(f"Train RMSE: {rmse_ri}")
63
+
64
+ with col1:
65
+ st.pyplot(fig1)
66
+
67
+ with col2:
68
+ st.pyplot(fig2)
69
+ hide_streamlit_style = """
70
+ <style>
71
+ #MainMenu {visibility: hidden;}
72
+ footer {visibility: hidden;}
73
+ </style>
74
+ """
75
+ st.markdown(hide_streamlit_style, unsafe_allow_html=True)