yerang's picture
Upload 701 files
7931de6 verified
raw
history blame
5.34 kB
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from typing import Union
import numpy as np
import wrapt
from google.protobuf import json_format # pytype: disable=pyi-error
from tritonclient.grpc import InferenceServerClient as SyncGrpcInferenceServerClient
from tritonclient.grpc import model_config_pb2, service_pb2
from tritonclient.http import InferenceServerClient as SyncHttpInferenceServerClient
from tritonclient.http.aio import InferenceServerClient as AsyncioHttpInferenceServerClient
from pytriton.model_config import DeviceKind
from pytriton.model_config.generator import ModelConfigGenerator
from pytriton.model_config.triton_model_config import TensorSpec, TritonModelConfig
_LOGGER = logging.getLogger(__name__)
ADD_SUB_WITH_BATCHING_MODEL_CONFIG = TritonModelConfig(
model_name="AddSub",
model_version=1,
max_batch_size=16,
instance_group={DeviceKind.KIND_CPU: 1},
inputs=[
TensorSpec(name="a", shape=(-1, 1), dtype=np.float32),
TensorSpec(name="b", shape=(-1, 1), dtype=np.float32),
],
outputs=[
TensorSpec(name="add", shape=(-1, 1), dtype=np.float32),
TensorSpec(name="sub", shape=(-1, 1), dtype=np.float32),
],
backend_parameters={"shared-memory-socket": "dummy/path"},
)
ADD_SUB_WITHOUT_BATCHING_MODEL_CONFIG = TritonModelConfig(
model_name="AddSub",
model_version=1,
batching=False,
instance_group={DeviceKind.KIND_CPU: 1},
inputs=[
TensorSpec(name="a", shape=(1,), dtype=np.float32),
TensorSpec(name="b", shape=(1,), dtype=np.float32),
],
outputs=[
TensorSpec(name="add", shape=(1,), dtype=np.float32),
TensorSpec(name="sub", shape=(1,), dtype=np.float32),
],
backend_parameters={"shared-memory-socket": "dummy/path"},
)
GRPC_LOCALHOST_URL = "grpc://localhost:8001"
HTTP_LOCALHOST_URL_NO_SCHEME = "localhost:8000"
HTTP_LOCALHOST_URL = f"http://{HTTP_LOCALHOST_URL_NO_SCHEME}"
EXPECTED_KWARGS_DEFAULT = {
"model_name": ADD_SUB_WITH_BATCHING_MODEL_CONFIG.model_name,
"model_version": "",
"request_id": "0",
"parameters": None,
"headers": None,
}
_TritonClientType = Union[
AsyncioHttpInferenceServerClient, SyncHttpInferenceServerClient, SyncGrpcInferenceServerClient
]
_HttpTritonClientType = Union[AsyncioHttpInferenceServerClient, SyncHttpInferenceServerClient]
_GrpcTritonClientType = SyncGrpcInferenceServerClient
def patch_client__server_up_and_ready(
mocker, base_triton_client: _TritonClientType, ready_server: bool = True, live_server: bool = True
):
mocker.patch.object(base_triton_client, base_triton_client.is_server_ready.__name__).return_value = ready_server
mocker.patch.object(base_triton_client, base_triton_client.is_server_live.__name__).return_value = live_server
def patch_http_client__model_up_and_ready(
mocker,
model_config: TritonModelConfig,
base_triton_client: _HttpTritonClientType,
ready: bool = True,
):
mocker.patch.object(base_triton_client, base_triton_client.is_model_ready.__name__).return_value = ready
model_config_dict = ModelConfigGenerator(model_config).get_config()
mock_get_model_config = mocker.patch.object(base_triton_client, base_triton_client.get_model_config.__name__)
mock_get_model_config.return_value = model_config_dict
def patch_grpc_client__model_up_and_ready(
mocker,
model_config: TritonModelConfig,
base_triton_client: _GrpcTritonClientType,
ready: bool = True,
):
def new_is_model_ready(model_name, model_version="", headers=None, parameters=None):
return (
ready
and model_name == model_config.model_name
and (model_version == "" or model_version == str(model_config.model_version))
)
mocker.patch.object(base_triton_client, base_triton_client.is_model_ready.__name__, side_effect=new_is_model_ready)
model_config_dict = ModelConfigGenerator(model_config).get_config()
model_config_protobuf = json_format.ParseDict(model_config_dict, model_config_pb2.ModelConfig())
response = service_pb2.ModelConfigResponse(config=model_config_protobuf)
response_dict = json.loads(json_format.MessageToJson(response, preserving_proto_field_name=True))
mock_get_model_config = mocker.patch.object(base_triton_client, base_triton_client.get_model_config.__name__)
mock_get_model_config.return_value = response_dict
@wrapt.decorator
def patch_server_model_addsub_no_batch_ready(wrapped, _instance, _args, kwargs):
mocker = kwargs["mocker"]
patch_client__server_up_and_ready(mocker, SyncGrpcInferenceServerClient)
patch_grpc_client__model_up_and_ready(mocker, ADD_SUB_WITH_BATCHING_MODEL_CONFIG, SyncGrpcInferenceServerClient)
return wrapped(mocker)