SageMakerAiObject.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. # -*- encoding: utf-8 -*-
  2. """
  3. @File : SageMakerAiObject.py
  4. @Time : 2023/11/10 17:52
  5. @Author : stephen
  6. @Email : zhangdongming@asj6.wecom.work
  7. @Software: PyCharm
  8. """
  9. import base64
  10. import json
  11. import logging
  12. import time
  13. from io import BytesIO
  14. import boto3
  15. import numpy as np
  16. from PIL import Image
  17. from AnsjerPush.Config.aiConfig import AI_IDENTIFICATION_TAGS_DICT
  18. from AnsjerPush.config import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY
  19. from Controller import AiController
  20. from Object.AiEngineObject import AiEngine
  21. from Object.enums.MessageTypeEnum import MessageTypeEnum
  22. from Service.EquipmentInfoService import EquipmentInfoService
  23. from Service.HuaweiPushService.HuaweiPushService import HuaweiPushObject
  24. from Service.PushService import PushObject
  25. LOGGER = logging.getLogger('info')
  26. class SageMakerAiObject:
  27. @staticmethod
  28. def sage_maker_ai_server(uid, base64_list):
  29. try:
  30. url = 'ec2-34-192-147-108.compute-1.amazonaws.com:8001'
  31. model_name = 'pdchv'
  32. ai = AiEngine(url, model_name)
  33. if ai.health and ai.set_model(model_name):
  34. LOGGER.info(f'uid={uid} Health check passed and Model set successfully')
  35. else:
  36. LOGGER.info('SageMake fail')
  37. return False
  38. img_list = list(map(base64.b64decode, base64_list))
  39. img_list = map(BytesIO, img_list)
  40. img_list = map(Image.open, img_list)
  41. img_list = map(np.array, img_list)
  42. img_list = np.array(list(img_list))
  43. nms_threshold = 0.45
  44. confidence = 0.5
  45. client_timeout = 30
  46. start = time.time()
  47. results = ai.yolo_infer(img_list, nms_threshold, confidence, client_timeout)
  48. end = time.time()
  49. LOGGER.info(f'uid={uid},{(end - start) * 1000}ms,sageMaker Result={results}')
  50. ai.close()
  51. if results == 'e_timeout':
  52. return False
  53. if results == 'e_no_model':
  54. return False
  55. if results == 'e_connect_fail':
  56. return False
  57. return results
  58. except Exception as e:
  59. LOGGER.info('***sage_maker_ai_server***uid={},errLine={errLine}, errMsg={errMsg}'
  60. .format(uid, errLine=e.__traceback__.tb_lineno, errMsg=repr(e)))
  61. return False
  62. @staticmethod
  63. def get_table_name(uid, ai_result):
  64. """
  65. sageMaker识别结果得到标签组和坐标框
  66. {
  67. "file_0":[
  68. {
  69. "x1":551,
  70. "x2":639,
  71. "y1":179,
  72. "y2":274,
  73. "Width":0.22867838541666666,
  74. "Height":0.14794921875,
  75. "Top":0.28017578125,
  76. "Left":1.4353841145833333,
  77. "confidence":0.86,
  78. "class":"vehicle"
  79. }
  80. ]
  81. }
  82. """
  83. try:
  84. LOGGER.info(f'***get_table_name***uid={uid},接收到识别信息={ai_result}')
  85. ai_table_groups = {'person': 1, 'cat': 2, 'dog': 2, 'vehicle': 3, 'package': 4}
  86. event_group = []
  87. # 获取识别到的标签以及坐标信息
  88. for pic_key, pic_value in ai_result.items():
  89. for item in pic_value:
  90. if 'class' in item:
  91. event_group.append(ai_table_groups[item['class']])
  92. if not event_group:
  93. LOGGER.info(f'***get_table_name***uid={uid},没有识别到指定标签={event_group}')
  94. return None
  95. event_group = set(event_group)
  96. event_type = int(''.join(map(str, sorted(event_group))))
  97. label_list = []
  98. for item in event_group: # 识别到类型获取name
  99. label_list.append(AI_IDENTIFICATION_TAGS_DICT[str(item)])
  100. LOGGER.info(f'***get_table_name***uid={uid},类型组={event_type},名称组={label_list}')
  101. return {'event_type': event_type, 'label_list': label_list,
  102. 'coords': ai_result}
  103. except Exception as e:
  104. LOGGER.info('***get_table_name***uid={},errLine={errLine}, errMsg={errMsg}'
  105. .format(uid, errLine=e.__traceback__.tb_lineno, errMsg=repr(e)))
  106. return False
  107. @staticmethod
  108. def save_push_message(uid, d_push_time, uid_push_qs, channel, res, file_list):
  109. """
  110. 保存推送消息
  111. """
  112. try:
  113. LOGGER.info(f'***save_push_message***uid={uid},res={res}')
  114. if not res:
  115. return False
  116. d_push_time = int(d_push_time)
  117. # 存储消息以及推送
  118. uid_push_list = [uid_push for uid_push in uid_push_qs]
  119. user_id_list = []
  120. nickname = uid_push_list[0]['uid_set__nickname']
  121. lang = uid_push_list[0]['lang']
  122. event_type = res['event_type']
  123. label_str = SageMakerAiObject().get_push_tag_message(lang, event_type)
  124. coords = json.dumps(res['coords']) if res['coords'] else '' # 坐标框字典转json字符串
  125. now_time = int(time.time())
  126. # 推送表存储数据
  127. equipment_info_kwargs = {
  128. 'device_uid': uid,
  129. 'device_nick_name': nickname,
  130. 'channel': channel,
  131. 'is_st': 3,
  132. 'storage_location': 2,
  133. 'event_type': event_type,
  134. 'event_time': d_push_time,
  135. 'add_time': now_time,
  136. 'alarm': label_str,
  137. 'border_coords': coords
  138. }
  139. equipment_info_list = []
  140. equipment_info_model = EquipmentInfoService.randoms_choice_equipment_info() # 随机获取其中一张推送表
  141. for up in uid_push_list:
  142. # 保存推送数据
  143. tz = up['tz']
  144. if tz is None or tz == '':
  145. tz = 0
  146. user_id = up['userID_id']
  147. if user_id not in user_id_list:
  148. equipment_info_kwargs['device_user_id'] = user_id
  149. equipment_info_list.append(equipment_info_model(**equipment_info_kwargs))
  150. user_id_list.append(user_id)
  151. # 推送
  152. push_type = up['push_type']
  153. app_bundle_id = up['appBundleId']
  154. token_val = up['token_val']
  155. lang = up['lang']
  156. # 推送标题和推送内容
  157. msg_title = PushObject.get_msg_title(nickname=nickname)
  158. msg_text = PushObject.get_ai_msg_text(channel=channel, n_time=d_push_time, lang=lang, tz=tz,
  159. label=label_str)
  160. kwargs = {
  161. 'nickname': nickname,
  162. 'app_bundle_id': app_bundle_id,
  163. 'token_val': token_val,
  164. 'n_time': d_push_time,
  165. 'event_type': event_type,
  166. 'msg_title': msg_title,
  167. 'msg_text': msg_text,
  168. 'uid': uid,
  169. 'channel': channel,
  170. }
  171. SageMakerAiObject().app_user_message_push(push_type, **kwargs)
  172. if equipment_info_list: # 消息存表
  173. equipment_info_model.objects.bulk_create(equipment_info_list)
  174. SageMakerAiObject().upload_image_to_s3(uid, channel, d_push_time, file_list)
  175. AiController.AiView().save_cloud_ai_tag(uid, d_push_time, event_type, 0) # 关联AI标签
  176. return True
  177. except Exception as e:
  178. LOGGER.info('***get_table_name***uid={},errLine={errLine}, errMsg={errMsg}'
  179. .format(uid, errLine=e.__traceback__.tb_lineno, errMsg=repr(e)))
  180. return False
  181. @staticmethod
  182. def app_user_message_push(push_type, **kwargs):
  183. """
  184. ai识别 app用户消息推送
  185. """
  186. uid = kwargs['uid']
  187. try:
  188. if push_type == 0: # ios apns
  189. PushObject.ios_apns_push(**kwargs)
  190. elif push_type == 1: # android gcm
  191. PushObject.android_fcm_push(**kwargs)
  192. elif push_type == 2: # android jpush
  193. kwargs.pop('uid')
  194. PushObject.android_jpush(**kwargs)
  195. elif push_type == 3:
  196. huawei_push_object = HuaweiPushObject()
  197. huawei_push_object.send_push_notify_message(**kwargs)
  198. elif push_type == 4: # android 小米推送
  199. PushObject.android_xmpush(**kwargs)
  200. elif push_type == 5: # android vivo推送
  201. PushObject.android_vivopush(**kwargs)
  202. elif push_type == 6: # android oppo推送
  203. PushObject.android_oppopush(**kwargs)
  204. elif push_type == 7: # android 魅族推送
  205. PushObject.android_meizupush(**kwargs)
  206. else:
  207. LOGGER.info(f'uid={uid},{push_type}推送类型不存在')
  208. return False
  209. return True
  210. except Exception as e:
  211. LOGGER.info('ai推送消息异常,uid={},errLine:{}, errMsg:{}'.format(kwargs['uid']
  212. , e.__traceback__.tb_lineno, repr(e)))
  213. return False
  214. @staticmethod
  215. def get_push_tag_message(lang, event_type):
  216. """
  217. 根据语言以及推送类型得到APP通知栏文案
  218. """
  219. event_type = str(event_type)
  220. types = []
  221. if len(event_type) > 1:
  222. for i in range(1, len(event_type) + 1):
  223. types.append(MessageTypeEnum(int(event_type[i - 1:i])))
  224. else:
  225. types.append(int(event_type))
  226. msg_cn = {1: '人', 2: '动物', 3: '车', 4: '包裹'}
  227. msg_en = {1: 'person', 2: 'animal', 3: 'vehicle', 4: 'package'}
  228. msg_text = ''
  229. for item in types:
  230. if lang == 'cn':
  231. msg_text += msg_cn.get(item) + ' '
  232. else:
  233. msg_text += msg_en.get(item) + ' '
  234. return msg_text
  235. @staticmethod
  236. def upload_image_to_s3(uid, channel, d_push_time, file_list):
  237. """
  238. 上传图片到S3
  239. """
  240. # 创建 AWS 访问密钥
  241. aws_key = AWS_ACCESS_KEY_ID[1]
  242. aws_secret = AWS_SECRET_ACCESS_KEY[1]
  243. # 创建会话对象
  244. session = boto3.Session(
  245. aws_access_key_id=aws_key,
  246. aws_secret_access_key=aws_secret,
  247. region_name="us-east-1"
  248. )
  249. # 创建 S3 资源对象
  250. s3_resource = session.resource("s3")
  251. try:
  252. # 解码base64字符串为图像数据
  253. for index, pic in enumerate(file_list):
  254. image_data = base64.b64decode(pic)
  255. # 指定存储桶名称和对象键
  256. bucket_name = "foreignpush"
  257. object_key = f'{uid}/{channel}/{d_push_time}_{index}.jpeg'
  258. # 获取指定的存储桶
  259. bucket = s3_resource.Bucket(bucket_name)
  260. # 上传图像数据到 S3
  261. bucket.put_object(Key=object_key, Body=image_data)
  262. LOGGER.info(f'uid={uid},base64上传缩略图到S3成功')
  263. except Exception as e:
  264. LOGGER.info('***upload_image_to_s3,uid={},errLine:{}, errMsg:{}'
  265. .format(uid, e.__traceback__.tb_lineno, repr(e)))
  266. @staticmethod
  267. def bird_recognition_demo(resultRsp):
  268. """
  269. 鸟类识别示例代码
  270. """
  271. from PIL import Image
  272. import base64
  273. from io import BytesIO
  274. import json
  275. import boto3
  276. image_path = "E:/stephen/AnsjerProject/bird_sagemaker/test.png"
  277. image = Image.open(image_path).convert("RGB")
  278. # 将图像转换为 Base64 编码字符串
  279. image_buffer = BytesIO()
  280. image.save(image_buffer, format="JPEG")
  281. image_base64 = base64.b64encode(image_buffer.getvalue()).decode("utf-8")
  282. input_json = json.dumps({'input': image_base64})
  283. AWS_SERVER_PUBLIC_KEY = 'AKIARBAGHTGOSK37QB4T'
  284. AWS_SERVER_SECRET_KEY = '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'
  285. runtime = boto3.client("sagemaker-runtime",
  286. aws_access_key_id=AWS_SERVER_PUBLIC_KEY,
  287. aws_secret_access_key=AWS_SERVER_SECRET_KEY,
  288. region_name='us-east-1')
  289. endpoint_name = 'ServerlessClsBirdEndpoint1'
  290. content_type = "application/json"
  291. payload = input_json
  292. session = boto3.Session()
  293. # runtime = session.client('sagemaker-runtime')
  294. import time
  295. start_time = time.time()
  296. response = runtime.invoke_endpoint(
  297. EndpointName=endpoint_name,
  298. ContentType=content_type,
  299. Body=payload
  300. )
  301. end_time = time.time()
  302. print("耗时: {:.2f}秒".format(end_time - start_time))
  303. result = json.loads(response['Body'].read().decode())
  304. # result = json.loads(response['Body'].read().decode())
  305. print(result)
  306. return resultRsp.json(0)