import ast import numpy as np import tritonclient.grpc as grpcclient class AiEngine: def __init__(self, url, model_name): try: self.triton_client = grpcclient.InferenceServerClient( url=url, verbose=False, ssl=False, root_certificates=None, private_key=None, certificate_chain=None ) self.health = self.check_health() except Exception as e: self.health = False self.model_name = model_name def check_health(self): conncet = True health_check = [ conncet, self.triton_client.is_server_live(), self.triton_client.is_server_ready() ] if False in health_check: print("Health check failed:", health_check) return False return True def set_model(self, model_name): if self.triton_client.is_model_ready(model_name): self.model_name = model_name return True else: return False def yolo_infer(self, img_arr, nms_threshold, confidence, client_timeout): if self.model_name == 'pdchv': pic_num_sum = len(img_arr) nms_threshold = np.array([[nms_threshold]] * pic_num_sum, dtype=np.float32) confidence = np.array([[confidence]] * pic_num_sum, dtype=np.float32) inputs = [ grpcclient.InferInput('img', [pic_num_sum, 360, 640, 3], "UINT8"), grpcclient.InferInput('nms_threshold', [pic_num_sum, 1], "FP32"), grpcclient.InferInput('confidence', [pic_num_sum, 1], "FP32") ] outputs = [grpcclient.InferRequestedOutput('result')] inputs[0].set_data_from_numpy(img_arr) inputs[1].set_data_from_numpy(nms_threshold) inputs[2].set_data_from_numpy(confidence) try: results = self.triton_client.infer(model_name=self.model_name, inputs=inputs, outputs=outputs, client_timeout=client_timeout) except Exception as e: print(repr(e)) return 'e_connect_fail' result_str = results.as_numpy('result')[0].decode("UTF-8") results = ast.literal_eval(result_str) return results else: return 'e_no_model' def close(self): return self.triton_client.close()