feat: AI聊天室多Agent协作讨论平台
- 实现Agent管理,支持AI辅助生成系统提示词 - 支持多个AI提供商(OpenRouter、智谱、MiniMax等) - 实现聊天室和讨论引擎 - WebSocket实时消息推送 - 前端使用React + Ant Design - 后端使用FastAPI + MongoDB Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
16
backend/.env.example
Normal file
16
backend/.env.example
Normal file
@@ -0,0 +1,16 @@
|
||||
# MongoDB配置
|
||||
MONGODB_URL=mongodb://localhost:27017
|
||||
MONGODB_DB=ai_chatroom
|
||||
|
||||
# 服务配置
|
||||
HOST=0.0.0.0
|
||||
PORT=8000
|
||||
DEBUG=true
|
||||
|
||||
# 安全配置(生产环境请修改)
|
||||
SECRET_KEY=your-secret-key-change-in-production
|
||||
ENCRYPTION_KEY=your-encryption-key-32-bytes-long
|
||||
|
||||
# 代理配置(可选)
|
||||
# DEFAULT_HTTP_PROXY=http://127.0.0.1:7890
|
||||
# DEFAULT_HTTPS_PROXY=http://127.0.0.1:7890
|
||||
25
backend/Dockerfile
Normal file
25
backend/Dockerfile
Normal file
@@ -0,0 +1,25 @@
|
||||
# AI聊天室后端 Dockerfile
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 安装系统依赖
|
||||
RUN apt-get update && apt-get install -y \
|
||||
build-essential \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 复制依赖文件
|
||||
COPY requirements.txt .
|
||||
|
||||
# 安装Python依赖
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 复制应用代码
|
||||
COPY . .
|
||||
|
||||
# 暴露端口
|
||||
EXPOSE 8000
|
||||
|
||||
# 启动命令
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
58
backend/adapters/__init__.py
Normal file
58
backend/adapters/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
AI接口适配器模块
|
||||
提供统一的AI调用接口
|
||||
"""
|
||||
from .base_adapter import BaseAdapter, AdapterResponse, ChatMessage
|
||||
from .minimax_adapter import MiniMaxAdapter
|
||||
from .zhipu_adapter import ZhipuAdapter
|
||||
from .openrouter_adapter import OpenRouterAdapter
|
||||
from .kimi_adapter import KimiAdapter
|
||||
from .deepseek_adapter import DeepSeekAdapter
|
||||
from .gemini_adapter import GeminiAdapter
|
||||
from .ollama_adapter import OllamaAdapter
|
||||
from .llmstudio_adapter import LLMStudioAdapter
|
||||
|
||||
__all__ = [
|
||||
"BaseAdapter",
|
||||
"AdapterResponse",
|
||||
"ChatMessage",
|
||||
"MiniMaxAdapter",
|
||||
"ZhipuAdapter",
|
||||
"OpenRouterAdapter",
|
||||
"KimiAdapter",
|
||||
"DeepSeekAdapter",
|
||||
"GeminiAdapter",
|
||||
"OllamaAdapter",
|
||||
"LLMStudioAdapter",
|
||||
]
|
||||
|
||||
# 适配器注册表
|
||||
ADAPTER_REGISTRY = {
|
||||
"minimax": MiniMaxAdapter,
|
||||
"zhipu": ZhipuAdapter,
|
||||
"openrouter": OpenRouterAdapter,
|
||||
"kimi": KimiAdapter,
|
||||
"deepseek": DeepSeekAdapter,
|
||||
"gemini": GeminiAdapter,
|
||||
"ollama": OllamaAdapter,
|
||||
"llmstudio": LLMStudioAdapter,
|
||||
}
|
||||
|
||||
|
||||
def get_adapter(provider_type: str) -> type:
|
||||
"""
|
||||
根据提供商类型获取对应的适配器类
|
||||
|
||||
Args:
|
||||
provider_type: 提供商类型标识
|
||||
|
||||
Returns:
|
||||
适配器类
|
||||
|
||||
Raises:
|
||||
ValueError: 未知的提供商类型
|
||||
"""
|
||||
adapter_class = ADAPTER_REGISTRY.get(provider_type.lower())
|
||||
if not adapter_class:
|
||||
raise ValueError(f"未知的AI提供商类型: {provider_type}")
|
||||
return adapter_class
|
||||
166
backend/adapters/base_adapter.py
Normal file
166
backend/adapters/base_adapter.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
AI适配器基类
|
||||
定义统一的AI调用接口
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""聊天消息"""
|
||||
role: str # system, user, assistant
|
||||
content: str # 消息内容
|
||||
name: Optional[str] = None # 发送者名称(可选)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
d = {"role": self.role, "content": self.content}
|
||||
if self.name:
|
||||
d["name"] = self.name
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterResponse:
|
||||
"""适配器响应"""
|
||||
success: bool # 是否成功
|
||||
content: str = "" # 响应内容
|
||||
error: Optional[str] = None # 错误信息
|
||||
|
||||
# 统计信息
|
||||
prompt_tokens: int = 0 # 输入token数
|
||||
completion_tokens: int = 0 # 输出token数
|
||||
total_tokens: int = 0 # 总token数
|
||||
|
||||
# 元数据
|
||||
model: str = "" # 使用的模型
|
||||
finish_reason: str = "" # 结束原因
|
||||
latency_ms: float = 0.0 # 延迟(毫秒)
|
||||
|
||||
# 工具调用结果
|
||||
tool_calls: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.total_tokens == 0:
|
||||
self.total_tokens = self.prompt_tokens + self.completion_tokens
|
||||
|
||||
|
||||
class BaseAdapter(ABC):
|
||||
"""
|
||||
AI适配器基类
|
||||
所有AI提供商适配器必须继承此类
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
model: str,
|
||||
use_proxy: bool = False,
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
timeout: int = 60,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
初始化适配器
|
||||
|
||||
Args:
|
||||
api_key: API密钥
|
||||
base_url: API基础URL
|
||||
model: 模型名称
|
||||
use_proxy: 是否使用代理
|
||||
proxy_config: 代理配置
|
||||
timeout: 超时时间(秒)
|
||||
**kwargs: 额外参数
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.model = model
|
||||
self.use_proxy = use_proxy
|
||||
self.proxy_config = proxy_config or {}
|
||||
self.timeout = timeout
|
||||
self.extra_params = kwargs
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AdapterResponse:
|
||||
"""
|
||||
发送聊天请求
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
适配器响应
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
发送流式聊天请求
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
**kwargs: 额外参数
|
||||
|
||||
Yields:
|
||||
响应内容片段
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""
|
||||
测试API连接
|
||||
|
||||
Returns:
|
||||
测试结果字典,包含 success, message, latency_ms
|
||||
"""
|
||||
pass
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
messages: List[ChatMessage]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
构建消息列表
|
||||
|
||||
Args:
|
||||
messages: ChatMessage列表
|
||||
|
||||
Returns:
|
||||
字典格式的消息列表
|
||||
"""
|
||||
return [msg.to_dict() for msg in messages]
|
||||
|
||||
def _calculate_latency(self, start_time: datetime) -> float:
|
||||
"""
|
||||
计算延迟
|
||||
|
||||
Args:
|
||||
start_time: 开始时间
|
||||
|
||||
Returns:
|
||||
延迟毫秒数
|
||||
"""
|
||||
return (datetime.utcnow() - start_time).total_seconds() * 1000
|
||||
197
backend/adapters/deepseek_adapter.py
Normal file
197
backend/adapters/deepseek_adapter.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
DeepSeek适配器
|
||||
支持DeepSeek大模型API
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from loguru import logger
|
||||
|
||||
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||
from utils.proxy_handler import get_http_client
|
||||
|
||||
|
||||
class DeepSeekAdapter(BaseAdapter):
|
||||
"""
|
||||
DeepSeek API适配器
|
||||
兼容OpenAI API格式
|
||||
"""
|
||||
|
||||
DEFAULT_BASE_URL = "https://api.deepseek.com/v1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = "",
|
||||
model: str = "deepseek-chat",
|
||||
use_proxy: bool = False,
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
timeout: int = 60,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||
model=model,
|
||||
use_proxy=use_proxy,
|
||||
proxy_config=proxy_config,
|
||||
timeout=timeout,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AdapterResponse:
|
||||
"""发送聊天请求"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
logger.error(f"DeepSeek API错误: {response.status_code} - {error_text}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=f"API错误: {response.status_code} - {error_text}",
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
usage = data.get("usage", {})
|
||||
|
||||
return AdapterResponse(
|
||||
success=True,
|
||||
content=message.get("content", ""),
|
||||
model=data.get("model", self.model),
|
||||
finish_reason=choice.get("finish_reason", ""),
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
total_tokens=usage.get("total_tokens", 0),
|
||||
latency_ms=self._calculate_latency(start_time),
|
||||
tool_calls=message.get("tool_calls", [])
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"DeepSeek请求异常: {e}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=str(e),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""发送流式聊天请求"""
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"DeepSeek流式请求异常: {e}")
|
||||
yield f"[错误: {str(e)}]"
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试API连接"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
test_messages = [
|
||||
ChatMessage(role="user", content="你好,请回复'OK'")
|
||||
]
|
||||
|
||||
response = await self.chat(
|
||||
messages=test_messages,
|
||||
temperature=0,
|
||||
max_tokens=10
|
||||
)
|
||||
|
||||
if response.success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接成功",
|
||||
"model": response.model,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": response.error,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": str(e),
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
250
backend/adapters/gemini_adapter.py
Normal file
250
backend/adapters/gemini_adapter.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
Gemini适配器
|
||||
支持Google Gemini大模型API
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from loguru import logger
|
||||
|
||||
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||
from utils.proxy_handler import get_http_client
|
||||
|
||||
|
||||
class GeminiAdapter(BaseAdapter):
|
||||
"""
|
||||
Google Gemini API适配器
|
||||
使用Gemini的原生API格式
|
||||
"""
|
||||
|
||||
DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = "",
|
||||
model: str = "gemini-1.5-pro",
|
||||
use_proxy: bool = True, # Gemini通常需要代理
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
timeout: int = 60,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||
model=model,
|
||||
use_proxy=use_proxy,
|
||||
proxy_config=proxy_config,
|
||||
timeout=timeout,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _convert_messages_to_gemini(
|
||||
self,
|
||||
messages: List[ChatMessage]
|
||||
) -> tuple[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
将消息转换为Gemini格式
|
||||
|
||||
Args:
|
||||
messages: 标准消息列表
|
||||
|
||||
Returns:
|
||||
(system_instruction, contents)
|
||||
"""
|
||||
system_instruction = ""
|
||||
contents = []
|
||||
|
||||
for msg in messages:
|
||||
if msg.role == "system":
|
||||
system_instruction += msg.content + "\n"
|
||||
else:
|
||||
role = "user" if msg.role == "user" else "model"
|
||||
contents.append({
|
||||
"role": role,
|
||||
"parts": [{"text": msg.content}]
|
||||
})
|
||||
|
||||
return system_instruction.strip(), contents
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AdapterResponse:
|
||||
"""发送聊天请求"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
system_instruction, contents = self._convert_messages_to_gemini(messages)
|
||||
|
||||
payload = {
|
||||
"contents": contents,
|
||||
"generationConfig": {
|
||||
"temperature": temperature,
|
||||
"maxOutputTokens": max_tokens,
|
||||
"topP": kwargs.get("top_p", 0.95),
|
||||
"topK": kwargs.get("top_k", 40)
|
||||
}
|
||||
}
|
||||
|
||||
# 添加系统指令
|
||||
if system_instruction:
|
||||
payload["systemInstruction"] = {
|
||||
"parts": [{"text": system_instruction}]
|
||||
}
|
||||
|
||||
url = f"{self.base_url}/models/{self.model}:generateContent?key={self.api_key}"
|
||||
|
||||
response = await client.post(
|
||||
url,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
logger.error(f"Gemini API错误: {response.status_code} - {error_text}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=f"API错误: {response.status_code} - {error_text}",
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
# 检查是否有候选回复
|
||||
candidates = data.get("candidates", [])
|
||||
if not candidates:
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error="没有生成回复",
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
candidate = candidates[0]
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
text = "".join(part.get("text", "") for part in parts)
|
||||
|
||||
# 获取token使用情况
|
||||
usage = data.get("usageMetadata", {})
|
||||
|
||||
return AdapterResponse(
|
||||
success=True,
|
||||
content=text,
|
||||
model=self.model,
|
||||
finish_reason=candidate.get("finishReason", ""),
|
||||
prompt_tokens=usage.get("promptTokenCount", 0),
|
||||
completion_tokens=usage.get("candidatesTokenCount", 0),
|
||||
total_tokens=usage.get("totalTokenCount", 0),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Gemini请求异常: {e}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=str(e),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""发送流式聊天请求"""
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
system_instruction, contents = self._convert_messages_to_gemini(messages)
|
||||
|
||||
payload = {
|
||||
"contents": contents,
|
||||
"generationConfig": {
|
||||
"temperature": temperature,
|
||||
"maxOutputTokens": max_tokens,
|
||||
"topP": kwargs.get("top_p", 0.95),
|
||||
"topK": kwargs.get("top_k", 40)
|
||||
}
|
||||
}
|
||||
|
||||
if system_instruction:
|
||||
payload["systemInstruction"] = {
|
||||
"parts": [{"text": system_instruction}]
|
||||
}
|
||||
|
||||
url = f"{self.base_url}/models/{self.model}:streamGenerateContent?key={self.api_key}&alt=sse"
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
url,
|
||||
json=payload
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
candidates = data.get("candidates", [])
|
||||
if candidates:
|
||||
content = candidates[0].get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
for part in parts:
|
||||
text = part.get("text", "")
|
||||
if text:
|
||||
yield text
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Gemini流式请求异常: {e}")
|
||||
yield f"[错误: {str(e)}]"
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试API连接"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
test_messages = [
|
||||
ChatMessage(role="user", content="Hello, respond with 'OK'")
|
||||
]
|
||||
|
||||
response = await self.chat(
|
||||
messages=test_messages,
|
||||
temperature=0,
|
||||
max_tokens=10
|
||||
)
|
||||
|
||||
if response.success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接成功",
|
||||
"model": response.model,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": response.error,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": str(e),
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
197
backend/adapters/kimi_adapter.py
Normal file
197
backend/adapters/kimi_adapter.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
Kimi适配器
|
||||
支持月之暗面Kimi大模型API
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from loguru import logger
|
||||
|
||||
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||
from utils.proxy_handler import get_http_client
|
||||
|
||||
|
||||
class KimiAdapter(BaseAdapter):
|
||||
"""
|
||||
Kimi API适配器
|
||||
兼容OpenAI API格式
|
||||
"""
|
||||
|
||||
DEFAULT_BASE_URL = "https://api.moonshot.cn/v1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = "",
|
||||
model: str = "moonshot-v1-8k",
|
||||
use_proxy: bool = False,
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
timeout: int = 60,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||
model=model,
|
||||
use_proxy=use_proxy,
|
||||
proxy_config=proxy_config,
|
||||
timeout=timeout,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AdapterResponse:
|
||||
"""发送聊天请求"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
logger.error(f"Kimi API错误: {response.status_code} - {error_text}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=f"API错误: {response.status_code} - {error_text}",
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
usage = data.get("usage", {})
|
||||
|
||||
return AdapterResponse(
|
||||
success=True,
|
||||
content=message.get("content", ""),
|
||||
model=data.get("model", self.model),
|
||||
finish_reason=choice.get("finish_reason", ""),
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
total_tokens=usage.get("total_tokens", 0),
|
||||
latency_ms=self._calculate_latency(start_time),
|
||||
tool_calls=message.get("tool_calls", [])
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Kimi请求异常: {e}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=str(e),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""发送流式聊天请求"""
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Kimi流式请求异常: {e}")
|
||||
yield f"[错误: {str(e)}]"
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试API连接"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
test_messages = [
|
||||
ChatMessage(role="user", content="你好,请回复'OK'")
|
||||
]
|
||||
|
||||
response = await self.chat(
|
||||
messages=test_messages,
|
||||
temperature=0,
|
||||
max_tokens=10
|
||||
)
|
||||
|
||||
if response.success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接成功",
|
||||
"model": response.model,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": response.error,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": str(e),
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
253
backend/adapters/llmstudio_adapter.py
Normal file
253
backend/adapters/llmstudio_adapter.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
LLM Studio适配器
|
||||
支持本地LLM Studio服务
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from loguru import logger
|
||||
|
||||
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||
from utils.proxy_handler import get_http_client
|
||||
|
||||
|
||||
class LLMStudioAdapter(BaseAdapter):
|
||||
"""
|
||||
LLM Studio API适配器
|
||||
兼容OpenAI API格式的本地服务
|
||||
"""
|
||||
|
||||
DEFAULT_BASE_URL = "http://localhost:1234/v1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str = "lm-studio", # LLM Studio使用固定key
|
||||
base_url: str = "",
|
||||
model: str = "local-model",
|
||||
use_proxy: bool = False, # 本地服务不需要代理
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
timeout: int = 120, # 本地模型可能需要更长时间
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||
model=model,
|
||||
use_proxy=use_proxy,
|
||||
proxy_config=proxy_config,
|
||||
timeout=timeout,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AdapterResponse:
|
||||
"""发送聊天请求"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
logger.error(f"LLM Studio API错误: {response.status_code} - {error_text}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=f"API错误: {response.status_code} - {error_text}",
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
usage = data.get("usage", {})
|
||||
|
||||
return AdapterResponse(
|
||||
success=True,
|
||||
content=message.get("content", ""),
|
||||
model=data.get("model", self.model),
|
||||
finish_reason=choice.get("finish_reason", ""),
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
total_tokens=usage.get("total_tokens", 0),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM Studio请求异常: {e}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=str(e),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""发送流式聊天请求"""
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM Studio流式请求异常: {e}")
|
||||
yield f"[错误: {str(e)}]"
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试API连接"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
# 首先检查服务是否在运行
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=10
|
||||
) as client:
|
||||
# 获取模型列表
|
||||
response = await client.get(
|
||||
f"{self.base_url}/models",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "LLM Studio服务未运行或不可访问",
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
|
||||
data = response.json()
|
||||
models = [m.get("id", "") for m in data.get("data", [])]
|
||||
|
||||
if not models:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "LLM Studio中没有加载的模型",
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
|
||||
# 发送测试消息
|
||||
test_messages = [
|
||||
ChatMessage(role="user", content="Hello, respond with 'OK'")
|
||||
]
|
||||
|
||||
response = await self.chat(
|
||||
messages=test_messages,
|
||||
temperature=0,
|
||||
max_tokens=10
|
||||
)
|
||||
|
||||
if response.success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接成功",
|
||||
"model": response.model,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": response.error,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": str(e),
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
|
||||
async def list_models(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
列出LLM Studio中加载的模型
|
||||
|
||||
Returns:
|
||||
模型信息列表
|
||||
"""
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=10
|
||||
) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/models",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return data.get("data", [])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取LLM Studio模型列表失败: {e}")
|
||||
|
||||
return []
|
||||
251
backend/adapters/minimax_adapter.py
Normal file
251
backend/adapters/minimax_adapter.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
MiniMax适配器
|
||||
支持MiniMax大模型API
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from loguru import logger
|
||||
|
||||
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||
from utils.proxy_handler import get_http_client
|
||||
|
||||
|
||||
class MiniMaxAdapter(BaseAdapter):
|
||||
"""
|
||||
MiniMax API适配器
|
||||
支持abab系列模型
|
||||
"""
|
||||
|
||||
DEFAULT_BASE_URL = "https://api.minimax.chat/v1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = "",
|
||||
model: str = "abab6.5-chat",
|
||||
use_proxy: bool = False,
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
timeout: int = 60,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||
model=model,
|
||||
use_proxy=use_proxy,
|
||||
proxy_config=proxy_config,
|
||||
timeout=timeout,
|
||||
**kwargs
|
||||
)
|
||||
# MiniMax需要group_id
|
||||
self.group_id = kwargs.get("group_id", "")
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AdapterResponse:
|
||||
"""发送聊天请求"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# MiniMax使用特殊的消息格式
|
||||
minimax_messages = []
|
||||
bot_setting = []
|
||||
|
||||
for msg in messages:
|
||||
if msg.role == "system":
|
||||
bot_setting.append({
|
||||
"bot_name": "assistant",
|
||||
"content": msg.content
|
||||
})
|
||||
else:
|
||||
minimax_messages.append({
|
||||
"sender_type": "USER" if msg.role == "user" else "BOT",
|
||||
"sender_name": msg.name or ("用户" if msg.role == "user" else "assistant"),
|
||||
"text": msg.content
|
||||
})
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": minimax_messages,
|
||||
"bot_setting": bot_setting if bot_setting else [{"bot_name": "assistant", "content": "你是一个有帮助的助手"}],
|
||||
"temperature": temperature,
|
||||
"tokens_to_generate": max_tokens,
|
||||
"mask_sensitive_info": False,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
url = f"{self.base_url}/text/chatcompletion_v2"
|
||||
if self.group_id:
|
||||
url = f"{url}?GroupId={self.group_id}"
|
||||
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
logger.error(f"MiniMax API错误: {response.status_code} - {error_text}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=f"API错误: {response.status_code} - {error_text}",
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
# 检查API返回的错误
|
||||
if data.get("base_resp", {}).get("status_code", 0) != 0:
|
||||
error_msg = data.get("base_resp", {}).get("status_msg", "未知错误")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=f"API错误: {error_msg}",
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
reply = data.get("reply", "")
|
||||
usage = data.get("usage", {})
|
||||
|
||||
return AdapterResponse(
|
||||
success=True,
|
||||
content=reply,
|
||||
model=self.model,
|
||||
finish_reason=data.get("output_sensitive", False) and "content_filter" or "stop",
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
total_tokens=usage.get("total_tokens", 0),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MiniMax请求异常: {e}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=str(e),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""发送流式聊天请求"""
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
minimax_messages = []
|
||||
bot_setting = []
|
||||
|
||||
for msg in messages:
|
||||
if msg.role == "system":
|
||||
bot_setting.append({
|
||||
"bot_name": "assistant",
|
||||
"content": msg.content
|
||||
})
|
||||
else:
|
||||
minimax_messages.append({
|
||||
"sender_type": "USER" if msg.role == "user" else "BOT",
|
||||
"sender_name": msg.name or ("用户" if msg.role == "user" else "assistant"),
|
||||
"text": msg.content
|
||||
})
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": minimax_messages,
|
||||
"bot_setting": bot_setting if bot_setting else [{"bot_name": "assistant", "content": "你是一个有帮助的助手"}],
|
||||
"temperature": temperature,
|
||||
"tokens_to_generate": max_tokens,
|
||||
"stream": True,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
url = f"{self.base_url}/text/chatcompletion_v2"
|
||||
if self.group_id:
|
||||
url = f"{url}?GroupId={self.group_id}"
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MiniMax流式请求异常: {e}")
|
||||
yield f"[错误: {str(e)}]"
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试API连接"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
test_messages = [
|
||||
ChatMessage(role="user", content="你好,请回复'OK'")
|
||||
]
|
||||
|
||||
response = await self.chat(
|
||||
messages=test_messages,
|
||||
temperature=0,
|
||||
max_tokens=10
|
||||
)
|
||||
|
||||
if response.success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接成功",
|
||||
"model": response.model,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": response.error,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": str(e),
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
241
backend/adapters/ollama_adapter.py
Normal file
241
backend/adapters/ollama_adapter.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
Ollama适配器
|
||||
支持本地Ollama服务
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from loguru import logger
|
||||
|
||||
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||
from utils.proxy_handler import get_http_client
|
||||
|
||||
|
||||
class OllamaAdapter(BaseAdapter):
|
||||
"""
|
||||
Ollama API适配器
|
||||
用于连接本地Ollama服务
|
||||
"""
|
||||
|
||||
DEFAULT_BASE_URL = "http://localhost:11434"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str = "", # Ollama通常不需要API密钥
|
||||
base_url: str = "",
|
||||
model: str = "llama2",
|
||||
use_proxy: bool = False, # 本地服务通常不需要代理
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
timeout: int = 120, # 本地模型可能需要更长时间
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||
model=model,
|
||||
use_proxy=use_proxy,
|
||||
proxy_config=proxy_config,
|
||||
timeout=timeout,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AdapterResponse:
|
||||
"""发送聊天请求"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"options": {
|
||||
"temperature": temperature,
|
||||
"num_predict": max_tokens,
|
||||
},
|
||||
"stream": False
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/chat",
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
logger.error(f"Ollama API错误: {response.status_code} - {error_text}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=f"API错误: {response.status_code} - {error_text}",
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
message = data.get("message", {})
|
||||
|
||||
return AdapterResponse(
|
||||
success=True,
|
||||
content=message.get("content", ""),
|
||||
model=data.get("model", self.model),
|
||||
finish_reason=data.get("done_reason", "stop"),
|
||||
prompt_tokens=data.get("prompt_eval_count", 0),
|
||||
completion_tokens=data.get("eval_count", 0),
|
||||
total_tokens=(
|
||||
data.get("prompt_eval_count", 0) +
|
||||
data.get("eval_count", 0)
|
||||
),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ollama请求异常: {e}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=str(e),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""发送流式聊天请求"""
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"options": {
|
||||
"temperature": temperature,
|
||||
"num_predict": max_tokens,
|
||||
},
|
||||
"stream": True
|
||||
}
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/api/chat",
|
||||
json=payload
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if line:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
message = data.get("message", {})
|
||||
content = message.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
|
||||
# 检查是否完成
|
||||
if data.get("done", False):
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ollama流式请求异常: {e}")
|
||||
yield f"[错误: {str(e)}]"
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试API连接"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
# 首先检查服务是否在运行
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=10
|
||||
) as client:
|
||||
# 检查模型是否存在
|
||||
response = await client.get(f"{self.base_url}/api/tags")
|
||||
|
||||
if response.status_code != 200:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Ollama服务未运行或不可访问",
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
|
||||
data = response.json()
|
||||
models = [m.get("name", "").split(":")[0] for m in data.get("models", [])]
|
||||
|
||||
model_name = self.model.split(":")[0]
|
||||
if model_name not in models:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"模型 {self.model} 未安装,可用模型: {', '.join(models)}",
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
|
||||
# 发送测试消息
|
||||
test_messages = [
|
||||
ChatMessage(role="user", content="Hello, respond with 'OK'")
|
||||
]
|
||||
|
||||
response = await self.chat(
|
||||
messages=test_messages,
|
||||
temperature=0,
|
||||
max_tokens=10
|
||||
)
|
||||
|
||||
if response.success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接成功",
|
||||
"model": response.model,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": response.error,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": str(e),
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
"""
|
||||
列出本地可用的模型
|
||||
|
||||
Returns:
|
||||
模型名称列表
|
||||
"""
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=10
|
||||
) as client:
|
||||
response = await client.get(f"{self.base_url}/api/tags")
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return [m.get("name", "") for m in data.get("models", [])]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取Ollama模型列表失败: {e}")
|
||||
|
||||
return []
|
||||
201
backend/adapters/openrouter_adapter.py
Normal file
201
backend/adapters/openrouter_adapter.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
OpenRouter适配器
|
||||
支持通过OpenRouter访问多种AI模型
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from loguru import logger
|
||||
|
||||
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||
from utils.proxy_handler import get_http_client
|
||||
|
||||
|
||||
class OpenRouterAdapter(BaseAdapter):
|
||||
"""
|
||||
OpenRouter API适配器
|
||||
兼容OpenAI API格式
|
||||
"""
|
||||
|
||||
DEFAULT_BASE_URL = "https://openrouter.ai/api/v1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = "",
|
||||
model: str = "openai/gpt-4-turbo",
|
||||
use_proxy: bool = False,
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
timeout: int = 60,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||
model=model,
|
||||
use_proxy=use_proxy,
|
||||
proxy_config=proxy_config,
|
||||
timeout=timeout,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AdapterResponse:
|
||||
"""发送聊天请求"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"HTTP-Referer": kwargs.get("referer", "https://ai-chatroom.local"),
|
||||
"X-Title": kwargs.get("title", "AI ChatRoom")
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
logger.error(f"OpenRouter API错误: {response.status_code} - {error_text}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=f"API错误: {response.status_code} - {error_text}",
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
usage = data.get("usage", {})
|
||||
|
||||
return AdapterResponse(
|
||||
success=True,
|
||||
content=message.get("content", ""),
|
||||
model=data.get("model", self.model),
|
||||
finish_reason=choice.get("finish_reason", ""),
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
total_tokens=usage.get("total_tokens", 0),
|
||||
latency_ms=self._calculate_latency(start_time),
|
||||
tool_calls=message.get("tool_calls", [])
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenRouter请求异常: {e}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=str(e),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""发送流式聊天请求"""
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"HTTP-Referer": kwargs.get("referer", "https://ai-chatroom.local"),
|
||||
"X-Title": kwargs.get("title", "AI ChatRoom")
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenRouter流式请求异常: {e}")
|
||||
yield f"[错误: {str(e)}]"
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试API连接"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
test_messages = [
|
||||
ChatMessage(role="user", content="Hello, respond with 'OK'")
|
||||
]
|
||||
|
||||
response = await self.chat(
|
||||
messages=test_messages,
|
||||
temperature=0,
|
||||
max_tokens=10
|
||||
)
|
||||
|
||||
if response.success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接成功",
|
||||
"model": response.model,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": response.error,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": str(e),
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
197
backend/adapters/zhipu_adapter.py
Normal file
197
backend/adapters/zhipu_adapter.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
智谱AI适配器
|
||||
支持智谱GLM系列模型
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from loguru import logger
|
||||
|
||||
from .base_adapter import BaseAdapter, ChatMessage, AdapterResponse
|
||||
from utils.proxy_handler import get_http_client
|
||||
|
||||
|
||||
class ZhipuAdapter(BaseAdapter):
|
||||
"""
|
||||
智谱AI API适配器
|
||||
支持GLM-4、GLM-3等模型
|
||||
"""
|
||||
|
||||
DEFAULT_BASE_URL = "https://open.bigmodel.cn/api/paas/v4"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = "",
|
||||
model: str = "glm-4",
|
||||
use_proxy: bool = False,
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
timeout: int = 60,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||
model=model,
|
||||
use_proxy=use_proxy,
|
||||
proxy_config=proxy_config,
|
||||
timeout=timeout,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AdapterResponse:
|
||||
"""发送聊天请求"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
logger.error(f"智谱API错误: {response.status_code} - {error_text}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=f"API错误: {response.status_code} - {error_text}",
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
usage = data.get("usage", {})
|
||||
|
||||
return AdapterResponse(
|
||||
success=True,
|
||||
content=message.get("content", ""),
|
||||
model=data.get("model", self.model),
|
||||
finish_reason=choice.get("finish_reason", ""),
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
total_tokens=usage.get("total_tokens", 0),
|
||||
latency_ms=self._calculate_latency(start_time),
|
||||
tool_calls=message.get("tool_calls", [])
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"智谱API请求异常: {e}")
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=str(e),
|
||||
latency_ms=self._calculate_latency(start_time)
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""发送流式聊天请求"""
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=self.use_proxy,
|
||||
proxy_config=self.proxy_config,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self._build_messages(messages),
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"智谱流式请求异常: {e}")
|
||||
yield f"[错误: {str(e)}]"
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试API连接"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
test_messages = [
|
||||
ChatMessage(role="user", content="你好,请回复'OK'")
|
||||
]
|
||||
|
||||
response = await self.chat(
|
||||
messages=test_messages,
|
||||
temperature=0,
|
||||
max_tokens=10
|
||||
)
|
||||
|
||||
if response.success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接成功",
|
||||
"model": response.model,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": response.error,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": str(e),
|
||||
"latency_ms": self._calculate_latency(start_time)
|
||||
}
|
||||
50
backend/config.py
Normal file
50
backend/config.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
应用配置模块
|
||||
从环境变量加载配置项
|
||||
"""
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""应用配置类"""
|
||||
|
||||
# MongoDB配置
|
||||
MONGODB_URL: str = "mongodb://localhost:27017"
|
||||
MONGODB_DB: str = "ai_chatroom"
|
||||
|
||||
# 服务配置
|
||||
HOST: str = "0.0.0.0"
|
||||
PORT: int = 8000
|
||||
DEBUG: bool = True
|
||||
|
||||
# 安全配置
|
||||
SECRET_KEY: str = "your-secret-key-change-in-production"
|
||||
ENCRYPTION_KEY: str = "your-encryption-key-32-bytes-long"
|
||||
|
||||
# CORS配置
|
||||
CORS_ORIGINS: list = ["http://localhost:3000", "http://127.0.0.1:3000"]
|
||||
|
||||
# WebSocket配置
|
||||
WS_HEARTBEAT_INTERVAL: int = 30
|
||||
|
||||
# 默认AI配置
|
||||
DEFAULT_TIMEOUT: int = 60
|
||||
DEFAULT_MAX_TOKENS: int = 2000
|
||||
DEFAULT_TEMPERATURE: float = 0.7
|
||||
|
||||
# 代理配置(全局默认)
|
||||
DEFAULT_HTTP_PROXY: Optional[str] = None
|
||||
DEFAULT_HTTPS_PROXY: Optional[str] = None
|
||||
|
||||
# 速率限制
|
||||
RATE_LIMIT_REQUESTS: int = 100
|
||||
RATE_LIMIT_PERIOD: int = 60 # 秒
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
|
||||
|
||||
# 全局配置实例
|
||||
settings = Settings()
|
||||
10
backend/database/__init__.py
Normal file
10
backend/database/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
数据库模块
|
||||
"""
|
||||
from .connection import connect_db, close_db, get_database
|
||||
|
||||
__all__ = [
|
||||
"connect_db",
|
||||
"close_db",
|
||||
"get_database",
|
||||
]
|
||||
94
backend/database/connection.py
Normal file
94
backend/database/connection.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
MongoDB数据库连接模块
|
||||
使用Motor异步驱动
|
||||
"""
|
||||
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
|
||||
from beanie import init_beanie
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
from config import settings
|
||||
|
||||
# 全局数据库客户端和数据库实例
|
||||
_client: Optional[AsyncIOMotorClient] = None
|
||||
_database: Optional[AsyncIOMotorDatabase] = None
|
||||
|
||||
|
||||
async def connect_db() -> None:
|
||||
"""
|
||||
连接MongoDB数据库
|
||||
初始化Beanie ODM
|
||||
"""
|
||||
global _client, _database
|
||||
|
||||
try:
|
||||
_client = AsyncIOMotorClient(settings.MONGODB_URL)
|
||||
_database = _client[settings.MONGODB_DB]
|
||||
|
||||
# 导入所有文档模型用于初始化Beanie
|
||||
from models.ai_provider import AIProvider
|
||||
from models.agent import Agent
|
||||
from models.chatroom import ChatRoom
|
||||
from models.message import Message
|
||||
from models.discussion_result import DiscussionResult
|
||||
from models.agent_memory import AgentMemory
|
||||
|
||||
# 初始化Beanie
|
||||
await init_beanie(
|
||||
database=_database,
|
||||
document_models=[
|
||||
AIProvider,
|
||||
Agent,
|
||||
ChatRoom,
|
||||
Message,
|
||||
DiscussionResult,
|
||||
AgentMemory,
|
||||
]
|
||||
)
|
||||
|
||||
logger.info(f"已连接到MongoDB数据库: {settings.MONGODB_DB}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def close_db() -> None:
|
||||
"""
|
||||
关闭数据库连接
|
||||
"""
|
||||
global _client
|
||||
|
||||
if _client:
|
||||
_client.close()
|
||||
logger.info("数据库连接已关闭")
|
||||
|
||||
|
||||
def get_database() -> AsyncIOMotorDatabase:
|
||||
"""
|
||||
获取数据库实例
|
||||
|
||||
Returns:
|
||||
MongoDB数据库实例
|
||||
|
||||
Raises:
|
||||
RuntimeError: 数据库未初始化
|
||||
"""
|
||||
if _database is None:
|
||||
raise RuntimeError("数据库未初始化,请先调用connect_db()")
|
||||
return _database
|
||||
|
||||
|
||||
def get_client() -> AsyncIOMotorClient:
|
||||
"""
|
||||
获取数据库客户端
|
||||
|
||||
Returns:
|
||||
MongoDB客户端实例
|
||||
|
||||
Raises:
|
||||
RuntimeError: 客户端未初始化
|
||||
"""
|
||||
if _client is None:
|
||||
raise RuntimeError("数据库客户端未初始化")
|
||||
return _client
|
||||
73
backend/main.py
Normal file
73
backend/main.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
AI聊天室后端主入口
|
||||
FastAPI应用启动文件
|
||||
"""
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
from loguru import logger
|
||||
|
||||
from config import settings
|
||||
from database.connection import connect_db, close_db
|
||||
from routers import providers, agents, chatrooms, discussions
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
应用生命周期管理
|
||||
启动时连接数据库,关闭时断开连接
|
||||
"""
|
||||
logger.info("正在启动AI聊天室服务...")
|
||||
await connect_db()
|
||||
logger.info("数据库连接成功")
|
||||
yield
|
||||
logger.info("正在关闭AI聊天室服务...")
|
||||
await close_db()
|
||||
logger.info("服务已关闭")
|
||||
|
||||
|
||||
# 创建FastAPI应用
|
||||
app = FastAPI(
|
||||
title="AI聊天室",
|
||||
description="多Agent协作讨论平台",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# 配置CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
app.include_router(providers.router, prefix="/api/providers", tags=["AI接口管理"])
|
||||
app.include_router(agents.router, prefix="/api/agents", tags=["Agent管理"])
|
||||
app.include_router(chatrooms.router, prefix="/api/chatrooms", tags=["聊天室管理"])
|
||||
app.include_router(discussions.router, prefix="/api/discussions", tags=["讨论结果"])
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""根路径健康检查"""
|
||||
return {"message": "AI聊天室服务运行中", "version": "1.0.0"}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查接口"""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host=settings.HOST,
|
||||
port=settings.PORT,
|
||||
reload=settings.DEBUG
|
||||
)
|
||||
25
backend/models/__init__.py
Normal file
25
backend/models/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
数据模型模块
|
||||
"""
|
||||
from .ai_provider import AIProvider, ProxyConfig, RateLimit
|
||||
from .agent import Agent, AgentCapabilities, AgentBehavior
|
||||
from .chatroom import ChatRoom, ChatRoomConfig
|
||||
from .message import Message, MessageType
|
||||
from .discussion_result import DiscussionResult
|
||||
from .agent_memory import AgentMemory, MemoryType
|
||||
|
||||
__all__ = [
|
||||
"AIProvider",
|
||||
"ProxyConfig",
|
||||
"RateLimit",
|
||||
"Agent",
|
||||
"AgentCapabilities",
|
||||
"AgentBehavior",
|
||||
"ChatRoom",
|
||||
"ChatRoomConfig",
|
||||
"Message",
|
||||
"MessageType",
|
||||
"DiscussionResult",
|
||||
"AgentMemory",
|
||||
"MemoryType",
|
||||
]
|
||||
168
backend/models/agent.py
Normal file
168
backend/models/agent.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
Agent数据模型
|
||||
定义AI聊天代理的配置结构
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from pydantic import Field
|
||||
from beanie import Document
|
||||
|
||||
|
||||
class AgentCapabilities:
|
||||
"""Agent能力配置"""
|
||||
memory_enabled: bool = False # 是否启用记忆
|
||||
mcp_tools: List[str] = [] # 可用的MCP工具
|
||||
skills: List[str] = [] # 可用的技能
|
||||
multimodal: bool = False # 是否支持多模态
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
memory_enabled: bool = False,
|
||||
mcp_tools: Optional[List[str]] = None,
|
||||
skills: Optional[List[str]] = None,
|
||||
multimodal: bool = False
|
||||
):
|
||||
self.memory_enabled = memory_enabled
|
||||
self.mcp_tools = mcp_tools or []
|
||||
self.skills = skills or []
|
||||
self.multimodal = multimodal
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"memory_enabled": self.memory_enabled,
|
||||
"mcp_tools": self.mcp_tools,
|
||||
"skills": self.skills,
|
||||
"multimodal": self.multimodal
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AgentCapabilities":
|
||||
"""从字典创建"""
|
||||
if not data:
|
||||
return cls()
|
||||
return cls(
|
||||
memory_enabled=data.get("memory_enabled", False),
|
||||
mcp_tools=data.get("mcp_tools", []),
|
||||
skills=data.get("skills", []),
|
||||
multimodal=data.get("multimodal", False)
|
||||
)
|
||||
|
||||
|
||||
class AgentBehavior:
|
||||
"""Agent行为配置"""
|
||||
speak_threshold: float = 0.5 # 发言阈值(判断是否需要发言)
|
||||
max_speak_per_round: int = 2 # 每轮最多发言次数
|
||||
speak_style: str = "balanced" # 发言风格: concise, balanced, detailed
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
speak_threshold: float = 0.5,
|
||||
max_speak_per_round: int = 2,
|
||||
speak_style: str = "balanced"
|
||||
):
|
||||
self.speak_threshold = speak_threshold
|
||||
self.max_speak_per_round = max_speak_per_round
|
||||
self.speak_style = speak_style
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"speak_threshold": self.speak_threshold,
|
||||
"max_speak_per_round": self.max_speak_per_round,
|
||||
"speak_style": self.speak_style
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AgentBehavior":
|
||||
"""从字典创建"""
|
||||
if not data:
|
||||
return cls()
|
||||
return cls(
|
||||
speak_threshold=data.get("speak_threshold", 0.5),
|
||||
max_speak_per_round=data.get("max_speak_per_round", 2),
|
||||
speak_style=data.get("speak_style", "balanced")
|
||||
)
|
||||
|
||||
|
||||
class Agent(Document):
|
||||
"""
|
||||
Agent文档模型
|
||||
存储AI代理的配置信息
|
||||
"""
|
||||
agent_id: str = Field(..., description="唯一标识")
|
||||
name: str = Field(..., description="Agent名称")
|
||||
role: str = Field(..., description="角色定义")
|
||||
system_prompt: str = Field(..., description="系统提示词")
|
||||
provider_id: str = Field(..., description="使用的AI接口ID")
|
||||
|
||||
# 模型参数
|
||||
temperature: float = Field(default=0.7, ge=0, le=2, description="温度参数")
|
||||
max_tokens: int = Field(default=2000, gt=0, description="最大token数")
|
||||
|
||||
# 能力配置
|
||||
capabilities: Dict[str, Any] = Field(
|
||||
default_factory=lambda: {
|
||||
"memory_enabled": False,
|
||||
"mcp_tools": [],
|
||||
"skills": [],
|
||||
"multimodal": False
|
||||
},
|
||||
description="能力配置"
|
||||
)
|
||||
|
||||
# 行为配置
|
||||
behavior: Dict[str, Any] = Field(
|
||||
default_factory=lambda: {
|
||||
"speak_threshold": 0.5,
|
||||
"max_speak_per_round": 2,
|
||||
"speak_style": "balanced"
|
||||
},
|
||||
description="行为配置"
|
||||
)
|
||||
|
||||
# 外观配置
|
||||
avatar: Optional[str] = Field(default=None, description="头像URL")
|
||||
color: str = Field(default="#1890ff", description="代表颜色")
|
||||
|
||||
# 元数据
|
||||
enabled: bool = Field(default=True, description="是否启用")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
class Settings:
|
||||
name = "agents"
|
||||
|
||||
def get_capabilities(self) -> AgentCapabilities:
|
||||
"""获取能力配置对象"""
|
||||
return AgentCapabilities.from_dict(self.capabilities)
|
||||
|
||||
def get_behavior(self) -> AgentBehavior:
|
||||
"""获取行为配置对象"""
|
||||
return AgentBehavior.from_dict(self.behavior)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"agent_id": "product-manager",
|
||||
"name": "产品经理",
|
||||
"role": "产品规划和需求分析专家",
|
||||
"system_prompt": "你是一位经验丰富的产品经理,擅长分析用户需求...",
|
||||
"provider_id": "openrouter-gpt4",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 2000,
|
||||
"capabilities": {
|
||||
"memory_enabled": True,
|
||||
"mcp_tools": ["web_search"],
|
||||
"skills": [],
|
||||
"multimodal": False
|
||||
},
|
||||
"behavior": {
|
||||
"speak_threshold": 0.5,
|
||||
"max_speak_per_round": 2,
|
||||
"speak_style": "balanced"
|
||||
},
|
||||
"avatar": "https://example.com/avatar.png",
|
||||
"color": "#1890ff"
|
||||
}
|
||||
}
|
||||
123
backend/models/agent_memory.py
Normal file
123
backend/models/agent_memory.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Agent记忆数据模型
|
||||
定义Agent的记忆存储结构
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from enum import Enum
|
||||
from pydantic import Field
|
||||
from beanie import Document
|
||||
|
||||
|
||||
class MemoryType(str, Enum):
|
||||
"""记忆类型枚举"""
|
||||
SHORT_TERM = "short_term" # 短期记忆(会话内)
|
||||
LONG_TERM = "long_term" # 长期记忆(跨会话)
|
||||
EPISODIC = "episodic" # 情景记忆(特定事件)
|
||||
SEMANTIC = "semantic" # 语义记忆(知识性)
|
||||
|
||||
|
||||
class AgentMemory(Document):
|
||||
"""
|
||||
Agent记忆文档模型
|
||||
存储Agent的记忆内容
|
||||
"""
|
||||
memory_id: str = Field(..., description="唯一标识")
|
||||
agent_id: str = Field(..., description="Agent ID")
|
||||
|
||||
# 记忆内容
|
||||
memory_type: str = Field(
|
||||
default=MemoryType.SHORT_TERM.value,
|
||||
description="记忆类型"
|
||||
)
|
||||
content: str = Field(..., description="记忆内容")
|
||||
summary: str = Field(default="", description="内容摘要")
|
||||
|
||||
# 向量嵌入(用于相似度检索)
|
||||
embedding: List[float] = Field(default_factory=list, description="向量嵌入")
|
||||
|
||||
# 元数据
|
||||
importance: float = Field(default=0.5, ge=0, le=1, description="重要性评分")
|
||||
access_count: int = Field(default=0, description="访问次数")
|
||||
|
||||
# 关联信息
|
||||
source_room_id: Optional[str] = Field(default=None, description="来源聊天室ID")
|
||||
source_discussion_id: Optional[str] = Field(default=None, description="来源讨论ID")
|
||||
related_agents: List[str] = Field(default_factory=list, description="相关Agent列表")
|
||||
tags: List[str] = Field(default_factory=list, description="标签")
|
||||
|
||||
# 时间戳
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
last_accessed: datetime = Field(default_factory=datetime.utcnow)
|
||||
expires_at: Optional[datetime] = Field(default=None, description="过期时间")
|
||||
|
||||
class Settings:
|
||||
name = "agent_memories"
|
||||
indexes = [
|
||||
[("agent_id", 1)],
|
||||
[("memory_type", 1)],
|
||||
[("importance", -1)],
|
||||
[("last_accessed", -1)],
|
||||
]
|
||||
|
||||
def access(self) -> None:
|
||||
"""
|
||||
记录访问,更新访问计数和时间
|
||||
"""
|
||||
self.access_count += 1
|
||||
self.last_accessed = datetime.utcnow()
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""
|
||||
检查记忆是否已过期
|
||||
|
||||
Returns:
|
||||
是否过期
|
||||
"""
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
return datetime.utcnow() > self.expires_at
|
||||
|
||||
def calculate_relevance_score(
|
||||
self,
|
||||
similarity: float,
|
||||
time_decay_factor: float = 0.1
|
||||
) -> float:
|
||||
"""
|
||||
计算综合相关性分数
|
||||
结合向量相似度、重要性和时间衰减
|
||||
|
||||
Args:
|
||||
similarity: 向量相似度 (0-1)
|
||||
time_decay_factor: 时间衰减因子
|
||||
|
||||
Returns:
|
||||
综合相关性分数
|
||||
"""
|
||||
# 计算时间衰减
|
||||
hours_since_access = (datetime.utcnow() - self.last_accessed).total_seconds() / 3600
|
||||
time_decay = 1.0 / (1.0 + time_decay_factor * hours_since_access)
|
||||
|
||||
# 综合评分
|
||||
score = (
|
||||
0.5 * similarity +
|
||||
0.3 * self.importance +
|
||||
0.2 * time_decay
|
||||
)
|
||||
|
||||
return min(1.0, max(0.0, score))
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"memory_id": "mem-001",
|
||||
"agent_id": "product-manager",
|
||||
"memory_type": "long_term",
|
||||
"content": "在登录系统设计讨论中,团队决定采用OAuth2.0方案",
|
||||
"summary": "登录系统采用OAuth2.0",
|
||||
"importance": 0.8,
|
||||
"access_count": 5,
|
||||
"source_room_id": "product-design-room",
|
||||
"tags": ["登录", "OAuth", "认证"]
|
||||
}
|
||||
}
|
||||
149
backend/models/ai_provider.py
Normal file
149
backend/models/ai_provider.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
AI接口提供商数据模型
|
||||
定义AI服务配置结构
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from enum import Enum
|
||||
from pydantic import Field
|
||||
from beanie import Document
|
||||
|
||||
|
||||
class ProviderType(str, Enum):
|
||||
"""AI提供商类型枚举"""
|
||||
MINIMAX = "minimax"
|
||||
ZHIPU = "zhipu"
|
||||
OPENROUTER = "openrouter"
|
||||
KIMI = "kimi"
|
||||
DEEPSEEK = "deepseek"
|
||||
GEMINI = "gemini"
|
||||
OLLAMA = "ollama"
|
||||
LLMSTUDIO = "llmstudio"
|
||||
|
||||
|
||||
class ProxyConfig:
|
||||
"""代理配置"""
|
||||
http_proxy: Optional[str] = None # HTTP代理地址
|
||||
https_proxy: Optional[str] = None # HTTPS代理地址
|
||||
no_proxy: List[str] = [] # 不使用代理的域名列表
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
http_proxy: Optional[str] = None,
|
||||
https_proxy: Optional[str] = None,
|
||||
no_proxy: Optional[List[str]] = None
|
||||
):
|
||||
self.http_proxy = http_proxy
|
||||
self.https_proxy = https_proxy
|
||||
self.no_proxy = no_proxy or []
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"http_proxy": self.http_proxy,
|
||||
"https_proxy": self.https_proxy,
|
||||
"no_proxy": self.no_proxy
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ProxyConfig":
|
||||
"""从字典创建"""
|
||||
if not data:
|
||||
return cls()
|
||||
return cls(
|
||||
http_proxy=data.get("http_proxy"),
|
||||
https_proxy=data.get("https_proxy"),
|
||||
no_proxy=data.get("no_proxy", [])
|
||||
)
|
||||
|
||||
|
||||
class RateLimit:
|
||||
"""速率限制配置"""
|
||||
requests_per_minute: int = 60 # 每分钟请求数
|
||||
tokens_per_minute: int = 100000 # 每分钟token数
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
requests_per_minute: int = 60,
|
||||
tokens_per_minute: int = 100000
|
||||
):
|
||||
self.requests_per_minute = requests_per_minute
|
||||
self.tokens_per_minute = tokens_per_minute
|
||||
|
||||
def to_dict(self) -> Dict[str, int]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"requests_per_minute": self.requests_per_minute,
|
||||
"tokens_per_minute": self.tokens_per_minute
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, int]) -> "RateLimit":
|
||||
"""从字典创建"""
|
||||
if not data:
|
||||
return cls()
|
||||
return cls(
|
||||
requests_per_minute=data.get("requests_per_minute", 60),
|
||||
tokens_per_minute=data.get("tokens_per_minute", 100000)
|
||||
)
|
||||
|
||||
|
||||
class AIProvider(Document):
|
||||
"""
|
||||
AI接口提供商文档模型
|
||||
存储各AI服务的配置信息
|
||||
"""
|
||||
provider_id: str = Field(..., description="唯一标识")
|
||||
provider_type: str = Field(..., description="提供商类型: minimax, zhipu等")
|
||||
name: str = Field(..., description="自定义名称")
|
||||
api_key: str = Field(default="", description="API密钥(加密存储)")
|
||||
base_url: str = Field(default="", description="API基础URL")
|
||||
model: str = Field(..., description="使用的模型名称")
|
||||
|
||||
# 代理配置
|
||||
use_proxy: bool = Field(default=False, description="是否使用代理")
|
||||
proxy_config: Dict[str, Any] = Field(default_factory=dict, description="代理配置")
|
||||
|
||||
# 速率限制
|
||||
rate_limit: Dict[str, int] = Field(
|
||||
default_factory=lambda: {"requests_per_minute": 60, "tokens_per_minute": 100000},
|
||||
description="速率限制配置"
|
||||
)
|
||||
|
||||
# 其他配置
|
||||
timeout: int = Field(default=60, description="超时时间(秒)")
|
||||
extra_params: Dict[str, Any] = Field(default_factory=dict, description="额外参数")
|
||||
|
||||
# 元数据
|
||||
enabled: bool = Field(default=True, description="是否启用")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
class Settings:
|
||||
name = "ai_providers"
|
||||
|
||||
def get_proxy_config(self) -> ProxyConfig:
|
||||
"""获取代理配置对象"""
|
||||
return ProxyConfig.from_dict(self.proxy_config)
|
||||
|
||||
def get_rate_limit(self) -> RateLimit:
|
||||
"""获取速率限制配置对象"""
|
||||
return RateLimit.from_dict(self.rate_limit)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"provider_id": "openrouter-gpt4",
|
||||
"provider_type": "openrouter",
|
||||
"name": "OpenRouter GPT-4",
|
||||
"api_key": "sk-xxx",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"model": "openai/gpt-4-turbo",
|
||||
"use_proxy": True,
|
||||
"proxy_config": {
|
||||
"http_proxy": "http://127.0.0.1:7890",
|
||||
"https_proxy": "http://127.0.0.1:7890"
|
||||
},
|
||||
"timeout": 60
|
||||
}
|
||||
}
|
||||
131
backend/models/chatroom.py
Normal file
131
backend/models/chatroom.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
聊天室数据模型
|
||||
定义讨论聊天室的配置结构
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from enum import Enum
|
||||
from pydantic import Field
|
||||
from beanie import Document
|
||||
|
||||
|
||||
class ChatRoomStatus(str, Enum):
|
||||
"""聊天室状态枚举"""
|
||||
IDLE = "idle" # 空闲,等待开始
|
||||
ACTIVE = "active" # 讨论进行中
|
||||
PAUSED = "paused" # 暂停
|
||||
COMPLETED = "completed" # 已完成
|
||||
ERROR = "error" # 出错
|
||||
|
||||
|
||||
class ChatRoomConfig:
|
||||
"""聊天室配置"""
|
||||
max_rounds: int = 50 # 最大轮数(备用终止条件)
|
||||
message_history_size: int = 20 # 上下文消息数
|
||||
consensus_threshold: float = 0.8 # 共识阈值
|
||||
round_interval: float = 1.0 # 轮次间隔(秒)
|
||||
allow_user_interrupt: bool = True # 允许用户中断
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_rounds: int = 50,
|
||||
message_history_size: int = 20,
|
||||
consensus_threshold: float = 0.8,
|
||||
round_interval: float = 1.0,
|
||||
allow_user_interrupt: bool = True
|
||||
):
|
||||
self.max_rounds = max_rounds
|
||||
self.message_history_size = message_history_size
|
||||
self.consensus_threshold = consensus_threshold
|
||||
self.round_interval = round_interval
|
||||
self.allow_user_interrupt = allow_user_interrupt
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"max_rounds": self.max_rounds,
|
||||
"message_history_size": self.message_history_size,
|
||||
"consensus_threshold": self.consensus_threshold,
|
||||
"round_interval": self.round_interval,
|
||||
"allow_user_interrupt": self.allow_user_interrupt
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ChatRoomConfig":
|
||||
"""从字典创建"""
|
||||
if not data:
|
||||
return cls()
|
||||
return cls(
|
||||
max_rounds=data.get("max_rounds", 50),
|
||||
message_history_size=data.get("message_history_size", 20),
|
||||
consensus_threshold=data.get("consensus_threshold", 0.8),
|
||||
round_interval=data.get("round_interval", 1.0),
|
||||
allow_user_interrupt=data.get("allow_user_interrupt", True)
|
||||
)
|
||||
|
||||
|
||||
class ChatRoom(Document):
|
||||
"""
|
||||
聊天室文档模型
|
||||
存储讨论聊天室的配置信息
|
||||
"""
|
||||
room_id: str = Field(..., description="唯一标识")
|
||||
name: str = Field(..., description="聊天室名称")
|
||||
description: str = Field(default="", description="描述")
|
||||
objective: str = Field(default="", description="当前讨论目标")
|
||||
|
||||
# 参与者
|
||||
agents: List[str] = Field(default_factory=list, description="Agent ID列表")
|
||||
moderator_agent_id: Optional[str] = Field(default=None, description="共识判断Agent ID")
|
||||
|
||||
# 配置
|
||||
config: Dict[str, Any] = Field(
|
||||
default_factory=lambda: {
|
||||
"max_rounds": 50,
|
||||
"message_history_size": 20,
|
||||
"consensus_threshold": 0.8,
|
||||
"round_interval": 1.0,
|
||||
"allow_user_interrupt": True
|
||||
},
|
||||
description="聊天室配置"
|
||||
)
|
||||
|
||||
# 状态
|
||||
status: str = Field(default=ChatRoomStatus.IDLE.value, description="当前状态")
|
||||
current_round: int = Field(default=0, description="当前轮次")
|
||||
current_discussion_id: Optional[str] = Field(default=None, description="当前讨论ID")
|
||||
|
||||
# 元数据
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
completed_at: Optional[datetime] = Field(default=None, description="完成时间")
|
||||
|
||||
class Settings:
|
||||
name = "chatrooms"
|
||||
|
||||
def get_config(self) -> ChatRoomConfig:
|
||||
"""获取配置对象"""
|
||||
return ChatRoomConfig.from_dict(self.config)
|
||||
|
||||
def is_active(self) -> bool:
|
||||
"""检查聊天室是否处于活动状态"""
|
||||
return self.status == ChatRoomStatus.ACTIVE.value
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"room_id": "product-design-room",
|
||||
"name": "产品设计讨论室",
|
||||
"description": "用于讨论新产品功能设计",
|
||||
"objective": "设计一个用户友好的登录系统",
|
||||
"agents": ["product-manager", "designer", "developer"],
|
||||
"moderator_agent_id": "moderator",
|
||||
"config": {
|
||||
"max_rounds": 50,
|
||||
"message_history_size": 20,
|
||||
"consensus_threshold": 0.8
|
||||
},
|
||||
"status": "idle",
|
||||
"current_round": 0
|
||||
}
|
||||
}
|
||||
126
backend/models/discussion_result.py
Normal file
126
backend/models/discussion_result.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""
|
||||
讨论结果数据模型
|
||||
定义讨论结果的结构
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from pydantic import Field
|
||||
from beanie import Document
|
||||
|
||||
|
||||
class DiscussionResult(Document):
|
||||
"""
|
||||
讨论结果文档模型
|
||||
存储讨论的最终结果
|
||||
"""
|
||||
discussion_id: str = Field(..., description="讨论唯一标识")
|
||||
room_id: str = Field(..., description="聊天室ID")
|
||||
objective: str = Field(..., description="讨论目标")
|
||||
|
||||
# 共识结果
|
||||
consensus_reached: bool = Field(default=False, description="是否达成共识")
|
||||
confidence: float = Field(default=0.0, ge=0, le=1, description="共识置信度")
|
||||
|
||||
# 结果摘要
|
||||
summary: str = Field(default="", description="讨论结果摘要")
|
||||
action_items: List[str] = Field(default_factory=list, description="行动项列表")
|
||||
unresolved_issues: List[str] = Field(default_factory=list, description="未解决的问题")
|
||||
key_decisions: List[str] = Field(default_factory=list, description="关键决策")
|
||||
|
||||
# 统计信息
|
||||
total_rounds: int = Field(default=0, description="总轮数")
|
||||
total_messages: int = Field(default=0, description="总消息数")
|
||||
participating_agents: List[str] = Field(default_factory=list, description="参与的Agent列表")
|
||||
agent_contributions: Dict[str, int] = Field(
|
||||
default_factory=dict,
|
||||
description="各Agent发言次数统计"
|
||||
)
|
||||
|
||||
# 状态
|
||||
status: str = Field(default="in_progress", description="状态: in_progress, completed, failed")
|
||||
end_reason: str = Field(default="", description="结束原因")
|
||||
|
||||
# 时间戳
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
completed_at: Optional[datetime] = Field(default=None, description="完成时间")
|
||||
|
||||
class Settings:
|
||||
name = "discussions"
|
||||
indexes = [
|
||||
[("room_id", 1)],
|
||||
[("created_at", -1)],
|
||||
]
|
||||
|
||||
def mark_completed(
|
||||
self,
|
||||
consensus_reached: bool,
|
||||
confidence: float,
|
||||
summary: str,
|
||||
action_items: List[str] = None,
|
||||
unresolved_issues: List[str] = None,
|
||||
end_reason: str = "consensus"
|
||||
) -> None:
|
||||
"""
|
||||
标记讨论为已完成
|
||||
|
||||
Args:
|
||||
consensus_reached: 是否达成共识
|
||||
confidence: 置信度
|
||||
summary: 结果摘要
|
||||
action_items: 行动项
|
||||
unresolved_issues: 未解决问题
|
||||
end_reason: 结束原因
|
||||
"""
|
||||
self.consensus_reached = consensus_reached
|
||||
self.confidence = confidence
|
||||
self.summary = summary
|
||||
self.action_items = action_items or []
|
||||
self.unresolved_issues = unresolved_issues or []
|
||||
self.status = "completed"
|
||||
self.end_reason = end_reason
|
||||
self.completed_at = datetime.utcnow()
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def update_stats(
|
||||
self,
|
||||
total_rounds: int,
|
||||
total_messages: int,
|
||||
agent_contributions: Dict[str, int]
|
||||
) -> None:
|
||||
"""
|
||||
更新统计信息
|
||||
|
||||
Args:
|
||||
total_rounds: 总轮数
|
||||
total_messages: 总消息数
|
||||
agent_contributions: Agent贡献统计
|
||||
"""
|
||||
self.total_rounds = total_rounds
|
||||
self.total_messages = total_messages
|
||||
self.agent_contributions = agent_contributions
|
||||
self.participating_agents = list(agent_contributions.keys())
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"discussion_id": "disc-001",
|
||||
"room_id": "product-design-room",
|
||||
"objective": "设计用户登录系统",
|
||||
"consensus_reached": True,
|
||||
"confidence": 0.85,
|
||||
"summary": "团队一致同意采用OAuth2.0 + 手机验证码的混合认证方案...",
|
||||
"action_items": [
|
||||
"设计OAuth2.0集成方案",
|
||||
"开发短信验证服务",
|
||||
"编写安全测试用例"
|
||||
],
|
||||
"unresolved_issues": [
|
||||
"第三方登录的优先级排序"
|
||||
],
|
||||
"total_rounds": 15,
|
||||
"total_messages": 45,
|
||||
"status": "completed"
|
||||
}
|
||||
}
|
||||
123
backend/models/message.py
Normal file
123
backend/models/message.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
消息数据模型
|
||||
定义聊天消息的结构
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from enum import Enum
|
||||
from pydantic import Field
|
||||
from beanie import Document
|
||||
|
||||
|
||||
class MessageType(str, Enum):
|
||||
"""消息类型枚举"""
|
||||
TEXT = "text" # 纯文本
|
||||
IMAGE = "image" # 图片
|
||||
FILE = "file" # 文件
|
||||
SYSTEM = "system" # 系统消息
|
||||
ACTION = "action" # 动作消息(如调用工具)
|
||||
|
||||
|
||||
class MessageAttachment:
|
||||
"""消息附件"""
|
||||
attachment_type: str # 附件类型: image, file
|
||||
url: str # 资源URL
|
||||
name: str # 文件名
|
||||
size: int = 0 # 文件大小(字节)
|
||||
mime_type: str = "" # MIME类型
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attachment_type: str,
|
||||
url: str,
|
||||
name: str,
|
||||
size: int = 0,
|
||||
mime_type: str = ""
|
||||
):
|
||||
self.attachment_type = attachment_type
|
||||
self.url = url
|
||||
self.name = name
|
||||
self.size = size
|
||||
self.mime_type = mime_type
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"attachment_type": self.attachment_type,
|
||||
"url": self.url,
|
||||
"name": self.name,
|
||||
"size": self.size,
|
||||
"mime_type": self.mime_type
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "MessageAttachment":
|
||||
"""从字典创建"""
|
||||
return cls(
|
||||
attachment_type=data.get("attachment_type", ""),
|
||||
url=data.get("url", ""),
|
||||
name=data.get("name", ""),
|
||||
size=data.get("size", 0),
|
||||
mime_type=data.get("mime_type", "")
|
||||
)
|
||||
|
||||
|
||||
class Message(Document):
|
||||
"""
|
||||
消息文档模型
|
||||
存储聊天消息
|
||||
"""
|
||||
message_id: str = Field(..., description="唯一标识")
|
||||
room_id: str = Field(..., description="聊天室ID")
|
||||
discussion_id: str = Field(..., description="讨论ID")
|
||||
agent_id: Optional[str] = Field(default=None, description="发送Agent ID(系统消息为空)")
|
||||
|
||||
# 消息内容
|
||||
content: str = Field(..., description="消息内容")
|
||||
message_type: str = Field(default=MessageType.TEXT.value, description="消息类型")
|
||||
attachments: List[Dict[str, Any]] = Field(default_factory=list, description="附件列表")
|
||||
|
||||
# 元数据
|
||||
round: int = Field(default=0, description="所属轮次")
|
||||
token_count: int = Field(default=0, description="token数量")
|
||||
|
||||
# 工具调用相关
|
||||
tool_calls: List[Dict[str, Any]] = Field(default_factory=list, description="工具调用记录")
|
||||
tool_results: List[Dict[str, Any]] = Field(default_factory=list, description="工具调用结果")
|
||||
|
||||
# 时间戳
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
class Settings:
|
||||
name = "messages"
|
||||
indexes = [
|
||||
[("room_id", 1), ("created_at", 1)],
|
||||
[("discussion_id", 1)],
|
||||
[("agent_id", 1)],
|
||||
]
|
||||
|
||||
def get_attachments(self) -> List[MessageAttachment]:
|
||||
"""获取附件对象列表"""
|
||||
return [MessageAttachment.from_dict(a) for a in self.attachments]
|
||||
|
||||
def is_from_agent(self, agent_id: str) -> bool:
|
||||
"""检查消息是否来自指定Agent"""
|
||||
return self.agent_id == agent_id
|
||||
|
||||
def is_system_message(self) -> bool:
|
||||
"""检查是否为系统消息"""
|
||||
return self.message_type == MessageType.SYSTEM.value
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"message_id": "msg-001",
|
||||
"room_id": "product-design-room",
|
||||
"discussion_id": "disc-001",
|
||||
"agent_id": "product-manager",
|
||||
"content": "我认为登录系统应该支持多种认证方式...",
|
||||
"message_type": "text",
|
||||
"round": 1,
|
||||
"token_count": 150
|
||||
}
|
||||
}
|
||||
42
backend/requirements.txt
Normal file
42
backend/requirements.txt
Normal file
@@ -0,0 +1,42 @@
|
||||
# FastAPI and server
|
||||
fastapi==0.109.0
|
||||
uvicorn[standard]==0.27.0
|
||||
python-multipart==0.0.6
|
||||
websockets==12.0
|
||||
|
||||
# MongoDB
|
||||
motor==3.3.2
|
||||
pymongo==4.6.1
|
||||
beanie==1.25.0
|
||||
|
||||
# HTTP client
|
||||
httpx==0.26.0
|
||||
aiohttp==3.9.1
|
||||
|
||||
# AI SDK clients
|
||||
openai==1.12.0
|
||||
google-generativeai==0.3.2
|
||||
zhipuai==2.0.1
|
||||
|
||||
# Data validation
|
||||
pydantic==2.6.0
|
||||
pydantic-settings==2.1.0
|
||||
|
||||
# Security
|
||||
cryptography==42.0.2
|
||||
python-jose[cryptography]==3.3.0
|
||||
passlib[bcrypt]==1.7.4
|
||||
|
||||
# Utilities
|
||||
python-dotenv==1.0.1
|
||||
tenacity==8.2.3
|
||||
numpy==1.26.4
|
||||
|
||||
# For embeddings and vector search
|
||||
sentence-transformers==2.3.1
|
||||
|
||||
# Rate limiting
|
||||
slowapi==0.1.9
|
||||
|
||||
# Logging
|
||||
loguru==0.7.2
|
||||
14
backend/routers/__init__.py
Normal file
14
backend/routers/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
API路由模块
|
||||
"""
|
||||
from . import providers
|
||||
from . import agents
|
||||
from . import chatrooms
|
||||
from . import discussions
|
||||
|
||||
__all__ = [
|
||||
"providers",
|
||||
"agents",
|
||||
"chatrooms",
|
||||
"discussions",
|
||||
]
|
||||
314
backend/routers/agents.py
Normal file
314
backend/routers/agents.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""
|
||||
Agent管理路由
|
||||
"""
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
|
||||
from services.agent_service import AgentService, AGENT_TEMPLATES
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============ 请求/响应模型 ============
|
||||
|
||||
class CapabilitiesModel(BaseModel):
|
||||
"""能力配置模型"""
|
||||
memory_enabled: bool = False
|
||||
mcp_tools: List[str] = []
|
||||
skills: List[str] = []
|
||||
multimodal: bool = False
|
||||
|
||||
|
||||
class BehaviorModel(BaseModel):
|
||||
"""行为配置模型"""
|
||||
speak_threshold: float = 0.5
|
||||
max_speak_per_round: int = 2
|
||||
speak_style: str = "balanced"
|
||||
|
||||
|
||||
class AgentCreateRequest(BaseModel):
|
||||
"""创建Agent请求"""
|
||||
name: str = Field(..., description="Agent名称")
|
||||
role: str = Field(..., description="角色定义")
|
||||
system_prompt: str = Field(..., description="系统提示词")
|
||||
provider_id: str = Field(..., description="使用的AI接口ID")
|
||||
temperature: float = Field(default=0.7, ge=0, le=2, description="温度参数")
|
||||
max_tokens: int = Field(default=2000, gt=0, description="最大token数")
|
||||
capabilities: Optional[CapabilitiesModel] = None
|
||||
behavior: Optional[BehaviorModel] = None
|
||||
avatar: Optional[str] = None
|
||||
color: str = "#1890ff"
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"name": "产品经理",
|
||||
"role": "产品规划和需求分析专家",
|
||||
"system_prompt": "你是一位经验丰富的产品经理...",
|
||||
"provider_id": "openrouter-abc123",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 2000
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class AgentUpdateRequest(BaseModel):
|
||||
"""更新Agent请求"""
|
||||
name: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
system_prompt: Optional[str] = None
|
||||
provider_id: Optional[str] = None
|
||||
temperature: Optional[float] = Field(default=None, ge=0, le=2)
|
||||
max_tokens: Optional[int] = Field(default=None, gt=0)
|
||||
capabilities: Optional[CapabilitiesModel] = None
|
||||
behavior: Optional[BehaviorModel] = None
|
||||
avatar: Optional[str] = None
|
||||
color: Optional[str] = None
|
||||
enabled: Optional[bool] = None
|
||||
|
||||
|
||||
class AgentResponse(BaseModel):
|
||||
"""Agent响应"""
|
||||
agent_id: str
|
||||
name: str
|
||||
role: str
|
||||
system_prompt: str
|
||||
provider_id: str
|
||||
temperature: float
|
||||
max_tokens: int
|
||||
capabilities: Dict[str, Any]
|
||||
behavior: Dict[str, Any]
|
||||
avatar: Optional[str]
|
||||
color: str
|
||||
enabled: bool
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class AgentTestRequest(BaseModel):
|
||||
"""Agent测试请求"""
|
||||
message: str = "你好,请简单介绍一下你自己。"
|
||||
|
||||
|
||||
class AgentTestResponse(BaseModel):
|
||||
"""Agent测试响应"""
|
||||
success: bool
|
||||
message: str
|
||||
response: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
tokens: Optional[int] = None
|
||||
latency_ms: Optional[float] = None
|
||||
|
||||
|
||||
class TemplateResponse(BaseModel):
|
||||
"""模板响应"""
|
||||
template_id: str
|
||||
name: str
|
||||
role: str
|
||||
system_prompt: str
|
||||
color: str
|
||||
|
||||
|
||||
class GeneratePromptRequest(BaseModel):
|
||||
"""生成提示词请求"""
|
||||
provider_id: str = Field(..., description="使用的AI接口ID")
|
||||
name: str = Field(..., description="Agent名称")
|
||||
role: str = Field(..., description="角色定位")
|
||||
description: Optional[str] = Field(None, description="额外描述(可选)")
|
||||
|
||||
|
||||
class GeneratePromptResponse(BaseModel):
|
||||
"""生成提示词响应"""
|
||||
success: bool
|
||||
message: Optional[str] = None
|
||||
prompt: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
tokens: Optional[int] = None
|
||||
|
||||
|
||||
# ============ 路由处理 ============
|
||||
|
||||
@router.post("", response_model=AgentResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_agent(request: AgentCreateRequest):
|
||||
"""
|
||||
创建新的Agent
|
||||
"""
|
||||
try:
|
||||
agent = await AgentService.create_agent(
|
||||
name=request.name,
|
||||
role=request.role,
|
||||
system_prompt=request.system_prompt,
|
||||
provider_id=request.provider_id,
|
||||
temperature=request.temperature,
|
||||
max_tokens=request.max_tokens,
|
||||
capabilities=request.capabilities.dict() if request.capabilities else None,
|
||||
behavior=request.behavior.dict() if request.behavior else None,
|
||||
avatar=request.avatar,
|
||||
color=request.color
|
||||
)
|
||||
|
||||
return _to_response(agent)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"创建Agent失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="创建失败")
|
||||
|
||||
|
||||
@router.get("", response_model=List[AgentResponse])
|
||||
async def list_agents(enabled_only: bool = False):
|
||||
"""
|
||||
获取所有Agent
|
||||
"""
|
||||
agents = await AgentService.get_all_agents(enabled_only)
|
||||
return [_to_response(a) for a in agents]
|
||||
|
||||
|
||||
@router.get("/templates", response_model=List[TemplateResponse])
|
||||
async def list_templates():
|
||||
"""
|
||||
获取Agent预设模板
|
||||
"""
|
||||
return [
|
||||
TemplateResponse(
|
||||
template_id=tid,
|
||||
name=t["name"],
|
||||
role=t["role"],
|
||||
system_prompt=t["system_prompt"],
|
||||
color=t["color"]
|
||||
)
|
||||
for tid, t in AGENT_TEMPLATES.items()
|
||||
]
|
||||
|
||||
|
||||
@router.post("/generate-prompt", response_model=GeneratePromptResponse)
|
||||
async def generate_prompt(request: GeneratePromptRequest):
|
||||
"""
|
||||
使用AI生成Agent系统提示词
|
||||
"""
|
||||
result = await AgentService.generate_system_prompt(
|
||||
provider_id=request.provider_id,
|
||||
name=request.name,
|
||||
role=request.role,
|
||||
description=request.description
|
||||
)
|
||||
return GeneratePromptResponse(**result)
|
||||
|
||||
|
||||
@router.get("/{agent_id}", response_model=AgentResponse)
|
||||
async def get_agent(agent_id: str):
|
||||
"""
|
||||
获取指定Agent
|
||||
"""
|
||||
agent = await AgentService.get_agent(agent_id)
|
||||
if not agent:
|
||||
raise HTTPException(status_code=404, detail="Agent不存在")
|
||||
return _to_response(agent)
|
||||
|
||||
|
||||
@router.put("/{agent_id}", response_model=AgentResponse)
|
||||
async def update_agent(agent_id: str, request: AgentUpdateRequest):
|
||||
"""
|
||||
更新Agent配置
|
||||
"""
|
||||
update_data = request.dict(exclude_unset=True)
|
||||
|
||||
# 转换嵌套模型
|
||||
if "capabilities" in update_data and update_data["capabilities"]:
|
||||
if hasattr(update_data["capabilities"], "dict"):
|
||||
update_data["capabilities"] = update_data["capabilities"].dict()
|
||||
if "behavior" in update_data and update_data["behavior"]:
|
||||
if hasattr(update_data["behavior"], "dict"):
|
||||
update_data["behavior"] = update_data["behavior"].dict()
|
||||
|
||||
try:
|
||||
agent = await AgentService.update_agent(agent_id, **update_data)
|
||||
if not agent:
|
||||
raise HTTPException(status_code=404, detail="Agent不存在")
|
||||
return _to_response(agent)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{agent_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_agent(agent_id: str):
|
||||
"""
|
||||
删除Agent
|
||||
"""
|
||||
success = await AgentService.delete_agent(agent_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Agent不存在")
|
||||
|
||||
|
||||
@router.post("/{agent_id}/test", response_model=AgentTestResponse)
|
||||
async def test_agent(agent_id: str, request: AgentTestRequest = None):
|
||||
"""
|
||||
测试Agent对话
|
||||
"""
|
||||
message = request.message if request else "你好,请简单介绍一下你自己。"
|
||||
result = await AgentService.test_agent(agent_id, message)
|
||||
return AgentTestResponse(**result)
|
||||
|
||||
|
||||
@router.post("/{agent_id}/duplicate", response_model=AgentResponse)
|
||||
async def duplicate_agent(agent_id: str, new_name: Optional[str] = None):
|
||||
"""
|
||||
复制Agent
|
||||
"""
|
||||
agent = await AgentService.duplicate_agent(agent_id, new_name)
|
||||
if not agent:
|
||||
raise HTTPException(status_code=404, detail="源Agent不存在")
|
||||
return _to_response(agent)
|
||||
|
||||
|
||||
@router.post("/from-template/{template_id}", response_model=AgentResponse)
|
||||
async def create_from_template(template_id: str, provider_id: str):
|
||||
"""
|
||||
从模板创建Agent
|
||||
"""
|
||||
if template_id not in AGENT_TEMPLATES:
|
||||
raise HTTPException(status_code=404, detail="模板不存在")
|
||||
|
||||
template = AGENT_TEMPLATES[template_id]
|
||||
|
||||
try:
|
||||
agent = await AgentService.create_agent(
|
||||
name=template["name"],
|
||||
role=template["role"],
|
||||
system_prompt=template["system_prompt"],
|
||||
provider_id=provider_id,
|
||||
color=template["color"]
|
||||
)
|
||||
return _to_response(agent)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
# ============ 辅助函数 ============
|
||||
|
||||
def _to_response(agent) -> AgentResponse:
|
||||
"""
|
||||
转换为响应模型
|
||||
"""
|
||||
return AgentResponse(
|
||||
agent_id=agent.agent_id,
|
||||
name=agent.name,
|
||||
role=agent.role,
|
||||
system_prompt=agent.system_prompt,
|
||||
provider_id=agent.provider_id,
|
||||
temperature=agent.temperature,
|
||||
max_tokens=agent.max_tokens,
|
||||
capabilities=agent.capabilities,
|
||||
behavior=agent.behavior,
|
||||
avatar=agent.avatar,
|
||||
color=agent.color,
|
||||
enabled=agent.enabled,
|
||||
created_at=agent.created_at.isoformat(),
|
||||
updated_at=agent.updated_at.isoformat()
|
||||
)
|
||||
387
backend/routers/chatrooms.py
Normal file
387
backend/routers/chatrooms.py
Normal file
@@ -0,0 +1,387 @@
|
||||
"""
|
||||
聊天室管理路由
|
||||
"""
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect, status
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
|
||||
from services.chatroom_service import ChatRoomService
|
||||
from services.discussion_engine import DiscussionEngine
|
||||
from services.message_router import MessageRouter
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============ 请求/响应模型 ============
|
||||
|
||||
class ChatRoomConfigModel(BaseModel):
|
||||
"""聊天室配置模型"""
|
||||
max_rounds: int = 50
|
||||
message_history_size: int = 20
|
||||
consensus_threshold: float = 0.8
|
||||
round_interval: float = 1.0
|
||||
allow_user_interrupt: bool = True
|
||||
|
||||
|
||||
class ChatRoomCreateRequest(BaseModel):
|
||||
"""创建聊天室请求"""
|
||||
name: str = Field(..., description="聊天室名称")
|
||||
description: str = Field(default="", description="描述")
|
||||
agents: List[str] = Field(default=[], description="Agent ID列表")
|
||||
moderator_agent_id: Optional[str] = Field(default=None, description="主持人Agent ID")
|
||||
config: Optional[ChatRoomConfigModel] = None
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"name": "产品设计讨论室",
|
||||
"description": "用于讨论新产品功能设计",
|
||||
"agents": ["agent-abc123", "agent-def456"],
|
||||
"moderator_agent_id": "agent-xyz789"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ChatRoomUpdateRequest(BaseModel):
|
||||
"""更新聊天室请求"""
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
agents: Optional[List[str]] = None
|
||||
moderator_agent_id: Optional[str] = None
|
||||
config: Optional[ChatRoomConfigModel] = None
|
||||
|
||||
|
||||
class ChatRoomResponse(BaseModel):
|
||||
"""聊天室响应"""
|
||||
room_id: str
|
||||
name: str
|
||||
description: str
|
||||
objective: str
|
||||
agents: List[str]
|
||||
moderator_agent_id: Optional[str]
|
||||
config: Dict[str, Any]
|
||||
status: str
|
||||
current_round: int
|
||||
current_discussion_id: Optional[str]
|
||||
created_at: str
|
||||
updated_at: str
|
||||
completed_at: Optional[str]
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
"""消息响应"""
|
||||
message_id: str
|
||||
room_id: str
|
||||
discussion_id: str
|
||||
agent_id: Optional[str]
|
||||
content: str
|
||||
message_type: str
|
||||
round: int
|
||||
created_at: str
|
||||
|
||||
|
||||
class StartDiscussionRequest(BaseModel):
|
||||
"""启动讨论请求"""
|
||||
objective: str = Field(..., description="讨论目标")
|
||||
|
||||
|
||||
class DiscussionStatusResponse(BaseModel):
|
||||
"""讨论状态响应"""
|
||||
is_active: bool
|
||||
room_id: str
|
||||
discussion_id: Optional[str] = None
|
||||
current_round: int = 0
|
||||
status: str
|
||||
|
||||
|
||||
# ============ 路由处理 ============
|
||||
|
||||
@router.post("", response_model=ChatRoomResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_chatroom(request: ChatRoomCreateRequest):
|
||||
"""
|
||||
创建新的聊天室
|
||||
"""
|
||||
try:
|
||||
chatroom = await ChatRoomService.create_chatroom(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
agents=request.agents,
|
||||
moderator_agent_id=request.moderator_agent_id,
|
||||
config=request.config.dict() if request.config else None
|
||||
)
|
||||
|
||||
return _to_response(chatroom)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"创建聊天室失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="创建失败")
|
||||
|
||||
|
||||
@router.get("", response_model=List[ChatRoomResponse])
|
||||
async def list_chatrooms():
|
||||
"""
|
||||
获取所有聊天室
|
||||
"""
|
||||
chatrooms = await ChatRoomService.get_all_chatrooms()
|
||||
return [_to_response(c) for c in chatrooms]
|
||||
|
||||
|
||||
@router.get("/{room_id}", response_model=ChatRoomResponse)
|
||||
async def get_chatroom(room_id: str):
|
||||
"""
|
||||
获取指定聊天室
|
||||
"""
|
||||
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||
if not chatroom:
|
||||
raise HTTPException(status_code=404, detail="聊天室不存在")
|
||||
return _to_response(chatroom)
|
||||
|
||||
|
||||
@router.put("/{room_id}", response_model=ChatRoomResponse)
|
||||
async def update_chatroom(room_id: str, request: ChatRoomUpdateRequest):
|
||||
"""
|
||||
更新聊天室配置
|
||||
"""
|
||||
update_data = request.dict(exclude_unset=True)
|
||||
|
||||
if "config" in update_data and update_data["config"]:
|
||||
if hasattr(update_data["config"], "dict"):
|
||||
update_data["config"] = update_data["config"].dict()
|
||||
|
||||
try:
|
||||
chatroom = await ChatRoomService.update_chatroom(room_id, **update_data)
|
||||
if not chatroom:
|
||||
raise HTTPException(status_code=404, detail="聊天室不存在")
|
||||
return _to_response(chatroom)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{room_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_chatroom(room_id: str):
|
||||
"""
|
||||
删除聊天室
|
||||
"""
|
||||
success = await ChatRoomService.delete_chatroom(room_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="聊天室不存在")
|
||||
|
||||
|
||||
@router.post("/{room_id}/agents/{agent_id}", response_model=ChatRoomResponse)
|
||||
async def add_agent_to_chatroom(room_id: str, agent_id: str):
|
||||
"""
|
||||
向聊天室添加Agent
|
||||
"""
|
||||
try:
|
||||
chatroom = await ChatRoomService.add_agent(room_id, agent_id)
|
||||
if not chatroom:
|
||||
raise HTTPException(status_code=404, detail="聊天室不存在")
|
||||
return _to_response(chatroom)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{room_id}/agents/{agent_id}", response_model=ChatRoomResponse)
|
||||
async def remove_agent_from_chatroom(room_id: str, agent_id: str):
|
||||
"""
|
||||
从聊天室移除Agent
|
||||
"""
|
||||
chatroom = await ChatRoomService.remove_agent(room_id, agent_id)
|
||||
if not chatroom:
|
||||
raise HTTPException(status_code=404, detail="聊天室不存在")
|
||||
return _to_response(chatroom)
|
||||
|
||||
|
||||
@router.get("/{room_id}/messages", response_model=List[MessageResponse])
|
||||
async def get_chatroom_messages(
|
||||
room_id: str,
|
||||
limit: int = 50,
|
||||
skip: int = 0,
|
||||
discussion_id: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
获取聊天室消息历史
|
||||
"""
|
||||
messages = await ChatRoomService.get_messages(
|
||||
room_id, limit, skip, discussion_id
|
||||
)
|
||||
return [_message_to_response(m) for m in messages]
|
||||
|
||||
|
||||
@router.post("/{room_id}/start", response_model=DiscussionStatusResponse)
|
||||
async def start_discussion(room_id: str, request: StartDiscussionRequest):
|
||||
"""
|
||||
启动讨论
|
||||
"""
|
||||
try:
|
||||
# 异步启动讨论(不等待完成)
|
||||
import asyncio
|
||||
asyncio.create_task(
|
||||
DiscussionEngine.start_discussion(room_id, request.objective)
|
||||
)
|
||||
|
||||
# 等待一小段时间让讨论初始化
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||
|
||||
return DiscussionStatusResponse(
|
||||
is_active=True,
|
||||
room_id=room_id,
|
||||
discussion_id=chatroom.current_discussion_id if chatroom else None,
|
||||
current_round=chatroom.current_round if chatroom else 0,
|
||||
status=chatroom.status if chatroom else "unknown"
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{room_id}/pause", response_model=DiscussionStatusResponse)
|
||||
async def pause_discussion(room_id: str):
|
||||
"""
|
||||
暂停讨论
|
||||
"""
|
||||
success = await DiscussionEngine.pause_discussion(room_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="没有进行中的讨论")
|
||||
|
||||
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||
return DiscussionStatusResponse(
|
||||
is_active=False,
|
||||
room_id=room_id,
|
||||
discussion_id=chatroom.current_discussion_id if chatroom else None,
|
||||
current_round=chatroom.current_round if chatroom else 0,
|
||||
status="paused"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{room_id}/resume", response_model=DiscussionStatusResponse)
|
||||
async def resume_discussion(room_id: str):
|
||||
"""
|
||||
恢复讨论
|
||||
"""
|
||||
success = await DiscussionEngine.resume_discussion(room_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="聊天室不在暂停状态")
|
||||
|
||||
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||
return DiscussionStatusResponse(
|
||||
is_active=True,
|
||||
room_id=room_id,
|
||||
discussion_id=chatroom.current_discussion_id if chatroom else None,
|
||||
current_round=chatroom.current_round if chatroom else 0,
|
||||
status="active"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{room_id}/stop", response_model=DiscussionStatusResponse)
|
||||
async def stop_discussion(room_id: str):
|
||||
"""
|
||||
停止讨论
|
||||
"""
|
||||
success = await DiscussionEngine.stop_discussion(room_id)
|
||||
|
||||
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||
return DiscussionStatusResponse(
|
||||
is_active=False,
|
||||
room_id=room_id,
|
||||
discussion_id=chatroom.current_discussion_id if chatroom else None,
|
||||
current_round=chatroom.current_round if chatroom else 0,
|
||||
status="stopping" if success else chatroom.status if chatroom else "unknown"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{room_id}/status", response_model=DiscussionStatusResponse)
|
||||
async def get_discussion_status(room_id: str):
|
||||
"""
|
||||
获取讨论状态
|
||||
"""
|
||||
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||
if not chatroom:
|
||||
raise HTTPException(status_code=404, detail="聊天室不存在")
|
||||
|
||||
is_active = DiscussionEngine.is_discussion_active(room_id)
|
||||
|
||||
return DiscussionStatusResponse(
|
||||
is_active=is_active,
|
||||
room_id=room_id,
|
||||
discussion_id=chatroom.current_discussion_id,
|
||||
current_round=chatroom.current_round,
|
||||
status=chatroom.status
|
||||
)
|
||||
|
||||
|
||||
# ============ WebSocket端点 ============
|
||||
|
||||
@router.websocket("/ws/{room_id}")
|
||||
async def chatroom_websocket(websocket: WebSocket, room_id: str):
|
||||
"""
|
||||
聊天室WebSocket连接
|
||||
"""
|
||||
# 验证聊天室存在
|
||||
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||
if not chatroom:
|
||||
await websocket.close(code=4004, reason="聊天室不存在")
|
||||
return
|
||||
|
||||
await MessageRouter.connect(room_id, websocket)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# 保持连接,接收客户端消息(如心跳)
|
||||
data = await websocket.receive_text()
|
||||
|
||||
# 处理心跳
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
|
||||
except WebSocketDisconnect:
|
||||
await MessageRouter.disconnect(room_id, websocket)
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket错误: {e}")
|
||||
await MessageRouter.disconnect(room_id, websocket)
|
||||
|
||||
|
||||
# ============ 辅助函数 ============
|
||||
|
||||
def _to_response(chatroom) -> ChatRoomResponse:
|
||||
"""
|
||||
转换为响应模型
|
||||
"""
|
||||
return ChatRoomResponse(
|
||||
room_id=chatroom.room_id,
|
||||
name=chatroom.name,
|
||||
description=chatroom.description,
|
||||
objective=chatroom.objective,
|
||||
agents=chatroom.agents,
|
||||
moderator_agent_id=chatroom.moderator_agent_id,
|
||||
config=chatroom.config,
|
||||
status=chatroom.status,
|
||||
current_round=chatroom.current_round,
|
||||
current_discussion_id=chatroom.current_discussion_id,
|
||||
created_at=chatroom.created_at.isoformat(),
|
||||
updated_at=chatroom.updated_at.isoformat(),
|
||||
completed_at=chatroom.completed_at.isoformat() if chatroom.completed_at else None
|
||||
)
|
||||
|
||||
|
||||
def _message_to_response(message) -> MessageResponse:
|
||||
"""
|
||||
转换消息为响应模型
|
||||
"""
|
||||
return MessageResponse(
|
||||
message_id=message.message_id,
|
||||
room_id=message.room_id,
|
||||
discussion_id=message.discussion_id,
|
||||
agent_id=message.agent_id,
|
||||
content=message.content,
|
||||
message_type=message.message_type,
|
||||
round=message.round,
|
||||
created_at=message.created_at.isoformat()
|
||||
)
|
||||
136
backend/routers/discussions.py
Normal file
136
backend/routers/discussions.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
讨论结果路由
|
||||
"""
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.discussion_result import DiscussionResult
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============ 响应模型 ============
|
||||
|
||||
class DiscussionResponse(BaseModel):
|
||||
"""讨论结果响应"""
|
||||
discussion_id: str
|
||||
room_id: str
|
||||
objective: str
|
||||
consensus_reached: bool
|
||||
confidence: float
|
||||
summary: str
|
||||
action_items: List[str]
|
||||
unresolved_issues: List[str]
|
||||
key_decisions: List[str]
|
||||
total_rounds: int
|
||||
total_messages: int
|
||||
participating_agents: List[str]
|
||||
agent_contributions: Dict[str, int]
|
||||
status: str
|
||||
end_reason: str
|
||||
created_at: str
|
||||
completed_at: Optional[str]
|
||||
|
||||
|
||||
class DiscussionListResponse(BaseModel):
|
||||
"""讨论列表响应"""
|
||||
discussions: List[DiscussionResponse]
|
||||
total: int
|
||||
|
||||
|
||||
# ============ 路由处理 ============
|
||||
|
||||
@router.get("", response_model=DiscussionListResponse)
|
||||
async def list_discussions(
|
||||
room_id: Optional[str] = None,
|
||||
limit: int = 20,
|
||||
skip: int = 0
|
||||
):
|
||||
"""
|
||||
获取讨论结果列表
|
||||
"""
|
||||
query = {}
|
||||
if room_id:
|
||||
query["room_id"] = room_id
|
||||
|
||||
discussions = await DiscussionResult.find(query).sort(
|
||||
"-created_at"
|
||||
).skip(skip).limit(limit).to_list()
|
||||
|
||||
total = await DiscussionResult.find(query).count()
|
||||
|
||||
return DiscussionListResponse(
|
||||
discussions=[_to_response(d) for d in discussions],
|
||||
total=total
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{discussion_id}", response_model=DiscussionResponse)
|
||||
async def get_discussion(discussion_id: str):
|
||||
"""
|
||||
获取指定讨论结果
|
||||
"""
|
||||
discussion = await DiscussionResult.find_one(
|
||||
DiscussionResult.discussion_id == discussion_id
|
||||
)
|
||||
|
||||
if not discussion:
|
||||
raise HTTPException(status_code=404, detail="讨论记录不存在")
|
||||
|
||||
return _to_response(discussion)
|
||||
|
||||
|
||||
@router.get("/room/{room_id}", response_model=List[DiscussionResponse])
|
||||
async def get_room_discussions(room_id: str, limit: int = 10):
|
||||
"""
|
||||
获取聊天室的讨论历史
|
||||
"""
|
||||
discussions = await DiscussionResult.find(
|
||||
{"room_id": room_id}
|
||||
).sort("-created_at").limit(limit).to_list()
|
||||
|
||||
return [_to_response(d) for d in discussions]
|
||||
|
||||
|
||||
@router.get("/room/{room_id}/latest", response_model=DiscussionResponse)
|
||||
async def get_latest_discussion(room_id: str):
|
||||
"""
|
||||
获取聊天室最新的讨论结果
|
||||
"""
|
||||
discussion = await DiscussionResult.find(
|
||||
{"room_id": room_id}
|
||||
).sort("-created_at").first_or_none()
|
||||
|
||||
if not discussion:
|
||||
raise HTTPException(status_code=404, detail="没有找到讨论记录")
|
||||
|
||||
return _to_response(discussion)
|
||||
|
||||
|
||||
# ============ 辅助函数 ============
|
||||
|
||||
def _to_response(discussion: DiscussionResult) -> DiscussionResponse:
|
||||
"""
|
||||
转换为响应模型
|
||||
"""
|
||||
return DiscussionResponse(
|
||||
discussion_id=discussion.discussion_id,
|
||||
room_id=discussion.room_id,
|
||||
objective=discussion.objective,
|
||||
consensus_reached=discussion.consensus_reached,
|
||||
confidence=discussion.confidence,
|
||||
summary=discussion.summary,
|
||||
action_items=discussion.action_items,
|
||||
unresolved_issues=discussion.unresolved_issues,
|
||||
key_decisions=discussion.key_decisions,
|
||||
total_rounds=discussion.total_rounds,
|
||||
total_messages=discussion.total_messages,
|
||||
participating_agents=discussion.participating_agents,
|
||||
agent_contributions=discussion.agent_contributions,
|
||||
status=discussion.status,
|
||||
end_reason=discussion.end_reason,
|
||||
created_at=discussion.created_at.isoformat(),
|
||||
completed_at=discussion.completed_at.isoformat() if discussion.completed_at else None
|
||||
)
|
||||
241
backend/routers/providers.py
Normal file
241
backend/routers/providers.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
AI接口管理路由
|
||||
"""
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
|
||||
from services.ai_provider_service import AIProviderService
|
||||
from utils.encryption import mask_api_key
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============ 请求/响应模型 ============
|
||||
|
||||
class ProxyConfigModel(BaseModel):
|
||||
"""代理配置模型"""
|
||||
http_proxy: Optional[str] = None
|
||||
https_proxy: Optional[str] = None
|
||||
no_proxy: List[str] = []
|
||||
|
||||
|
||||
class RateLimitModel(BaseModel):
|
||||
"""速率限制模型"""
|
||||
requests_per_minute: int = 60
|
||||
tokens_per_minute: int = 100000
|
||||
|
||||
|
||||
class ProviderCreateRequest(BaseModel):
|
||||
"""创建AI接口请求"""
|
||||
provider_type: str = Field(..., description="提供商类型: minimax, zhipu, openrouter, kimi, deepseek, gemini, ollama, llmstudio")
|
||||
name: str = Field(..., description="自定义名称")
|
||||
model: str = Field(..., description="模型名称")
|
||||
api_key: str = Field(default="", description="API密钥")
|
||||
base_url: str = Field(default="", description="API基础URL")
|
||||
use_proxy: bool = Field(default=False, description="是否使用代理")
|
||||
proxy_config: Optional[ProxyConfigModel] = None
|
||||
rate_limit: Optional[RateLimitModel] = None
|
||||
timeout: int = Field(default=60, description="超时时间(秒)")
|
||||
extra_params: Dict[str, Any] = Field(default_factory=dict, description="额外参数")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"provider_type": "openrouter",
|
||||
"name": "OpenRouter GPT-4",
|
||||
"model": "openai/gpt-4-turbo",
|
||||
"api_key": "sk-xxx",
|
||||
"use_proxy": True,
|
||||
"proxy_config": {
|
||||
"http_proxy": "http://127.0.0.1:7890",
|
||||
"https_proxy": "http://127.0.0.1:7890"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ProviderUpdateRequest(BaseModel):
|
||||
"""更新AI接口请求"""
|
||||
name: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
use_proxy: Optional[bool] = None
|
||||
proxy_config: Optional[ProxyConfigModel] = None
|
||||
rate_limit: Optional[RateLimitModel] = None
|
||||
timeout: Optional[int] = None
|
||||
extra_params: Optional[Dict[str, Any]] = None
|
||||
enabled: Optional[bool] = None
|
||||
|
||||
|
||||
class ProviderResponse(BaseModel):
|
||||
"""AI接口响应"""
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
name: str
|
||||
api_key_masked: str
|
||||
base_url: str
|
||||
model: str
|
||||
use_proxy: bool
|
||||
proxy_config: Dict[str, Any]
|
||||
rate_limit: Dict[str, int]
|
||||
timeout: int
|
||||
extra_params: Dict[str, Any]
|
||||
enabled: bool
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class TestConfigRequest(BaseModel):
|
||||
"""测试配置请求"""
|
||||
provider_type: str
|
||||
api_key: str
|
||||
base_url: str = ""
|
||||
model: str = ""
|
||||
use_proxy: bool = False
|
||||
proxy_config: Optional[ProxyConfigModel] = None
|
||||
timeout: int = 30
|
||||
|
||||
|
||||
class TestResponse(BaseModel):
|
||||
"""测试响应"""
|
||||
success: bool
|
||||
message: str
|
||||
model: Optional[str] = None
|
||||
latency_ms: Optional[float] = None
|
||||
|
||||
|
||||
# ============ 路由处理 ============
|
||||
|
||||
@router.post("", response_model=ProviderResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_provider(request: ProviderCreateRequest):
|
||||
"""
|
||||
创建新的AI接口配置
|
||||
"""
|
||||
try:
|
||||
provider = await AIProviderService.create_provider(
|
||||
provider_type=request.provider_type,
|
||||
name=request.name,
|
||||
model=request.model,
|
||||
api_key=request.api_key,
|
||||
base_url=request.base_url,
|
||||
use_proxy=request.use_proxy,
|
||||
proxy_config=request.proxy_config.dict() if request.proxy_config else None,
|
||||
rate_limit=request.rate_limit.dict() if request.rate_limit else None,
|
||||
timeout=request.timeout,
|
||||
extra_params=request.extra_params
|
||||
)
|
||||
|
||||
return _to_response(provider)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"创建AI接口失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="创建失败")
|
||||
|
||||
|
||||
@router.get("", response_model=List[ProviderResponse])
|
||||
async def list_providers(enabled_only: bool = False):
|
||||
"""
|
||||
获取所有AI接口配置
|
||||
"""
|
||||
providers = await AIProviderService.get_all_providers(enabled_only)
|
||||
return [_to_response(p) for p in providers]
|
||||
|
||||
|
||||
@router.get("/{provider_id}", response_model=ProviderResponse)
|
||||
async def get_provider(provider_id: str):
|
||||
"""
|
||||
获取指定AI接口配置
|
||||
"""
|
||||
provider = await AIProviderService.get_provider(provider_id)
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="AI接口不存在")
|
||||
return _to_response(provider)
|
||||
|
||||
|
||||
@router.put("/{provider_id}", response_model=ProviderResponse)
|
||||
async def update_provider(provider_id: str, request: ProviderUpdateRequest):
|
||||
"""
|
||||
更新AI接口配置
|
||||
"""
|
||||
update_data = request.dict(exclude_unset=True)
|
||||
|
||||
# 转换嵌套模型
|
||||
if "proxy_config" in update_data and update_data["proxy_config"]:
|
||||
update_data["proxy_config"] = update_data["proxy_config"].dict() if hasattr(update_data["proxy_config"], "dict") else update_data["proxy_config"]
|
||||
if "rate_limit" in update_data and update_data["rate_limit"]:
|
||||
update_data["rate_limit"] = update_data["rate_limit"].dict() if hasattr(update_data["rate_limit"], "dict") else update_data["rate_limit"]
|
||||
|
||||
try:
|
||||
provider = await AIProviderService.update_provider(provider_id, **update_data)
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="AI接口不存在")
|
||||
return _to_response(provider)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{provider_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_provider(provider_id: str):
|
||||
"""
|
||||
删除AI接口配置
|
||||
"""
|
||||
success = await AIProviderService.delete_provider(provider_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="AI接口不存在")
|
||||
|
||||
|
||||
@router.post("/{provider_id}/test", response_model=TestResponse)
|
||||
async def test_provider(provider_id: str):
|
||||
"""
|
||||
测试AI接口连接
|
||||
"""
|
||||
result = await AIProviderService.test_provider(provider_id)
|
||||
return TestResponse(**result)
|
||||
|
||||
|
||||
@router.post("/test", response_model=TestResponse)
|
||||
async def test_provider_config(request: TestConfigRequest):
|
||||
"""
|
||||
测试AI接口配置(不保存)
|
||||
"""
|
||||
result = await AIProviderService.test_provider_config(
|
||||
provider_type=request.provider_type,
|
||||
api_key=request.api_key,
|
||||
base_url=request.base_url,
|
||||
model=request.model,
|
||||
use_proxy=request.use_proxy,
|
||||
proxy_config=request.proxy_config.dict() if request.proxy_config else None,
|
||||
timeout=request.timeout
|
||||
)
|
||||
return TestResponse(**result)
|
||||
|
||||
|
||||
# ============ 辅助函数 ============
|
||||
|
||||
def _to_response(provider) -> ProviderResponse:
|
||||
"""
|
||||
转换为响应模型
|
||||
"""
|
||||
return ProviderResponse(
|
||||
provider_id=provider.provider_id,
|
||||
provider_type=provider.provider_type,
|
||||
name=provider.name,
|
||||
api_key_masked=mask_api_key(provider.api_key) if provider.api_key else "",
|
||||
base_url=provider.base_url,
|
||||
model=provider.model,
|
||||
use_proxy=provider.use_proxy,
|
||||
proxy_config=provider.proxy_config,
|
||||
rate_limit=provider.rate_limit,
|
||||
timeout=provider.timeout,
|
||||
extra_params=provider.extra_params,
|
||||
enabled=provider.enabled,
|
||||
created_at=provider.created_at.isoformat(),
|
||||
updated_at=provider.updated_at.isoformat()
|
||||
)
|
||||
22
backend/services/__init__.py
Normal file
22
backend/services/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
业务服务模块
|
||||
"""
|
||||
from .ai_provider_service import AIProviderService
|
||||
from .agent_service import AgentService
|
||||
from .chatroom_service import ChatRoomService
|
||||
from .message_router import MessageRouter
|
||||
from .discussion_engine import DiscussionEngine
|
||||
from .consensus_manager import ConsensusManager
|
||||
from .mcp_service import MCPService
|
||||
from .memory_service import MemoryService
|
||||
|
||||
__all__ = [
|
||||
"AIProviderService",
|
||||
"AgentService",
|
||||
"ChatRoomService",
|
||||
"MessageRouter",
|
||||
"DiscussionEngine",
|
||||
"ConsensusManager",
|
||||
"MCPService",
|
||||
"MemoryService",
|
||||
]
|
||||
438
backend/services/agent_service.py
Normal file
438
backend/services/agent_service.py
Normal file
@@ -0,0 +1,438 @@
|
||||
"""
|
||||
Agent服务
|
||||
管理AI代理的配置
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from loguru import logger
|
||||
|
||||
from models.agent import Agent
|
||||
from services.ai_provider_service import AIProviderService
|
||||
|
||||
|
||||
class AgentService:
|
||||
"""
|
||||
Agent服务类
|
||||
负责Agent的CRUD操作
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def create_agent(
|
||||
cls,
|
||||
name: str,
|
||||
role: str,
|
||||
system_prompt: str,
|
||||
provider_id: str,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
capabilities: Optional[Dict[str, Any]] = None,
|
||||
behavior: Optional[Dict[str, Any]] = None,
|
||||
avatar: Optional[str] = None,
|
||||
color: str = "#1890ff"
|
||||
) -> Agent:
|
||||
"""
|
||||
创建新的Agent
|
||||
|
||||
Args:
|
||||
name: Agent名称
|
||||
role: 角色定义
|
||||
system_prompt: 系统提示词
|
||||
provider_id: 使用的AI接口ID
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
capabilities: 能力配置
|
||||
behavior: 行为配置
|
||||
avatar: 头像URL
|
||||
color: 代表颜色
|
||||
|
||||
Returns:
|
||||
创建的Agent文档
|
||||
"""
|
||||
# 验证AI接口存在
|
||||
provider = await AIProviderService.get_provider(provider_id)
|
||||
if not provider:
|
||||
raise ValueError(f"AI接口不存在: {provider_id}")
|
||||
|
||||
# 生成唯一ID
|
||||
agent_id = f"agent-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 默认能力配置
|
||||
default_capabilities = {
|
||||
"memory_enabled": False,
|
||||
"mcp_tools": [],
|
||||
"skills": [],
|
||||
"multimodal": False
|
||||
}
|
||||
if capabilities:
|
||||
default_capabilities.update(capabilities)
|
||||
|
||||
# 默认行为配置
|
||||
default_behavior = {
|
||||
"speak_threshold": 0.5,
|
||||
"max_speak_per_round": 2,
|
||||
"speak_style": "balanced"
|
||||
}
|
||||
if behavior:
|
||||
default_behavior.update(behavior)
|
||||
|
||||
# 创建文档
|
||||
agent = Agent(
|
||||
agent_id=agent_id,
|
||||
name=name,
|
||||
role=role,
|
||||
system_prompt=system_prompt,
|
||||
provider_id=provider_id,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
capabilities=default_capabilities,
|
||||
behavior=default_behavior,
|
||||
avatar=avatar,
|
||||
color=color,
|
||||
enabled=True,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
await agent.insert()
|
||||
|
||||
logger.info(f"创建Agent: {agent_id} ({name})")
|
||||
return agent
|
||||
|
||||
@classmethod
|
||||
async def get_agent(cls, agent_id: str) -> Optional[Agent]:
|
||||
"""
|
||||
获取指定Agent
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
Agent文档或None
|
||||
"""
|
||||
return await Agent.find_one(Agent.agent_id == agent_id)
|
||||
|
||||
@classmethod
|
||||
async def get_all_agents(
|
||||
cls,
|
||||
enabled_only: bool = False
|
||||
) -> List[Agent]:
|
||||
"""
|
||||
获取所有Agent
|
||||
|
||||
Args:
|
||||
enabled_only: 是否只返回启用的Agent
|
||||
|
||||
Returns:
|
||||
Agent列表
|
||||
"""
|
||||
if enabled_only:
|
||||
return await Agent.find(Agent.enabled == True).to_list()
|
||||
return await Agent.find_all().to_list()
|
||||
|
||||
@classmethod
|
||||
async def get_agents_by_ids(
|
||||
cls,
|
||||
agent_ids: List[str]
|
||||
) -> List[Agent]:
|
||||
"""
|
||||
根据ID列表获取多个Agent
|
||||
|
||||
Args:
|
||||
agent_ids: Agent ID列表
|
||||
|
||||
Returns:
|
||||
Agent列表
|
||||
"""
|
||||
return await Agent.find(
|
||||
{"agent_id": {"$in": agent_ids}}
|
||||
).to_list()
|
||||
|
||||
@classmethod
|
||||
async def update_agent(
|
||||
cls,
|
||||
agent_id: str,
|
||||
**kwargs
|
||||
) -> Optional[Agent]:
|
||||
"""
|
||||
更新Agent配置
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
**kwargs: 要更新的字段
|
||||
|
||||
Returns:
|
||||
更新后的Agent或None
|
||||
"""
|
||||
agent = await cls.get_agent(agent_id)
|
||||
if not agent:
|
||||
return None
|
||||
|
||||
# 如果更新了provider_id,验证其存在
|
||||
if "provider_id" in kwargs:
|
||||
provider = await AIProviderService.get_provider(kwargs["provider_id"])
|
||||
if not provider:
|
||||
raise ValueError(f"AI接口不存在: {kwargs['provider_id']}")
|
||||
|
||||
# 更新字段
|
||||
kwargs["updated_at"] = datetime.utcnow()
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(agent, key):
|
||||
setattr(agent, key, value)
|
||||
|
||||
await agent.save()
|
||||
|
||||
logger.info(f"更新Agent: {agent_id}")
|
||||
return agent
|
||||
|
||||
@classmethod
|
||||
async def delete_agent(cls, agent_id: str) -> bool:
|
||||
"""
|
||||
删除Agent
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
agent = await cls.get_agent(agent_id)
|
||||
if not agent:
|
||||
return False
|
||||
|
||||
await agent.delete()
|
||||
|
||||
logger.info(f"删除Agent: {agent_id}")
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def test_agent(
|
||||
cls,
|
||||
agent_id: str,
|
||||
test_message: str = "你好,请简单介绍一下你自己。"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
测试Agent对话
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
test_message: 测试消息
|
||||
|
||||
Returns:
|
||||
测试结果
|
||||
"""
|
||||
agent = await cls.get_agent(agent_id)
|
||||
if not agent:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Agent不存在: {agent_id}"
|
||||
}
|
||||
|
||||
if not agent.enabled:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Agent已禁用"
|
||||
}
|
||||
|
||||
# 构建消息
|
||||
messages = [
|
||||
{"role": "system", "content": agent.system_prompt},
|
||||
{"role": "user", "content": test_message}
|
||||
]
|
||||
|
||||
# 调用AI接口
|
||||
response = await AIProviderService.chat(
|
||||
provider_id=agent.provider_id,
|
||||
messages=messages,
|
||||
temperature=agent.temperature,
|
||||
max_tokens=agent.max_tokens
|
||||
)
|
||||
|
||||
if response.success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "测试成功",
|
||||
"response": response.content,
|
||||
"model": response.model,
|
||||
"tokens": response.total_tokens,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": response.error
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def duplicate_agent(
|
||||
cls,
|
||||
agent_id: str,
|
||||
new_name: Optional[str] = None
|
||||
) -> Optional[Agent]:
|
||||
"""
|
||||
复制Agent
|
||||
|
||||
Args:
|
||||
agent_id: 源Agent ID
|
||||
new_name: 新Agent名称
|
||||
|
||||
Returns:
|
||||
新创建的Agent或None
|
||||
"""
|
||||
source_agent = await cls.get_agent(agent_id)
|
||||
if not source_agent:
|
||||
return None
|
||||
|
||||
return await cls.create_agent(
|
||||
name=new_name or f"{source_agent.name} (副本)",
|
||||
role=source_agent.role,
|
||||
system_prompt=source_agent.system_prompt,
|
||||
provider_id=source_agent.provider_id,
|
||||
temperature=source_agent.temperature,
|
||||
max_tokens=source_agent.max_tokens,
|
||||
capabilities=source_agent.capabilities,
|
||||
behavior=source_agent.behavior,
|
||||
avatar=source_agent.avatar,
|
||||
color=source_agent.color
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def generate_system_prompt(
|
||||
cls,
|
||||
provider_id: str,
|
||||
name: str,
|
||||
role: str,
|
||||
description: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
使用AI生成Agent系统提示词
|
||||
|
||||
Args:
|
||||
provider_id: AI接口ID
|
||||
name: Agent名称
|
||||
role: 角色定位
|
||||
description: 额外描述(可选)
|
||||
|
||||
Returns:
|
||||
生成结果,包含success和生成的prompt
|
||||
"""
|
||||
# 验证AI接口存在
|
||||
provider = await AIProviderService.get_provider(provider_id)
|
||||
if not provider:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"AI接口不存在: {provider_id}"
|
||||
}
|
||||
|
||||
# 构建生成提示词的请求
|
||||
generate_prompt = f"""请为一个AI Agent编写系统提示词(system prompt)。
|
||||
|
||||
Agent名称:{name}
|
||||
角色定位:{role}
|
||||
{f'补充说明:{description}' if description else ''}
|
||||
|
||||
要求:
|
||||
1. 提示词应简洁专业,控制在200字以内
|
||||
2. 明确该Agent的核心职责和专业领域
|
||||
3. 说明在多Agent讨论中应该关注什么
|
||||
4. 使用中文编写
|
||||
5. 不要包含任何问候语或开场白,直接给出提示词内容
|
||||
|
||||
请直接输出系统提示词,不要有任何额外的解释或包装。"""
|
||||
|
||||
try:
|
||||
messages = [{"role": "user", "content": generate_prompt}]
|
||||
|
||||
response = await AIProviderService.chat(
|
||||
provider_id=provider_id,
|
||||
messages=messages,
|
||||
temperature=0.7,
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
if response.success:
|
||||
# 清理可能的包装文本
|
||||
content = response.content.strip()
|
||||
# 移除可能的markdown代码块标记
|
||||
if content.startswith("```"):
|
||||
lines = content.split("\n")
|
||||
content = "\n".join(lines[1:])
|
||||
if content.endswith("```"):
|
||||
content = content[:-3]
|
||||
content = content.strip()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"prompt": content,
|
||||
"model": response.model,
|
||||
"tokens": response.total_tokens
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": response.error or "生成失败"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"生成系统提示词失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"生成失败: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
# Agent预设模板
|
||||
AGENT_TEMPLATES = {
|
||||
"product_manager": {
|
||||
"name": "产品经理",
|
||||
"role": "产品规划和需求分析专家",
|
||||
"system_prompt": """你是一位经验丰富的产品经理,擅长:
|
||||
- 分析用户需求和痛点
|
||||
- 制定产品策略和路线图
|
||||
- 平衡业务目标和用户体验
|
||||
- 与团队协作推进产品迭代
|
||||
|
||||
在讨论中,你需要从产品角度出发,关注用户价值、商业可行性和优先级排序。
|
||||
请用专业但易懂的语言表达观点。""",
|
||||
"color": "#1890ff"
|
||||
},
|
||||
"developer": {
|
||||
"name": "开发工程师",
|
||||
"role": "技术实现和架构设计专家",
|
||||
"system_prompt": """你是一位资深的软件开发工程师,擅长:
|
||||
- 系统架构设计
|
||||
- 代码实现和优化
|
||||
- 技术方案评估
|
||||
- 性能和安全考量
|
||||
|
||||
在讨论中,你需要从技术角度出发,关注实现可行性、技术债务和最佳实践。
|
||||
请提供具体的技术建议和潜在风险评估。""",
|
||||
"color": "#52c41a"
|
||||
},
|
||||
"designer": {
|
||||
"name": "设计师",
|
||||
"role": "用户体验和界面设计专家",
|
||||
"system_prompt": """你是一位专业的UI/UX设计师,擅长:
|
||||
- 用户体验设计
|
||||
- 界面视觉设计
|
||||
- 交互流程优化
|
||||
- 设计系统构建
|
||||
|
||||
在讨论中,你需要从设计角度出发,关注用户体验、视觉美感和交互流畅性。
|
||||
请提供设计建议并考虑可用性和一致性。""",
|
||||
"color": "#eb2f96"
|
||||
},
|
||||
"moderator": {
|
||||
"name": "主持人",
|
||||
"role": "讨论主持和共识判断专家",
|
||||
"system_prompt": """你是讨论的主持人,负责:
|
||||
- 引导讨论方向
|
||||
- 总结各方观点
|
||||
- 判断是否达成共识
|
||||
- 提炼行动要点
|
||||
|
||||
在讨论中,你需要保持中立,促进有效沟通,并在适当时机总结讨论成果。
|
||||
当各方观点趋于一致时,请明确指出并总结共识内容。""",
|
||||
"color": "#722ed1"
|
||||
}
|
||||
}
|
||||
364
backend/services/ai_provider_service.py
Normal file
364
backend/services/ai_provider_service.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""
|
||||
AI接口提供商服务
|
||||
管理AI接口的配置和调用
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from loguru import logger
|
||||
|
||||
from models.ai_provider import AIProvider
|
||||
from adapters import get_adapter, BaseAdapter, ChatMessage, AdapterResponse
|
||||
from utils.encryption import encrypt_api_key, decrypt_api_key
|
||||
from utils.rate_limiter import rate_limiter
|
||||
|
||||
|
||||
class AIProviderService:
|
||||
"""
|
||||
AI接口提供商服务类
|
||||
负责AI接口的CRUD操作和调用
|
||||
"""
|
||||
|
||||
# 缓存适配器实例
|
||||
_adapter_cache: Dict[str, BaseAdapter] = {}
|
||||
|
||||
@classmethod
|
||||
async def create_provider(
|
||||
cls,
|
||||
provider_type: str,
|
||||
name: str,
|
||||
model: str,
|
||||
api_key: str = "",
|
||||
base_url: str = "",
|
||||
use_proxy: bool = False,
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
rate_limit: Optional[Dict[str, int]] = None,
|
||||
timeout: int = 60,
|
||||
extra_params: Optional[Dict[str, Any]] = None
|
||||
) -> AIProvider:
|
||||
"""
|
||||
创建新的AI接口配置
|
||||
|
||||
Args:
|
||||
provider_type: 提供商类型
|
||||
name: 自定义名称
|
||||
model: 模型名称
|
||||
api_key: API密钥
|
||||
base_url: API基础URL
|
||||
use_proxy: 是否使用代理
|
||||
proxy_config: 代理配置
|
||||
rate_limit: 速率限制配置
|
||||
timeout: 超时时间
|
||||
extra_params: 额外参数
|
||||
|
||||
Returns:
|
||||
创建的AIProvider文档
|
||||
"""
|
||||
# 验证提供商类型
|
||||
try:
|
||||
get_adapter(provider_type)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"不支持的提供商类型: {provider_type}")
|
||||
|
||||
# 生成唯一ID
|
||||
provider_id = f"{provider_type}-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 加密API密钥
|
||||
encrypted_key = encrypt_api_key(api_key) if api_key else ""
|
||||
|
||||
# 创建文档
|
||||
provider = AIProvider(
|
||||
provider_id=provider_id,
|
||||
provider_type=provider_type,
|
||||
name=name,
|
||||
api_key=encrypted_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
use_proxy=use_proxy,
|
||||
proxy_config=proxy_config or {},
|
||||
rate_limit=rate_limit or {"requests_per_minute": 60, "tokens_per_minute": 100000},
|
||||
timeout=timeout,
|
||||
extra_params=extra_params or {},
|
||||
enabled=True,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
await provider.insert()
|
||||
|
||||
# 注册速率限制
|
||||
rate_limiter.register(
|
||||
provider_id,
|
||||
provider.rate_limit.get("requests_per_minute", 60),
|
||||
provider.rate_limit.get("tokens_per_minute", 100000)
|
||||
)
|
||||
|
||||
logger.info(f"创建AI接口配置: {provider_id} ({name})")
|
||||
return provider
|
||||
|
||||
@classmethod
|
||||
async def get_provider(cls, provider_id: str) -> Optional[AIProvider]:
|
||||
"""
|
||||
获取指定AI接口配置
|
||||
|
||||
Args:
|
||||
provider_id: 接口ID
|
||||
|
||||
Returns:
|
||||
AIProvider文档或None
|
||||
"""
|
||||
return await AIProvider.find_one(AIProvider.provider_id == provider_id)
|
||||
|
||||
@classmethod
|
||||
async def get_all_providers(
|
||||
cls,
|
||||
enabled_only: bool = False
|
||||
) -> List[AIProvider]:
|
||||
"""
|
||||
获取所有AI接口配置
|
||||
|
||||
Args:
|
||||
enabled_only: 是否只返回启用的接口
|
||||
|
||||
Returns:
|
||||
AIProvider列表
|
||||
"""
|
||||
if enabled_only:
|
||||
return await AIProvider.find(AIProvider.enabled == True).to_list()
|
||||
return await AIProvider.find_all().to_list()
|
||||
|
||||
@classmethod
|
||||
async def update_provider(
|
||||
cls,
|
||||
provider_id: str,
|
||||
**kwargs
|
||||
) -> Optional[AIProvider]:
|
||||
"""
|
||||
更新AI接口配置
|
||||
|
||||
Args:
|
||||
provider_id: 接口ID
|
||||
**kwargs: 要更新的字段
|
||||
|
||||
Returns:
|
||||
更新后的AIProvider或None
|
||||
"""
|
||||
provider = await cls.get_provider(provider_id)
|
||||
if not provider:
|
||||
return None
|
||||
|
||||
# 如果更新了API密钥,需要加密
|
||||
if "api_key" in kwargs and kwargs["api_key"]:
|
||||
kwargs["api_key"] = encrypt_api_key(kwargs["api_key"])
|
||||
|
||||
# 更新字段
|
||||
kwargs["updated_at"] = datetime.utcnow()
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(provider, key):
|
||||
setattr(provider, key, value)
|
||||
|
||||
await provider.save()
|
||||
|
||||
# 清除适配器缓存
|
||||
cls._adapter_cache.pop(provider_id, None)
|
||||
|
||||
# 更新速率限制
|
||||
if "rate_limit" in kwargs:
|
||||
rate_limiter.unregister(provider_id)
|
||||
rate_limiter.register(
|
||||
provider_id,
|
||||
provider.rate_limit.get("requests_per_minute", 60),
|
||||
provider.rate_limit.get("tokens_per_minute", 100000)
|
||||
)
|
||||
|
||||
logger.info(f"更新AI接口配置: {provider_id}")
|
||||
return provider
|
||||
|
||||
@classmethod
|
||||
async def delete_provider(cls, provider_id: str) -> bool:
|
||||
"""
|
||||
删除AI接口配置
|
||||
|
||||
Args:
|
||||
provider_id: 接口ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
provider = await cls.get_provider(provider_id)
|
||||
if not provider:
|
||||
return False
|
||||
|
||||
await provider.delete()
|
||||
|
||||
# 清除缓存和速率限制
|
||||
cls._adapter_cache.pop(provider_id, None)
|
||||
rate_limiter.unregister(provider_id)
|
||||
|
||||
logger.info(f"删除AI接口配置: {provider_id}")
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def get_adapter(cls, provider_id: str) -> Optional[BaseAdapter]:
|
||||
"""
|
||||
获取AI接口的适配器实例
|
||||
|
||||
Args:
|
||||
provider_id: 接口ID
|
||||
|
||||
Returns:
|
||||
适配器实例或None
|
||||
"""
|
||||
# 检查缓存
|
||||
if provider_id in cls._adapter_cache:
|
||||
return cls._adapter_cache[provider_id]
|
||||
|
||||
provider = await cls.get_provider(provider_id)
|
||||
if not provider or not provider.enabled:
|
||||
return None
|
||||
|
||||
# 解密API密钥
|
||||
api_key = decrypt_api_key(provider.api_key) if provider.api_key else ""
|
||||
|
||||
# 创建适配器
|
||||
adapter_class = get_adapter(provider.provider_type)
|
||||
adapter = adapter_class(
|
||||
api_key=api_key,
|
||||
base_url=provider.base_url,
|
||||
model=provider.model,
|
||||
use_proxy=provider.use_proxy,
|
||||
proxy_config=provider.proxy_config,
|
||||
timeout=provider.timeout,
|
||||
**provider.extra_params
|
||||
)
|
||||
|
||||
# 缓存适配器
|
||||
cls._adapter_cache[provider_id] = adapter
|
||||
|
||||
return adapter
|
||||
|
||||
@classmethod
|
||||
async def chat(
|
||||
cls,
|
||||
provider_id: str,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AdapterResponse:
|
||||
"""
|
||||
调用AI接口进行对话
|
||||
|
||||
Args:
|
||||
provider_id: 接口ID
|
||||
messages: 消息列表 [{"role": "user", "content": "..."}]
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
适配器响应
|
||||
"""
|
||||
adapter = await cls.get_adapter(provider_id)
|
||||
if not adapter:
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=f"AI接口不存在或未启用: {provider_id}"
|
||||
)
|
||||
|
||||
# 检查速率限制
|
||||
estimated_tokens = sum(len(m.get("content", "")) for m in messages) // 4
|
||||
if not await rate_limiter.acquire_wait(provider_id, estimated_tokens):
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error="请求频率超限,请稍后重试"
|
||||
)
|
||||
|
||||
# 转换消息格式
|
||||
chat_messages = [
|
||||
ChatMessage(
|
||||
role=m.get("role", "user"),
|
||||
content=m.get("content", ""),
|
||||
name=m.get("name")
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
|
||||
# 调用适配器
|
||||
response = await adapter.chat(
|
||||
messages=chat_messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
async def test_provider(cls, provider_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
测试AI接口连接
|
||||
|
||||
Args:
|
||||
provider_id: 接口ID
|
||||
|
||||
Returns:
|
||||
测试结果
|
||||
"""
|
||||
adapter = await cls.get_adapter(provider_id)
|
||||
if not adapter:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"AI接口不存在或未启用: {provider_id}"
|
||||
}
|
||||
|
||||
return await adapter.test_connection()
|
||||
|
||||
@classmethod
|
||||
async def test_provider_config(
|
||||
cls,
|
||||
provider_type: str,
|
||||
api_key: str,
|
||||
base_url: str = "",
|
||||
model: str = "",
|
||||
use_proxy: bool = False,
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
timeout: int = 30,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
测试AI接口配置(不保存)
|
||||
|
||||
Args:
|
||||
provider_type: 提供商类型
|
||||
api_key: API密钥
|
||||
base_url: API基础URL
|
||||
model: 模型名称
|
||||
use_proxy: 是否使用代理
|
||||
proxy_config: 代理配置
|
||||
timeout: 超时时间
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
测试结果
|
||||
"""
|
||||
try:
|
||||
adapter_class = get_adapter(provider_type)
|
||||
except ValueError:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"不支持的提供商类型: {provider_type}"
|
||||
}
|
||||
|
||||
adapter = adapter_class(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
use_proxy=use_proxy,
|
||||
proxy_config=proxy_config,
|
||||
timeout=timeout,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return await adapter.test_connection()
|
||||
357
backend/services/chatroom_service.py
Normal file
357
backend/services/chatroom_service.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""
|
||||
聊天室服务
|
||||
管理聊天室的创建和状态
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from loguru import logger
|
||||
|
||||
from models.chatroom import ChatRoom, ChatRoomStatus
|
||||
from models.message import Message
|
||||
from services.agent_service import AgentService
|
||||
|
||||
|
||||
class ChatRoomService:
|
||||
"""
|
||||
聊天室服务类
|
||||
负责聊天室的CRUD操作
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def create_chatroom(
|
||||
cls,
|
||||
name: str,
|
||||
description: str = "",
|
||||
agents: Optional[List[str]] = None,
|
||||
moderator_agent_id: Optional[str] = None,
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
) -> ChatRoom:
|
||||
"""
|
||||
创建新的聊天室
|
||||
|
||||
Args:
|
||||
name: 聊天室名称
|
||||
description: 描述
|
||||
agents: Agent ID列表
|
||||
moderator_agent_id: 主持人Agent ID
|
||||
config: 聊天室配置
|
||||
|
||||
Returns:
|
||||
创建的ChatRoom文档
|
||||
"""
|
||||
# 验证Agent存在
|
||||
if agents:
|
||||
existing_agents = await AgentService.get_agents_by_ids(agents)
|
||||
existing_ids = {a.agent_id for a in existing_agents}
|
||||
missing_ids = set(agents) - existing_ids
|
||||
if missing_ids:
|
||||
raise ValueError(f"Agent不存在: {', '.join(missing_ids)}")
|
||||
|
||||
# 验证主持人Agent
|
||||
if moderator_agent_id:
|
||||
moderator = await AgentService.get_agent(moderator_agent_id)
|
||||
if not moderator:
|
||||
raise ValueError(f"主持人Agent不存在: {moderator_agent_id}")
|
||||
|
||||
# 生成唯一ID
|
||||
room_id = f"room-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 默认配置
|
||||
default_config = {
|
||||
"max_rounds": 50,
|
||||
"message_history_size": 20,
|
||||
"consensus_threshold": 0.8,
|
||||
"round_interval": 1.0,
|
||||
"allow_user_interrupt": True
|
||||
}
|
||||
if config:
|
||||
default_config.update(config)
|
||||
|
||||
# 创建文档
|
||||
chatroom = ChatRoom(
|
||||
room_id=room_id,
|
||||
name=name,
|
||||
description=description,
|
||||
objective="",
|
||||
agents=agents or [],
|
||||
moderator_agent_id=moderator_agent_id,
|
||||
config=default_config,
|
||||
status=ChatRoomStatus.IDLE.value,
|
||||
current_round=0,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
await chatroom.insert()
|
||||
|
||||
logger.info(f"创建聊天室: {room_id} ({name})")
|
||||
return chatroom
|
||||
|
||||
@classmethod
|
||||
async def get_chatroom(cls, room_id: str) -> Optional[ChatRoom]:
|
||||
"""
|
||||
获取指定聊天室
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
ChatRoom文档或None
|
||||
"""
|
||||
return await ChatRoom.find_one(ChatRoom.room_id == room_id)
|
||||
|
||||
@classmethod
|
||||
async def get_all_chatrooms(cls) -> List[ChatRoom]:
|
||||
"""
|
||||
获取所有聊天室
|
||||
|
||||
Returns:
|
||||
ChatRoom列表
|
||||
"""
|
||||
return await ChatRoom.find_all().to_list()
|
||||
|
||||
@classmethod
|
||||
async def update_chatroom(
|
||||
cls,
|
||||
room_id: str,
|
||||
**kwargs
|
||||
) -> Optional[ChatRoom]:
|
||||
"""
|
||||
更新聊天室配置
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
**kwargs: 要更新的字段
|
||||
|
||||
Returns:
|
||||
更新后的ChatRoom或None
|
||||
"""
|
||||
chatroom = await cls.get_chatroom(room_id)
|
||||
if not chatroom:
|
||||
return None
|
||||
|
||||
# 验证Agent
|
||||
if "agents" in kwargs:
|
||||
existing_agents = await AgentService.get_agents_by_ids(kwargs["agents"])
|
||||
existing_ids = {a.agent_id for a in existing_agents}
|
||||
missing_ids = set(kwargs["agents"]) - existing_ids
|
||||
if missing_ids:
|
||||
raise ValueError(f"Agent不存在: {', '.join(missing_ids)}")
|
||||
|
||||
# 验证主持人
|
||||
if "moderator_agent_id" in kwargs and kwargs["moderator_agent_id"]:
|
||||
moderator = await AgentService.get_agent(kwargs["moderator_agent_id"])
|
||||
if not moderator:
|
||||
raise ValueError(f"主持人Agent不存在: {kwargs['moderator_agent_id']}")
|
||||
|
||||
# 更新字段
|
||||
kwargs["updated_at"] = datetime.utcnow()
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(chatroom, key):
|
||||
setattr(chatroom, key, value)
|
||||
|
||||
await chatroom.save()
|
||||
|
||||
logger.info(f"更新聊天室: {room_id}")
|
||||
return chatroom
|
||||
|
||||
@classmethod
|
||||
async def delete_chatroom(cls, room_id: str) -> bool:
|
||||
"""
|
||||
删除聊天室
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
chatroom = await cls.get_chatroom(room_id)
|
||||
if not chatroom:
|
||||
return False
|
||||
|
||||
# 删除相关消息
|
||||
await Message.find(Message.room_id == room_id).delete()
|
||||
|
||||
await chatroom.delete()
|
||||
|
||||
logger.info(f"删除聊天室: {room_id}")
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def add_agent(cls, room_id: str, agent_id: str) -> Optional[ChatRoom]:
|
||||
"""
|
||||
向聊天室添加Agent
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
更新后的ChatRoom或None
|
||||
"""
|
||||
chatroom = await cls.get_chatroom(room_id)
|
||||
if not chatroom:
|
||||
return None
|
||||
|
||||
# 验证Agent存在
|
||||
agent = await AgentService.get_agent(agent_id)
|
||||
if not agent:
|
||||
raise ValueError(f"Agent不存在: {agent_id}")
|
||||
|
||||
# 添加Agent
|
||||
if agent_id not in chatroom.agents:
|
||||
chatroom.agents.append(agent_id)
|
||||
chatroom.updated_at = datetime.utcnow()
|
||||
await chatroom.save()
|
||||
|
||||
return chatroom
|
||||
|
||||
@classmethod
|
||||
async def remove_agent(cls, room_id: str, agent_id: str) -> Optional[ChatRoom]:
|
||||
"""
|
||||
从聊天室移除Agent
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
更新后的ChatRoom或None
|
||||
"""
|
||||
chatroom = await cls.get_chatroom(room_id)
|
||||
if not chatroom:
|
||||
return None
|
||||
|
||||
# 移除Agent
|
||||
if agent_id in chatroom.agents:
|
||||
chatroom.agents.remove(agent_id)
|
||||
chatroom.updated_at = datetime.utcnow()
|
||||
await chatroom.save()
|
||||
|
||||
return chatroom
|
||||
|
||||
@classmethod
|
||||
async def set_objective(
|
||||
cls,
|
||||
room_id: str,
|
||||
objective: str
|
||||
) -> Optional[ChatRoom]:
|
||||
"""
|
||||
设置讨论目标
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
objective: 讨论目标
|
||||
|
||||
Returns:
|
||||
更新后的ChatRoom或None
|
||||
"""
|
||||
return await cls.update_chatroom(room_id, objective=objective)
|
||||
|
||||
@classmethod
|
||||
async def update_status(
|
||||
cls,
|
||||
room_id: str,
|
||||
status: ChatRoomStatus
|
||||
) -> Optional[ChatRoom]:
|
||||
"""
|
||||
更新聊天室状态
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
status: 新状态
|
||||
|
||||
Returns:
|
||||
更新后的ChatRoom或None
|
||||
"""
|
||||
chatroom = await cls.get_chatroom(room_id)
|
||||
if not chatroom:
|
||||
return None
|
||||
|
||||
chatroom.status = status.value
|
||||
chatroom.updated_at = datetime.utcnow()
|
||||
|
||||
if status == ChatRoomStatus.COMPLETED:
|
||||
chatroom.completed_at = datetime.utcnow()
|
||||
|
||||
await chatroom.save()
|
||||
|
||||
logger.info(f"聊天室状态更新: {room_id} -> {status.value}")
|
||||
return chatroom
|
||||
|
||||
@classmethod
|
||||
async def increment_round(cls, room_id: str) -> Optional[ChatRoom]:
|
||||
"""
|
||||
增加轮次计数
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
更新后的ChatRoom或None
|
||||
"""
|
||||
chatroom = await cls.get_chatroom(room_id)
|
||||
if not chatroom:
|
||||
return None
|
||||
|
||||
chatroom.current_round += 1
|
||||
chatroom.updated_at = datetime.utcnow()
|
||||
await chatroom.save()
|
||||
|
||||
return chatroom
|
||||
|
||||
@classmethod
|
||||
async def get_messages(
|
||||
cls,
|
||||
room_id: str,
|
||||
limit: int = 50,
|
||||
skip: int = 0,
|
||||
discussion_id: Optional[str] = None
|
||||
) -> List[Message]:
|
||||
"""
|
||||
获取聊天室消息历史
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
limit: 返回数量限制
|
||||
skip: 跳过数量
|
||||
discussion_id: 讨论ID(可选)
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
query = {"room_id": room_id}
|
||||
if discussion_id:
|
||||
query["discussion_id"] = discussion_id
|
||||
|
||||
return await Message.find(query).sort(
|
||||
"-created_at"
|
||||
).skip(skip).limit(limit).to_list()
|
||||
|
||||
@classmethod
|
||||
async def get_recent_messages(
|
||||
cls,
|
||||
room_id: str,
|
||||
count: int = 20,
|
||||
discussion_id: Optional[str] = None
|
||||
) -> List[Message]:
|
||||
"""
|
||||
获取最近的消息
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
count: 消息数量
|
||||
discussion_id: 讨论ID(可选)
|
||||
|
||||
Returns:
|
||||
消息列表(按时间正序)
|
||||
"""
|
||||
messages = await cls.get_messages(
|
||||
room_id,
|
||||
limit=count,
|
||||
discussion_id=discussion_id
|
||||
)
|
||||
return list(reversed(messages)) # 返回正序
|
||||
227
backend/services/consensus_manager.py
Normal file
227
backend/services/consensus_manager.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""
|
||||
共识管理器
|
||||
判断讨论是否达成共识
|
||||
"""
|
||||
import json
|
||||
from typing import Dict, Any, Optional
|
||||
from loguru import logger
|
||||
|
||||
from models.agent import Agent
|
||||
from models.chatroom import ChatRoom
|
||||
from services.ai_provider_service import AIProviderService
|
||||
|
||||
|
||||
class ConsensusManager:
|
||||
"""
|
||||
共识管理器
|
||||
使用主持人Agent判断讨论共识
|
||||
"""
|
||||
|
||||
# 共识判断提示词模板
|
||||
CONSENSUS_PROMPT = """你是讨论的主持人,负责判断讨论是否达成共识。
|
||||
|
||||
讨论目标:{objective}
|
||||
|
||||
对话历史:
|
||||
{history}
|
||||
|
||||
请仔细分析对话内容,判断:
|
||||
1. 参与者是否对核心问题达成一致意见?
|
||||
2. 是否还有重要分歧未解决?
|
||||
3. 讨论结果是否足够明确和可执行?
|
||||
|
||||
请以JSON格式回复(不要包含任何其他文字):
|
||||
{{
|
||||
"consensus_reached": true或false,
|
||||
"confidence": 0到1之间的数字,
|
||||
"summary": "讨论结果摘要,简洁概括达成的共识或当前状态",
|
||||
"action_items": ["具体的行动项列表"],
|
||||
"unresolved_issues": ["未解决的问题列表"],
|
||||
"key_decisions": ["关键决策列表"]
|
||||
}}
|
||||
|
||||
注意:
|
||||
- consensus_reached为true表示核心问题已有明确结论
|
||||
- confidence表示你对共识判断的信心程度
|
||||
- 如果讨论仍有争议或不够深入,应该返回false
|
||||
- action_items应该是具体可执行的任务
|
||||
- 请确保返回有效的JSON格式"""
|
||||
|
||||
@classmethod
|
||||
async def check_consensus(
|
||||
cls,
|
||||
moderator: Agent,
|
||||
context: "DiscussionContext",
|
||||
chatroom: ChatRoom
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
检查是否达成共识
|
||||
|
||||
Args:
|
||||
moderator: 主持人Agent
|
||||
context: 讨论上下文
|
||||
chatroom: 聊天室
|
||||
|
||||
Returns:
|
||||
共识判断结果
|
||||
"""
|
||||
from services.discussion_engine import DiscussionContext
|
||||
|
||||
# 构建历史记录
|
||||
history_text = ""
|
||||
for msg in context.messages:
|
||||
if msg.agent_id:
|
||||
history_text += f"[{msg.agent_id}]: {msg.content}\n\n"
|
||||
else:
|
||||
history_text += f"[系统]: {msg.content}\n\n"
|
||||
|
||||
if not history_text:
|
||||
return {
|
||||
"consensus_reached": False,
|
||||
"confidence": 0,
|
||||
"summary": "讨论尚未开始",
|
||||
"action_items": [],
|
||||
"unresolved_issues": [],
|
||||
"key_decisions": []
|
||||
}
|
||||
|
||||
# 构建提示词
|
||||
prompt = cls.CONSENSUS_PROMPT.format(
|
||||
objective=context.objective,
|
||||
history=history_text
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用主持人Agent的AI接口
|
||||
response = await AIProviderService.chat(
|
||||
provider_id=moderator.provider_id,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.3, # 使用较低温度以获得更一致的结果
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
if not response.success:
|
||||
logger.error(f"共识判断失败: {response.error}")
|
||||
return cls._default_result("AI接口调用失败")
|
||||
|
||||
# 解析JSON响应
|
||||
content = response.content.strip()
|
||||
|
||||
# 尝试提取JSON部分
|
||||
try:
|
||||
# 尝试直接解析
|
||||
result = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
# 尝试提取JSON块
|
||||
import re
|
||||
json_match = re.search(r'\{[\s\S]*\}', content)
|
||||
if json_match:
|
||||
try:
|
||||
result = json.loads(json_match.group())
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"无法解析共识判断结果: {content}")
|
||||
return cls._default_result("无法解析AI响应")
|
||||
else:
|
||||
return cls._default_result("AI响应格式错误")
|
||||
|
||||
# 验证和规范化结果
|
||||
return cls._normalize_result(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"共识判断异常: {e}")
|
||||
return cls._default_result(str(e))
|
||||
|
||||
@classmethod
|
||||
async def generate_summary(
|
||||
cls,
|
||||
moderator: Agent,
|
||||
context: "DiscussionContext"
|
||||
) -> str:
|
||||
"""
|
||||
生成讨论摘要
|
||||
|
||||
Args:
|
||||
moderator: 主持人Agent
|
||||
context: 讨论上下文
|
||||
|
||||
Returns:
|
||||
讨论摘要
|
||||
"""
|
||||
from services.discussion_engine import DiscussionContext
|
||||
|
||||
# 构建历史记录
|
||||
history_text = ""
|
||||
for msg in context.messages:
|
||||
if msg.agent_id:
|
||||
history_text += f"[{msg.agent_id}]: {msg.content}\n\n"
|
||||
|
||||
prompt = f"""请为以下讨论生成一份简洁的摘要。
|
||||
|
||||
讨论目标:{context.objective}
|
||||
|
||||
对话记录:
|
||||
{history_text}
|
||||
|
||||
请提供:
|
||||
1. 讨论的主要观点和结论
|
||||
2. 参与者的立场和建议
|
||||
3. 最终的决策或共识(如果有)
|
||||
|
||||
摘要应该简洁明了,控制在300字以内。"""
|
||||
|
||||
try:
|
||||
response = await AIProviderService.chat(
|
||||
provider_id=moderator.provider_id,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.5,
|
||||
max_tokens=500
|
||||
)
|
||||
|
||||
if response.success:
|
||||
return response.content.strip()
|
||||
else:
|
||||
return "无法生成摘要"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成摘要异常: {e}")
|
||||
return "生成摘要时发生错误"
|
||||
|
||||
@classmethod
|
||||
def _default_result(cls, error: str = "") -> Dict[str, Any]:
|
||||
"""
|
||||
返回默认结果
|
||||
|
||||
Args:
|
||||
error: 错误信息
|
||||
|
||||
Returns:
|
||||
默认共识结果
|
||||
"""
|
||||
return {
|
||||
"consensus_reached": False,
|
||||
"confidence": 0,
|
||||
"summary": error if error else "共识判断失败",
|
||||
"action_items": [],
|
||||
"unresolved_issues": [],
|
||||
"key_decisions": []
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _normalize_result(cls, result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
规范化共识结果
|
||||
|
||||
Args:
|
||||
result: 原始结果
|
||||
|
||||
Returns:
|
||||
规范化的结果
|
||||
"""
|
||||
return {
|
||||
"consensus_reached": bool(result.get("consensus_reached", False)),
|
||||
"confidence": max(0, min(1, float(result.get("confidence", 0)))),
|
||||
"summary": str(result.get("summary", "")),
|
||||
"action_items": list(result.get("action_items", [])),
|
||||
"unresolved_issues": list(result.get("unresolved_issues", [])),
|
||||
"key_decisions": list(result.get("key_decisions", []))
|
||||
}
|
||||
589
backend/services/discussion_engine.py
Normal file
589
backend/services/discussion_engine.py
Normal file
@@ -0,0 +1,589 @@
|
||||
"""
|
||||
讨论引擎
|
||||
实现自由讨论的核心逻辑
|
||||
"""
|
||||
import uuid
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
|
||||
from models.chatroom import ChatRoom, ChatRoomStatus
|
||||
from models.agent import Agent
|
||||
from models.message import Message, MessageType
|
||||
from models.discussion_result import DiscussionResult
|
||||
from services.ai_provider_service import AIProviderService
|
||||
from services.agent_service import AgentService
|
||||
from services.chatroom_service import ChatRoomService
|
||||
from services.message_router import MessageRouter
|
||||
from services.consensus_manager import ConsensusManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiscussionContext:
|
||||
"""讨论上下文"""
|
||||
discussion_id: str
|
||||
room_id: str
|
||||
objective: str
|
||||
current_round: int = 0
|
||||
messages: List[Message] = field(default_factory=list)
|
||||
agent_speak_counts: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
def add_message(self, message: Message) -> None:
|
||||
"""添加消息到上下文"""
|
||||
self.messages.append(message)
|
||||
if message.agent_id:
|
||||
self.agent_speak_counts[message.agent_id] = \
|
||||
self.agent_speak_counts.get(message.agent_id, 0) + 1
|
||||
|
||||
def get_recent_messages(self, count: int = 20) -> List[Message]:
|
||||
"""获取最近的消息"""
|
||||
return self.messages[-count:] if len(self.messages) > count else self.messages
|
||||
|
||||
def get_agent_speak_count(self, agent_id: str) -> int:
|
||||
"""获取Agent在当前轮次的发言次数"""
|
||||
return self.agent_speak_counts.get(agent_id, 0)
|
||||
|
||||
def reset_round_counts(self) -> None:
|
||||
"""重置轮次发言计数"""
|
||||
self.agent_speak_counts.clear()
|
||||
|
||||
|
||||
class DiscussionEngine:
|
||||
"""
|
||||
讨论引擎
|
||||
实现多Agent自由讨论的核心逻辑
|
||||
"""
|
||||
|
||||
# 活跃的讨论: room_id -> DiscussionContext
|
||||
_active_discussions: Dict[str, DiscussionContext] = {}
|
||||
|
||||
# 停止信号
|
||||
_stop_signals: Dict[str, bool] = {}
|
||||
|
||||
@classmethod
|
||||
async def start_discussion(
|
||||
cls,
|
||||
room_id: str,
|
||||
objective: str
|
||||
) -> Optional[DiscussionResult]:
|
||||
"""
|
||||
启动讨论
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
objective: 讨论目标
|
||||
|
||||
Returns:
|
||||
讨论结果
|
||||
"""
|
||||
# 获取聊天室
|
||||
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||
if not chatroom:
|
||||
raise ValueError(f"聊天室不存在: {room_id}")
|
||||
|
||||
if not chatroom.agents:
|
||||
raise ValueError("聊天室没有Agent参与")
|
||||
|
||||
if not objective:
|
||||
raise ValueError("讨论目标不能为空")
|
||||
|
||||
# 检查是否已有活跃讨论
|
||||
if room_id in cls._active_discussions:
|
||||
raise ValueError("聊天室已有进行中的讨论")
|
||||
|
||||
# 创建讨论
|
||||
discussion_id = f"disc-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 创建讨论结果记录
|
||||
discussion_result = DiscussionResult(
|
||||
discussion_id=discussion_id,
|
||||
room_id=room_id,
|
||||
objective=objective,
|
||||
status="in_progress",
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
await discussion_result.insert()
|
||||
|
||||
# 创建讨论上下文
|
||||
context = DiscussionContext(
|
||||
discussion_id=discussion_id,
|
||||
room_id=room_id,
|
||||
objective=objective
|
||||
)
|
||||
cls._active_discussions[room_id] = context
|
||||
cls._stop_signals[room_id] = False
|
||||
|
||||
# 更新聊天室状态
|
||||
await ChatRoomService.update_chatroom(
|
||||
room_id,
|
||||
status=ChatRoomStatus.ACTIVE.value,
|
||||
objective=objective,
|
||||
current_discussion_id=discussion_id,
|
||||
current_round=0
|
||||
)
|
||||
|
||||
# 广播讨论开始
|
||||
await MessageRouter.broadcast_status(room_id, "discussion_started", {
|
||||
"discussion_id": discussion_id,
|
||||
"objective": objective
|
||||
})
|
||||
|
||||
# 发送系统消息
|
||||
await MessageRouter.save_and_broadcast_message(
|
||||
room_id=room_id,
|
||||
discussion_id=discussion_id,
|
||||
agent_id=None,
|
||||
content=f"讨论开始\n\n目标:{objective}",
|
||||
message_type=MessageType.SYSTEM.value,
|
||||
round_num=0
|
||||
)
|
||||
|
||||
logger.info(f"讨论开始: {room_id} - {objective}")
|
||||
|
||||
# 运行讨论循环
|
||||
try:
|
||||
result = await cls._run_discussion_loop(chatroom, context)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"讨论异常: {e}")
|
||||
await cls._handle_discussion_error(room_id, discussion_id, str(e))
|
||||
raise
|
||||
finally:
|
||||
# 清理
|
||||
cls._active_discussions.pop(room_id, None)
|
||||
cls._stop_signals.pop(room_id, None)
|
||||
|
||||
@classmethod
|
||||
async def stop_discussion(cls, room_id: str) -> bool:
|
||||
"""
|
||||
停止讨论
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if room_id not in cls._active_discussions:
|
||||
return False
|
||||
|
||||
cls._stop_signals[room_id] = True
|
||||
logger.info(f"收到停止讨论信号: {room_id}")
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def pause_discussion(cls, room_id: str) -> bool:
|
||||
"""
|
||||
暂停讨论
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if room_id not in cls._active_discussions:
|
||||
return False
|
||||
|
||||
await ChatRoomService.update_status(room_id, ChatRoomStatus.PAUSED)
|
||||
await MessageRouter.broadcast_status(room_id, "discussion_paused")
|
||||
|
||||
logger.info(f"讨论暂停: {room_id}")
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def resume_discussion(cls, room_id: str) -> bool:
|
||||
"""
|
||||
恢复讨论
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||
if not chatroom or chatroom.status != ChatRoomStatus.PAUSED.value:
|
||||
return False
|
||||
|
||||
await ChatRoomService.update_status(room_id, ChatRoomStatus.ACTIVE)
|
||||
await MessageRouter.broadcast_status(room_id, "discussion_resumed")
|
||||
|
||||
logger.info(f"讨论恢复: {room_id}")
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def _run_discussion_loop(
|
||||
cls,
|
||||
chatroom: ChatRoom,
|
||||
context: DiscussionContext
|
||||
) -> DiscussionResult:
|
||||
"""
|
||||
运行讨论循环
|
||||
|
||||
Args:
|
||||
chatroom: 聊天室
|
||||
context: 讨论上下文
|
||||
|
||||
Returns:
|
||||
讨论结果
|
||||
"""
|
||||
room_id = chatroom.room_id
|
||||
config = chatroom.get_config()
|
||||
|
||||
# 获取所有Agent
|
||||
agents = await AgentService.get_agents_by_ids(chatroom.agents)
|
||||
agent_map = {a.agent_id: a for a in agents}
|
||||
|
||||
# 获取主持人(用于共识判断)
|
||||
moderator = None
|
||||
if chatroom.moderator_agent_id:
|
||||
moderator = await AgentService.get_agent(chatroom.moderator_agent_id)
|
||||
|
||||
consecutive_no_speak = 0 # 连续无人发言的轮次
|
||||
|
||||
while context.current_round < config.max_rounds:
|
||||
# 检查停止信号
|
||||
if cls._stop_signals.get(room_id, False):
|
||||
break
|
||||
|
||||
# 检查暂停状态
|
||||
current_chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||
if current_chatroom and current_chatroom.status == ChatRoomStatus.PAUSED.value:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
# 增加轮次
|
||||
context.current_round += 1
|
||||
context.reset_round_counts()
|
||||
|
||||
# 广播轮次信息
|
||||
await MessageRouter.broadcast_round_info(
|
||||
room_id,
|
||||
context.current_round,
|
||||
config.max_rounds
|
||||
)
|
||||
|
||||
# 更新聊天室轮次
|
||||
await ChatRoomService.update_chatroom(
|
||||
room_id,
|
||||
current_round=context.current_round
|
||||
)
|
||||
|
||||
# 本轮是否有人发言
|
||||
round_has_message = False
|
||||
|
||||
# 遍历所有Agent,判断是否发言
|
||||
for agent_id in chatroom.agents:
|
||||
agent = agent_map.get(agent_id)
|
||||
if not agent or not agent.enabled:
|
||||
continue
|
||||
|
||||
# 检查本轮发言次数限制
|
||||
behavior = agent.get_behavior()
|
||||
if context.get_agent_speak_count(agent_id) >= behavior.max_speak_per_round:
|
||||
continue
|
||||
|
||||
# 判断是否发言
|
||||
should_speak, content = await cls._should_agent_speak(
|
||||
agent, context, chatroom
|
||||
)
|
||||
|
||||
if should_speak and content:
|
||||
# 广播输入状态
|
||||
await MessageRouter.broadcast_typing(room_id, agent_id, True)
|
||||
|
||||
# 保存并广播消息
|
||||
message = await MessageRouter.save_and_broadcast_message(
|
||||
room_id=room_id,
|
||||
discussion_id=context.discussion_id,
|
||||
agent_id=agent_id,
|
||||
content=content,
|
||||
message_type=MessageType.TEXT.value,
|
||||
round_num=context.current_round
|
||||
)
|
||||
|
||||
# 更新上下文
|
||||
context.add_message(message)
|
||||
round_has_message = True
|
||||
|
||||
# 广播输入结束
|
||||
await MessageRouter.broadcast_typing(room_id, agent_id, False)
|
||||
|
||||
# 轮次间隔
|
||||
await asyncio.sleep(config.round_interval)
|
||||
|
||||
# 检查是否需要共识判断
|
||||
if round_has_message and moderator:
|
||||
consecutive_no_speak = 0
|
||||
|
||||
# 每隔几轮检查一次共识
|
||||
if context.current_round % 3 == 0 or context.current_round >= config.max_rounds - 5:
|
||||
consensus_result = await ConsensusManager.check_consensus(
|
||||
moderator, context, chatroom
|
||||
)
|
||||
|
||||
if consensus_result.get("consensus_reached", False):
|
||||
confidence = consensus_result.get("confidence", 0)
|
||||
if confidence >= config.consensus_threshold:
|
||||
# 达成共识,结束讨论
|
||||
return await cls._finalize_discussion(
|
||||
context,
|
||||
consensus_result,
|
||||
"consensus"
|
||||
)
|
||||
else:
|
||||
consecutive_no_speak += 1
|
||||
|
||||
# 连续多轮无人发言,检查共识或结束
|
||||
if consecutive_no_speak >= 3:
|
||||
if moderator:
|
||||
consensus_result = await ConsensusManager.check_consensus(
|
||||
moderator, context, chatroom
|
||||
)
|
||||
return await cls._finalize_discussion(
|
||||
context,
|
||||
consensus_result,
|
||||
"no_more_discussion"
|
||||
)
|
||||
else:
|
||||
return await cls._finalize_discussion(
|
||||
context,
|
||||
{"consensus_reached": False, "summary": "讨论结束,无明确共识"},
|
||||
"no_more_discussion"
|
||||
)
|
||||
|
||||
# 达到最大轮次
|
||||
if moderator:
|
||||
consensus_result = await ConsensusManager.check_consensus(
|
||||
moderator, context, chatroom
|
||||
)
|
||||
else:
|
||||
consensus_result = {"consensus_reached": False, "summary": "达到最大轮次限制"}
|
||||
|
||||
return await cls._finalize_discussion(
|
||||
context,
|
||||
consensus_result,
|
||||
"max_rounds"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _should_agent_speak(
|
||||
cls,
|
||||
agent: Agent,
|
||||
context: DiscussionContext,
|
||||
chatroom: ChatRoom
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
判断Agent是否应该发言
|
||||
|
||||
Args:
|
||||
agent: Agent实例
|
||||
context: 讨论上下文
|
||||
chatroom: 聊天室
|
||||
|
||||
Returns:
|
||||
(是否发言, 发言内容)
|
||||
"""
|
||||
# 构建判断提示词
|
||||
recent_messages = context.get_recent_messages(chatroom.get_config().message_history_size)
|
||||
|
||||
history_text = ""
|
||||
for msg in recent_messages:
|
||||
if msg.agent_id:
|
||||
history_text += f"[{msg.agent_id}]: {msg.content}\n\n"
|
||||
else:
|
||||
history_text += f"[系统]: {msg.content}\n\n"
|
||||
|
||||
prompt = f"""你是{agent.name},角色是{agent.role}。
|
||||
|
||||
{agent.system_prompt}
|
||||
|
||||
当前讨论目标:{context.objective}
|
||||
|
||||
对话历史:
|
||||
{history_text if history_text else "(还没有对话)"}
|
||||
|
||||
当前是第{context.current_round}轮讨论。
|
||||
|
||||
请根据你的角色判断:
|
||||
1. 你是否有新的观点或建议要分享?
|
||||
2. 你是否需要回应其他人的观点?
|
||||
3. 当前讨论是否需要你的专业意见?
|
||||
|
||||
如果你认为需要发言,请直接给出你的发言内容。
|
||||
如果你认为暂时不需要发言(例如等待更多信息、当前轮次已有足够讨论、或者你的观点已经充分表达),请只回复"PASS"。
|
||||
|
||||
注意:
|
||||
- 请保持发言简洁有力,每次发言控制在200字以内
|
||||
- 避免重复已经说过的内容
|
||||
- 如果已经达成共识或接近共识,可以选择PASS"""
|
||||
|
||||
try:
|
||||
# 调用AI接口
|
||||
response = await AIProviderService.chat(
|
||||
provider_id=agent.provider_id,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=agent.temperature,
|
||||
max_tokens=agent.max_tokens
|
||||
)
|
||||
|
||||
if not response.success:
|
||||
logger.warning(f"Agent {agent.agent_id} 响应失败: {response.error}")
|
||||
return False, ""
|
||||
|
||||
content = response.content.strip()
|
||||
|
||||
# 判断是否PASS
|
||||
if content.upper() == "PASS" or content.upper().startswith("PASS"):
|
||||
return False, ""
|
||||
|
||||
return True, content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent {agent.agent_id} 判断发言异常: {e}")
|
||||
return False, ""
|
||||
|
||||
@classmethod
|
||||
async def _finalize_discussion(
|
||||
cls,
|
||||
context: DiscussionContext,
|
||||
consensus_result: Dict[str, Any],
|
||||
end_reason: str
|
||||
) -> DiscussionResult:
|
||||
"""
|
||||
完成讨论,保存结果
|
||||
|
||||
Args:
|
||||
context: 讨论上下文
|
||||
consensus_result: 共识判断结果
|
||||
end_reason: 结束原因
|
||||
|
||||
Returns:
|
||||
讨论结果
|
||||
"""
|
||||
room_id = context.room_id
|
||||
|
||||
# 获取讨论结果记录
|
||||
discussion_result = await DiscussionResult.find_one(
|
||||
DiscussionResult.discussion_id == context.discussion_id
|
||||
)
|
||||
|
||||
if discussion_result:
|
||||
# 更新统计
|
||||
discussion_result.update_stats(
|
||||
total_rounds=context.current_round,
|
||||
total_messages=len(context.messages),
|
||||
agent_contributions=context.agent_speak_counts
|
||||
)
|
||||
|
||||
# 标记完成
|
||||
discussion_result.mark_completed(
|
||||
consensus_reached=consensus_result.get("consensus_reached", False),
|
||||
confidence=consensus_result.get("confidence", 0),
|
||||
summary=consensus_result.get("summary", ""),
|
||||
action_items=consensus_result.get("action_items", []),
|
||||
unresolved_issues=consensus_result.get("unresolved_issues", []),
|
||||
end_reason=end_reason
|
||||
)
|
||||
|
||||
await discussion_result.save()
|
||||
|
||||
# 更新聊天室状态
|
||||
await ChatRoomService.update_status(room_id, ChatRoomStatus.COMPLETED)
|
||||
|
||||
# 发送系统消息
|
||||
summary_text = f"""讨论结束
|
||||
|
||||
结果:{"达成共识" if consensus_result.get("consensus_reached") else "未达成明确共识"}
|
||||
置信度:{consensus_result.get("confidence", 0):.0%}
|
||||
|
||||
摘要:{consensus_result.get("summary", "无")}
|
||||
|
||||
行动项:
|
||||
{chr(10).join("- " + item for item in consensus_result.get("action_items", [])) or "无"}
|
||||
|
||||
未解决问题:
|
||||
{chr(10).join("- " + issue for issue in consensus_result.get("unresolved_issues", [])) or "无"}
|
||||
|
||||
共进行 {context.current_round} 轮讨论,产生 {len(context.messages)} 条消息。"""
|
||||
|
||||
await MessageRouter.save_and_broadcast_message(
|
||||
room_id=room_id,
|
||||
discussion_id=context.discussion_id,
|
||||
agent_id=None,
|
||||
content=summary_text,
|
||||
message_type=MessageType.SYSTEM.value,
|
||||
round_num=context.current_round
|
||||
)
|
||||
|
||||
# 广播讨论结束
|
||||
await MessageRouter.broadcast_status(room_id, "discussion_completed", {
|
||||
"discussion_id": context.discussion_id,
|
||||
"consensus_reached": consensus_result.get("consensus_reached", False),
|
||||
"end_reason": end_reason
|
||||
})
|
||||
|
||||
logger.info(f"讨论结束: {room_id}, 原因: {end_reason}")
|
||||
|
||||
return discussion_result
|
||||
|
||||
@classmethod
|
||||
async def _handle_discussion_error(
|
||||
cls,
|
||||
room_id: str,
|
||||
discussion_id: str,
|
||||
error: str
|
||||
) -> None:
|
||||
"""
|
||||
处理讨论错误
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
discussion_id: 讨论ID
|
||||
error: 错误信息
|
||||
"""
|
||||
# 更新聊天室状态
|
||||
await ChatRoomService.update_status(room_id, ChatRoomStatus.ERROR)
|
||||
|
||||
# 更新讨论结果
|
||||
discussion_result = await DiscussionResult.find_one(
|
||||
DiscussionResult.discussion_id == discussion_id
|
||||
)
|
||||
if discussion_result:
|
||||
discussion_result.status = "failed"
|
||||
discussion_result.end_reason = f"error: {error}"
|
||||
discussion_result.updated_at = datetime.utcnow()
|
||||
await discussion_result.save()
|
||||
|
||||
# 广播错误
|
||||
await MessageRouter.broadcast_error(room_id, error)
|
||||
|
||||
@classmethod
|
||||
def get_active_discussion(cls, room_id: str) -> Optional[DiscussionContext]:
|
||||
"""
|
||||
获取活跃的讨论上下文
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
讨论上下文或None
|
||||
"""
|
||||
return cls._active_discussions.get(room_id)
|
||||
|
||||
@classmethod
|
||||
def is_discussion_active(cls, room_id: str) -> bool:
|
||||
"""
|
||||
检查是否有活跃讨论
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
是否活跃
|
||||
"""
|
||||
return room_id in cls._active_discussions
|
||||
252
backend/services/mcp_service.py
Normal file
252
backend/services/mcp_service.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""
|
||||
MCP服务
|
||||
管理MCP工具的集成和调用
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class MCPService:
|
||||
"""
|
||||
MCP工具服务
|
||||
集成MCP服务器,提供工具调用能力
|
||||
"""
|
||||
|
||||
# MCP服务器配置目录
|
||||
MCP_CONFIG_DIR = Path(os.getenv("CURSOR_MCP_DIR", "~/.cursor/mcps")).expanduser()
|
||||
|
||||
# 已注册的工具: server_name -> List[tool_info]
|
||||
_registered_tools: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
# Agent工具映射: agent_id -> List[tool_name]
|
||||
_agent_tools: Dict[str, List[str]] = {}
|
||||
|
||||
@classmethod
|
||||
async def initialize(cls) -> None:
|
||||
"""
|
||||
初始化MCP服务
|
||||
扫描并注册可用的MCP工具
|
||||
"""
|
||||
logger.info("初始化MCP服务...")
|
||||
|
||||
if not cls.MCP_CONFIG_DIR.exists():
|
||||
logger.warning(f"MCP配置目录不存在: {cls.MCP_CONFIG_DIR}")
|
||||
return
|
||||
|
||||
# 扫描MCP服务器目录
|
||||
for server_dir in cls.MCP_CONFIG_DIR.iterdir():
|
||||
if server_dir.is_dir():
|
||||
await cls._scan_server(server_dir)
|
||||
|
||||
logger.info(f"MCP服务初始化完成,已注册 {len(cls._registered_tools)} 个服务器")
|
||||
|
||||
@classmethod
|
||||
async def _scan_server(cls, server_dir: Path) -> None:
|
||||
"""
|
||||
扫描MCP服务器目录
|
||||
|
||||
Args:
|
||||
server_dir: 服务器目录
|
||||
"""
|
||||
server_name = server_dir.name
|
||||
tools_dir = server_dir / "tools"
|
||||
|
||||
if not tools_dir.exists():
|
||||
return
|
||||
|
||||
tools = []
|
||||
for tool_file in tools_dir.glob("*.json"):
|
||||
try:
|
||||
with open(tool_file, "r", encoding="utf-8") as f:
|
||||
tool_info = json.load(f)
|
||||
tool_info["_file"] = str(tool_file)
|
||||
tools.append(tool_info)
|
||||
except Exception as e:
|
||||
logger.warning(f"加载MCP工具配置失败: {tool_file} - {e}")
|
||||
|
||||
if tools:
|
||||
cls._registered_tools[server_name] = tools
|
||||
logger.debug(f"注册MCP服务器: {server_name}, 工具数: {len(tools)}")
|
||||
|
||||
@classmethod
|
||||
def list_servers(cls) -> List[str]:
|
||||
"""
|
||||
列出所有可用的MCP服务器
|
||||
|
||||
Returns:
|
||||
服务器名称列表
|
||||
"""
|
||||
return list(cls._registered_tools.keys())
|
||||
|
||||
@classmethod
|
||||
def list_tools(cls, server: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
列出可用的MCP工具
|
||||
|
||||
Args:
|
||||
server: 服务器名称(可选,不指定则返回所有)
|
||||
|
||||
Returns:
|
||||
工具信息列表
|
||||
"""
|
||||
if server:
|
||||
return cls._registered_tools.get(server, [])
|
||||
|
||||
# 返回所有工具
|
||||
all_tools = []
|
||||
for server_name, tools in cls._registered_tools.items():
|
||||
for tool in tools:
|
||||
tool_copy = tool.copy()
|
||||
tool_copy["server"] = server_name
|
||||
all_tools.append(tool_copy)
|
||||
|
||||
return all_tools
|
||||
|
||||
@classmethod
|
||||
def get_tool(cls, server: str, tool_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定工具的信息
|
||||
|
||||
Args:
|
||||
server: 服务器名称
|
||||
tool_name: 工具名称
|
||||
|
||||
Returns:
|
||||
工具信息或None
|
||||
"""
|
||||
tools = cls._registered_tools.get(server, [])
|
||||
for tool in tools:
|
||||
if tool.get("name") == tool_name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def call_tool(
|
||||
cls,
|
||||
server: str,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
调用MCP工具
|
||||
|
||||
Args:
|
||||
server: 服务器名称
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
|
||||
Returns:
|
||||
调用结果
|
||||
"""
|
||||
tool = cls.get_tool(server, tool_name)
|
||||
if not tool:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"工具不存在: {server}/{tool_name}"
|
||||
}
|
||||
|
||||
# TODO: 实际的MCP工具调用逻辑
|
||||
# 这里需要根据MCP协议实现工具调用
|
||||
# 目前返回模拟结果
|
||||
logger.info(f"调用MCP工具: {server}/{tool_name}, 参数: {arguments}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"result": f"MCP工具调用: {tool_name}",
|
||||
"tool": tool_name,
|
||||
"server": server,
|
||||
"arguments": arguments
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_tool_for_agent(
|
||||
cls,
|
||||
agent_id: str,
|
||||
tool_name: str
|
||||
) -> bool:
|
||||
"""
|
||||
为Agent注册可用工具
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
tool_name: 工具名称(格式: server/tool_name)
|
||||
|
||||
Returns:
|
||||
是否注册成功
|
||||
"""
|
||||
if agent_id not in cls._agent_tools:
|
||||
cls._agent_tools[agent_id] = []
|
||||
|
||||
if tool_name not in cls._agent_tools[agent_id]:
|
||||
cls._agent_tools[agent_id].append(tool_name)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def unregister_tool_for_agent(
|
||||
cls,
|
||||
agent_id: str,
|
||||
tool_name: str
|
||||
) -> bool:
|
||||
"""
|
||||
为Agent注销工具
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
tool_name: 工具名称
|
||||
|
||||
Returns:
|
||||
是否注销成功
|
||||
"""
|
||||
if agent_id in cls._agent_tools:
|
||||
if tool_name in cls._agent_tools[agent_id]:
|
||||
cls._agent_tools[agent_id].remove(tool_name)
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_agent_tools(cls, agent_id: str) -> List[str]:
|
||||
"""
|
||||
获取Agent可用的工具列表
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
工具名称列表
|
||||
"""
|
||||
return cls._agent_tools.get(agent_id, [])
|
||||
|
||||
@classmethod
|
||||
def get_tools_for_prompt(cls, agent_id: str) -> str:
|
||||
"""
|
||||
获取用于提示词的工具描述
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
工具描述文本
|
||||
"""
|
||||
tool_names = cls.get_agent_tools(agent_id)
|
||||
if not tool_names:
|
||||
return ""
|
||||
|
||||
descriptions = []
|
||||
for full_name in tool_names:
|
||||
parts = full_name.split("/", 1)
|
||||
if len(parts) == 2:
|
||||
server, tool_name = parts
|
||||
tool = cls.get_tool(server, tool_name)
|
||||
if tool:
|
||||
desc = tool.get("description", "无描述")
|
||||
descriptions.append(f"- {tool_name}: {desc}")
|
||||
|
||||
if not descriptions:
|
||||
return ""
|
||||
|
||||
return "你可以使用以下工具:\n" + "\n".join(descriptions)
|
||||
416
backend/services/memory_service.py
Normal file
416
backend/services/memory_service.py
Normal file
@@ -0,0 +1,416 @@
|
||||
"""
|
||||
记忆服务
|
||||
管理Agent的记忆存储和检索
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any, Optional
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from models.agent_memory import AgentMemory, MemoryType
|
||||
|
||||
|
||||
class MemoryService:
|
||||
"""
|
||||
Agent记忆服务
|
||||
提供记忆的存储、检索和管理功能
|
||||
"""
|
||||
|
||||
# 嵌入模型(延迟加载)
|
||||
_embedding_model = None
|
||||
|
||||
@classmethod
|
||||
def _get_embedding_model(cls):
|
||||
"""
|
||||
获取嵌入模型实例(延迟加载)
|
||||
"""
|
||||
if cls._embedding_model is None:
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
cls._embedding_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
|
||||
logger.info("嵌入模型加载成功")
|
||||
except Exception as e:
|
||||
logger.warning(f"嵌入模型加载失败: {e}")
|
||||
return None
|
||||
return cls._embedding_model
|
||||
|
||||
@classmethod
|
||||
async def create_memory(
|
||||
cls,
|
||||
agent_id: str,
|
||||
content: str,
|
||||
memory_type: str = MemoryType.SHORT_TERM.value,
|
||||
importance: float = 0.5,
|
||||
source_room_id: Optional[str] = None,
|
||||
source_discussion_id: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
expires_in_hours: Optional[int] = None
|
||||
) -> AgentMemory:
|
||||
"""
|
||||
创建新的记忆
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
content: 记忆内容
|
||||
memory_type: 记忆类型
|
||||
importance: 重要性评分
|
||||
source_room_id: 来源聊天室
|
||||
source_discussion_id: 来源讨论
|
||||
tags: 标签
|
||||
expires_in_hours: 过期时间(小时)
|
||||
|
||||
Returns:
|
||||
创建的AgentMemory文档
|
||||
"""
|
||||
memory_id = f"mem-{uuid.uuid4().hex[:12]}"
|
||||
|
||||
# 生成向量嵌入
|
||||
embedding = await cls._generate_embedding(content)
|
||||
|
||||
# 生成摘要
|
||||
summary = content[:100] + "..." if len(content) > 100 else content
|
||||
|
||||
# 计算过期时间
|
||||
expires_at = None
|
||||
if expires_in_hours:
|
||||
expires_at = datetime.utcnow() + timedelta(hours=expires_in_hours)
|
||||
|
||||
memory = AgentMemory(
|
||||
memory_id=memory_id,
|
||||
agent_id=agent_id,
|
||||
memory_type=memory_type,
|
||||
content=content,
|
||||
summary=summary,
|
||||
embedding=embedding,
|
||||
importance=importance,
|
||||
source_room_id=source_room_id,
|
||||
source_discussion_id=source_discussion_id,
|
||||
tags=tags or [],
|
||||
created_at=datetime.utcnow(),
|
||||
last_accessed=datetime.utcnow(),
|
||||
expires_at=expires_at
|
||||
)
|
||||
|
||||
await memory.insert()
|
||||
|
||||
logger.debug(f"创建记忆: {memory_id} for Agent {agent_id}")
|
||||
return memory
|
||||
|
||||
@classmethod
|
||||
async def get_memory(cls, memory_id: str) -> Optional[AgentMemory]:
|
||||
"""
|
||||
获取指定记忆
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
|
||||
Returns:
|
||||
AgentMemory文档或None
|
||||
"""
|
||||
return await AgentMemory.find_one(AgentMemory.memory_id == memory_id)
|
||||
|
||||
@classmethod
|
||||
async def get_agent_memories(
|
||||
cls,
|
||||
agent_id: str,
|
||||
memory_type: Optional[str] = None,
|
||||
limit: int = 50
|
||||
) -> List[AgentMemory]:
|
||||
"""
|
||||
获取Agent的记忆列表
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
memory_type: 记忆类型(可选)
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
记忆列表
|
||||
"""
|
||||
query = {"agent_id": agent_id}
|
||||
if memory_type:
|
||||
query["memory_type"] = memory_type
|
||||
|
||||
return await AgentMemory.find(query).sort(
|
||||
"-importance", "-last_accessed"
|
||||
).limit(limit).to_list()
|
||||
|
||||
@classmethod
|
||||
async def search_memories(
|
||||
cls,
|
||||
agent_id: str,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
memory_type: Optional[str] = None,
|
||||
min_relevance: float = 0.3
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
搜索相关记忆
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
query: 查询文本
|
||||
limit: 返回数量
|
||||
memory_type: 记忆类型(可选)
|
||||
min_relevance: 最小相关性阈值
|
||||
|
||||
Returns:
|
||||
带相关性分数的记忆列表
|
||||
"""
|
||||
# 生成查询向量
|
||||
query_embedding = await cls._generate_embedding(query)
|
||||
if not query_embedding:
|
||||
# 无法生成向量时,使用文本匹配
|
||||
return await cls._text_search(agent_id, query, limit, memory_type)
|
||||
|
||||
# 获取Agent的所有记忆
|
||||
filter_query = {"agent_id": agent_id}
|
||||
if memory_type:
|
||||
filter_query["memory_type"] = memory_type
|
||||
|
||||
memories = await AgentMemory.find(filter_query).to_list()
|
||||
|
||||
# 计算相似度
|
||||
results = []
|
||||
for memory in memories:
|
||||
if memory.is_expired():
|
||||
continue
|
||||
|
||||
if memory.embedding:
|
||||
similarity = cls._cosine_similarity(query_embedding, memory.embedding)
|
||||
relevance = memory.calculate_relevance_score(similarity)
|
||||
|
||||
if relevance >= min_relevance:
|
||||
results.append({
|
||||
"memory": memory,
|
||||
"similarity": similarity,
|
||||
"relevance": relevance
|
||||
})
|
||||
|
||||
# 按相关性排序
|
||||
results.sort(key=lambda x: x["relevance"], reverse=True)
|
||||
|
||||
# 更新访问记录
|
||||
for item in results[:limit]:
|
||||
memory = item["memory"]
|
||||
memory.access()
|
||||
await memory.save()
|
||||
|
||||
return results[:limit]
|
||||
|
||||
@classmethod
|
||||
async def update_memory(
|
||||
cls,
|
||||
memory_id: str,
|
||||
**kwargs
|
||||
) -> Optional[AgentMemory]:
|
||||
"""
|
||||
更新记忆
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
**kwargs: 要更新的字段
|
||||
|
||||
Returns:
|
||||
更新后的AgentMemory或None
|
||||
"""
|
||||
memory = await cls.get_memory(memory_id)
|
||||
if not memory:
|
||||
return None
|
||||
|
||||
# 如果更新了内容,重新生成嵌入
|
||||
if "content" in kwargs:
|
||||
kwargs["embedding"] = await cls._generate_embedding(kwargs["content"])
|
||||
kwargs["summary"] = kwargs["content"][:100] + "..." if len(kwargs["content"]) > 100 else kwargs["content"]
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(memory, key):
|
||||
setattr(memory, key, value)
|
||||
|
||||
await memory.save()
|
||||
return memory
|
||||
|
||||
@classmethod
|
||||
async def delete_memory(cls, memory_id: str) -> bool:
|
||||
"""
|
||||
删除记忆
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
memory = await cls.get_memory(memory_id)
|
||||
if not memory:
|
||||
return False
|
||||
|
||||
await memory.delete()
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def delete_agent_memories(
|
||||
cls,
|
||||
agent_id: str,
|
||||
memory_type: Optional[str] = None
|
||||
) -> int:
|
||||
"""
|
||||
删除Agent的记忆
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
memory_type: 记忆类型(可选)
|
||||
|
||||
Returns:
|
||||
删除的数量
|
||||
"""
|
||||
query = {"agent_id": agent_id}
|
||||
if memory_type:
|
||||
query["memory_type"] = memory_type
|
||||
|
||||
result = await AgentMemory.find(query).delete()
|
||||
return result.deleted_count if result else 0
|
||||
|
||||
@classmethod
|
||||
async def cleanup_expired_memories(cls) -> int:
|
||||
"""
|
||||
清理过期的记忆
|
||||
|
||||
Returns:
|
||||
清理的数量
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
result = await AgentMemory.find(
|
||||
{"expires_at": {"$lt": now}}
|
||||
).delete()
|
||||
|
||||
count = result.deleted_count if result else 0
|
||||
if count > 0:
|
||||
logger.info(f"清理了 {count} 条过期记忆")
|
||||
|
||||
return count
|
||||
|
||||
@classmethod
|
||||
async def consolidate_memories(
|
||||
cls,
|
||||
agent_id: str,
|
||||
min_importance: float = 0.7,
|
||||
max_age_days: int = 30
|
||||
) -> None:
|
||||
"""
|
||||
整合记忆(将重要的短期记忆转为长期记忆)
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
min_importance: 最小重要性阈值
|
||||
max_age_days: 最大年龄(天)
|
||||
"""
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=max_age_days)
|
||||
|
||||
# 查找符合条件的短期记忆
|
||||
memories = await AgentMemory.find({
|
||||
"agent_id": agent_id,
|
||||
"memory_type": MemoryType.SHORT_TERM.value,
|
||||
"importance": {"$gte": min_importance},
|
||||
"created_at": {"$lt": cutoff_date}
|
||||
}).to_list()
|
||||
|
||||
for memory in memories:
|
||||
memory.memory_type = MemoryType.LONG_TERM.value
|
||||
memory.expires_at = None # 长期记忆不过期
|
||||
await memory.save()
|
||||
|
||||
if memories:
|
||||
logger.info(f"整合了 {len(memories)} 条记忆为长期记忆: Agent {agent_id}")
|
||||
|
||||
@classmethod
|
||||
async def _generate_embedding(cls, text: str) -> List[float]:
|
||||
"""
|
||||
生成文本的向量嵌入
|
||||
|
||||
Args:
|
||||
text: 文本内容
|
||||
|
||||
Returns:
|
||||
向量嵌入列表
|
||||
"""
|
||||
model = cls._get_embedding_model()
|
||||
if model is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
embedding = model.encode(text, convert_to_numpy=True)
|
||||
return embedding.tolist()
|
||||
except Exception as e:
|
||||
logger.warning(f"生成嵌入失败: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _cosine_similarity(cls, vec1: List[float], vec2: List[float]) -> float:
|
||||
"""
|
||||
计算余弦相似度
|
||||
|
||||
Args:
|
||||
vec1: 向量1
|
||||
vec2: 向量2
|
||||
|
||||
Returns:
|
||||
相似度 (0-1)
|
||||
"""
|
||||
if not vec1 or not vec2:
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
a = np.array(vec1)
|
||||
b = np.array(vec2)
|
||||
similarity = np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
||||
return float(max(0, similarity))
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
async def _text_search(
|
||||
cls,
|
||||
agent_id: str,
|
||||
query: str,
|
||||
limit: int,
|
||||
memory_type: Optional[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
文本搜索(后备方案)
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
query: 查询文本
|
||||
limit: 返回数量
|
||||
memory_type: 记忆类型
|
||||
|
||||
Returns:
|
||||
记忆列表
|
||||
"""
|
||||
filter_query = {"agent_id": agent_id}
|
||||
if memory_type:
|
||||
filter_query["memory_type"] = memory_type
|
||||
|
||||
# 简单的文本匹配
|
||||
memories = await AgentMemory.find(filter_query).to_list()
|
||||
|
||||
results = []
|
||||
query_lower = query.lower()
|
||||
for memory in memories:
|
||||
if memory.is_expired():
|
||||
continue
|
||||
|
||||
content_lower = memory.content.lower()
|
||||
if query_lower in content_lower:
|
||||
# 计算简单的匹配分数
|
||||
score = len(query_lower) / len(content_lower)
|
||||
results.append({
|
||||
"memory": memory,
|
||||
"similarity": score,
|
||||
"relevance": score * memory.importance
|
||||
})
|
||||
|
||||
results.sort(key=lambda x: x["relevance"], reverse=True)
|
||||
return results[:limit]
|
||||
335
backend/services/message_router.py
Normal file
335
backend/services/message_router.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
消息路由服务
|
||||
管理消息的发送和广播
|
||||
"""
|
||||
import uuid
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, Callable, Set
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
from fastapi import WebSocket
|
||||
|
||||
from models.message import Message, MessageType
|
||||
from models.chatroom import ChatRoom
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSocketConnection:
|
||||
"""WebSocket连接信息"""
|
||||
websocket: WebSocket
|
||||
room_id: str
|
||||
connected_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class MessageRouter:
|
||||
"""
|
||||
消息路由器
|
||||
管理WebSocket连接和消息广播
|
||||
"""
|
||||
|
||||
# 房间连接映射: room_id -> Set[WebSocket]
|
||||
_room_connections: Dict[str, Set[WebSocket]] = {}
|
||||
|
||||
# 消息回调: 用于外部订阅消息
|
||||
_message_callbacks: List[Callable] = []
|
||||
|
||||
@classmethod
|
||||
async def connect(cls, room_id: str, websocket: WebSocket) -> None:
|
||||
"""
|
||||
建立WebSocket连接
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
websocket: WebSocket实例
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
if room_id not in cls._room_connections:
|
||||
cls._room_connections[room_id] = set()
|
||||
|
||||
cls._room_connections[room_id].add(websocket)
|
||||
|
||||
logger.info(f"WebSocket连接建立: {room_id}, 当前连接数: {len(cls._room_connections[room_id])}")
|
||||
|
||||
@classmethod
|
||||
async def disconnect(cls, room_id: str, websocket: WebSocket) -> None:
|
||||
"""
|
||||
断开WebSocket连接
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
websocket: WebSocket实例
|
||||
"""
|
||||
if room_id in cls._room_connections:
|
||||
cls._room_connections[room_id].discard(websocket)
|
||||
|
||||
# 清理空房间
|
||||
if not cls._room_connections[room_id]:
|
||||
del cls._room_connections[room_id]
|
||||
|
||||
logger.info(f"WebSocket连接断开: {room_id}")
|
||||
|
||||
@classmethod
|
||||
async def broadcast_to_room(
|
||||
cls,
|
||||
room_id: str,
|
||||
message: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
向聊天室广播消息
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
message: 消息内容
|
||||
"""
|
||||
if room_id not in cls._room_connections:
|
||||
return
|
||||
|
||||
# 获取所有连接
|
||||
connections = cls._room_connections[room_id].copy()
|
||||
|
||||
# 并发发送
|
||||
tasks = []
|
||||
for websocket in connections:
|
||||
tasks.append(cls._send_message(room_id, websocket, message))
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
@classmethod
|
||||
async def _send_message(
|
||||
cls,
|
||||
room_id: str,
|
||||
websocket: WebSocket,
|
||||
message: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
向单个WebSocket发送消息
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
websocket: WebSocket实例
|
||||
message: 消息内容
|
||||
"""
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket发送失败: {e}")
|
||||
# 移除断开的连接
|
||||
await cls.disconnect(room_id, websocket)
|
||||
|
||||
@classmethod
|
||||
async def save_and_broadcast_message(
|
||||
cls,
|
||||
room_id: str,
|
||||
discussion_id: str,
|
||||
agent_id: Optional[str],
|
||||
content: str,
|
||||
message_type: str = MessageType.TEXT.value,
|
||||
round_num: int = 0,
|
||||
attachments: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_results: Optional[List[Dict[str, Any]]] = None
|
||||
) -> Message:
|
||||
"""
|
||||
保存消息并广播
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
discussion_id: 讨论ID
|
||||
agent_id: 发送Agent ID
|
||||
content: 消息内容
|
||||
message_type: 消息类型
|
||||
round_num: 轮次号
|
||||
attachments: 附件
|
||||
tool_calls: 工具调用
|
||||
tool_results: 工具结果
|
||||
|
||||
Returns:
|
||||
保存的Message文档
|
||||
"""
|
||||
# 创建消息
|
||||
message = Message(
|
||||
message_id=f"msg-{uuid.uuid4().hex[:12]}",
|
||||
room_id=room_id,
|
||||
discussion_id=discussion_id,
|
||||
agent_id=agent_id,
|
||||
content=content,
|
||||
message_type=message_type,
|
||||
attachments=attachments or [],
|
||||
round=round_num,
|
||||
token_count=len(content) // 4, # 粗略估计
|
||||
tool_calls=tool_calls or [],
|
||||
tool_results=tool_results or [],
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
await message.insert()
|
||||
|
||||
# 构建广播消息
|
||||
broadcast_data = {
|
||||
"type": "message",
|
||||
"data": {
|
||||
"message_id": message.message_id,
|
||||
"room_id": message.room_id,
|
||||
"discussion_id": message.discussion_id,
|
||||
"agent_id": message.agent_id,
|
||||
"content": message.content,
|
||||
"message_type": message.message_type,
|
||||
"round": message.round,
|
||||
"created_at": message.created_at.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
# 广播消息
|
||||
await cls.broadcast_to_room(room_id, broadcast_data)
|
||||
|
||||
# 触发回调
|
||||
for callback in cls._message_callbacks:
|
||||
try:
|
||||
await callback(message)
|
||||
except Exception as e:
|
||||
logger.error(f"消息回调执行失败: {e}")
|
||||
|
||||
return message
|
||||
|
||||
@classmethod
|
||||
async def broadcast_status(
|
||||
cls,
|
||||
room_id: str,
|
||||
status: str,
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""
|
||||
广播状态更新
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
status: 状态类型
|
||||
data: 附加数据
|
||||
"""
|
||||
message = {
|
||||
"type": "status",
|
||||
"status": status,
|
||||
"data": data or {},
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await cls.broadcast_to_room(room_id, message)
|
||||
|
||||
@classmethod
|
||||
async def broadcast_typing(
|
||||
cls,
|
||||
room_id: str,
|
||||
agent_id: str,
|
||||
is_typing: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
广播Agent输入状态
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
agent_id: Agent ID
|
||||
is_typing: 是否正在输入
|
||||
"""
|
||||
message = {
|
||||
"type": "typing",
|
||||
"agent_id": agent_id,
|
||||
"is_typing": is_typing,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await cls.broadcast_to_room(room_id, message)
|
||||
|
||||
@classmethod
|
||||
async def broadcast_round_info(
|
||||
cls,
|
||||
room_id: str,
|
||||
round_num: int,
|
||||
total_rounds: int
|
||||
) -> None:
|
||||
"""
|
||||
广播轮次信息
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
round_num: 当前轮次
|
||||
total_rounds: 最大轮次
|
||||
"""
|
||||
message = {
|
||||
"type": "round",
|
||||
"round": round_num,
|
||||
"total_rounds": total_rounds,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await cls.broadcast_to_room(room_id, message)
|
||||
|
||||
@classmethod
|
||||
async def broadcast_error(
|
||||
cls,
|
||||
room_id: str,
|
||||
error: str,
|
||||
agent_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
广播错误信息
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
error: 错误信息
|
||||
agent_id: 相关Agent ID
|
||||
"""
|
||||
message = {
|
||||
"type": "error",
|
||||
"error": error,
|
||||
"agent_id": agent_id,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await cls.broadcast_to_room(room_id, message)
|
||||
|
||||
@classmethod
|
||||
def register_callback(cls, callback: Callable) -> None:
|
||||
"""
|
||||
注册消息回调
|
||||
|
||||
Args:
|
||||
callback: 回调函数,接收Message参数
|
||||
"""
|
||||
cls._message_callbacks.append(callback)
|
||||
|
||||
@classmethod
|
||||
def unregister_callback(cls, callback: Callable) -> None:
|
||||
"""
|
||||
注销消息回调
|
||||
|
||||
Args:
|
||||
callback: 回调函数
|
||||
"""
|
||||
if callback in cls._message_callbacks:
|
||||
cls._message_callbacks.remove(callback)
|
||||
|
||||
@classmethod
|
||||
def get_connection_count(cls, room_id: str) -> int:
|
||||
"""
|
||||
获取房间连接数
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
连接数
|
||||
"""
|
||||
return len(cls._room_connections.get(room_id, set()))
|
||||
|
||||
@classmethod
|
||||
def get_all_room_ids(cls) -> List[str]:
|
||||
"""
|
||||
获取所有活跃房间ID
|
||||
|
||||
Returns:
|
||||
房间ID列表
|
||||
"""
|
||||
return list(cls._room_connections.keys())
|
||||
13
backend/utils/__init__.py
Normal file
13
backend/utils/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
工具函数模块
|
||||
"""
|
||||
from .encryption import encrypt_api_key, decrypt_api_key
|
||||
from .proxy_handler import get_http_client
|
||||
from .rate_limiter import RateLimiter
|
||||
|
||||
__all__ = [
|
||||
"encrypt_api_key",
|
||||
"decrypt_api_key",
|
||||
"get_http_client",
|
||||
"RateLimiter",
|
||||
]
|
||||
97
backend/utils/encryption.py
Normal file
97
backend/utils/encryption.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
加密工具模块
|
||||
用于API密钥的加密和解密
|
||||
"""
|
||||
import base64
|
||||
from cryptography.fernet import Fernet
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
from loguru import logger
|
||||
|
||||
from config import settings
|
||||
|
||||
|
||||
def _get_fernet() -> Fernet:
|
||||
"""
|
||||
获取Fernet加密器实例
|
||||
使用配置的加密密钥派生加密密钥
|
||||
|
||||
Returns:
|
||||
Fernet加密器
|
||||
"""
|
||||
# 使用PBKDF2从密钥派生32字节密钥
|
||||
salt = b"ai_chatroom_salt" # 固定salt,实际生产环境应使用随机salt
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=salt,
|
||||
iterations=100000,
|
||||
)
|
||||
key = base64.urlsafe_b64encode(
|
||||
kdf.derive(settings.ENCRYPTION_KEY.encode())
|
||||
)
|
||||
return Fernet(key)
|
||||
|
||||
|
||||
def encrypt_api_key(api_key: str) -> str:
|
||||
"""
|
||||
加密API密钥
|
||||
|
||||
Args:
|
||||
api_key: 原始API密钥
|
||||
|
||||
Returns:
|
||||
加密后的密钥字符串
|
||||
"""
|
||||
if not api_key:
|
||||
return ""
|
||||
|
||||
try:
|
||||
fernet = _get_fernet()
|
||||
encrypted = fernet.encrypt(api_key.encode())
|
||||
return encrypted.decode()
|
||||
except Exception as e:
|
||||
logger.error(f"API密钥加密失败: {e}")
|
||||
raise ValueError("加密失败")
|
||||
|
||||
|
||||
def decrypt_api_key(encrypted_key: str) -> str:
|
||||
"""
|
||||
解密API密钥
|
||||
|
||||
Args:
|
||||
encrypted_key: 加密的密钥字符串
|
||||
|
||||
Returns:
|
||||
解密后的原始API密钥
|
||||
"""
|
||||
if not encrypted_key:
|
||||
return ""
|
||||
|
||||
try:
|
||||
fernet = _get_fernet()
|
||||
decrypted = fernet.decrypt(encrypted_key.encode())
|
||||
return decrypted.decode()
|
||||
except Exception as e:
|
||||
logger.error(f"API密钥解密失败: {e}")
|
||||
raise ValueError("解密失败,密钥可能已损坏或被篡改")
|
||||
|
||||
|
||||
def mask_api_key(api_key: str, visible_chars: int = 4) -> str:
|
||||
"""
|
||||
掩码API密钥,用于安全显示
|
||||
|
||||
Args:
|
||||
api_key: 原始API密钥
|
||||
visible_chars: 末尾可见字符数
|
||||
|
||||
Returns:
|
||||
掩码后的密钥 (如: ****abc1)
|
||||
"""
|
||||
if not api_key:
|
||||
return ""
|
||||
|
||||
if len(api_key) <= visible_chars:
|
||||
return "*" * len(api_key)
|
||||
|
||||
return "*" * (len(api_key) - visible_chars) + api_key[-visible_chars:]
|
||||
135
backend/utils/proxy_handler.py
Normal file
135
backend/utils/proxy_handler.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
代理处理模块
|
||||
处理HTTP代理配置
|
||||
"""
|
||||
from typing import Optional, Dict, Any
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from config import settings
|
||||
|
||||
|
||||
def get_proxy_dict(
|
||||
use_proxy: bool,
|
||||
proxy_config: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
获取代理配置字典
|
||||
|
||||
Args:
|
||||
use_proxy: 是否使用代理
|
||||
proxy_config: 代理配置
|
||||
|
||||
Returns:
|
||||
代理配置字典或None
|
||||
"""
|
||||
if not use_proxy:
|
||||
return None
|
||||
|
||||
proxies = {}
|
||||
|
||||
if proxy_config:
|
||||
http_proxy = proxy_config.get("http_proxy")
|
||||
https_proxy = proxy_config.get("https_proxy")
|
||||
else:
|
||||
# 使用全局默认代理
|
||||
http_proxy = settings.DEFAULT_HTTP_PROXY
|
||||
https_proxy = settings.DEFAULT_HTTPS_PROXY
|
||||
|
||||
if http_proxy:
|
||||
proxies["http://"] = http_proxy
|
||||
if https_proxy:
|
||||
proxies["https://"] = https_proxy
|
||||
|
||||
return proxies if proxies else None
|
||||
|
||||
|
||||
def get_http_client(
|
||||
use_proxy: bool = False,
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
timeout: int = 60,
|
||||
**kwargs
|
||||
) -> httpx.AsyncClient:
|
||||
"""
|
||||
获取配置好的HTTP异步客户端
|
||||
|
||||
Args:
|
||||
use_proxy: 是否使用代理
|
||||
proxy_config: 代理配置
|
||||
timeout: 超时时间(秒)
|
||||
**kwargs: 其他httpx参数
|
||||
|
||||
Returns:
|
||||
配置好的httpx.AsyncClient实例
|
||||
"""
|
||||
proxies = get_proxy_dict(use_proxy, proxy_config)
|
||||
|
||||
client_kwargs = {
|
||||
"timeout": httpx.Timeout(timeout),
|
||||
"follow_redirects": True,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
if proxies:
|
||||
client_kwargs["proxies"] = proxies
|
||||
logger.debug(f"HTTP客户端使用代理: {proxies}")
|
||||
|
||||
return httpx.AsyncClient(**client_kwargs)
|
||||
|
||||
|
||||
async def test_proxy_connection(
|
||||
proxy_config: Dict[str, Any],
|
||||
test_url: str = "https://www.google.com"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
测试代理连接是否可用
|
||||
|
||||
Args:
|
||||
proxy_config: 代理配置
|
||||
test_url: 测试URL
|
||||
|
||||
Returns:
|
||||
测试结果字典,包含 success, message, latency_ms
|
||||
"""
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=True,
|
||||
proxy_config=proxy_config,
|
||||
timeout=10
|
||||
) as client:
|
||||
import time
|
||||
start = time.time()
|
||||
response = await client.get(test_url)
|
||||
latency = (time.time() - start) * 1000
|
||||
|
||||
if response.status_code == 200:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "代理连接正常",
|
||||
"latency_ms": round(latency, 2)
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"代理返回状态码: {response.status_code}",
|
||||
"latency_ms": round(latency, 2)
|
||||
}
|
||||
|
||||
except httpx.ProxyError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"代理连接失败: {str(e)}",
|
||||
"latency_ms": None
|
||||
}
|
||||
except httpx.TimeoutException:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "代理连接超时",
|
||||
"latency_ms": None
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"连接错误: {str(e)}",
|
||||
"latency_ms": None
|
||||
}
|
||||
233
backend/utils/rate_limiter.py
Normal file
233
backend/utils/rate_limiter.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
速率限制器模块
|
||||
使用令牌桶算法控制请求频率
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenBucket:
|
||||
"""令牌桶"""
|
||||
capacity: int # 桶容量
|
||||
tokens: float = field(init=False) # 当前令牌数
|
||||
refill_rate: float # 每秒填充速率
|
||||
last_refill: float = field(default_factory=time.time)
|
||||
|
||||
def __post_init__(self):
|
||||
self.tokens = float(self.capacity)
|
||||
|
||||
def _refill(self) -> None:
|
||||
"""填充令牌"""
|
||||
now = time.time()
|
||||
elapsed = now - self.last_refill
|
||||
self.tokens = min(
|
||||
self.capacity,
|
||||
self.tokens + elapsed * self.refill_rate
|
||||
)
|
||||
self.last_refill = now
|
||||
|
||||
def consume(self, tokens: int = 1) -> bool:
|
||||
"""
|
||||
尝试消费令牌
|
||||
|
||||
Args:
|
||||
tokens: 要消费的令牌数
|
||||
|
||||
Returns:
|
||||
是否消费成功
|
||||
"""
|
||||
self._refill()
|
||||
if self.tokens >= tokens:
|
||||
self.tokens -= tokens
|
||||
return True
|
||||
return False
|
||||
|
||||
def wait_time(self, tokens: int = 1) -> float:
|
||||
"""
|
||||
计算需要等待的时间
|
||||
|
||||
Args:
|
||||
tokens: 需要的令牌数
|
||||
|
||||
Returns:
|
||||
需要等待的秒数
|
||||
"""
|
||||
self._refill()
|
||||
if self.tokens >= tokens:
|
||||
return 0.0
|
||||
needed = tokens - self.tokens
|
||||
return needed / self.refill_rate
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
速率限制器
|
||||
管理多个提供商的速率限制
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._buckets: Dict[str, TokenBucket] = {}
|
||||
self._locks: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
provider_id: str,
|
||||
requests_per_minute: int = 60,
|
||||
tokens_per_minute: int = 100000
|
||||
) -> None:
|
||||
"""
|
||||
注册提供商的速率限制
|
||||
|
||||
Args:
|
||||
provider_id: 提供商ID
|
||||
requests_per_minute: 每分钟请求数
|
||||
tokens_per_minute: 每分钟token数
|
||||
"""
|
||||
# 请求限制桶
|
||||
self._buckets[f"{provider_id}:requests"] = TokenBucket(
|
||||
capacity=requests_per_minute,
|
||||
refill_rate=requests_per_minute / 60.0
|
||||
)
|
||||
|
||||
# Token限制桶
|
||||
self._buckets[f"{provider_id}:tokens"] = TokenBucket(
|
||||
capacity=tokens_per_minute,
|
||||
refill_rate=tokens_per_minute / 60.0
|
||||
)
|
||||
|
||||
# 创建锁
|
||||
self._locks[provider_id] = asyncio.Lock()
|
||||
|
||||
logger.debug(
|
||||
f"注册速率限制: {provider_id} - "
|
||||
f"{requests_per_minute}请求/分钟, "
|
||||
f"{tokens_per_minute}tokens/分钟"
|
||||
)
|
||||
|
||||
def unregister(self, provider_id: str) -> None:
|
||||
"""
|
||||
取消注册提供商的速率限制
|
||||
|
||||
Args:
|
||||
provider_id: 提供商ID
|
||||
"""
|
||||
self._buckets.pop(f"{provider_id}:requests", None)
|
||||
self._buckets.pop(f"{provider_id}:tokens", None)
|
||||
self._locks.pop(provider_id, None)
|
||||
|
||||
async def acquire(
|
||||
self,
|
||||
provider_id: str,
|
||||
estimated_tokens: int = 1
|
||||
) -> bool:
|
||||
"""
|
||||
获取请求许可(非阻塞)
|
||||
|
||||
Args:
|
||||
provider_id: 提供商ID
|
||||
estimated_tokens: 预估token数
|
||||
|
||||
Returns:
|
||||
是否获取成功
|
||||
"""
|
||||
request_bucket = self._buckets.get(f"{provider_id}:requests")
|
||||
token_bucket = self._buckets.get(f"{provider_id}:tokens")
|
||||
|
||||
if not request_bucket or not token_bucket:
|
||||
# 未注册,默认允许
|
||||
return True
|
||||
|
||||
lock = self._locks.get(provider_id)
|
||||
if lock:
|
||||
async with lock:
|
||||
if request_bucket.consume(1) and token_bucket.consume(estimated_tokens):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def acquire_wait(
|
||||
self,
|
||||
provider_id: str,
|
||||
estimated_tokens: int = 1,
|
||||
max_wait: float = 60.0
|
||||
) -> bool:
|
||||
"""
|
||||
获取请求许可(阻塞等待)
|
||||
|
||||
Args:
|
||||
provider_id: 提供商ID
|
||||
estimated_tokens: 预估token数
|
||||
max_wait: 最大等待时间(秒)
|
||||
|
||||
Returns:
|
||||
是否获取成功
|
||||
"""
|
||||
request_bucket = self._buckets.get(f"{provider_id}:requests")
|
||||
token_bucket = self._buckets.get(f"{provider_id}:tokens")
|
||||
|
||||
if not request_bucket or not token_bucket:
|
||||
return True
|
||||
|
||||
lock = self._locks.get(provider_id)
|
||||
if not lock:
|
||||
return True
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
while True:
|
||||
async with lock:
|
||||
# 计算需要等待的时间
|
||||
request_wait = request_bucket.wait_time(1)
|
||||
token_wait = token_bucket.wait_time(estimated_tokens)
|
||||
wait_time = max(request_wait, token_wait)
|
||||
|
||||
if wait_time == 0:
|
||||
request_bucket.consume(1)
|
||||
token_bucket.consume(estimated_tokens)
|
||||
return True
|
||||
|
||||
# 检查是否超时
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed + wait_time > max_wait:
|
||||
logger.warning(
|
||||
f"速率限制等待超时: {provider_id}, "
|
||||
f"需要等待{wait_time:.2f}秒"
|
||||
)
|
||||
return False
|
||||
|
||||
# 在锁外等待
|
||||
await asyncio.sleep(min(wait_time, 1.0))
|
||||
|
||||
def get_status(self, provider_id: str) -> Optional[Dict[str, any]]:
|
||||
"""
|
||||
获取提供商的速率限制状态
|
||||
|
||||
Args:
|
||||
provider_id: 提供商ID
|
||||
|
||||
Returns:
|
||||
状态字典
|
||||
"""
|
||||
request_bucket = self._buckets.get(f"{provider_id}:requests")
|
||||
token_bucket = self._buckets.get(f"{provider_id}:tokens")
|
||||
|
||||
if not request_bucket or not token_bucket:
|
||||
return None
|
||||
|
||||
request_bucket._refill()
|
||||
token_bucket._refill()
|
||||
|
||||
return {
|
||||
"requests_remaining": int(request_bucket.tokens),
|
||||
"requests_capacity": request_bucket.capacity,
|
||||
"tokens_remaining": int(token_bucket.tokens),
|
||||
"tokens_capacity": token_bucket.capacity
|
||||
}
|
||||
|
||||
|
||||
# 全局速率限制器实例
|
||||
rate_limiter = RateLimiter()
|
||||
Reference in New Issue
Block a user