瀏覽代碼

修改ai图片根据配置保存到不同区域

zhangdongming 1 年之前
父節點
當前提交
c46a3c704e
共有 1 個文件被更改,包括 24 次插入13 次删除
  1. 24 13
      Object/SageMakerAiObject.py

+ 24 - 13
Object/SageMakerAiObject.py

@@ -18,7 +18,8 @@ import numpy as np
 from PIL import Image
 
 from AnsjerPush.Config.aiConfig import AI_IDENTIFICATION_TAGS_DICT
-from AnsjerPush.config import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY
+from AnsjerPush.config import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, CONFIG_US, CONFIG_EUR
+from AnsjerPush.config import CONFIG_INFO
 from Controller import AiController
 from Object.AiEngineObject import AiEngine
 from Object.enums.MessageTypeEnum import MessageTypeEnum
@@ -286,24 +287,34 @@ class SageMakerAiObject:
         """
         上传图片到S3
         """
-        # 创建 AWS 访问密钥
-        aws_key = AWS_ACCESS_KEY_ID[1]
-        aws_secret = AWS_SECRET_ACCESS_KEY[1]
-        # 创建会话对象
-        session = boto3.Session(
-            aws_access_key_id=aws_key,
-            aws_secret_access_key=aws_secret,
-            region_name="us-east-1"
-        )
+        if CONFIG_INFO == CONFIG_US or CONFIG_INFO == CONFIG_EUR:
+            # 创建 AWS 访问密钥
+            aws_key = AWS_ACCESS_KEY_ID[1]
+            aws_secret = AWS_SECRET_ACCESS_KEY[1]
+            # 创建会话对象
+            session = boto3.Session(
+                aws_access_key_id=aws_key,
+                aws_secret_access_key=aws_secret,
+                region_name="us-east-1"
+            )
+            s3_resource = session.resource("s3")
+            bucket_name = "foreignpush"
+        else:
+            # 存国内
+            aws_key = AWS_ACCESS_KEY_ID[0]
+            aws_secret = AWS_SECRET_ACCESS_KEY[0]
+            session = boto3.Session(
+                aws_access_key_id=aws_key,
+                aws_secret_access_key=aws_secret,
+                region_name="cn-northwest-1")
+            s3_resource = session.resource("s3")
+            bucket_name = "push"
 
-        # 创建 S3 资源对象
-        s3_resource = session.resource("s3")
         try:
             # 解码base64字符串为图像数据
             for index, pic in enumerate(file_list):
                 image_data = base64.b64decode(pic)
                 # 指定存储桶名称和对象键
-                bucket_name = "foreignpush"
                 object_key = f'{uid}/{channel}/{d_push_time}_{index}.jpeg'
 
                 # 获取指定的存储桶