André Fernandes
added predict method for a classic ML classifier model (wip)
48b647e
raw
history blame contribute delete
896 Bytes
from typing import Literal, Optional
from joblib import dump
from sklearn import datasets
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler
from sklearn.tree import DecisionTreeClassifier
def load_dataset(dataset_name: Literal["iris", "other"]):
if dataset_name != "iris":
raise NotImplementedError()
dataset = datasets.load_iris(return_X_y=True)
return dataset[0], dataset[1]
def train_ml_classifier(X, y, output_file: Optional[str] = None):
clf_pipeline = [('scaling', MinMaxScaler()),
('classifier', DecisionTreeClassifier(random_state=42))]
pipeline = Pipeline(clf_pipeline)
pipeline.fit(X, y)
if output_file is not None:
dump(pipeline, output_file)
if __name__ == '__main__':
X, y = load_dataset('iris')
model = train_ml_classifier(X, y, output_file='./iris_v1.joblib')