# -*- encoding: utf-8 -*- """ @File : AiEngine.py @Time : 2023/2/28 15:15 @Author : stephen @Email : zhangdongming@asj6.wecom.work @Software: PyCharm """ import ast import time import numpy as np import tritonclient.grpc as grpcclient class AiEngine: def __init__(self, **kwargs): # create grpc conncet try: self.triton_client = grpcclient.InferenceServerClient( url=kwargs.get('url'), verbose=False, ssl=False, root_certificates=None, private_key=None, certificate_chain=None) conncet = True except Exception as e: conncet = False self.health = True health_check = [ conncet, self.triton_client.is_server_live(), self.triton_client.is_server_ready() ] if False in health_check: print("health_check failed:" + health_check) self.health = False self.model_name = None def set_model(self, model_name): if self.triton_client.is_model_ready(model_name): self.model_name = model_name return True else: return False def get_model_name(self, model_name): return self.model_name def yolo_infer(self, img_arr, nms_threshold, confidence, client_timeout): if self.model_name == 'AI_5obj_pdcpv_detect_yolov5_pipeline': pic_num_sum = len(img_arr) nms_threshold = np.array([[nms_threshold]] * pic_num_sum, dtype=np.float32) confidence = np.array([[confidence]] * pic_num_sum, dtype=np.float32) inputs = [ grpcclient.InferInput('img1', [pic_num_sum, 360, 640, 3], "UINT8"), grpcclient.InferInput('nms_threshold', [pic_num_sum, 1], "FP32"), grpcclient.InferInput('confidence', [pic_num_sum, 1], "FP32") ] outputs = [grpcclient.InferRequestedOutput('OUTPUT0')] inputs[0].set_data_from_numpy(img_arr) inputs[1].set_data_from_numpy(nms_threshold) inputs[2].set_data_from_numpy(confidence) try: results = self.triton_client.infer(model_name=self.model_name, inputs=inputs, outputs=outputs, client_timeout=client_timeout) except Exception as e: return 'e_timeout' result_str = results.as_numpy('OUTPUT0')[0].decode("UTF-8") results = ast.literal_eval(result_str) return results else: return 'e_no_model' def close(self): return self.triton_client.close() def func(): print('func start') time.sleep(1) print('func end')