Spaces:
Sleeping
Sleeping
File size: 896 Bytes
48b647e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
from typing import Literal, Optional
from joblib import dump
from sklearn import datasets
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler
from sklearn.tree import DecisionTreeClassifier
def load_dataset(dataset_name: Literal["iris", "other"]):
if dataset_name != "iris":
raise NotImplementedError()
dataset = datasets.load_iris(return_X_y=True)
return dataset[0], dataset[1]
def train_ml_classifier(X, y, output_file: Optional[str] = None):
clf_pipeline = [('scaling', MinMaxScaler()),
('classifier', DecisionTreeClassifier(random_state=42))]
pipeline = Pipeline(clf_pipeline)
pipeline.fit(X, y)
if output_file is not None:
dump(pipeline, output_file)
if __name__ == '__main__':
X, y = load_dataset('iris')
model = train_ml_classifier(X, y, output_file='./iris_v1.joblib')
|