123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400 |
- # -*- encoding: utf-8 -*-
- """
- @File : SageMakerAiObject.py
- @Time : 2023/11/10 17:52
- @Author : stephen
- @Email : zhangdongming@asj6.wecom.work
- @Software: PyCharm
- """
- import base64
- import json
- import logging
- import threading
- import time
- from io import BytesIO
- import cv2
- import numpy as np
- from PIL import Image, UnidentifiedImageError
- from AnsjerPush.Config.aiConfig import AI_IDENTIFICATION_TAGS_DICT
- from AnsjerPush.config import CONFIG_EUR
- from AnsjerPush.config import CONFIG_INFO
- from AnsjerPush.config import PUSH_BUCKET
- from Controller import AiController
- from Object.AiEngineObject import AiEngine
- from Object.OCIObjectStorage import OCIObjectStorage
- from Object.enums.MessageTypeEnum import MessageTypeEnum
- from Service.EquipmentInfoService import EquipmentInfoService
- from Service.HuaweiPushService.HuaweiPushService import HuaweiPushObject
- from Service.PushService import PushObject
- LOGGER = logging.getLogger('time')
- class SageMakerAiObject:
- @staticmethod
- def sage_maker_ai_server(uid, base64_list):
- try:
- start = time.time()
- url = 'ec2-34-192-147-108.compute-1.amazonaws.com:8001'
- model_name = 'pdchv'
- ai = AiEngine(url, model_name)
- if ai.health and ai.set_model(model_name):
- LOGGER.info(f'uid={uid} Health check passed and Model set successfully')
- else:
- LOGGER.info('SageMake fail')
- return False
- for i in range(len(base64_list)):
- base64_list[i] = base64_list[i].replace(' ', '+')
- img_list = list(map(base64.b64decode, base64_list))
- img_list = map(BytesIO, img_list)
- img_list = map(Image.open, img_list)
- img_list = map(np.array, img_list)
- img_list = map(lambda x: cv2.cvtColor(x, cv2.COLOR_RGB2BGR), img_list)
- img_list = np.array(list(img_list))
- nms_threshold = 0.45
- confidence = 0.5
- client_timeout = 3
- results = ai.yolo_infer(img_list, nms_threshold, confidence, client_timeout)
- end = time.time()
- LOGGER.info(f'uid={uid},{(end - start)}s,sageMaker Result={results}')
- ai.close()
- if results == 'e_timeout':
- return False
- if results == 'e_no_model':
- return False
- if results == 'e_connect_fail':
- return False
- return results
- except UnidentifiedImageError as e:
- LOGGER.info('***sagemakerUnidentifiedImageError***uid={},errLine={errLine}, errMsg={errMsg}'
- .format(uid, errLine=e.__traceback__.tb_lineno, errMsg=repr(e)))
- return 'imageError'
- except Exception as e:
- LOGGER.info('***sagemakerException***uid={},errLine={errLine}, errMsg={errMsg}'
- .format(uid, errLine=e.__traceback__.tb_lineno, errMsg=repr(e)))
- return 'imageError'
- @staticmethod
- def get_table_name(uid, ai_result, detect_group):
- """
- 根据sageMaker识别结果得到标签组和坐标框
- :param uid: str,请求ID
- :param ai_result: dict,sageMaker识别结果
- :param detect_group: str,指定标签组
- :return: dict or None or False
- """
- try:
- LOGGER.info(f'***get_table_name***uid={uid}')
- # 定义标签组的映射关系
- ai_table_groups = {'person': 1, 'cat': 2, 'dog': 2, 'vehicle': 3, 'package': 4}
- event_group = []
- # 获取识别到的标签以及坐标信息
- for pic_key, pic_value in ai_result.items():
- for item in pic_value:
- if 'class' in item:
- event_group.append(ai_table_groups[item['class']])
- # 如果没有识别到标签则返回False
- if not event_group:
- LOGGER.info(f'***get_table_name***uid={uid},没有识别到标签={event_group}')
- return False
- # 去重标签组并转换为整数类型,同时将指定标签组转换为整数列表
- event_group = set(event_group)
- ai_group = list(map(int, detect_group.split(',')))
- is_success = False
- # 判断是否识别到了指定的标签
- for item in event_group:
- if item in ai_group:
- is_success = True
- break
- # 如果未识别到指定标签则返回False
- if not is_success:
- LOGGER.info(f'***get_table_name***uid={uid},没有识别到指定标签sageMakerLabel={event_group}')
- return False
- # 将标签组转换为整数表示的类型组
- event_type = int(''.join(map(str, sorted(event_group))))
- label_list = []
- # 根据标签组获取对应的标签名称
- for item in event_group:
- label_list.append(AI_IDENTIFICATION_TAGS_DICT[str(item)])
- LOGGER.info(f'***get_table_name***uid={uid},类型组={event_type},名称组={label_list}')
- # 返回包含事件类型、标签名称列表和坐标框的字典对象
- return {'event_type': event_type, 'label_list': label_list,
- 'coords': ai_result}
- except Exception as e:
- LOGGER.info('***get_table_name***uid={},errLine={errLine}, errMsg={errMsg}'
- .format(uid, errLine=e.__traceback__.tb_lineno, errMsg=repr(e)))
- return False
- @staticmethod
- def save_push_message(uid, d_push_time, uid_push_qs, channel, res, file_list, notify):
- """
- 保存推送消息
- """
- try:
- LOGGER.info(f'***save_push_message***uid={uid},res={res}')
- if not res:
- return False
- d_push_time = int(d_push_time)
- # 存储消息以及推送
- uid_push_list = [uid_push for uid_push in uid_push_qs]
- user_id_list = []
- nickname = uid_push_list[0]['uid_set__nickname']
- event_type = res['event_type']
- coords = json.dumps(res['coords']) if res['coords'] else '' # 坐标框字典转json字符串
- now_time = int(time.time())
- region = 4 if CONFIG_INFO == CONFIG_EUR else 3
- # 推送表存储数据
- equipment_info_kwargs = {
- 'device_uid': uid,
- 'device_nick_name': nickname,
- 'channel': channel,
- 'is_st': 3,
- 'storage_location': region,
- 'event_type': event_type,
- 'event_time': d_push_time,
- 'add_time': now_time,
- 'border_coords': coords
- }
- push_msg_list = []
- equipment_info_list = []
- equipment_info_model = EquipmentInfoService.randoms_choice_equipment_info() # 随机获取其中一张推送表
- for up in uid_push_list:
- # 保存推送数据
- lang = up['lang']
- label_str = SageMakerAiObject().get_push_tag_message(lang, event_type)
- equipment_info_kwargs['alarm'] = label_str
- tz = up['tz']
- if tz is None or tz == '':
- tz = 0
- user_id = up['userID_id']
- if user_id not in user_id_list:
- equipment_info_kwargs['device_user_id'] = user_id
- equipment_info_list.append(equipment_info_model(**equipment_info_kwargs))
- user_id_list.append(user_id)
- # 推送
- push_type = up['push_type']
- app_bundle_id = up['appBundleId']
- token_val = up['token_val']
- # 推送标题和推送内容
- msg_title = PushObject.get_msg_title(nickname=nickname)
- msg_text = PushObject.get_ai_msg_text(channel=channel, n_time=d_push_time, lang=lang, tz=tz,
- label=label_str)
- kwargs = {
- 'nickname': nickname,
- 'app_bundle_id': app_bundle_id,
- 'token_val': token_val,
- 'n_time': d_push_time,
- 'event_type': event_type,
- 'msg_title': msg_title,
- 'msg_text': msg_text,
- 'uid': uid,
- 'channel': channel,
- 'push_type': push_type
- }
- push_msg_list.append(kwargs)
- if equipment_info_list: # 消息存表
- equipment_info_model.objects.bulk_create(equipment_info_list)
- SageMakerAiObject().upload_image_to_s3(uid, channel, d_push_time, file_list) # 上传到云端
- if notify: # 异步APP消息通知
- push_thread = threading.Thread(target=SageMakerAiObject.async_app_msg_push,
- kwargs={'uid': uid, 'push_msg_list': push_msg_list})
- push_thread.start() # APP消息提醒异步推送
- AiController.AiView().save_cloud_ai_tag(uid, d_push_time, event_type, 0) # 关联AI标签
- return True
- except Exception as e:
- LOGGER.info('***get_table_name***uid={},errLine={errLine}, errMsg={errMsg}'
- .format(uid, errLine=e.__traceback__.tb_lineno, errMsg=repr(e)))
- return False
- @staticmethod
- def async_app_msg_push(uid, push_msg_list):
- """
- APP消息提醒异步推送
- """
- if not push_msg_list:
- LOGGER.info(f'***uid={uid}APP推送push_info为空***')
- for item in push_msg_list:
- push_type = item['push_type']
- item.pop('push_type')
- SageMakerAiObject.app_user_message_push(push_type, **item)
- LOGGER.info(f'***uid={uid}APP推送完成')
- @staticmethod
- def app_user_message_push(push_type, **kwargs):
- """
- ai识别 app用户消息推送
- """
- uid = kwargs['uid']
- try:
- if push_type == 0: # ios apns
- PushObject.ios_apns_push(**kwargs)
- elif push_type == 1: # android gcm
- PushObject.android_fcm_push_v1(**kwargs)
- elif push_type == 2: # android jpush
- kwargs.pop('uid')
- PushObject.android_jpush(**kwargs)
- elif push_type == 3:
- huawei_push_object = HuaweiPushObject()
- huawei_push_object.send_push_notify_message(**kwargs)
- elif push_type == 4: # android 小米推送
- PushObject.android_xmpush(**kwargs)
- elif push_type == 5: # android vivo推送
- PushObject.android_vivopush(**kwargs)
- elif push_type == 6: # android oppo推送
- PushObject.android_oppopush(**kwargs)
- elif push_type == 7: # android 魅族推送
- PushObject.android_meizupush(**kwargs)
- else:
- LOGGER.info(f'uid={uid},{push_type}推送类型不存在')
- return False
- return True
- except Exception as e:
- LOGGER.info('ai推送消息异常,uid={},errLine:{}, errMsg:{}'.format(kwargs['uid']
- , e.__traceback__.tb_lineno, repr(e)))
- return False
- @staticmethod
- def get_push_tag_message(lang, event_type):
- """
- 根据语言以及推送类型得到APP通知栏文案
- """
- event_type = str(event_type)
- types = []
- if len(event_type) > 1:
- for i in range(1, len(event_type) + 1):
- types.append(MessageTypeEnum(int(event_type[i - 1:i])))
- else:
- types.append(int(event_type))
- msg_cn = {1: '人', 2: '动物', 3: '车', 4: '包裹'}
- msg_en = {1: 'person', 2: 'animal', 3: 'vehicle', 4: 'package'}
- msg_text = ''
- for item in types:
- if lang == 'cn':
- msg_text += msg_cn.get(item) + ' '
- else:
- msg_text += msg_en.get(item) + ' '
- return msg_text
- @staticmethod
- def upload_image_to_s3(uid, channel, d_push_time, file_list):
- """
- 上传图片到S3
- """
- # if CONFIG_INFO == CONFIG_US or CONFIG_INFO == CONFIG_EUR:
- # # 创建 AWS 访问密钥
- # aws_key = AWS_ACCESS_KEY_ID[1]
- # aws_secret = AWS_SECRET_ACCESS_KEY[1]
- # # 创建会话对象
- # session = boto3.Session(
- # aws_access_key_id=aws_key,
- # aws_secret_access_key=aws_secret,
- # region_name="us-east-1"
- # )
- # s3_resource = session.resource("s3")
- # bucket_name = "foreignpush"
- # else:
- # # 存国内
- # aws_key = AWS_ACCESS_KEY_ID[0]
- # aws_secret = AWS_SECRET_ACCESS_KEY[0]
- # session = boto3.Session(
- # aws_access_key_id=aws_key,
- # aws_secret_access_key=aws_secret,
- # region_name="cn-northwest-1")
- # s3_resource = session.resource("s3")
- # bucket_name = "push"
- try:
- region = 'eur' if CONFIG_INFO == CONFIG_EUR else 'us'
- oci = OCIObjectStorage(region)
- # 解码base64字符串为图像数据
- for index, pic in enumerate(file_list):
- image_data = base64.b64decode(pic)
- # 指定存储桶名称和对象键
- object_key = f'{uid}/{channel}/{d_push_time}_{index}.jpeg'
- # OCI上传对象
- oci.put_object(PUSH_BUCKET, object_key, image_data, 'image/jpeg')
- # AWS获取指定的存储桶
- # bucket = s3_resource.Bucket(bucket_name)
- # AWS上传图像数据到 S3
- # bucket.put_object(Key=object_key, Body=image_data)
- LOGGER.info(f'uid={uid},base64上传缩略图到S3成功')
- except Exception as e:
- LOGGER.error('***upload_image_to_s3,uid={},errLine:{}, errMsg:{}'
- .format(uid, e.__traceback__.tb_lineno, repr(e)))
- @staticmethod
- def bird_recognition_demo(resultRsp):
- """
- 鸟类识别示例代码
- """
- from PIL import Image
- import base64
- from io import BytesIO
- import json
- import boto3
- image_path = "E:/stephen/AnsjerProject/bird_sagemaker/test.png"
- image = Image.open(image_path).convert("RGB")
- # 将图像转换为 Base64 编码字符串
- image_buffer = BytesIO()
- image.save(image_buffer, format="JPEG")
- image_base64 = base64.b64encode(image_buffer.getvalue()).decode("utf-8")
- input_json = json.dumps({'input': image_base64})
- AWS_SERVER_PUBLIC_KEY = 'AKIARBAGHTGOSK37QB4T'
- AWS_SERVER_SECRET_KEY = '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'
- runtime = boto3.client("sagemaker-runtime",
- aws_access_key_id=AWS_SERVER_PUBLIC_KEY,
- aws_secret_access_key=AWS_SERVER_SECRET_KEY,
- region_name='us-east-1')
- endpoint_name = 'ServerlessClsBirdEndpoint1'
- content_type = "application/json"
- payload = input_json
- session = boto3.Session()
- # runtime = session.client('sagemaker-runtime')
- import time
- start_time = time.time()
- response = runtime.invoke_endpoint(
- EndpointName=endpoint_name,
- ContentType=content_type,
- Body=payload
- )
- end_time = time.time()
- print("耗时: {:.2f}秒".format(end_time - start_time))
- result = json.loads(response['Body'].read().decode())
- # result = json.loads(response['Body'].read().decode())
- print(result)
- return resultRsp.json(0)
|