Spaces:
Runtime error
Runtime error
File size: 1,728 Bytes
c73381c 7f1aa39 c73381c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
# {% include 'template/license_header' %}
import mlflow
import pandas as pd
from sklearn.base import ClassifierMixin
from sklearn.tree import DecisionTreeClassifier
from typing_extensions import Annotated
from zenml import ArtifactConfig, step
from zenml.client import Client
from zenml.logger import get_logger
logger = get_logger(__name__)
experiment_tracker = Client().active_stack.experiment_tracker
@step(enable_cache=False, experiment_tracker="mlflow", step_operator="sagemaker-eu")
def model_trainer(
dataset_trn: pd.DataFrame,
) -> Annotated[ClassifierMixin, ArtifactConfig(name="model", is_model_artifact=True)]:
"""Configure and train a model on the training dataset.
This is an example of a model training step that takes in a dataset artifact
previously loaded and pre-processed by other steps in your pipeline, then
configures and trains a model on it. The model is then returned as a step
output artifact.
Args:
dataset_trn: The preprocessed train dataset.
target: The name of the target column in the dataset.
Returns:
The trained model artifact.
"""
# Use the dataset to fetch the target
# context = get_step_context()
# target = context.inputs["dataset_trn"].run_metadata['target'].value
target = "target"
# Initialize the model with the hyperparameters indicated in the step
# parameters and train it on the training set.
model = DecisionTreeClassifier()
logger.info(f"Training model {model}...")
model.fit(
dataset_trn.drop(columns=[target]),
dataset_trn[target],
)
mlflow.sklearn.log_model(model, "breast_cancer_classifier_model")
mlflow.sklearn.autolog()
return model
|