AiEngineObject.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import ast
  2. import numpy as np
  3. import tritonclient.grpc as grpcclient
  4. class AiEngine:
  5. def __init__(self, url, model_name):
  6. try:
  7. self.triton_client = grpcclient.InferenceServerClient(
  8. url=url,
  9. verbose=False,
  10. ssl=False,
  11. root_certificates=None,
  12. private_key=None,
  13. certificate_chain=None
  14. )
  15. self.health = self.check_health()
  16. except Exception as e:
  17. self.health = False
  18. self.model_name = model_name
  19. def check_health(self):
  20. conncet = True
  21. health_check = [
  22. conncet,
  23. self.triton_client.is_server_live(),
  24. self.triton_client.is_server_ready()
  25. ]
  26. if False in health_check:
  27. print("Health check failed:", health_check)
  28. return False
  29. return True
  30. def set_model(self, model_name):
  31. if self.triton_client.is_model_ready(model_name):
  32. self.model_name = model_name
  33. return True
  34. else:
  35. return False
  36. def yolo_infer(self, img_arr, nms_threshold, confidence, client_timeout):
  37. if self.model_name == 'pdchv':
  38. pic_num_sum = len(img_arr)
  39. nms_threshold = np.array([[nms_threshold]] * pic_num_sum, dtype=np.float32)
  40. confidence = np.array([[confidence]] * pic_num_sum, dtype=np.float32)
  41. inputs = [
  42. grpcclient.InferInput('img', [pic_num_sum, 360, 640, 3], "UINT8"),
  43. grpcclient.InferInput('nms_threshold', [pic_num_sum, 1], "FP32"),
  44. grpcclient.InferInput('confidence', [pic_num_sum, 1], "FP32")
  45. ]
  46. outputs = [grpcclient.InferRequestedOutput('result')]
  47. inputs[0].set_data_from_numpy(img_arr)
  48. inputs[1].set_data_from_numpy(nms_threshold)
  49. inputs[2].set_data_from_numpy(confidence)
  50. try:
  51. results = self.triton_client.infer(model_name=self.model_name, inputs=inputs, outputs=outputs,
  52. client_timeout=client_timeout)
  53. except Exception as e:
  54. print(repr(e))
  55. return 'e_connect_fail'
  56. result_str = results.as_numpy('result')[0].decode("UTF-8")
  57. results = ast.literal_eval(result_str)
  58. return results
  59. else:
  60. return 'e_no_model'
  61. def close(self):
  62. return self.triton_client.close()