|
import os |
|
from absl import logging |
|
|
|
from tfx import v1 as tfx |
|
from tfx import proto |
|
from tfx.orchestration.kubeflow.v2 import kubeflow_v2_dag_runner as runner |
|
from tfx.orchestration.data_types import RuntimeParameter |
|
from tfx.proto import pusher_pb2 |
|
from tfx.proto import trainer_pb2 |
|
from tfx.proto import tuner_pb2 |
|
|
|
from pipeline import configs |
|
from pipeline import pipeline |
|
|
|
""" |
|
RuntimeParameter could be injected with TFX CLI |
|
: |
|
--runtime-parameter output-config='{}' \ |
|
--runtime-parameter input-config='{"splits": [{"name": "train", "pattern": "span-[12]/train/*.tfrecord"}, {"name": "val", "pattern": "span-[12]/test/*.tfrecord"}]}' |
|
|
|
OR it could be injected programatically |
|
: |
|
import json |
|
from kfp.v2.google import client |
|
|
|
pipelines_client = client.AIPlatformClient( |
|
project_id=GOOGLE_CLOUD_PROJECT, region=GOOGLE_CLOUD_REGION, |
|
) |
|
_ = pipelines_client.create_run_from_job_spec( |
|
PIPELINE_DEFINITION_FILE, |
|
enable_caching=False, |
|
parameter_values={ |
|
"input-config": json.dumps( |
|
{ |
|
"splits": [ |
|
{"name": "train", "pattern": "span-[12]/train/*.tfrecord"}, |
|
{"name": "val", "pattern": "span-[12]/test/*.tfrecord"}, |
|
] |
|
} |
|
), |
|
"output-config": json.dumps({}), |
|
}, |
|
) |
|
""" |
|
|
|
|
|
def run(): |
|
runner_config = runner.KubeflowV2DagRunnerConfig( |
|
default_image=configs.PIPELINE_IMAGE |
|
) |
|
|
|
runner.KubeflowV2DagRunner( |
|
config=runner_config, |
|
output_filename=configs.PIPELINE_NAME + "_pipeline.json", |
|
).run( |
|
pipeline.create_pipeline( |
|
input_config=RuntimeParameter( |
|
name="input-config", |
|
default='{"input_config": {"splits": [{"name":"train", "pattern":"span-1/train/*"}, {"name":"eval", "pattern":"span-1/test/*"}]}}', |
|
ptype=str, |
|
), |
|
output_config=RuntimeParameter( |
|
name="output-config", |
|
default="{}", |
|
ptype=str, |
|
), |
|
pipeline_name=configs.PIPELINE_NAME, |
|
pipeline_root=configs.PIPELINE_ROOT, |
|
data_path=configs.DATA_PATH, |
|
modules={ |
|
"preprocessing_fn": configs.PREPROCESSING_FN, |
|
"training_fn": configs.TRAINING_FN, |
|
"cloud_tuner_fn": configs.CLOUD_TUNER_FN, |
|
}, |
|
train_args=trainer_pb2.TrainArgs(num_steps=configs.TRAIN_NUM_STEPS), |
|
eval_args=trainer_pb2.EvalArgs(num_steps=configs.EVAL_NUM_STEPS), |
|
tuner_args=tuner_pb2.TuneArgs( |
|
num_parallel_trials=configs.NUM_PARALLEL_TRIALS |
|
), |
|
ai_platform_training_args=configs.GCP_AI_PLATFORM_TRAINING_ARGS, |
|
ai_platform_tuner_args=configs.GCP_AI_PLATFORM_TUNER_ARGS, |
|
ai_platform_serving_args=configs.GCP_AI_PLATFORM_SERVING_ARGS, |
|
gh_release_args=configs.GH_RELEASE_ARGS, |
|
hf_model_release_args=configs.HF_MODEL_RELEASE_ARGS, |
|
) |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
logging.set_verbosity(logging.INFO) |
|
run() |
|
|