#!/usr/bin/env python3 """ AI image generation module for WeWrite. Supports multiple providers via a simple abstraction: - doubao-seedream (Volcengine Ark) — default, good for Chinese prompts - openai (DALL-E 3) — broad availability - gemini (Google Gemini Imagen) — multimodal image generation - Custom providers via ImageProvider base class Usage as CLI: python3 image_gen.py --prompt "描述" --output cover.png python3 image_gen.py --prompt "描述" --output cover.png --size cover python3 image_gen.py --prompt "描述" --output cover.png --provider gemini Usage as module: from image_gen import generate_image path = generate_image("prompt text", "output.png", size="cover") """ import abc import argparse import base64 import json import sys from pathlib import Path import requests import yaml # --- Config --- CONFIG_PATHS = [ Path.cwd() / "config.yaml", Path(__file__).parent.parent / "config.yaml", # skill root Path(__file__).parent / "config.yaml", # toolkit dir Path.home() / ".config" / "wewrite" / "config.yaml", ] def _load_config() -> dict: for p in CONFIG_PATHS: if p.exists(): with open(p, "r", encoding="utf-8") as f: return yaml.safe_load(f) or {} return {} # --- Size presets --- # Cover: 2.35:1 微信封面比例 # Article: 16:9 横版内文配图 # Vertical: 9:16 竖版 SIZE_PRESETS = { "cover": {"doubao": "2952x1256", "openai": "1792x1024", "gemini": "1792x1024"}, "article": {"doubao": "2560x1440", "openai": "1792x1024", "gemini": "1792x1024"}, "vertical": {"doubao": "1088x2560", "openai": "1024x1792", "gemini": "1024x1792"}, "square": {"doubao": "2048x2048", "openai": "1024x1024", "gemini": "1024x1024"}, } MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB def _compress_image(raw_bytes: bytes, max_size: int) -> bytes: """Compress image to fit under max_size by reducing JPEG quality.""" from io import BytesIO from PIL import Image img = Image.open(BytesIO(raw_bytes)) if img.mode == "RGBA": img = img.convert("RGB") for quality in (90, 80, 70, 60, 50): buf = BytesIO() img.save(buf, format="JPEG", quality=quality, optimize=True) if buf.tell() <= max_size: return buf.getvalue() return buf.getvalue() # --- Provider abstraction --- class ImageProvider(abc.ABC): """Base class for image generation providers.""" @abc.abstractmethod def generate(self, prompt: str, size: str) -> bytes: """Generate an image and return raw bytes. Args: prompt: Image description (Chinese or English). size: Resolved size string (e.g. "1792x1024"). Returns: Raw image bytes. """ ... def resolve_size(self, preset: str) -> str: """Resolve a size preset to a concrete size string for this provider.""" provider_key = self.provider_key if preset in SIZE_PRESETS: return SIZE_PRESETS[preset].get(provider_key, list(SIZE_PRESETS[preset].values())[0]) return preset # assume explicit WxH @property @abc.abstractmethod def provider_key(self) -> str: """Short identifier used for size preset lookup.""" ... class DoubaoProvider(ImageProvider): """doubao-seedream via Volcengine Ark API.""" provider_key = "doubao" def __init__(self, api_key: str, model: str = "doubao-seedream-5-0-260128", base_url: str = "https://ark.cn-beijing.volces.com/api/v3"): self._api_key = api_key self._model = model self._base_url = base_url def generate(self, prompt: str, size: str) -> bytes: body = { "model": self._model, "prompt": prompt, "response_format": "url", "size": size, "stream": False, "watermark": False, } resp = requests.post( f"{self._base_url}/images/generations", headers={ "Content-Type": "application/json", "Authorization": f"Bearer {self._api_key}", }, json=body, timeout=120, ) data = resp.json() if resp.status_code != 200: error = data.get("error", {}) msg = error.get("message", json.dumps(data, ensure_ascii=False)) raise ValueError(f"Doubao API error ({resp.status_code}): {msg}") image_data = data.get("data", []) if not image_data: raise ValueError(f"No image returned: {json.dumps(data, ensure_ascii=False)}") image_url = image_data[0].get("url") if not image_url: raise ValueError(f"No image URL in response: {json.dumps(data, ensure_ascii=False)}") img_resp = requests.get(image_url, timeout=60) img_resp.raise_for_status() return img_resp.content class OpenAIProvider(ImageProvider): """OpenAI DALL-E 3 provider.""" provider_key = "openai" def __init__(self, api_key: str, model: str = "dall-e-3", base_url: str = "https://api.openai.com/v1"): self._api_key = api_key self._model = model self._base_url = base_url def generate(self, prompt: str, size: str) -> bytes: # DALL-E 3 expects size as "WxH" format dall_e_size = size.replace("x", "x") # normalize body = { "model": self._model, "prompt": prompt, "n": 1, "size": dall_e_size, "response_format": "url", } resp = requests.post( f"{self._base_url}/images/generations", headers={ "Content-Type": "application/json", "Authorization": f"Bearer {self._api_key}", }, json=body, timeout=120, ) data = resp.json() if resp.status_code != 200: error = data.get("error", {}) msg = error.get("message", json.dumps(data, ensure_ascii=False)) raise ValueError(f"OpenAI API error ({resp.status_code}): {msg}") image_data = data.get("data", []) if not image_data: raise ValueError(f"No image returned: {json.dumps(data, ensure_ascii=False)}") image_url = image_data[0].get("url") if not image_url: raise ValueError(f"No image URL in response: {json.dumps(data, ensure_ascii=False)}") img_resp = requests.get(image_url, timeout=60) img_resp.raise_for_status() return img_resp.content class GeminiProvider(ImageProvider): """Google Gemini Imagen provider.""" provider_key = "gemini" def __init__(self, api_key: str, model: str = "gemini-3.1-flash-image-preview", base_url: str = "https://generativelanguage.googleapis.com/v1beta"): self._api_key = api_key self._model = model self._base_url = base_url def generate(self, prompt: str, size: str) -> bytes: body = { "contents": [{"parts": [{"text": prompt}]}], "generationConfig": {"responseModalities": ["TEXT", "IMAGE"]}, } session = requests.Session() session.trust_env = False resp = session.post( f"{self._base_url}/models/{self._model}:generateContent?key={self._api_key}", headers={"Content-Type": "application/json"}, json=body, timeout=120, ) data = resp.json() if resp.status_code != 200: error = data.get("error", {}) msg = error.get("message", json.dumps(data, ensure_ascii=False)) raise ValueError(f"Gemini API error ({resp.status_code}): {msg}") candidates = data.get("candidates", []) if not candidates: raise ValueError(f"No candidates in Gemini response") parts = candidates[0].get("content", {}).get("parts", []) for part in parts: inline_data = part.get("inlineData") if inline_data and inline_data.get("mimeType", "").startswith("image/"): return base64.b64decode(inline_data["data"]) raise ValueError(f"No image found in Gemini response parts") # --- Provider registry --- PROVIDERS = { "doubao": DoubaoProvider, "openai": OpenAIProvider, "gemini": GeminiProvider, } def _build_provider(config: dict) -> ImageProvider: """Build an ImageProvider from config.yaml's image section.""" img_cfg = config.get("image", {}) provider_name = img_cfg.get("provider", "doubao") api_key = img_cfg.get("api_key") if not api_key: raise ValueError( f"image.api_key not set in config.yaml. " f"Configure your {provider_name} API key to enable image generation." ) provider_cls = PROVIDERS.get(provider_name) if not provider_cls: raise ValueError( f"Unknown image provider: '{provider_name}'. " f"Available: {', '.join(PROVIDERS.keys())}" ) kwargs = {"api_key": api_key} if img_cfg.get("model"): kwargs["model"] = img_cfg["model"] if img_cfg.get("base_url"): kwargs["base_url"] = img_cfg["base_url"] return provider_cls(**kwargs) # --- Public API --- def generate_image( prompt: str, output_path: str, size: str = "cover", config: dict = None, ) -> str: """ Generate an image using the configured provider. Args: prompt: Image generation prompt (Chinese or English). output_path: Where to save the image. size: Size preset ("cover", "article", "vertical", "square") or explicit "WxH". config: Optional config dict. If None, loads from config.yaml. Returns: The output file path. """ if config is None: config = _load_config() provider = _build_provider(config) resolved_size = provider.resolve_size(size) raw_bytes = provider.generate(prompt, resolved_size) # Compress if over 5MB (WeChat upload limit) if len(raw_bytes) > MAX_FILE_SIZE: raw_bytes = _compress_image(raw_bytes, MAX_FILE_SIZE) output = Path(output_path) output.parent.mkdir(parents=True, exist_ok=True) output.write_bytes(raw_bytes) return str(output) def main(): parser = argparse.ArgumentParser( description="Generate images using AI (doubao-seedream, OpenAI DALL-E, Gemini Imagen, etc.)" ) parser.add_argument("--prompt", required=True, help="Image generation prompt") parser.add_argument("--output", required=True, help="Output file path") parser.add_argument( "--size", default="cover", help="Size: cover, article, vertical, square, or WxH", ) parser.add_argument( "--provider", default=None, help="Override provider (doubao, openai, gemini). Default: from config.yaml", ) args = parser.parse_args() try: config = _load_config() if args.provider: config.setdefault("image", {})["provider"] = args.provider path = generate_image(args.prompt, args.output, size=args.size, config=config) print(f"Image saved: {path}") except Exception as e: print(f"Error: {e}", file=sys.stderr) sys.exit(1) if __name__ == "__main__": main()