Ver código fonte

AI服务替换自有云模型识别

zhangdongming 2 anos atrás
pai
commit
7d2a31c065
2 arquivos alterados com 187 adições e 22 exclusões
  1. 95 22
      Controller/AiController.py
  2. 92 0
      Object/AiEngineObject.py

+ 95 - 22
Controller/AiController.py

@@ -5,20 +5,25 @@ import os
 import threading
 import threading
 import time
 import time
 
 
+import numpy as np
+from PIL import Image
 from django.views.generic.base import View
 from django.views.generic.base import View
 
 
 from AnsjerPush.config import BASE_DIR
 from AnsjerPush.config import BASE_DIR
 from Model.models import UidPushModel, AiService, Device_Info, VodHlsTag, VodHlsTagType
 from Model.models import UidPushModel, AiService, Device_Info, VodHlsTag, VodHlsTagType
-from Object.AiImageObject import ImageProcessingObject
+from Object.AiEngineObject import AiEngine
 from Object.ETkObject import ETkObject
 from Object.ETkObject import ETkObject
 from Object.ResponseObject import ResponseObject
 from Object.ResponseObject import ResponseObject
 from Object.enums.MessageTypeEnum import MessageTypeEnum
 from Object.enums.MessageTypeEnum import MessageTypeEnum
 from Object.utils import LocalDateTimeUtil
 from Object.utils import LocalDateTimeUtil
-from Object.utils.AmazonRekognitionUtil import AmazonRekognitionUtil
 from Service.CommonService import CommonService
 from Service.CommonService import CommonService
 from Service.EquipmentInfoService import EquipmentInfoService
 from Service.EquipmentInfoService import EquipmentInfoService
 from Service.PushService import PushObject
 from Service.PushService import PushObject
 
 
+LOGGING = logging.getLogger('info')
+CLOUD_BASED_AI_URL = '34.192.147.108:8001'
+MODEL_NAME = 'AI_5obj_pdcpv_detect_yolov5_pipeline'
+
 
 
 class AiView(View):
 class AiView(View):
     def get(self, request, *args, **kwargs):
     def get(self, request, *args, **kwargs):
@@ -105,30 +110,34 @@ class AiView(View):
                 with open(file_path, 'wb') as f:
                 with open(file_path, 'wb') as f:
                     f.write(val)
                     f.write(val)
                     f.close()
                     f.close()
-
-            image_size = 0  # 每张小图片的大小,等于0是按原图大小进行合并
-            image_row = 1  # 合并成一张图后,一行有几个小图
-            image_processing_obj = ImageProcessingObject(dir_path, image_size, image_row)
-            image_processing_obj.merge_images()
+            ai_view = AiView()
+            ai_results = ai_view.image_aI_recognition(file_path_list, 0.45, 0.45, 50)
+            if not ai_results:
+                return response.json(0)
+            event_type = ai_view.get_cloud_recognition_tag(ai_results)
+            if event_type == 0:
+                return response.json(0)
+            # image_size = 0  # 每张小图片的大小,等于0是按原图大小进行合并
+            # image_row = 1  # 合并成一张图后,一行有几个小图
+            # image_processing_obj = ImageProcessingObject(dir_path, image_size, image_row)
+            # image_processing_obj.merge_images()
 
 
             # 获取识别结果
             # 获取识别结果
-            aws_rekognition = AmazonRekognitionUtil()
-            with open(dir_path + '.jpg', 'rb') as f:
-                rekognition_res = aws_rekognition.detect_labels(f.read())
+            # aws_rekognition = AmazonRekognitionUtil()
+            # with open(dir_path + '.jpg', 'rb') as f:
+            #     rekognition_res = aws_rekognition.detect_labels(f.read())
 
 
-            if not rekognition_res:
-                return response.json(0)
+            # if not rekognition_res:
+            #     return response.json(0)
 
 
-            label_dict = image_processing_obj.handle_rekognition_res(detect_group, rekognition_res)
-            if not label_dict['label_list']:
-                # 需要删除图片
-                # photo.close()
-                # self.del_path(os.path.join(BASE_DIR, 'static/ai/' + uid))
-                return response.json(0)
-
-            event_type = label_dict['event_type']
-            label_str = ','.join(label_dict['label_list'])
-            new_bounding_box_dict = label_dict['new_bounding_box_dict']
+            # label_dict = image_processing_obj.handle_rekognition_res(detect_group, rekognition_res)
+            # if not label_dict['label_list']:
+            #     # 需要删除图片
+            #     # photo.close()
+            #     # self.del_path(os.path.join(BASE_DIR, 'static/ai/' + uid))
+            #     return response.json(0)
+            label_str = ''  # ','.join(label_dict['label_list'])
+            new_bounding_box_dict = ''  # label_dict['new_bounding_box_dict']
 
 
             # 上传缩略图到s3
             # 上传缩略图到s3
             file_dict = {}
             file_dict = {}
@@ -213,6 +222,70 @@ class AiView(View):
             }
             }
             return response.json(48, data)
             return response.json(48, data)
 
 
+    @classmethod
+    def image_aI_recognition(cls, input_name_arr, nms_threshold, confidence, client_timeout):
+        """
+        自有图片云模型识别
+        :param input_name_arr: 推理图片地址名
+        :param nms_threshold: nms置信度
+        :param confidence: 目标置信度(一般只用调整这个)
+        :param client_timeout: 超时时间(秒为单位)
+        :return: results 推理结果
+        """
+        try:
+            t = time.time()
+            # 建立长连接
+            ai = AiEngine(url=CLOUD_BASED_AI_URL)
+            # 检查连通性、推理服务器状态
+            if ai.health:
+                LOGGING.info('健康状况通过')
+            # 设定模型
+            if ai.set_model(MODEL_NAME):
+                LOGGING.info('设置模型通过')
+            # 推理张数(一次最多推理128张!)
+            # 图片名称(这里可以改成内存)注意改完之后要检查input_tmp的【类型(type)、尺寸(shape)】是否和之前的一致
+            # 输入尺寸固定640wx360h,如需变动可以联系我们,我们这边做resize会快
+            input_name_arr = np.array(list(map(np.array, map(Image.open, input_name_arr))))
+            # 推理
+            results = ai.yolo_infer(input_name_arr, nms_threshold, confidence, client_timeout)
+            # 推理完请关闭长连接
+            ai.close()
+            LOGGING.info(f'coast:{time.time() - t:.4f}s')
+            # 报错返回
+            if results == 'e_timeout':
+                raise Exception('推理超时')
+            elif results == 'e_no_model':
+                raise Exception('没有设置模型')
+            LOGGING.info('云上模型推理结果:{}'.format(results))
+            return results
+        except Exception as e:
+            LOGGING.info('云模型AI识别失败,errLine:{}, errMsg:{}'.format(e.__traceback__.tb_lineno, repr(e)))
+            return {}
+
+    @classmethod
+    def get_cloud_recognition_tag(cls, results):
+        """
+        根据推理结果
+        返回组合类型,0则推理失败,1:人,2:车,3:宠物,4:包裹
+        """
+        if not results:
+            return 0
+        # 返回字典,可以复制到绘制python脚本中查看结果
+        tag_list = []
+        event_dict = {1: 1, 0: 2, 3: 2, 2: 3, 4: 4}
+        for k, v in results.items():
+            if len(v) > 0:
+                for item in v:
+                    tag_list.append(event_dict.get(item['classID'], 0))
+        if tag_list:
+            tag_list = set(tag_list)
+            event_type = ''
+            for val in tag_list:
+                event_type += str(val)
+            return int(event_type)
+        else:
+            return 0
+
     @classmethod
     @classmethod
     def save_cloud_ai_tag(cls, uid, event_time, types):
     def save_cloud_ai_tag(cls, uid, event_time, types):
         """
         """

+ 92 - 0
Object/AiEngineObject.py

@@ -0,0 +1,92 @@
+# -*- encoding: utf-8 -*-
+"""
+@File    : AiEngine.py
+@Time    : 2023/2/28 15:15
+@Author  : stephen
+@Email   : zhangdongming@asj6.wecom.work
+@Software: PyCharm
+"""
+import ast
+import time
+
+import numpy as np
+import tritonclient.grpc as grpcclient
+
+
+class AiEngine:
+    def __init__(self, **kwargs):
+
+        # create grpc conncet
+        try:
+            self.triton_client = grpcclient.InferenceServerClient(
+                url=kwargs.get('url'),
+                verbose=False,
+                ssl=False,
+                root_certificates=None,
+                private_key=None,
+                certificate_chain=None)
+            conncet = True
+        except Exception as e:
+            conncet = False
+        self.health = True
+        health_check = [
+            conncet,
+            self.triton_client.is_server_live(),
+            self.triton_client.is_server_ready()
+        ]
+        if False in health_check:
+            print("health_check failed:" + health_check)
+            self.health = False
+        self.model_name = None
+
+    def set_model(self, model_name):
+        if self.triton_client.is_model_ready(model_name):
+            self.model_name = model_name
+            return True
+        else:
+            return False
+
+    def get_model_name(self, model_name):
+        return self.model_name
+
+    def yolo_infer(self, img_arr, nms_threshold, confidence, client_timeout):
+        if self.model_name == 'AI_5obj_pdcpv_detect_yolov5_pipeline':
+            pic_num_sum = len(img_arr)
+            nms_threshold = np.array([[nms_threshold]] * pic_num_sum, dtype=np.float32)
+            confidence = np.array([[confidence]] * pic_num_sum, dtype=np.float32)
+            inputs = [
+                grpcclient.InferInput('img1', [pic_num_sum, 360, 640, 3], "UINT8"),
+                grpcclient.InferInput('nms_threshold', [pic_num_sum, 1], "FP32"),
+                grpcclient.InferInput('confidence', [pic_num_sum, 1], "FP32")
+            ]
+            outputs = [grpcclient.InferRequestedOutput('OUTPUT0')]
+
+            inputs[0].set_data_from_numpy(img_arr)
+            inputs[1].set_data_from_numpy(nms_threshold)
+            inputs[2].set_data_from_numpy(confidence)
+
+            try:
+                results = self.triton_client.infer(model_name=self.model_name, inputs=inputs, outputs=outputs,
+                                                   client_timeout=client_timeout)
+            except Exception as e:
+                return 'e_timeout'
+            result_str = results.as_numpy('OUTPUT0')[0].decode("UTF-8")
+            results = ast.literal_eval(result_str)
+            return results
+        else:
+            return 'e_no_model'
+
+    def close(self):
+        return self.triton_client.close()
+
+
+def func():
+    print('func start')
+    time.sleep(1)
+    print('func end')
+
+
+if __name__ == "__main__":
+    t = time.time()
+    func()
+    print(f'coast:{time.time() - t:.4f}s')