nielsgl commited on
Commit
a255cdf
1 Parent(s): a2cb80c

update project

Browse files
Files changed (7) hide show
  1. .pre-commit-config.yaml +34 -0
  2. .python-version +1 -0
  3. app.py +186 -0
  4. poetry.lock +0 -0
  5. poetry.toml +2 -0
  6. pyproject.toml +53 -0
  7. requirements.txt +74 -0
.pre-commit-config.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # See https://pre-commit.com for more information
2
+ # See https://pre-commit.com/hooks.html for more hooks
3
+ repos:
4
+ - repo: https://github.com/pre-commit/pre-commit-hooks
5
+ rev: v4.4.0
6
+ hooks:
7
+ - id: trailing-whitespace
8
+ - id: end-of-file-fixer
9
+ - id: check-yaml
10
+ # - id: check-added-large-files
11
+ - repo: https://github.com/psf/black
12
+ rev: 23.3.0
13
+ hooks:
14
+ # - id: black
15
+ - id: black-jupyter
16
+ - repo: https://github.com/pycqa/isort
17
+ rev: 5.12.0
18
+ hooks:
19
+ - id: isort
20
+ name: isort (python)
21
+ - repo: https://github.com/asottile/pyupgrade
22
+ rev: v3.3.1
23
+ hooks:
24
+ - id: pyupgrade
25
+ args: [--py311-plus]
26
+ - repo: https://github.com/nbQA-dev/nbQA
27
+ rev: 1.7.0
28
+ hooks:
29
+ - id: nbqa-isort
30
+ - id: nbqa-black
31
+ - id: nbqa-pyupgrade
32
+ args: [--py311-plus]
33
+ default_language_version:
34
+ python: python3.11
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11.1
app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ from sklearn import datasets
5
+ from sklearn.linear_model import LogisticRegression
6
+ from sklearn.preprocessing import StandardScaler
7
+
8
+ rng = np.random.default_rng(0)
9
+
10
+ X, y = datasets.load_digits(return_X_y=True)
11
+
12
+ X = StandardScaler().fit_transform(X)
13
+
14
+ # classify small against large digits
15
+ y = (y > 4).astype(int)
16
+
17
+ # l1_ratio = 0.5 # L1 weight in the Elastic-Net regularization
18
+
19
+ md_description = """
20
+ # L1 Penalty and Sparsity in Logistic Regression
21
+
22
+ Comparison of the sparsity (percentage of zero coefficients) of solutions when L1, L2 and Elastic-Net penalty are used for different values of C. We can see that large values of C give more freedom to the model. Conversely, smaller values of C constrain the model more. In the L1 penalty case, this leads to sparser solutions. As expected, the Elastic-Net penalty sparsity is between that of L1 and L2.
23
+
24
+ We classify 8x8 images of digits into two classes: 0-4 against 5-9. The visualization shows coefficients of the models for varying C.
25
+ """
26
+
27
+
28
+ def make_regression(l1_ratio):
29
+ fig, axes = plt.subplots(3, 3)
30
+
31
+ # Set regularization parameter
32
+ for i, (C, axes_row) in enumerate(zip((1, 0.1, 0.01), axes)):
33
+ # Increase tolerance for short training time
34
+ clf_l1_LR = LogisticRegression(C=C, penalty="l1", tol=0.01, solver="saga")
35
+ clf_l2_LR = LogisticRegression(C=C, penalty="l2", tol=0.01, solver="saga")
36
+ clf_en_LR = LogisticRegression(
37
+ C=C, penalty="elasticnet", solver="saga", l1_ratio=l1_ratio, tol=0.01
38
+ )
39
+ clf_l1_LR.fit(X, y)
40
+ clf_l2_LR.fit(X, y)
41
+ clf_en_LR.fit(X, y)
42
+
43
+ coef_l1_LR = clf_l1_LR.coef_.ravel()
44
+ coef_l2_LR = clf_l2_LR.coef_.ravel()
45
+ coef_en_LR = clf_en_LR.coef_.ravel()
46
+
47
+ # coef_l1_LR contains zeros due to the
48
+ # L1 sparsity inducing norm
49
+ sparsity_l1_LR = np.mean(coef_l1_LR == 0) * 100
50
+ sparsity_l2_LR = np.mean(coef_l2_LR == 0) * 100
51
+ sparsity_en_LR = np.mean(coef_en_LR == 0) * 100
52
+
53
+ print(f"C={C:.2f}")
54
+ print(f"{'Sparsity with L1 penalty:':<40} {sparsity_l1_LR:2f}%")
55
+ print(f"{'Sparsity with Elastic-Net penalty:':<40} {sparsity_en_LR:.2f}%")
56
+ print(f"{'Sparsity with L2 penalty:':<40} {sparsity_l2_LR:.2f}%")
57
+ print(f"{'Score with L1 penalty:':<40} {clf_l1_LR.score(X, y):.2f}")
58
+ print(f"{'Score with Elastic-Net penalty:':<40} {clf_en_LR.score(X, y):.2f}")
59
+ print(f"{'Score with L2 penalty:':<40} {clf_l2_LR.score(X, y):.2f}")
60
+
61
+ log_out = f"""
62
+ C={C:.2f}
63
+ {'Sparsity with L1 penalty:':<40} {sparsity_l1_LR:2f}%
64
+ {'Sparsity with Elastic-Net penalty:':<40} {sparsity_en_LR:.2f}%
65
+ {'Sparsity with L2 penalty:':<40} {sparsity_l2_LR:.2f}%
66
+ {'Score with L1 penalty:':<40} {clf_l1_LR.score(X, y):.2f}
67
+ {'Score with Elastic-Net penalty:':<40} {clf_en_LR.score(X, y):.2f}
68
+ {'Score with L2 penalty:':<40} {clf_l2_LR.score(X, y):.2f}
69
+ """
70
+
71
+ if i == 0:
72
+ axes_row[0].set_title("L1 penalty")
73
+ axes_row[1].set_title(f"Elastic-Net\nl1/l2_ratio = {l1_ratio}")
74
+ axes_row[2].set_title("L2 penalty")
75
+
76
+ for ax, coefs in zip(axes_row, [coef_l1_LR, coef_en_LR, coef_l2_LR]):
77
+ ax.imshow(
78
+ np.abs(coefs.reshape(8, 8)),
79
+ interpolation="nearest",
80
+ cmap="binary",
81
+ vmax=1,
82
+ vmin=0,
83
+ )
84
+ ax.set_xticks(())
85
+ ax.set_yticks(())
86
+
87
+ axes_row[0].set_ylabel(f"{C=}")
88
+
89
+ return fig, log_out, make_example(l1_ratio)
90
+
91
+
92
+ def make_example(l1_ratio):
93
+ return f"""
94
+ With the following code you can reproduce this example with the current values of the sliders and the same data in a notebook:
95
+
96
+ ```python
97
+ import numpy as np
98
+ import matplotlib.pyplot as plt
99
+
100
+ from sklearn.linear_model import LogisticRegression
101
+ from sklearn import datasets
102
+ from sklearn.preprocessing import StandardScaler
103
+
104
+ rng = np.random.default_rng(0)
105
+
106
+ X, y = datasets.load_digits(return_X_y=True)
107
+
108
+ X = StandardScaler().fit_transform(X)
109
+
110
+ # classify small against large digits
111
+ y = (y > 4).astype(int)
112
+
113
+ l1_ratio = 0.5 # L1 weight in the Elastic-Net regularization
114
+
115
+ fig, axes = plt.subplots(3, 3)
116
+
117
+ # Set regularization parameter
118
+ for i, (C, axes_row) in enumerate(zip((1, 0.1, 0.01), axes)):
119
+ # Increase tolerance for short training time
120
+ clf_l1_LR = LogisticRegression(C=C, penalty="l1", tol=0.01, solver="saga")
121
+ clf_l2_LR = LogisticRegression(C=C, penalty="l2", tol=0.01, solver="saga")
122
+ clf_en_LR = LogisticRegression(
123
+ C=C, penalty="elasticnet", solver="saga", l1_ratio=l1_ratio, tol=0.01
124
+ )
125
+ clf_l1_LR.fit(X, y)
126
+ clf_l2_LR.fit(X, y)
127
+ clf_en_LR.fit(X, y)
128
+
129
+ coef_l1_LR = clf_l1_LR.coef_.ravel()
130
+ coef_l2_LR = clf_l2_LR.coef_.ravel()
131
+ coef_en_LR = clf_en_LR.coef_.ravel()
132
+
133
+ # coef_l1_LR contains zeros due to the
134
+ # L1 sparsity inducing norm
135
+
136
+ sparsity_l1_LR = np.mean(coef_l1_LR == 0) * 100
137
+ sparsity_l2_LR = np.mean(coef_l2_LR == 0) * 100
138
+ sparsity_en_LR = np.mean(coef_en_LR == 0) * 100
139
+
140
+ print(f"C={{C:.2f}}")
141
+ print(f"{{'Sparsity with L1 penalty:':<40}} {{sparsity_l1_LR:2f}}%\")
142
+ print(f"{{'Sparsity with Elastic-Net penalty:':<40}} {{sparsity_en_LR:.2f}}%")
143
+ print(f"{{'Sparsity with L2 penalty:':<40}} {{sparsity_l2_LR:.2f}}%")
144
+ print(f"{{'Score with L1 penalty:':<40}} {{clf_l1_LR.score(X, y):.2f}}")
145
+ print(f"{{'Score with Elastic-Net penalty:':<40}} {{clf_en_LR.score(X, y):.2f}}")
146
+ print(f"{{'Score with L2 penalty:':<40}} {{clf_l2_LR.score(X, y):.2f}}")
147
+
148
+ if i == 0:
149
+ axes_row[0].set_title("L1 penalty")
150
+ axes_row[1].set_title(f"Elastic-Net\\nl1/l2_ratio = {l1_ratio}")
151
+ axes_row[2].set_title("L2 penalty")
152
+
153
+ for ax, coefs in zip(axes_row, [coef_l1_LR, coef_en_LR, coef_l2_LR]):
154
+ ax.imshow(
155
+ np.abs(coefs.reshape(8, 8)),
156
+ interpolation="nearest",
157
+ cmap="binary",
158
+ vmax=1,
159
+ vmin=0,
160
+ )
161
+ ax.set_xticks(())
162
+ ax.set_yticks(())
163
+
164
+ axes_row[0].set_ylabel(f"{{C=}}")
165
+ plt.show()
166
+ ```
167
+ """
168
+
169
+
170
+ with gr.Blocks() as demo:
171
+ with gr.Row():
172
+ gr.Markdown(md_description)
173
+ with gr.Row():
174
+ with gr.Column():
175
+ ratio_slider = gr.Slider(minimum=0, maximum=1, label="L1/L2 ratio", step=0.1, value=0.5)
176
+ button = gr.Button(value="Generate")
177
+ with gr.Column():
178
+ plot = gr.Plot(label="Output")
179
+ log = gr.Markdown("")
180
+
181
+ with gr.Row():
182
+ example = gr.Markdown(make_example(ratio_slider.value))
183
+ button.click(make_regression, inputs=[ratio_slider], outputs=[plot, log, example])
184
+ ratio_slider.change(fn=make_regression, inputs=[ratio_slider], outputs=[plot, log, example])
185
+
186
+ demo.launch()
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
poetry.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [virtualenvs]
2
+ in-project = true
pyproject.toml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "sklearn-decision-tree-regression"
3
+ version = "0.1.0"
4
+ description = "Hugging Face Scikit Learn Demos"
5
+ authors = ["Niels van Galen Last <nvangalenlast@gmail.com>"]
6
+ license = "MIT"
7
+ readme = "README.md"
8
+ # packages = [{ include = "huggingface_sklearn" }]
9
+
10
+ [tool.poetry.dependencies]
11
+ python = ">=3.8.9,<3.12"
12
+ numpy = "^1.24.2"
13
+ scikit-learn = "^1.2.2"
14
+ matplotlib = "^3.7.1"
15
+ plotly = "^5.14.0"
16
+ gradio = "^3.24.1"
17
+
18
+
19
+ [tool.poetry.group.dev.dependencies]
20
+ black = { extras = ["jupyter"], version = "^23.3.0" }
21
+ isort = "^5.12.0"
22
+ pre-commit = "^3.2.1"
23
+ pylint = "^2.17.1"
24
+ pytest = "^7.2.2"
25
+ jupyterlab = "^3.6.3"
26
+ jupyterlab-widgets = "^3.0.7"
27
+ ipywidgets = "^8.0.6"
28
+
29
+ [build-system]
30
+ requires = ["poetry-core"]
31
+ build-backend = "poetry.core.masonry.api"
32
+
33
+ [tool.black]
34
+ line-length = 100
35
+ target_version = ['py311']
36
+ include = '\.py$'
37
+
38
+ [tool.isort]
39
+ profile = "black"
40
+ # force_single_line = "false"
41
+ force_sort_within_sections = "true"
42
+ line_length = 100
43
+
44
+ [tool.pylint]
45
+ [tool.pylint.messages_control]
46
+ #line-too-long='off'
47
+ disable = """
48
+ invalid-name,
49
+ logging-fstring-interpolation,
50
+ missing-class-docstring,
51
+ missing-function-docstring,
52
+ missing-module-docstring,
53
+ """
requirements.txt ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==22.1.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
2
+ aiohttp==3.8.4 ; python_full_version >= "3.8.9" and python_version < "3.12"
3
+ aiosignal==1.3.1 ; python_full_version >= "3.8.9" and python_version < "3.12"
4
+ altair==4.2.2 ; python_full_version >= "3.8.9" and python_version < "3.12"
5
+ anyio==3.6.2 ; python_full_version >= "3.8.9" and python_version < "3.12"
6
+ async-timeout==4.0.2 ; python_full_version >= "3.8.9" and python_version < "3.12"
7
+ attrs==22.2.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
8
+ certifi==2022.12.7 ; python_full_version >= "3.8.9" and python_version < "3.12"
9
+ charset-normalizer==3.1.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
10
+ click==8.1.3 ; python_full_version >= "3.8.9" and python_version < "3.12"
11
+ colorama==0.4.6 ; python_full_version >= "3.8.9" and python_version < "3.12" and platform_system == "Windows"
12
+ contourpy==1.0.7 ; python_full_version >= "3.8.9" and python_version < "3.12"
13
+ cycler==0.11.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
14
+ entrypoints==0.4 ; python_full_version >= "3.8.9" and python_version < "3.12"
15
+ fastapi==0.95.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
16
+ ffmpy==0.3.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
17
+ filelock==3.10.7 ; python_full_version >= "3.8.9" and python_version < "3.12"
18
+ fonttools==4.39.3 ; python_full_version >= "3.8.9" and python_version < "3.12"
19
+ frozenlist==1.3.3 ; python_full_version >= "3.8.9" and python_version < "3.12"
20
+ fsspec==2023.3.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
21
+ gradio-client==0.0.5 ; python_full_version >= "3.8.9" and python_version < "3.12"
22
+ gradio==3.24.1 ; python_full_version >= "3.8.9" and python_version < "3.12"
23
+ h11==0.14.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
24
+ httpcore==0.16.3 ; python_full_version >= "3.8.9" and python_version < "3.12"
25
+ httpx==0.23.3 ; python_full_version >= "3.8.9" and python_version < "3.12"
26
+ huggingface-hub==0.13.3 ; python_full_version >= "3.8.9" and python_version < "3.12"
27
+ idna==3.4 ; python_full_version >= "3.8.9" and python_version < "3.12"
28
+ importlib-resources==5.12.0 ; python_full_version >= "3.8.9" and python_version < "3.10"
29
+ jinja2==3.1.2 ; python_full_version >= "3.8.9" and python_version < "3.12"
30
+ joblib==1.2.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
31
+ jsonschema==4.17.3 ; python_full_version >= "3.8.9" and python_version < "3.12"
32
+ kiwisolver==1.4.4 ; python_full_version >= "3.8.9" and python_version < "3.12"
33
+ linkify-it-py==2.0.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
34
+ markdown-it-py==2.2.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
35
+ markdown-it-py[linkify]==2.2.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
36
+ markupsafe==2.1.2 ; python_full_version >= "3.8.9" and python_version < "3.12"
37
+ matplotlib==3.7.1 ; python_full_version >= "3.8.9" and python_version < "3.12"
38
+ mdit-py-plugins==0.3.3 ; python_full_version >= "3.8.9" and python_version < "3.12"
39
+ mdurl==0.1.2 ; python_full_version >= "3.8.9" and python_version < "3.12"
40
+ multidict==6.0.4 ; python_full_version >= "3.8.9" and python_version < "3.12"
41
+ numpy==1.24.2 ; python_full_version >= "3.8.9" and python_version < "3.12"
42
+ orjson==3.8.9 ; python_full_version >= "3.8.9" and python_version < "3.12"
43
+ packaging==23.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
44
+ pandas==1.5.3 ; python_full_version >= "3.8.9" and python_version < "3.12"
45
+ pillow==9.5.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
46
+ pkgutil-resolve-name==1.3.10 ; python_full_version >= "3.8.9" and python_version < "3.9"
47
+ plotly==5.14.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
48
+ pydantic==1.10.7 ; python_full_version >= "3.8.9" and python_version < "3.12"
49
+ pydub==0.25.1 ; python_full_version >= "3.8.9" and python_version < "3.12"
50
+ pyparsing==3.0.9 ; python_full_version >= "3.8.9" and python_version < "3.12"
51
+ pyrsistent==0.19.3 ; python_full_version >= "3.8.9" and python_version < "3.12"
52
+ python-dateutil==2.8.2 ; python_full_version >= "3.8.9" and python_version < "3.12"
53
+ python-multipart==0.0.6 ; python_full_version >= "3.8.9" and python_version < "3.12"
54
+ pytz==2023.3 ; python_full_version >= "3.8.9" and python_version < "3.12"
55
+ pyyaml==6.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
56
+ requests==2.28.2 ; python_full_version >= "3.8.9" and python_version < "3.12"
57
+ rfc3986[idna2008]==1.5.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
58
+ scikit-learn==1.2.2 ; python_full_version >= "3.8.9" and python_version < "3.12"
59
+ scipy==1.9.3 ; python_full_version >= "3.8.9" and python_version < "3.12"
60
+ semantic-version==2.10.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
61
+ six==1.16.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
62
+ sniffio==1.3.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
63
+ starlette==0.26.1 ; python_full_version >= "3.8.9" and python_version < "3.12"
64
+ tenacity==8.2.2 ; python_full_version >= "3.8.9" and python_version < "3.12"
65
+ threadpoolctl==3.1.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
66
+ toolz==0.12.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
67
+ tqdm==4.65.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
68
+ typing-extensions==4.5.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
69
+ uc-micro-py==1.0.1 ; python_full_version >= "3.8.9" and python_version < "3.12"
70
+ urllib3==1.26.15 ; python_full_version >= "3.8.9" and python_version < "3.12"
71
+ uvicorn==0.21.1 ; python_full_version >= "3.8.9" and python_version < "3.12"
72
+ websockets==11.0 ; python_full_version >= "3.8.9" and python_version < "3.12"
73
+ yarl==1.8.2 ; python_full_version >= "3.8.9" and python_version < "3.12"
74
+ zipp==3.15.0 ; python_full_version >= "3.8.9" and python_version < "3.10"