- 实现Agent管理,支持AI辅助生成系统提示词 - 支持多个AI提供商(OpenRouter、智谱、MiniMax等) - 实现聊天室和讨论引擎 - WebSocket实时消息推送 - 前端使用React + Ant Design - 后端使用FastAPI + MongoDB Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
251 lines
8.7 KiB
Python
251 lines
8.7 KiB
Python
"""
|
|
Gemini适配器
|
|
支持Google Gemini大模型API
|
|
"""
|
|
import json
|
|
from datetime import datetime
|
|
from typing import List, Dict, Any, Optional, AsyncGenerator
|
|
from loguru import logger
|
|
|
|
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
|
from utils.proxy_handler import get_http_client
|
|
|
|
|
|
class GeminiAdapter(BaseAdapter):
|
|
"""
|
|
Google Gemini API适配器
|
|
使用Gemini的原生API格式
|
|
"""
|
|
|
|
DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str,
|
|
base_url: str = "",
|
|
model: str = "gemini-1.5-pro",
|
|
use_proxy: bool = True, # Gemini通常需要代理
|
|
proxy_config: Optional[Dict[str, Any]] = None,
|
|
timeout: int = 60,
|
|
**kwargs
|
|
):
|
|
super().__init__(
|
|
api_key=api_key,
|
|
base_url=base_url or self.DEFAULT_BASE_URL,
|
|
model=model,
|
|
use_proxy=use_proxy,
|
|
proxy_config=proxy_config,
|
|
timeout=timeout,
|
|
**kwargs
|
|
)
|
|
|
|
def _convert_messages_to_gemini(
|
|
self,
|
|
messages: List[ChatMessage]
|
|
) -> tuple[str, List[Dict[str, Any]]]:
|
|
"""
|
|
将消息转换为Gemini格式
|
|
|
|
Args:
|
|
messages: 标准消息列表
|
|
|
|
Returns:
|
|
(system_instruction, contents)
|
|
"""
|
|
system_instruction = ""
|
|
contents = []
|
|
|
|
for msg in messages:
|
|
if msg.role == "system":
|
|
system_instruction += msg.content + "\n"
|
|
else:
|
|
role = "user" if msg.role == "user" else "model"
|
|
contents.append({
|
|
"role": role,
|
|
"parts": [{"text": msg.content}]
|
|
})
|
|
|
|
return system_instruction.strip(), contents
|
|
|
|
async def chat(
|
|
self,
|
|
messages: List[ChatMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: int = 2000,
|
|
**kwargs
|
|
) -> AdapterResponse:
|
|
"""发送聊天请求"""
|
|
start_time = datetime.utcnow()
|
|
|
|
try:
|
|
async with get_http_client(
|
|
use_proxy=self.use_proxy,
|
|
proxy_config=self.proxy_config,
|
|
timeout=self.timeout
|
|
) as client:
|
|
system_instruction, contents = self._convert_messages_to_gemini(messages)
|
|
|
|
payload = {
|
|
"contents": contents,
|
|
"generationConfig": {
|
|
"temperature": temperature,
|
|
"maxOutputTokens": max_tokens,
|
|
"topP": kwargs.get("top_p", 0.95),
|
|
"topK": kwargs.get("top_k", 40)
|
|
}
|
|
}
|
|
|
|
# 添加系统指令
|
|
if system_instruction:
|
|
payload["systemInstruction"] = {
|
|
"parts": [{"text": system_instruction}]
|
|
}
|
|
|
|
url = f"{self.base_url}/models/{self.model}:generateContent?key={self.api_key}"
|
|
|
|
response = await client.post(
|
|
url,
|
|
json=payload
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
error_text = response.text
|
|
logger.error(f"Gemini API错误: {response.status_code} - {error_text}")
|
|
return AdapterResponse(
|
|
success=False,
|
|
error=f"API错误: {response.status_code} - {error_text}",
|
|
latency_ms=self._calculate_latency(start_time)
|
|
)
|
|
|
|
data = response.json()
|
|
|
|
# 检查是否有候选回复
|
|
candidates = data.get("candidates", [])
|
|
if not candidates:
|
|
return AdapterResponse(
|
|
success=False,
|
|
error="没有生成回复",
|
|
latency_ms=self._calculate_latency(start_time)
|
|
)
|
|
|
|
candidate = candidates[0]
|
|
content = candidate.get("content", {})
|
|
parts = content.get("parts", [])
|
|
text = "".join(part.get("text", "") for part in parts)
|
|
|
|
# 获取token使用情况
|
|
usage = data.get("usageMetadata", {})
|
|
|
|
return AdapterResponse(
|
|
success=True,
|
|
content=text,
|
|
model=self.model,
|
|
finish_reason=candidate.get("finishReason", ""),
|
|
prompt_tokens=usage.get("promptTokenCount", 0),
|
|
completion_tokens=usage.get("candidatesTokenCount", 0),
|
|
total_tokens=usage.get("totalTokenCount", 0),
|
|
latency_ms=self._calculate_latency(start_time)
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Gemini请求异常: {e}")
|
|
return AdapterResponse(
|
|
success=False,
|
|
error=str(e),
|
|
latency_ms=self._calculate_latency(start_time)
|
|
)
|
|
|
|
async def chat_stream(
|
|
self,
|
|
messages: List[ChatMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: int = 2000,
|
|
**kwargs
|
|
) -> AsyncGenerator[str, None]:
|
|
"""发送流式聊天请求"""
|
|
try:
|
|
async with get_http_client(
|
|
use_proxy=self.use_proxy,
|
|
proxy_config=self.proxy_config,
|
|
timeout=self.timeout
|
|
) as client:
|
|
system_instruction, contents = self._convert_messages_to_gemini(messages)
|
|
|
|
payload = {
|
|
"contents": contents,
|
|
"generationConfig": {
|
|
"temperature": temperature,
|
|
"maxOutputTokens": max_tokens,
|
|
"topP": kwargs.get("top_p", 0.95),
|
|
"topK": kwargs.get("top_k", 40)
|
|
}
|
|
}
|
|
|
|
if system_instruction:
|
|
payload["systemInstruction"] = {
|
|
"parts": [{"text": system_instruction}]
|
|
}
|
|
|
|
url = f"{self.base_url}/models/{self.model}:streamGenerateContent?key={self.api_key}&alt=sse"
|
|
|
|
async with client.stream(
|
|
"POST",
|
|
url,
|
|
json=payload
|
|
) as response:
|
|
async for line in response.aiter_lines():
|
|
if line.startswith("data: "):
|
|
data_str = line[6:]
|
|
try:
|
|
data = json.loads(data_str)
|
|
candidates = data.get("candidates", [])
|
|
if candidates:
|
|
content = candidates[0].get("content", {})
|
|
parts = content.get("parts", [])
|
|
for part in parts:
|
|
text = part.get("text", "")
|
|
if text:
|
|
yield text
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
except Exception as e:
|
|
logger.error(f"Gemini流式请求异常: {e}")
|
|
yield f"[错误: {str(e)}]"
|
|
|
|
async def test_connection(self) -> Dict[str, Any]:
|
|
"""测试API连接"""
|
|
start_time = datetime.utcnow()
|
|
|
|
try:
|
|
test_messages = [
|
|
ChatMessage(role="user", content="Hello, respond with 'OK'")
|
|
]
|
|
|
|
response = await self.chat(
|
|
messages=test_messages,
|
|
temperature=0,
|
|
max_tokens=10
|
|
)
|
|
|
|
if response.success:
|
|
return {
|
|
"success": True,
|
|
"message": "连接成功",
|
|
"model": response.model,
|
|
"latency_ms": response.latency_ms
|
|
}
|
|
else:
|
|
return {
|
|
"success": False,
|
|
"message": response.error,
|
|
"latency_ms": response.latency_ms
|
|
}
|
|
|
|
except Exception as e:
|
|
return {
|
|
"success": False,
|
|
"message": str(e),
|
|
"latency_ms": self._calculate_latency(start_time)
|
|
}
|