SageMakerAiObject.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. # -*- encoding: utf-8 -*-
  2. """
  3. @File : SageMakerAiObject.py
  4. @Time : 2023/11/10 17:52
  5. @Author : stephen
  6. @Email : zhangdongming@asj6.wecom.work
  7. @Software: PyCharm
  8. """
  9. import base64
  10. import json
  11. import logging
  12. import threading
  13. import time
  14. from io import BytesIO
  15. import cv2
  16. import numpy as np
  17. from PIL import Image, UnidentifiedImageError
  18. from AnsjerPush.Config.aiConfig import AI_IDENTIFICATION_TAGS_DICT
  19. from AnsjerPush.config import CONFIG_EUR
  20. from AnsjerPush.config import CONFIG_INFO
  21. from AnsjerPush.config import PUSH_BUCKET
  22. from Controller import AiController
  23. from Object.AiEngineObject import AiEngine
  24. from Object.OCIObjectStorage import OCIObjectStorage
  25. from Object.enums.MessageTypeEnum import MessageTypeEnum
  26. from Service.EquipmentInfoService import EquipmentInfoService
  27. from Service.HuaweiPushService.HuaweiPushService import HuaweiPushObject
  28. from Service.PushService import PushObject
  29. LOGGER = logging.getLogger('time')
  30. class SageMakerAiObject:
  31. @staticmethod
  32. def sage_maker_ai_server(uid, base64_list):
  33. try:
  34. start = time.time()
  35. url = 'ec2-34-192-147-108.compute-1.amazonaws.com:8001'
  36. model_name = 'pdchv'
  37. ai = AiEngine(url, model_name)
  38. if ai.health and ai.set_model(model_name):
  39. LOGGER.info(f'uid={uid} Health check passed and Model set successfully')
  40. else:
  41. LOGGER.info('SageMake fail')
  42. return False
  43. for i in range(len(base64_list)):
  44. base64_list[i] = base64_list[i].replace(' ', '+')
  45. img_list = list(map(base64.b64decode, base64_list))
  46. img_list = map(BytesIO, img_list)
  47. img_list = map(Image.open, img_list)
  48. img_list = map(np.array, img_list)
  49. img_list = map(lambda x: cv2.cvtColor(x, cv2.COLOR_RGB2BGR), img_list)
  50. img_list = np.array(list(img_list))
  51. nms_threshold = 0.45
  52. confidence = 0.5
  53. client_timeout = 3
  54. results = ai.yolo_infer(img_list, nms_threshold, confidence, client_timeout)
  55. end = time.time()
  56. LOGGER.info(f'uid={uid},{(end - start)}s,sageMaker Result={results}')
  57. ai.close()
  58. if results == 'e_timeout':
  59. return False
  60. if results == 'e_no_model':
  61. return False
  62. if results == 'e_connect_fail':
  63. return False
  64. return results
  65. except UnidentifiedImageError as e:
  66. LOGGER.info('***sagemakerUnidentifiedImageError***uid={},errLine={errLine}, errMsg={errMsg}'
  67. .format(uid, errLine=e.__traceback__.tb_lineno, errMsg=repr(e)))
  68. return 'imageError'
  69. except Exception as e:
  70. LOGGER.info('***sagemakerException***uid={},errLine={errLine}, errMsg={errMsg}'
  71. .format(uid, errLine=e.__traceback__.tb_lineno, errMsg=repr(e)))
  72. return 'imageError'
  73. @staticmethod
  74. def get_table_name(uid, ai_result, detect_group):
  75. """
  76. 根据sageMaker识别结果得到标签组和坐标框
  77. :param uid: str,请求ID
  78. :param ai_result: dict,sageMaker识别结果
  79. :param detect_group: str,指定标签组
  80. :return: dict or None or False
  81. """
  82. try:
  83. LOGGER.info(f'***get_table_name***uid={uid}')
  84. # 定义标签组的映射关系
  85. ai_table_groups = {'person': 1, 'cat': 2, 'dog': 2, 'vehicle': 3, 'package': 4}
  86. event_group = []
  87. # 获取识别到的标签以及坐标信息
  88. for pic_key, pic_value in ai_result.items():
  89. for item in pic_value:
  90. if 'class' in item:
  91. event_group.append(ai_table_groups[item['class']])
  92. # 如果没有识别到标签则返回False
  93. if not event_group:
  94. LOGGER.info(f'***get_table_name***uid={uid},没有识别到标签={event_group}')
  95. return False
  96. # 去重标签组并转换为整数类型,同时将指定标签组转换为整数列表
  97. event_group = set(event_group)
  98. ai_group = list(map(int, detect_group.split(',')))
  99. is_success = False
  100. # 判断是否识别到了指定的标签
  101. for item in event_group:
  102. if item in ai_group:
  103. is_success = True
  104. break
  105. # 如果未识别到指定标签则返回False
  106. if not is_success:
  107. LOGGER.info(f'***get_table_name***uid={uid},没有识别到指定标签sageMakerLabel={event_group}')
  108. return False
  109. # 将标签组转换为整数表示的类型组
  110. event_type = int(''.join(map(str, sorted(event_group))))
  111. label_list = []
  112. # 根据标签组获取对应的标签名称
  113. for item in event_group:
  114. label_list.append(AI_IDENTIFICATION_TAGS_DICT[str(item)])
  115. LOGGER.info(f'***get_table_name***uid={uid},类型组={event_type},名称组={label_list}')
  116. # 返回包含事件类型、标签名称列表和坐标框的字典对象
  117. return {'event_type': event_type, 'label_list': label_list,
  118. 'coords': ai_result}
  119. except Exception as e:
  120. LOGGER.info('***get_table_name***uid={},errLine={errLine}, errMsg={errMsg}'
  121. .format(uid, errLine=e.__traceback__.tb_lineno, errMsg=repr(e)))
  122. return False
  123. @staticmethod
  124. def save_push_message(uid, d_push_time, uid_push_qs, channel, res, file_list, notify):
  125. """
  126. 保存推送消息
  127. """
  128. try:
  129. LOGGER.info(f'***save_push_message***uid={uid},res={res}')
  130. if not res:
  131. return False
  132. d_push_time = int(d_push_time)
  133. # 存储消息以及推送
  134. uid_push_list = [uid_push for uid_push in uid_push_qs]
  135. user_id_list = []
  136. nickname = uid_push_list[0]['uid_set__nickname']
  137. event_type = res['event_type']
  138. coords = json.dumps(res['coords']) if res['coords'] else '' # 坐标框字典转json字符串
  139. now_time = int(time.time())
  140. region = 4 if CONFIG_INFO == CONFIG_EUR else 3
  141. # 推送表存储数据
  142. equipment_info_kwargs = {
  143. 'device_uid': uid,
  144. 'device_nick_name': nickname,
  145. 'channel': channel,
  146. 'is_st': 3,
  147. 'storage_location': region,
  148. 'event_type': event_type,
  149. 'event_time': d_push_time,
  150. 'add_time': now_time,
  151. 'border_coords': coords
  152. }
  153. push_msg_list = []
  154. equipment_info_list = []
  155. equipment_info_model = EquipmentInfoService.randoms_choice_equipment_info() # 随机获取其中一张推送表
  156. for up in uid_push_list:
  157. # 保存推送数据
  158. lang = up['lang']
  159. label_str = SageMakerAiObject().get_push_tag_message(lang, event_type)
  160. equipment_info_kwargs['alarm'] = label_str
  161. tz = up['tz']
  162. if tz is None or tz == '':
  163. tz = 0
  164. user_id = up['userID_id']
  165. if user_id not in user_id_list:
  166. equipment_info_kwargs['device_user_id'] = user_id
  167. equipment_info_list.append(equipment_info_model(**equipment_info_kwargs))
  168. user_id_list.append(user_id)
  169. # 推送
  170. push_type = up['push_type']
  171. app_bundle_id = up['appBundleId']
  172. token_val = up['token_val']
  173. # 推送标题和推送内容
  174. msg_title = PushObject.get_msg_title(nickname=nickname)
  175. msg_text = PushObject.get_ai_msg_text(channel=channel, n_time=d_push_time, lang=lang, tz=tz,
  176. label=label_str)
  177. kwargs = {
  178. 'nickname': nickname,
  179. 'app_bundle_id': app_bundle_id,
  180. 'token_val': token_val,
  181. 'n_time': d_push_time,
  182. 'event_type': event_type,
  183. 'msg_title': msg_title,
  184. 'msg_text': msg_text,
  185. 'uid': uid,
  186. 'channel': channel,
  187. 'push_type': push_type
  188. }
  189. push_msg_list.append(kwargs)
  190. if equipment_info_list: # 消息存表
  191. equipment_info_model.objects.bulk_create(equipment_info_list)
  192. SageMakerAiObject().upload_image_to_s3(uid, channel, d_push_time, file_list) # 上传到云端
  193. if notify: # 异步APP消息通知
  194. push_thread = threading.Thread(target=SageMakerAiObject.async_app_msg_push,
  195. kwargs={'uid': uid, 'push_msg_list': push_msg_list})
  196. push_thread.start() # APP消息提醒异步推送
  197. AiController.AiView().save_cloud_ai_tag(uid, d_push_time, event_type, 0) # 关联AI标签
  198. return True
  199. except Exception as e:
  200. LOGGER.info('***get_table_name***uid={},errLine={errLine}, errMsg={errMsg}'
  201. .format(uid, errLine=e.__traceback__.tb_lineno, errMsg=repr(e)))
  202. return False
  203. @staticmethod
  204. def async_app_msg_push(uid, push_msg_list):
  205. """
  206. APP消息提醒异步推送
  207. """
  208. if not push_msg_list:
  209. LOGGER.info(f'***uid={uid}APP推送push_info为空***')
  210. for item in push_msg_list:
  211. push_type = item['push_type']
  212. item.pop('push_type')
  213. SageMakerAiObject.app_user_message_push(push_type, **item)
  214. LOGGER.info(f'***uid={uid}APP推送完成')
  215. @staticmethod
  216. def app_user_message_push(push_type, **kwargs):
  217. """
  218. ai识别 app用户消息推送
  219. """
  220. uid = kwargs['uid']
  221. try:
  222. if push_type == 0: # ios apns
  223. PushObject.ios_apns_push(**kwargs)
  224. elif push_type == 1: # android gcm
  225. PushObject.android_fcm_push_v1(**kwargs)
  226. elif push_type == 2: # android jpush
  227. kwargs.pop('uid')
  228. PushObject.android_jpush(**kwargs)
  229. elif push_type == 3:
  230. huawei_push_object = HuaweiPushObject()
  231. huawei_push_object.send_push_notify_message(**kwargs)
  232. elif push_type == 4: # android 小米推送
  233. PushObject.android_xmpush(**kwargs)
  234. elif push_type == 5: # android vivo推送
  235. PushObject.android_vivopush(**kwargs)
  236. elif push_type == 6: # android oppo推送
  237. PushObject.android_oppopush(**kwargs)
  238. elif push_type == 7: # android 魅族推送
  239. PushObject.android_meizupush(**kwargs)
  240. else:
  241. LOGGER.info(f'uid={uid},{push_type}推送类型不存在')
  242. return False
  243. return True
  244. except Exception as e:
  245. LOGGER.info('ai推送消息异常,uid={},errLine:{}, errMsg:{}'.format(kwargs['uid']
  246. , e.__traceback__.tb_lineno, repr(e)))
  247. return False
  248. @staticmethod
  249. def get_push_tag_message(lang, event_type):
  250. """
  251. 根据语言以及推送类型得到APP通知栏文案
  252. """
  253. event_type = str(event_type)
  254. types = []
  255. if len(event_type) > 1:
  256. for i in range(1, len(event_type) + 1):
  257. types.append(MessageTypeEnum(int(event_type[i - 1:i])))
  258. else:
  259. types.append(int(event_type))
  260. msg_cn = {1: '人', 2: '动物', 3: '车', 4: '包裹'}
  261. msg_en = {1: 'person', 2: 'animal', 3: 'vehicle', 4: 'package'}
  262. msg_text = ''
  263. for item in types:
  264. if lang == 'cn':
  265. msg_text += msg_cn.get(item) + ' '
  266. else:
  267. msg_text += msg_en.get(item) + ' '
  268. return msg_text
  269. @staticmethod
  270. def upload_image_to_s3(uid, channel, d_push_time, file_list):
  271. """
  272. 上传图片到S3
  273. """
  274. # if CONFIG_INFO == CONFIG_US or CONFIG_INFO == CONFIG_EUR:
  275. # # 创建 AWS 访问密钥
  276. # aws_key = AWS_ACCESS_KEY_ID[1]
  277. # aws_secret = AWS_SECRET_ACCESS_KEY[1]
  278. # # 创建会话对象
  279. # session = boto3.Session(
  280. # aws_access_key_id=aws_key,
  281. # aws_secret_access_key=aws_secret,
  282. # region_name="us-east-1"
  283. # )
  284. # s3_resource = session.resource("s3")
  285. # bucket_name = "foreignpush"
  286. # else:
  287. # # 存国内
  288. # aws_key = AWS_ACCESS_KEY_ID[0]
  289. # aws_secret = AWS_SECRET_ACCESS_KEY[0]
  290. # session = boto3.Session(
  291. # aws_access_key_id=aws_key,
  292. # aws_secret_access_key=aws_secret,
  293. # region_name="cn-northwest-1")
  294. # s3_resource = session.resource("s3")
  295. # bucket_name = "push"
  296. try:
  297. region = 'eur' if CONFIG_INFO == CONFIG_EUR else 'us'
  298. oci = OCIObjectStorage(region)
  299. # 解码base64字符串为图像数据
  300. for index, pic in enumerate(file_list):
  301. image_data = base64.b64decode(pic)
  302. # 指定存储桶名称和对象键
  303. object_key = f'{uid}/{channel}/{d_push_time}_{index}.jpeg'
  304. # OCI上传对象
  305. oci.put_object(PUSH_BUCKET, object_key, image_data, 'image/jpeg')
  306. # AWS获取指定的存储桶
  307. # bucket = s3_resource.Bucket(bucket_name)
  308. # AWS上传图像数据到 S3
  309. # bucket.put_object(Key=object_key, Body=image_data)
  310. LOGGER.info(f'uid={uid},base64上传缩略图到S3成功')
  311. except Exception as e:
  312. LOGGER.error('***upload_image_to_s3,uid={},errLine:{}, errMsg:{}'
  313. .format(uid, e.__traceback__.tb_lineno, repr(e)))
  314. @staticmethod
  315. def bird_recognition_demo(resultRsp):
  316. """
  317. 鸟类识别示例代码
  318. """
  319. from PIL import Image
  320. import base64
  321. from io import BytesIO
  322. import json
  323. import boto3
  324. image_path = "E:/stephen/AnsjerProject/bird_sagemaker/test.png"
  325. image = Image.open(image_path).convert("RGB")
  326. # 将图像转换为 Base64 编码字符串
  327. image_buffer = BytesIO()
  328. image.save(image_buffer, format="JPEG")
  329. image_base64 = base64.b64encode(image_buffer.getvalue()).decode("utf-8")
  330. input_json = json.dumps({'input': image_base64})
  331. AWS_SERVER_PUBLIC_KEY = 'AKIARBAGHTGOSK37QB4T'
  332. AWS_SERVER_SECRET_KEY = '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'
  333. runtime = boto3.client("sagemaker-runtime",
  334. aws_access_key_id=AWS_SERVER_PUBLIC_KEY,
  335. aws_secret_access_key=AWS_SERVER_SECRET_KEY,
  336. region_name='us-east-1')
  337. endpoint_name = 'ServerlessClsBirdEndpoint1'
  338. content_type = "application/json"
  339. payload = input_json
  340. session = boto3.Session()
  341. # runtime = session.client('sagemaker-runtime')
  342. import time
  343. start_time = time.time()
  344. response = runtime.invoke_endpoint(
  345. EndpointName=endpoint_name,
  346. ContentType=content_type,
  347. Body=payload
  348. )
  349. end_time = time.time()
  350. print("耗时: {:.2f}秒".format(end_time - start_time))
  351. result = json.loads(response['Body'].read().decode())
  352. # result = json.loads(response['Body'].read().decode())
  353. print(result)
  354. return resultRsp.json(0)