Files
AIChatRoom/backend/adapters/gemini_adapter.py

251 lines
8.7 KiB
Python
Raw Permalink Normal View History

"""
Gemini适配器
支持Google Gemini大模型API
"""
import json
from datetime import datetime
from typing import List, Dict, Any, Optional, AsyncGenerator
from loguru import logger
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
from utils.proxy_handler import get_http_client
class GeminiAdapter(BaseAdapter):
"""
Google Gemini API适配器
使用Gemini的原生API格式
"""
DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
def __init__(
self,
api_key: str,
base_url: str = "",
model: str = "gemini-1.5-pro",
use_proxy: bool = True, # Gemini通常需要代理
proxy_config: Optional[Dict[str, Any]] = None,
timeout: int = 60,
**kwargs
):
super().__init__(
api_key=api_key,
base_url=base_url or self.DEFAULT_BASE_URL,
model=model,
use_proxy=use_proxy,
proxy_config=proxy_config,
timeout=timeout,
**kwargs
)
def _convert_messages_to_gemini(
self,
messages: List[ChatMessage]
) -> tuple[str, List[Dict[str, Any]]]:
"""
将消息转换为Gemini格式
Args:
messages: 标准消息列表
Returns:
(system_instruction, contents)
"""
system_instruction = ""
contents = []
for msg in messages:
if msg.role == "system":
system_instruction += msg.content + "\n"
else:
role = "user" if msg.role == "user" else "model"
contents.append({
"role": role,
"parts": [{"text": msg.content}]
})
return system_instruction.strip(), contents
async def chat(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AdapterResponse:
"""发送聊天请求"""
start_time = datetime.utcnow()
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
system_instruction, contents = self._convert_messages_to_gemini(messages)
payload = {
"contents": contents,
"generationConfig": {
"temperature": temperature,
"maxOutputTokens": max_tokens,
"topP": kwargs.get("top_p", 0.95),
"topK": kwargs.get("top_k", 40)
}
}
# 添加系统指令
if system_instruction:
payload["systemInstruction"] = {
"parts": [{"text": system_instruction}]
}
url = f"{self.base_url}/models/{self.model}:generateContent?key={self.api_key}"
response = await client.post(
url,
json=payload
)
if response.status_code != 200:
error_text = response.text
logger.error(f"Gemini API错误: {response.status_code} - {error_text}")
return AdapterResponse(
success=False,
error=f"API错误: {response.status_code} - {error_text}",
latency_ms=self._calculate_latency(start_time)
)
data = response.json()
# 检查是否有候选回复
candidates = data.get("candidates", [])
if not candidates:
return AdapterResponse(
success=False,
error="没有生成回复",
latency_ms=self._calculate_latency(start_time)
)
candidate = candidates[0]
content = candidate.get("content", {})
parts = content.get("parts", [])
text = "".join(part.get("text", "") for part in parts)
# 获取token使用情况
usage = data.get("usageMetadata", {})
return AdapterResponse(
success=True,
content=text,
model=self.model,
finish_reason=candidate.get("finishReason", ""),
prompt_tokens=usage.get("promptTokenCount", 0),
completion_tokens=usage.get("candidatesTokenCount", 0),
total_tokens=usage.get("totalTokenCount", 0),
latency_ms=self._calculate_latency(start_time)
)
except Exception as e:
logger.error(f"Gemini请求异常: {e}")
return AdapterResponse(
success=False,
error=str(e),
latency_ms=self._calculate_latency(start_time)
)
async def chat_stream(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AsyncGenerator[str, None]:
"""发送流式聊天请求"""
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
system_instruction, contents = self._convert_messages_to_gemini(messages)
payload = {
"contents": contents,
"generationConfig": {
"temperature": temperature,
"maxOutputTokens": max_tokens,
"topP": kwargs.get("top_p", 0.95),
"topK": kwargs.get("top_k", 40)
}
}
if system_instruction:
payload["systemInstruction"] = {
"parts": [{"text": system_instruction}]
}
url = f"{self.base_url}/models/{self.model}:streamGenerateContent?key={self.api_key}&alt=sse"
async with client.stream(
"POST",
url,
json=payload
) as response:
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:]
try:
data = json.loads(data_str)
candidates = data.get("candidates", [])
if candidates:
content = candidates[0].get("content", {})
parts = content.get("parts", [])
for part in parts:
text = part.get("text", "")
if text:
yield text
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"Gemini流式请求异常: {e}")
yield f"[错误: {str(e)}]"
async def test_connection(self) -> Dict[str, Any]:
"""测试API连接"""
start_time = datetime.utcnow()
try:
test_messages = [
ChatMessage(role="user", content="Hello, respond with 'OK'")
]
response = await self.chat(
messages=test_messages,
temperature=0,
max_tokens=10
)
if response.success:
return {
"success": True,
"message": "连接成功",
"model": response.model,
"latency_ms": response.latency_ms
}
else:
return {
"success": False,
"message": response.error,
"latency_ms": response.latency_ms
}
except Exception as e:
return {
"success": False,
"message": str(e),
"latency_ms": self._calculate_latency(start_time)
}