Prechádzať zdrojové kódy

新增ai接入Nova模型识别

zhangdongming 3 týždňov pred
rodič
commit
d7eb3eb7fa

+ 72 - 0
Controller/AiController.py

@@ -38,8 +38,23 @@ from Object.enums.MessageTypeEnum import MessageTypeEnum
 from Service.CommonService import CommonService
 from Service.DevicePushService import DevicePushService
 from Service.EquipmentInfoService import EquipmentInfoService
+from Object.NovaImageTagObject import NovaImageTagObject
+from datetime import datetime
 
 TIME_LOGGER = logging.getLogger('time')
+# 1. 声明一个全局变量,用于存放创建好的实例
+nova_image_tag_instance = None
+
+# 2. 根据条件,创建实例并赋值给这个全局变量
+# 这段代码会放在文件的最顶部,它只在 Gunicorn worker 进程启动时执行一次
+if CONFIG_INFO == CONFIG_EUR:
+    nova_image_tag_instance = NovaImageTagObject(
+        AWS_ACCESS_KEY_ID[1], AWS_SECRET_ACCESS_KEY[1], 'eu-west-2'
+    )
+else:
+    nova_image_tag_instance = NovaImageTagObject(
+        AWS_ACCESS_KEY_ID[1], AWS_SECRET_ACCESS_KEY[1], 'us-east-1'
+    )
 
 
 # AI服务
@@ -141,7 +156,22 @@ class AiView(View):
 
             # APP推送提醒状态
             notify = self.is_ai_push(uid, notify_data) if is_push else is_push
+            nova_key = f'PUSH:NOVA:LITE:{uid}'
+            nova = redis_obj.get_data(nova_key)
+            if nova:  # AWS AI模型
+                sage_maker = SageMakerAiObject()
 
+                # AI nova识别异步存表&推送
+                push_thread = threading.Thread(
+                    target=self.async_detection_image_label,
+                    kwargs={'sage_maker': sage_maker, 'uid': uid, 'n_time': n_time, 'uid_push_qs': uid_push_qs,
+                            'channel': channel, 'file_list': file_list, 'notify': notify, 'detect_group': detect_group})
+                push_thread.start()
+
+                self.add_push_cache(APP_NOTIFY_KEY, redis_obj, push_cache_data,
+                                    uid_push_qs[0]['uid_set__new_detect_interval'])
+                redis_obj.set_data(ai_key, uid, 60)
+                return response.json(0)
             if ai_server == 'sageMaker':  # 自建模型sageMaker AI
                 sage_maker = SageMakerAiObject()
                 ai_result = sage_maker.sage_maker_ai_server(uid, file_list)  # 图片base64识别AI标签
@@ -714,3 +744,45 @@ class AiView(View):
         except Exception as e:
             logger.info('{}识别后存S3与DynamoDB失败:{}'.format(uid, repr(e)))
             return False
+
+    @staticmethod
+    def async_detection_image_label(sage_maker,uid, n_time, uid_push_qs,
+                            channel, file_list, notify,detect_group):
+        final_results = AiView.get_nova_tag_recognition(file_list, uid)
+        if not final_results:
+            return
+        res = sage_maker.get_table_name(uid, final_results, detect_group)
+        sage_maker.save_push_message(uid, n_time, uid_push_qs, channel, res, file_list, notify)
+
+
+
+    @staticmethod
+    def get_nova_tag_recognition(base64_images, uid):
+        TIME_LOGGER.info(f'uid:{uid} Nova标签识别开始')
+        try:
+            redis_obj = RedisObject(db=6)
+            now = datetime.now()
+            year_month_day = now.strftime("%Y%m%d")  # 如20241005(单日标识)
+            daily_total_key = f"api:recognition:nova:daily:{year_month_day}:total"
+            # 使用incr命令保证原子操作
+            redis_obj.incr(daily_total_key, 1, 3600 * 24 * 7)
+        except Exception as e:
+            # Redis计数失败不影响主业务(降级处理),仅打日志
+            TIME_LOGGER.error(f"{uid}Redis统计接口访问量失败:{repr(e)}")
+        # 定义要检测的类别
+        categories_to_detect = ["person", "car", "pet", "package"]
+        # --- 记录process_image_batch执行时间 ---
+        start_time = time.perf_counter()  # 高精度计时
+        final_results = nova_image_tag_instance.process_image_batch(
+            base64_images,
+            categories_to_detect,
+            uid
+        )
+        # 无论是否发生异常,都记录执行时间
+        end_time = time.perf_counter()
+        execution_time = end_time - start_time
+        TIME_LOGGER.info(
+            f'uid:{uid} Nova标签识别批量处理完成 | '
+            f'执行时间: {execution_time:.4f}秒'
+        )
+        return final_results

+ 172 - 0
Controller/DrawBoxesOnImageController.py

@@ -0,0 +1,172 @@
+# -*- encoding: utf-8 -*-
+"""
+@File    : DrawBoxesOnImageController.py
+@Time    : 2025/9/2 09:21
+@Author  : stephen
+@Email   : zhangdongming@asj6.wecom.work
+@Software: PyCharm
+"""
+import base64
+import io
+import json
+import os
+import time
+
+from PIL import Image, ImageDraw, ImageFont
+from django.views import View
+
+from AnsjerPush.config import BASE_DIR
+from Object.ResponseObject import ResponseObject
+
+
+class DrawBoxesOnImageView(View):
+    def get(self, request, *args, **kwargs):
+        request.encoding = 'utf-8'
+        operation = kwargs.get('operation')
+        return self.validation(request.GET, request, operation)
+
+    def post(self, request, *args, **kwargs):
+        request.encoding = 'utf-8'
+        operation = kwargs.get('operation')
+        return self.validation(request.POST, request, operation)
+
+    def validation(self, request_dict, request, operation):
+        response = ResponseObject()
+        if operation is None:
+            return response.json(444, 'error path')
+        elif operation == 'drawBoxesFromRequest':
+            return self.draw_boxes_from_request(request_dict, response)
+        return response.json(414)
+
+    def draw_boxes_from_request(self, request_dict, response: ResponseObject):
+        """
+        处理绘制边界框的API请求。
+        预期请求体 (JSON):
+        {
+            "fileOne": "base64_string_1",
+            "fileTwo": "base64_string_2",
+            "fileThree": "base64_string_3", // 可选
+            "coordinates": { ... } // 包含所有图片坐标的JSON对象
+        }
+        """
+        # 1. 定义图片保存的相对路径和绝对路径
+        #    请确保 settings.BASE_DIR 指向您的Django项目根目录
+        relative_save_dir = os.path.join('static', 'ai')
+        absolute_save_dir = os.path.join(BASE_DIR, relative_save_dir)
+
+        # 确保目标目录存在
+        os.makedirs(absolute_save_dir, exist_ok=True)
+
+        # 2. 从请求中提取图片和坐标数据
+        base64_images = {
+            "file_0": request_dict.get("fileOne"),
+            "file_1": request_dict.get("fileTwo"),
+            "file_2": request_dict.get("fileThree"),
+        }
+
+        # 坐标可以直接是字典,也可以是JSON字符串,这里做兼容处理
+        coordinates_data = request_dict.get("coordinates")
+        if isinstance(coordinates_data, str):
+            try:
+                coordinates_data = json.loads(coordinates_data)
+            except json.JSONDecodeError:
+                return response.json(400, "coordinates字段不是一个有效的JSON字符串")
+
+        if not coordinates_data or not isinstance(coordinates_data, dict):
+            return response.json(400, "缺少或无效的 'coordinates' 数据")
+
+        processed_files = []
+        errors = []
+
+        # 3. 循环处理每张图片
+        for i in range(3):
+            image_key = f"file_{i}"
+
+            base64_str = base64_images.get(image_key)
+            detections = coordinates_data.get(image_key)
+
+            # 如果图片和对应的坐标都存在,则进行处理
+            if base64_str and detections:
+                # 生成一个唯一的文件名以避免冲突
+                timestamp = int(time.time() * 1000)
+                output_filename = f"result_{image_key}_{timestamp}.jpg"
+                output_path = os.path.join(absolute_save_dir, output_filename)
+
+                # 调用绘图函数
+                success = self.draw_boxes_on_image(base64_str, detections, output_path)
+
+                if success:
+                    # 返回可供前端访问的静态文件URL
+                    static_url = os.path.join(BASE_DIR, 'static', 'ai', output_filename).replace("\\", "/")
+                    processed_files.append({"source": image_key, "url": static_url})
+                else:
+                    errors.append(f"处理 {image_key} 失败")
+
+        # 4. 根据处理结果返回响应
+        if not processed_files and not errors:
+            return response.json(400, "未提供任何有效的图片和坐标数据进行处理")
+        elif errors:
+            return response.json(500, {"errors": errors, "success": processed_files})
+        else:
+            return response.json(0, {"processed_files": processed_files})
+
+    def draw_boxes_on_image(self, base64_image_string: str, detections: list, output_path: str) -> bool:
+        """
+        在一个Base64编码的图片上根据提供的检测坐标绘制边界框,并保存到指定路径。
+
+        Args:
+            base64_image_string (str): 图片的Base64编码字符串。
+            detections (list): 一个包含检测对象信息的字典列表。
+                               每个字典应包含 'class' 和 'Left', 'Top', 'Width', 'Height' 键。
+            output_path (str): 带有标注的图片的完整保存路径。
+
+        Returns:
+            bool: 如果成功保存图片则返回 True,否则返回 False。
+        """
+        if not base64_image_string or not detections:
+            print("错误:未提供图片数据或检测坐标。")
+            return False
+
+        try:
+            # 1. 解码Base64图片并加载
+            image_bytes = base64.b64decode(base64_image_string)
+            image = Image.open(io.BytesIO(image_bytes))
+            draw = ImageDraw.Draw(image)
+            img_width, img_height = image.size
+
+            # 2. 尝试加载字体,如果失败则使用默认字体
+            try:
+                # 确保服务器上有这个字体文件,或者换成一个您确定存在的字体路径
+                font = ImageFont.truetype("arial.ttf", size=20)
+            except IOError:
+                print("警告: 未找到arial.ttf字体,将使用默认字体。")
+                font = ImageFont.load_default()
+
+            colors = ["red", "blue", "green", "yellow", "purple", "orange", "cyan", "magenta"]
+
+            # 3. 遍历所有检测框并绘制
+            for idx, item in enumerate(detections):
+                label = item.get("class", "unknown")
+
+                left_ratio = float(item["Left"])
+                top_ratio = float(item["Top"])
+                width_ratio = float(item["Width"])
+                height_ratio = float(item["Height"])
+
+                x1 = int(left_ratio * img_width)
+                y1 = int(top_ratio * img_height)
+                x2 = int((left_ratio + width_ratio) * img_width)
+                y2 = int((top_ratio + height_ratio) * img_height)
+
+                color = colors[idx % len(colors)]
+                draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
+                draw.text((x1 + 4, y1 + 2), label, fill=color, font=font)
+
+            # 4. 保存绘制好的图片
+            image.save(output_path)
+            print(f"检测结果图片已成功保存到: {output_path}")
+            return True
+
+        except Exception as e:
+            print(f"在绘制和保存图片时发生错误: {e}")
+            return False

+ 167 - 0
Object/NovaImageTagObject.py

@@ -0,0 +1,167 @@
+# -*- encoding: utf-8 -*-
+"""
+@File    : NovaImageTagObject.py
+@Time    : 2025/8/29 09:03
+@Author  : stephen
+@Email   : zhangdongming@asj6.wecom.work
+@Software: PyCharm
+"""
+import base64
+import imghdr
+import json
+import logging
+import re
+
+import boto3
+
+LOGGER = logging.getLogger('time')
+
+# --- 配置信息 ---
+MODEL_ID = "us.amazon.nova-lite-v1:0"
+
+
+class NovaImageTagObject(object):
+    def __init__(self, aws_access_key_id, secret_access_key, region_name):
+        self.bedrock = boto3.client(
+            'bedrock-runtime',
+            aws_access_key_id=aws_access_key_id,
+            aws_secret_access_key=secret_access_key,
+            region_name=region_name
+        )
+
+    @staticmethod
+    def safe_json_load(json_string):
+        """
+        一个更健壮的JSON解析函数,尝试修复常见的模型输出格式问题。
+        """
+        try:
+            # 寻找被代码块包围的JSON
+            json_match = re.search(r'```json\s*([\s\S]*?)\s*```', json_string)
+            if json_match:
+                json_string = json_match.group(1)
+
+            # 寻找常规的JSON对象或数组
+            json_match = re.search(r'\{.*\}|\[.*\]', json_string, re.DOTALL)
+            if json_match:
+                json_string = json_match.group(0)
+
+            return json.loads(json_string)
+        except json.JSONDecodeError:
+            LOGGER.error("JSON解析失败,尝试修复...")
+            try:
+                json_string = re.sub(r"(\w+):", r'"\1":', json_string)
+                json_string = json_string.replace("'", '"')
+                return json.loads(json_string)
+            except Exception as e:
+                LOGGER.error(f"无法解析模型返回的JSON: {e}")
+                return None
+        except Exception as e:
+            LOGGER.error(f"发生未知解析错误: {e}")
+            return None
+
+    @staticmethod
+    def format_and_convert_detections(nova_detections: list) -> list:
+        """
+        将Nova模型返回的坐标转换为您指定的详细格式,包含原始坐标和Rekognition比例。
+        """
+        formatted_results = []
+        if not isinstance(nova_detections, list):
+            return []
+
+        for item in nova_detections:
+            if not isinstance(item, dict): continue
+
+            label = list(item.keys())[0]
+            nx1, ny1, nx2, ny2 = item[label]
+
+            left = nx1 / 1000.0
+            top = ny1 / 1000.0
+            width = (nx2 - nx1) / 1000.0
+            height = (ny2 - ny1) / 1000.0
+
+            formatted_results.append({
+                "x1": nx1, "x2": nx2, "y1": ny1, "y2": ny2,
+                "Width": f"{width:.5f}", "Height": f"{height:.5f}",
+                "Top": f"{top:.5f}", "Left": f"{left:.5f}",
+                "class": label
+            })
+        return formatted_results
+
+    def process_image_batch(self, base64_images: list, categories: list, uid=''):
+        """
+        通过单次API调用处理一批图片,并返回结构化的检测结果。
+        """
+        if not base64_images:
+            LOGGER.error(f"{uid}错误: 未提供图片数据。")
+            return {}
+
+        image_contents = []
+        img_bytes_list = []
+        for b64_image in base64_images:
+            try:
+                img_bytes = base64.b64decode(b64_image)
+                img_type = imghdr.what(None, h=img_bytes)
+                if img_type.lower() not in ["jpeg", "jpg", "png", "webp"]:
+                    raise ValueError(f"不支持的图片格式: {img_type}")
+                image_contents.append({"image": {"format": img_type, "source": {"bytes": img_bytes}}})
+                img_bytes_list.append(img_bytes)
+            except Exception as e:
+                LOGGER.error(f"{uid}处理图片时出错,已跳过: {repr(e)}")
+                img_bytes_list.append(None)  # 添加占位符以保持索引一致
+
+        if not image_contents:
+            LOGGER.error(f"{uid}错误: 所有图片均无法处理。")
+            return {}
+
+        category_str = ", ".join([f'"{cat.lower()}"' for cat in categories])
+        num_images = len(image_contents)
+
+        # --- 关键改动:为多图片设计的全新Prompt ---
+        prompt = f"""
+    You have been provided with {num_images} images. Analyze each image sequentially.
+    For each image, detect bounding boxes of objects from the following categories: {category_str}.
+    Your output MUST be a single, valid JSON object.
+    The keys of this object should be "image_0", "image_1", ..., "image_{num_images - 1}", corresponding to the first, second, and subsequent images provided.
+    The value for each key must be a list of detected objects for that specific image. If no objects are detected in an image, the value should be an empty list [].
+    Use a 1000x1000 coordinate system for the bounding boxes.
+
+    Example output format for {num_images} images:
+    {{
+      "image_0": [{{"person": [100, 150, 200, 350]}}, {{"car": [400, 500, 600, 700]}}],
+      "image_1": [],
+      "image_2": [{{"package": [300, 300, 400, 400]}}]
+    }}
+    """
+
+        messages = [{"role": "user", "content": image_contents + [{"text": prompt}]}]
+
+        try:
+            response = self.bedrock.converse(
+                modelId=MODEL_ID,
+                messages=messages,
+                inferenceConfig={"temperature": 0.0, "maxTokens": 4096, "topP": 1.0},
+            )
+            model_output = response["output"]["message"]["content"][0]["text"]
+            LOGGER.info(f"\n--- {uid}模型对整个批次的原始输出 ---\n{model_output}")
+
+            # 解析模型返回的包含所有图片结果的JSON对象
+            batch_results = self.safe_json_load(model_output)
+            if not batch_results or not isinstance(batch_results, dict):
+                LOGGER.error(f"{uid}模型未返回预期的字典格式结果。")
+                return {}
+
+            # --- 核心逻辑:将批处理结果映射回您的格式 ---
+            final_output_dict = {}
+            for i in range(len(base64_images)):
+                # 从批处理结果中获取当前图片的数据,如果不存在则默认为空列表
+                nova_detections = batch_results.get(f"image_{i}", [])
+
+                # 转换为您最终需要的格式
+                detailed_results = self.format_and_convert_detections(nova_detections)
+                final_output_dict[f"file_{i}"] = detailed_results
+
+            return final_output_dict
+
+        except Exception as e:
+            LOGGER.error(f"{uid}调用Bedrock模型或处理过程中发生错误: {repr(e)}")
+            return {}

+ 21 - 2
Object/RedisObject.py

@@ -2,7 +2,7 @@ import redis
 from AnsjerPush.config import REDIS_ADDRESS
 
 # 本地调试把注释打开
-# REDIS_ADDRESS = '127.0.0.1'
+REDIS_ADDRESS = '127.0.0.1'
 
 
 class RedisObject:
@@ -107,4 +107,23 @@ class RedisObject:
             return val.decode('utf-8') if val else None
         except Exception as e:
             print(f"Redis hget error: {repr(e)}")
-            return None
+            return None
+
+    def incr(self, key, amount=1, ttl=0):
+        """
+        增加计数器的值
+        :param key: 键名,用于存储计数器的 Redis 键
+        :param amount: 增加的数量,默认为 1
+        :param ttl: 键的过期时间(秒)
+        :return: 更新后的计数值,若发生异常则返回 False
+        """
+        try:
+            # 增加计数器
+            result = self.CONN.incrby(key, amount)
+            # 设置过期时间
+            if ttl > 0:
+                self.CONN.expire(key, ttl)
+            return result
+        except Exception as e:
+            print(repr(e))
+            return False

+ 1 - 1
Object/SageMakerAiObject.py

@@ -100,7 +100,7 @@ class SageMakerAiObject:
         try:
             LOGGER.info(f'***get_table_name***uid={uid}')
             # 定义标签组的映射关系
-            ai_table_groups = {'person': 1, 'cat': 2, 'dog': 2, 'vehicle': 3, 'package': 4}
+            ai_table_groups = {'person': 1, 'cat': 2, 'dog': 2, 'pet': 2, 'vehicle': 3, 'car': 3, 'package': 4}
             event_group = []
 
             # 获取识别到的标签以及坐标信息