浏览代码

修改AI接口对接sageMaker

zhangdongming 1 年之前
父节点
当前提交
fc6d8f0218
共有 3 个文件被更改,包括 402 次插入52 次删除
  1. 38 16
      Controller/AiController.py
  2. 21 36
      Object/AiEngineObject.py
  3. 343 0
      Object/SageMakerAiObject.py

文件差异内容过多而无法显示
+ 38 - 16
Controller/AiController.py


+ 21 - 36
Object/AiEngineObject.py

@@ -1,43 +1,37 @@
-# -*- encoding: utf-8 -*-
-"""
-@File    : AiEngine.py
-@Time    : 2023/2/28 15:15
-@Author  : stephen
-@Email   : zhangdongming@asj6.wecom.work
-@Software: PyCharm
-"""
 import ast
-import time
 
 import numpy as np
 import tritonclient.grpc as grpcclient
 
 
 class AiEngine:
-    def __init__(self, **kwargs):
-
-        # create grpc conncet
+    def __init__(self, url, model_name):
         try:
             self.triton_client = grpcclient.InferenceServerClient(
-                url=kwargs.get('url'),
+                url=url,
                 verbose=False,
                 ssl=False,
                 root_certificates=None,
                 private_key=None,
-                certificate_chain=None)
-            conncet = True
+                certificate_chain=None
+            )
+            self.health = self.check_health()
         except Exception as e:
-            conncet = False
-        self.health = True
+            self.health = False
+
+        self.model_name = model_name
+
+    def check_health(self):
+        conncet = True
         health_check = [
             conncet,
             self.triton_client.is_server_live(),
             self.triton_client.is_server_ready()
         ]
         if False in health_check:
-            print("health_check failed:" + health_check)
-            self.health = False
-        self.model_name = None
+            print("Health check failed:", health_check)
+            return False
+        return True
 
     def set_model(self, model_name):
         if self.triton_client.is_model_ready(model_name):
@@ -46,31 +40,29 @@ class AiEngine:
         else:
             return False
 
-    def get_model_name(self, model_name):
-        return self.model_name
-
     def yolo_infer(self, img_arr, nms_threshold, confidence, client_timeout):
-        if self.model_name == 'AI_5obj_pdcpv_detect_yolov5_pipeline':
+        if self.model_name == 'pdchv':
             pic_num_sum = len(img_arr)
             nms_threshold = np.array([[nms_threshold]] * pic_num_sum, dtype=np.float32)
             confidence = np.array([[confidence]] * pic_num_sum, dtype=np.float32)
             inputs = [
-                grpcclient.InferInput('img1', [pic_num_sum, 360, 640, 3], "UINT8"),
+                grpcclient.InferInput('img', [pic_num_sum, 360, 640, 3], "UINT8"),
                 grpcclient.InferInput('nms_threshold', [pic_num_sum, 1], "FP32"),
                 grpcclient.InferInput('confidence', [pic_num_sum, 1], "FP32")
             ]
-            outputs = [grpcclient.InferRequestedOutput('OUTPUT0')]
+            outputs = [grpcclient.InferRequestedOutput('result')]
 
             inputs[0].set_data_from_numpy(img_arr)
             inputs[1].set_data_from_numpy(nms_threshold)
             inputs[2].set_data_from_numpy(confidence)
-
             try:
                 results = self.triton_client.infer(model_name=self.model_name, inputs=inputs, outputs=outputs,
                                                    client_timeout=client_timeout)
             except Exception as e:
-                return 'e_timeout'
-            result_str = results.as_numpy('OUTPUT0')[0].decode("UTF-8")
+                print(repr(e))
+                return 'e_connect_fail'
+
+            result_str = results.as_numpy('result')[0].decode("UTF-8")
             results = ast.literal_eval(result_str)
             return results
         else:
@@ -78,10 +70,3 @@ class AiEngine:
 
     def close(self):
         return self.triton_client.close()
-
-
-def func():
-    print('func start')
-    time.sleep(1)
-    print('func end')
-

+ 343 - 0
Object/SageMakerAiObject.py

@@ -0,0 +1,343 @@
+# -*- encoding: utf-8 -*-
+"""
+@File    : SageMakerAiObject.py
+@Time    : 2023/11/10 17:52
+@Author  : stephen
+@Email   : zhangdongming@asj6.wecom.work
+@Software: PyCharm
+"""
+import base64
+import json
+import logging
+import time
+from io import BytesIO
+
+import boto3
+import numpy as np
+from PIL import Image
+
+from AnsjerPush.Config.aiConfig import AI_IDENTIFICATION_TAGS_DICT
+from AnsjerPush.config import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY
+from Controller import AiController
+from Object.AiEngineObject import AiEngine
+from Object.enums.MessageTypeEnum import MessageTypeEnum
+from Service.EquipmentInfoService import EquipmentInfoService
+from Service.HuaweiPushService.HuaweiPushService import HuaweiPushObject
+from Service.PushService import PushObject
+
+LOGGER = logging.getLogger('info')
+
+
+class SageMakerAiObject:
+
+    @staticmethod
+    def sage_maker_ai_server(uid, base64_list):
+
+        try:
+            url = 'ec2-34-192-147-108.compute-1.amazonaws.com:8001'
+            model_name = 'pdchv'
+
+            ai = AiEngine(url, model_name)
+
+            if ai.health and ai.set_model(model_name):
+                LOGGER.info(f'uid={uid} Health check passed and Model set successfully')
+            else:
+                LOGGER.info('SageMake fail')
+                return False
+
+            img_list = list(map(base64.b64decode, base64_list))
+            img_list = map(BytesIO, img_list)
+            img_list = map(Image.open, img_list)
+            img_list = map(np.array, img_list)
+            img_list = np.array(list(img_list))
+
+            nms_threshold = 0.45
+            confidence = 0.5
+            client_timeout = 30
+
+            start = time.time()
+            results = ai.yolo_infer(img_list, nms_threshold, confidence, client_timeout)
+
+            end = time.time()
+            LOGGER.info(f'uid={uid},{(end - start) * 1000}ms,sageMaker Result={results}')
+
+            ai.close()
+
+            if results == 'e_timeout':
+                return False
+            if results == 'e_no_model':
+                return False
+            if results == 'e_connect_fail':
+                return False
+
+            return results
+        except Exception as e:
+            LOGGER.info('***sage_maker_ai_server***uid={},errLine={errLine}, errMsg={errMsg}'
+                        .format(uid, errLine=e.__traceback__.tb_lineno, errMsg=repr(e)))
+            return False
+
+    @staticmethod
+    def get_table_name(uid, ai_result):
+        """
+        sageMaker识别结果得到标签组和坐标框
+        {
+            "file_0":[
+                {
+                    "x1":551,
+                    "x2":639,
+                    "y1":179,
+                    "y2":274,
+                    "Width":0.22867838541666666,
+                    "Height":0.14794921875,
+                    "Top":0.28017578125,
+                    "Left":1.4353841145833333,
+                    "confidence":0.86,
+                    "class":"vehicle"
+                }
+            ]
+        }
+        """
+        try:
+            LOGGER.info(f'***get_table_name***uid={uid},接收到识别信息={ai_result}')
+            ai_table_groups = {'person': 1, 'cat': 2, 'dog': 2, 'vehicle': 3, 'package': 4}
+            event_group = []
+
+            # 获取识别到的标签以及坐标信息
+            for pic_key, pic_value in ai_result.items():
+                for item in pic_value:
+                    if 'class' in item:
+                        event_group.append(ai_table_groups[item['class']])
+
+            if not event_group:
+                LOGGER.info(f'***get_table_name***uid={uid},没有识别到指定标签={event_group}')
+                return None
+
+            event_group = set(event_group)
+            event_type = int(''.join(map(str, sorted(event_group))))
+            label_list = []
+            for item in event_group:  # 识别到类型获取name
+                label_list.append(AI_IDENTIFICATION_TAGS_DICT[str(item)])
+            LOGGER.info(f'***get_table_name***uid={uid},类型组={event_type},名称组={label_list}')
+
+            return {'event_type': event_type, 'label_list': label_list,
+                    'coords': ai_result}
+        except Exception as e:
+            LOGGER.info('***get_table_name***uid={},errLine={errLine}, errMsg={errMsg}'
+                        .format(uid, errLine=e.__traceback__.tb_lineno, errMsg=repr(e)))
+            return False
+
+    @staticmethod
+    def save_push_message(uid, d_push_time, uid_push_qs, channel, res, file_list):
+        """
+        保存推送消息
+        """
+        try:
+            LOGGER.info(f'***save_push_message***uid={uid},res={res}')
+            if not res:
+                return False
+            d_push_time = int(d_push_time)
+            # 存储消息以及推送
+            uid_push_list = [uid_push for uid_push in uid_push_qs]
+            user_id_list = []
+            nickname = uid_push_list[0]['uid_set__nickname']
+            lang = uid_push_list[0]['lang']
+            event_type = res['event_type']
+            label_str = SageMakerAiObject().get_push_tag_message(lang, event_type)
+            coords = json.dumps(res['coords']) if res['coords'] else ''  # 坐标框字典转json字符串
+            now_time = int(time.time())
+            # 推送表存储数据
+            equipment_info_kwargs = {
+                'device_uid': uid,
+                'device_nick_name': nickname,
+                'channel': channel,
+                'is_st': 3,
+                'storage_location': 2,
+                'event_type': event_type,
+                'event_time': d_push_time,
+                'add_time': now_time,
+                'alarm': label_str,
+                'border_coords': coords
+            }
+
+            equipment_info_list = []
+            equipment_info_model = EquipmentInfoService.randoms_choice_equipment_info()  # 随机获取其中一张推送表
+            for up in uid_push_list:
+                # 保存推送数据
+                tz = up['tz']
+                if tz is None or tz == '':
+                    tz = 0
+                user_id = up['userID_id']
+                if user_id not in user_id_list:
+                    equipment_info_kwargs['device_user_id'] = user_id
+                    equipment_info_list.append(equipment_info_model(**equipment_info_kwargs))
+                    user_id_list.append(user_id)
+
+                # 推送
+                push_type = up['push_type']
+                app_bundle_id = up['appBundleId']
+                token_val = up['token_val']
+                lang = up['lang']
+                # 推送标题和推送内容
+                msg_title = PushObject.get_msg_title(nickname=nickname)
+                msg_text = PushObject.get_ai_msg_text(channel=channel, n_time=d_push_time, lang=lang, tz=tz,
+                                                      label=label_str)
+                kwargs = {
+                    'nickname': nickname,
+                    'app_bundle_id': app_bundle_id,
+                    'token_val': token_val,
+                    'n_time': d_push_time,
+                    'event_type': event_type,
+                    'msg_title': msg_title,
+                    'msg_text': msg_text,
+                    'uid': uid,
+                    'channel': channel,
+                }
+                SageMakerAiObject().app_user_message_push(push_type, **kwargs)
+
+            if equipment_info_list:  # 消息存表
+                equipment_info_model.objects.bulk_create(equipment_info_list)
+            SageMakerAiObject().upload_image_to_s3(uid, channel, d_push_time, file_list)
+            AiController.AiView().save_cloud_ai_tag(uid, d_push_time, event_type, 0)  # 关联AI标签
+            return True
+        except Exception as e:
+            LOGGER.info('***get_table_name***uid={},errLine={errLine}, errMsg={errMsg}'
+                        .format(uid, errLine=e.__traceback__.tb_lineno, errMsg=repr(e)))
+            return False
+
+    @staticmethod
+    def app_user_message_push(push_type, **kwargs):
+        """
+            ai识别 app用户消息推送
+        """
+        uid = kwargs['uid']
+        try:
+            if push_type == 0:  # ios apns
+                PushObject.ios_apns_push(**kwargs)
+            elif push_type == 1:  # android gcm
+                PushObject.android_fcm_push(**kwargs)
+            elif push_type == 2:  # android jpush
+                kwargs.pop('uid')
+                PushObject.android_jpush(**kwargs)
+            elif push_type == 3:
+                huawei_push_object = HuaweiPushObject()
+                huawei_push_object.send_push_notify_message(**kwargs)
+            elif push_type == 4:  # android 小米推送
+                PushObject.android_xmpush(**kwargs)
+            elif push_type == 5:  # android vivo推送
+                PushObject.android_vivopush(**kwargs)
+            elif push_type == 6:  # android oppo推送
+                PushObject.android_oppopush(**kwargs)
+            elif push_type == 7:  # android 魅族推送
+                PushObject.android_meizupush(**kwargs)
+            else:
+                LOGGER.info(f'uid={uid},{push_type}推送类型不存在')
+                return False
+            return True
+        except Exception as e:
+            LOGGER.info('ai推送消息异常,uid={},errLine:{}, errMsg:{}'.format(kwargs['uid']
+                                                                       , e.__traceback__.tb_lineno, repr(e)))
+            return False
+
+    @staticmethod
+    def get_push_tag_message(lang, event_type):
+        """
+        根据语言以及推送类型得到APP通知栏文案
+        """
+        event_type = str(event_type)
+        types = []
+        if len(event_type) > 1:
+            for i in range(1, len(event_type) + 1):
+                types.append(MessageTypeEnum(int(event_type[i - 1:i])))
+        else:
+            types.append(int(event_type))
+        msg_cn = {1: '人', 2: '动物', 3: '车', 4: '包裹'}
+        msg_en = {1: 'person', 2: 'animal', 3: 'vehicle', 4: 'package'}
+        msg_text = ''
+        for item in types:
+            if lang == 'cn':
+                msg_text += msg_cn.get(item) + ' '
+            else:
+                msg_text += msg_en.get(item) + ' '
+        return msg_text
+
+    @staticmethod
+    def upload_image_to_s3(uid, channel, d_push_time, file_list):
+        """
+        上传图片到S3
+        """
+        # 创建 AWS 访问密钥
+        aws_key = AWS_ACCESS_KEY_ID[1]
+        aws_secret = AWS_SECRET_ACCESS_KEY[1]
+        # 创建会话对象
+        session = boto3.Session(
+            aws_access_key_id=aws_key,
+            aws_secret_access_key=aws_secret,
+            region_name="us-east-1"
+        )
+
+        # 创建 S3 资源对象
+        s3_resource = session.resource("s3")
+        try:
+            # 解码base64字符串为图像数据
+            for index, pic in enumerate(file_list):
+                image_data = base64.b64decode(pic)
+                # 指定存储桶名称和对象键
+                bucket_name = "foreignpush"
+                object_key = f'{uid}/{channel}/{d_push_time}_{index}.jpeg'
+
+                # 获取指定的存储桶
+                bucket = s3_resource.Bucket(bucket_name)
+
+                # 上传图像数据到 S3
+                bucket.put_object(Key=object_key, Body=image_data)
+            LOGGER.info(f'uid={uid},base64上传缩略图到S3成功')
+        except Exception as e:
+            LOGGER.info('***upload_image_to_s3,uid={},errLine:{}, errMsg:{}'
+                        .format(uid, e.__traceback__.tb_lineno, repr(e)))
+
+    @staticmethod
+    def bird_recognition_demo(resultRsp):
+        """
+        鸟类识别示例代码
+        """
+        from PIL import Image
+        import base64
+        from io import BytesIO
+        import json
+        import boto3
+
+        image_path = "E:/stephen/AnsjerProject/bird_sagemaker/test.png"
+        image = Image.open(image_path).convert("RGB")
+
+        # 将图像转换为 Base64 编码字符串
+        image_buffer = BytesIO()
+        image.save(image_buffer, format="JPEG")
+        image_base64 = base64.b64encode(image_buffer.getvalue()).decode("utf-8")
+        input_json = json.dumps({'input': image_base64})
+
+        AWS_SERVER_PUBLIC_KEY = 'AKIARBAGHTGOSK37QB4T'
+        AWS_SERVER_SECRET_KEY = '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'
+        runtime = boto3.client("sagemaker-runtime",
+                               aws_access_key_id=AWS_SERVER_PUBLIC_KEY,
+                               aws_secret_access_key=AWS_SERVER_SECRET_KEY,
+                               region_name='us-east-1')
+
+        endpoint_name = 'ServerlessClsBirdEndpoint1'
+        content_type = "application/json"
+        payload = input_json
+        session = boto3.Session()
+        # runtime = session.client('sagemaker-runtime')
+        import time
+        start_time = time.time()
+        response = runtime.invoke_endpoint(
+            EndpointName=endpoint_name,
+            ContentType=content_type,
+            Body=payload
+        )
+        end_time = time.time()
+        print("耗时: {:.2f}秒".format(end_time - start_time))
+
+        result = json.loads(response['Body'].read().decode())
+        # result = json.loads(response['Body'].read().decode())
+        print(result)
+        return resultRsp.json(0)

部分文件因为文件数量过多而无法显示