Spaces:
Running
Running
Completely change the structure of the project
Browse files- .vscode/settings.json +1 -0
- README.md +4 -0
- app/__main__.py +6 -0
- app/cli.py +144 -0
- app/constants.py +27 -11
- app/gui.py +39 -73
- app/model.py +273 -113
- app/utils.py +0 -164
- deprecated/__init__.py +0 -0
- deprecated/main.py +0 -44
- deprecated/train.py +0 -152
- justfile +4 -6
- notebook.ipynb +152 -0
- poetry.lock +0 -0
- pyproject.toml +2 -1
.vscode/settings.json
CHANGED
@@ -23,5 +23,6 @@
|
|
23 |
"**/__pycache__": true,
|
24 |
"**/.ruff_cache": true,
|
25 |
"**/.venv": true,
|
|
|
26 |
}
|
27 |
}
|
|
|
23 |
"**/__pycache__": true,
|
24 |
"**/.ruff_cache": true,
|
25 |
"**/.venv": true,
|
26 |
+
"**/.cache": true,
|
27 |
}
|
28 |
}
|
README.md
CHANGED
@@ -7,6 +7,10 @@ Sentiment Analysis
|
|
7 |
3. Run `just install` to install the dependencies
|
8 |
4. Run `just run --help` to see the available commands
|
9 |
|
|
|
|
|
|
|
|
|
10 |
|
11 |
### TODO
|
12 |
- [ ] CLI using `click` (commands: predict, train, evaluate) with settings set via flags or environment variables
|
|
|
7 |
3. Run `just install` to install the dependencies
|
8 |
4. Run `just run --help` to see the available commands
|
9 |
|
10 |
+
### Datasets
|
11 |
+
- [Sentiment140](https://www.kaggle.com/datasets/kazanova/sentiment140)
|
12 |
+
- [IMDb](https://www.kaggle.com/datasets/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews)
|
13 |
+
- [Amazon Reviews](https://www.kaggle.com/datasets/bittlingmayer/amazonreviews)
|
14 |
|
15 |
### TODO
|
16 |
- [ ] CLI using `click` (commands: predict, train, evaluate) with settings set via flags or environment variables
|
app/__main__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from app.cli import cli_wrapper as cli
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
cli()
|
app/cli.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Literal
|
5 |
+
|
6 |
+
import click
|
7 |
+
|
8 |
+
__all__ = ["cli_wrapper"]
|
9 |
+
|
10 |
+
ERROR_STR = click.style("ERROR", fg="red")
|
11 |
+
DONE_STR = click.style("DONE", fg="green")
|
12 |
+
POSITIVE_STR = click.style("POSITIVE", fg="green")
|
13 |
+
NEUTRAL_STR = click.style("NEUTRAL", fg="yellow")
|
14 |
+
NEGATIVE_STR = click.style("NEGATIVE", fg="red")
|
15 |
+
|
16 |
+
|
17 |
+
@click.group()
|
18 |
+
def cli() -> None: ...
|
19 |
+
|
20 |
+
|
21 |
+
@cli.command()
|
22 |
+
@click.option(
|
23 |
+
"--model",
|
24 |
+
"model_path",
|
25 |
+
required=True,
|
26 |
+
help="Path to the trained model",
|
27 |
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True, resolve_path=True, path_type=Path),
|
28 |
+
)
|
29 |
+
@click.option(
|
30 |
+
"--share/--no-share",
|
31 |
+
default=False,
|
32 |
+
help="Whether to create a shareable link",
|
33 |
+
)
|
34 |
+
def gui(model_path: Path, share: bool) -> None:
|
35 |
+
"""Launch the Gradio GUI"""
|
36 |
+
from app.gui import launch_gui
|
37 |
+
|
38 |
+
launch_gui(model_path, share)
|
39 |
+
|
40 |
+
|
41 |
+
@cli.command()
|
42 |
+
@click.option(
|
43 |
+
"--model",
|
44 |
+
"model_path",
|
45 |
+
required=True,
|
46 |
+
help="Path to the trained model",
|
47 |
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True, resolve_path=True, path_type=Path),
|
48 |
+
)
|
49 |
+
@click.argument("text", nargs=-1)
|
50 |
+
def predict(model_path: Path, text: list[str]) -> None:
|
51 |
+
"""Perform sentiment analysis on the provided text.
|
52 |
+
|
53 |
+
Note: Piped input takes precedence over the text argument
|
54 |
+
"""
|
55 |
+
import sys
|
56 |
+
|
57 |
+
import joblib
|
58 |
+
|
59 |
+
text = " ".join(text).strip()
|
60 |
+
if not sys.stdin.isatty():
|
61 |
+
piped_text = sys.stdin.read().strip()
|
62 |
+
text = piped_text or text
|
63 |
+
|
64 |
+
if not text:
|
65 |
+
click.echo(f"{ERROR_STR}: No text provided")
|
66 |
+
return
|
67 |
+
|
68 |
+
click.echo("Loading model... ", nl=False)
|
69 |
+
model = joblib.load(model_path)
|
70 |
+
click.echo(DONE_STR)
|
71 |
+
|
72 |
+
click.echo("Performing sentiment analysis... ", nl=False)
|
73 |
+
prediction = model.predict([text])[0]
|
74 |
+
if prediction == 0:
|
75 |
+
sentiment = NEGATIVE_STR
|
76 |
+
elif prediction == 1:
|
77 |
+
sentiment = POSITIVE_STR
|
78 |
+
else:
|
79 |
+
sentiment = NEUTRAL_STR
|
80 |
+
click.echo(sentiment)
|
81 |
+
|
82 |
+
|
83 |
+
@cli.command()
|
84 |
+
@click.option(
|
85 |
+
"--dataset",
|
86 |
+
required=True,
|
87 |
+
help="Dataset to train the model on",
|
88 |
+
type=click.Choice(["sentiment140", "amazonreviews", "imdb50k"]),
|
89 |
+
)
|
90 |
+
@click.option(
|
91 |
+
"--max-features",
|
92 |
+
default=20000,
|
93 |
+
help="Maximum number of features",
|
94 |
+
show_default=True,
|
95 |
+
type=click.IntRange(1, None),
|
96 |
+
)
|
97 |
+
@click.option(
|
98 |
+
"--seed",
|
99 |
+
default=42,
|
100 |
+
help="Random seed (-1 for random seed)",
|
101 |
+
show_default=True,
|
102 |
+
type=click.IntRange(-1, None),
|
103 |
+
)
|
104 |
+
def train(
|
105 |
+
dataset: Literal["sentiment140", "amazonreviews", "imdb50k"],
|
106 |
+
max_features: int,
|
107 |
+
seed: int,
|
108 |
+
) -> None:
|
109 |
+
"""Train the model on the provided dataset"""
|
110 |
+
import joblib
|
111 |
+
|
112 |
+
from app.constants import MODELS_DIR
|
113 |
+
from app.model import create_model, load_data, train_model
|
114 |
+
|
115 |
+
model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl"
|
116 |
+
if model_path.exists():
|
117 |
+
click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
|
118 |
+
|
119 |
+
click.echo("Preprocessing dataset... ", nl=False)
|
120 |
+
text_data, label_data = load_data(dataset)
|
121 |
+
click.echo(DONE_STR)
|
122 |
+
|
123 |
+
click.echo("Creating model... ", nl=False)
|
124 |
+
model = create_model(max_features, seed=None if seed == -1 else seed)
|
125 |
+
click.echo(DONE_STR)
|
126 |
+
|
127 |
+
click.echo("Training model... ", nl=False)
|
128 |
+
accuracy = train_model(model, text_data, label_data)
|
129 |
+
joblib.dump(model, model_path)
|
130 |
+
click.echo(DONE_STR)
|
131 |
+
|
132 |
+
click.echo("Model accuracy: ")
|
133 |
+
click.secho(f"{accuracy:.2%}", fg="blue")
|
134 |
+
|
135 |
+
# TODO: Add hyperparameter options
|
136 |
+
# TODO: Random/grid search for finding best classifier and hyperparameters
|
137 |
+
|
138 |
+
|
139 |
+
def cli_wrapper() -> None:
|
140 |
+
cli(max_content_width=120)
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
cli_wrapper()
|
app/constants.py
CHANGED
@@ -1,16 +1,32 @@
|
|
|
|
|
|
|
|
1 |
from pathlib import Path
|
2 |
|
3 |
-
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
MODELS_DIR: Path = Path("models")
|
10 |
-
CACHE_DIR: Path = Path("cache")
|
11 |
-
CHECKPOINT_PATH: Path = CACHE_DIR / "pipeline.pkl"
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import os
|
4 |
from pathlib import Path
|
5 |
|
6 |
+
CACHE_DIR = Path(os.getenv("CACHE_DIR", ".cache"))
|
7 |
+
DATA_DIR = Path(os.getenv("DATA_DIR", "data"))
|
8 |
+
MODELS_DIR = Path(os.getenv("MODELS_DIR", "models"))
|
9 |
+
|
10 |
+
SENTIMENT140_PATH = DATA_DIR / "sentiment140.csv"
|
11 |
+
SENTIMENT140_URL = "https://www.kaggle.com/datasets/kazanova/sentiment140"
|
12 |
+
|
13 |
+
AMAZONREVIEWS_PATH = (DATA_DIR / "amazonreviews.test.txt.bz2", DATA_DIR / "amazonreviews.train.txt.bz2")
|
14 |
+
AMAZONREVIEWS_URL = "https://www.kaggle.com/datasets/bittlingmayer/amazonreviews"
|
15 |
|
16 |
+
IMDB50K_PATH = DATA_DIR / "imdb50k.csv"
|
17 |
+
IMDB50K_URL = "https://www.kaggle.com/datasets/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews"
|
|
|
|
|
|
|
18 |
|
19 |
+
URL_REGEX = r"(https:\/\/www\.|http:\/\/www\.|https:\/\/|http:\/\/)?[a-zA-Z]{2,}(\.[a-zA-Z]{2,})(\.[a-zA-Z]{2,})?\/[a-zA-Z0-9]{2,}|((https:\/\/www\.|http:\/\/www\.|https:\/\/|http:\/\/)?[a-zA-Z]{2,}(\.[a-zA-Z]{2,})(\.[a-zA-Z]{2,})?)|(https:\/\/www\.|http:\/\/www\.|https:\/\/|http:\/\/)?[a-zA-Z0-9]{2,}\.[a-zA-Z0-9]{2,}\.[a-zA-Z0-9]{2,}(\.[a-zA-Z0-9]{2,})?" # https://www.freecodecamp.org/news/how-to-write-a-regular-expression-for-a-url/
|
20 |
+
EMOTICON_MAP = {
|
21 |
+
"SMILE": [":)", ":-)", ": )", ":D", ":-D", ": D", ";)", ";-)", "; )", ":>", ":->", ": >", ":]", ":-]", ": ]"],
|
22 |
+
"LOVE": ["<3", ":*", ":-*", ": *"],
|
23 |
+
"WINK": [";)", ";-)", "; )", ";>", ";->", "; >"],
|
24 |
+
"FROWN": [":(", ":-(", ": (", ":[", ":-[", ": ["],
|
25 |
+
"CRY": [":'(", ": (", ":' (", ":'[", ":' ["],
|
26 |
+
"SURPRISE": [":O", ":-O", ": O", ":0", ":-0", ": 0", ":o", ":-o", ": o"],
|
27 |
+
"ANGRY": [">:(", ">:-(", "> :(", ">:["],
|
28 |
+
}
|
29 |
|
30 |
+
CACHE_DIR.mkdir(exist_ok=True, parents=True)
|
31 |
+
DATA_DIR.mkdir(exist_ok=True, parents=True)
|
32 |
+
MODELS_DIR.mkdir(exist_ok=True, parents=True)
|
app/gui.py
CHANGED
@@ -1,92 +1,58 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
-
|
|
|
|
|
4 |
|
5 |
import gradio as gr
|
|
|
6 |
|
7 |
-
|
8 |
-
from
|
9 |
|
10 |
-
|
11 |
-
TOKENIZER_EXT = ".tokenizer.pkl"
|
12 |
-
MODEL_EXT = ".model.pkl"
|
13 |
-
POSITIVE_LABEL = "Positive 😊"
|
14 |
-
NEGATIVE_LABEL = "Negative 😤"
|
15 |
-
REFRESH_SYMBOL = "🔄"
|
16 |
-
|
17 |
-
|
18 |
-
def load_style() -> str:
|
19 |
-
if not CSS_PATH.is_file():
|
20 |
-
return ""
|
21 |
-
|
22 |
-
with Path.open(CSS_PATH) as f:
|
23 |
-
return f.read()
|
24 |
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
return POSITIVE_LABEL if pred else NEGATIVE_LABEL
|
30 |
-
|
31 |
-
|
32 |
-
def train_wrapper() -> None:
|
33 |
-
msg = "Training is not supported in the GUI."
|
34 |
-
raise NotImplementedError(msg)
|
35 |
-
|
36 |
|
37 |
-
def evaluate_wrapper() -> None:
|
38 |
-
msg = "Evaluation is not supported in the GUI."
|
39 |
-
raise NotImplementedError(msg)
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
with gr.Blocks(css=load_style()) as demo:
|
43 |
-
gr.Markdown("## Sentiment Analysis")
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
placeholder="Enter text here",
|
50 |
-
key="input-textbox",
|
51 |
-
)
|
52 |
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
55 |
|
56 |
-
with gr.Row(elem_classes="justify-between"):
|
57 |
-
clear_btn = gr.ClearButton([textbox, output], value="Clear 🧹")
|
58 |
-
analyze_btn = gr.Button(
|
59 |
-
"Analyze 🔍",
|
60 |
-
variant="primary",
|
61 |
-
interactive=False,
|
62 |
-
)
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
|
71 |
-
model_selector = gr.Dropdown(
|
72 |
-
choices=[mdl.stem[: -len(".model")] for mdl in MODELS_DIR.glob(f"*{MODEL_EXT}")],
|
73 |
-
label="Model",
|
74 |
-
key="model-selector",
|
75 |
-
)
|
76 |
|
77 |
-
|
|
|
|
|
|
|
78 |
|
79 |
-
# Event handlers
|
80 |
-
textbox.input(
|
81 |
-
fn=lambda text: gr.update(interactive=bool(text.strip())),
|
82 |
-
inputs=[textbox],
|
83 |
-
outputs=[analyze_btn],
|
84 |
-
)
|
85 |
-
analyze_btn.click(
|
86 |
-
fn=predict_wrapper,
|
87 |
-
inputs=[textbox, tokenizer_selector, model_selector],
|
88 |
-
outputs=[output],
|
89 |
-
)
|
90 |
|
91 |
-
|
92 |
-
demo.launch()
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
+
import os
|
4 |
+
from functools import lru_cache
|
5 |
+
from typing import TYPE_CHECKING
|
6 |
|
7 |
import gradio as gr
|
8 |
+
import joblib
|
9 |
|
10 |
+
if TYPE_CHECKING:
|
11 |
+
from sklearn.pipeline import Pipeline
|
12 |
|
13 |
+
__all__ = ["launch_gui"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
+
POSITIVE_LABEL = "Positive 😊"
|
17 |
+
NEUTRAL_LABEL = "Neutral 😐"
|
18 |
+
NEGATIVE_LABEL = "Negative 😤"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
|
|
|
|
|
|
20 |
|
21 |
+
@lru_cache(maxsize=1)
|
22 |
+
def load_model() -> Pipeline:
|
23 |
+
"""Load the trained model and cache it."""
|
24 |
+
model_path = os.environ.get("MODEL_PATH", None)
|
25 |
+
if model_path is None:
|
26 |
+
msg = "MODEL_PATH environment variable not set"
|
27 |
+
raise ValueError(msg)
|
28 |
+
return joblib.load(model_path)
|
29 |
|
|
|
|
|
30 |
|
31 |
+
def sentiment_analysis(text: str) -> str:
|
32 |
+
"""Perform sentiment analysis on the provided text."""
|
33 |
+
model = load_model()
|
34 |
+
prediction = model.predict([text])[0]
|
|
|
|
|
|
|
35 |
|
36 |
+
if prediction == 0:
|
37 |
+
return NEGATIVE_LABEL
|
38 |
+
if prediction == 1:
|
39 |
+
return POSITIVE_LABEL
|
40 |
+
return NEUTRAL_LABEL
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
+
demo = gr.Interface(
|
44 |
+
fn=sentiment_analysis,
|
45 |
+
inputs="text",
|
46 |
+
outputs="label",
|
47 |
+
title="Sentiment Analysis",
|
48 |
+
)
|
49 |
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
+
def launch_gui(model_path: str, share: bool) -> None:
|
52 |
+
"""Launch the Gradio GUI."""
|
53 |
+
os.environ["MODEL_PATH"] = model_path
|
54 |
+
demo.launch(share=share)
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
+
if __name__ == "__main__":
|
58 |
+
demo.launch()
|
app/model.py
CHANGED
@@ -1,144 +1,304 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
|
|
|
|
3 |
import warnings
|
4 |
-
from
|
5 |
-
from typing import TYPE_CHECKING, Sequence
|
6 |
|
7 |
-
import
|
|
|
|
|
|
|
8 |
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
|
9 |
from sklearn.linear_model import LogisticRegression
|
|
|
10 |
from sklearn.pipeline import Pipeline
|
11 |
|
12 |
-
from constants import
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
|
|
|
|
34 |
|
35 |
-
def export_to_file(pipeline: Pipeline, path: Path) -> None:
|
36 |
-
joblib.dump(pipeline, path)
|
37 |
|
|
|
|
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
return tokenizer.transform([text])[0]
|
42 |
|
|
|
|
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
|
50 |
-
def
|
51 |
-
|
52 |
-
|
53 |
-
y: list[int],
|
54 |
-
export_path: Path,
|
55 |
-
cache: joblib.Memory,
|
56 |
) -> Pipeline:
|
57 |
-
|
58 |
-
|
59 |
-
with warnings.catch_warnings():
|
60 |
-
warnings.simplefilter("ignore")
|
61 |
-
pipeline.fit(x, y)
|
62 |
-
|
63 |
-
export_to_file(pipeline, export_path)
|
64 |
-
return pipeline
|
65 |
|
|
|
|
|
|
|
66 |
|
67 |
-
|
68 |
-
|
|
|
|
|
69 |
[
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
)
|
|
|
74 |
("tfidf", TfidfTransformer()),
|
|
|
|
|
75 |
],
|
76 |
-
|
77 |
-
y,
|
78 |
-
export_path,
|
79 |
-
cache,
|
80 |
)
|
81 |
|
82 |
|
83 |
-
def
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
model = train_model(x_tr, y, cache, rs)
|
107 |
-
|
108 |
-
return Pipeline([("tokenizer", tokenizer), ("model", model)])
|
109 |
-
|
110 |
-
|
111 |
-
def train_tokenizer(x: list[str], y: list[int], cache: joblib.Memory) -> Pipeline:
|
112 |
-
# TODO: In the future, allow for different tokenizers
|
113 |
-
pipeline = Pipeline(
|
114 |
-
[
|
115 |
-
(
|
116 |
-
"vectorize",
|
117 |
-
CountVectorizer(stop_words="english", ngram_range=(1, 2), max_features=MAX_TOKENIZER_FEATURES),
|
118 |
-
),
|
119 |
-
("tfidf", TfidfTransformer()),
|
120 |
-
],
|
121 |
-
memory=cache,
|
122 |
)
|
123 |
|
124 |
with warnings.catch_warnings():
|
125 |
-
warnings.simplefilter("ignore")
|
126 |
-
|
127 |
-
|
128 |
-
return pipeline
|
129 |
-
|
130 |
-
|
131 |
-
def train_model(x: list[str], y: list[int], cache: joblib.Memory, rs: RandomState) -> Pipeline:
|
132 |
-
# TODO: In the future, allow for different classifiers
|
133 |
-
pipeline = Pipeline(
|
134 |
-
[
|
135 |
-
("clf", LogisticRegression(max_iter=CLF_MAX_ITER, random_state=rs)),
|
136 |
-
],
|
137 |
-
memory=cache,
|
138 |
-
)
|
139 |
-
|
140 |
-
with warnings.catch_warnings():
|
141 |
-
warnings.simplefilter("ignore") # Ignore joblib warnings
|
142 |
-
pipeline.fit(x, y)
|
143 |
|
144 |
-
return
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
+
import bz2
|
4 |
+
import re
|
5 |
import warnings
|
6 |
+
from typing import Literal
|
|
|
7 |
|
8 |
+
import pandas as pd
|
9 |
+
from joblib import Memory
|
10 |
+
from nltk.stem import WordNetLemmatizer
|
11 |
+
from sklearn.base import BaseEstimator, TransformerMixin
|
12 |
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
|
13 |
from sklearn.linear_model import LogisticRegression
|
14 |
+
from sklearn.model_selection import train_test_split
|
15 |
from sklearn.pipeline import Pipeline
|
16 |
|
17 |
+
from app.constants import (
|
18 |
+
AMAZONREVIEWS_PATH,
|
19 |
+
AMAZONREVIEWS_URL,
|
20 |
+
CACHE_DIR,
|
21 |
+
EMOTICON_MAP,
|
22 |
+
IMDB50K_PATH,
|
23 |
+
IMDB50K_URL,
|
24 |
+
SENTIMENT140_PATH,
|
25 |
+
SENTIMENT140_URL,
|
26 |
+
URL_REGEX,
|
27 |
+
)
|
28 |
+
|
29 |
+
__all__ = ["load_data", "create_model", "train_model"]
|
30 |
+
|
31 |
+
|
32 |
+
class TextCleaner(BaseEstimator, TransformerMixin):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
*,
|
36 |
+
replace_url: bool = True,
|
37 |
+
replace_hashtag: bool = True,
|
38 |
+
replace_emoticon: bool = True,
|
39 |
+
replace_emoji: bool = True,
|
40 |
+
lowercase: bool = True,
|
41 |
+
character_threshold: int = 2,
|
42 |
+
remove_special_characters: bool = True,
|
43 |
+
remove_extra_spaces: bool = True,
|
44 |
+
):
|
45 |
+
self.replace_url = replace_url
|
46 |
+
self.replace_hashtag = replace_hashtag
|
47 |
+
self.replace_emoticon = replace_emoticon
|
48 |
+
self.replace_emoji = replace_emoji
|
49 |
+
self.lowercase = lowercase
|
50 |
+
self.character_threshold = character_threshold
|
51 |
+
self.remove_special_characters = remove_special_characters
|
52 |
+
self.remove_extra_spaces = remove_extra_spaces
|
53 |
+
|
54 |
+
def fit(self, _data: list[str], _labels: list[int] | None = None) -> TextCleaner:
|
55 |
+
return self
|
56 |
+
|
57 |
+
def transform(self, data: list[str], _labels: list[int] | None = None) -> list[str]:
|
58 |
+
# Replace URLs, hashtags, emoticons, and emojis
|
59 |
+
data = [re.sub(URL_REGEX, "URL", text) for text in data] if self.replace_url else data
|
60 |
+
data = [re.sub(r"#\w+", "HASHTAG", text) for text in data] if self.replace_hashtag else data
|
61 |
+
|
62 |
+
# Replace emoticons
|
63 |
+
if self.replace_emoticon:
|
64 |
+
for word, emoticons in EMOTICON_MAP.items():
|
65 |
+
for emoticon in emoticons:
|
66 |
+
data = [text.replace(emoticon, f"EMOTE_{word}") for text in data]
|
67 |
+
|
68 |
+
# Basic text cleaning
|
69 |
+
data = [text.lower() for text in data] if self.lowercase else data # Lowercase
|
70 |
+
threshold_pattern = re.compile(rf"\b\w{{1,{self.character_threshold}}}\b")
|
71 |
+
data = (
|
72 |
+
[re.sub(threshold_pattern, "", text) for text in data] if self.character_threshold > 0 else data
|
73 |
+
) # Remove short words
|
74 |
+
data = (
|
75 |
+
[re.sub(r"[^a-zA-Z0-9\s]", "", text) for text in data] if self.remove_special_characters else data
|
76 |
+
) # Remove special characters
|
77 |
+
data = [re.sub(r"\s+", " ", text) for text in data] if self.remove_extra_spaces else data # Remove extra spaces
|
78 |
+
|
79 |
+
# Remove leading and trailing whitespace
|
80 |
+
return [text.strip() for text in data]
|
81 |
+
|
82 |
+
|
83 |
+
class TextLemmatizer(BaseEstimator, TransformerMixin):
|
84 |
+
def __init__(self):
|
85 |
+
self.lemmatizer = WordNetLemmatizer()
|
86 |
+
|
87 |
+
def fit(self, _data: list[str], _labels: list[int] | None = None) -> TextLemmatizer:
|
88 |
+
return self
|
89 |
+
|
90 |
+
def transform(self, data: list[str], _labels: list[int] | None = None) -> list[str]:
|
91 |
+
return [self.lemmatizer.lemmatize(text) for text in data]
|
92 |
+
|
93 |
+
|
94 |
+
def load_sentiment140(include_neutral: bool = False) -> tuple[list[str], list[int]]:
|
95 |
+
"""Load the sentiment140 dataset and make it suitable for use.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
include_neutral: Whether to include neutral sentiment
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
Text and label data
|
102 |
+
|
103 |
+
Raises:
|
104 |
+
FileNotFoundError: If the dataset is not found
|
105 |
+
"""
|
106 |
+
# Check if the dataset exists
|
107 |
+
if not SENTIMENT140_PATH.exists():
|
108 |
+
msg = (
|
109 |
+
f"Sentiment140 dataset not found at: '{SENTIMENT140_PATH}'\n"
|
110 |
+
"Please download the dataset from:\n"
|
111 |
+
f"{SENTIMENT140_URL}"
|
112 |
+
)
|
113 |
+
raise FileNotFoundError(msg)
|
114 |
+
|
115 |
+
# Load the dataset
|
116 |
+
data = pd.read_csv(
|
117 |
+
SENTIMENT140_PATH,
|
118 |
+
encoding="ISO-8859-1",
|
119 |
+
names=[
|
120 |
+
"target", # 0 = negative, 2 = neutral, 4 = positive
|
121 |
+
"id", # The id of the tweet
|
122 |
+
"date", # The date of the tweet
|
123 |
+
"flag", # The query, NO_QUERY if not present
|
124 |
+
"user", # The user that tweeted
|
125 |
+
"text", # The text of the tweet
|
126 |
+
],
|
127 |
+
)
|
128 |
|
129 |
+
# Ignore rows with neutral sentiment
|
130 |
+
if not include_neutral:
|
131 |
+
data = data[data["target"] != 2]
|
132 |
+
|
133 |
+
# Map sentiment values
|
134 |
+
data["sentiment"] = data["target"].map(
|
135 |
+
{
|
136 |
+
0: 0, # Negative
|
137 |
+
4: 1, # Positive
|
138 |
+
2: 2, # Neutral
|
139 |
+
},
|
140 |
+
)
|
141 |
|
142 |
+
# Return as lists
|
143 |
+
return data["text"].tolist(), data["sentiment"].tolist()
|
144 |
+
|
145 |
+
|
146 |
+
def load_amazonreviews(merge: bool = True) -> tuple[list[str], list[int]]:
|
147 |
+
"""Load the amazonreviews dataset and make it suitable for use.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
merge: Whether to merge the test and train datasets (otherwise ignore test)
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
Text and label data
|
154 |
+
|
155 |
+
Raises:
|
156 |
+
FileNotFoundError: If the dataset is not found
|
157 |
+
"""
|
158 |
+
# Check if the dataset exists
|
159 |
+
test_exists = AMAZONREVIEWS_PATH[0].exists() or not merge
|
160 |
+
train_exists = AMAZONREVIEWS_PATH[1].exists()
|
161 |
+
if not (test_exists and train_exists):
|
162 |
+
msg = (
|
163 |
+
f"Amazonreviews dataset not found at: '{AMAZONREVIEWS_PATH[0]}' and '{AMAZONREVIEWS_PATH[1]}'\n"
|
164 |
+
"Please download the dataset from:\n"
|
165 |
+
f"{AMAZONREVIEWS_URL}"
|
166 |
+
)
|
167 |
+
raise FileNotFoundError(msg)
|
168 |
+
|
169 |
+
# Load the datasets
|
170 |
+
with bz2.BZ2File(AMAZONREVIEWS_PATH[1]) as train_file:
|
171 |
+
train_data = [line.decode("utf-8") for line in train_file]
|
172 |
+
|
173 |
+
test_data = []
|
174 |
+
if merge:
|
175 |
+
with bz2.BZ2File(AMAZONREVIEWS_PATH[0]) as test_file:
|
176 |
+
test_data = [line.decode("utf-8") for line in test_file]
|
177 |
+
|
178 |
+
# Merge the datasets
|
179 |
+
data = train_data + test_data
|
180 |
+
|
181 |
+
# Split the data into labels and text
|
182 |
+
labels, texts = zip(*(line.split(" ", 1) for line in data))
|
183 |
+
|
184 |
+
# Map sentiment values
|
185 |
+
sentiments = [int(label.split("__label__")[1]) - 1 for label in labels]
|
186 |
+
|
187 |
+
# Return as lists
|
188 |
+
return texts, sentiments
|
189 |
+
|
190 |
+
|
191 |
+
def load_imdb50k() -> tuple[list[str], list[int]]:
|
192 |
+
"""Load the imdb50k dataset and make it suitable for use.
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
Text and label data
|
196 |
+
|
197 |
+
Raises:
|
198 |
+
FileNotFoundError: If the dataset is not found
|
199 |
+
"""
|
200 |
+
# Check if the dataset exists
|
201 |
+
if not IMDB50K_PATH.exists():
|
202 |
+
msg = (
|
203 |
+
f"IMDB50K dataset not found at: '{IMDB50K_PATH}'\n"
|
204 |
+
"Please download the dataset from:\n"
|
205 |
+
f"{IMDB50K_URL}"
|
206 |
+
) # fmt: off
|
207 |
+
raise FileNotFoundError(msg)
|
208 |
+
|
209 |
+
# Load the dataset
|
210 |
+
data = pd.read_csv(IMDB50K_PATH)
|
211 |
+
|
212 |
+
# Map sentiment values
|
213 |
+
data["sentiment"] = data["sentiment"].map(
|
214 |
+
{
|
215 |
+
"positive": 1,
|
216 |
+
"negative": 0,
|
217 |
+
},
|
218 |
+
)
|
219 |
|
220 |
+
# Return as lists
|
221 |
+
return data["review"].tolist(), data["sentiment"].tolist()
|
222 |
|
|
|
|
|
223 |
|
224 |
+
def load_data(dataset: Literal["sentiment140", "amazonreviews", "imdb50k"]) -> tuple[list[str], list[int]]:
|
225 |
+
"""Load and preprocess the specified dataset.
|
226 |
|
227 |
+
Args:
|
228 |
+
dataset: Dataset to load
|
|
|
229 |
|
230 |
+
Returns:
|
231 |
+
Text and label data
|
232 |
|
233 |
+
Raises:
|
234 |
+
ValueError: If the dataset is not recognized
|
235 |
+
"""
|
236 |
+
match dataset:
|
237 |
+
case "sentiment140":
|
238 |
+
return load_sentiment140(include_neutral=False)
|
239 |
+
case "amazonreviews":
|
240 |
+
return load_amazonreviews(merge=True)
|
241 |
+
case "imdb50k":
|
242 |
+
return load_imdb50k()
|
243 |
+
case _:
|
244 |
+
msg = f"Unknown dataset: {dataset}"
|
245 |
+
raise ValueError(msg)
|
246 |
|
247 |
|
248 |
+
def create_model(
|
249 |
+
max_features: int,
|
250 |
+
seed: int | None = None,
|
|
|
|
|
|
|
251 |
) -> Pipeline:
|
252 |
+
"""Create a sentiment analysis model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
+
Args:
|
255 |
+
max_features: Maximum number of features
|
256 |
+
seed: Random seed (None for random seed)
|
257 |
|
258 |
+
Returns:
|
259 |
+
Untrained model
|
260 |
+
"""
|
261 |
+
return Pipeline(
|
262 |
[
|
263 |
+
# Text preprocessing
|
264 |
+
("clean", TextCleaner()),
|
265 |
+
("lemma", TextLemmatizer()),
|
266 |
+
# Preprocess (NOTE: Can be replaced with TfidfVectorizer, but left for clarity)
|
267 |
+
("vectorize", CountVectorizer(stop_words="english", ngram_range=(1, 2), max_features=max_features)),
|
268 |
("tfidf", TfidfTransformer()),
|
269 |
+
# Classifier
|
270 |
+
("clf", LogisticRegression(max_iter=1000, random_state=seed)),
|
271 |
],
|
272 |
+
memory=Memory(CACHE_DIR, verbose=0),
|
|
|
|
|
|
|
273 |
)
|
274 |
|
275 |
|
276 |
+
def train_model(
|
277 |
+
model: Pipeline,
|
278 |
+
text_data: list[str],
|
279 |
+
label_data: list[int],
|
280 |
+
seed: int = 42,
|
281 |
+
) -> float:
|
282 |
+
"""Train the sentiment analysis model.
|
283 |
+
|
284 |
+
Args:
|
285 |
+
model: Untrained model
|
286 |
+
text_data: Text data
|
287 |
+
label_data: Label data
|
288 |
+
seed: Random seed (None for random seed)
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
Accuracy score
|
292 |
+
"""
|
293 |
+
text_train, text_test, label_train, label_test = train_test_split(
|
294 |
+
text_data,
|
295 |
+
label_data,
|
296 |
+
test_size=0.2,
|
297 |
+
random_state=seed,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
)
|
299 |
|
300 |
with warnings.catch_warnings():
|
301 |
+
warnings.simplefilter("ignore")
|
302 |
+
model.fit(text_train, label_train)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
|
304 |
+
return model.score(text_test, label_test)
|
app/utils.py
DELETED
@@ -1,164 +0,0 @@
|
|
1 |
-
"""Utility functions"""
|
2 |
-
|
3 |
-
from __future__ import annotations
|
4 |
-
|
5 |
-
import itertools
|
6 |
-
import re
|
7 |
-
import warnings
|
8 |
-
from collections import deque
|
9 |
-
from enum import Enum
|
10 |
-
from functools import lru_cache
|
11 |
-
from threading import Event, Lock
|
12 |
-
from typing import Any
|
13 |
-
|
14 |
-
from joblib import Memory
|
15 |
-
from numpy.random import RandomState
|
16 |
-
|
17 |
-
from constants import CACHE_DIR, DEFAULT_SEED
|
18 |
-
|
19 |
-
__all__ = ["colorize", "wrap_queued_call", "get_random_state", "get_cache_memory"]
|
20 |
-
|
21 |
-
|
22 |
-
ANSI_RESET = 0
|
23 |
-
|
24 |
-
|
25 |
-
class Color(Enum):
|
26 |
-
"""ANSI color codes."""
|
27 |
-
|
28 |
-
BLACK = 30
|
29 |
-
RED = 31
|
30 |
-
GREEN = 32
|
31 |
-
YELLOW = 33
|
32 |
-
BLUE = 34
|
33 |
-
MAGENTA = 35
|
34 |
-
CYAN = 36
|
35 |
-
WHITE = 37
|
36 |
-
|
37 |
-
|
38 |
-
class Style(Enum):
|
39 |
-
"""ANSI style codes."""
|
40 |
-
|
41 |
-
BOLD = 1
|
42 |
-
DIM = 2
|
43 |
-
ITALIC = 3
|
44 |
-
UNDERLINE = 4
|
45 |
-
BLINK = 5
|
46 |
-
INVERTED = 7
|
47 |
-
HIDDEN = 8
|
48 |
-
|
49 |
-
|
50 |
-
# https://gist.github.com/vitaliyp/6d54dd76ca2c3cdfc1149d33007dc34a
|
51 |
-
class FIFOLock:
|
52 |
-
def __init__(self):
|
53 |
-
self._lock = Lock()
|
54 |
-
self._inner_lock = Lock()
|
55 |
-
self._pending_threads = deque()
|
56 |
-
|
57 |
-
def acquire(self, blocking: bool = True) -> bool:
|
58 |
-
with self._inner_lock:
|
59 |
-
lock_acquired = self._lock.acquire(False)
|
60 |
-
if lock_acquired:
|
61 |
-
return True
|
62 |
-
if not blocking:
|
63 |
-
return False
|
64 |
-
|
65 |
-
release_event = Event()
|
66 |
-
self._pending_threads.append(release_event)
|
67 |
-
|
68 |
-
release_event.wait()
|
69 |
-
return self._lock.acquire()
|
70 |
-
|
71 |
-
def release(self) -> None:
|
72 |
-
with self._inner_lock:
|
73 |
-
if self._pending_threads:
|
74 |
-
release_event = self._pending_threads.popleft()
|
75 |
-
release_event.set()
|
76 |
-
|
77 |
-
self._lock.release()
|
78 |
-
|
79 |
-
__enter__ = acquire
|
80 |
-
|
81 |
-
def __exit__(self, _t, _v, _tb): # noqa: ANN001
|
82 |
-
self.release()
|
83 |
-
|
84 |
-
|
85 |
-
@lru_cache(maxsize=1)
|
86 |
-
def get_queue_lock() -> FIFOLock:
|
87 |
-
return FIFOLock()
|
88 |
-
|
89 |
-
|
90 |
-
@lru_cache(maxsize=1)
|
91 |
-
def get_random_state(seed: int = DEFAULT_SEED) -> RandomState:
|
92 |
-
return RandomState(seed)
|
93 |
-
|
94 |
-
|
95 |
-
@lru_cache(maxsize=1)
|
96 |
-
def get_cache_memory() -> Memory:
|
97 |
-
return Memory(CACHE_DIR, verbose=0)
|
98 |
-
|
99 |
-
|
100 |
-
def to_ansi(code: int) -> str:
|
101 |
-
"""Convert an integer to an ANSI escape code."""
|
102 |
-
return f"\033[{code}m"
|
103 |
-
|
104 |
-
|
105 |
-
@lru_cache(maxsize=None)
|
106 |
-
def get_ansi_color(color: Color, bright: bool = False, background: bool = False) -> str:
|
107 |
-
"""Get ANSI color code for the specified color, brightness and background."""
|
108 |
-
code = color.value
|
109 |
-
if bright:
|
110 |
-
code += 60
|
111 |
-
if background:
|
112 |
-
code += 10
|
113 |
-
return to_ansi(code)
|
114 |
-
|
115 |
-
|
116 |
-
def replace_color_tag(color: Color, text: str) -> None:
|
117 |
-
"""Replace both dark and light color tags for background and foreground."""
|
118 |
-
for bright, bg in itertools.product([False, True], repeat=2):
|
119 |
-
tag = f"{'BG_' if bg else ''}{'BRIGHT_' if bright else ''}{color.name}"
|
120 |
-
text = text.replace(f"[{tag}]", get_ansi_color(color, bright=bright, background=bg))
|
121 |
-
text = text.replace(f"[/{tag}]", to_ansi(ANSI_RESET))
|
122 |
-
|
123 |
-
return text
|
124 |
-
|
125 |
-
|
126 |
-
@lru_cache(maxsize=256)
|
127 |
-
def colorize(text: str, strip: bool = True) -> str:
|
128 |
-
"""Format text with ANSI color codes using tags [COLOR], [BG_COLOR] and [STYLE].
|
129 |
-
Reset color/style with [/TAG].
|
130 |
-
Escape with double brackets [[]]. Strip leading and trailing whitespace if strip=True.
|
131 |
-
"""
|
132 |
-
|
133 |
-
# replace foreground and background color tags
|
134 |
-
for color in Color:
|
135 |
-
text = replace_color_tag(color, text)
|
136 |
-
|
137 |
-
# replace style tags
|
138 |
-
for style in Style:
|
139 |
-
text = text.replace(f"[{style.name}]", to_ansi(style.value)).replace(f"[/{style.name}]", to_ansi(ANSI_RESET))
|
140 |
-
|
141 |
-
# if there are any tags left, remove them and throw a warning
|
142 |
-
pat1 = re.compile(r"((?<!\[)\[)([^\[\]]*)(\](?!\]))")
|
143 |
-
for match in pat1.finditer(text):
|
144 |
-
color = match.group(1)
|
145 |
-
text = text.replace(match.group(0), "")
|
146 |
-
warnings.warn(f"Invalid color tag: {color!r}", UserWarning, stacklevel=2)
|
147 |
-
|
148 |
-
# escape double brackets
|
149 |
-
pat2 = re.compile(r"\[\[[^\[\]\v]+\]\]")
|
150 |
-
text = pat2.sub("", text)
|
151 |
-
|
152 |
-
# reset color/style at the end
|
153 |
-
text += to_ansi(ANSI_RESET)
|
154 |
-
|
155 |
-
return text.strip() if strip else text
|
156 |
-
|
157 |
-
|
158 |
-
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/modules/call_queue.py
|
159 |
-
def wrap_queued_call(func: callable) -> callable:
|
160 |
-
def f(*args, **kwargs) -> Any: # noqa: ANN003, ANN002
|
161 |
-
with get_queue_lock():
|
162 |
-
return func(*args, **kwargs)
|
163 |
-
|
164 |
-
return f
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deprecated/__init__.py
DELETED
File without changes
|
deprecated/main.py
DELETED
@@ -1,44 +0,0 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
-
from pathlib import Path
|
4 |
-
|
5 |
-
import click
|
6 |
-
import joblib
|
7 |
-
|
8 |
-
from app.utils import colorize
|
9 |
-
|
10 |
-
|
11 |
-
@click.group()
|
12 |
-
def cli() -> None: ...
|
13 |
-
|
14 |
-
|
15 |
-
@cli.command("predict")
|
16 |
-
@click.option(
|
17 |
-
"-m",
|
18 |
-
"--model",
|
19 |
-
"model_path",
|
20 |
-
default="models/model.pkl",
|
21 |
-
help="Path to the model file.",
|
22 |
-
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True, resolve_path=True, path_type=Path),
|
23 |
-
)
|
24 |
-
@click.argument("text", nargs=-1)
|
25 |
-
def predict(model_path: Path, text: list[str]) -> None:
|
26 |
-
input_text = " ".join(text).strip()
|
27 |
-
if not input_text:
|
28 |
-
click.echo("[RED]Error[/RED]: Input text is empty.")
|
29 |
-
return
|
30 |
-
|
31 |
-
# Load the model
|
32 |
-
click.echo("Loading model... ", nl=False)
|
33 |
-
model = joblib.load(model_path)
|
34 |
-
click.echo(colorize("[GREEN]DONE"))
|
35 |
-
|
36 |
-
# Run the model
|
37 |
-
click.echo("Performing sentiment analysis... ", nl=False)
|
38 |
-
prediction = model.predict([input_text])
|
39 |
-
sentiment = "[GREEN]POSITIVE" if prediction[0] == 1 else "[RED]NEGATIVE"
|
40 |
-
click.echo(colorize(sentiment))
|
41 |
-
|
42 |
-
|
43 |
-
if __name__ == "__main__":
|
44 |
-
cli()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deprecated/train.py
DELETED
@@ -1,152 +0,0 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
-
import warnings
|
4 |
-
from pathlib import Path
|
5 |
-
from typing import TYPE_CHECKING
|
6 |
-
|
7 |
-
import click
|
8 |
-
import joblib
|
9 |
-
import pandas as pd
|
10 |
-
from numpy.random import RandomState
|
11 |
-
from sklearn.feature_extraction.text import TfidfVectorizer
|
12 |
-
from sklearn.linear_model import LogisticRegression
|
13 |
-
from sklearn.metrics import accuracy_score, classification_report
|
14 |
-
from sklearn.model_selection import train_test_split
|
15 |
-
from sklearn.pipeline import Pipeline
|
16 |
-
|
17 |
-
if TYPE_CHECKING:
|
18 |
-
from sklearn.base import BaseEstimator
|
19 |
-
|
20 |
-
SEED = 42
|
21 |
-
DATASET_PATH = Path("data/training.1600000.processed.noemoticon.csv")
|
22 |
-
STOPWORDS_PATH = Path("data/stopwords-en.txt")
|
23 |
-
CHECKPOINT_PATH = Path("cache/pipeline.pkl")
|
24 |
-
MODELS_DIR = Path("models")
|
25 |
-
CACHE_DIR = Path("cache")
|
26 |
-
MAX_FEATURES = 10000 # 500000
|
27 |
-
|
28 |
-
# Make sure paths exist
|
29 |
-
MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
30 |
-
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
31 |
-
|
32 |
-
# Memory cache for sklearn pipelines
|
33 |
-
mem = joblib.Memory(CACHE_DIR, verbose=0)
|
34 |
-
|
35 |
-
# TODO: use xgboost
|
36 |
-
|
37 |
-
|
38 |
-
def get_random_state(seed: int = SEED) -> RandomState:
|
39 |
-
return RandomState(seed)
|
40 |
-
|
41 |
-
|
42 |
-
def load_data() -> tuple[list[str], list[int]]:
|
43 |
-
"""The model takes in a list of strings and a list of integers where 1 is positive sentiment and 0 is negative sentiment."""
|
44 |
-
data = pd.read_csv(
|
45 |
-
DATASET_PATH,
|
46 |
-
encoding="ISO-8859-1",
|
47 |
-
names=[
|
48 |
-
"target", # 0 = negative, 2 = neutral, 4 = positive
|
49 |
-
"id", # The id of the tweet
|
50 |
-
"date", # The date of the tweet
|
51 |
-
"flag", # The query, NO_QUERY if not present
|
52 |
-
"user", # The user that tweeted
|
53 |
-
"text", # The text of the tweet
|
54 |
-
],
|
55 |
-
)
|
56 |
-
|
57 |
-
# Ignore rows with neutral sentiment
|
58 |
-
data = data[data["target"] != 2]
|
59 |
-
|
60 |
-
# Create new column called "sentiment" with 1 for positive and 0 for negative
|
61 |
-
data["sentiment"] = data["target"] == 4
|
62 |
-
|
63 |
-
# Drop the columns we don't need
|
64 |
-
# data = data.drop(columns=["target", "id", "date", "flag", "user"]) # NOTE: No need, since we return the columns we need
|
65 |
-
|
66 |
-
# Return as lists
|
67 |
-
return list(data["text"]), list(data["sentiment"])
|
68 |
-
|
69 |
-
|
70 |
-
def create_pipeline(clf: BaseEstimator) -> Pipeline:
|
71 |
-
return Pipeline(
|
72 |
-
[
|
73 |
-
# Preprocess
|
74 |
-
# ("vectorize", CountVectorizer(stop_words="english", ngram_range=(1, 2), max_features=MAX_FEATURES)),
|
75 |
-
# ("tfidf", TfidfTransformer()),
|
76 |
-
("vectorize", TfidfVectorizer(ngram_range=(1, 2), max_features=MAX_FEATURES)),
|
77 |
-
# Classifier
|
78 |
-
("clf", clf),
|
79 |
-
],
|
80 |
-
memory=mem,
|
81 |
-
)
|
82 |
-
|
83 |
-
|
84 |
-
def evaluate_pipeline(pipeline: Pipeline, x: list[str], y: list[int]) -> float:
|
85 |
-
y_pred = pipeline.predict(x)
|
86 |
-
report = classification_report(y, y_pred)
|
87 |
-
click.echo(report)
|
88 |
-
|
89 |
-
# TODO: Confusion matrix
|
90 |
-
|
91 |
-
return accuracy_score(y, y_pred)
|
92 |
-
|
93 |
-
|
94 |
-
def export_pipeline(pipeline: Pipeline, name: str) -> None:
|
95 |
-
model_path = MODELS_DIR / f"{name}.pkl"
|
96 |
-
joblib.dump(pipeline, model_path)
|
97 |
-
click.echo(f"Model exported to {model_path!r}")
|
98 |
-
|
99 |
-
|
100 |
-
@click.command()
|
101 |
-
@click.option("--retrain", is_flag=True, help="Train the model even if a checkpoint exists.")
|
102 |
-
@click.option("--evaluate", is_flag=True, help="Evaluate the model.")
|
103 |
-
@click.option("--flush-cache", is_flag=True, help="Clear sklearn cache.")
|
104 |
-
@click.option("--seed", type=int, default=SEED, help="Random seed.")
|
105 |
-
def train(retrain: bool, evaluate: bool, flush_cache: bool, seed: int) -> None:
|
106 |
-
rng = get_random_state(seed)
|
107 |
-
|
108 |
-
# Clear sklearn cache
|
109 |
-
if flush_cache:
|
110 |
-
click.echo("Clearing cache... ", nl=False)
|
111 |
-
mem.clear(warn=False)
|
112 |
-
click.echo("DONE")
|
113 |
-
|
114 |
-
# Load and split data
|
115 |
-
click.echo("Loading data... ", nl=False)
|
116 |
-
x, y = load_data()
|
117 |
-
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=rng)
|
118 |
-
click.echo("DONE")
|
119 |
-
|
120 |
-
# Train model
|
121 |
-
if retrain or not CHECKPOINT_PATH.exists():
|
122 |
-
click.echo("Training model... ", nl=False)
|
123 |
-
clf = LogisticRegression(max_iter=1000, random_state=rng)
|
124 |
-
model = create_pipeline(clf)
|
125 |
-
with warnings.catch_warnings():
|
126 |
-
warnings.simplefilter("ignore") # Ignore joblib warnings
|
127 |
-
model.fit(x_train, y_train)
|
128 |
-
joblib.dump(model, CHECKPOINT_PATH)
|
129 |
-
click.echo("DONE")
|
130 |
-
else:
|
131 |
-
click.echo("Loading model... ", nl=False)
|
132 |
-
model = joblib.load(CHECKPOINT_PATH)
|
133 |
-
click.echo("DONE")
|
134 |
-
|
135 |
-
# Evaluate model
|
136 |
-
if evaluate:
|
137 |
-
evaluate_pipeline(model, x_test, y_test)
|
138 |
-
|
139 |
-
# Quick test
|
140 |
-
test_text = ["I love this movie", "I hate this movie"]
|
141 |
-
click.echo("Quick test:")
|
142 |
-
for text in test_text:
|
143 |
-
click.echo(f"\t{'positive' if model.predict([text])[0] else 'negative'}: {text}")
|
144 |
-
|
145 |
-
# Export model
|
146 |
-
click.echo("Exporting model... ", nl=False)
|
147 |
-
export_pipeline(model, "logistic_regression")
|
148 |
-
click.echo("DONE")
|
149 |
-
|
150 |
-
|
151 |
-
if __name__ == "__main__":
|
152 |
-
train()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
justfile
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
#!/usr/bin/env just --justfile
|
2 |
|
3 |
@default:
|
4 |
-
|
5 |
|
6 |
@lint:
|
7 |
poetry run pre-commit run --all-files
|
@@ -16,8 +16,6 @@
|
|
16 |
@requirements:
|
17 |
poetry export -f requirements.txt --output requirements.txt --without dev
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
@gui:
|
23 |
-
poetry run gradio app/gui.py
|
|
|
1 |
#!/usr/bin/env just --justfile
|
2 |
|
3 |
@default:
|
4 |
+
just --list
|
5 |
|
6 |
@lint:
|
7 |
poetry run pre-commit run --all-files
|
|
|
16 |
@requirements:
|
17 |
poetry export -f requirements.txt --output requirements.txt --without dev
|
18 |
|
19 |
+
[no-exit-message]
|
20 |
+
@app *ARGS:
|
21 |
+
poetry run python -m app {{ARGS}}
|
|
|
|
notebook.ipynb
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# Sentiment Analysis"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "markdown",
|
12 |
+
"metadata": {},
|
13 |
+
"source": [
|
14 |
+
"## Imports"
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"cell_type": "code",
|
19 |
+
"execution_count": null,
|
20 |
+
"metadata": {},
|
21 |
+
"outputs": [],
|
22 |
+
"source": [
|
23 |
+
"from __future__ import annotations\n",
|
24 |
+
"\n",
|
25 |
+
"import re\n",
|
26 |
+
"from functools import cache\n",
|
27 |
+
"\n",
|
28 |
+
"import matplotlib.pyplot as plt\n",
|
29 |
+
"import pandas as pd\n",
|
30 |
+
"import seaborn as sns"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "markdown",
|
35 |
+
"metadata": {},
|
36 |
+
"source": [
|
37 |
+
"## Load the data"
|
38 |
+
]
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"cell_type": "code",
|
42 |
+
"execution_count": null,
|
43 |
+
"metadata": {},
|
44 |
+
"outputs": [],
|
45 |
+
"source": [
|
46 |
+
"data: pd.DataFrame = None # TODO: load dataset\n",
|
47 |
+
"stopwords: set[str] = None # TODO: load stopwords"
|
48 |
+
]
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"cell_type": "markdown",
|
52 |
+
"metadata": {},
|
53 |
+
"source": [
|
54 |
+
"## Explore the data"
|
55 |
+
]
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"cell_type": "code",
|
59 |
+
"execution_count": null,
|
60 |
+
"metadata": {},
|
61 |
+
"outputs": [],
|
62 |
+
"source": [
|
63 |
+
"# Plot the distribution\n",
|
64 |
+
"_, ax = plt.subplots(figsize=(6, 4))\n",
|
65 |
+
"data[\"sentiment\"].value_counts().plot(kind=\"bar\", ax=ax)\n",
|
66 |
+
"ax.set_xticklabels([\"Negative\", \"Positive\"], rotation=0)\n",
|
67 |
+
"ax.set_xlabel(\"Sentiment\")\n",
|
68 |
+
"ax.grid(False)\n",
|
69 |
+
"plt.show()"
|
70 |
+
]
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"cell_type": "code",
|
74 |
+
"execution_count": null,
|
75 |
+
"metadata": {},
|
76 |
+
"outputs": [],
|
77 |
+
"source": [
|
78 |
+
"@cache\n",
|
79 |
+
"def extract_words(text: str) -> list[str]:\n",
|
80 |
+
" return re.findall(r\"(\\b[^\\s]+\\b)\", text.lower())"
|
81 |
+
]
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"cell_type": "code",
|
85 |
+
"execution_count": null,
|
86 |
+
"metadata": {},
|
87 |
+
"outputs": [],
|
88 |
+
"source": [
|
89 |
+
"# Extract words and count them\n",
|
90 |
+
"words = data[\"text\"].apply(extract_words).explode()\n",
|
91 |
+
"word_counts = words.value_counts().reset_index()\n",
|
92 |
+
"word_counts.columns = [\"word\", \"count\"]\n",
|
93 |
+
"word_counts.head()"
|
94 |
+
]
|
95 |
+
},
|
96 |
+
{
|
97 |
+
"cell_type": "code",
|
98 |
+
"execution_count": null,
|
99 |
+
"metadata": {},
|
100 |
+
"outputs": [],
|
101 |
+
"source": [
|
102 |
+
"# Plot the most common words\n",
|
103 |
+
"_, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))\n",
|
104 |
+
"\n",
|
105 |
+
"sns.barplot(data=word_counts.head(10), x=\"count\", y=\"word\", ax=ax1)\n",
|
106 |
+
"ax1.set_title(\"Most common words\")\n",
|
107 |
+
"ax1.grid(False)\n",
|
108 |
+
"ax1.tick_params(axis=\"x\", rotation=45)\n",
|
109 |
+
"\n",
|
110 |
+
"ax2.set_title(\"Most common words (excluding stopwords)\")\n",
|
111 |
+
"sns.barplot(\n",
|
112 |
+
" data=word_counts[~word_counts[\"word\"].isin(stopwords)].head(10),\n",
|
113 |
+
" x=\"count\",\n",
|
114 |
+
" y=\"word\",\n",
|
115 |
+
" ax=ax2,\n",
|
116 |
+
")\n",
|
117 |
+
"ax2.grid(False)\n",
|
118 |
+
"ax2.tick_params(axis=\"x\", rotation=45)\n",
|
119 |
+
"ax2.set_ylabel(\"\")\n",
|
120 |
+
"\n",
|
121 |
+
"plt.show()"
|
122 |
+
]
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"cell_type": "markdown",
|
126 |
+
"metadata": {},
|
127 |
+
"source": [
|
128 |
+
"## Find best classifier"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"cell_type": "markdown",
|
133 |
+
"metadata": {},
|
134 |
+
"source": [
|
135 |
+
"## Find best hyperparameters"
|
136 |
+
]
|
137 |
+
}
|
138 |
+
],
|
139 |
+
"metadata": {
|
140 |
+
"kernelspec": {
|
141 |
+
"display_name": ".venv",
|
142 |
+
"language": "python",
|
143 |
+
"name": "python3"
|
144 |
+
},
|
145 |
+
"language_info": {
|
146 |
+
"name": "python",
|
147 |
+
"version": "3.12.3"
|
148 |
+
}
|
149 |
+
},
|
150 |
+
"nbformat": 4,
|
151 |
+
"nbformat_minor": 2
|
152 |
+
}
|
poetry.lock
CHANGED
Binary files a/poetry.lock and b/poetry.lock differ
|
|
pyproject.toml
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
[tool.poetry]
|
2 |
name = "sentiment-analysis"
|
3 |
package-mode = false
|
4 |
-
packages = [{ include = "app" }]
|
5 |
|
6 |
[tool.poetry.dependencies]
|
7 |
python = "^3.12"
|
8 |
click = "^8.1.7"
|
9 |
scikit-learn = "^1.4.2"
|
10 |
gradio = "^4.31.0"
|
|
|
|
|
11 |
|
12 |
[tool.poetry.group.train.dependencies]
|
13 |
pandas = "^2.2.2"
|
|
|
1 |
[tool.poetry]
|
2 |
name = "sentiment-analysis"
|
3 |
package-mode = false
|
|
|
4 |
|
5 |
[tool.poetry.dependencies]
|
6 |
python = "^3.12"
|
7 |
click = "^8.1.7"
|
8 |
scikit-learn = "^1.4.2"
|
9 |
gradio = "^4.31.0"
|
10 |
+
colorama = "^0.4.6"
|
11 |
+
nltk = "^3.8.1"
|
12 |
|
13 |
[tool.poetry.group.train.dependencies]
|
14 |
pandas = "^2.2.2"
|