import base64 import time import uuid from google import genai from google.genai import types from app.config.config import settings from app.core.constants import VALID_IMAGE_RATIOS from app.domain.openai_models import ImageGenerationRequest from app.log.logger import get_image_create_logger from app.utils.uploader import ImageUploaderFactory logger = get_image_create_logger() class ImageCreateService: def __init__(self, aspect_ratio="1:1"): self.image_model = settings.CREATE_IMAGE_MODEL self.paid_key = settings.PAID_KEY self.aspect_ratio = aspect_ratio def parse_prompt_parameters(self, prompt: str) -> tuple: """从prompt中解析参数 支持的格式: - {n:数量} 例如: {n:2} 生成2张图片 - {ratio:比例} 例如: {ratio:16:9} 使用16:9比例 """ import re # 默认值 n = 1 aspect_ratio = self.aspect_ratio # 解析n参数 n_match = re.search(r"{n:(\d+)}", prompt) if n_match: n = int(n_match.group(1)) if n < 1 or n > 4: raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.") prompt = prompt.replace(n_match.group(0), "").strip() # 解析ratio参数 ratio_match = re.search(r"{ratio:(\d+:\d+)}", prompt) if ratio_match: aspect_ratio = ratio_match.group(1) if aspect_ratio not in VALID_IMAGE_RATIOS: raise ValueError( f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}" ) prompt = prompt.replace(ratio_match.group(0), "").strip() return prompt, n, aspect_ratio def generate_images(self, request: ImageGenerationRequest): client = genai.Client(api_key=self.paid_key) if request.size == "1024x1024": self.aspect_ratio = "1:1" elif request.size == "1792x1024": self.aspect_ratio = "16:9" elif request.size == "1027x1792": self.aspect_ratio = "9:16" else: raise ValueError( f"Invalid size: {request.size}. Supported sizes are 1024x1024, 1792x1024, and 1024x1792." ) # 解析prompt中的参数 cleaned_prompt, prompt_n, prompt_ratio = self.parse_prompt_parameters( request.prompt ) request.prompt = cleaned_prompt # 如果prompt中指定了n,则覆盖请求中的n if prompt_n > 1: request.n = prompt_n # 如果prompt中指定了ratio,则覆盖默认的aspect_ratio if prompt_ratio != self.aspect_ratio: self.aspect_ratio = prompt_ratio response = client.models.generate_images( model=self.image_model, prompt=request.prompt, config=types.GenerateImagesConfig( number_of_images=request.n, output_mime_type="image/png", aspect_ratio=self.aspect_ratio, safety_filter_level="BLOCK_LOW_AND_ABOVE", person_generation="ALLOW_ADULT", # language="auto" ), ) if response.generated_images: images_data = [] for index, generated_image in enumerate(response.generated_images): image_data = generated_image.image.image_bytes image_uploader = None if request.response_format == "b64_json": base64_image = base64.b64encode(image_data).decode("utf-8") images_data.append( {"b64_json": base64_image, "revised_prompt": request.prompt} ) else: current_date = time.strftime("%Y/%m/%d") filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png" if settings.UPLOAD_PROVIDER == "smms": image_uploader = ImageUploaderFactory.create( provider=settings.UPLOAD_PROVIDER, api_key=settings.SMMS_SECRET_TOKEN, ) elif settings.UPLOAD_PROVIDER == "picgo": image_uploader = ImageUploaderFactory.create( provider=settings.UPLOAD_PROVIDER, api_key=settings.PICGO_API_KEY, ) elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed": image_uploader = ImageUploaderFactory.create( provider=settings.UPLOAD_PROVIDER, base_url=settings.CLOUDFLARE_IMGBED_URL, auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE, ) else: raise ValueError( f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}" ) upload_response = image_uploader.upload(image_data, filename) images_data.append( { "url": f"{upload_response.data.url}", "revised_prompt": request.prompt, } ) response_data = { "created": int(time.time()), # Current timestamp "data": images_data, } return response_data else: raise Exception("I can't generate these images") def generate_images_chat(self, request: ImageGenerationRequest) -> str: response = self.generate_images(request) image_datas = response["data"] if image_datas: markdown_images = [] for index, image_data in enumerate(image_datas): if "url" in image_data: markdown_images.append( f"![Generated Image {index+1}]({image_data['url']})" ) else: # 如果是base64格式,创建data URL markdown_images.append( f"![Generated Image {index+1}](data:image/png;base64,{image_data['b64_json']})" ) return "\n".join(markdown_images)