""" Gemini适配器 支持Google Gemini大模型API """ import json from datetime import datetime from typing import List, Dict, Any, Optional, AsyncGenerator from loguru import logger from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse from utils.proxy_handler import get_http_client class GeminiAdapter(BaseAdapter): """ Google Gemini API适配器 使用Gemini的原生API格式 """ DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta" def __init__( self, api_key: str, base_url: str = "", model: str = "gemini-1.5-pro", use_proxy: bool = True, # Gemini通常需要代理 proxy_config: Optional[Dict[str, Any]] = None, timeout: int = 60, **kwargs ): super().__init__( api_key=api_key, base_url=base_url or self.DEFAULT_BASE_URL, model=model, use_proxy=use_proxy, proxy_config=proxy_config, timeout=timeout, **kwargs ) def _convert_messages_to_gemini( self, messages: List[ChatMessage] ) -> tuple[str, List[Dict[str, Any]]]: """ 将消息转换为Gemini格式 Args: messages: 标准消息列表 Returns: (system_instruction, contents) """ system_instruction = "" contents = [] for msg in messages: if msg.role == "system": system_instruction += msg.content + "\n" else: role = "user" if msg.role == "user" else "model" contents.append({ "role": role, "parts": [{"text": msg.content}] }) return system_instruction.strip(), contents async def chat( self, messages: List[ChatMessage], temperature: float = 0.7, max_tokens: int = 2000, **kwargs ) -> AdapterResponse: """发送聊天请求""" start_time = datetime.utcnow() try: async with get_http_client( use_proxy=self.use_proxy, proxy_config=self.proxy_config, timeout=self.timeout ) as client: system_instruction, contents = self._convert_messages_to_gemini(messages) payload = { "contents": contents, "generationConfig": { "temperature": temperature, "maxOutputTokens": max_tokens, "topP": kwargs.get("top_p", 0.95), "topK": kwargs.get("top_k", 40) } } # 添加系统指令 if system_instruction: payload["systemInstruction"] = { "parts": [{"text": system_instruction}] } url = f"{self.base_url}/models/{self.model}:generateContent?key={self.api_key}" response = await client.post( url, json=payload ) if response.status_code != 200: error_text = response.text logger.error(f"Gemini API错误: {response.status_code} - {error_text}") return AdapterResponse( success=False, error=f"API错误: {response.status_code} - {error_text}", latency_ms=self._calculate_latency(start_time) ) data = response.json() # 检查是否有候选回复 candidates = data.get("candidates", []) if not candidates: return AdapterResponse( success=False, error="没有生成回复", latency_ms=self._calculate_latency(start_time) ) candidate = candidates[0] content = candidate.get("content", {}) parts = content.get("parts", []) text = "".join(part.get("text", "") for part in parts) # 获取token使用情况 usage = data.get("usageMetadata", {}) return AdapterResponse( success=True, content=text, model=self.model, finish_reason=candidate.get("finishReason", ""), prompt_tokens=usage.get("promptTokenCount", 0), completion_tokens=usage.get("candidatesTokenCount", 0), total_tokens=usage.get("totalTokenCount", 0), latency_ms=self._calculate_latency(start_time) ) except Exception as e: logger.error(f"Gemini请求异常: {e}") return AdapterResponse( success=False, error=str(e), latency_ms=self._calculate_latency(start_time) ) async def chat_stream( self, messages: List[ChatMessage], temperature: float = 0.7, max_tokens: int = 2000, **kwargs ) -> AsyncGenerator[str, None]: """发送流式聊天请求""" try: async with get_http_client( use_proxy=self.use_proxy, proxy_config=self.proxy_config, timeout=self.timeout ) as client: system_instruction, contents = self._convert_messages_to_gemini(messages) payload = { "contents": contents, "generationConfig": { "temperature": temperature, "maxOutputTokens": max_tokens, "topP": kwargs.get("top_p", 0.95), "topK": kwargs.get("top_k", 40) } } if system_instruction: payload["systemInstruction"] = { "parts": [{"text": system_instruction}] } url = f"{self.base_url}/models/{self.model}:streamGenerateContent?key={self.api_key}&alt=sse" async with client.stream( "POST", url, json=payload ) as response: async for line in response.aiter_lines(): if line.startswith("data: "): data_str = line[6:] try: data = json.loads(data_str) candidates = data.get("candidates", []) if candidates: content = candidates[0].get("content", {}) parts = content.get("parts", []) for part in parts: text = part.get("text", "") if text: yield text except json.JSONDecodeError: continue except Exception as e: logger.error(f"Gemini流式请求异常: {e}") yield f"[错误: {str(e)}]" async def test_connection(self) -> Dict[str, Any]: """测试API连接""" start_time = datetime.utcnow() try: test_messages = [ ChatMessage(role="user", content="Hello, respond with 'OK'") ] response = await self.chat( messages=test_messages, temperature=0, max_tokens=10 ) if response.success: return { "success": True, "message": "连接成功", "model": response.model, "latency_ms": response.latency_ms } else: return { "success": False, "message": response.error, "latency_ms": response.latency_ms } except Exception as e: return { "success": False, "message": str(e), "latency_ms": self._calculate_latency(start_time) }