feat: AI聊天室多Agent协作讨论平台
- 实现Agent管理,支持AI辅助生成系统提示词 - 支持多个AI提供商(OpenRouter、智谱、MiniMax等) - 实现聊天室和讨论引擎 - WebSocket实时消息推送 - 前端使用React + Ant Design - 后端使用FastAPI + MongoDB Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
58
backend/adapters/__init__.py
Normal file
58
backend/adapters/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
AI接口适配器模块
|
||||
提供统一的AI调用接口
|
||||
"""
|
||||
from .base_adapter import BaseAdapter, AdapterResponse, ChatMessage
|
||||
from .minimax_adapter import MiniMaxAdapter
|
||||
from .zhipu_adapter import ZhipuAdapter
|
||||
from .openrouter_adapter import OpenRouterAdapter
|
||||
from .kimi_adapter import KimiAdapter
|
||||
from .deepseek_adapter import DeepSeekAdapter
|
||||
from .gemini_adapter import GeminiAdapter
|
||||
from .ollama_adapter import OllamaAdapter
|
||||
from .llmstudio_adapter import LLMStudioAdapter
|
||||
|
||||
__all__ = [
|
||||
"BaseAdapter",
|
||||
"AdapterResponse",
|
||||
"ChatMessage",
|
||||
"MiniMaxAdapter",
|
||||
"ZhipuAdapter",
|
||||
"OpenRouterAdapter",
|
||||
"KimiAdapter",
|
||||
"DeepSeekAdapter",
|
||||
"GeminiAdapter",
|
||||
"OllamaAdapter",
|
||||
"LLMStudioAdapter",
|
||||
]
|
||||
|
||||
# 适配器注册表
|
||||
ADAPTER_REGISTRY = {
|
||||
"minimax": MiniMaxAdapter,
|
||||
"zhipu": ZhipuAdapter,
|
||||
"openrouter": OpenRouterAdapter,
|
||||
"kimi": KimiAdapter,
|
||||
"deepseek": DeepSeekAdapter,
|
||||
"gemini": GeminiAdapter,
|
||||
"ollama": OllamaAdapter,
|
||||
"llmstudio": LLMStudioAdapter,
|
||||
}
|
||||
|
||||
|
||||
def get_adapter(provider_type: str) -> type:
|
||||
"""
|
||||
根据提供商类型获取对应的适配器类
|
||||
|
||||
Args:
|
||||
provider_type: 提供商类型标识
|
||||
|
||||
Returns:
|
||||
适配器类
|
||||
|
||||
Raises:
|
||||
ValueError: 未知的提供商类型
|
||||
"""
|
||||
adapter_class = ADAPTER_REGISTRY.get(provider_type.lower())
|
||||
if not adapter_class:
|
||||
raise ValueError(f"未知的AI提供商类型: {provider_type}")
|
||||
return adapter_class
|
||||
166
backend/adapters/base_adapter.py
Normal file
166
backend/adapters/base_adapter.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
AI适配器基类
|
||||
定义统一的AI调用接口
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""聊天消息"""
|
||||
role: str # system, user, assistant
|
||||
content: str # 消息内容
|
||||
name: Optional[str] = None # 发送者名称(可选)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
d = {"role": self.role, "content": self.content}
|
||||
if self.name:
|
||||
d["name"] = self.name
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterResponse:
|
||||
"""适配器响应"""
|
||||
success: bool # 是否成功
|
||||
content: str = "" # 响应内容
|
||||
error: Optional[str] = None # 错误信息
|
||||
|
||||
# 统计信息
|
||||
prompt_tokens: int = 0 # 输入token数
|
||||
completion_tokens: int = 0 # 输出token数
|
||||
total_tokens: int = 0 # 总token数
|
||||
|
||||
# 元数据
|
||||
model: str = "" # 使用的模型
|
||||
finish_reason: str = "" # 结束原因
|
||||
latency_ms: float = 0.0 # 延迟(毫秒)
|
||||
|
||||
# 工具调用结果
|
||||
tool_calls: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.total_tokens == 0:
|
||||
self.total_tokens = self.prompt_tokens + self.completion_tokens
|
||||
|
||||
|
||||
class BaseAdapter(ABC):
|
||||
"""
|
||||
AI适配器基类
|
||||
所有AI提供商适配器必须继承此类
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
model: str,
|
||||
use_proxy: bool = False,
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
timeout: int = 60,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
初始化适配器
|
||||
|
||||
Args:
|
||||
api_key: API密钥
|
||||
base_url: API基础URL
|
||||
model: 模型名称
|
||||
use_proxy: 是否使用代理
|
||||
proxy_config: 代理配置
|
||||
timeout: 超时时间(秒)
|
||||
**kwargs: 额外参数
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.model = model
|
||||
self.use_proxy = use_proxy
|
||||
self.proxy_config = proxy_config or {}
|
||||
self.timeout = timeout
|
||||
self.extra_params = kwargs
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AdapterResponse:
|
||||
"""
|
||||
发送聊天请求
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
适配器响应
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
发送流式聊天请求
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
**kwargs: 额外参数
|
||||
|
||||
Yields:
|
||||
响应内容片段
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""
|
||||
测试API连接
|
||||
|
||||
Returns:
|
||||
测试结果字典,包含 success, message, latency_ms
|
||||
"""
|
||||
pass
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
messages: List[ChatMessage]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
构建消息列表
|
||||
|
||||
Args:
|
||||
messages: ChatMessage列表
|
||||
|
||||
Returns:
|
||||
字典格式的消息列表
|
||||
"""
|
||||
return [msg.to_dict() for msg in messages]
|
||||
|
||||
def _calculate_latency(self, start_time: datetime) -> float:
|
||||
"""
|
||||
计算延迟
|
||||
|
||||
Args:
|
||||
start_time: 开始时间
|
||||
|
||||
Returns:
|
||||
延迟毫秒数
|
||||
"""
|
||||
return (datetime.utcnow() - start_time).total_seconds() * 1000
|
||||
197
backend/adapters/deepseek_adapter.py
Normal file
197
backend/adapters/deepseek_adapter.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
DeepSeek适配器
|
||||
支持DeepSeek大模型API
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from loguru import logger
|
||||
|
||||
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||
from utils.proxy_handler import get_http_client
|
||||
|
||||
|
||||
class DeepSeekAdapter(BaseAdapter):
|
||||
"""
|
||||
DeepSeek API适配器
|
||||
兼容OpenAI API格式
|
||||
"""
|
||||
|
||||
DEFAULT_BASE_URL = "https://api.deepseek.com/v1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = "",
|
||||
model: str = "deepseek-chat",
|
||||
use_proxy: bool = False,
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
timeout: int = 60,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||
model=model,
|
||||
use_proxy=use_proxy,
|
||||
proxy_config=proxy_config,
|
||||
timeout=timeout,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AdapterResponse:
|
||||
"""发送聊天请求"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
logger.error(f"DeepSeek API错误: {response.status_code} - {error_text}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=f"API错误: {response.status_code} - {error_text}",
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
usage = data.get("usage", {})
|
||||
|
||||
return AdapterResponse(
|
||||
success=True,
|
||||
content=message.get("content", ""),
|
||||
model=data.get("model", self.model),
|
||||
finish_reason=choice.get("finish_reason", ""),
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
total_tokens=usage.get("total_tokens", 0),
|
||||
latency_ms=self._calculate_latency(start_time),
|
||||
tool_calls=message.get("tool_calls", [])
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"DeepSeek请求异常: {e}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=str(e),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""发送流式聊天请求"""
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"DeepSeek流式请求异常: {e}")
|
||||
yield f"[错误: {str(e)}]"
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试API连接"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
test_messages = [
|
||||
ChatMessage(role="user", content="你好,请回复'OK'")
|
||||
]
|
||||
|
||||
response = await self.chat(
|
||||
messages=test_messages,
|
||||
temperature=0,
|
||||
max_tokens=10
|
||||
)
|
||||
|
||||
if response.success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接成功",
|
||||
"model": response.model,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": response.error,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": str(e),
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
250
backend/adapters/gemini_adapter.py
Normal file
250
backend/adapters/gemini_adapter.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
Gemini适配器
|
||||
支持Google Gemini大模型API
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from loguru import logger
|
||||
|
||||
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||
from utils.proxy_handler import get_http_client
|
||||
|
||||
|
||||
class GeminiAdapter(BaseAdapter):
|
||||
"""
|
||||
Google Gemini API适配器
|
||||
使用Gemini的原生API格式
|
||||
"""
|
||||
|
||||
DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = "",
|
||||
model: str = "gemini-1.5-pro",
|
||||
use_proxy: bool = True, # Gemini通常需要代理
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
timeout: int = 60,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||
model=model,
|
||||
use_proxy=use_proxy,
|
||||
proxy_config=proxy_config,
|
||||
timeout=timeout,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _convert_messages_to_gemini(
|
||||
self,
|
||||
messages: List[ChatMessage]
|
||||
) -> tuple[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
将消息转换为Gemini格式
|
||||
|
||||
Args:
|
||||
messages: 标准消息列表
|
||||
|
||||
Returns:
|
||||
(system_instruction, contents)
|
||||
"""
|
||||
system_instruction = ""
|
||||
contents = []
|
||||
|
||||
for msg in messages:
|
||||
if msg.role == "system":
|
||||
system_instruction += msg.content + "\n"
|
||||
else:
|
||||
role = "user" if msg.role == "user" else "model"
|
||||
contents.append({
|
||||
"role": role,
|
||||
"parts": [{"text": msg.content}]
|
||||
})
|
||||
|
||||
return system_instruction.strip(), contents
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AdapterResponse:
|
||||
"""发送聊天请求"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
system_instruction, contents = self._convert_messages_to_gemini(messages)
|
||||
|
||||
payload = {
|
||||
"contents": contents,
|
||||
"generationConfig": {
|
||||
"temperature": temperature,
|
||||
"maxOutputTokens": max_tokens,
|
||||
"topP": kwargs.get("top_p", 0.95),
|
||||
"topK": kwargs.get("top_k", 40)
|
||||
}
|
||||
}
|
||||
|
||||
# 添加系统指令
|
||||
if system_instruction:
|
||||
payload["systemInstruction"] = {
|
||||
"parts": [{"text": system_instruction}]
|
||||
}
|
||||
|
||||
url = f"{self.base_url}/models/{self.model}:generateContent?key={self.api_key}"
|
||||
|
||||
response = await client.post(
|
||||
url,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
logger.error(f"Gemini API错误: {response.status_code} - {error_text}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=f"API错误: {response.status_code} - {error_text}",
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
# 检查是否有候选回复
|
||||
candidates = data.get("candidates", [])
|
||||
if not candidates:
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error="没有生成回复",
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
candidate = candidates[0]
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
text = "".join(part.get("text", "") for part in parts)
|
||||
|
||||
# 获取token使用情况
|
||||
usage = data.get("usageMetadata", {})
|
||||
|
||||
return AdapterResponse(
|
||||
success=True,
|
||||
content=text,
|
||||
model=self.model,
|
||||
finish_reason=candidate.get("finishReason", ""),
|
||||
prompt_tokens=usage.get("promptTokenCount", 0),
|
||||
completion_tokens=usage.get("candidatesTokenCount", 0),
|
||||
total_tokens=usage.get("totalTokenCount", 0),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Gemini请求异常: {e}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=str(e),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""发送流式聊天请求"""
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
system_instruction, contents = self._convert_messages_to_gemini(messages)
|
||||
|
||||
payload = {
|
||||
"contents": contents,
|
||||
"generationConfig": {
|
||||
"temperature": temperature,
|
||||
"maxOutputTokens": max_tokens,
|
||||
"topP": kwargs.get("top_p", 0.95),
|
||||
"topK": kwargs.get("top_k", 40)
|
||||
}
|
||||
}
|
||||
|
||||
if system_instruction:
|
||||
payload["systemInstruction"] = {
|
||||
"parts": [{"text": system_instruction}]
|
||||
}
|
||||
|
||||
url = f"{self.base_url}/models/{self.model}:streamGenerateContent?key={self.api_key}&alt=sse"
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
url,
|
||||
json=payload
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
candidates = data.get("candidates", [])
|
||||
if candidates:
|
||||
content = candidates[0].get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
for part in parts:
|
||||
text = part.get("text", "")
|
||||
if text:
|
||||
yield text
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Gemini流式请求异常: {e}")
|
||||
yield f"[错误: {str(e)}]"
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试API连接"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
test_messages = [
|
||||
ChatMessage(role="user", content="Hello, respond with 'OK'")
|
||||
]
|
||||
|
||||
response = await self.chat(
|
||||
messages=test_messages,
|
||||
temperature=0,
|
||||
max_tokens=10
|
||||
)
|
||||
|
||||
if response.success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接成功",
|
||||
"model": response.model,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": response.error,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": str(e),
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
197
backend/adapters/kimi_adapter.py
Normal file
197
backend/adapters/kimi_adapter.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
Kimi适配器
|
||||
支持月之暗面Kimi大模型API
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from loguru import logger
|
||||
|
||||
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||
from utils.proxy_handler import get_http_client
|
||||
|
||||
|
||||
class KimiAdapter(BaseAdapter):
|
||||
"""
|
||||
Kimi API适配器
|
||||
兼容OpenAI API格式
|
||||
"""
|
||||
|
||||
DEFAULT_BASE_URL = "https://api.moonshot.cn/v1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = "",
|
||||
model: str = "moonshot-v1-8k",
|
||||
use_proxy: bool = False,
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
timeout: int = 60,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||
model=model,
|
||||
use_proxy=use_proxy,
|
||||
proxy_config=proxy_config,
|
||||
timeout=timeout,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AdapterResponse:
|
||||
"""发送聊天请求"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
logger.error(f"Kimi API错误: {response.status_code} - {error_text}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=f"API错误: {response.status_code} - {error_text}",
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
usage = data.get("usage", {})
|
||||
|
||||
return AdapterResponse(
|
||||
success=True,
|
||||
content=message.get("content", ""),
|
||||
model=data.get("model", self.model),
|
||||
finish_reason=choice.get("finish_reason", ""),
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
total_tokens=usage.get("total_tokens", 0),
|
||||
latency_ms=self._calculate_latency(start_time),
|
||||
tool_calls=message.get("tool_calls", [])
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Kimi请求异常: {e}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=str(e),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""发送流式聊天请求"""
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Kimi流式请求异常: {e}")
|
||||
yield f"[错误: {str(e)}]"
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试API连接"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
test_messages = [
|
||||
ChatMessage(role="user", content="你好,请回复'OK'")
|
||||
]
|
||||
|
||||
response = await self.chat(
|
||||
messages=test_messages,
|
||||
temperature=0,
|
||||
max_tokens=10
|
||||
)
|
||||
|
||||
if response.success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接成功",
|
||||
"model": response.model,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": response.error,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": str(e),
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
253
backend/adapters/llmstudio_adapter.py
Normal file
253
backend/adapters/llmstudio_adapter.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
LLM Studio适配器
|
||||
支持本地LLM Studio服务
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from loguru import logger
|
||||
|
||||
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||
from utils.proxy_handler import get_http_client
|
||||
|
||||
|
||||
class LLMStudioAdapter(BaseAdapter):
|
||||
"""
|
||||
LLM Studio API适配器
|
||||
兼容OpenAI API格式的本地服务
|
||||
"""
|
||||
|
||||
DEFAULT_BASE_URL = "http://localhost:1234/v1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str = "lm-studio", # LLM Studio使用固定key
|
||||
base_url: str = "",
|
||||
model: str = "local-model",
|
||||
use_proxy: bool = False, # 本地服务不需要代理
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
timeout: int = 120, # 本地模型可能需要更长时间
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||
model=model,
|
||||
use_proxy=use_proxy,
|
||||
proxy_config=proxy_config,
|
||||
timeout=timeout,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AdapterResponse:
|
||||
"""发送聊天请求"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
logger.error(f"LLM Studio API错误: {response.status_code} - {error_text}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=f"API错误: {response.status_code} - {error_text}",
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
usage = data.get("usage", {})
|
||||
|
||||
return AdapterResponse(
|
||||
success=True,
|
||||
content=message.get("content", ""),
|
||||
model=data.get("model", self.model),
|
||||
finish_reason=choice.get("finish_reason", ""),
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
total_tokens=usage.get("total_tokens", 0),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM Studio请求异常: {e}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=str(e),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""发送流式聊天请求"""
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM Studio流式请求异常: {e}")
|
||||
yield f"[错误: {str(e)}]"
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试API连接"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
# 首先检查服务是否在运行
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=10
|
||||
) as client:
|
||||
# 获取模型列表
|
||||
response = await client.get(
|
||||
f"{self.base_url}/models",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "LLM Studio服务未运行或不可访问",
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
|
||||
data = response.json()
|
||||
models = [m.get("id", "") for m in data.get("data", [])]
|
||||
|
||||
if not models:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "LLM Studio中没有加载的模型",
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
|
||||
# 发送测试消息
|
||||
test_messages = [
|
||||
ChatMessage(role="user", content="Hello, respond with 'OK'")
|
||||
]
|
||||
|
||||
response = await self.chat(
|
||||
messages=test_messages,
|
||||
temperature=0,
|
||||
max_tokens=10
|
||||
)
|
||||
|
||||
if response.success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接成功",
|
||||
"model": response.model,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": response.error,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": str(e),
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
|
||||
async def list_models(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
列出LLM Studio中加载的模型
|
||||
|
||||
Returns:
|
||||
模型信息列表
|
||||
"""
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=10
|
||||
) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/models",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return data.get("data", [])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取LLM Studio模型列表失败: {e}")
|
||||
|
||||
return []
|
||||
251
backend/adapters/minimax_adapter.py
Normal file
251
backend/adapters/minimax_adapter.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
MiniMax适配器
|
||||
支持MiniMax大模型API
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from loguru import logger
|
||||
|
||||
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||
from utils.proxy_handler import get_http_client
|
||||
|
||||
|
||||
class MiniMaxAdapter(BaseAdapter):
|
||||
"""
|
||||
MiniMax API适配器
|
||||
支持abab系列模型
|
||||
"""
|
||||
|
||||
DEFAULT_BASE_URL = "https://api.minimax.chat/v1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = "",
|
||||
model: str = "abab6.5-chat",
|
||||
use_proxy: bool = False,
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
timeout: int = 60,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||
model=model,
|
||||
use_proxy=use_proxy,
|
||||
proxy_config=proxy_config,
|
||||
timeout=timeout,
|
||||
**kwargs
|
||||
)
|
||||
# MiniMax需要group_id
|
||||
self.group_id = kwargs.get("group_id", "")
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AdapterResponse:
|
||||
"""发送聊天请求"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# MiniMax使用特殊的消息格式
|
||||
minimax_messages = []
|
||||
bot_setting = []
|
||||
|
||||
for msg in messages:
|
||||
if msg.role == "system":
|
||||
bot_setting.append({
|
||||
"bot_name": "assistant",
|
||||
"content": msg.content
|
||||
})
|
||||
else:
|
||||
minimax_messages.append({
|
||||
"sender_type": "USER" if msg.role == "user" else "BOT",
|
||||
"sender_name": msg.name or ("用户" if msg.role == "user" else "assistant"),
|
||||
"text": msg.content
|
||||
})
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": minimax_messages,
|
||||
"bot_setting": bot_setting if bot_setting else [{"bot_name": "assistant", "content": "你是一个有帮助的助手"}],
|
||||
"temperature": temperature,
|
||||
"tokens_to_generate": max_tokens,
|
||||
"mask_sensitive_info": False,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
url = f"{self.base_url}/text/chatcompletion_v2"
|
||||
if self.group_id:
|
||||
url = f"{url}?GroupId={self.group_id}"
|
||||
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
logger.error(f"MiniMax API错误: {response.status_code} - {error_text}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=f"API错误: {response.status_code} - {error_text}",
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
# 检查API返回的错误
|
||||
if data.get("base_resp", {}).get("status_code", 0) != 0:
|
||||
error_msg = data.get("base_resp", {}).get("status_msg", "未知错误")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=f"API错误: {error_msg}",
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
reply = data.get("reply", "")
|
||||
usage = data.get("usage", {})
|
||||
|
||||
return AdapterResponse(
|
||||
success=True,
|
||||
content=reply,
|
||||
model=self.model,
|
||||
finish_reason=data.get("output_sensitive", False) and "content_filter" or "stop",
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
total_tokens=usage.get("total_tokens", 0),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MiniMax请求异常: {e}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=str(e),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""发送流式聊天请求"""
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
minimax_messages = []
|
||||
bot_setting = []
|
||||
|
||||
for msg in messages:
|
||||
if msg.role == "system":
|
||||
bot_setting.append({
|
||||
"bot_name": "assistant",
|
||||
"content": msg.content
|
||||
})
|
||||
else:
|
||||
minimax_messages.append({
|
||||
"sender_type": "USER" if msg.role == "user" else "BOT",
|
||||
"sender_name": msg.name or ("用户" if msg.role == "user" else "assistant"),
|
||||
"text": msg.content
|
||||
})
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": minimax_messages,
|
||||
"bot_setting": bot_setting if bot_setting else [{"bot_name": "assistant", "content": "你是一个有帮助的助手"}],
|
||||
"temperature": temperature,
|
||||
"tokens_to_generate": max_tokens,
|
||||
"stream": True,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
url = f"{self.base_url}/text/chatcompletion_v2"
|
||||
if self.group_id:
|
||||
url = f"{url}?GroupId={self.group_id}"
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MiniMax流式请求异常: {e}")
|
||||
yield f"[错误: {str(e)}]"
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试API连接"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
test_messages = [
|
||||
ChatMessage(role="user", content="你好,请回复'OK'")
|
||||
]
|
||||
|
||||
response = await self.chat(
|
||||
messages=test_messages,
|
||||
temperature=0,
|
||||
max_tokens=10
|
||||
)
|
||||
|
||||
if response.success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接成功",
|
||||
"model": response.model,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": response.error,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": str(e),
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
241
backend/adapters/ollama_adapter.py
Normal file
241
backend/adapters/ollama_adapter.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
Ollama适配器
|
||||
支持本地Ollama服务
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from loguru import logger
|
||||
|
||||
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||
from utils.proxy_handler import get_http_client
|
||||
|
||||
|
||||
class OllamaAdapter(BaseAdapter):
|
||||
"""
|
||||
Ollama API适配器
|
||||
用于连接本地Ollama服务
|
||||
"""
|
||||
|
||||
DEFAULT_BASE_URL = "http://localhost:11434"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str = "", # Ollama通常不需要API密钥
|
||||
base_url: str = "",
|
||||
model: str = "llama2",
|
||||
use_proxy: bool = False, # 本地服务通常不需要代理
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
timeout: int = 120, # 本地模型可能需要更长时间
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||
model=model,
|
||||
use_proxy=use_proxy,
|
||||
proxy_config=proxy_config,
|
||||
timeout=timeout,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AdapterResponse:
|
||||
"""发送聊天请求"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"options": {
|
||||
"temperature": temperature,
|
||||
"num_predict": max_tokens,
|
||||
},
|
||||
"stream": False
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/chat",
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
logger.error(f"Ollama API错误: {response.status_code} - {error_text}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=f"API错误: {response.status_code} - {error_text}",
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
message = data.get("message", {})
|
||||
|
||||
return AdapterResponse(
|
||||
success=True,
|
||||
content=message.get("content", ""),
|
||||
model=data.get("model", self.model),
|
||||
finish_reason=data.get("done_reason", "stop"),
|
||||
prompt_tokens=data.get("prompt_eval_count", 0),
|
||||
completion_tokens=data.get("eval_count", 0),
|
||||
total_tokens=(
|
||||
data.get("prompt_eval_count", 0) +
|
||||
data.get("eval_count", 0)
|
||||
),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ollama请求异常: {e}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=str(e),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""发送流式聊天请求"""
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"options": {
|
||||
"temperature": temperature,
|
||||
"num_predict": max_tokens,
|
||||
},
|
||||
"stream": True
|
||||
}
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/api/chat",
|
||||
json=payload
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if line:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
message = data.get("message", {})
|
||||
content = message.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
|
||||
# 检查是否完成
|
||||
if data.get("done", False):
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ollama流式请求异常: {e}")
|
||||
yield f"[错误: {str(e)}]"
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试API连接"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
# 首先检查服务是否在运行
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=10
|
||||
) as client:
|
||||
# 检查模型是否存在
|
||||
response = await client.get(f"{self.base_url}/api/tags")
|
||||
|
||||
if response.status_code != 200:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Ollama服务未运行或不可访问",
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
|
||||
data = response.json()
|
||||
models = [m.get("name", "").split(":")[0] for m in data.get("models", [])]
|
||||
|
||||
model_name = self.model.split(":")[0]
|
||||
if model_name not in models:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"模型 {self.model} 未安装,可用模型: {', '.join(models)}",
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
|
||||
# 发送测试消息
|
||||
test_messages = [
|
||||
ChatMessage(role="user", content="Hello, respond with 'OK'")
|
||||
]
|
||||
|
||||
response = await self.chat(
|
||||
messages=test_messages,
|
||||
temperature=0,
|
||||
max_tokens=10
|
||||
)
|
||||
|
||||
if response.success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接成功",
|
||||
"model": response.model,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": response.error,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": str(e),
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
"""
|
||||
列出本地可用的模型
|
||||
|
||||
Returns:
|
||||
模型名称列表
|
||||
"""
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=10
|
||||
) as client:
|
||||
response = await client.get(f"{self.base_url}/api/tags")
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return [m.get("name", "") for m in data.get("models", [])]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取Ollama模型列表失败: {e}")
|
||||
|
||||
return []
|
||||
201
backend/adapters/openrouter_adapter.py
Normal file
201
backend/adapters/openrouter_adapter.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
OpenRouter适配器
|
||||
支持通过OpenRouter访问多种AI模型
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from loguru import logger
|
||||
|
||||
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||
from utils.proxy_handler import get_http_client
|
||||
|
||||
|
||||
class OpenRouterAdapter(BaseAdapter):
|
||||
"""
|
||||
OpenRouter API适配器
|
||||
兼容OpenAI API格式
|
||||
"""
|
||||
|
||||
DEFAULT_BASE_URL = "https://openrouter.ai/api/v1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = "",
|
||||
model: str = "openai/gpt-4-turbo",
|
||||
use_proxy: bool = False,
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
timeout: int = 60,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||
model=model,
|
||||
use_proxy=use_proxy,
|
||||
proxy_config=proxy_config,
|
||||
timeout=timeout,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AdapterResponse:
|
||||
"""发送聊天请求"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"HTTP-Referer": kwargs.get("referer", "https://ai-chatroom.local"),
|
||||
"X-Title": kwargs.get("title", "AI ChatRoom")
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
logger.error(f"OpenRouter API错误: {response.status_code} - {error_text}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=f"API错误: {response.status_code} - {error_text}",
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
usage = data.get("usage", {})
|
||||
|
||||
return AdapterResponse(
|
||||
success=True,
|
||||
content=message.get("content", ""),
|
||||
model=data.get("model", self.model),
|
||||
finish_reason=choice.get("finish_reason", ""),
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
total_tokens=usage.get("total_tokens", 0),
|
||||
latency_ms=self._calculate_latency(start_time),
|
||||
tool_calls=message.get("tool_calls", [])
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenRouter请求异常: {e}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=str(e),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""发送流式聊天请求"""
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"HTTP-Referer": kwargs.get("referer", "https://ai-chatroom.local"),
|
||||
"X-Title": kwargs.get("title", "AI ChatRoom")
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenRouter流式请求异常: {e}")
|
||||
yield f"[错误: {str(e)}]"
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试API连接"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
test_messages = [
|
||||
ChatMessage(role="user", content="Hello, respond with 'OK'")
|
||||
]
|
||||
|
||||
response = await self.chat(
|
||||
messages=test_messages,
|
||||
temperature=0,
|
||||
max_tokens=10
|
||||
)
|
||||
|
||||
if response.success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接成功",
|
||||
"model": response.model,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": response.error,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": str(e),
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
197
backend/adapters/zhipu_adapter.py
Normal file
197
backend/adapters/zhipu_adapter.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
智谱AI适配器
|
||||
支持智谱GLM系列模型
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from loguru import logger
|
||||
|
||||
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||
from utils.proxy_handler import get_http_client
|
||||
|
||||
|
||||
class ZhipuAdapter(BaseAdapter):
|
||||
"""
|
||||
智谱AI API适配器
|
||||
支持GLM-4、GLM-3等模型
|
||||
"""
|
||||
|
||||
DEFAULT_BASE_URL = "https://open.bigmodel.cn/api/paas/v4"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = "",
|
||||
model: str = "glm-4",
|
||||
use_proxy: bool = False,
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
timeout: int = 60,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||
model=model,
|
||||
use_proxy=use_proxy,
|
||||
proxy_config=proxy_config,
|
||||
timeout=timeout,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AdapterResponse:
|
||||
"""发送聊天请求"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
logger.error(f"智谱API错误: {response.status_code} - {error_text}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=f"API错误: {response.status_code} - {error_text}",
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
usage = data.get("usage", {})
|
||||
|
||||
return AdapterResponse(
|
||||
success=True,
|
||||
content=message.get("content", ""),
|
||||
model=data.get("model", self.model),
|
||||
finish_reason=choice.get("finish_reason", ""),
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
total_tokens=usage.get("total_tokens", 0),
|
||||
latency_ms=self._calculate_latency(start_time),
|
||||
tool_calls=message.get("tool_calls", [])
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"智谱API请求异常: {e}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=str(e),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""发送流式聊天请求"""
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"智谱流式请求异常: {e}")
|
||||
yield f"[错误: {str(e)}]"
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试API连接"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
test_messages = [
|
||||
ChatMessage(role="user", content="你好,请回复'OK'")
|
||||
]
|
||||
|
||||
response = await self.chat(
|
||||
messages=test_messages,
|
||||
temperature=0,
|
||||
max_tokens=10
|
||||
)
|
||||
|
||||
if response.success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接成功",
|
||||
"model": response.model,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": response.error,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": str(e),
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
Reference in New Issue
Block a user