|
@@ -0,0 +1,343 @@
|
|
|
+# -*- encoding: utf-8 -*-
|
|
|
+"""
|
|
|
+@File : SageMakerAiObject.py
|
|
|
+@Time : 2023/11/10 17:52
|
|
|
+@Author : stephen
|
|
|
+@Email : zhangdongming@asj6.wecom.work
|
|
|
+@Software: PyCharm
|
|
|
+"""
|
|
|
+import base64
|
|
|
+import json
|
|
|
+import logging
|
|
|
+import time
|
|
|
+from io import BytesIO
|
|
|
+
|
|
|
+import boto3
|
|
|
+import numpy as np
|
|
|
+from PIL import Image
|
|
|
+
|
|
|
+from AnsjerPush.Config.aiConfig import AI_IDENTIFICATION_TAGS_DICT
|
|
|
+from AnsjerPush.config import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY
|
|
|
+from Controller import AiController
|
|
|
+from Object.AiEngineObject import AiEngine
|
|
|
+from Object.enums.MessageTypeEnum import MessageTypeEnum
|
|
|
+from Service.EquipmentInfoService import EquipmentInfoService
|
|
|
+from Service.HuaweiPushService.HuaweiPushService import HuaweiPushObject
|
|
|
+from Service.PushService import PushObject
|
|
|
+
|
|
|
+LOGGER = logging.getLogger('info')
|
|
|
+
|
|
|
+
|
|
|
+class SageMakerAiObject:
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def sage_maker_ai_server(uid, base64_list):
|
|
|
+
|
|
|
+ try:
|
|
|
+ url = 'ec2-34-192-147-108.compute-1.amazonaws.com:8001'
|
|
|
+ model_name = 'pdchv'
|
|
|
+
|
|
|
+ ai = AiEngine(url, model_name)
|
|
|
+
|
|
|
+ if ai.health and ai.set_model(model_name):
|
|
|
+ LOGGER.info(f'uid={uid} Health check passed and Model set successfully')
|
|
|
+ else:
|
|
|
+ LOGGER.info('SageMake fail')
|
|
|
+ return False
|
|
|
+
|
|
|
+ img_list = list(map(base64.b64decode, base64_list))
|
|
|
+ img_list = map(BytesIO, img_list)
|
|
|
+ img_list = map(Image.open, img_list)
|
|
|
+ img_list = map(np.array, img_list)
|
|
|
+ img_list = np.array(list(img_list))
|
|
|
+
|
|
|
+ nms_threshold = 0.45
|
|
|
+ confidence = 0.5
|
|
|
+ client_timeout = 30
|
|
|
+
|
|
|
+ start = time.time()
|
|
|
+ results = ai.yolo_infer(img_list, nms_threshold, confidence, client_timeout)
|
|
|
+
|
|
|
+ end = time.time()
|
|
|
+ LOGGER.info(f'uid={uid},{(end - start) * 1000}ms,sageMaker Result={results}')
|
|
|
+
|
|
|
+ ai.close()
|
|
|
+
|
|
|
+ if results == 'e_timeout':
|
|
|
+ return False
|
|
|
+ if results == 'e_no_model':
|
|
|
+ return False
|
|
|
+ if results == 'e_connect_fail':
|
|
|
+ return False
|
|
|
+
|
|
|
+ return results
|
|
|
+ except Exception as e:
|
|
|
+ LOGGER.info('***sage_maker_ai_server***uid={},errLine={errLine}, errMsg={errMsg}'
|
|
|
+ .format(uid, errLine=e.__traceback__.tb_lineno, errMsg=repr(e)))
|
|
|
+ return False
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def get_table_name(uid, ai_result):
|
|
|
+ """
|
|
|
+ sageMaker识别结果得到标签组和坐标框
|
|
|
+ {
|
|
|
+ "file_0":[
|
|
|
+ {
|
|
|
+ "x1":551,
|
|
|
+ "x2":639,
|
|
|
+ "y1":179,
|
|
|
+ "y2":274,
|
|
|
+ "Width":0.22867838541666666,
|
|
|
+ "Height":0.14794921875,
|
|
|
+ "Top":0.28017578125,
|
|
|
+ "Left":1.4353841145833333,
|
|
|
+ "confidence":0.86,
|
|
|
+ "class":"vehicle"
|
|
|
+ }
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ LOGGER.info(f'***get_table_name***uid={uid},接收到识别信息={ai_result}')
|
|
|
+ ai_table_groups = {'person': 1, 'cat': 2, 'dog': 2, 'vehicle': 3, 'package': 4}
|
|
|
+ event_group = []
|
|
|
+
|
|
|
+ # 获取识别到的标签以及坐标信息
|
|
|
+ for pic_key, pic_value in ai_result.items():
|
|
|
+ for item in pic_value:
|
|
|
+ if 'class' in item:
|
|
|
+ event_group.append(ai_table_groups[item['class']])
|
|
|
+
|
|
|
+ if not event_group:
|
|
|
+ LOGGER.info(f'***get_table_name***uid={uid},没有识别到指定标签={event_group}')
|
|
|
+ return None
|
|
|
+
|
|
|
+ event_group = set(event_group)
|
|
|
+ event_type = int(''.join(map(str, sorted(event_group))))
|
|
|
+ label_list = []
|
|
|
+ for item in event_group: # 识别到类型获取name
|
|
|
+ label_list.append(AI_IDENTIFICATION_TAGS_DICT[str(item)])
|
|
|
+ LOGGER.info(f'***get_table_name***uid={uid},类型组={event_type},名称组={label_list}')
|
|
|
+
|
|
|
+ return {'event_type': event_type, 'label_list': label_list,
|
|
|
+ 'coords': ai_result}
|
|
|
+ except Exception as e:
|
|
|
+ LOGGER.info('***get_table_name***uid={},errLine={errLine}, errMsg={errMsg}'
|
|
|
+ .format(uid, errLine=e.__traceback__.tb_lineno, errMsg=repr(e)))
|
|
|
+ return False
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def save_push_message(uid, d_push_time, uid_push_qs, channel, res, file_list):
|
|
|
+ """
|
|
|
+ 保存推送消息
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ LOGGER.info(f'***save_push_message***uid={uid},res={res}')
|
|
|
+ if not res:
|
|
|
+ return False
|
|
|
+ d_push_time = int(d_push_time)
|
|
|
+ # 存储消息以及推送
|
|
|
+ uid_push_list = [uid_push for uid_push in uid_push_qs]
|
|
|
+ user_id_list = []
|
|
|
+ nickname = uid_push_list[0]['uid_set__nickname']
|
|
|
+ lang = uid_push_list[0]['lang']
|
|
|
+ event_type = res['event_type']
|
|
|
+ label_str = SageMakerAiObject().get_push_tag_message(lang, event_type)
|
|
|
+ coords = json.dumps(res['coords']) if res['coords'] else '' # 坐标框字典转json字符串
|
|
|
+ now_time = int(time.time())
|
|
|
+ # 推送表存储数据
|
|
|
+ equipment_info_kwargs = {
|
|
|
+ 'device_uid': uid,
|
|
|
+ 'device_nick_name': nickname,
|
|
|
+ 'channel': channel,
|
|
|
+ 'is_st': 3,
|
|
|
+ 'storage_location': 2,
|
|
|
+ 'event_type': event_type,
|
|
|
+ 'event_time': d_push_time,
|
|
|
+ 'add_time': now_time,
|
|
|
+ 'alarm': label_str,
|
|
|
+ 'border_coords': coords
|
|
|
+ }
|
|
|
+
|
|
|
+ equipment_info_list = []
|
|
|
+ equipment_info_model = EquipmentInfoService.randoms_choice_equipment_info() # 随机获取其中一张推送表
|
|
|
+ for up in uid_push_list:
|
|
|
+ # 保存推送数据
|
|
|
+ tz = up['tz']
|
|
|
+ if tz is None or tz == '':
|
|
|
+ tz = 0
|
|
|
+ user_id = up['userID_id']
|
|
|
+ if user_id not in user_id_list:
|
|
|
+ equipment_info_kwargs['device_user_id'] = user_id
|
|
|
+ equipment_info_list.append(equipment_info_model(**equipment_info_kwargs))
|
|
|
+ user_id_list.append(user_id)
|
|
|
+
|
|
|
+ # 推送
|
|
|
+ push_type = up['push_type']
|
|
|
+ app_bundle_id = up['appBundleId']
|
|
|
+ token_val = up['token_val']
|
|
|
+ lang = up['lang']
|
|
|
+ # 推送标题和推送内容
|
|
|
+ msg_title = PushObject.get_msg_title(nickname=nickname)
|
|
|
+ msg_text = PushObject.get_ai_msg_text(channel=channel, n_time=d_push_time, lang=lang, tz=tz,
|
|
|
+ label=label_str)
|
|
|
+ kwargs = {
|
|
|
+ 'nickname': nickname,
|
|
|
+ 'app_bundle_id': app_bundle_id,
|
|
|
+ 'token_val': token_val,
|
|
|
+ 'n_time': d_push_time,
|
|
|
+ 'event_type': event_type,
|
|
|
+ 'msg_title': msg_title,
|
|
|
+ 'msg_text': msg_text,
|
|
|
+ 'uid': uid,
|
|
|
+ 'channel': channel,
|
|
|
+ }
|
|
|
+ SageMakerAiObject().app_user_message_push(push_type, **kwargs)
|
|
|
+
|
|
|
+ if equipment_info_list: # 消息存表
|
|
|
+ equipment_info_model.objects.bulk_create(equipment_info_list)
|
|
|
+ SageMakerAiObject().upload_image_to_s3(uid, channel, d_push_time, file_list)
|
|
|
+ AiController.AiView().save_cloud_ai_tag(uid, d_push_time, event_type, 0) # 关联AI标签
|
|
|
+ return True
|
|
|
+ except Exception as e:
|
|
|
+ LOGGER.info('***get_table_name***uid={},errLine={errLine}, errMsg={errMsg}'
|
|
|
+ .format(uid, errLine=e.__traceback__.tb_lineno, errMsg=repr(e)))
|
|
|
+ return False
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def app_user_message_push(push_type, **kwargs):
|
|
|
+ """
|
|
|
+ ai识别 app用户消息推送
|
|
|
+ """
|
|
|
+ uid = kwargs['uid']
|
|
|
+ try:
|
|
|
+ if push_type == 0: # ios apns
|
|
|
+ PushObject.ios_apns_push(**kwargs)
|
|
|
+ elif push_type == 1: # android gcm
|
|
|
+ PushObject.android_fcm_push(**kwargs)
|
|
|
+ elif push_type == 2: # android jpush
|
|
|
+ kwargs.pop('uid')
|
|
|
+ PushObject.android_jpush(**kwargs)
|
|
|
+ elif push_type == 3:
|
|
|
+ huawei_push_object = HuaweiPushObject()
|
|
|
+ huawei_push_object.send_push_notify_message(**kwargs)
|
|
|
+ elif push_type == 4: # android 小米推送
|
|
|
+ PushObject.android_xmpush(**kwargs)
|
|
|
+ elif push_type == 5: # android vivo推送
|
|
|
+ PushObject.android_vivopush(**kwargs)
|
|
|
+ elif push_type == 6: # android oppo推送
|
|
|
+ PushObject.android_oppopush(**kwargs)
|
|
|
+ elif push_type == 7: # android 魅族推送
|
|
|
+ PushObject.android_meizupush(**kwargs)
|
|
|
+ else:
|
|
|
+ LOGGER.info(f'uid={uid},{push_type}推送类型不存在')
|
|
|
+ return False
|
|
|
+ return True
|
|
|
+ except Exception as e:
|
|
|
+ LOGGER.info('ai推送消息异常,uid={},errLine:{}, errMsg:{}'.format(kwargs['uid']
|
|
|
+ , e.__traceback__.tb_lineno, repr(e)))
|
|
|
+ return False
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def get_push_tag_message(lang, event_type):
|
|
|
+ """
|
|
|
+ 根据语言以及推送类型得到APP通知栏文案
|
|
|
+ """
|
|
|
+ event_type = str(event_type)
|
|
|
+ types = []
|
|
|
+ if len(event_type) > 1:
|
|
|
+ for i in range(1, len(event_type) + 1):
|
|
|
+ types.append(MessageTypeEnum(int(event_type[i - 1:i])))
|
|
|
+ else:
|
|
|
+ types.append(int(event_type))
|
|
|
+ msg_cn = {1: '人', 2: '动物', 3: '车', 4: '包裹'}
|
|
|
+ msg_en = {1: 'person', 2: 'animal', 3: 'vehicle', 4: 'package'}
|
|
|
+ msg_text = ''
|
|
|
+ for item in types:
|
|
|
+ if lang == 'cn':
|
|
|
+ msg_text += msg_cn.get(item) + ' '
|
|
|
+ else:
|
|
|
+ msg_text += msg_en.get(item) + ' '
|
|
|
+ return msg_text
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def upload_image_to_s3(uid, channel, d_push_time, file_list):
|
|
|
+ """
|
|
|
+ 上传图片到S3
|
|
|
+ """
|
|
|
+ # 创建 AWS 访问密钥
|
|
|
+ aws_key = AWS_ACCESS_KEY_ID[1]
|
|
|
+ aws_secret = AWS_SECRET_ACCESS_KEY[1]
|
|
|
+ # 创建会话对象
|
|
|
+ session = boto3.Session(
|
|
|
+ aws_access_key_id=aws_key,
|
|
|
+ aws_secret_access_key=aws_secret,
|
|
|
+ region_name="us-east-1"
|
|
|
+ )
|
|
|
+
|
|
|
+ # 创建 S3 资源对象
|
|
|
+ s3_resource = session.resource("s3")
|
|
|
+ try:
|
|
|
+ # 解码base64字符串为图像数据
|
|
|
+ for index, pic in enumerate(file_list):
|
|
|
+ image_data = base64.b64decode(pic)
|
|
|
+ # 指定存储桶名称和对象键
|
|
|
+ bucket_name = "foreignpush"
|
|
|
+ object_key = f'{uid}/{channel}/{d_push_time}_{index}.jpeg'
|
|
|
+
|
|
|
+ # 获取指定的存储桶
|
|
|
+ bucket = s3_resource.Bucket(bucket_name)
|
|
|
+
|
|
|
+ # 上传图像数据到 S3
|
|
|
+ bucket.put_object(Key=object_key, Body=image_data)
|
|
|
+ LOGGER.info(f'uid={uid},base64上传缩略图到S3成功')
|
|
|
+ except Exception as e:
|
|
|
+ LOGGER.info('***upload_image_to_s3,uid={},errLine:{}, errMsg:{}'
|
|
|
+ .format(uid, e.__traceback__.tb_lineno, repr(e)))
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def bird_recognition_demo(resultRsp):
|
|
|
+ """
|
|
|
+ 鸟类识别示例代码
|
|
|
+ """
|
|
|
+ from PIL import Image
|
|
|
+ import base64
|
|
|
+ from io import BytesIO
|
|
|
+ import json
|
|
|
+ import boto3
|
|
|
+
|
|
|
+ image_path = "E:/stephen/AnsjerProject/bird_sagemaker/test.png"
|
|
|
+ image = Image.open(image_path).convert("RGB")
|
|
|
+
|
|
|
+ # 将图像转换为 Base64 编码字符串
|
|
|
+ image_buffer = BytesIO()
|
|
|
+ image.save(image_buffer, format="JPEG")
|
|
|
+ image_base64 = base64.b64encode(image_buffer.getvalue()).decode("utf-8")
|
|
|
+ input_json = json.dumps({'input': image_base64})
|
|
|
+
|
|
|
+ AWS_SERVER_PUBLIC_KEY = 'AKIARBAGHTGOSK37QB4T'
|
|
|
+ AWS_SERVER_SECRET_KEY = 'EUbBgXNXV1yIuj1D6QfUA70b5m/SQbuYpW5vqUYx'
|
|
|
+ runtime = boto3.client("sagemaker-runtime",
|
|
|
+ aws_access_key_id=AWS_SERVER_PUBLIC_KEY,
|
|
|
+ aws_secret_access_key=AWS_SERVER_SECRET_KEY,
|
|
|
+ region_name='us-east-1')
|
|
|
+
|
|
|
+ endpoint_name = 'ServerlessClsBirdEndpoint1'
|
|
|
+ content_type = "application/json"
|
|
|
+ payload = input_json
|
|
|
+ session = boto3.Session()
|
|
|
+ # runtime = session.client('sagemaker-runtime')
|
|
|
+ import time
|
|
|
+ start_time = time.time()
|
|
|
+ response = runtime.invoke_endpoint(
|
|
|
+ EndpointName=endpoint_name,
|
|
|
+ ContentType=content_type,
|
|
|
+ Body=payload
|
|
|
+ )
|
|
|
+ end_time = time.time()
|
|
|
+ print("耗时: {:.2f}秒".format(end_time - start_time))
|
|
|
+
|
|
|
+ result = json.loads(response['Body'].read().decode())
|
|
|
+ # result = json.loads(response['Body'].read().decode())
|
|
|
+ print(result)
|
|
|
+ return resultRsp.json(0)
|