Просмотр исходного кода

sageMaker增加判断是否识别到了指定的标签

zhangdongming 1 год назад
Родитель
Сommit
39ff5422e6
2 измененных файлов с 31 добавлено и 22 удалено
  1. 1 1
      Controller/AiController.py
  2. 30 21
      Object/SageMakerAiObject.py

+ 1 - 1
Controller/AiController.py

@@ -125,7 +125,7 @@ class AiView(View):
                 sage_maker = SageMakerAiObject()
                 ai_result = sage_maker.sage_maker_ai_server(uid, file_list)  # 图片base64识别AI标签
                 if ai_result:
-                    res = sage_maker.get_table_name(uid, ai_result)
+                    res = sage_maker.get_table_name(uid, ai_result, AiServiceQuery[0]['detect_group'])
                     if not res:
                         return response.json(0)
                     sage_maker.save_push_message(uid, n_time, uid_push_qs, channel, res, file_list)

+ 30 - 21
Object/SageMakerAiObject.py

@@ -79,28 +79,17 @@ class SageMakerAiObject:
             return False
 
     @staticmethod
-    def get_table_name(uid, ai_result):
+    def get_table_name(uid, ai_result, detect_group):
         """
-        sageMaker识别结果得到标签组和坐标框
-        {
-            "file_0":[
-                {
-                    "x1":551,
-                    "x2":639,
-                    "y1":179,
-                    "y2":274,
-                    "Width":0.22867838541666666,
-                    "Height":0.14794921875,
-                    "Top":0.28017578125,
-                    "Left":1.4353841145833333,
-                    "confidence":0.86,
-                    "class":"vehicle"
-                }
-            ]
-        }
+            根据sageMaker识别结果得到标签组和坐标框
+            :param uid: str,请求ID
+            :param ai_result: dict,sageMaker识别结果
+            :param detect_group: str,指定标签组
+            :return: dict or None or False
         """
         try:
             LOGGER.info(f'***get_table_name***uid={uid}')
+            # 定义标签组的映射关系
             ai_table_groups = {'person': 1, 'cat': 2, 'dog': 2, 'vehicle': 3, 'package': 4}
             event_group = []
 
@@ -110,17 +99,37 @@ class SageMakerAiObject:
                     if 'class' in item:
                         event_group.append(ai_table_groups[item['class']])
 
+            # 如果没有识别到标签则返回False
             if not event_group:
-                LOGGER.info(f'***get_table_name***uid={uid},没有识别到指定标签={event_group}')
-                return None
+                LOGGER.info(f'***get_table_name***uid={uid},没有识别到标签={event_group}')
+                return False
 
+            # 去重标签组并转换为整数类型,同时将指定标签组转换为整数列表
             event_group = set(event_group)
+            ai_group = list(map(int, detect_group.split(',')))
+            is_success = False
+
+            # 判断是否识别到了指定的标签
+            for item in event_group:
+                if item in ai_group:
+                    is_success = True
+                    break
+
+            # 如果未识别到指定标签则返回False
+            if not is_success:
+                LOGGER.info(f'***get_table_name***uid={uid},没有识别到指定标签sageMakerLabel={event_group}')
+                return False
+
+            # 将标签组转换为整数表示的类型组
             event_type = int(''.join(map(str, sorted(event_group))))
             label_list = []
-            for item in event_group:  # 识别到类型获取name
+
+            # 根据标签组获取对应的标签名称
+            for item in event_group:
                 label_list.append(AI_IDENTIFICATION_TAGS_DICT[str(item)])
             LOGGER.info(f'***get_table_name***uid={uid},类型组={event_type},名称组={label_list}')
 
+            # 返回包含事件类型、标签名称列表和坐标框的字典对象
             return {'event_type': event_type, 'label_list': label_list,
                     'coords': ai_result}
         except Exception as e: