AiController.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. import base64
  2. import json
  3. import logging
  4. import os
  5. import threading
  6. import time
  7. import numpy as np
  8. from PIL import Image
  9. from django.views.generic.base import View
  10. from AnsjerPush.config import BASE_DIR
  11. from Model.models import UidPushModel, AiService, Device_Info, VodHlsTag, VodHlsTagType
  12. from Object.AiEngineObject import AiEngine
  13. from Object.ETkObject import ETkObject
  14. from Object.ResponseObject import ResponseObject
  15. from Object.enums.MessageTypeEnum import MessageTypeEnum
  16. from Object.utils import LocalDateTimeUtil
  17. from Service.CommonService import CommonService
  18. from Service.EquipmentInfoService import EquipmentInfoService
  19. from Service.HuaweiPushService.HuaweiPushService import HuaweiPushObject
  20. from Service.PushService import PushObject
  21. LOGGING = logging.getLogger('info')
  22. CLOUD_BASED_AI_URL = '34.192.147.108:8001'
  23. MODEL_NAME = 'AI_5obj_pdcpv_detect_yolov5_pipeline'
  24. # 建立长连接
  25. ai_connect = ''# AiEngine(url=CLOUD_BASED_AI_URL)
  26. HEALTH = False
  27. # 检查连通性、推理服务器状态
  28. # if ai_connect.health:
  29. # HEALTH = True
  30. # print('健康状况通过')
  31. # # 设定模型
  32. # if ai_connect.set_model(MODEL_NAME):
  33. # print('设置模型通过')
  34. class AiView(View):
  35. def get(self, request, *args, **kwargs):
  36. request.encoding = 'utf-8'
  37. operation = kwargs.get('operation')
  38. return self.validation(request.GET, operation)
  39. def post(self, request, *args, **kwargs):
  40. request.encoding = 'utf-8'
  41. operation = kwargs.get('operation')
  42. return self.validation(request.POST, operation)
  43. def validation(self, request_dict, operation):
  44. response = ResponseObject()
  45. if operation == 'identification': # ai识别推送
  46. return self.identification(request_dict, response)
  47. else:
  48. return response.json(414)
  49. @staticmethod
  50. def identification(request_dict, response):
  51. """
  52. ai识别推送
  53. @param request_dict: 请求数据
  54. @request_dict etk: uid token
  55. @request_dict n_time: 设备的当前时间
  56. @request_dict channel: 通道
  57. @request_dict fileOne: 图片一
  58. @request_dict fileTwo: 图片二
  59. @request_dict fileThree: 图片三
  60. @param response: 响应
  61. @return: response
  62. """
  63. etk = request_dict.get('etk', None)
  64. n_time = request_dict.get('n_time', None)
  65. channel = request_dict.get('channel', '1')
  66. file_one = request_dict.get('fileOne', None)
  67. file_two = request_dict.get('fileTwo', None)
  68. file_three = request_dict.get('fileThree', None)
  69. if not all([etk, n_time]):
  70. return response.json(444)
  71. # 解密etk并判断uid长度
  72. eto = ETkObject(etk)
  73. uid = eto.uid
  74. LOGGING.info('---进入ai识别推送接口--- etk:{}, uid:{}'.format(etk, uid))
  75. receive_time = int(time.time())
  76. file_list = [file_one, file_two, file_three]
  77. # 查询设备是否有使用中的ai服务
  78. ai_service_qs = AiService.objects.filter(uid=uid, detect_status=1, use_status=1, endTime__gt=receive_time). \
  79. values('detect_group')
  80. if not ai_service_qs.exists():
  81. return response.json(173)
  82. # 查询推送相关数据
  83. uid_push_qs = UidPushModel.objects.filter(uid_set__uid=uid). \
  84. values('push_type', 'appBundleId', 'token_val', 'lang', 'tz', 'userID_id')
  85. if not uid_push_qs.exists():
  86. return response.json(173)
  87. # 查询设备数据
  88. device_info_qs = Device_Info.objects.filter(UID=uid).first()
  89. nickname = uid if device_info_qs is None else device_info_qs.NickName
  90. now_time = int(time.time())
  91. try:
  92. dir_path = os.path.join(BASE_DIR, 'static/ai/' + uid + '/' + str(n_time))
  93. if not os.path.exists(dir_path):
  94. os.makedirs(dir_path)
  95. file_path_list = []
  96. for i, val in enumerate(file_list):
  97. val = val.replace(' ', '+')
  98. val = base64.b64decode(val)
  99. file_path = "{dir_path}/{n_time}_{i}.jpg".format(dir_path=dir_path, n_time=n_time, i=i)
  100. file_path_list.append(file_path)
  101. with open(file_path, 'wb') as f:
  102. f.write(val)
  103. f.close()
  104. ai_view = AiView()
  105. ai_results = ai_view.image_aI_recognition(file_path_list, 0.45, 0.55, 1)
  106. if not ai_results:
  107. CommonService.del_path(dir_path)
  108. return response.json(0)
  109. event_type = ai_view.get_cloud_recognition_tag(ai_results)
  110. if event_type == 0:
  111. CommonService.del_path(dir_path)
  112. return response.json(0)
  113. new_bounding_box_dict = ai_view.get_pic_coordinates(ai_results)
  114. # 上传缩略图到s3
  115. file_dict = {}
  116. for i, val in enumerate(file_path_list):
  117. # 封面图
  118. file_dict[val] = '{}/{}/{}_{}.jpeg'.format(uid, channel, n_time, i)
  119. upload_images_thread = threading.Thread(target=CommonService.upload_images, args=(file_dict, dir_path))
  120. upload_images_thread.start()
  121. # 存储消息以及推送
  122. uid_push_list = [uid_push for uid_push in uid_push_qs]
  123. eq_list = []
  124. user_id_list = []
  125. local_date_time = ''
  126. lang = uid_push_list[0]['lang']
  127. label_str = ai_view.get_tag_message(lang, event_type)
  128. for up in uid_push_list:
  129. # 保存推送数据
  130. tz = up['tz']
  131. if tz is None or tz == '':
  132. tz = 0
  133. local_date_time = CommonService.get_now_time_str(n_time=n_time, tz=tz, lang='cn')[:10]
  134. user_id = up['userID_id']
  135. if user_id not in user_id_list:
  136. eq_list.append(EquipmentInfoService.get_equipment_info_obj(
  137. local_date_time,
  138. device_user_id=user_id,
  139. event_time=n_time,
  140. event_type=event_type,
  141. device_uid=uid,
  142. device_nick_name=nickname,
  143. channel=channel,
  144. alarm=label_str,
  145. is_st=3,
  146. receive_time=receive_time,
  147. add_time=now_time,
  148. storage_location=2,
  149. border_coords=new_bounding_box_dict
  150. ))
  151. user_id_list.append(user_id)
  152. # 推送
  153. push_type = up['push_type']
  154. app_bundle_id = up['appBundleId']
  155. token_val = up['token_val']
  156. lang = up['lang']
  157. # 推送标题和推送内容
  158. msg_title = PushObject.get_msg_title(nickname=nickname)
  159. msg_text = PushObject.get_ai_msg_text(channel=channel, n_time=n_time, lang=lang, tz=tz, label=label_str)
  160. kwargs = {
  161. 'nickname': nickname,
  162. 'app_bundle_id': app_bundle_id,
  163. 'token_val': token_val,
  164. 'n_time': n_time,
  165. 'event_type': event_type,
  166. 'msg_title': msg_title,
  167. 'msg_text': msg_text,
  168. 'uid': uid,
  169. 'channel': channel,
  170. }
  171. try:
  172. # 推送消息
  173. if push_type == 0: # ios apns
  174. PushObject.ios_apns_push(**kwargs)
  175. elif push_type == 1: # android gcm
  176. PushObject.android_fcm_push(**kwargs)
  177. elif push_type == 2: # android jpush
  178. kwargs.pop('uid')
  179. PushObject.android_jpush(**kwargs)
  180. elif push_type == 3:
  181. huawei_push_object = HuaweiPushObject()
  182. huawei_push_object.send_push_notify_message(**kwargs)
  183. elif push_type == 4: # android 小米推送
  184. PushObject.android_xmpush(**kwargs)
  185. elif push_type == 5: # android vivo推送
  186. PushObject.android_vivopush(**kwargs)
  187. elif push_type == 6: # android oppo推送
  188. PushObject.android_oppopush(**kwargs)
  189. elif push_type == 7: # android 魅族推送
  190. PushObject.android_meizupush(**kwargs)
  191. except Exception as e:
  192. LOGGING.info('ai推送消息异常,errLine:{}, errMsg:{}'.format(e.__traceback__.tb_lineno, repr(e)))
  193. continue
  194. AiView.save_cloud_ai_tag(uid, int(n_time), event_type)
  195. week = LocalDateTimeUtil.date_to_week(local_date_time)
  196. EquipmentInfoService.equipment_info_bulk_create(week, eq_list)
  197. return response.json(0)
  198. except Exception as e:
  199. LOGGING.info('---ai识别推送异常---:{}'.format(repr(e)))
  200. data = {
  201. 'errLine': e.__traceback__.tb_lineno,
  202. 'errMsg': repr(e)
  203. }
  204. return response.json(48, data)
  205. @classmethod
  206. def image_aI_recognition(cls, input_name_arr, nms_threshold, confidence, client_timeout):
  207. """
  208. 自有图片云模型识别
  209. :param input_name_arr: 推理图片地址名
  210. :param nms_threshold: nms置信度
  211. :param confidence: 目标置信度(一般只用调整这个)
  212. :param client_timeout: 超时时间(秒为单位)
  213. :return: results 推理结果
  214. """
  215. try:
  216. if not HEALTH:
  217. LOGGING.info('AI health:{}'.format(HEALTH))
  218. return {}
  219. # 推理张数(一次最多推理128张!)
  220. # 图片名称(这里可以改成内存)注意改完之后要检查input_tmp的【类型(type)、尺寸(shape)】是否和之前的一致
  221. # 输入尺寸固定640wx360h,如需变动可以联系我们,我们这边做resize会快
  222. input_name_arr = np.array(list(map(np.array, map(Image.open, input_name_arr))))
  223. # 推理
  224. results = ai_connect.yolo_infer(input_name_arr, nms_threshold, confidence, client_timeout)
  225. # 报错返回
  226. if results == 'e_timeout':
  227. raise Exception('推理超时')
  228. elif results == 'e_no_model':
  229. raise Exception('没有设置模型')
  230. LOGGING.info('云上模型推理结果:{}'.format(results))
  231. return results
  232. except Exception as e:
  233. LOGGING.info('云模型AI识别失败,errLine:{}, errMsg:{}'.format(e.__traceback__.tb_lineno, repr(e)))
  234. return {}
  235. @classmethod
  236. def get_cloud_recognition_tag(cls, results):
  237. """
  238. 根据推理结果
  239. 返回组合类型,0则推理失败,1:人,2:车,3:宠物,4:包裹
  240. """
  241. if not results:
  242. return 0
  243. # 返回字典,可以复制到绘制python脚本中查看结果
  244. tag_list = []
  245. event_dict = {1: 1, 0: 2, 3: 2, 2: 3, 4: 4}
  246. for k, v in results.items():
  247. if len(v) > 0:
  248. for item in v:
  249. tag_list.append(event_dict.get(item['classID'], 0))
  250. if tag_list:
  251. tag_list = set(tag_list)
  252. event_type = ''
  253. for val in tag_list:
  254. event_type += str(val)
  255. return int(event_type)
  256. else:
  257. return 0
  258. @classmethod
  259. def get_pic_coordinates(cls, results):
  260. """
  261. 获取识别图片坐标
  262. """
  263. try:
  264. ai_dict = {}
  265. for i in range(3):
  266. ai_dict['file_' + str(i)] = results['pic_' + str(i)]
  267. return json.dumps(ai_dict)
  268. except Exception as e:
  269. LOGGING.info('AI推理结果解析异常详情,errLine:{}, errMsg:{}'.format(e.__traceback__.tb_lineno, repr(e)))
  270. return ''
  271. @staticmethod
  272. def get_tag_message(lang, event_type):
  273. event_type = str(event_type)
  274. types = []
  275. if len(event_type) > 1:
  276. for i in range(1, len(event_type) + 1):
  277. types.append(MessageTypeEnum(int(event_type[i - 1:i])))
  278. else:
  279. types.append(int(event_type))
  280. msg_cn = {1: '人', 2: '动物', 3: '车', 4: '包裹'}
  281. msg_en = {1: 'person', 2: 'animal', 3: 'vehicle', 4: 'package'}
  282. msg_text = ''
  283. for item in types:
  284. if lang == 'cn':
  285. msg_text += msg_cn.get(item) + ' '
  286. else:
  287. msg_text += msg_en.get(item) + ' '
  288. return msg_text
  289. @classmethod
  290. def save_cloud_ai_tag(cls, uid, event_time, types):
  291. """
  292. 保存云存AI标签
  293. """
  294. try:
  295. types = str(types)
  296. if not types:
  297. return False
  298. n_time = int(time.time())
  299. vod_hls_tag = {"uid": uid, "ai_event_time": event_time, "created_time": n_time}
  300. vod_tag_vo = VodHlsTag.objects.create(**vod_hls_tag)
  301. tag_list = []
  302. if len(types) > 1:
  303. for i in range(1, len(types) + 1):
  304. ai_type = MessageTypeEnum(int(types[i - 1:i]))
  305. vod_tag_type_vo = VodHlsTagType(tag_id=vod_tag_vo.id, created_time=n_time, type=ai_type.value)
  306. tag_list.append(vod_tag_type_vo)
  307. else:
  308. ai_type = MessageTypeEnum(int(types))
  309. vod_tag_type_vo = {"tag_id": vod_tag_vo.id, "created_time": n_time, "type": ai_type.value}
  310. VodHlsTagType.objects.create(**vod_tag_type_vo)
  311. if tag_list:
  312. VodHlsTagType.objects.bulk_create(tag_list)
  313. return True
  314. except Exception as e:
  315. print('AI标签存储异常详情,errLine:{}, errMsg:{}'.format(e.__traceback__.tb_lineno, repr(e)))
  316. return False