AiEngineObject.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # -*- encoding: utf-8 -*-
  2. """
  3. @File : AiEngine.py
  4. @Time : 2023/2/28 15:15
  5. @Author : stephen
  6. @Email : zhangdongming@asj6.wecom.work
  7. @Software: PyCharm
  8. """
  9. import ast
  10. import time
  11. import numpy as np
  12. import tritonclient.grpc as grpcclient
  13. class AiEngine:
  14. def __init__(self, **kwargs):
  15. # create grpc conncet
  16. try:
  17. self.triton_client = grpcclient.InferenceServerClient(
  18. url=kwargs.get('url'),
  19. verbose=False,
  20. ssl=False,
  21. root_certificates=None,
  22. private_key=None,
  23. certificate_chain=None)
  24. conncet = True
  25. except Exception as e:
  26. conncet = False
  27. self.health = True
  28. health_check = [
  29. conncet,
  30. self.triton_client.is_server_live(),
  31. self.triton_client.is_server_ready()
  32. ]
  33. if False in health_check:
  34. print("health_check failed:" + health_check)
  35. self.health = False
  36. self.model_name = None
  37. def set_model(self, model_name):
  38. if self.triton_client.is_model_ready(model_name):
  39. self.model_name = model_name
  40. return True
  41. else:
  42. return False
  43. def get_model_name(self, model_name):
  44. return self.model_name
  45. def yolo_infer(self, img_arr, nms_threshold, confidence, client_timeout):
  46. if self.model_name == 'AI_5obj_pdcpv_detect_yolov5_pipeline':
  47. pic_num_sum = len(img_arr)
  48. nms_threshold = np.array([[nms_threshold]] * pic_num_sum, dtype=np.float32)
  49. confidence = np.array([[confidence]] * pic_num_sum, dtype=np.float32)
  50. inputs = [
  51. grpcclient.InferInput('img1', [pic_num_sum, 360, 640, 3], "UINT8"),
  52. grpcclient.InferInput('nms_threshold', [pic_num_sum, 1], "FP32"),
  53. grpcclient.InferInput('confidence', [pic_num_sum, 1], "FP32")
  54. ]
  55. outputs = [grpcclient.InferRequestedOutput('OUTPUT0')]
  56. inputs[0].set_data_from_numpy(img_arr)
  57. inputs[1].set_data_from_numpy(nms_threshold)
  58. inputs[2].set_data_from_numpy(confidence)
  59. try:
  60. results = self.triton_client.infer(model_name=self.model_name, inputs=inputs, outputs=outputs,
  61. client_timeout=client_timeout)
  62. except Exception as e:
  63. return 'e_timeout'
  64. result_str = results.as_numpy('OUTPUT0')[0].decode("UTF-8")
  65. results = ast.literal_eval(result_str)
  66. return results
  67. else:
  68. return 'e_no_model'
  69. def close(self):
  70. return self.triton_client.close()
  71. def func():
  72. print('func start')
  73. time.sleep(1)
  74. print('func end')