fix: improve GeminiProvider — size hint, API key in header, error handling

- Append size instruction to prompt (Gemini has no native size param)
- Move API key from URL query string to x-goog-api-key header
- Check status_code before parsing JSON to handle non-JSON error responses
- Remove unnecessary Session with trust_env=False
- Remove f-prefix from strings with no interpolation

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
wangzhuc 2026-04-01 13:13:22 +08:00
parent 6e0ff85f30
commit eb3115d537
2 changed files with 37 additions and 19 deletions

View file

@ -225,26 +225,35 @@ class GeminiProvider(ImageProvider):
self._base_url = base_url self._base_url = base_url
def generate(self, prompt: str, size: str) -> bytes: def generate(self, prompt: str, size: str) -> bytes:
# Append size instruction to prompt (Gemini doesn't have a native size param)
if "x" in size:
w, h = size.split("x", 1)
prompt = f"{prompt}\n\nGenerate this image at {w}x{h} resolution."
body = { body = {
"contents": [{"parts": [{"text": prompt}]}], "contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {"responseModalities": ["TEXT", "IMAGE"]}, "generationConfig": {"responseModalities": ["TEXT", "IMAGE"]},
} }
session = requests.Session() resp = requests.post(
session.trust_env = False f"{self._base_url}/models/{self._model}:generateContent",
resp = session.post( headers={
f"{self._base_url}/models/{self._model}:generateContent?key={self._api_key}", "Content-Type": "application/json",
headers={"Content-Type": "application/json"}, "x-goog-api-key": self._api_key,
},
json=body, json=body,
timeout=120, timeout=120,
) )
data = resp.json()
if resp.status_code != 200: if resp.status_code != 200:
error = data.get("error", {}) try:
msg = error.get("message", json.dumps(data, ensure_ascii=False)) error = resp.json().get("error", {})
msg = error.get("message", resp.text[:200])
except (ValueError, KeyError):
msg = resp.text[:200]
raise ValueError(f"Gemini API error ({resp.status_code}): {msg}") raise ValueError(f"Gemini API error ({resp.status_code}): {msg}")
data = resp.json()
candidates = data.get("candidates", []) candidates = data.get("candidates", [])
if not candidates: if not candidates:
raise ValueError(f"No candidates in Gemini response") raise ValueError("No candidates in Gemini response")
parts = candidates[0].get("content", {}).get("parts", []) parts = candidates[0].get("content", {}).get("parts", [])
for part in parts: for part in parts:
inline_data = part.get("inlineData") inline_data = part.get("inlineData")

View file

@ -225,32 +225,41 @@ class GeminiProvider(ImageProvider):
self._base_url = base_url self._base_url = base_url
def generate(self, prompt: str, size: str) -> bytes: def generate(self, prompt: str, size: str) -> bytes:
# Append size instruction to prompt (Gemini doesn't have a native size param)
if "x" in size:
w, h = size.split("x", 1)
prompt = f"{prompt}\n\nGenerate this image at {w}x{h} resolution."
body = { body = {
"contents": [{"parts": [{"text": prompt}]}], "contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {"responseModalities": ["TEXT", "IMAGE"]}, "generationConfig": {"responseModalities": ["TEXT", "IMAGE"]},
} }
session = requests.Session() resp = requests.post(
session.trust_env = False f"{self._base_url}/models/{self._model}:generateContent",
resp = session.post( headers={
f"{self._base_url}/models/{self._model}:generateContent?key={self._api_key}", "Content-Type": "application/json",
headers={"Content-Type": "application/json"}, "x-goog-api-key": self._api_key,
},
json=body, json=body,
timeout=120, timeout=120,
) )
data = resp.json()
if resp.status_code != 200: if resp.status_code != 200:
error = data.get("error", {}) try:
msg = error.get("message", json.dumps(data, ensure_ascii=False)) error = resp.json().get("error", {})
msg = error.get("message", resp.text[:200])
except (ValueError, KeyError):
msg = resp.text[:200]
raise ValueError(f"Gemini API error ({resp.status_code}): {msg}") raise ValueError(f"Gemini API error ({resp.status_code}): {msg}")
data = resp.json()
candidates = data.get("candidates", []) candidates = data.get("candidates", [])
if not candidates: if not candidates:
raise ValueError(f"No candidates in Gemini response") raise ValueError("No candidates in Gemini response")
parts = candidates[0].get("content", {}).get("parts", []) parts = candidates[0].get("content", {}).get("parts", [])
for part in parts: for part in parts:
inline_data = part.get("inlineData") inline_data = part.get("inlineData")
if inline_data and inline_data.get("mimeType", "").startswith("image/"): if inline_data and inline_data.get("mimeType", "").startswith("image/"):
return base64.b64decode(inline_data["data"]) return base64.b64decode(inline_data["data"])
raise ValueError(f"No image found in Gemini response parts") raise ValueError("No image found in Gemini response parts")
# --- Provider registry --- # --- Provider registry ---