123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 |
- import ast
- import numpy as np
- import tritonclient.grpc as grpcclient
- class AiEngine:
- def __init__(self, url, model_name):
- try:
- self.triton_client = grpcclient.InferenceServerClient(
- url=url,
- verbose=False,
- ssl=False,
- root_certificates=None,
- private_key=None,
- certificate_chain=None
- )
- self.health = self.check_health()
- except Exception as e:
- self.health = False
- self.model_name = model_name
- def check_health(self):
- conncet = True
- health_check = [
- conncet,
- self.triton_client.is_server_live(),
- self.triton_client.is_server_ready()
- ]
- if False in health_check:
- print("Health check failed:", health_check)
- return False
- return True
- def set_model(self, model_name):
- if self.triton_client.is_model_ready(model_name):
- self.model_name = model_name
- return True
- else:
- return False
- def yolo_infer(self, img_arr, nms_threshold, confidence, client_timeout):
- if self.model_name == 'pdchv':
- pic_num_sum = len(img_arr)
- nms_threshold = np.array([[nms_threshold]] * pic_num_sum, dtype=np.float32)
- confidence = np.array([[confidence]] * pic_num_sum, dtype=np.float32)
- inputs = [
- grpcclient.InferInput('img', [pic_num_sum, 360, 640, 3], "UINT8"),
- grpcclient.InferInput('nms_threshold', [pic_num_sum, 1], "FP32"),
- grpcclient.InferInput('confidence', [pic_num_sum, 1], "FP32")
- ]
- outputs = [grpcclient.InferRequestedOutput('result')]
- inputs[0].set_data_from_numpy(img_arr)
- inputs[1].set_data_from_numpy(nms_threshold)
- inputs[2].set_data_from_numpy(confidence)
- try:
- results = self.triton_client.infer(model_name=self.model_name, inputs=inputs, outputs=outputs,
- client_timeout=client_timeout)
- except Exception as e:
- print(repr(e))
- return 'e_connect_fail'
- result_str = results.as_numpy('result')[0].decode("UTF-8")
- results = ast.literal_eval(result_str)
- return results
- else:
- return 'e_no_model'
- def close(self):
- return self.triton_client.close()
|