File size: 1,450 Bytes
0598e08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
"""Script to create the model artifact
Trains a simple logistic regression with grid search on a synthetic dataset and
stores the model in a pickle file.
"""
import pickle
from sklearn.datasets import make_classification
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import GridSearchCV
SEED = 0
def get_data():
X, y = make_classification(n_samples=1000, random_state=SEED)
return X, y
def get_model(**kwargs):
model = SGDClassifier(random_state=SEED)
model.set_params(**kwargs)
return model
def get_hparams():
hparams = {
'penalty': ['l1', 'l2'],
'alpha': [0.00001, 0.0001, 0.001],
}
return hparams
def grid_search(model, X, y, hparams):
search = GridSearchCV(model, hparams, cv=5, scoring='accuracy')
search.fit(X, y)
return search
def train(model, X, y, hparams):
search = grid_search(model, X, y, hparams=hparams)
print(f"Best accuracy: {100 * search.best_score_:.1f}%")
print(f"Best parameters: {search.best_params_}")
return search.best_estimator_
def save_model(model, filename):
with open(filename, 'wb') as f:
pickle.dump(model, f)
print(f"Stored model in '{filename}'")
def main():
X, y = get_data()
model = get_model()
hparams = get_hparams()
model_trained = train(model, X, y, hparams=hparams)
save_model(model_trained, 'model.pickle')
if __name__ == '__main__':
main()
|