wewrite/dist/openclaw/toolkit/image_gen.py
wangzhuc 25d6a44082 feat: add article content extraction with anti-scraping fallback
- New `scripts/fetch_article.py`: extract WeChat article content as Markdown
  with three-level fetch strategy (requests → Playwright → manual HTML)
- Refactor `learn_theme.py` to reuse `fetch_article.fetch_html()`, removing
  duplicate fetch logic
- Update SKILL.md: add "学习这篇文章/导入范文" auxiliary function
- Update README.md: add article extraction to feature table and directory tree

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-01 16:34:13 +00:00

754 lines
27 KiB
Python

#!/usr/bin/env python3
"""
AI image generation module for WeWrite.
Supports multiple providers via a simple abstraction:
- doubao-seedream (Volcengine Ark) — default, good for Chinese prompts
- openai (DALL-E 3) — broad availability
- gemini (Google Gemini Imagen) — multimodal image generation
- dashscope (Alibaba Tongyi Wanxiang) — good for Chinese prompts
- minimax — Chinese provider
- replicate — open-source models
- azure_openai — Azure-hosted DALL-E
- openrouter — multi-model proxy
- jimeng (ByteDance) — good for Chinese prompts
- Custom providers via ImageProvider base class
Usage as CLI:
python3 image_gen.py --prompt "描述" --output cover.png
python3 image_gen.py --prompt "描述" --output cover.png --size cover
python3 image_gen.py --prompt "描述" --output cover.png --provider gemini
Usage as module:
from image_gen import generate_image
path = generate_image("prompt text", "output.png", size="cover")
"""
import abc
import argparse
import base64
import hashlib
import hmac
import json
import sys
import time
from datetime import datetime, timezone
from pathlib import Path
import requests
import yaml
# --- Config ---
CONFIG_PATHS = [
Path.cwd() / "config.yaml",
Path(__file__).parent.parent / "config.yaml", # skill root
Path(__file__).parent / "config.yaml", # toolkit dir
Path.home() / ".config" / "wewrite" / "config.yaml",
]
def _load_config() -> dict:
for p in CONFIG_PATHS:
if p.exists():
with open(p, "r", encoding="utf-8") as f:
return yaml.safe_load(f) or {}
return {}
# --- Size presets ---
# Cover: 2.35:1 微信封面比例
# Article: 16:9 横版内文配图
# Vertical: 9:16 竖版
_DEFAULT = "1792x1024"
_DEFAULT_V = "1024x1792"
_DEFAULT_SQ = "1024x1024"
SIZE_PRESETS = {
"cover": {
"doubao": "2952x1256", "openai": _DEFAULT, "gemini": _DEFAULT,
"dashscope": _DEFAULT, "minimax": _DEFAULT, "replicate": _DEFAULT,
"azure_openai": _DEFAULT, "openrouter": _DEFAULT, "jimeng": _DEFAULT,
},
"article": {
"doubao": "2560x1440", "openai": _DEFAULT, "gemini": _DEFAULT,
"dashscope": _DEFAULT, "minimax": _DEFAULT, "replicate": _DEFAULT,
"azure_openai": _DEFAULT, "openrouter": _DEFAULT, "jimeng": _DEFAULT,
},
"vertical": {
"doubao": "1088x2560", "openai": _DEFAULT_V, "gemini": _DEFAULT_V,
"dashscope": _DEFAULT_V, "minimax": _DEFAULT_V, "replicate": _DEFAULT_V,
"azure_openai": _DEFAULT_V, "openrouter": _DEFAULT_V, "jimeng": _DEFAULT_V,
},
"square": {
"doubao": "2048x2048", "openai": _DEFAULT_SQ, "gemini": _DEFAULT_SQ,
"dashscope": _DEFAULT_SQ, "minimax": _DEFAULT_SQ, "replicate": _DEFAULT_SQ,
"azure_openai": _DEFAULT_SQ, "openrouter": _DEFAULT_SQ, "jimeng": _DEFAULT_SQ,
},
}
MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB
def _compress_image(raw_bytes: bytes, max_size: int) -> bytes:
"""Compress image to fit under max_size by reducing JPEG quality."""
from io import BytesIO
from PIL import Image
img = Image.open(BytesIO(raw_bytes))
if img.mode == "RGBA":
img = img.convert("RGB")
for quality in (90, 80, 70, 60, 50):
buf = BytesIO()
img.save(buf, format="JPEG", quality=quality, optimize=True)
if buf.tell() <= max_size:
return buf.getvalue()
return buf.getvalue()
def _size_to_aspect(size: str) -> str:
"""Convert 'WxH' to nearest standard aspect ratio string."""
if ":" in size:
return size
try:
w, h = (int(x) for x in size.split("x", 1))
except ValueError:
return "16:9"
ratio = w / h
for ar, val in [("1:1", 1.0), ("16:9", 16/9), ("9:16", 9/16),
("4:3", 4/3), ("3:4", 3/4), ("3:2", 3/2), ("2:3", 2/3)]:
if abs(ratio - val) < 0.15:
return ar
return "16:9"
def _download_image(url: str) -> bytes:
"""Download image bytes from URL."""
resp = requests.get(url, timeout=60)
resp.raise_for_status()
return resp.content
# --- Provider abstraction ---
class ImageProvider(abc.ABC):
"""Base class for image generation providers."""
@abc.abstractmethod
def generate(self, prompt: str, size: str) -> bytes:
"""Generate an image and return raw bytes."""
...
def resolve_size(self, preset: str) -> str:
"""Resolve a size preset to a concrete size string for this provider."""
provider_key = self.provider_key
if preset in SIZE_PRESETS:
return SIZE_PRESETS[preset].get(provider_key, list(SIZE_PRESETS[preset].values())[0])
return preset
@property
@abc.abstractmethod
def provider_key(self) -> str:
...
# --- Providers ---
class DoubaoProvider(ImageProvider):
"""doubao-seedream via Volcengine Ark API."""
provider_key = "doubao"
def __init__(self, api_key: str, model: str = "doubao-seedream-5-0-260128",
base_url: str = "https://ark.cn-beijing.volces.com/api/v3", **_kw):
self._api_key = api_key
self._model = model
self._base_url = base_url
def generate(self, prompt: str, size: str) -> bytes:
resp = requests.post(
f"{self._base_url}/images/generations",
headers={"Content-Type": "application/json",
"Authorization": f"Bearer {self._api_key}"},
json={"model": self._model, "prompt": prompt,
"response_format": "url", "size": size,
"stream": False, "watermark": False},
timeout=120,
)
data = resp.json()
if resp.status_code != 200:
raise ValueError(f"Doubao error ({resp.status_code}): "
f"{data.get('error', {}).get('message', str(data))}")
url = data.get("data", [{}])[0].get("url")
if not url:
raise ValueError(f"No image URL: {data}")
return _download_image(url)
class OpenAIProvider(ImageProvider):
"""OpenAI DALL-E 3 provider."""
provider_key = "openai"
def __init__(self, api_key: str, model: str = "dall-e-3",
base_url: str = "https://api.openai.com/v1", **_kw):
self._api_key = api_key
self._model = model
self._base_url = base_url
def generate(self, prompt: str, size: str) -> bytes:
resp = requests.post(
f"{self._base_url}/images/generations",
headers={"Content-Type": "application/json",
"Authorization": f"Bearer {self._api_key}"},
json={"model": self._model, "prompt": prompt,
"n": 1, "size": size, "response_format": "url"},
timeout=120,
)
data = resp.json()
if resp.status_code != 200:
raise ValueError(f"OpenAI error ({resp.status_code}): "
f"{data.get('error', {}).get('message', str(data))}")
url = data.get("data", [{}])[0].get("url")
if not url:
raise ValueError(f"No image URL: {data}")
return _download_image(url)
class GeminiProvider(ImageProvider):
"""Google Gemini Imagen provider."""
provider_key = "gemini"
def __init__(self, api_key: str, model: str = "gemini-3.1-flash-image-preview",
base_url: str = "https://generativelanguage.googleapis.com/v1beta", **_kw):
self._api_key = api_key
self._model = model
self._base_url = base_url
def generate(self, prompt: str, size: str) -> bytes:
if "x" in size:
w, h = size.split("x", 1)
prompt = f"{prompt}\n\nGenerate this image at {w}x{h} resolution."
resp = requests.post(
f"{self._base_url}/models/{self._model}:generateContent",
headers={"Content-Type": "application/json",
"x-goog-api-key": self._api_key},
json={"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {"responseModalities": ["TEXT", "IMAGE"]}},
timeout=120,
)
if resp.status_code != 200:
msg = resp.text[:200]
try:
msg = resp.json().get("error", {}).get("message", msg)
except Exception:
pass
raise ValueError(f"Gemini error ({resp.status_code}): {msg}")
for part in resp.json().get("candidates", [{}])[0].get("content", {}).get("parts", []):
inline = part.get("inlineData")
if inline and inline.get("mimeType", "").startswith("image/"):
return base64.b64decode(inline["data"])
raise ValueError("No image in Gemini response")
class DashScopeProvider(ImageProvider):
"""Alibaba Tongyi Wanxiang (通义万相) via DashScope API."""
provider_key = "dashscope"
def __init__(self, api_key: str, model: str = "qwen-image-2.0-pro",
base_url: str = "https://dashscope.aliyuncs.com/api/v1", **_kw):
self._api_key = api_key
self._model = model
self._base_url = base_url
def generate(self, prompt: str, size: str) -> bytes:
ds_size = size.replace("x", "*") # DashScope uses "W*H"
resp = requests.post(
f"{self._base_url}/services/aigc/multimodal-generation/generation",
headers={"Content-Type": "application/json",
"Authorization": f"Bearer {self._api_key}"},
json={
"model": self._model,
"input": {"messages": [{"role": "user", "content": [{"text": prompt}]}]},
"parameters": {"prompt_extend": False, "size": ds_size, "watermark": False},
},
timeout=120,
)
data = resp.json()
if resp.status_code != 200:
raise ValueError(f"DashScope error ({resp.status_code}): "
f"{data.get('message', str(data))}")
# Try output.result_image first, then output.choices
output = data.get("output", {})
img = output.get("result_image")
if not img:
choices = output.get("choices", [])
if choices:
for c in choices[0].get("message", {}).get("content", []):
if "image" in c:
img = c["image"]
break
if not img:
raise ValueError(f"No image in DashScope response: {data}")
if img.startswith("http"):
return _download_image(img)
return base64.b64decode(img)
class MiniMaxProvider(ImageProvider):
"""MiniMax image generation."""
provider_key = "minimax"
def __init__(self, api_key: str, model: str = "image-01",
base_url: str = "https://api.minimax.io/v1", **_kw):
self._api_key = api_key
self._model = model
self._base_url = base_url
def generate(self, prompt: str, size: str) -> bytes:
w, h = 1024, 1024
try:
w, h = (int(x) for x in size.split("x", 1))
except ValueError:
pass
resp = requests.post(
f"{self._base_url}/image_generation",
headers={"Content-Type": "application/json",
"Authorization": f"Bearer {self._api_key}"},
json={"model": self._model, "prompt": prompt,
"response_format": "base64",
"width": w, "height": h, "n": 1},
timeout=120,
)
data = resp.json()
if resp.status_code != 200:
raise ValueError(f"MiniMax error ({resp.status_code}): {data}")
b64_list = data.get("data", {}).get("image_base64", [])
if not b64_list:
raise ValueError(f"No image in MiniMax response: {data}")
return base64.b64decode(b64_list[0])
class ReplicateProvider(ImageProvider):
"""Replicate API — supports many open-source image models."""
provider_key = "replicate"
_POLL_INTERVAL = 2
_POLL_TIMEOUT = 300
def __init__(self, api_key: str, model: str = "google/nano-banana-pro",
base_url: str = "https://api.replicate.com/v1", **_kw):
self._api_key = api_key
self._model = model
self._base_url = base_url
def generate(self, prompt: str, size: str) -> bytes:
aspect = _size_to_aspect(size)
headers = {"Content-Type": "application/json",
"Authorization": f"Bearer {self._api_key}",
"Prefer": "wait=60"}
resp = requests.post(
f"{self._base_url}/models/{self._model}/predictions",
headers=headers,
json={"input": {"prompt": prompt, "aspect_ratio": aspect,
"number_of_images": 1, "output_format": "png"}},
timeout=120,
)
data = resp.json()
if resp.status_code not in (200, 201):
raise ValueError(f"Replicate error ({resp.status_code}): {data}")
# Poll if not completed yet
poll_url = data.get("urls", {}).get("get")
deadline = time.monotonic() + self._POLL_TIMEOUT
while data.get("status") not in ("succeeded", "failed", "canceled"):
if time.monotonic() > deadline:
raise ValueError("Replicate polling timeout")
time.sleep(self._POLL_INTERVAL)
data = requests.get(poll_url, headers=headers, timeout=30).json()
if data.get("status") != "succeeded":
raise ValueError(f"Replicate failed: {data.get('error')}")
output = data.get("output")
if isinstance(output, list):
output = output[0]
if isinstance(output, dict):
output = output.get("url", output.get("uri"))
if not output or not isinstance(output, str):
raise ValueError(f"No image URL in Replicate output: {data}")
return _download_image(output)
class AzureOpenAIProvider(ImageProvider):
"""Azure-hosted OpenAI DALL-E."""
provider_key = "azure_openai"
def __init__(self, api_key: str, model: str = "dall-e-3",
base_url: str = "", deployment: str = "", **_kw):
self._api_key = api_key
self._deployment = deployment or model
self._base_url = base_url.rstrip("/")
def generate(self, prompt: str, size: str) -> bytes:
if not self._base_url:
raise ValueError("Azure OpenAI requires base_url "
"(e.g. https://YOUR-RESOURCE.openai.azure.com/openai)")
resp = requests.post(
f"{self._base_url}/deployments/{self._deployment}"
f"/images/generations?api-version=2025-04-01-preview",
headers={"Content-Type": "application/json",
"api-key": self._api_key},
json={"prompt": prompt, "size": size, "n": 1, "quality": "medium"},
timeout=120,
)
data = resp.json()
if resp.status_code != 200:
raise ValueError(f"Azure OpenAI error ({resp.status_code}): {data}")
item = data.get("data", [{}])[0]
if item.get("url"):
return _download_image(item["url"])
if item.get("b64_json"):
return base64.b64decode(item["b64_json"])
raise ValueError(f"No image in Azure response: {data}")
class OpenRouterProvider(ImageProvider):
"""OpenRouter — multi-model proxy using chat completions format."""
provider_key = "openrouter"
def __init__(self, api_key: str, model: str = "google/gemini-3.1-flash-image-preview",
base_url: str = "https://openrouter.ai/api/v1", **_kw):
self._api_key = api_key
self._model = model
self._base_url = base_url
def generate(self, prompt: str, size: str) -> bytes:
aspect = _size_to_aspect(size)
resp = requests.post(
f"{self._base_url}/chat/completions",
headers={"Content-Type": "application/json",
"Authorization": f"Bearer {self._api_key}"},
json={
"model": self._model,
"messages": [{"role": "user", "content": prompt}],
"modalities": ["image"],
"stream": False,
"image_config": {"aspect_ratio": aspect},
"provider": {"require_parameters": True},
},
timeout=120,
)
data = resp.json()
if resp.status_code != 200:
raise ValueError(f"OpenRouter error ({resp.status_code}): {data}")
# Extract image from multiple possible locations
choice = data.get("choices", [{}])[0].get("message", {})
# Path 1: images array
images = choice.get("images", [])
if images:
img = images[0]
if img.startswith("http"):
return _download_image(img)
if img.startswith("data:"):
_, b64 = img.split(",", 1)
return base64.b64decode(b64)
# Path 2: content array with image items
content = choice.get("content", [])
if isinstance(content, list):
for item in content:
if isinstance(item, dict) and item.get("type") == "image":
url = item.get("url") or item.get("image_url", {}).get("url")
if url:
if url.startswith("data:"):
_, b64 = url.split(",", 1)
return base64.b64decode(b64)
return _download_image(url)
raise ValueError(f"No image in OpenRouter response: {data}")
class JimengProvider(ImageProvider):
"""ByteDance Jimeng (即梦) — async submit + poll with HMAC-SHA256 auth."""
provider_key = "jimeng"
_POLL_INTERVAL = 2
_POLL_MAX_ATTEMPTS = 60
def __init__(self, api_key: str, secret_key: str = "",
model: str = "jimeng_t2i_v40",
base_url: str = "https://visual.volcengineapi.com", **_kw):
self._access_key = api_key
self._secret_key = secret_key
self._model = model
self._base_url = base_url
def _sign(self, method: str, path: str, query: str,
headers: dict, payload: bytes) -> dict:
"""Generate Volcengine HMAC-SHA256 signed headers."""
now = datetime.now(timezone.utc)
date_stamp = now.strftime("%Y%m%d")
amz_date = now.strftime("%Y%m%dT%H%M%SZ")
signed_headers_list = sorted(k.lower() for k in headers)
signed_headers_str = ";".join(signed_headers_list)
canonical = "\n".join([
method, path, query,
"".join(f"{k.lower()}:{headers[k]}\n" for k in sorted(headers)),
signed_headers_str,
hashlib.sha256(payload).hexdigest(),
])
region = "cn-north-1"
service = "cv"
scope = f"{date_stamp}/{region}/{service}/request"
string_to_sign = "\n".join([
"HMAC-SHA256", amz_date, scope,
hashlib.sha256(canonical.encode()).hexdigest(),
])
def _hmac(key: bytes, msg: str) -> bytes:
return hmac.new(key, msg.encode(), hashlib.sha256).digest()
k_date = _hmac(self._secret_key.encode(), date_stamp)
k_region = _hmac(k_date, region)
k_service = _hmac(k_region, service)
k_signing = _hmac(k_service, "request")
signature = hmac.new(k_signing, string_to_sign.encode(),
hashlib.sha256).hexdigest()
auth = (f"HMAC-SHA256 Credential={self._access_key}/{scope}, "
f"SignedHeaders={signed_headers_str}, Signature={signature}")
return {**headers, "Authorization": auth, "X-Date": amz_date}
def _request(self, action: str, body: dict) -> dict:
payload = json.dumps(body).encode()
path = "/"
query = f"Action={action}&Version=2022-08-31"
headers = {
"Content-Type": "application/json",
"Host": self._base_url.replace("https://", "").replace("http://", ""),
}
signed = self._sign("POST", path, query, headers, payload)
resp = requests.post(
f"{self._base_url}/?{query}",
headers=signed, data=payload, timeout=120,
)
data = resp.json()
if resp.status_code != 200:
raise ValueError(f"Jimeng error ({resp.status_code}): {data}")
return data
def generate(self, prompt: str, size: str) -> bytes:
if not self._secret_key:
raise ValueError("Jimeng requires both api_key (access_key_id) "
"and secret_key (secret_access_key)")
w, h = 1024, 1024
try:
w, h = (int(x) for x in size.split("x", 1))
except ValueError:
pass
# Submit task
submit = self._request("CVSync2AsyncSubmitTask", {
"req_key": self._model, "prompt": prompt,
"width": w, "height": h,
})
task_id = submit.get("data", {}).get("task_id")
if not task_id:
raise ValueError(f"No task_id from Jimeng: {submit}")
# Poll for result
for _ in range(self._POLL_MAX_ATTEMPTS):
time.sleep(self._POLL_INTERVAL)
result = self._request("CVSync2AsyncGetResult", {
"req_key": self._model, "task_id": task_id,
})
code = result.get("code")
if code == 10000:
data = result.get("data", {})
b64_list = data.get("binary_data_base64", [])
if b64_list:
return base64.b64decode(b64_list[0])
urls = data.get("image_urls", [])
if urls:
return _download_image(urls[0])
raise ValueError(f"No image data in Jimeng result: {result}")
if code and code != 10000:
status = result.get("data", {}).get("status")
if status in ("failed", "canceled"):
raise ValueError(f"Jimeng task failed: {result}")
raise ValueError("Jimeng polling timeout")
# --- Provider registry ---
PROVIDERS = {
"doubao": DoubaoProvider,
"openai": OpenAIProvider,
"gemini": GeminiProvider,
"dashscope": DashScopeProvider,
"minimax": MiniMaxProvider,
"replicate": ReplicateProvider,
"azure_openai": AzureOpenAIProvider,
"openrouter": OpenRouterProvider,
"jimeng": JimengProvider,
}
def _build_provider_from_entry(entry: dict) -> ImageProvider:
"""Build a single ImageProvider from a provider config entry."""
provider_name = entry.get("provider", "doubao")
api_key = entry.get("api_key")
if not api_key:
raise ValueError(f"No api_key for provider '{provider_name}'")
provider_cls = PROVIDERS.get(provider_name)
if not provider_cls:
raise ValueError(
f"Unknown provider: '{provider_name}'. "
f"Available: {', '.join(PROVIDERS.keys())}"
)
kwargs = {"api_key": api_key}
if entry.get("model"):
kwargs["model"] = entry["model"]
if entry.get("base_url"):
kwargs["base_url"] = entry["base_url"]
if entry.get("secret_key"):
kwargs["secret_key"] = entry["secret_key"]
if entry.get("deployment"):
kwargs["deployment"] = entry["deployment"]
return provider_cls(**kwargs)
def _build_provider_chain(config: dict) -> list[ImageProvider]:
"""Build an ordered list of providers to try.
Supports two config formats:
- Legacy: image.provider + image.api_key (single provider)
- New: image.providers (list, tried in order with auto-fallback)
"""
img_cfg = config.get("image", {})
providers_list = img_cfg.get("providers")
if providers_list and isinstance(providers_list, list):
chain = []
for entry in providers_list:
try:
chain.append(_build_provider_from_entry(entry))
except ValueError:
continue # skip misconfigured entries
if not chain:
raise ValueError(
"No valid providers in image.providers list. "
"Each entry needs 'provider' and 'api_key'."
)
return chain
# Legacy single-provider format
api_key = img_cfg.get("api_key")
if not api_key:
raise ValueError(
"image.api_key not set in config.yaml. "
"Configure your API key to enable image generation."
)
return [_build_provider_from_entry(img_cfg)]
def _build_provider(config: dict) -> ImageProvider:
"""Build an ImageProvider from config.yaml (backward-compatible entry point)."""
return _build_provider_chain(config)[0]
# --- Public API ---
def generate_image(
prompt: str,
output_path: str,
size: str = "cover",
config: dict = None,
) -> str:
"""
Generate an image using configured providers with auto-fallback.
Tries each provider in order. If one fails, falls back to the next.
Supports both single-provider (legacy) and multi-provider config.
Args:
prompt: Image generation prompt (Chinese or English).
output_path: Where to save the image.
size: Size preset ("cover", "article", "vertical", "square") or explicit "WxH".
config: Optional config dict. If None, loads from config.yaml.
Returns:
The output file path.
"""
if config is None:
config = _load_config()
chain = _build_provider_chain(config)
last_error = None
for provider in chain:
resolved_size = provider.resolve_size(size)
try:
raw_bytes = provider.generate(prompt, resolved_size)
except Exception as e:
last_error = e
print(
f"Provider '{provider.provider_key}' failed: {e}. "
f"Trying next...",
file=sys.stderr,
)
continue
# Compress if over 5MB (WeChat upload limit)
if len(raw_bytes) > MAX_FILE_SIZE:
raw_bytes = _compress_image(raw_bytes, MAX_FILE_SIZE)
output = Path(output_path)
output.parent.mkdir(parents=True, exist_ok=True)
output.write_bytes(raw_bytes)
return str(output)
raise ValueError(
f"All providers failed. Last error: {last_error}"
)
def main():
ap = argparse.ArgumentParser(description="Generate images using AI")
ap.add_argument("--prompt", required=True, help="Image generation prompt")
ap.add_argument("--output", required=True, help="Output file path")
ap.add_argument("--size", default="cover",
help="Size: cover, article, vertical, square, or WxH")
ap.add_argument("--provider", default=None,
help=f"Override provider ({', '.join(PROVIDERS)})")
args = ap.parse_args()
try:
config = _load_config()
if args.provider:
config.setdefault("image", {})["provider"] = args.provider
path = generate_image(args.prompt, args.output, size=args.size, config=config)
print(f"Image saved: {path}")
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()