Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import json | |
import logging | |
from typing import Union | |
import numpy as np | |
import wrapt | |
from google.protobuf import json_format # pytype: disable=pyi-error | |
from tritonclient.grpc import InferenceServerClient as SyncGrpcInferenceServerClient | |
from tritonclient.grpc import model_config_pb2, service_pb2 | |
from tritonclient.http import InferenceServerClient as SyncHttpInferenceServerClient | |
from tritonclient.http.aio import InferenceServerClient as AsyncioHttpInferenceServerClient | |
from pytriton.model_config import DeviceKind | |
from pytriton.model_config.generator import ModelConfigGenerator | |
from pytriton.model_config.triton_model_config import TensorSpec, TritonModelConfig | |
_LOGGER = logging.getLogger(__name__) | |
ADD_SUB_WITH_BATCHING_MODEL_CONFIG = TritonModelConfig( | |
model_name="AddSub", | |
model_version=1, | |
max_batch_size=16, | |
instance_group={DeviceKind.KIND_CPU: 1}, | |
inputs=[ | |
TensorSpec(name="a", shape=(-1, 1), dtype=np.float32), | |
TensorSpec(name="b", shape=(-1, 1), dtype=np.float32), | |
], | |
outputs=[ | |
TensorSpec(name="add", shape=(-1, 1), dtype=np.float32), | |
TensorSpec(name="sub", shape=(-1, 1), dtype=np.float32), | |
], | |
backend_parameters={"shared-memory-socket": "dummy/path"}, | |
) | |
ADD_SUB_WITHOUT_BATCHING_MODEL_CONFIG = TritonModelConfig( | |
model_name="AddSub", | |
model_version=1, | |
batching=False, | |
instance_group={DeviceKind.KIND_CPU: 1}, | |
inputs=[ | |
TensorSpec(name="a", shape=(1,), dtype=np.float32), | |
TensorSpec(name="b", shape=(1,), dtype=np.float32), | |
], | |
outputs=[ | |
TensorSpec(name="add", shape=(1,), dtype=np.float32), | |
TensorSpec(name="sub", shape=(1,), dtype=np.float32), | |
], | |
backend_parameters={"shared-memory-socket": "dummy/path"}, | |
) | |
GRPC_LOCALHOST_URL = "grpc://localhost:8001" | |
HTTP_LOCALHOST_URL_NO_SCHEME = "localhost:8000" | |
HTTP_LOCALHOST_URL = f"http://{HTTP_LOCALHOST_URL_NO_SCHEME}" | |
EXPECTED_KWARGS_DEFAULT = { | |
"model_name": ADD_SUB_WITH_BATCHING_MODEL_CONFIG.model_name, | |
"model_version": "", | |
"request_id": "0", | |
"parameters": None, | |
"headers": None, | |
} | |
_TritonClientType = Union[ | |
AsyncioHttpInferenceServerClient, SyncHttpInferenceServerClient, SyncGrpcInferenceServerClient | |
] | |
_HttpTritonClientType = Union[AsyncioHttpInferenceServerClient, SyncHttpInferenceServerClient] | |
_GrpcTritonClientType = SyncGrpcInferenceServerClient | |
def patch_client__server_up_and_ready( | |
mocker, base_triton_client: _TritonClientType, ready_server: bool = True, live_server: bool = True | |
): | |
mocker.patch.object(base_triton_client, base_triton_client.is_server_ready.__name__).return_value = ready_server | |
mocker.patch.object(base_triton_client, base_triton_client.is_server_live.__name__).return_value = live_server | |
def patch_http_client__model_up_and_ready( | |
mocker, | |
model_config: TritonModelConfig, | |
base_triton_client: _HttpTritonClientType, | |
ready: bool = True, | |
): | |
mocker.patch.object(base_triton_client, base_triton_client.is_model_ready.__name__).return_value = ready | |
model_config_dict = ModelConfigGenerator(model_config).get_config() | |
mock_get_model_config = mocker.patch.object(base_triton_client, base_triton_client.get_model_config.__name__) | |
mock_get_model_config.return_value = model_config_dict | |
def patch_grpc_client__model_up_and_ready( | |
mocker, | |
model_config: TritonModelConfig, | |
base_triton_client: _GrpcTritonClientType, | |
ready: bool = True, | |
): | |
def new_is_model_ready(model_name, model_version="", headers=None, parameters=None): | |
return ( | |
ready | |
and model_name == model_config.model_name | |
and (model_version == "" or model_version == str(model_config.model_version)) | |
) | |
mocker.patch.object(base_triton_client, base_triton_client.is_model_ready.__name__, side_effect=new_is_model_ready) | |
model_config_dict = ModelConfigGenerator(model_config).get_config() | |
model_config_protobuf = json_format.ParseDict(model_config_dict, model_config_pb2.ModelConfig()) | |
response = service_pb2.ModelConfigResponse(config=model_config_protobuf) | |
response_dict = json.loads(json_format.MessageToJson(response, preserving_proto_field_name=True)) | |
mock_get_model_config = mocker.patch.object(base_triton_client, base_triton_client.get_model_config.__name__) | |
mock_get_model_config.return_value = response_dict | |
def patch_server_model_addsub_no_batch_ready(wrapped, _instance, _args, kwargs): | |
mocker = kwargs["mocker"] | |
patch_client__server_up_and_ready(mocker, SyncGrpcInferenceServerClient) | |
patch_grpc_client__model_up_and_ready(mocker, ADD_SUB_WITH_BATCHING_MODEL_CONFIG, SyncGrpcInferenceServerClient) | |
return wrapped(mocker) | |