Ver código fonte

增加遮面类型推送、优化AI识别新增sageMaker

zhangdongming 1 ano atrás
pai
commit
4a26fd7a88

+ 10 - 2
AnsjerPush/Config/aiConfig.py

@@ -21,6 +21,13 @@ AI_IDENTIFICATION_TAGS_DICT = {
     '4': 'Package'
 }
 
+# 算法类型设备上传10进制数值,uid_set表ai_type不为空则会将数值转二进制
+# 设备如传组合类型上报移动侦测和异响 1+32=33,二进制等于100001
+# 从右往左数第一位是移动第二位是人形,1则识别了移动标签,0表示未识别到该标签类型,100001则表示识别到了移动和异响
+# 1(移动侦测)、2(人形)、4(车型)、8(人脸)、16(宠物)、32(异响)、64(区域闯入)、128(区域离开)、
+# 256(徘徊检测)、512(无人检测)、1024(往来检测)、2048(哭声检测)、
+# 4096(手势检测)、8192(火焰报警)、16384(婴儿遮面检测)
+# 以下字典元素key与设备定义好的类型,对应value值是APP标签类型,如:设备上报1 则数据返回给APP是57
 DEVICE_EVENT_TYPE = {
     1: 51,
     2: 57,
@@ -35,7 +42,8 @@ DEVICE_EVENT_TYPE = {
     1024: 66,
     2048: 67,
     4096: 68,
-    8192: 69
+    8192: 69,
+    16384: 70
 }
 
-ALGORITHM_COMBO_TYPES = [51, 57, 58, 60, 59, 61, 62, 63, 64, 65, 66, 67, 68, 69]
+ALGORITHM_COMBO_TYPES = [51, 57, 58, 60, 59, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70]

+ 12 - 2
AnsjerPush/config.py

@@ -163,7 +163,8 @@ EVENT_DICT_CN = {
     },
     67: '检测到哭声 ',
     68: '检测到手势 ',
-    69: '检测到火焰 '
+    69: '检测到火焰 ',
+    70: '看不到宝宝的脸了,快去看看!'
 }
 
 EVENT_DICT = {
@@ -183,5 +184,14 @@ EVENT_DICT = {
     },
     67: 'Cries detected ',
     68: 'Gesture detected ',
-    69: 'Flame detected '
+    69: 'Flame detected ',
+    70: 'Cant see the baby\'s face now, go take a look!'
+}
+
+# 小米推送通知类别id
+XM_PUSH_CHANNEL_ID = {
+    'push_to_talk': 111934,        # 一键通话
+    'device_reminder': 104551,     # 设备提醒
+    'service_reminder': 104552,    # 服务提醒
+    'sys_notification': 104553     # 系统通知
 }

Diferenças do arquivo suprimidas por serem muito extensas
+ 38 - 17
Controller/AiController.py


+ 72 - 0
Object/AiEngineObject.py

@@ -0,0 +1,72 @@
+import ast
+
+import numpy as np
+import tritonclient.grpc as grpcclient
+
+
+class AiEngine:
+    def __init__(self, url, model_name):
+        try:
+            self.triton_client = grpcclient.InferenceServerClient(
+                url=url,
+                verbose=False,
+                ssl=False,
+                root_certificates=None,
+                private_key=None,
+                certificate_chain=None
+            )
+            self.health = self.check_health()
+        except Exception as e:
+            self.health = False
+
+        self.model_name = model_name
+
+    def check_health(self):
+        conncet = True
+        health_check = [
+            conncet,
+            self.triton_client.is_server_live(),
+            self.triton_client.is_server_ready()
+        ]
+        if False in health_check:
+            print("Health check failed:", health_check)
+            return False
+        return True
+
+    def set_model(self, model_name):
+        if self.triton_client.is_model_ready(model_name):
+            self.model_name = model_name
+            return True
+        else:
+            return False
+
+    def yolo_infer(self, img_arr, nms_threshold, confidence, client_timeout):
+        if self.model_name == 'pdchv':
+            pic_num_sum = len(img_arr)
+            nms_threshold = np.array([[nms_threshold]] * pic_num_sum, dtype=np.float32)
+            confidence = np.array([[confidence]] * pic_num_sum, dtype=np.float32)
+            inputs = [
+                grpcclient.InferInput('img', [pic_num_sum, 360, 640, 3], "UINT8"),
+                grpcclient.InferInput('nms_threshold', [pic_num_sum, 1], "FP32"),
+                grpcclient.InferInput('confidence', [pic_num_sum, 1], "FP32")
+            ]
+            outputs = [grpcclient.InferRequestedOutput('result')]
+
+            inputs[0].set_data_from_numpy(img_arr)
+            inputs[1].set_data_from_numpy(nms_threshold)
+            inputs[2].set_data_from_numpy(confidence)
+            try:
+                results = self.triton_client.infer(model_name=self.model_name, inputs=inputs, outputs=outputs,
+                                                   client_timeout=client_timeout)
+            except Exception as e:
+                print(repr(e))
+                return 'e_connect_fail'
+
+            result_str = results.as_numpy('result')[0].decode("UTF-8")
+            results = ast.literal_eval(result_str)
+            return results
+        else:
+            return 'e_no_model'
+
+    def close(self):
+        return self.triton_client.close()

+ 213 - 0
Object/AiImageObject.py

@@ -0,0 +1,213 @@
+import logging
+import os
+import PIL.Image as Image
+
+from AnsjerPush.Config.aiConfig import LABEL_DICT, AI_IDENTIFICATION_TAGS_DICT
+
+
+class ImageProcessingObject:
+    # 图片加工类
+    def __init__(self, image_dir_path, image_size, image_row):
+        self.image_dir_path = image_dir_path
+        self.image_size = image_size
+        self.image_row = image_row
+        self.image_info_dict = {}
+
+    def merge_images(self):
+        """
+        合并图片
+        """
+
+        # 获取图片集地址下的所有图片名称
+        image_full_path_list = self.get_image_full_path_list(self.image_dir_path)
+        image_full_path_list.sort()
+
+        image_save_path = r'{}.jpg'.format(self.image_dir_path)  # 图片转换后的地址
+        # 计算合成图片后的图片行数
+        if len(image_full_path_list) % self.image_row == 0:
+            image_row_num = len(image_full_path_list) // self.image_row
+        else:
+            image_row_num = len(image_full_path_list) // self.image_row + 1
+
+        x_list = []
+        y_list = []
+        for img_file in image_full_path_list:
+            img_x, img_y = self.get_new_img_xy(img_file, self.image_size)
+            x_list.append(img_x)
+            y_list.append(img_y)
+
+        x_new = int(x_list[len(x_list) // 5 * 4])
+        y_new = int(y_list[len(y_list) // 5 * 4])
+        self.composite_images(self.image_row, self.image_size, image_row_num, image_full_path_list, image_save_path,
+                              x_new, y_new)
+
+        self.image_info_dict = {'width': x_list[0], 'height': sum(y_list), 'num': len(y_list)}
+
+    def get_image_full_path_list(self, image_dir_path):
+        """
+        获取图片完整路径列表
+        @param image_dir_path: 图片路径
+        @return: image_full_path_list
+        """
+        file_name_list = os.listdir(image_dir_path)
+        image_full_path_list = []
+        for file_name_one in file_name_list:
+            file_one_path = os.path.join(image_dir_path, file_name_one)
+            if os.path.isfile(file_one_path):
+                image_full_path_list.append(file_one_path)
+            else:
+                img_path_list = self.get_image_full_path_list(file_one_path)
+                image_full_path_list.extend(img_path_list)
+        return image_full_path_list
+
+    @staticmethod
+    def get_new_img_xy(img_file, image_size):
+        """
+        获取图片宽高像素
+        @param img_file: 图片文件
+        @param image_size: 图片大小
+        @return : x_s(宽像素), y_s(高像素)
+        """
+        im = Image.open(img_file)
+        if image_size == 0:     # 等于0时按照原比例
+            (x_s, y_s) = im.size
+            return x_s, y_s
+        else:
+            (x, y) = im.size
+            lv = round(x / image_size, 2) + 0.01
+            x_s = x // lv
+            y_s = y // lv
+            return x_s, y_s
+
+    def composite_images(self, image_row, image_size, image_row_num, image_names, image_save_path, x_new, y_new):
+        """
+        合成图片
+        @param image_row: 图片行数(合成前)
+        @param image_size: 图片大小
+        @param image_row_num: 图片行数(合成后)
+        @param image_names: 图片名字
+        @param image_save_path: 图片保存路径
+        @param x_new: 横向位置
+        @param y_new: 纵向位置
+        @return: None
+        """
+        to_image = Image.new('RGB', (image_row * x_new, image_row_num * y_new))  # 创建一个新图
+        # 循环遍历,把每张图片按顺序粘贴到对应位置上
+        total_num = 0
+        for y in range(1, image_row_num + 1):
+            for x in range(1, image_row + 1):
+                from_image = self.resize_by_width(image_names[image_row * (y - 1) + x - 1], image_size)
+                to_image.paste(from_image, ((x - 1) * x_new, (y - 1) * y_new))
+                total_num += 1
+                if total_num == len(image_names):
+                    break
+        to_image.save(image_save_path)  # 保存新图
+
+    @staticmethod
+    def resize_by_width(infile, image_size):
+        """按照宽度进行所需比例缩放"""
+        im = Image.open(infile)
+        if image_size != 0:
+            (x, y) = im.size
+            lv = round(x / image_size, 2) + 0.01
+            x_s = int(x // lv)
+            y_s = int(y // lv)
+            print("x_s", x_s, y_s)
+            out = im.resize((x_s, y_s), Image.ANTIALIAS)
+            return out
+        else:
+            (x_s, y_s) = im.size
+            print("x_s", x_s, y_s)
+            out = im.resize((x_s, y_s), Image.ANTIALIAS)
+            return out
+
+    def handle_rekognition_res(self, detect_group, rekognition_res):
+        """
+        处理识别结果,匹配检测类型,并且返回标签坐标位置信息
+        @param detect_group: 检测类型
+        @param rekognition_res: 识别响应
+        @return: label_dict
+        """
+        logger = logging.getLogger('info')
+        labels = rekognition_res['Labels']
+        logger.info('--------识别到的标签-------:{}'.format(labels))
+
+        label_name = []
+        label_list = []
+
+        # 找出识别的所有标签
+        for label in labels:
+            label_name.append(label['Name'])
+            for Parents in label['Parents']:
+                label_name.append(Parents['Name'])
+
+        logger.info('------标签名------:{}'.format(label_name))
+
+        # 删除用户没有选择的ai识别类型, 并且得出最终识别结果
+        user_detect_list = detect_group.split(',')
+        user_detect_list = [i.strip() for i in user_detect_list]
+        conform_label_list = []
+        conform_detect_group = set()
+        for key, label_type_val in LABEL_DICT.items():
+            if key in user_detect_list:
+                for label in label_type_val:
+                    if label in label_name:
+                        conform_detect_group.add(key)
+                        conform_label_list.append(label)
+
+        # 找出标签边框线位置信息
+        bounding_box_list = []
+        for label in labels:
+            if label['Name'] in conform_label_list:
+                for label_instance in label['Instances']:
+                    bounding_box_list.append(label_instance['BoundingBox'])
+
+        # 找出边框位置信息对应的单图位置并重新计算位置比
+        merge_image_height = self.image_info_dict['height']
+        single_height = merge_image_height // self.image_info_dict['num']
+        new_bounding_box_dict = {
+            'file_0': [],
+            'file_1': [],
+            'file_2': []
+        }
+
+        for k, val in enumerate(bounding_box_list):
+            bounding_box_top = merge_image_height * val['Top']
+            # 找出当前边框属于哪张图片范围
+            box_dict = {}
+            for i in range(self.image_info_dict['num']):
+                top_min = i * single_height
+                top_max = (i + 1) * single_height
+                if top_min <= bounding_box_top <= top_max:
+                    box_dict['Width'] = val['Width']
+                    box_dict['Height'] = merge_image_height * val['Height'] / single_height
+                    # 减去前i张图片的高度
+                    box_dict['Top'] = ((merge_image_height * val['Top']) - (i * single_height)) / single_height
+                    box_dict['Left'] = val['Left']
+                    new_bounding_box_dict['file_{i}'.format(i=i)].append(box_dict)
+
+        # 组织返回数据
+        if not conform_detect_group:  # 没有识别到符合的标签
+            event_type = ''
+            label_list = []
+        else:
+            conform_detect_group = list(conform_detect_group)
+            if len(conform_detect_group) > 1:
+                conform_detect_group.sort()
+                # 集成识别标签
+                for label_key in conform_detect_group:
+                    label_list.append(AI_IDENTIFICATION_TAGS_DICT[label_key])
+                event_type = ''.join(conform_detect_group)  # 组合类型
+            else:
+                label_list.append(AI_IDENTIFICATION_TAGS_DICT[conform_detect_group[0]])
+                event_type = conform_detect_group[0]
+
+        logger.info('------conform_detect_group------ {}'.format(conform_detect_group))
+
+        label_dict = {
+            'event_type': event_type,
+            'label_list': label_list,
+            'new_bounding_box_dict': new_bounding_box_dict
+        }
+        logger.info('------label_dict------ {}'.format(label_dict))
+        return label_dict

+ 384 - 0
Object/SageMakerAiObject.py

@@ -0,0 +1,384 @@
+# -*- encoding: utf-8 -*-
+"""
+@File    : SageMakerAiObject.py
+@Time    : 2023/11/10 17:52
+@Author  : stephen
+@Email   : zhangdongming@asj6.wecom.work
+@Software: PyCharm
+"""
+import base64
+import json
+import logging
+import threading
+import time
+from io import BytesIO
+
+import boto3
+import numpy as np
+from PIL import Image
+
+from AnsjerPush.Config.aiConfig import AI_IDENTIFICATION_TAGS_DICT
+from AnsjerPush.config import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, CONFIG_US, CONFIG_EUR
+from AnsjerPush.config import CONFIG_INFO
+from Controller import AiController
+from Object.AiEngineObject import AiEngine
+from Object.enums.MessageTypeEnum import MessageTypeEnum
+from Service.EquipmentInfoService import EquipmentInfoService
+from Service.HuaweiPushService.HuaweiPushService import HuaweiPushObject
+from Service.PushService import PushObject
+
+LOGGER = logging.getLogger('info')
+
+
+class SageMakerAiObject:
+
+    @staticmethod
+    def sage_maker_ai_server(uid, base64_list):
+
+        try:
+            url = 'ec2-34-192-147-108.compute-1.amazonaws.com:8001'
+            model_name = 'pdchv'
+
+            ai = AiEngine(url, model_name)
+
+            if ai.health and ai.set_model(model_name):
+                LOGGER.info(f'uid={uid} Health check passed and Model set successfully')
+            else:
+                LOGGER.info('SageMake fail')
+                return False
+
+            img_list = list(map(base64.b64decode, base64_list))
+            img_list = map(BytesIO, img_list)
+            img_list = map(Image.open, img_list)
+            img_list = map(np.array, img_list)
+            img_list = np.array(list(img_list))
+
+            nms_threshold = 0.45
+            confidence = 0.5
+            client_timeout = 3
+
+            start = time.time()
+            results = ai.yolo_infer(img_list, nms_threshold, confidence, client_timeout)
+
+            end = time.time()
+            LOGGER.info(f'uid={uid},{(end - start) * 1000}ms,sageMaker Result={results}')
+
+            ai.close()
+
+            if results == 'e_timeout':
+                return False
+            if results == 'e_no_model':
+                return False
+            if results == 'e_connect_fail':
+                return False
+
+            return results
+        except Exception as e:
+            LOGGER.info('***sage_maker_ai_server***uid={},errLine={errLine}, errMsg={errMsg}'
+                        .format(uid, errLine=e.__traceback__.tb_lineno, errMsg=repr(e)))
+            return False
+
+    @staticmethod
+    def get_table_name(uid, ai_result, detect_group):
+        """
+            根据sageMaker识别结果得到标签组和坐标框
+            :param uid: str,请求ID
+            :param ai_result: dict,sageMaker识别结果
+            :param detect_group: str,指定标签组
+            :return: dict or None or False
+        """
+        try:
+            LOGGER.info(f'***get_table_name***uid={uid}')
+            # 定义标签组的映射关系
+            ai_table_groups = {'person': 1, 'cat': 2, 'dog': 2, 'vehicle': 3, 'package': 4}
+            event_group = []
+
+            # 获取识别到的标签以及坐标信息
+            for pic_key, pic_value in ai_result.items():
+                for item in pic_value:
+                    if 'class' in item:
+                        event_group.append(ai_table_groups[item['class']])
+
+            # 如果没有识别到标签则返回False
+            if not event_group:
+                LOGGER.info(f'***get_table_name***uid={uid},没有识别到标签={event_group}')
+                return False
+
+            # 去重标签组并转换为整数类型,同时将指定标签组转换为整数列表
+            event_group = set(event_group)
+            ai_group = list(map(int, detect_group.split(',')))
+            is_success = False
+
+            # 判断是否识别到了指定的标签
+            for item in event_group:
+                if item in ai_group:
+                    is_success = True
+                    break
+
+            # 如果未识别到指定标签则返回False
+            if not is_success:
+                LOGGER.info(f'***get_table_name***uid={uid},没有识别到指定标签sageMakerLabel={event_group}')
+                return False
+
+            # 将标签组转换为整数表示的类型组
+            event_type = int(''.join(map(str, sorted(event_group))))
+            label_list = []
+
+            # 根据标签组获取对应的标签名称
+            for item in event_group:
+                label_list.append(AI_IDENTIFICATION_TAGS_DICT[str(item)])
+            LOGGER.info(f'***get_table_name***uid={uid},类型组={event_type},名称组={label_list}')
+
+            # 返回包含事件类型、标签名称列表和坐标框的字典对象
+            return {'event_type': event_type, 'label_list': label_list,
+                    'coords': ai_result}
+        except Exception as e:
+            LOGGER.info('***get_table_name***uid={},errLine={errLine}, errMsg={errMsg}'
+                        .format(uid, errLine=e.__traceback__.tb_lineno, errMsg=repr(e)))
+            return False
+
+    @staticmethod
+    def save_push_message(uid, d_push_time, uid_push_qs, channel, res, file_list):
+        """
+        保存推送消息
+        """
+        try:
+            LOGGER.info(f'***save_push_message***uid={uid},res={res}')
+            if not res:
+                return False
+            d_push_time = int(d_push_time)
+            # 存储消息以及推送
+            uid_push_list = [uid_push for uid_push in uid_push_qs]
+            user_id_list = []
+            nickname = uid_push_list[0]['uid_set__nickname']
+            lang = uid_push_list[0]['lang']
+            event_type = res['event_type']
+            label_str = SageMakerAiObject().get_push_tag_message(lang, event_type)
+            coords = json.dumps(res['coords']) if res['coords'] else ''  # 坐标框字典转json字符串
+            now_time = int(time.time())
+            # 推送表存储数据
+            equipment_info_kwargs = {
+                'device_uid': uid,
+                'device_nick_name': nickname,
+                'channel': channel,
+                'is_st': 3,
+                'storage_location': 2,
+                'event_type': event_type,
+                'event_time': d_push_time,
+                'add_time': now_time,
+                'alarm': label_str,
+                'border_coords': coords
+            }
+            push_msg_list = []
+            equipment_info_list = []
+            equipment_info_model = EquipmentInfoService.randoms_choice_equipment_info()  # 随机获取其中一张推送表
+            for up in uid_push_list:
+                # 保存推送数据
+                tz = up['tz']
+                if tz is None or tz == '':
+                    tz = 0
+                user_id = up['userID_id']
+                if user_id not in user_id_list:
+                    equipment_info_kwargs['device_user_id'] = user_id
+                    equipment_info_list.append(equipment_info_model(**equipment_info_kwargs))
+                    user_id_list.append(user_id)
+
+                # 推送
+                push_type = up['push_type']
+                app_bundle_id = up['appBundleId']
+                token_val = up['token_val']
+                lang = up['lang']
+                # 推送标题和推送内容
+                msg_title = PushObject.get_msg_title(nickname=nickname)
+                msg_text = PushObject.get_ai_msg_text(channel=channel, n_time=d_push_time, lang=lang, tz=tz,
+                                                      label=label_str)
+                kwargs = {
+                    'nickname': nickname,
+                    'app_bundle_id': app_bundle_id,
+                    'token_val': token_val,
+                    'n_time': d_push_time,
+                    'event_type': event_type,
+                    'msg_title': msg_title,
+                    'msg_text': msg_text,
+                    'uid': uid,
+                    'channel': channel,
+                    'push_type': push_type
+                }
+                push_msg_list.append(kwargs)
+
+            if equipment_info_list:  # 消息存表
+                equipment_info_model.objects.bulk_create(equipment_info_list)
+
+            SageMakerAiObject().upload_image_to_s3(uid, channel, d_push_time, file_list)  # 上传到云端
+
+            push_thread = threading.Thread(target=SageMakerAiObject.async_app_msg_push,
+                                           kwargs={'uid': uid, 'push_msg_list': push_msg_list})
+            push_thread.start()  # APP消息提醒异步推送
+
+            AiController.AiView().save_cloud_ai_tag(uid, d_push_time, event_type, 0)  # 关联AI标签
+            return True
+        except Exception as e:
+            LOGGER.info('***get_table_name***uid={},errLine={errLine}, errMsg={errMsg}'
+                        .format(uid, errLine=e.__traceback__.tb_lineno, errMsg=repr(e)))
+            return False
+
+    @staticmethod
+    def async_app_msg_push(uid, push_msg_list):
+        """
+        APP消息提醒异步推送
+        """
+        if not push_msg_list:
+            LOGGER.info(f'***uid={uid}APP推送push_info为空***')
+        for item in push_msg_list:
+            push_type = item['push_type']
+            item.pop('push_type')
+            SageMakerAiObject.app_user_message_push(push_type, **item)
+        LOGGER.info(f'***uid={uid}APP推送完成')
+
+    @staticmethod
+    def app_user_message_push(push_type, **kwargs):
+        """
+            ai识别 app用户消息推送
+        """
+        uid = kwargs['uid']
+        try:
+            if push_type == 0:  # ios apns
+                PushObject.ios_apns_push(**kwargs)
+            elif push_type == 1:  # android gcm
+                PushObject.android_fcm_push(**kwargs)
+            elif push_type == 2:  # android jpush
+                kwargs.pop('uid')
+                PushObject.android_jpush(**kwargs)
+            elif push_type == 3:
+                huawei_push_object = HuaweiPushObject()
+                huawei_push_object.send_push_notify_message(**kwargs)
+            elif push_type == 4:  # android 小米推送
+                PushObject.android_xmpush(**kwargs)
+            elif push_type == 5:  # android vivo推送
+                PushObject.android_vivopush(**kwargs)
+            elif push_type == 6:  # android oppo推送
+                PushObject.android_oppopush(**kwargs)
+            elif push_type == 7:  # android 魅族推送
+                PushObject.android_meizupush(**kwargs)
+            else:
+                LOGGER.info(f'uid={uid},{push_type}推送类型不存在')
+                return False
+            return True
+        except Exception as e:
+            LOGGER.info('ai推送消息异常,uid={},errLine:{}, errMsg:{}'.format(kwargs['uid']
+                                                                       , e.__traceback__.tb_lineno, repr(e)))
+            return False
+
+    @staticmethod
+    def get_push_tag_message(lang, event_type):
+        """
+        根据语言以及推送类型得到APP通知栏文案
+        """
+        event_type = str(event_type)
+        types = []
+        if len(event_type) > 1:
+            for i in range(1, len(event_type) + 1):
+                types.append(MessageTypeEnum(int(event_type[i - 1:i])))
+        else:
+            types.append(int(event_type))
+        msg_cn = {1: '人', 2: '动物', 3: '车', 4: '包裹'}
+        msg_en = {1: 'person', 2: 'animal', 3: 'vehicle', 4: 'package'}
+        msg_text = ''
+        for item in types:
+            if lang == 'cn':
+                msg_text += msg_cn.get(item) + ' '
+            else:
+                msg_text += msg_en.get(item) + ' '
+        return msg_text
+
+    @staticmethod
+    def upload_image_to_s3(uid, channel, d_push_time, file_list):
+        """
+        上传图片到S3
+        """
+        if CONFIG_INFO == CONFIG_US or CONFIG_INFO == CONFIG_EUR:
+            # 创建 AWS 访问密钥
+            aws_key = AWS_ACCESS_KEY_ID[1]
+            aws_secret = AWS_SECRET_ACCESS_KEY[1]
+            # 创建会话对象
+            session = boto3.Session(
+                aws_access_key_id=aws_key,
+                aws_secret_access_key=aws_secret,
+                region_name="us-east-1"
+            )
+            s3_resource = session.resource("s3")
+            bucket_name = "foreignpush"
+        else:
+            # 存国内
+            aws_key = AWS_ACCESS_KEY_ID[0]
+            aws_secret = AWS_SECRET_ACCESS_KEY[0]
+            session = boto3.Session(
+                aws_access_key_id=aws_key,
+                aws_secret_access_key=aws_secret,
+                region_name="cn-northwest-1")
+            s3_resource = session.resource("s3")
+            bucket_name = "push"
+
+        try:
+            # 解码base64字符串为图像数据
+            for index, pic in enumerate(file_list):
+                image_data = base64.b64decode(pic)
+                # 指定存储桶名称和对象键
+                object_key = f'{uid}/{channel}/{d_push_time}_{index}.jpeg'
+
+                # 获取指定的存储桶
+                bucket = s3_resource.Bucket(bucket_name)
+
+                # 上传图像数据到 S3
+                bucket.put_object(Key=object_key, Body=image_data)
+            LOGGER.info(f'uid={uid},base64上传缩略图到S3成功')
+        except Exception as e:
+            LOGGER.info('***upload_image_to_s3,uid={},errLine:{}, errMsg:{}'
+                        .format(uid, e.__traceback__.tb_lineno, repr(e)))
+
+    @staticmethod
+    def bird_recognition_demo(resultRsp):
+        """
+        鸟类识别示例代码
+        """
+        from PIL import Image
+        import base64
+        from io import BytesIO
+        import json
+        import boto3
+
+        image_path = "E:/stephen/AnsjerProject/bird_sagemaker/test.png"
+        image = Image.open(image_path).convert("RGB")
+
+        # 将图像转换为 Base64 编码字符串
+        image_buffer = BytesIO()
+        image.save(image_buffer, format="JPEG")
+        image_base64 = base64.b64encode(image_buffer.getvalue()).decode("utf-8")
+        input_json = json.dumps({'input': image_base64})
+
+        AWS_SERVER_PUBLIC_KEY = 'AKIARBAGHTGOSK37QB4T'
+        AWS_SERVER_SECRET_KEY = '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'
+        runtime = boto3.client("sagemaker-runtime",
+                               aws_access_key_id=AWS_SERVER_PUBLIC_KEY,
+                               aws_secret_access_key=AWS_SERVER_SECRET_KEY,
+                               region_name='us-east-1')
+
+        endpoint_name = 'ServerlessClsBirdEndpoint1'
+        content_type = "application/json"
+        payload = input_json
+        session = boto3.Session()
+        # runtime = session.client('sagemaker-runtime')
+        import time
+        start_time = time.time()
+        response = runtime.invoke_endpoint(
+            EndpointName=endpoint_name,
+            ContentType=content_type,
+            Body=payload
+        )
+        end_time = time.time()
+        print("耗时: {:.2f}秒".format(end_time - start_time))
+
+        result = json.loads(response['Body'].read().decode())
+        # result = json.loads(response['Body'].read().decode())
+        print(result)
+        return resultRsp.json(0)

+ 40 - 0
Object/utils/AmazonRekognitionUtil.py

@@ -0,0 +1,40 @@
+# -*- coding: utf-8 -*-
+"""
+@Author : Rocky
+@Time : 2022/12/12 14:28
+@File :AmazonRekognitionUtil.py
+"""
+import boto3
+
+
+class AmazonRekognitionUtil:
+    # doc: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rekognition.html
+    def __init__(self):
+        # 全部使用美东服务
+        self.region_name = 'us-east-1'
+        self.aws_access_key_id = 'AKIA2E67UIMD6JD6TN3J'
+        self.aws_secret_access_key = '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'
+
+        self.client = boto3.client(
+            'rekognition',
+            region_name=self.region_name,
+            aws_access_key_id=self.aws_access_key_id,
+            aws_secret_access_key=self.aws_secret_access_key,
+        )
+
+    def detect_labels(self, image):
+        """
+        识别图片标签
+        @param image: 图片二进制文件
+        @return: rekognition_res
+        """
+        rekognition_res = self.client.detect_labels(
+            Image={'Bytes': image},
+            MaxLabels=50,
+            MinConfidence=80
+        )
+        try:
+            assert rekognition_res['ResponseMetadata']['HTTPStatusCode'] == 200
+        except AssertionError:
+            return {}
+        return rekognition_res

Alguns arquivos não foram mostrados porque muitos arquivos mudaram nesse diff