123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- # -*- encoding: utf-8 -*-
- """
- @File : AiEngine.py
- @Time : 2023/2/28 15:15
- @Author : stephen
- @Email : zhangdongming@asj6.wecom.work
- @Software: PyCharm
- """
- import ast
- import time
- import numpy as np
- import tritonclient.grpc as grpcclient
- class AiEngine:
- def __init__(self, **kwargs):
- # create grpc conncet
- try:
- self.triton_client = grpcclient.InferenceServerClient(
- url=kwargs.get('url'),
- verbose=False,
- ssl=False,
- root_certificates=None,
- private_key=None,
- certificate_chain=None)
- conncet = True
- except Exception as e:
- conncet = False
- self.health = True
- health_check = [
- conncet,
- self.triton_client.is_server_live(),
- self.triton_client.is_server_ready()
- ]
- if False in health_check:
- print("health_check failed:" + health_check)
- self.health = False
- self.model_name = None
- def set_model(self, model_name):
- if self.triton_client.is_model_ready(model_name):
- self.model_name = model_name
- return True
- else:
- return False
- def get_model_name(self, model_name):
- return self.model_name
- def yolo_infer(self, img_arr, nms_threshold, confidence, client_timeout):
- if self.model_name == 'AI_5obj_pdcpv_detect_yolov5_pipeline':
- pic_num_sum = len(img_arr)
- nms_threshold = np.array([[nms_threshold]] * pic_num_sum, dtype=np.float32)
- confidence = np.array([[confidence]] * pic_num_sum, dtype=np.float32)
- inputs = [
- grpcclient.InferInput('img1', [pic_num_sum, 360, 640, 3], "UINT8"),
- grpcclient.InferInput('nms_threshold', [pic_num_sum, 1], "FP32"),
- grpcclient.InferInput('confidence', [pic_num_sum, 1], "FP32")
- ]
- outputs = [grpcclient.InferRequestedOutput('OUTPUT0')]
- inputs[0].set_data_from_numpy(img_arr)
- inputs[1].set_data_from_numpy(nms_threshold)
- inputs[2].set_data_from_numpy(confidence)
- try:
- results = self.triton_client.infer(model_name=self.model_name, inputs=inputs, outputs=outputs,
- client_timeout=client_timeout)
- except Exception as e:
- return 'e_timeout'
- result_str = results.as_numpy('OUTPUT0')[0].decode("UTF-8")
- results = ast.literal_eval(result_str)
- return results
- else:
- return 'e_no_model'
- def close(self):
- return self.triton_client.close()
- def func():
- print('func start')
- time.sleep(1)
- print('func end')
|