chansung's picture
hello
efb03ac
raw
history blame
No virus
3.13 kB
import os
from absl import logging
from tfx import v1 as tfx
from tfx import proto
from tfx.orchestration.kubeflow.v2 import kubeflow_v2_dag_runner as runner
from tfx.orchestration.data_types import RuntimeParameter
from tfx.proto import pusher_pb2
from tfx.proto import trainer_pb2
from tfx.proto import tuner_pb2
from pipeline import configs
from pipeline import pipeline
"""
RuntimeParameter could be injected with TFX CLI
:
--runtime-parameter output-config='{}' \
--runtime-parameter input-config='{"splits": [{"name": "train", "pattern": "span-[12]/train/*.tfrecord"}, {"name": "val", "pattern": "span-[12]/test/*.tfrecord"}]}'
OR it could be injected programatically
:
import json
from kfp.v2.google import client
pipelines_client = client.AIPlatformClient(
project_id=GOOGLE_CLOUD_PROJECT, region=GOOGLE_CLOUD_REGION,
)
_ = pipelines_client.create_run_from_job_spec(
PIPELINE_DEFINITION_FILE,
enable_caching=False,
parameter_values={
"input-config": json.dumps(
{
"splits": [
{"name": "train", "pattern": "span-[12]/train/*.tfrecord"},
{"name": "val", "pattern": "span-[12]/test/*.tfrecord"},
]
}
),
"output-config": json.dumps({}),
},
)
"""
def run():
runner_config = runner.KubeflowV2DagRunnerConfig(
default_image=configs.PIPELINE_IMAGE
)
runner.KubeflowV2DagRunner(
config=runner_config,
output_filename=configs.PIPELINE_NAME + "_pipeline.json",
).run(
pipeline.create_pipeline(
input_config=RuntimeParameter(
name="input-config",
default='{"input_config": {"splits": [{"name":"train", "pattern":"span-1/train/*"}, {"name":"eval", "pattern":"span-1/test/*"}]}}',
ptype=str,
),
output_config=RuntimeParameter(
name="output-config",
default="{}",
ptype=str,
),
pipeline_name=configs.PIPELINE_NAME,
pipeline_root=configs.PIPELINE_ROOT,
data_path=configs.DATA_PATH,
modules={
"preprocessing_fn": configs.PREPROCESSING_FN,
"training_fn": configs.TRAINING_FN,
"cloud_tuner_fn": configs.CLOUD_TUNER_FN,
},
train_args=trainer_pb2.TrainArgs(num_steps=configs.TRAIN_NUM_STEPS),
eval_args=trainer_pb2.EvalArgs(num_steps=configs.EVAL_NUM_STEPS),
tuner_args=tuner_pb2.TuneArgs(
num_parallel_trials=configs.NUM_PARALLEL_TRIALS
),
ai_platform_training_args=configs.GCP_AI_PLATFORM_TRAINING_ARGS,
ai_platform_tuner_args=configs.GCP_AI_PLATFORM_TUNER_ARGS,
ai_platform_serving_args=configs.GCP_AI_PLATFORM_SERVING_ARGS,
gh_release_args=configs.GH_RELEASE_ARGS,
hf_model_release_args=configs.HF_MODEL_RELEASE_ARGS,
)
)
if __name__ == "__main__":
logging.set_verbosity(logging.INFO)
run()