mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-05-11 18:09:55 +08:00
156 lines
6.2 KiB
Python
156 lines
6.2 KiB
Python
import time
|
||
import uuid
|
||
|
||
from google import genai
|
||
from google.genai import types
|
||
import base64
|
||
|
||
from app.core.config import settings
|
||
from app.core.logger import get_image_create_logger
|
||
from app.core.uploader import ImageUploaderFactory
|
||
from app.schemas.openai_models import ImageGenerationRequest
|
||
|
||
logger = get_image_create_logger()
|
||
|
||
|
||
class ImageCreateService:
|
||
def __init__(self, aspect_ratio="1:1"):
|
||
self.image_model = settings.CREATE_IMAGE_MODEL
|
||
self.paid_key = settings.PAID_KEY
|
||
self.aspect_ratio = aspect_ratio
|
||
|
||
def parse_prompt_parameters(self, prompt: str) -> tuple:
|
||
"""从prompt中解析参数
|
||
支持的格式:
|
||
- {n:数量} 例如: {n:2} 生成2张图片
|
||
- {ratio:比例} 例如: {ratio:16:9} 使用16:9比例
|
||
"""
|
||
import re
|
||
|
||
# 默认值
|
||
n = 1
|
||
aspect_ratio = self.aspect_ratio
|
||
|
||
# 解析n参数
|
||
n_match = re.search(r'{n:(\d+)}', prompt)
|
||
if n_match:
|
||
n = int(n_match.group(1))
|
||
if n < 1 or n > 4:
|
||
raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.")
|
||
prompt = prompt.replace(n_match.group(0), '').strip()
|
||
|
||
# 解析ratio参数
|
||
ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt)
|
||
if ratio_match:
|
||
aspect_ratio = ratio_match.group(1)
|
||
valid_ratios = ["1:1", "3:4", "4:3", "9:16", "16:9"]
|
||
if aspect_ratio not in valid_ratios:
|
||
raise ValueError(
|
||
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(valid_ratios)}"
|
||
)
|
||
prompt = prompt.replace(ratio_match.group(0), '').strip()
|
||
|
||
return prompt, n, aspect_ratio
|
||
|
||
def generate_images(self, request: ImageGenerationRequest):
|
||
client = genai.Client(api_key=self.paid_key)
|
||
|
||
if request.size == "1024x1024":
|
||
self.aspect_ratio = "1:1"
|
||
elif request.size == "1792x1024":
|
||
self.aspect_ratio = "16:9"
|
||
elif request.size == "1027x1792":
|
||
self.aspect_ratio = "9:16"
|
||
else:
|
||
raise ValueError(
|
||
f"Invalid size: {request.size}. Supported sizes are 1024x1024, 1792x1024, and 1024x1792."
|
||
)
|
||
|
||
# 解析prompt中的参数
|
||
cleaned_prompt, prompt_n, prompt_ratio = self.parse_prompt_parameters(request.prompt)
|
||
request.prompt = cleaned_prompt
|
||
|
||
# 如果prompt中指定了n,则覆盖请求中的n
|
||
if prompt_n > 1:
|
||
request.n = prompt_n
|
||
|
||
# 如果prompt中指定了ratio,则覆盖默认的aspect_ratio
|
||
if prompt_ratio != self.aspect_ratio:
|
||
self.aspect_ratio = prompt_ratio
|
||
|
||
response = client.models.generate_images(
|
||
model=self.image_model,
|
||
prompt=request.prompt,
|
||
config=types.GenerateImagesConfig(
|
||
number_of_images=request.n,
|
||
output_mime_type="image/png",
|
||
aspect_ratio=self.aspect_ratio,
|
||
safety_filter_level="BLOCK_LOW_AND_ABOVE",
|
||
person_generation="ALLOW_ADULT",
|
||
# language="auto"
|
||
),
|
||
)
|
||
|
||
if response.generated_images:
|
||
images_data = []
|
||
for index, generated_image in enumerate(response.generated_images):
|
||
image_data = generated_image.image.image_bytes
|
||
image_uploader = None
|
||
|
||
if request.response_format == "b64_json":
|
||
base64_image = base64.b64encode(image_data).decode('utf-8')
|
||
images_data.append({
|
||
"b64_json": base64_image,
|
||
"revised_prompt": request.prompt
|
||
})
|
||
else:
|
||
current_date = time.strftime("%Y/%m/%d")
|
||
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
|
||
|
||
if settings.UPLOAD_PROVIDER == "smms":
|
||
image_uploader = ImageUploaderFactory.create(
|
||
provider=settings.UPLOAD_PROVIDER,
|
||
api_key=settings.SMMS_SECRET_TOKEN
|
||
)
|
||
elif settings.UPLOAD_PROVIDER == "picgo":
|
||
image_uploader = ImageUploaderFactory.create(
|
||
provider=settings.UPLOAD_PROVIDER,
|
||
api_key=settings.PICGO_API_KEY
|
||
)
|
||
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
|
||
image_uploader = ImageUploaderFactory.create(
|
||
provider=settings.UPLOAD_PROVIDER,
|
||
base_url=settings.CLOUDFLARE_IMGBED_URL,
|
||
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE
|
||
)
|
||
else:
|
||
raise ValueError(f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}")
|
||
|
||
upload_response = image_uploader.upload(image_data, filename)
|
||
|
||
images_data.append({
|
||
"url": f"{upload_response.data.url}",
|
||
"revised_prompt": request.prompt
|
||
})
|
||
|
||
response_data = {
|
||
"created": int(time.time()), # Current timestamp
|
||
"data": images_data
|
||
}
|
||
return response_data
|
||
else:
|
||
raise Exception("I can't generate these images")
|
||
|
||
def generate_images_chat(self, request: ImageGenerationRequest) -> str:
|
||
response = self.generate_images(request)
|
||
image_datas = response["data"]
|
||
if image_datas:
|
||
markdown_images = []
|
||
for index, image_data in enumerate(image_datas):
|
||
if 'url' in image_data:
|
||
markdown_images.append(f"")
|
||
else:
|
||
# 如果是base64格式,创建data URL
|
||
markdown_images.append(f"")
|
||
return "\n".join(markdown_images)
|