Spaces:
Build error
Build error
#!/usr/bin/env python | |
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Redistribution and use in source and binary forms, with or without | |
# modification, are permitted provided that the following conditions | |
# are met: | |
# * Redistributions of source code must retain the above copyright | |
# notice, this list of conditions and the following disclaimer. | |
# * Redistributions in binary form must reproduce the above copyright | |
# notice, this list of conditions and the following disclaimer in the | |
# documentation and/or other materials provided with the distribution. | |
# * Neither the name of NVIDIA CORPORATION nor the names of its | |
# contributors may be used to endorse or promote products derived | |
# from this software without specific prior written permission. | |
# | |
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY | |
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | |
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR | |
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | |
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | |
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | |
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
import argparse | |
import queue | |
import sys | |
import uuid | |
from functools import partial | |
import numpy as np | |
import tritonclient.grpc as grpcclient | |
from tritonclient.utils import InferenceServerException | |
## | |
import time | |
import threading | |
### | |
FLAGS = None | |
class UserData: | |
def __init__(self): | |
self._completed_requests = queue.Queue() | |
# Define the callback function. Note the last two parameters should be | |
# result and error. InferenceServerClient would povide the results of an | |
# inference as grpcclient.InferResult in result. For successful | |
# inference, error will be None, otherwise it will be an object of | |
# tritonclientutils.InferenceServerException holding the error details | |
def callback(user_data, result, error): | |
if error: | |
user_data._completed_requests.put(error) | |
else: | |
user_data._completed_requests.put(result) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"-v", | |
"--verbose", | |
action="store_true", | |
required=False, | |
default=False, | |
help="Enable verbose output", | |
) | |
# parser.add_argument( | |
# "-u", | |
# "--url", | |
# type=str, | |
# required=False, | |
# default="localhost:8001", | |
# help="Inference server URL and it gRPC port. Default is localhost:8001.", | |
# ) | |
parser.add_argument( | |
"-u", | |
"--url", | |
type=str, | |
required=False, | |
default="10.199.14.151:8001", | |
help="Inference server URL and it gRPC port. Default is localhost:8001.", | |
) | |
parser.add_argument( | |
"-t", | |
"--stream-timeout", | |
type=float, | |
required=False, | |
default=None, | |
help="Stream timeout in seconds. Default is None.", | |
) | |
# parser.add_argument( | |
# "-d", | |
# "--dyna", | |
# action="store_true", | |
# required=False, | |
# default=False, | |
# help="Assume dynamic sequence model", | |
# ) | |
# parser.add_argument( | |
# "-o", | |
# "--offset", | |
# type=int, | |
# required=False, | |
# default=0, | |
# help="Add offset to sequence ID used", | |
# ) | |
FLAGS = parser.parse_args() | |
# # We use custom "sequence" models which take 1 input | |
# # value. The output is the accumulated value of the inputs. See | |
# # src/custom/sequence. | |
# int_sequence_model_name = ( | |
# "simple_dyna_sequence" if FLAGS.dyna else "simple_sequence" | |
# ) | |
# string_sequence_model_name = ( | |
# "simple_string_dyna_sequence" if FLAGS.dyna else "simple_sequence" | |
# ) | |
model_name = 'ensemble_mllm' | |
model_version = "" | |
batch_size = 1 | |
# img_url = f"https://s3plus.sankuai.com/automl-pkgs/0000.jpeg" | |
img_url = "/workdir/yanghandi/gradio_demo/static/0000.jpeg" | |
# img_url = f"https://s3plus.sankuai.com/automl-pkgs/0003.jpeg" | |
text = f"详细描述一下这张图片" | |
sequence_id = 100 | |
int_sequence_id0 = sequence_id | |
result_list = [] | |
user_data = UserData() | |
# It is advisable to use client object within with..as clause | |
# when sending streaming requests. This ensures the client | |
# is closed when the block inside with exits. | |
with grpcclient.InferenceServerClient( | |
url=FLAGS.url, verbose=FLAGS.verbose | |
) as triton_client: | |
try: | |
# Establish stream | |
triton_client.start_stream( | |
callback=partial(callback, user_data), | |
stream_timeout=FLAGS.stream_timeout, | |
) | |
# Create the tensor for INPUT | |
inputs = [] | |
img_url_bytes = img_url.encode("utf-8") | |
img_url_bytes = np.array(img_url_bytes, dtype=bytes) | |
img_url_bytes = img_url_bytes.reshape([1, -1]) | |
inputs.append(grpcclient.InferInput('IMAGE_URL', img_url_bytes.shape, "BYTES")) | |
inputs[0].set_data_from_numpy(img_url_bytes) | |
text_bytes = text.encode("utf-8") | |
text_bytes = np.array(text_bytes, dtype=bytes) | |
text_bytes = text_bytes.reshape([1, -1]) | |
# text_input = np.expand_dims(text_bytes, axis=0) | |
text_input = text_bytes | |
inputs.append(grpcclient.InferInput('TEXT', text_input.shape, "BYTES")) | |
inputs[1].set_data_from_numpy(text_input) | |
outputs = [] | |
outputs.append(grpcclient.InferRequestedOutput("OUTPUT")) | |
# Issue the asynchronous sequence inference. | |
triton_client.async_stream_infer( | |
model_name=model_name, | |
inputs=inputs, | |
outputs=outputs, | |
request_id="{}".format(sequence_id), | |
sequence_id=sequence_id, | |
sequence_start=True, | |
sequence_end=True, | |
) | |
except InferenceServerException as error: | |
print(error) | |
sys.exit(1) | |
# Retrieve results... | |
recv_count = 0 | |
##### | |
#### | |
while True: | |
# if len(result_list) == 80: | |
# print("1") | |
data_item = user_data._completed_requests.get() | |
# try: | |
# data_item = user_data._completed_requests.get(timeout=5) | |
# except Exception as e: | |
# print("queue wrong") | |
# break | |
if type(data_item) == InferenceServerException: | |
print('InferenceServerException: ', data_item) | |
sys.exit(1) | |
this_id = data_item.get_response().id.split("_")[0] | |
if int(this_id) != int_sequence_id0: | |
print("unexpected sequence id returned by the server: {}".format(this_id)) | |
sys.exit(1) | |
result = data_item.as_numpy("OUTPUT") | |
if len(result[0][0])==0: | |
break | |
result_list.append(data_item.as_numpy("OUTPUT")) | |
recv_count = recv_count + 1 | |
result_str = ''.join([item[0][0].decode('utf-8') for item in result_list]) | |
print(f"{len(result_list)}: {result_str}") | |
print("hd",result_str) | |
print("PASS: Sequence") | |
print("hd",result_str) |