# -*- encoding: utf-8 -*- """ @File : SageMakerAiObject.py @Time : 2023/11/10 17:52 @Author : stephen @Email : zhangdongming@asj6.wecom.work @Software: PyCharm """ import base64 import json import logging import threading import time from io import BytesIO import boto3 import numpy as np from PIL import Image from AnsjerPush.Config.aiConfig import AI_IDENTIFICATION_TAGS_DICT from AnsjerPush.config import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, CONFIG_US, CONFIG_EUR from AnsjerPush.config import CONFIG_INFO from Controller import AiController from Object.AiEngineObject import AiEngine from Object.enums.MessageTypeEnum import MessageTypeEnum from Service.EquipmentInfoService import EquipmentInfoService from Service.HuaweiPushService.HuaweiPushService import HuaweiPushObject from Service.PushService import PushObject LOGGER = logging.getLogger('info') class SageMakerAiObject: @staticmethod def sage_maker_ai_server(uid, base64_list): try: url = 'ec2-34-192-147-108.compute-1.amazonaws.com:8001' model_name = 'pdchv' ai = AiEngine(url, model_name) if ai.health and ai.set_model(model_name): LOGGER.info(f'uid={uid} Health check passed and Model set successfully') else: LOGGER.info('SageMake fail') return False for i in range(len(base64_list)): base64_list[i] = base64_list[i].replace(' ', '+') img_list = list(map(base64.b64decode, base64_list)) img_list = map(BytesIO, img_list) img_list = map(Image.open, img_list) img_list = map(np.array, img_list) img_list = np.array(list(img_list)) nms_threshold = 0.45 confidence = 0.5 client_timeout = 3 start = time.time() results = ai.yolo_infer(img_list, nms_threshold, confidence, client_timeout) end = time.time() LOGGER.info(f'uid={uid},{(end - start) * 1000}ms,sageMaker Result={results}') ai.close() if results == 'e_timeout': return False if results == 'e_no_model': return False if results == 'e_connect_fail': return False return results except Exception as e: LOGGER.info('***sage_maker_ai_server***uid={},errLine={errLine}, errMsg={errMsg}' .format(uid, errLine=e.__traceback__.tb_lineno, errMsg=repr(e))) return False @staticmethod def get_table_name(uid, ai_result, detect_group): """ 根据sageMaker识别结果得到标签组和坐标框 :param uid: str,请求ID :param ai_result: dict,sageMaker识别结果 :param detect_group: str,指定标签组 :return: dict or None or False """ try: LOGGER.info(f'***get_table_name***uid={uid}') # 定义标签组的映射关系 ai_table_groups = {'person': 1, 'cat': 2, 'dog': 2, 'vehicle': 3, 'package': 4} event_group = [] # 获取识别到的标签以及坐标信息 for pic_key, pic_value in ai_result.items(): for item in pic_value: if 'class' in item: event_group.append(ai_table_groups[item['class']]) # 如果没有识别到标签则返回False if not event_group: LOGGER.info(f'***get_table_name***uid={uid},没有识别到标签={event_group}') return False # 去重标签组并转换为整数类型,同时将指定标签组转换为整数列表 event_group = set(event_group) ai_group = list(map(int, detect_group.split(','))) is_success = False # 判断是否识别到了指定的标签 for item in event_group: if item in ai_group: is_success = True break # 如果未识别到指定标签则返回False if not is_success: LOGGER.info(f'***get_table_name***uid={uid},没有识别到指定标签sageMakerLabel={event_group}') return False # 将标签组转换为整数表示的类型组 event_type = int(''.join(map(str, sorted(event_group)))) label_list = [] # 根据标签组获取对应的标签名称 for item in event_group: label_list.append(AI_IDENTIFICATION_TAGS_DICT[str(item)]) LOGGER.info(f'***get_table_name***uid={uid},类型组={event_type},名称组={label_list}') # 返回包含事件类型、标签名称列表和坐标框的字典对象 return {'event_type': event_type, 'label_list': label_list, 'coords': ai_result} except Exception as e: LOGGER.info('***get_table_name***uid={},errLine={errLine}, errMsg={errMsg}' .format(uid, errLine=e.__traceback__.tb_lineno, errMsg=repr(e))) return False @staticmethod def save_push_message(uid, d_push_time, uid_push_qs, channel, res, file_list): """ 保存推送消息 """ try: LOGGER.info(f'***save_push_message***uid={uid},res={res}') if not res: return False d_push_time = int(d_push_time) # 存储消息以及推送 uid_push_list = [uid_push for uid_push in uid_push_qs] user_id_list = [] nickname = uid_push_list[0]['uid_set__nickname'] event_type = res['event_type'] coords = json.dumps(res['coords']) if res['coords'] else '' # 坐标框字典转json字符串 now_time = int(time.time()) # 推送表存储数据 equipment_info_kwargs = { 'device_uid': uid, 'device_nick_name': nickname, 'channel': channel, 'is_st': 3, 'storage_location': 2, 'event_type': event_type, 'event_time': d_push_time, 'add_time': now_time, 'border_coords': coords } push_msg_list = [] equipment_info_list = [] equipment_info_model = EquipmentInfoService.randoms_choice_equipment_info() # 随机获取其中一张推送表 for up in uid_push_list: # 保存推送数据 lang = up['lang'] label_str = SageMakerAiObject().get_push_tag_message(lang, event_type) equipment_info_kwargs['alarm'] = label_str tz = up['tz'] if tz is None or tz == '': tz = 0 user_id = up['userID_id'] if user_id not in user_id_list: equipment_info_kwargs['device_user_id'] = user_id equipment_info_list.append(equipment_info_model(**equipment_info_kwargs)) user_id_list.append(user_id) # 推送 push_type = up['push_type'] app_bundle_id = up['appBundleId'] token_val = up['token_val'] # 推送标题和推送内容 msg_title = PushObject.get_msg_title(nickname=nickname) msg_text = PushObject.get_ai_msg_text(channel=channel, n_time=d_push_time, lang=lang, tz=tz, label=label_str) kwargs = { 'nickname': nickname, 'app_bundle_id': app_bundle_id, 'token_val': token_val, 'n_time': d_push_time, 'event_type': event_type, 'msg_title': msg_title, 'msg_text': msg_text, 'uid': uid, 'channel': channel, 'push_type': push_type } push_msg_list.append(kwargs) if equipment_info_list: # 消息存表 equipment_info_model.objects.bulk_create(equipment_info_list) SageMakerAiObject().upload_image_to_s3(uid, channel, d_push_time, file_list) # 上传到云端 push_thread = threading.Thread(target=SageMakerAiObject.async_app_msg_push, kwargs={'uid': uid, 'push_msg_list': push_msg_list}) push_thread.start() # APP消息提醒异步推送 AiController.AiView().save_cloud_ai_tag(uid, d_push_time, event_type, 0) # 关联AI标签 return True except Exception as e: LOGGER.info('***get_table_name***uid={},errLine={errLine}, errMsg={errMsg}' .format(uid, errLine=e.__traceback__.tb_lineno, errMsg=repr(e))) return False @staticmethod def async_app_msg_push(uid, push_msg_list): """ APP消息提醒异步推送 """ if not push_msg_list: LOGGER.info(f'***uid={uid}APP推送push_info为空***') for item in push_msg_list: push_type = item['push_type'] item.pop('push_type') SageMakerAiObject.app_user_message_push(push_type, **item) LOGGER.info(f'***uid={uid}APP推送完成') @staticmethod def app_user_message_push(push_type, **kwargs): """ ai识别 app用户消息推送 """ uid = kwargs['uid'] try: if push_type == 0: # ios apns PushObject.ios_apns_push(**kwargs) elif push_type == 1: # android gcm PushObject.android_fcm_push(**kwargs) elif push_type == 2: # android jpush kwargs.pop('uid') PushObject.android_jpush(**kwargs) elif push_type == 3: huawei_push_object = HuaweiPushObject() huawei_push_object.send_push_notify_message(**kwargs) elif push_type == 4: # android 小米推送 PushObject.android_xmpush(**kwargs) elif push_type == 5: # android vivo推送 PushObject.android_vivopush(**kwargs) elif push_type == 6: # android oppo推送 PushObject.android_oppopush(**kwargs) elif push_type == 7: # android 魅族推送 PushObject.android_meizupush(**kwargs) else: LOGGER.info(f'uid={uid},{push_type}推送类型不存在') return False return True except Exception as e: LOGGER.info('ai推送消息异常,uid={},errLine:{}, errMsg:{}'.format(kwargs['uid'] , e.__traceback__.tb_lineno, repr(e))) return False @staticmethod def get_push_tag_message(lang, event_type): """ 根据语言以及推送类型得到APP通知栏文案 """ event_type = str(event_type) types = [] if len(event_type) > 1: for i in range(1, len(event_type) + 1): types.append(MessageTypeEnum(int(event_type[i - 1:i]))) else: types.append(int(event_type)) msg_cn = {1: '人', 2: '动物', 3: '车', 4: '包裹'} msg_en = {1: 'person', 2: 'animal', 3: 'vehicle', 4: 'package'} msg_text = '' for item in types: if lang == 'cn': msg_text += msg_cn.get(item) + ' ' else: msg_text += msg_en.get(item) + ' ' return msg_text @staticmethod def upload_image_to_s3(uid, channel, d_push_time, file_list): """ 上传图片到S3 """ if CONFIG_INFO == CONFIG_US or CONFIG_INFO == CONFIG_EUR: # 创建 AWS 访问密钥 aws_key = AWS_ACCESS_KEY_ID[1] aws_secret = AWS_SECRET_ACCESS_KEY[1] # 创建会话对象 session = boto3.Session( aws_access_key_id=aws_key, aws_secret_access_key=aws_secret, region_name="us-east-1" ) s3_resource = session.resource("s3") bucket_name = "foreignpush" else: # 存国内 aws_key = AWS_ACCESS_KEY_ID[0] aws_secret = AWS_SECRET_ACCESS_KEY[0] session = boto3.Session( aws_access_key_id=aws_key, aws_secret_access_key=aws_secret, region_name="cn-northwest-1") s3_resource = session.resource("s3") bucket_name = "push" try: # 解码base64字符串为图像数据 for index, pic in enumerate(file_list): image_data = base64.b64decode(pic) # 指定存储桶名称和对象键 object_key = f'{uid}/{channel}/{d_push_time}_{index}.jpeg' # 获取指定的存储桶 bucket = s3_resource.Bucket(bucket_name) # 上传图像数据到 S3 bucket.put_object(Key=object_key, Body=image_data) LOGGER.info(f'uid={uid},base64上传缩略图到S3成功') except Exception as e: LOGGER.info('***upload_image_to_s3,uid={},errLine:{}, errMsg:{}' .format(uid, e.__traceback__.tb_lineno, repr(e))) @staticmethod def bird_recognition_demo(resultRsp): """ 鸟类识别示例代码 """ from PIL import Image import base64 from io import BytesIO import json import boto3 image_path = "E:/stephen/AnsjerProject/bird_sagemaker/test.png" image = Image.open(image_path).convert("RGB") # 将图像转换为 Base64 编码字符串 image_buffer = BytesIO() image.save(image_buffer, format="JPEG") image_base64 = base64.b64encode(image_buffer.getvalue()).decode("utf-8") input_json = json.dumps({'input': image_base64}) AWS_SERVER_PUBLIC_KEY = 'AKIARBAGHTGOSK37QB4T' AWS_SERVER_SECRET_KEY = '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' runtime = boto3.client("sagemaker-runtime", aws_access_key_id=AWS_SERVER_PUBLIC_KEY, aws_secret_access_key=AWS_SERVER_SECRET_KEY, region_name='us-east-1') endpoint_name = 'ServerlessClsBirdEndpoint1' content_type = "application/json" payload = input_json session = boto3.Session() # runtime = session.client('sagemaker-runtime') import time start_time = time.time() response = runtime.invoke_endpoint( EndpointName=endpoint_name, ContentType=content_type, Body=payload ) end_time = time.time() print("耗时: {:.2f}秒".format(end_time - start_time)) result = json.loads(response['Body'].read().decode()) # result = json.loads(response['Body'].read().decode()) print(result) return resultRsp.json(0)