Files
AIChatRoom/backend/adapters/zhipu_adapter.py

198 lines
6.7 KiB
Python
Raw Normal View History

"""
智谱AI适配器
支持智谱GLM系列模型
"""
import json
from datetime import datetime
from typing import List, Dict, Any, Optional, AsyncGenerator
from loguru import logger
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
from utils.proxy_handler import get_http_client
class ZhipuAdapter(BaseAdapter):
"""
智谱AI API适配器
支持GLM-4GLM-3等模型
"""
DEFAULT_BASE_URL = "https://open.bigmodel.cn/api/paas/v4"
def __init__(
self,
api_key: str,
base_url: str = "",
model: str = "glm-4",
use_proxy: bool = False,
proxy_config: Optional[Dict[str, Any]] = None,
timeout: int = 60,
**kwargs
):
super().__init__(
api_key=api_key,
base_url=base_url or self.DEFAULT_BASE_URL,
model=model,
use_proxy=use_proxy,
proxy_config=proxy_config,
timeout=timeout,
**kwargs
)
async def chat(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AdapterResponse:
"""发送聊天请求"""
start_time = datetime.utcnow()
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": self.model,
"messages": self._build_messages(messages),
"temperature": temperature,
"max_tokens": max_tokens,
**kwargs
}
response = await client.post(
f"{self.base_url}/chat/completions",
headers=headers,
json=payload
)
if response.status_code != 200:
error_text = response.text
logger.error(f"智谱API错误: {response.status_code} - {error_text}")
return AdapterResponse(
success=False,
error=f"API错误: {response.status_code} - {error_text}",
latency_ms=self._calculate_latency(start_time)
)
data = response.json()
choice = data.get("choices", [{}])[0]
message = choice.get("message", {})
usage = data.get("usage", {})
return AdapterResponse(
success=True,
content=message.get("content", ""),
model=data.get("model", self.model),
finish_reason=choice.get("finish_reason", ""),
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
latency_ms=self._calculate_latency(start_time),
tool_calls=message.get("tool_calls", [])
)
except Exception as e:
logger.error(f"智谱API请求异常: {e}")
return AdapterResponse(
success=False,
error=str(e),
latency_ms=self._calculate_latency(start_time)
)
async def chat_stream(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AsyncGenerator[str, None]:
"""发送流式聊天请求"""
try:
async with get_http_client(
use_proxy=self.use_proxy,
proxy_config=self.proxy_config,
timeout=self.timeout
) as client:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": self.model,
"messages": self._build_messages(messages),
"temperature": temperature,
"max_tokens": max_tokens,
"stream": True,
**kwargs
}
async with client.stream(
"POST",
f"{self.base_url}/chat/completions",
headers=headers,
json=payload
) as response:
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
delta = data.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"智谱流式请求异常: {e}")
yield f"[错误: {str(e)}]"
async def test_connection(self) -> Dict[str, Any]:
"""测试API连接"""
start_time = datetime.utcnow()
try:
test_messages = [
ChatMessage(role="user", content="你好,请回复'OK'")
]
response = await self.chat(
messages=test_messages,
temperature=0,
max_tokens=10
)
if response.success:
return {
"success": True,
"message": "连接成功",
"model": response.model,
"latency_ms": response.latency_ms
}
else:
return {
"success": False,
"message": response.error,
"latency_ms": response.latency_ms
}
except Exception as e:
return {
"success": False,
"message": str(e),
"latency_ms": self._calculate_latency(start_time)
}